TravelContentCreator/utils/output_handler.py

466 lines
22 KiB
Python
Raw Normal View History

import os
import simplejson as json
import logging
from abc import ABC, abstractmethod
import traceback
2025-05-18 22:29:29 +08:00
import base64
# 自定义JSON编码器强制处理所有可能的JSON序列化问题
class SafeJSONEncoder(json.JSONEncoder):
"""安全的JSON编码器可以处理所有类型的字符串"""
def encode(self, obj):
"""重写encode方法确保任何字符串都能被安全编码"""
if isinstance(obj, dict):
# 处理字典:递归处理每个值
return '{' + ','.join(f'"{key}":{self.encode(value)}'
for key, value in obj.items()
if key not in ["error", "raw_result"]) + '}'
elif isinstance(obj, list):
# 处理列表:递归处理每个项
return '[' + ','.join(self.encode(item) for item in obj) + ']'
elif isinstance(obj, str):
# 安全处理字符串:移除可能导致问题的字符
safe_str = ''
for char in obj:
if char in '\n\r\t' or (32 <= ord(char) <= 126):
safe_str += char
# 跳过所有其他字符
return json.JSONEncoder.encode(self, safe_str)
else:
# 其他类型:使用默认处理
return json.JSONEncoder.encode(self, obj)
def iterencode(self, obj, _one_shot=False):
"""重写iterencode方法确保能处理迭代编码"""
return self.encode(obj)
class OutputHandler(ABC):
"""Abstract base class for handling the output of the generation pipeline."""
@abstractmethod
def handle_topic_results(self, run_id: str, topics_list: list, system_prompt: str, user_prompt: str):
"""Handles the results from the topic generation step."""
pass
@abstractmethod
def handle_content_variant(self, run_id: str, topic_index: int, variant_index: int, content_data: dict, prompt_data: str):
"""Handles the results for a single content variant."""
pass
@abstractmethod
def handle_poster_configs(self, run_id: str, topic_index: int, config_data: list | dict):
"""Handles the poster configuration generated for a topic."""
pass
@abstractmethod
2025-04-26 13:45:47 +08:00
def handle_generated_image(self, run_id: str, topic_index: int, variant_index: int, image_type: str, image_data, output_filename: str, metadata: dict = None):
"""Handles a generated image (collage or final poster).
Args:
image_type: Either 'collage' or 'poster'.
image_data: The image data (e.g., PIL Image object or bytes).
output_filename: The desired filename for the output (e.g., 'poster.jpg').
2025-04-26 13:45:47 +08:00
metadata: Optional dictionary with additional metadata about the image (e.g., used image files).
"""
pass
@abstractmethod
def finalize(self, run_id: str):
"""Perform any final actions for the run (e.g., close files, upload manifests)."""
pass
class FileSystemOutputHandler(OutputHandler):
"""Handles output by saving results to the local file system."""
def __init__(self, base_output_dir: str = "result"):
self.base_output_dir = base_output_dir
logging.info(f"FileSystemOutputHandler initialized. Base output directory: {self.base_output_dir}")
def _get_run_dir(self, run_id: str) -> str:
"""Gets the specific directory for a run, creating it if necessary."""
run_dir = os.path.join(self.base_output_dir, run_id)
os.makedirs(run_dir, exist_ok=True)
return run_dir
def _get_variant_dir(self, run_id: str, topic_index: int, variant_index: int, subdir: str | None = None) -> str:
"""Gets the specific directory for a variant, optionally within a subdirectory (e.g., 'poster'), creating it if necessary."""
run_dir = self._get_run_dir(run_id)
variant_base_dir = os.path.join(run_dir, f"{topic_index}_{variant_index}")
target_dir = variant_base_dir
if subdir:
target_dir = os.path.join(variant_base_dir, subdir)
os.makedirs(target_dir, exist_ok=True)
return target_dir
def handle_topic_results(self, run_id: str, topics_list: list, system_prompt: str, user_prompt: str):
run_dir = self._get_run_dir(run_id)
# Save topics list
topics_path = os.path.join(run_dir, f"tweet_topic_{run_id}.json")
try:
with open(topics_path, "w", encoding="utf-8") as f:
# 不使用自定义编码器直接使用标准json
json.dump(topics_list, f, ensure_ascii=False, indent=4, ignore_nan=True)
logging.info(f"Topics list saved successfully to: {topics_path}")
# 额外创建txt格式输出
txt_path = os.path.join(run_dir, f"tweet_topic_{run_id}.txt")
with open(txt_path, "w", encoding="utf-8") as f:
f.write(f"# 选题列表 (run_id: {run_id})\n\n")
for topic in topics_list:
f.write(f"## 选题 {topic.get('index', 'N/A')}\n")
f.write(f"- 日期: {topic.get('date', 'N/A')}\n")
f.write(f"- 对象: {topic.get('object', 'N/A')}\n")
f.write(f"- 产品: {topic.get('product', 'N/A')}\n")
f.write(f"- 产品策略: {topic.get('product_logic', 'N/A')}\n")
f.write(f"- 风格: {topic.get('style', 'N/A')}\n")
f.write(f"- 风格策略: {topic.get('style_logic', 'N/A')}\n")
f.write(f"- 目标受众: {topic.get('target_audience', 'N/A')}\n")
f.write(f"- 受众策略: {topic.get('target_audience_logic', 'N/A')}\n")
f.write(f"- 逻辑: {topic.get('logic', 'N/A')}\n\n")
logging.info(f"选题文本版本已保存到: {txt_path}")
except Exception as e:
logging.exception(f"Error saving topic JSON file to {topics_path}:")
# Save prompts
prompt_path = os.path.join(run_dir, f"topic_prompt_{run_id}.txt")
try:
with open(prompt_path, "w", encoding="utf-8") as f:
f.write("--- SYSTEM PROMPT ---\n")
f.write(system_prompt + "\n\n")
f.write("--- USER PROMPT ---\n")
f.write(user_prompt + "\n")
logging.info(f"Topic prompts saved successfully to: {prompt_path}")
except Exception as e:
logging.exception(f"Error saving topic prompts file to {prompt_path}:")
def handle_content_variant(self, run_id: str, topic_index: int, variant_index: int, content_data: dict, prompt_data: str):
"""Saves content JSON and prompt for a specific variant."""
variant_dir = self._get_variant_dir(run_id, topic_index, variant_index)
2025-05-10 21:03:55 +08:00
# 创建输出数据的副本,避免修改原始数据
import copy
input_data = copy.deepcopy(content_data)
2025-05-10 21:03:55 +08:00
2025-05-12 15:44:54 +08:00
# 统一使用tags字段避免tag和tags重复
if "tag" in input_data and "tags" not in input_data:
2025-05-12 15:44:54 +08:00
# 只有tag字段存在复制到tags
input_data["tags"] = input_data["tag"]
elif "tag" in input_data and "tags" in input_data:
# 两个字段都存在保留tags
pass
2025-05-10 21:03:55 +08:00
# 确保即使在未启用审核的情况下,字段也保持一致
if not input_data.get("judged", False):
input_data["judged"] = False
# 添加original字段临时值为当前值
if "title" in input_data and "original_title" not in input_data:
input_data["original_title"] = input_data["title"]
if "content" in input_data and "original_content" not in input_data:
input_data["original_content"] = input_data["content"]
if "tags" in input_data and "original_tags" not in input_data:
input_data["original_tags"] = input_data["tags"]
2025-05-10 20:53:31 +08:00
# 保存原始值用于txt文件生成和调试
original_title = input_data.get("title", "")
original_content = input_data.get("content", "")
original_tags = input_data.get("tags", "")
original_judge_analysis = input_data.get("judge_analysis", "")
2025-05-18 22:29:29 +08:00
# 创建一个只包含元数据和base64编码的输出数据对象
output_data = {
# 保留元数据字段
"judged": input_data.get("judged", False),
"judge_success": input_data.get("judge_success", False)
}
# 为所有内容字段创建base64编码版本
2025-05-18 22:29:29 +08:00
try:
# 1. 标题和内容
if "title" in input_data and input_data["title"]:
output_data["title_base64"] = base64.b64encode(input_data["title"].encode('utf-8')).decode('ascii')
2025-05-18 22:29:29 +08:00
if "content" in input_data and input_data["content"]:
output_data["content_base64"] = base64.b64encode(input_data["content"].encode('utf-8')).decode('ascii')
2025-05-18 22:29:29 +08:00
# 2. 标签
if "tags" in input_data and input_data["tags"]:
output_data["tags_base64"] = base64.b64encode(input_data["tags"].encode('utf-8')).decode('ascii')
# 3. 原始内容
if "original_title" in input_data and input_data["original_title"]:
output_data["original_title_base64"] = base64.b64encode(input_data["original_title"].encode('utf-8')).decode('ascii')
2025-05-18 22:29:29 +08:00
if "original_content" in input_data and input_data["original_content"]:
output_data["original_content_base64"] = base64.b64encode(input_data["original_content"].encode('utf-8')).decode('ascii')
2025-05-18 22:29:29 +08:00
# 4. 原始标签
if "original_tags" in input_data and input_data["original_tags"]:
output_data["original_tags_base64"] = base64.b64encode(input_data["original_tags"].encode('utf-8')).decode('ascii')
2025-05-18 22:29:29 +08:00
# 5. 审核分析
if "judge_analysis" in input_data and input_data["judge_analysis"]:
output_data["judge_analysis_base64"] = base64.b64encode(input_data["judge_analysis"].encode('utf-8')).decode('ascii')
2025-05-18 22:29:29 +08:00
logging.info("成功添加Base64编码内容")
except Exception as e:
logging.error(f"Base64编码内容时出错: {e}")
# 保存可能有用的额外字段
if "error" in input_data:
output_data["error"] = input_data["error"]
# 保存统一格式的article.json (只包含base64编码和元数据)
content_path = os.path.join(variant_dir, "article.json")
try:
with open(content_path, "w", encoding="utf-8") as f:
# 使用标准json
json.dump(output_data, f, ensure_ascii=False, indent=4, ignore_nan=True)
logging.info(f"Content JSON saved to: {content_path}")
except Exception as e:
logging.exception(f"Failed to save content JSON to {content_path}: {e}")
2025-05-18 22:29:29 +08:00
# 创建一份article.txt文件以便直接查看
txt_path = os.path.join(variant_dir, "article.txt")
try:
# 使用原始内容,保留所有换行符
2025-05-18 22:29:29 +08:00
with open(txt_path, "w", encoding="utf-8") as f:
if original_title:
f.write(f"{original_title}\n\n")
# 保持原始内容的所有换行符
if original_content:
f.write(original_content)
if original_tags:
f.write(f"\n\n{original_tags}")
if original_judge_analysis:
f.write(f"\n\n审核分析:\n{original_judge_analysis}")
2025-05-18 22:29:29 +08:00
logging.info(f"Article text saved to: {txt_path}")
except Exception as e:
logging.error(f"Failed to save article.txt: {e}")
# 记录调试信息,无论是否成功 (包含原始数据的完整副本以便调试)
2025-05-18 22:29:29 +08:00
debug_path = os.path.join(variant_dir, "debug_content.txt")
try:
with open(debug_path, "w", encoding="utf-8") as f:
f.write(f"原始标题: {original_title}\n\n")
f.write(f"原始内容: {original_content}\n\n")
if original_tags:
f.write(f"原始标签: {original_tags}\n\n")
if original_judge_analysis:
f.write(f"审核分析: {original_judge_analysis}\n\n")
2025-05-18 22:29:29 +08:00
f.write("---处理后---\n\n")
for key, value in output_data.items():
if isinstance(value, str):
f.write(f"{key}: (length: {len(value)})\n")
f.write(f"{repr(value[:200])}...\n\n")
else:
f.write(f"{key}: {type(value)}\n")
logging.info(f"调试内容已保存到: {debug_path}")
except Exception as debug_err:
logging.error(f"保存调试内容失败: {debug_err}")
# Save content prompt
prompt_path = os.path.join(variant_dir, "tweet_prompt.txt")
try:
with open(prompt_path, "w", encoding="utf-8") as f:
f.write(prompt_data)
logging.info(f"Content prompt saved to: {prompt_path}")
except Exception as e:
logging.error(f"Failed to save content prompt to {prompt_path}: {e}")
2025-05-18 22:29:29 +08:00
def _ultra_safe_clean(self, text):
"""执行最严格的字符清理确保100%可序列化"""
if not isinstance(text, str):
return ""
return ''.join(c for c in text if 32 <= ord(c) <= 126)
def _preprocess_for_json(self, text):
"""预处理文本,将换行符转换为\\n形式保证JSON安全"""
if not isinstance(text, str):
return text
# 将所有实际换行符替换为\\n字符串
return text.replace('\n', '\\n').replace('\r', '\\r')
def handle_poster_configs(self, run_id: str, topic_index: int, config_data: list | dict):
"""处理海报配置数据"""
# 处理海报配置数据
try:
# 创建目标目录
variant_dir = os.path.join(self._get_run_dir(run_id), f"{topic_index}_1")
os.makedirs(variant_dir, exist_ok=True)
# 确保配置数据是可序列化的
processed_configs = []
if isinstance(config_data, list):
for config in config_data:
processed_config = {}
# 处理索引字段
processed_config["index"] = config.get("index", 0)
# 处理标题字段应用JSON预处理
main_title = config.get("main_title", "")
processed_config["main_title"] = self._preprocess_for_json(main_title)
# 处理文本字段列表对每个文本应用JSON预处理
texts = config.get("texts", [])
processed_texts = []
for text in texts:
processed_texts.append(self._preprocess_for_json(text))
processed_config["texts"] = processed_texts
processed_configs.append(processed_config)
else:
# 如果不是列表,可能是字典或其他格式,尝试转换
if isinstance(config_data, dict):
# 处理单个配置字典
processed_config = {}
processed_config["index"] = config_data.get("index", 0)
processed_config["main_title"] = self._preprocess_for_json(config_data.get("main_title", ""))
texts = config_data.get("texts", [])
processed_texts = []
for text in texts:
processed_texts.append(self._preprocess_for_json(text))
processed_config["texts"] = processed_texts
processed_configs.append(processed_config)
# 保存配置到JSON文件
config_file_path = os.path.join(variant_dir, f"topic_{topic_index}_poster_configs.json")
with open(config_file_path, 'w', encoding='utf-8') as f:
json.dump(processed_configs, f, ensure_ascii=False, indent=4, cls=self.SafeJSONEncoder)
logging.info(f"Successfully saved poster configs to {config_file_path}")
except Exception as e:
logging.error(f"Error saving poster configs: {e}")
traceback.print_exc()
2025-04-26 13:45:47 +08:00
def handle_generated_image(self, run_id: str, topic_index: int, variant_index: int, image_type: str, image_data, output_filename: str, metadata: dict = None):
2025-04-26 14:53:54 +08:00
"""处理生成的图像对于笔记图像和额外配图保存到image目录其他类型保持原有路径结构"""
if not image_data:
logging.warning(f"传入的{image_type}图像数据为空 (Topic {topic_index}, Variant {variant_index})。跳过保存。")
return
2025-04-26 14:53:54 +08:00
# 根据图像类型确定保存路径
if image_type == 'note' or image_type == 'additional': # 笔记图像和额外配图保存到image目录
# 创建run_id/i_j/image目录
run_dir = self._get_run_dir(run_id)
variant_dir = os.path.join(run_dir, f"{topic_index}_{variant_index}")
image_dir = os.path.join(variant_dir, "image")
os.makedirs(image_dir, exist_ok=True)
# 在输出文件名前加上图像类型前缀
prefixed_filename = f"{image_type}_{output_filename}" if not output_filename.startswith(image_type) else output_filename
save_path = os.path.join(image_dir, prefixed_filename)
else:
2025-04-26 14:53:54 +08:00
# 其他类型图像使用原有的保存路径逻辑
subdir = None
if image_type == 'collage':
subdir = 'collage_img' # 可配置的子目录名称
elif image_type == 'poster':
subdir = 'poster'
else:
logging.warning(f"未知图像类型 '{image_type}',保存到变体根目录。")
subdir = None # 如果类型未知,直接保存到变体目录
# 确保目标目录存在
variant_dir = os.path.join(self._get_run_dir(run_id), f"{topic_index}_{variant_index}")
if subdir:
target_dir = os.path.join(variant_dir, subdir)
os.makedirs(target_dir, exist_ok=True)
else:
target_dir = variant_dir
os.makedirs(target_dir, exist_ok=True)
2025-04-26 14:53:54 +08:00
save_path = os.path.join(target_dir, output_filename)
try:
2025-04-26 13:45:47 +08:00
# 保存图片
image_data.save(save_path)
2025-04-26 14:53:54 +08:00
logging.info(f"保存{image_type}图像到: {save_path}")
2025-04-26 13:45:47 +08:00
# 保存元数据(如果有)
if metadata:
2025-04-26 14:53:54 +08:00
# 确保元数据文件与图像在同一目录
metadata_filename = os.path.splitext(os.path.basename(save_path))[0] + "_metadata.json"
metadata_path = os.path.join(os.path.dirname(save_path), metadata_filename)
2025-04-26 13:45:47 +08:00
try:
with open(metadata_path, 'w', encoding='utf-8') as f:
# 不使用自定义编码器使用标准json
json.dump(metadata, f, ensure_ascii=False, indent=4, ignore_nan=True)
2025-04-26 14:53:54 +08:00
logging.info(f"保存{image_type}元数据到: {metadata_path}")
2025-04-26 13:45:47 +08:00
except Exception as me:
2025-04-26 14:53:54 +08:00
logging.error(f"无法保存{image_type}元数据到{metadata_path}: {me}")
traceback.print_exc()
2025-04-26 13:45:47 +08:00
except Exception as e:
2025-04-26 14:53:54 +08:00
logging.exception(f"无法保存{image_type}图像到{save_path}: {e}")
traceback.print_exc()
2025-04-26 14:53:54 +08:00
return save_path
def finalize(self, run_id: str):
logging.info(f"FileSystemOutputHandler finalizing run: {run_id}. No specific actions needed.")
pass # Nothing specific to do for file system finalize
def _sanitize_content_for_json(self, data):
"""对内容进行深度清理确保可以安全序列化为JSON
Args:
data: 要处理的数据字典列表或基本类型
Returns:
经过处理的数据确保可以安全序列化
"""
if isinstance(data, dict):
# 处理字典类型
sanitized_dict = {}
for key, value in data.items():
2025-05-18 22:29:29 +08:00
# 移除error标志我们会在最终验证后重新设置它
if key == "error":
continue
if key == "raw_result":
continue
sanitized_dict[key] = self._sanitize_content_for_json(value)
return sanitized_dict
elif isinstance(data, list):
# 处理列表类型
return [self._sanitize_content_for_json(item) for item in data]
elif isinstance(data, str):
# 处理字符串类型(重点关注)
# 1. 首先,替换所有字面的"\n"为真正的换行符
if r'\n' in data:
data = data.replace(r'\n', '\n')
2025-05-18 22:29:29 +08:00
# 2. 使用更强的处理方式 - 只保留绝对安全的字符
# - ASCII 32-126 (标准可打印ASCII字符)
# - 换行、回车、制表符
# - 去除所有其他控制字符和潜在问题字符
safe_chars = []
for char in data:
2025-05-18 22:29:29 +08:00
if char in '\n\r\t' or (32 <= ord(char) <= 126):
safe_chars.append(char)
elif ord(char) > 127: # 非ASCII字符 (包括emoji)
# 转换为Unicode转义序列
safe_chars.append(f"\\u{ord(char):04x}".encode().decode('unicode-escape'))
cleaned = ''.join(safe_chars)
# 3. 验证字符串可以被安全序列化
try:
json.dumps(cleaned, ensure_ascii=False)
return cleaned
except Exception as e:
2025-05-18 22:29:29 +08:00
logging.warning(f"字符串清理后仍无法序列化,使用保守处理: {e}")
# 最保守的处理 - 只保留ASCII字符
return ''.join(c for c in cleaned if ord(c) < 128)
else:
# 其他类型(数字、布尔值等)原样返回
return data