TravelContentCreator/utils/output_handler.py

381 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import simplejson as json
import logging
from abc import ABC, abstractmethod
import traceback
import base64
# 自定义JSON编码器强制处理所有可能的JSON序列化问题
class SafeJSONEncoder(json.JSONEncoder):
"""安全的JSON编码器可以处理所有类型的字符串"""
def encode(self, obj):
"""重写encode方法确保任何字符串都能被安全编码"""
if isinstance(obj, dict):
# 处理字典:递归处理每个值
return '{' + ','.join(f'"{key}":{self.encode(value)}'
for key, value in obj.items()
if key not in ["error", "raw_result"]) + '}'
elif isinstance(obj, list):
# 处理列表:递归处理每个项
return '[' + ','.join(self.encode(item) for item in obj) + ']'
elif isinstance(obj, str):
# 安全处理字符串:移除可能导致问题的字符
safe_str = ''
for char in obj:
if char in '\n\r\t' or (32 <= ord(char) <= 126):
safe_str += char
# 跳过所有其他字符
return json.JSONEncoder.encode(self, safe_str)
else:
# 其他类型:使用默认处理
return json.JSONEncoder.encode(self, obj)
def iterencode(self, obj, _one_shot=False):
"""重写iterencode方法确保能处理迭代编码"""
return self.encode(obj)
class OutputHandler(ABC):
"""Abstract base class for handling the output of the generation pipeline."""
@abstractmethod
def handle_topic_results(self, run_id: str, topics_list: list, system_prompt: str, user_prompt: str):
"""Handles the results from the topic generation step."""
pass
@abstractmethod
def handle_content_variant(self, run_id: str, topic_index: int, variant_index: int, content_data: dict, prompt_data: str):
"""Handles the results for a single content variant."""
pass
@abstractmethod
def handle_poster_configs(self, run_id: str, topic_index: int, config_data: list | dict):
"""Handles the poster configuration generated for a topic."""
pass
@abstractmethod
def handle_generated_image(self, run_id: str, topic_index: int, variant_index: int, image_type: str, image_data, output_filename: str, metadata: dict = None):
"""Handles a generated image (collage or final poster).
Args:
image_type: Either 'collage' or 'poster'.
image_data: The image data (e.g., PIL Image object or bytes).
output_filename: The desired filename for the output (e.g., 'poster.jpg').
metadata: Optional dictionary with additional metadata about the image (e.g., used image files).
"""
pass
@abstractmethod
def finalize(self, run_id: str):
"""Perform any final actions for the run (e.g., close files, upload manifests)."""
pass
class FileSystemOutputHandler(OutputHandler):
"""Handles output by saving results to the local file system."""
def __init__(self, base_output_dir: str = "result"):
self.base_output_dir = base_output_dir
logging.info(f"FileSystemOutputHandler initialized. Base output directory: {self.base_output_dir}")
def _get_run_dir(self, run_id: str) -> str:
"""Gets the specific directory for a run, creating it if necessary."""
run_dir = os.path.join(self.base_output_dir, run_id)
os.makedirs(run_dir, exist_ok=True)
return run_dir
def _get_variant_dir(self, run_id: str, topic_index: int, variant_index: int, subdir: str | None = None) -> str:
"""Gets the specific directory for a variant, optionally within a subdirectory (e.g., 'poster'), creating it if necessary."""
run_dir = self._get_run_dir(run_id)
variant_base_dir = os.path.join(run_dir, f"{topic_index}_{variant_index}")
target_dir = variant_base_dir
if subdir:
target_dir = os.path.join(variant_base_dir, subdir)
os.makedirs(target_dir, exist_ok=True)
return target_dir
def handle_topic_results(self, run_id: str, topics_list: list, system_prompt: str, user_prompt: str):
run_dir = self._get_run_dir(run_id)
# Save topics list
topics_path = os.path.join(run_dir, f"tweet_topic_{run_id}.json")
try:
with open(topics_path, "w", encoding="utf-8") as f:
json.dump(topics_list, f, ensure_ascii=False, indent=4, ignore_nan=True, cls=SafeJSONEncoder)
logging.info(f"Topics list saved successfully to: {topics_path}")
except Exception as e:
logging.exception(f"Error saving topic JSON file to {topics_path}:")
# Save prompts
prompt_path = os.path.join(run_dir, f"topic_prompt_{run_id}.txt")
try:
with open(prompt_path, "w", encoding="utf-8") as f:
f.write("--- SYSTEM PROMPT ---\n")
f.write(system_prompt + "\n\n")
f.write("--- USER PROMPT ---\n")
f.write(user_prompt + "\n")
logging.info(f"Topic prompts saved successfully to: {prompt_path}")
except Exception as e:
logging.exception(f"Error saving topic prompts file to {prompt_path}:")
def handle_content_variant(self, run_id: str, topic_index: int, variant_index: int, content_data: dict, prompt_data: str):
"""Saves content JSON and prompt for a specific variant."""
variant_dir = self._get_variant_dir(run_id, topic_index, variant_index)
# 创建输出数据的副本,避免修改原始数据
import copy
output_data = copy.deepcopy(content_data)
# 统一使用tags字段避免tag和tags重复
if "tag" in output_data and "tags" not in output_data:
# 只有tag字段存在复制到tags
output_data["tags"] = output_data["tag"]
del output_data["tag"]
elif "tag" in output_data and "tags" in output_data:
# 两个字段都存在保留tags并删除tag
del output_data["tag"]
# 确保即使在未启用审核的情况下,字段也保持一致
if not output_data.get("judged", False):
output_data["judged"] = False
# 添加original_title、original_content和judge_analysis字段值为null
output_data["original_title"] = None
output_data["original_content"] = None
output_data["judge_analysis"] = None
# 添加original_tags字段
if "tags" in output_data and "original_tags" not in output_data:
output_data["original_tags"] = output_data["tags"]
# 保存原始值用于调试
original_title = output_data.get("title", "")
original_content = output_data.get("content", "")
# 添加Base64编码内容
try:
# 编码标题和内容
title_base64 = base64.b64encode(output_data.get("title", "").encode('utf-8')).decode('ascii')
content_base64 = base64.b64encode(output_data.get("content", "").encode('utf-8')).decode('ascii')
# 添加到输出数据
output_data["title_base64"] = title_base64
output_data["content_base64"] = content_base64
# 如果有原始内容,也编码
if "original_title" in output_data and output_data["original_title"]:
output_data["original_title_base64"] = base64.b64encode(
output_data["original_title"].encode('utf-8')).decode('ascii')
if "original_content" in output_data and output_data["original_content"]:
output_data["original_content_base64"] = base64.b64encode(
output_data["original_content"].encode('utf-8')).decode('ascii')
logging.info("成功添加Base64编码内容")
except Exception as e:
logging.error(f"Base64编码内容时出错: {e}")
# 对内容进行深度清理,确保安全序列化
try:
# 暂存judge_success状态
judge_success = output_data.get("judge_success", False)
# 深度清理
output_data = self._sanitize_content_for_json(output_data)
# 恢复judge_success状态
output_data["judge_success"] = judge_success
# 移除可能的错误标志 - 我们通过尝试序列化来决定是否设置它
if "error" in output_data:
del output_data["error"]
if "raw_result" in output_data:
del output_data["raw_result"]
logging.info("内容已经过安全清理,可以序列化")
except Exception as e:
logging.error(f"内容清理过程中出错: {e}")
# 保存统一格式的article.json
content_path = os.path.join(variant_dir, "article.json")
try:
with open(content_path, "w", encoding="utf-8") as f:
# 使用自定义的SafeJSONEncoder
json.dump(output_data, f, ensure_ascii=False, indent=4, ignore_nan=True, cls=SafeJSONEncoder)
logging.info(f"Content JSON saved to: {content_path}")
except Exception as e:
logging.exception(f"Failed to save content JSON to {content_path}: {e}")
# 创建一份article.txt文件以便直接查看
txt_path = os.path.join(variant_dir, "article.txt")
try:
# 使用原始内容
with open(txt_path, "w", encoding="utf-8") as f:
f.write(f"{original_title}\n\n{original_content}")
logging.info(f"Article text saved to: {txt_path}")
except Exception as e:
logging.error(f"Failed to save article.txt: {e}")
# 记录调试信息,无论是否成功
debug_path = os.path.join(variant_dir, "debug_content.txt")
try:
with open(debug_path, "w", encoding="utf-8") as f:
f.write(f"原始标题: {original_title}\n\n")
f.write(f"原始内容: {original_content}\n\n")
f.write("---处理后---\n\n")
for key, value in output_data.items():
if isinstance(value, str):
f.write(f"{key}: (length: {len(value)})\n")
f.write(f"{repr(value[:200])}...\n\n")
else:
f.write(f"{key}: {type(value)}\n")
logging.info(f"调试内容已保存到: {debug_path}")
except Exception as debug_err:
logging.error(f"保存调试内容失败: {debug_err}")
# Save content prompt
prompt_path = os.path.join(variant_dir, "tweet_prompt.txt")
try:
with open(prompt_path, "w", encoding="utf-8") as f:
# Assuming prompt_data is the user prompt used for this variant
f.write(prompt_data + "\n")
logging.info(f"Content prompt saved to: {prompt_path}")
except Exception as e:
logging.exception(f"Failed to save content prompt to {prompt_path}: {e}")
def _ultra_safe_clean(self, text):
"""执行最严格的字符清理确保100%可序列化"""
if not isinstance(text, str):
return ""
return ''.join(c for c in text if 32 <= ord(c) <= 126)
def handle_poster_configs(self, run_id: str, topic_index: int, config_data: list | dict):
"""Saves the complete poster configuration list/dict for a topic."""
run_dir = self._get_run_dir(run_id)
config_path = os.path.join(run_dir, f"topic_{topic_index}_poster_configs.json")
try:
with open(config_path, 'w', encoding='utf-8') as f_cfg_topic:
json.dump(config_data, f_cfg_topic, ensure_ascii=False, indent=4, ignore_nan=True, cls=SafeJSONEncoder)
logging.info(f"Saved complete poster configurations for topic {topic_index} to: {config_path}")
except Exception as save_err:
logging.error(f"Failed to save complete poster configurations for topic {topic_index} to {config_path}: {save_err}")
def handle_generated_image(self, run_id: str, topic_index: int, variant_index: int, image_type: str, image_data, output_filename: str, metadata: dict = None):
"""处理生成的图像对于笔记图像和额外配图保存到image目录其他类型保持原有路径结构"""
if not image_data:
logging.warning(f"传入的{image_type}图像数据为空 (Topic {topic_index}, Variant {variant_index})。跳过保存。")
return
# 根据图像类型确定保存路径
if image_type == 'note' or image_type == 'additional': # 笔记图像和额外配图保存到image目录
# 创建run_id/i_j/image目录
run_dir = self._get_run_dir(run_id)
variant_dir = os.path.join(run_dir, f"{topic_index}_{variant_index}")
image_dir = os.path.join(variant_dir, "image")
os.makedirs(image_dir, exist_ok=True)
# 在输出文件名前加上图像类型前缀
prefixed_filename = f"{image_type}_{output_filename}" if not output_filename.startswith(image_type) else output_filename
save_path = os.path.join(image_dir, prefixed_filename)
else:
# 其他类型图像使用原有的保存路径逻辑
subdir = None
if image_type == 'collage':
subdir = 'collage_img' # 可配置的子目录名称
elif image_type == 'poster':
subdir = 'poster'
else:
logging.warning(f"未知图像类型 '{image_type}',保存到变体根目录。")
subdir = None # 如果类型未知,直接保存到变体目录
# 确保目标目录存在
variant_dir = os.path.join(self._get_run_dir(run_id), f"{topic_index}_{variant_index}")
if subdir:
target_dir = os.path.join(variant_dir, subdir)
os.makedirs(target_dir, exist_ok=True)
else:
target_dir = variant_dir
os.makedirs(target_dir, exist_ok=True)
save_path = os.path.join(target_dir, output_filename)
try:
# 保存图片
image_data.save(save_path)
logging.info(f"保存{image_type}图像到: {save_path}")
# 保存元数据(如果有)
if metadata:
# 确保元数据文件与图像在同一目录
metadata_filename = os.path.splitext(os.path.basename(save_path))[0] + "_metadata.json"
metadata_path = os.path.join(os.path.dirname(save_path), metadata_filename)
try:
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, ensure_ascii=False, indent=4, ignore_nan=True, cls=SafeJSONEncoder)
logging.info(f"保存{image_type}元数据到: {metadata_path}")
except Exception as me:
logging.error(f"无法保存{image_type}元数据到{metadata_path}: {me}")
traceback.print_exc()
except Exception as e:
logging.exception(f"无法保存{image_type}图像到{save_path}: {e}")
traceback.print_exc()
return save_path
def finalize(self, run_id: str):
logging.info(f"FileSystemOutputHandler finalizing run: {run_id}. No specific actions needed.")
pass # Nothing specific to do for file system finalize
def _sanitize_content_for_json(self, data):
"""对内容进行深度清理确保可以安全序列化为JSON
Args:
data: 要处理的数据(字典、列表或基本类型)
Returns:
经过处理的数据,确保可以安全序列化
"""
if isinstance(data, dict):
# 处理字典类型
sanitized_dict = {}
for key, value in data.items():
# 移除error标志我们会在最终验证后重新设置它
if key == "error":
continue
if key == "raw_result":
continue
sanitized_dict[key] = self._sanitize_content_for_json(value)
return sanitized_dict
elif isinstance(data, list):
# 处理列表类型
return [self._sanitize_content_for_json(item) for item in data]
elif isinstance(data, str):
# 处理字符串类型(重点关注)
# 1. 首先,替换所有字面的"\n"为真正的换行符
if r'\n' in data:
data = data.replace(r'\n', '\n')
# 2. 使用更强的处理方式 - 只保留绝对安全的字符
# - ASCII 32-126 (标准可打印ASCII字符)
# - 换行、回车、制表符
# - 去除所有其他控制字符和潜在问题字符
safe_chars = []
for char in data:
if char in '\n\r\t' or (32 <= ord(char) <= 126):
safe_chars.append(char)
elif ord(char) > 127: # 非ASCII字符 (包括emoji)
# 转换为Unicode转义序列
safe_chars.append(f"\\u{ord(char):04x}".encode().decode('unicode-escape'))
cleaned = ''.join(safe_chars)
# 3. 验证字符串可以被安全序列化
try:
json.dumps(cleaned, ensure_ascii=False)
return cleaned
except Exception as e:
logging.warning(f"字符串清理后仍无法序列化,使用保守处理: {e}")
# 最保守的处理 - 只保留ASCII字符
return ''.join(c for c in cleaned if ord(c) < 128)
else:
# 其他类型(数字、布尔值等)原样返回
return data