TravelContentCreator/utils/output_handler.py

493 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import simplejson as json
import logging
from abc import ABC, abstractmethod
import traceback
import base64
# 自定义JSON编码器强制处理所有可能的JSON序列化问题
class SafeJSONEncoder(json.JSONEncoder):
"""安全的JSON编码器可以处理所有类型的字符串"""
def encode(self, obj):
"""重写encode方法确保任何字符串都能被安全编码"""
if isinstance(obj, dict):
# 处理字典:递归处理每个值
return '{' + ','.join(f'"{key}":{self.encode(value)}'
for key, value in obj.items()
if key not in ["error", "raw_result"]) + '}'
elif isinstance(obj, list):
# 处理列表:递归处理每个项
return '[' + ','.join(self.encode(item) for item in obj) + ']'
elif isinstance(obj, str):
# 安全处理字符串:移除可能导致问题的字符
safe_str = ''
for char in obj:
if char in '\n\r\t' or (32 <= ord(char) <= 126):
safe_str += char
# 跳过所有其他字符
return json.JSONEncoder.encode(self, safe_str)
else:
# 其他类型:使用默认处理
return json.JSONEncoder.encode(self, obj)
def iterencode(self, obj, _one_shot=False):
"""重写iterencode方法确保能处理迭代编码"""
return self.encode(obj)
class OutputHandler(ABC):
"""Abstract base class for handling the output of the generation pipeline."""
@abstractmethod
def handle_topic_results(self, run_id: str, topics_list: list, system_prompt: str, user_prompt: str):
"""Handles the results from the topic generation step."""
pass
@abstractmethod
def handle_content_variant(self, run_id: str, topic_index: int, variant_index: int, content_data: dict, prompt_data: str):
"""Handles the results for a single content variant."""
pass
@abstractmethod
def handle_poster_configs(self, run_id: str, topic_index: int, config_data: list | dict):
"""Handles the poster configuration generated for a topic."""
pass
@abstractmethod
def handle_generated_image(self, run_id: str, topic_index: int, variant_index: int, image_type: str, image_data, output_filename: str, metadata: dict = None):
"""Handles a generated image (collage or final poster).
Args:
image_type: Either 'collage' or 'poster'.
image_data: The image data (e.g., PIL Image object or bytes).
output_filename: The desired filename for the output (e.g., 'poster.jpg').
metadata: Optional dictionary with additional metadata about the image (e.g., used image files).
"""
pass
@abstractmethod
def finalize(self, run_id: str):
"""Perform any final actions for the run (e.g., close files, upload manifests)."""
pass
class FileSystemOutputHandler(OutputHandler):
"""Handles output by saving results to the local file system."""
def __init__(self, base_output_dir: str = "result"):
self.base_output_dir = base_output_dir
logging.info(f"FileSystemOutputHandler initialized. Base output directory: {self.base_output_dir}")
def _get_run_dir(self, run_id: str) -> str:
"""Gets the specific directory for a run, creating it if necessary."""
run_dir = os.path.join(self.base_output_dir, run_id)
os.makedirs(run_dir, exist_ok=True)
return run_dir
def _get_variant_dir(self, run_id: str, topic_index: int, variant_index: int, subdir: str | None = None) -> str:
"""Gets the specific directory for a variant, optionally within a subdirectory (e.g., 'poster'), creating it if necessary."""
run_dir = self._get_run_dir(run_id)
variant_base_dir = os.path.join(run_dir, f"{topic_index}_{variant_index}")
target_dir = variant_base_dir
if subdir:
target_dir = os.path.join(variant_base_dir, subdir)
os.makedirs(target_dir, exist_ok=True)
return target_dir
def handle_topic_results(self, run_id: str, topics_list: list, system_prompt: str, user_prompt: str):
run_dir = self._get_run_dir(run_id)
# Save topics list
topics_path = os.path.join(run_dir, f"tweet_topic_{run_id}.json")
try:
with open(topics_path, "w", encoding="utf-8") as f:
# 不使用自定义编码器直接使用标准json
json.dump(topics_list, f, ensure_ascii=False, indent=4, ignore_nan=True)
logging.info(f"Topics list saved successfully to: {topics_path}")
# 额外创建txt格式输出
txt_path = os.path.join(run_dir, f"tweet_topic_{run_id}.txt")
with open(txt_path, "w", encoding="utf-8") as f:
f.write(f"# 选题列表 (run_id: {run_id})\n\n")
for topic in topics_list:
f.write(f"## 选题 {topic.get('index', 'N/A')}\n")
f.write(f"- 日期: {topic.get('date', 'N/A')}\n")
f.write(f"- 对象: {topic.get('object', 'N/A')}\n")
f.write(f"- 产品: {topic.get('product', 'N/A')}\n")
f.write(f"- 产品策略: {topic.get('product_logic', 'N/A')}\n")
f.write(f"- 风格: {topic.get('style', 'N/A')}\n")
f.write(f"- 风格策略: {topic.get('style_logic', 'N/A')}\n")
f.write(f"- 目标受众: {topic.get('target_audience', 'N/A')}\n")
f.write(f"- 受众策略: {topic.get('target_audience_logic', 'N/A')}\n")
f.write(f"- 逻辑: {topic.get('logic', 'N/A')}\n\n")
logging.info(f"选题文本版本已保存到: {txt_path}")
except Exception as e:
logging.exception(f"Error saving topic JSON file to {topics_path}:")
# Save prompts
prompt_path = os.path.join(run_dir, f"topic_prompt_{run_id}.txt")
try:
with open(prompt_path, "w", encoding="utf-8") as f:
f.write("--- SYSTEM PROMPT ---\n")
f.write(system_prompt + "\n\n")
f.write("--- USER PROMPT ---\n")
f.write(user_prompt + "\n")
logging.info(f"Topic prompts saved successfully to: {prompt_path}")
except Exception as e:
logging.exception(f"Error saving topic prompts file to {prompt_path}:")
def handle_content_variant(self, run_id: str, topic_index: int, variant_index: int, content_data: dict, prompt_data: str):
"""Saves content JSON and prompt for a specific variant."""
variant_dir = self._get_variant_dir(run_id, topic_index, variant_index)
# 创建输出数据的副本,避免修改原始数据
import copy
input_data = copy.deepcopy(content_data)
# 简化输出数据结构只保留必要的元数据和base64编码内容
output_data = {
"judged": input_data.get("judged", False),
"judge_success": input_data.get("judge_success", False),
"error": input_data.get("error", False)
}
# 检查并处理内容字段确保全部以base64编码保存
try:
# 检查是否已经是base64编码的字段
def is_base64(s):
if not isinstance(s, str):
return False
try:
# 尝试解码看是否成功
base64.b64decode(s).decode('utf-8')
# 如果能成功解码而且是标准base64长度(4的倍数)则可能是base64
return len(s) % 4 == 0 and all(c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=' for c in s)
except:
return False
# 1. 处理标题和内容
if "title" in input_data:
if is_base64(input_data["title"]):
# 已经是base64编码直接使用
output_data["title_base64"] = input_data["title"]
# 尝试解码用于txt文件
try:
title_text = base64.b64decode(input_data["title"]).decode('utf-8')
except:
title_text = input_data["title"]
else:
# 未编码进行base64编码
title_text = input_data["title"]
output_data["title_base64"] = base64.b64encode(title_text.encode('utf-8')).decode('ascii')
if "content" in input_data:
if is_base64(input_data["content"]):
# 已经是base64编码直接使用
output_data["content_base64"] = input_data["content"]
# 尝试解码用于txt文件
try:
content_text = base64.b64decode(input_data["content"]).decode('utf-8')
except:
content_text = input_data["content"]
else:
# 未编码进行base64编码
content_text = input_data["content"]
output_data["content_base64"] = base64.b64encode(content_text.encode('utf-8')).decode('ascii')
# 2. 处理标签
tags_text = input_data.get("tags", input_data.get("tag", ""))
if tags_text:
if is_base64(tags_text):
output_data["tags_base64"] = tags_text
try:
tags_text = base64.b64decode(tags_text).decode('utf-8')
except:
pass
else:
output_data["tags_base64"] = base64.b64encode(tags_text.encode('utf-8')).decode('ascii')
# 3. 处理分析
analysis_text = input_data.get("analysis", input_data.get("judge_analysis", ""))
if analysis_text:
if is_base64(analysis_text):
output_data["analysis_base64"] = analysis_text
try:
analysis_text = base64.b64decode(analysis_text).decode('utf-8')
except:
pass
else:
output_data["analysis_base64"] = base64.b64encode(analysis_text.encode('utf-8')).decode('ascii')
logging.info("成功处理内容并添加Base64编码")
except Exception as e:
logging.error(f"处理内容或Base64编码时出错: {e}")
# 保存处理后的article.json
content_path = os.path.join(variant_dir, "article.json")
try:
with open(content_path, "w", encoding="utf-8") as f:
# 使用标准json并确保正确处理中文和特殊字符
json.dump(output_data, f, ensure_ascii=False, indent=4)
logging.info(f"Content JSON saved to: {content_path}")
except Exception as e:
logging.exception(f"Failed to save content JSON to {content_path}: {e}")
# 创建一份article.txt文件使用解码后的文本
txt_path = os.path.join(variant_dir, "article.txt")
try:
is_judged = output_data.get("judged", False)
is_judge_success = output_data.get("judge_success", False)
# 确保我们有可用的文本版本
title_text = title_text if 'title_text' in locals() else "未找到标题"
content_text = content_text if 'content_text' in locals() else "未找到内容"
tags_text = tags_text if 'tags_text' in locals() else ""
with open(txt_path, "w", encoding="utf-8") as f:
# 根据审核状态决定显示内容
if is_judged and is_judge_success:
f.write(f"{title_text}\n\n")
f.write(content_text)
if tags_text:
f.write(f"\n\n{tags_text}")
else:
# 未审核或审核未通过
if not is_judged:
f.write(f"{title_text}\n\n")
else:
# 审核失败
f.write(f"审核失败\n\n{title_text}\n\n")
f.write(content_text)
if tags_text:
f.write(f"\n\n{tags_text}")
# 添加审核分析
if 'analysis_text' in locals() and analysis_text:
f.write(f"\n\n=== 审核分析 ===\n{analysis_text}")
logging.info(f"Article text saved to: {txt_path}")
except Exception as e:
logging.error(f"Failed to save article.txt: {e}")
# 保存调试信息
debug_path = os.path.join(variant_dir, "debug_content.txt")
try:
with open(debug_path, "w", encoding="utf-8") as f:
f.write(f"处理前内容信息:\n")
f.write(f"标题: {input_data.get('title', '未提供')[:200]}...\n\n")
f.write(f"内容: {input_data.get('content', '未提供')[:200]}...\n\n")
f.write(f"标签: {input_data.get('tags', input_data.get('tag', '未提供'))}\n\n")
f.write(f"审核状态: judged={input_data.get('judged', False)}, judge_success={input_data.get('judge_success', False)}\n\n")
f.write("处理后JSON输出字段:\n")
for key, value in output_data.items():
value_preview = str(value)[:100] + "..." if isinstance(value, str) and len(str(value)) > 100 else value
f.write(f"{key}: {value_preview}\n")
f.write("\n解码后文本内容:\n")
f.write(f"标题: {title_text if 'title_text' in locals() else '未解码'}\n\n")
f.write(f"内容: {content_text[:200] if 'content_text' in locals() else '未解码'}...\n")
logging.info(f"调试内容已保存到: {debug_path}")
except Exception as debug_err:
logging.error(f"保存调试内容失败: {debug_err}")
# 保存提示词
prompt_path = os.path.join(variant_dir, "tweet_prompt.txt")
try:
with open(prompt_path, "w", encoding="utf-8") as f:
f.write(prompt_data)
logging.info(f"Content prompt saved to: {prompt_path}")
except Exception as e:
logging.error(f"Failed to save content prompt to {prompt_path}: {e}")
def _ultra_safe_clean(self, text):
"""执行最严格的字符清理确保100%可序列化"""
if not isinstance(text, str):
return ""
return ''.join(c for c in text if 32 <= ord(c) <= 126)
def _preprocess_for_json(self, text):
"""预处理文本,将换行符转换为\\n形式保证JSON安全"""
if not isinstance(text, str):
return text
# 将所有实际换行符替换为\\n字符串
return text.replace('\n', '\\n').replace('\r', '\\r')
def handle_poster_configs(self, run_id: str, topic_index: int, config_data: list | dict):
"""处理海报配置数据"""
# 处理海报配置数据
try:
# 创建目标目录
variant_dir = os.path.join(self._get_run_dir(run_id), f"{topic_index}_1")
os.makedirs(variant_dir, exist_ok=True)
# 确保配置数据是可序列化的
processed_configs = []
if isinstance(config_data, list):
for config in config_data:
processed_config = {}
# 处理索引字段
processed_config["index"] = config.get("index", 0)
# 处理标题字段应用JSON预处理
main_title = config.get("main_title", "")
processed_config["main_title"] = self._preprocess_for_json(main_title)
# 处理文本字段列表对每个文本应用JSON预处理
texts = config.get("texts", [])
processed_texts = []
for text in texts:
processed_texts.append(self._preprocess_for_json(text))
processed_config["texts"] = processed_texts
processed_configs.append(processed_config)
else:
# 如果不是列表,可能是字典或其他格式,尝试转换
if isinstance(config_data, dict):
# 处理单个配置字典
processed_config = {}
processed_config["index"] = config_data.get("index", 0)
processed_config["main_title"] = self._preprocess_for_json(config_data.get("main_title", ""))
texts = config_data.get("texts", [])
processed_texts = []
for text in texts:
processed_texts.append(self._preprocess_for_json(text))
processed_config["texts"] = processed_texts
processed_configs.append(processed_config)
# 保存配置到JSON文件
config_file_path = os.path.join(variant_dir, f"topic_{topic_index}_poster_configs.json")
with open(config_file_path, 'w', encoding='utf-8') as f:
json.dump(processed_configs, f, ensure_ascii=False, indent=4)
logging.info(f"Successfully saved poster configs to {config_file_path}")
except Exception as e:
logging.error(f"Error saving poster configs: {e}")
traceback.print_exc()
def handle_generated_image(self, run_id: str, topic_index: int, variant_index: int, image_type: str, image_data, output_filename: str, metadata: dict = None):
"""处理生成的图像对于笔记图像和额外配图保存到image目录其他类型保持原有路径结构"""
if not image_data:
logging.warning(f"传入的{image_type}图像数据为空 (Topic {topic_index}, Variant {variant_index})。跳过保存。")
return
# 根据图像类型确定保存路径
if image_type == 'note' or image_type == 'additional': # 笔记图像和额外配图保存到image目录
# 创建run_id/i_j/image目录
run_dir = self._get_run_dir(run_id)
variant_dir = os.path.join(run_dir, f"{topic_index}_{variant_index}")
image_dir = os.path.join(variant_dir, "image")
os.makedirs(image_dir, exist_ok=True)
# 在输出文件名前加上图像类型前缀
prefixed_filename = f"{image_type}_{output_filename}" if not output_filename.startswith(image_type) else output_filename
save_path = os.path.join(image_dir, prefixed_filename)
else:
# 其他类型图像使用原有的保存路径逻辑
subdir = None
if image_type == 'collage':
subdir = 'collage_img' # 可配置的子目录名称
elif image_type == 'poster':
subdir = 'poster'
else:
logging.warning(f"未知图像类型 '{image_type}',保存到变体根目录。")
subdir = None # 如果类型未知,直接保存到变体目录
# 确保目标目录存在
variant_dir = os.path.join(self._get_run_dir(run_id), f"{topic_index}_{variant_index}")
if subdir:
target_dir = os.path.join(variant_dir, subdir)
os.makedirs(target_dir, exist_ok=True)
else:
target_dir = variant_dir
os.makedirs(target_dir, exist_ok=True)
save_path = os.path.join(target_dir, output_filename)
try:
# 保存图片
image_data.save(save_path)
logging.info(f"保存{image_type}图像到: {save_path}")
# 保存元数据(如果有)
if metadata:
# 确保元数据文件与图像在同一目录
metadata_filename = os.path.splitext(os.path.basename(save_path))[0] + "_metadata.json"
metadata_path = os.path.join(os.path.dirname(save_path), metadata_filename)
try:
with open(metadata_path, 'w', encoding='utf-8') as f:
# 不使用自定义编码器使用标准json
json.dump(metadata, f, ensure_ascii=False, indent=4, ignore_nan=True)
logging.info(f"保存{image_type}元数据到: {metadata_path}")
except Exception as me:
logging.error(f"无法保存{image_type}元数据到{metadata_path}: {me}")
traceback.print_exc()
except Exception as e:
logging.exception(f"无法保存{image_type}图像到{save_path}: {e}")
traceback.print_exc()
return save_path
def finalize(self, run_id: str):
logging.info(f"FileSystemOutputHandler finalizing run: {run_id}. No specific actions needed.")
pass # Nothing specific to do for file system finalize
def _sanitize_content_for_json(self, data):
"""对内容进行深度清理确保可以安全序列化为JSON
Args:
data: 要处理的数据(字典、列表或基本类型)
Returns:
经过处理的数据,确保可以安全序列化
"""
if isinstance(data, dict):
# 处理字典类型
sanitized_dict = {}
for key, value in data.items():
# 移除error标志我们会在最终验证后重新设置它
if key == "error":
continue
if key == "raw_result":
continue
sanitized_dict[key] = self._sanitize_content_for_json(value)
return sanitized_dict
elif isinstance(data, list):
# 处理列表类型
return [self._sanitize_content_for_json(item) for item in data]
elif isinstance(data, str):
# 处理字符串类型(重点关注)
# 1. 首先,替换所有字面的"\n"为真正的换行符
if r'\n' in data:
data = data.replace(r'\n', '\n')
# 2. 使用更强的处理方式 - 只保留绝对安全的字符
# - ASCII 32-126 (标准可打印ASCII字符)
# - 换行、回车、制表符
# - 去除所有其他控制字符和潜在问题字符
safe_chars = []
for char in data:
if char in '\n\r\t' or (32 <= ord(char) <= 126):
safe_chars.append(char)
elif ord(char) > 127: # 非ASCII字符 (包括emoji)
# 转换为Unicode转义序列
safe_chars.append(f"\\u{ord(char):04x}".encode().decode('unicode-escape'))
cleaned = ''.join(safe_chars)
# 3. 验证字符串可以被安全序列化
try:
json.dumps(cleaned, ensure_ascii=False)
return cleaned
except Exception as e:
logging.warning(f"字符串清理后仍无法序列化,使用保守处理: {e}")
# 最保守的处理 - 只保留ASCII字符
return ''.join(c for c in cleaned if ord(c) < 128)
else:
# 其他类型(数字、布尔值等)原样返回
return data