TravelContentCreator/utils/output_handler.py

466 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import simplejson as json
import logging
from abc import ABC, abstractmethod
import traceback
import base64
# 自定义JSON编码器强制处理所有可能的JSON序列化问题
class SafeJSONEncoder(json.JSONEncoder):
"""安全的JSON编码器可以处理所有类型的字符串"""
def encode(self, obj):
"""重写encode方法确保任何字符串都能被安全编码"""
if isinstance(obj, dict):
# 处理字典:递归处理每个值
return '{' + ','.join(f'"{key}":{self.encode(value)}'
for key, value in obj.items()
if key not in ["error", "raw_result"]) + '}'
elif isinstance(obj, list):
# 处理列表:递归处理每个项
return '[' + ','.join(self.encode(item) for item in obj) + ']'
elif isinstance(obj, str):
# 安全处理字符串:移除可能导致问题的字符
safe_str = ''
for char in obj:
if char in '\n\r\t' or (32 <= ord(char) <= 126):
safe_str += char
# 跳过所有其他字符
return json.JSONEncoder.encode(self, safe_str)
else:
# 其他类型:使用默认处理
return json.JSONEncoder.encode(self, obj)
def iterencode(self, obj, _one_shot=False):
"""重写iterencode方法确保能处理迭代编码"""
return self.encode(obj)
class OutputHandler(ABC):
"""Abstract base class for handling the output of the generation pipeline."""
@abstractmethod
def handle_topic_results(self, run_id: str, topics_list: list, system_prompt: str, user_prompt: str):
"""Handles the results from the topic generation step."""
pass
@abstractmethod
def handle_content_variant(self, run_id: str, topic_index: int, variant_index: int, content_data: dict, prompt_data: str):
"""Handles the results for a single content variant."""
pass
@abstractmethod
def handle_poster_configs(self, run_id: str, topic_index: int, config_data: list | dict):
"""Handles the poster configuration generated for a topic."""
pass
@abstractmethod
def handle_generated_image(self, run_id: str, topic_index: int, variant_index: int, image_type: str, image_data, output_filename: str, metadata: dict = None):
"""Handles a generated image (collage or final poster).
Args:
image_type: Either 'collage' or 'poster'.
image_data: The image data (e.g., PIL Image object or bytes).
output_filename: The desired filename for the output (e.g., 'poster.jpg').
metadata: Optional dictionary with additional metadata about the image (e.g., used image files).
"""
pass
@abstractmethod
def finalize(self, run_id: str):
"""Perform any final actions for the run (e.g., close files, upload manifests)."""
pass
class FileSystemOutputHandler(OutputHandler):
"""Handles output by saving results to the local file system."""
def __init__(self, base_output_dir: str = "result"):
self.base_output_dir = base_output_dir
logging.info(f"FileSystemOutputHandler initialized. Base output directory: {self.base_output_dir}")
def _get_run_dir(self, run_id: str) -> str:
"""Gets the specific directory for a run, creating it if necessary."""
run_dir = os.path.join(self.base_output_dir, run_id)
os.makedirs(run_dir, exist_ok=True)
return run_dir
def _get_variant_dir(self, run_id: str, topic_index: int, variant_index: int, subdir: str | None = None) -> str:
"""Gets the specific directory for a variant, optionally within a subdirectory (e.g., 'poster'), creating it if necessary."""
run_dir = self._get_run_dir(run_id)
variant_base_dir = os.path.join(run_dir, f"{topic_index}_{variant_index}")
target_dir = variant_base_dir
if subdir:
target_dir = os.path.join(variant_base_dir, subdir)
os.makedirs(target_dir, exist_ok=True)
return target_dir
def handle_topic_results(self, run_id: str, topics_list: list, system_prompt: str, user_prompt: str):
run_dir = self._get_run_dir(run_id)
# Save topics list
topics_path = os.path.join(run_dir, f"tweet_topic_{run_id}.json")
try:
with open(topics_path, "w", encoding="utf-8") as f:
# 不使用自定义编码器直接使用标准json
json.dump(topics_list, f, ensure_ascii=False, indent=4, ignore_nan=True)
logging.info(f"Topics list saved successfully to: {topics_path}")
# 额外创建txt格式输出
txt_path = os.path.join(run_dir, f"tweet_topic_{run_id}.txt")
with open(txt_path, "w", encoding="utf-8") as f:
f.write(f"# 选题列表 (run_id: {run_id})\n\n")
for topic in topics_list:
f.write(f"## 选题 {topic.get('index', 'N/A')}\n")
f.write(f"- 日期: {topic.get('date', 'N/A')}\n")
f.write(f"- 对象: {topic.get('object', 'N/A')}\n")
f.write(f"- 产品: {topic.get('product', 'N/A')}\n")
f.write(f"- 产品策略: {topic.get('product_logic', 'N/A')}\n")
f.write(f"- 风格: {topic.get('style', 'N/A')}\n")
f.write(f"- 风格策略: {topic.get('style_logic', 'N/A')}\n")
f.write(f"- 目标受众: {topic.get('target_audience', 'N/A')}\n")
f.write(f"- 受众策略: {topic.get('target_audience_logic', 'N/A')}\n")
f.write(f"- 逻辑: {topic.get('logic', 'N/A')}\n\n")
logging.info(f"选题文本版本已保存到: {txt_path}")
except Exception as e:
logging.exception(f"Error saving topic JSON file to {topics_path}:")
# Save prompts
prompt_path = os.path.join(run_dir, f"topic_prompt_{run_id}.txt")
try:
with open(prompt_path, "w", encoding="utf-8") as f:
f.write("--- SYSTEM PROMPT ---\n")
f.write(system_prompt + "\n\n")
f.write("--- USER PROMPT ---\n")
f.write(user_prompt + "\n")
logging.info(f"Topic prompts saved successfully to: {prompt_path}")
except Exception as e:
logging.exception(f"Error saving topic prompts file to {prompt_path}:")
def handle_content_variant(self, run_id: str, topic_index: int, variant_index: int, content_data: dict, prompt_data: str):
"""Saves content JSON and prompt for a specific variant."""
variant_dir = self._get_variant_dir(run_id, topic_index, variant_index)
# 创建输出数据的副本,避免修改原始数据
import copy
input_data = copy.deepcopy(content_data)
# 统一使用tags字段避免tag和tags重复
if "tag" in input_data and "tags" not in input_data:
# 只有tag字段存在复制到tags
input_data["tags"] = input_data["tag"]
elif "tag" in input_data and "tags" in input_data:
# 两个字段都存在保留tags
pass
# 确保即使在未启用审核的情况下,字段也保持一致
if not input_data.get("judged", False):
input_data["judged"] = False
# 添加original字段临时值为当前值
if "title" in input_data and "original_title" not in input_data:
input_data["original_title"] = input_data["title"]
if "content" in input_data and "original_content" not in input_data:
input_data["original_content"] = input_data["content"]
if "tags" in input_data and "original_tags" not in input_data:
input_data["original_tags"] = input_data["tags"]
# 保存原始值用于txt文件生成和调试
original_title = input_data.get("title", "")
original_content = input_data.get("content", "")
original_tags = input_data.get("tags", "")
original_judge_analysis = input_data.get("judge_analysis", "")
# 创建一个只包含元数据和base64编码的输出数据对象
output_data = {
# 保留元数据字段
"judged": input_data.get("judged", False),
"judge_success": input_data.get("judge_success", False)
}
# 为所有内容字段创建base64编码版本
try:
# 1. 标题和内容
if "title" in input_data and input_data["title"]:
output_data["title_base64"] = base64.b64encode(input_data["title"].encode('utf-8')).decode('ascii')
if "content" in input_data and input_data["content"]:
output_data["content_base64"] = base64.b64encode(input_data["content"].encode('utf-8')).decode('ascii')
# 2. 标签
if "tags" in input_data and input_data["tags"]:
output_data["tags_base64"] = base64.b64encode(input_data["tags"].encode('utf-8')).decode('ascii')
# 3. 原始内容
if "original_title" in input_data and input_data["original_title"]:
output_data["original_title_base64"] = base64.b64encode(input_data["original_title"].encode('utf-8')).decode('ascii')
if "original_content" in input_data and input_data["original_content"]:
output_data["original_content_base64"] = base64.b64encode(input_data["original_content"].encode('utf-8')).decode('ascii')
# 4. 原始标签
if "original_tags" in input_data and input_data["original_tags"]:
output_data["original_tags_base64"] = base64.b64encode(input_data["original_tags"].encode('utf-8')).decode('ascii')
# 5. 审核分析
if "judge_analysis" in input_data and input_data["judge_analysis"]:
output_data["judge_analysis_base64"] = base64.b64encode(input_data["judge_analysis"].encode('utf-8')).decode('ascii')
logging.info("成功添加Base64编码内容")
except Exception as e:
logging.error(f"Base64编码内容时出错: {e}")
# 保存可能有用的额外字段
if "error" in input_data:
output_data["error"] = input_data["error"]
# 保存统一格式的article.json (只包含base64编码和元数据)
content_path = os.path.join(variant_dir, "article.json")
try:
with open(content_path, "w", encoding="utf-8") as f:
# 使用标准json
json.dump(output_data, f, ensure_ascii=False, indent=4, ignore_nan=True)
logging.info(f"Content JSON saved to: {content_path}")
except Exception as e:
logging.exception(f"Failed to save content JSON to {content_path}: {e}")
# 创建一份article.txt文件以便直接查看
txt_path = os.path.join(variant_dir, "article.txt")
try:
# 使用原始内容,保留所有换行符
with open(txt_path, "w", encoding="utf-8") as f:
if original_title:
f.write(f"{original_title}\n\n")
# 保持原始内容的所有换行符
if original_content:
f.write(original_content)
if original_tags:
f.write(f"\n\n{original_tags}")
if original_judge_analysis:
f.write(f"\n\n审核分析:\n{original_judge_analysis}")
logging.info(f"Article text saved to: {txt_path}")
except Exception as e:
logging.error(f"Failed to save article.txt: {e}")
# 记录调试信息,无论是否成功 (包含原始数据的完整副本以便调试)
debug_path = os.path.join(variant_dir, "debug_content.txt")
try:
with open(debug_path, "w", encoding="utf-8") as f:
f.write(f"原始标题: {original_title}\n\n")
f.write(f"原始内容: {original_content}\n\n")
if original_tags:
f.write(f"原始标签: {original_tags}\n\n")
if original_judge_analysis:
f.write(f"审核分析: {original_judge_analysis}\n\n")
f.write("---处理后---\n\n")
for key, value in output_data.items():
if isinstance(value, str):
f.write(f"{key}: (length: {len(value)})\n")
f.write(f"{repr(value[:200])}...\n\n")
else:
f.write(f"{key}: {type(value)}\n")
logging.info(f"调试内容已保存到: {debug_path}")
except Exception as debug_err:
logging.error(f"保存调试内容失败: {debug_err}")
# Save content prompt
prompt_path = os.path.join(variant_dir, "tweet_prompt.txt")
try:
with open(prompt_path, "w", encoding="utf-8") as f:
f.write(prompt_data)
logging.info(f"Content prompt saved to: {prompt_path}")
except Exception as e:
logging.error(f"Failed to save content prompt to {prompt_path}: {e}")
def _ultra_safe_clean(self, text):
"""执行最严格的字符清理确保100%可序列化"""
if not isinstance(text, str):
return ""
return ''.join(c for c in text if 32 <= ord(c) <= 126)
def _preprocess_for_json(self, text):
"""预处理文本,将换行符转换为\\n形式保证JSON安全"""
if not isinstance(text, str):
return text
# 将所有实际换行符替换为\\n字符串
return text.replace('\n', '\\n').replace('\r', '\\r')
def handle_poster_configs(self, run_id: str, topic_index: int, config_data: list | dict):
"""处理海报配置数据"""
# 处理海报配置数据
try:
# 创建目标目录
variant_dir = os.path.join(self._get_run_dir(run_id), f"{topic_index}_1")
os.makedirs(variant_dir, exist_ok=True)
# 确保配置数据是可序列化的
processed_configs = []
if isinstance(config_data, list):
for config in config_data:
processed_config = {}
# 处理索引字段
processed_config["index"] = config.get("index", 0)
# 处理标题字段应用JSON预处理
main_title = config.get("main_title", "")
processed_config["main_title"] = self._preprocess_for_json(main_title)
# 处理文本字段列表对每个文本应用JSON预处理
texts = config.get("texts", [])
processed_texts = []
for text in texts:
processed_texts.append(self._preprocess_for_json(text))
processed_config["texts"] = processed_texts
processed_configs.append(processed_config)
else:
# 如果不是列表,可能是字典或其他格式,尝试转换
if isinstance(config_data, dict):
# 处理单个配置字典
processed_config = {}
processed_config["index"] = config_data.get("index", 0)
processed_config["main_title"] = self._preprocess_for_json(config_data.get("main_title", ""))
texts = config_data.get("texts", [])
processed_texts = []
for text in texts:
processed_texts.append(self._preprocess_for_json(text))
processed_config["texts"] = processed_texts
processed_configs.append(processed_config)
# 保存配置到JSON文件
config_file_path = os.path.join(variant_dir, f"topic_{topic_index}_poster_configs.json")
with open(config_file_path, 'w', encoding='utf-8') as f:
json.dump(processed_configs, f, ensure_ascii=False, indent=4, cls=self.SafeJSONEncoder)
logging.info(f"Successfully saved poster configs to {config_file_path}")
except Exception as e:
logging.error(f"Error saving poster configs: {e}")
traceback.print_exc()
def handle_generated_image(self, run_id: str, topic_index: int, variant_index: int, image_type: str, image_data, output_filename: str, metadata: dict = None):
"""处理生成的图像对于笔记图像和额外配图保存到image目录其他类型保持原有路径结构"""
if not image_data:
logging.warning(f"传入的{image_type}图像数据为空 (Topic {topic_index}, Variant {variant_index})。跳过保存。")
return
# 根据图像类型确定保存路径
if image_type == 'note' or image_type == 'additional': # 笔记图像和额外配图保存到image目录
# 创建run_id/i_j/image目录
run_dir = self._get_run_dir(run_id)
variant_dir = os.path.join(run_dir, f"{topic_index}_{variant_index}")
image_dir = os.path.join(variant_dir, "image")
os.makedirs(image_dir, exist_ok=True)
# 在输出文件名前加上图像类型前缀
prefixed_filename = f"{image_type}_{output_filename}" if not output_filename.startswith(image_type) else output_filename
save_path = os.path.join(image_dir, prefixed_filename)
else:
# 其他类型图像使用原有的保存路径逻辑
subdir = None
if image_type == 'collage':
subdir = 'collage_img' # 可配置的子目录名称
elif image_type == 'poster':
subdir = 'poster'
else:
logging.warning(f"未知图像类型 '{image_type}',保存到变体根目录。")
subdir = None # 如果类型未知,直接保存到变体目录
# 确保目标目录存在
variant_dir = os.path.join(self._get_run_dir(run_id), f"{topic_index}_{variant_index}")
if subdir:
target_dir = os.path.join(variant_dir, subdir)
os.makedirs(target_dir, exist_ok=True)
else:
target_dir = variant_dir
os.makedirs(target_dir, exist_ok=True)
save_path = os.path.join(target_dir, output_filename)
try:
# 保存图片
image_data.save(save_path)
logging.info(f"保存{image_type}图像到: {save_path}")
# 保存元数据(如果有)
if metadata:
# 确保元数据文件与图像在同一目录
metadata_filename = os.path.splitext(os.path.basename(save_path))[0] + "_metadata.json"
metadata_path = os.path.join(os.path.dirname(save_path), metadata_filename)
try:
with open(metadata_path, 'w', encoding='utf-8') as f:
# 不使用自定义编码器使用标准json
json.dump(metadata, f, ensure_ascii=False, indent=4, ignore_nan=True)
logging.info(f"保存{image_type}元数据到: {metadata_path}")
except Exception as me:
logging.error(f"无法保存{image_type}元数据到{metadata_path}: {me}")
traceback.print_exc()
except Exception as e:
logging.exception(f"无法保存{image_type}图像到{save_path}: {e}")
traceback.print_exc()
return save_path
def finalize(self, run_id: str):
logging.info(f"FileSystemOutputHandler finalizing run: {run_id}. No specific actions needed.")
pass # Nothing specific to do for file system finalize
def _sanitize_content_for_json(self, data):
"""对内容进行深度清理确保可以安全序列化为JSON
Args:
data: 要处理的数据(字典、列表或基本类型)
Returns:
经过处理的数据,确保可以安全序列化
"""
if isinstance(data, dict):
# 处理字典类型
sanitized_dict = {}
for key, value in data.items():
# 移除error标志我们会在最终验证后重新设置它
if key == "error":
continue
if key == "raw_result":
continue
sanitized_dict[key] = self._sanitize_content_for_json(value)
return sanitized_dict
elif isinstance(data, list):
# 处理列表类型
return [self._sanitize_content_for_json(item) for item in data]
elif isinstance(data, str):
# 处理字符串类型(重点关注)
# 1. 首先,替换所有字面的"\n"为真正的换行符
if r'\n' in data:
data = data.replace(r'\n', '\n')
# 2. 使用更强的处理方式 - 只保留绝对安全的字符
# - ASCII 32-126 (标准可打印ASCII字符)
# - 换行、回车、制表符
# - 去除所有其他控制字符和潜在问题字符
safe_chars = []
for char in data:
if char in '\n\r\t' or (32 <= ord(char) <= 126):
safe_chars.append(char)
elif ord(char) > 127: # 非ASCII字符 (包括emoji)
# 转换为Unicode转义序列
safe_chars.append(f"\\u{ord(char):04x}".encode().decode('unicode-escape'))
cleaned = ''.join(safe_chars)
# 3. 验证字符串可以被安全序列化
try:
json.dumps(cleaned, ensure_ascii=False)
return cleaned
except Exception as e:
logging.warning(f"字符串清理后仍无法序列化,使用保守处理: {e}")
# 最保守的处理 - 只保留ASCII字符
return ''.join(c for c in cleaned if ord(c) < 128)
else:
# 其他类型(数字、布尔值等)原样返回
return data