TravelContentCreator/utils/output_handler.py

220 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import logging
from abc import ABC, abstractmethod
import traceback
class OutputHandler(ABC):
"""Abstract base class for handling the output of the generation pipeline."""
@abstractmethod
def handle_topic_results(self, run_id: str, topics_list: list, system_prompt: str, user_prompt: str):
"""Handles the results from the topic generation step."""
pass
@abstractmethod
def handle_content_variant(self, run_id: str, topic_index: int, variant_index: int, content_data: dict, prompt_data: str):
"""Handles the results for a single content variant."""
pass
@abstractmethod
def handle_poster_configs(self, run_id: str, topic_index: int, config_data: list | dict):
"""Handles the poster configuration generated for a topic."""
pass
@abstractmethod
def handle_generated_image(self, run_id: str, topic_index: int, variant_index: int, image_type: str, image_data, output_filename: str, metadata: dict = None):
"""Handles a generated image (collage or final poster).
Args:
image_type: Either 'collage' or 'poster'.
image_data: The image data (e.g., PIL Image object or bytes).
output_filename: The desired filename for the output (e.g., 'poster.jpg').
metadata: Optional dictionary with additional metadata about the image (e.g., used image files).
"""
pass
@abstractmethod
def finalize(self, run_id: str):
"""Perform any final actions for the run (e.g., close files, upload manifests)."""
pass
class FileSystemOutputHandler(OutputHandler):
"""Handles output by saving results to the local file system."""
def __init__(self, base_output_dir: str = "result"):
self.base_output_dir = base_output_dir
logging.info(f"FileSystemOutputHandler initialized. Base output directory: {self.base_output_dir}")
def _get_run_dir(self, run_id: str) -> str:
"""Gets the specific directory for a run, creating it if necessary."""
run_dir = os.path.join(self.base_output_dir, run_id)
os.makedirs(run_dir, exist_ok=True)
return run_dir
def _get_variant_dir(self, run_id: str, topic_index: int, variant_index: int, subdir: str | None = None) -> str:
"""Gets the specific directory for a variant, optionally within a subdirectory (e.g., 'poster'), creating it if necessary."""
run_dir = self._get_run_dir(run_id)
variant_base_dir = os.path.join(run_dir, f"{topic_index}_{variant_index}")
target_dir = variant_base_dir
if subdir:
target_dir = os.path.join(variant_base_dir, subdir)
os.makedirs(target_dir, exist_ok=True)
return target_dir
def handle_topic_results(self, run_id: str, topics_list: list, system_prompt: str, user_prompt: str):
run_dir = self._get_run_dir(run_id)
# Save topics list
topics_path = os.path.join(run_dir, f"tweet_topic_{run_id}.json")
try:
with open(topics_path, "w", encoding="utf-8") as f:
json.dump(topics_list, f, ensure_ascii=False, indent=4)
logging.info(f"Topics list saved successfully to: {topics_path}")
except Exception as e:
logging.exception(f"Error saving topic JSON file to {topics_path}:")
# Save prompts
prompt_path = os.path.join(run_dir, f"topic_prompt_{run_id}.txt")
try:
with open(prompt_path, "w", encoding="utf-8") as f:
f.write("--- SYSTEM PROMPT ---\n")
f.write(system_prompt + "\n\n")
f.write("--- USER PROMPT ---\n")
f.write(user_prompt + "\n")
logging.info(f"Topic prompts saved successfully to: {prompt_path}")
except Exception as e:
logging.exception(f"Error saving topic prompts file to {prompt_path}:")
def handle_content_variant(self, run_id: str, topic_index: int, variant_index: int, content_data: dict, prompt_data: str):
"""Saves content JSON and prompt for a specific variant."""
variant_dir = self._get_variant_dir(run_id, topic_index, variant_index)
# 检查内容是否经过审核
if content_data.get("judged", False):
# 保存原始内容到raw_data.json
raw_content = {
"title": content_data.get("original_title", content_data.get("title", "")),
"content": content_data.get("original_content", content_data.get("content", "")),
"tag": content_data.get("tag", ""), # 确保保留tag字段
"error": content_data.get("error", False)
}
# 如果内容里没有original_字段则认为当前内容就是原始内容
if "original_title" not in content_data and "original_content" not in content_data:
# 先深拷贝当前内容数据以避免修改原数据
import copy
raw_content = copy.deepcopy(content_data)
# 移除审核相关字段
raw_content.pop("judged", None)
raw_content.pop("judge_analysis", None)
raw_data_path = os.path.join(variant_dir, "raw_data.json")
try:
with open(raw_data_path, "w", encoding="utf-8") as f:
json.dump(raw_content, f, ensure_ascii=False, indent=4)
logging.info(f"原始内容保存到: {raw_data_path}")
except Exception as e:
logging.exception(f"保存原始内容到 {raw_data_path} 失败: {e}")
# Save content JSON
content_path = os.path.join(variant_dir, "article.json")
try:
# 确保tag字段存在
if "tag" not in content_data and content_data.get("tags"):
content_data["tag"] = content_data["tags"]
with open(content_path, "w", encoding="utf-8") as f:
json.dump(content_data, f, ensure_ascii=False, indent=4)
logging.info(f"Content JSON saved to: {content_path}")
except Exception as e:
logging.exception(f"Failed to save content JSON to {content_path}: {e}")
# Save content prompt
prompt_path = os.path.join(variant_dir, "tweet_prompt.txt")
try:
with open(prompt_path, "w", encoding="utf-8") as f:
# Assuming prompt_data is the user prompt used for this variant
f.write(prompt_data + "\n")
logging.info(f"Content prompt saved to: {prompt_path}")
except Exception as e:
logging.exception(f"Failed to save content prompt to {prompt_path}: {e}")
def handle_poster_configs(self, run_id: str, topic_index: int, config_data: list | dict):
"""Saves the complete poster configuration list/dict for a topic."""
run_dir = self._get_run_dir(run_id)
config_path = os.path.join(run_dir, f"topic_{topic_index}_poster_configs.json")
try:
with open(config_path, 'w', encoding='utf-8') as f_cfg_topic:
json.dump(config_data, f_cfg_topic, ensure_ascii=False, indent=4)
logging.info(f"Saved complete poster configurations for topic {topic_index} to: {config_path}")
except Exception as save_err:
logging.error(f"Failed to save complete poster configurations for topic {topic_index} to {config_path}: {save_err}")
def handle_generated_image(self, run_id: str, topic_index: int, variant_index: int, image_type: str, image_data, output_filename: str, metadata: dict = None):
"""处理生成的图像对于笔记图像和额外配图保存到image目录其他类型保持原有路径结构"""
if not image_data:
logging.warning(f"传入的{image_type}图像数据为空 (Topic {topic_index}, Variant {variant_index})。跳过保存。")
return
# 根据图像类型确定保存路径
if image_type == 'note' or image_type == 'additional': # 笔记图像和额外配图保存到image目录
# 创建run_id/i_j/image目录
run_dir = self._get_run_dir(run_id)
variant_dir = os.path.join(run_dir, f"{topic_index}_{variant_index}")
image_dir = os.path.join(variant_dir, "image")
os.makedirs(image_dir, exist_ok=True)
# 在输出文件名前加上图像类型前缀
prefixed_filename = f"{image_type}_{output_filename}" if not output_filename.startswith(image_type) else output_filename
save_path = os.path.join(image_dir, prefixed_filename)
else:
# 其他类型图像使用原有的保存路径逻辑
subdir = None
if image_type == 'collage':
subdir = 'collage_img' # 可配置的子目录名称
elif image_type == 'poster':
subdir = 'poster'
else:
logging.warning(f"未知图像类型 '{image_type}',保存到变体根目录。")
subdir = None # 如果类型未知,直接保存到变体目录
# 确保目标目录存在
variant_dir = os.path.join(self._get_run_dir(run_id), f"{topic_index}_{variant_index}")
if subdir:
target_dir = os.path.join(variant_dir, subdir)
os.makedirs(target_dir, exist_ok=True)
else:
target_dir = variant_dir
os.makedirs(target_dir, exist_ok=True)
save_path = os.path.join(target_dir, output_filename)
try:
# 保存图片
image_data.save(save_path)
logging.info(f"保存{image_type}图像到: {save_path}")
# 保存元数据(如果有)
if metadata:
# 确保元数据文件与图像在同一目录
metadata_filename = os.path.splitext(os.path.basename(save_path))[0] + "_metadata.json"
metadata_path = os.path.join(os.path.dirname(save_path), metadata_filename)
try:
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, ensure_ascii=False, indent=4)
logging.info(f"保存{image_type}元数据到: {metadata_path}")
except Exception as me:
logging.error(f"无法保存{image_type}元数据到{metadata_path}: {me}")
traceback.print_exc()
except Exception as e:
logging.exception(f"无法保存{image_type}图像到{save_path}: {e}")
traceback.print_exc()
return save_path
def finalize(self, run_id: str):
logging.info(f"FileSystemOutputHandler finalizing run: {run_id}. No specific actions needed.")
pass # Nothing specific to do for file system finalize