TravelContentCreator/utils/output_handler.py

277 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import simplejson as json
import logging
from abc import ABC, abstractmethod
import traceback
class OutputHandler(ABC):
"""Abstract base class for handling the output of the generation pipeline."""
@abstractmethod
def handle_topic_results(self, run_id: str, topics_list: list, system_prompt: str, user_prompt: str):
"""Handles the results from the topic generation step."""
pass
@abstractmethod
def handle_content_variant(self, run_id: str, topic_index: int, variant_index: int, content_data: dict, prompt_data: str):
"""Handles the results for a single content variant."""
pass
@abstractmethod
def handle_poster_configs(self, run_id: str, topic_index: int, config_data: list | dict):
"""Handles the poster configuration generated for a topic."""
pass
@abstractmethod
def handle_generated_image(self, run_id: str, topic_index: int, variant_index: int, image_type: str, image_data, output_filename: str, metadata: dict = None):
"""Handles a generated image (collage or final poster).
Args:
image_type: Either 'collage' or 'poster'.
image_data: The image data (e.g., PIL Image object or bytes).
output_filename: The desired filename for the output (e.g., 'poster.jpg').
metadata: Optional dictionary with additional metadata about the image (e.g., used image files).
"""
pass
@abstractmethod
def finalize(self, run_id: str):
"""Perform any final actions for the run (e.g., close files, upload manifests)."""
pass
class FileSystemOutputHandler(OutputHandler):
"""Handles output by saving results to the local file system."""
def __init__(self, base_output_dir: str = "result"):
self.base_output_dir = base_output_dir
logging.info(f"FileSystemOutputHandler initialized. Base output directory: {self.base_output_dir}")
def _get_run_dir(self, run_id: str) -> str:
"""Gets the specific directory for a run, creating it if necessary."""
run_dir = os.path.join(self.base_output_dir, run_id)
os.makedirs(run_dir, exist_ok=True)
return run_dir
def _get_variant_dir(self, run_id: str, topic_index: int, variant_index: int, subdir: str | None = None) -> str:
"""Gets the specific directory for a variant, optionally within a subdirectory (e.g., 'poster'), creating it if necessary."""
run_dir = self._get_run_dir(run_id)
variant_base_dir = os.path.join(run_dir, f"{topic_index}_{variant_index}")
target_dir = variant_base_dir
if subdir:
target_dir = os.path.join(variant_base_dir, subdir)
os.makedirs(target_dir, exist_ok=True)
return target_dir
def handle_topic_results(self, run_id: str, topics_list: list, system_prompt: str, user_prompt: str):
run_dir = self._get_run_dir(run_id)
# Save topics list
topics_path = os.path.join(run_dir, f"tweet_topic_{run_id}.json")
try:
with open(topics_path, "w", encoding="utf-8") as f:
json.dump(topics_list, f, ensure_ascii=False, indent=4, ignore_nan=True)
logging.info(f"Topics list saved successfully to: {topics_path}")
except Exception as e:
logging.exception(f"Error saving topic JSON file to {topics_path}:")
# Save prompts
prompt_path = os.path.join(run_dir, f"topic_prompt_{run_id}.txt")
try:
with open(prompt_path, "w", encoding="utf-8") as f:
f.write("--- SYSTEM PROMPT ---\n")
f.write(system_prompt + "\n\n")
f.write("--- USER PROMPT ---\n")
f.write(user_prompt + "\n")
logging.info(f"Topic prompts saved successfully to: {prompt_path}")
except Exception as e:
logging.exception(f"Error saving topic prompts file to {prompt_path}:")
def handle_content_variant(self, run_id: str, topic_index: int, variant_index: int, content_data: dict, prompt_data: str):
"""Saves content JSON and prompt for a specific variant."""
variant_dir = self._get_variant_dir(run_id, topic_index, variant_index)
# 创建输出数据的副本,避免修改原始数据
import copy
output_data = copy.deepcopy(content_data)
# 统一使用tags字段避免tag和tags重复
if "tag" in output_data and "tags" not in output_data:
# 只有tag字段存在复制到tags
output_data["tags"] = output_data["tag"]
del output_data["tag"]
elif "tag" in output_data and "tags" in output_data:
# 两个字段都存在保留tags并删除tag
del output_data["tag"]
# 确保即使在未启用审核的情况下,字段也保持一致
if not output_data.get("judged", False):
output_data["judged"] = False
# 添加original_title、original_content和judge_analysis字段值为null
output_data["original_title"] = None
output_data["original_content"] = None
output_data["judge_analysis"] = None
# 添加original_tags字段
if "tags" in output_data and "original_tags" not in output_data:
output_data["original_tags"] = output_data["tags"]
# 对内容进行深度清理,确保安全序列化
try:
output_data = self._sanitize_content_for_json(output_data)
logging.info("内容已经过安全清理,可以序列化")
except Exception as e:
logging.error(f"内容清理过程中出错: {e}")
# 保存统一格式的article.json
content_path = os.path.join(variant_dir, "article.json")
try:
with open(content_path, "w", encoding="utf-8") as f:
json.dump(output_data, f, ensure_ascii=False, indent=4, ignore_nan=True)
logging.info(f"Content JSON saved to: {content_path}")
except Exception as e:
logging.exception(f"Failed to save content JSON to {content_path}: {e}")
# 如果序列化失败,记录原始内容用于调试
debug_path = os.path.join(variant_dir, "debug_content.txt")
try:
with open(debug_path, "w", encoding="utf-8") as f:
for key, value in output_data.items():
if isinstance(value, str):
f.write(f"{key}: (length: {len(value)})\n")
f.write(f"{repr(value[:200])}...\n\n")
else:
f.write(f"{key}: {type(value)}\n")
logging.info(f"Debug content saved to: {debug_path}")
except Exception as debug_err:
logging.error(f"Failed to save debug content: {debug_err}")
# Save content prompt
prompt_path = os.path.join(variant_dir, "tweet_prompt.txt")
try:
with open(prompt_path, "w", encoding="utf-8") as f:
# Assuming prompt_data is the user prompt used for this variant
f.write(prompt_data + "\n")
logging.info(f"Content prompt saved to: {prompt_path}")
except Exception as e:
logging.exception(f"Failed to save content prompt to {prompt_path}: {e}")
def handle_poster_configs(self, run_id: str, topic_index: int, config_data: list | dict):
"""Saves the complete poster configuration list/dict for a topic."""
run_dir = self._get_run_dir(run_id)
config_path = os.path.join(run_dir, f"topic_{topic_index}_poster_configs.json")
try:
with open(config_path, 'w', encoding='utf-8') as f_cfg_topic:
json.dump(config_data, f_cfg_topic, ensure_ascii=False, indent=4, ignore_nan=True)
logging.info(f"Saved complete poster configurations for topic {topic_index} to: {config_path}")
except Exception as save_err:
logging.error(f"Failed to save complete poster configurations for topic {topic_index} to {config_path}: {save_err}")
def handle_generated_image(self, run_id: str, topic_index: int, variant_index: int, image_type: str, image_data, output_filename: str, metadata: dict = None):
"""处理生成的图像对于笔记图像和额外配图保存到image目录其他类型保持原有路径结构"""
if not image_data:
logging.warning(f"传入的{image_type}图像数据为空 (Topic {topic_index}, Variant {variant_index})。跳过保存。")
return
# 根据图像类型确定保存路径
if image_type == 'note' or image_type == 'additional': # 笔记图像和额外配图保存到image目录
# 创建run_id/i_j/image目录
run_dir = self._get_run_dir(run_id)
variant_dir = os.path.join(run_dir, f"{topic_index}_{variant_index}")
image_dir = os.path.join(variant_dir, "image")
os.makedirs(image_dir, exist_ok=True)
# 在输出文件名前加上图像类型前缀
prefixed_filename = f"{image_type}_{output_filename}" if not output_filename.startswith(image_type) else output_filename
save_path = os.path.join(image_dir, prefixed_filename)
else:
# 其他类型图像使用原有的保存路径逻辑
subdir = None
if image_type == 'collage':
subdir = 'collage_img' # 可配置的子目录名称
elif image_type == 'poster':
subdir = 'poster'
else:
logging.warning(f"未知图像类型 '{image_type}',保存到变体根目录。")
subdir = None # 如果类型未知,直接保存到变体目录
# 确保目标目录存在
variant_dir = os.path.join(self._get_run_dir(run_id), f"{topic_index}_{variant_index}")
if subdir:
target_dir = os.path.join(variant_dir, subdir)
os.makedirs(target_dir, exist_ok=True)
else:
target_dir = variant_dir
os.makedirs(target_dir, exist_ok=True)
save_path = os.path.join(target_dir, output_filename)
try:
# 保存图片
image_data.save(save_path)
logging.info(f"保存{image_type}图像到: {save_path}")
# 保存元数据(如果有)
if metadata:
# 确保元数据文件与图像在同一目录
metadata_filename = os.path.splitext(os.path.basename(save_path))[0] + "_metadata.json"
metadata_path = os.path.join(os.path.dirname(save_path), metadata_filename)
try:
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, ensure_ascii=False, indent=4, ignore_nan=True)
logging.info(f"保存{image_type}元数据到: {metadata_path}")
except Exception as me:
logging.error(f"无法保存{image_type}元数据到{metadata_path}: {me}")
traceback.print_exc()
except Exception as e:
logging.exception(f"无法保存{image_type}图像到{save_path}: {e}")
traceback.print_exc()
return save_path
def finalize(self, run_id: str):
logging.info(f"FileSystemOutputHandler finalizing run: {run_id}. No specific actions needed.")
pass # Nothing specific to do for file system finalize
def _sanitize_content_for_json(self, data):
"""对内容进行深度清理确保可以安全序列化为JSON
Args:
data: 要处理的数据(字典、列表或基本类型)
Returns:
经过处理的数据,确保可以安全序列化
"""
if isinstance(data, dict):
# 处理字典类型
sanitized_dict = {}
for key, value in data.items():
sanitized_dict[key] = self._sanitize_content_for_json(value)
return sanitized_dict
elif isinstance(data, list):
# 处理列表类型
return [self._sanitize_content_for_json(item) for item in data]
elif isinstance(data, str):
# 处理字符串类型(重点关注)
# 1. 首先,替换所有字面的"\n"为真正的换行符
if r'\n' in data:
data = data.replace(r'\n', '\n')
# 2. 移除所有控制字符ASCII 0-31除了\n, \r, \t
cleaned = ''
for char in data:
# 允许常见的空白字符
if char in '\n\r\t' or ord(char) >= 32:
cleaned += char
# 3. 验证字符串可以被安全序列化
try:
json.dumps(cleaned, ensure_ascii=False)
return cleaned
except Exception as e:
logging.warning(f"字符串清理后仍无法序列化,尝试更严格的清理: {e}")
# 如果仍然无法序列化,使用更严格的清理
return ''.join(c for c in cleaned if ord(c) < 65536 and (c in '\n\r\t' or ord(c) >= 32))
else:
# 其他类型(数字、布尔值等)原样返回
return data