import os import simplejson as json import logging from abc import ABC, abstractmethod import traceback import base64 # 自定义JSON编码器,强制处理所有可能的JSON序列化问题 class SafeJSONEncoder(json.JSONEncoder): """安全的JSON编码器,可以处理所有类型的字符串""" def encode(self, obj): """重写encode方法,确保任何字符串都能被安全编码""" if isinstance(obj, dict): # 处理字典:递归处理每个值 return '{' + ','.join(f'"{key}":{self.encode(value)}' for key, value in obj.items() if key not in ["error", "raw_result"]) + '}' elif isinstance(obj, list): # 处理列表:递归处理每个项 return '[' + ','.join(self.encode(item) for item in obj) + ']' elif isinstance(obj, str): # 安全处理字符串:移除可能导致问题的字符 safe_str = '' for char in obj: if char in '\n\r\t' or (32 <= ord(char) <= 126): safe_str += char # 跳过所有其他字符 return json.JSONEncoder.encode(self, safe_str) else: # 其他类型:使用默认处理 return json.JSONEncoder.encode(self, obj) def iterencode(self, obj, _one_shot=False): """重写iterencode方法,确保能处理迭代编码""" return self.encode(obj) class OutputHandler(ABC): """Abstract base class for handling the output of the generation pipeline.""" @abstractmethod def handle_topic_results(self, run_id: str, topics_list: list, system_prompt: str, user_prompt: str): """Handles the results from the topic generation step.""" pass @abstractmethod def handle_content_variant(self, run_id: str, topic_index: int, variant_index: int, content_data: dict, prompt_data: str): """Handles the results for a single content variant.""" pass @abstractmethod def handle_poster_configs(self, run_id: str, topic_index: int, config_data: list | dict): """Handles the poster configuration generated for a topic.""" pass @abstractmethod def handle_generated_image(self, run_id: str, topic_index: int, variant_index: int, image_type: str, image_data, output_filename: str, metadata: dict = None): """Handles a generated image (collage or final poster). Args: image_type: Either 'collage' or 'poster'. image_data: The image data (e.g., PIL Image object or bytes). output_filename: The desired filename for the output (e.g., 'poster.jpg'). metadata: Optional dictionary with additional metadata about the image (e.g., used image files). """ pass @abstractmethod def finalize(self, run_id: str): """Perform any final actions for the run (e.g., close files, upload manifests).""" pass class FileSystemOutputHandler(OutputHandler): """Handles output by saving results to the local file system.""" def __init__(self, base_output_dir: str = "result"): self.base_output_dir = base_output_dir logging.info(f"FileSystemOutputHandler initialized. Base output directory: {self.base_output_dir}") def _get_run_dir(self, run_id: str) -> str: """Gets the specific directory for a run, creating it if necessary.""" run_dir = os.path.join(self.base_output_dir, run_id) os.makedirs(run_dir, exist_ok=True) return run_dir def _get_variant_dir(self, run_id: str, topic_index: int, variant_index: int, subdir: str | None = None) -> str: """Gets the specific directory for a variant, optionally within a subdirectory (e.g., 'poster'), creating it if necessary.""" run_dir = self._get_run_dir(run_id) variant_base_dir = os.path.join(run_dir, f"{topic_index}_{variant_index}") target_dir = variant_base_dir if subdir: target_dir = os.path.join(variant_base_dir, subdir) os.makedirs(target_dir, exist_ok=True) return target_dir def handle_topic_results(self, run_id: str, topics_list: list, system_prompt: str, user_prompt: str): run_dir = self._get_run_dir(run_id) # Save topics list topics_path = os.path.join(run_dir, f"tweet_topic_{run_id}.json") try: with open(topics_path, "w", encoding="utf-8") as f: # 不使用自定义编码器,直接使用标准json json.dump(topics_list, f, ensure_ascii=False, indent=4, ignore_nan=True) logging.info(f"Topics list saved successfully to: {topics_path}") # 额外创建txt格式输出 txt_path = os.path.join(run_dir, f"tweet_topic_{run_id}.txt") with open(txt_path, "w", encoding="utf-8") as f: f.write(f"# 选题列表 (run_id: {run_id})\n\n") for topic in topics_list: f.write(f"## 选题 {topic.get('index', 'N/A')}\n") f.write(f"- 日期: {topic.get('date', 'N/A')}\n") f.write(f"- 对象: {topic.get('object', 'N/A')}\n") f.write(f"- 产品: {topic.get('product', 'N/A')}\n") f.write(f"- 产品策略: {topic.get('product_logic', 'N/A')}\n") f.write(f"- 风格: {topic.get('style', 'N/A')}\n") f.write(f"- 风格策略: {topic.get('style_logic', 'N/A')}\n") f.write(f"- 目标受众: {topic.get('target_audience', 'N/A')}\n") f.write(f"- 受众策略: {topic.get('target_audience_logic', 'N/A')}\n") f.write(f"- 逻辑: {topic.get('logic', 'N/A')}\n\n") logging.info(f"选题文本版本已保存到: {txt_path}") except Exception as e: logging.exception(f"Error saving topic JSON file to {topics_path}:") # Save prompts prompt_path = os.path.join(run_dir, f"topic_prompt_{run_id}.txt") try: with open(prompt_path, "w", encoding="utf-8") as f: f.write("--- SYSTEM PROMPT ---\n") f.write(system_prompt + "\n\n") f.write("--- USER PROMPT ---\n") f.write(user_prompt + "\n") logging.info(f"Topic prompts saved successfully to: {prompt_path}") except Exception as e: logging.exception(f"Error saving topic prompts file to {prompt_path}:") def handle_content_variant(self, run_id: str, topic_index: int, variant_index: int, content_data: dict, prompt_data: str): """Saves content JSON and prompt for a specific variant.""" variant_dir = self._get_variant_dir(run_id, topic_index, variant_index) # 创建输出数据的副本,避免修改原始数据 import copy input_data = copy.deepcopy(content_data) # 统一使用tags字段,避免tag和tags重复 if "tag" in input_data and "tags" not in input_data: # 只有tag字段存在,复制到tags input_data["tags"] = input_data["tag"] elif "tag" in input_data and "tags" in input_data: # 两个字段都存在,保留tags pass # 确保即使在未启用审核的情况下,字段也保持一致 if not input_data.get("judged", False): input_data["judged"] = False # 添加original字段(临时),值为当前值 if "title" in input_data and "original_title" not in input_data: input_data["original_title"] = input_data["title"] if "content" in input_data and "original_content" not in input_data: input_data["original_content"] = input_data["content"] if "tags" in input_data and "original_tags" not in input_data: input_data["original_tags"] = input_data["tags"] # 保存原始值用于txt文件生成和调试 original_title = input_data.get("title", "") original_content = input_data.get("content", "") original_tags = input_data.get("tags", "") original_judge_analysis = input_data.get("judge_analysis", "") # 创建一个只包含元数据和base64编码的输出数据对象 output_data = { # 保留元数据字段 "judged": input_data.get("judged", False), "judge_success": input_data.get("judge_success", False) } # 为所有内容字段创建base64编码版本 try: # 1. 标题和内容 if "title" in input_data and input_data["title"]: output_data["title_base64"] = base64.b64encode(input_data["title"].encode('utf-8')).decode('ascii') if "content" in input_data and input_data["content"]: output_data["content_base64"] = base64.b64encode(input_data["content"].encode('utf-8')).decode('ascii') # 2. 标签 if "tags" in input_data and input_data["tags"]: output_data["tags_base64"] = base64.b64encode(input_data["tags"].encode('utf-8')).decode('ascii') # 3. 原始内容 if "original_title" in input_data and input_data["original_title"]: output_data["original_title_base64"] = base64.b64encode(input_data["original_title"].encode('utf-8')).decode('ascii') if "original_content" in input_data and input_data["original_content"]: output_data["original_content_base64"] = base64.b64encode(input_data["original_content"].encode('utf-8')).decode('ascii') # 4. 原始标签 if "original_tags" in input_data and input_data["original_tags"]: output_data["original_tags_base64"] = base64.b64encode(input_data["original_tags"].encode('utf-8')).decode('ascii') # 5. 审核分析 if "judge_analysis" in input_data and input_data["judge_analysis"]: output_data["judge_analysis_base64"] = base64.b64encode(input_data["judge_analysis"].encode('utf-8')).decode('ascii') logging.info("成功添加Base64编码内容") except Exception as e: logging.error(f"Base64编码内容时出错: {e}") # 保存可能有用的额外字段 if "error" in input_data: output_data["error"] = input_data["error"] # 保存统一格式的article.json (只包含base64编码和元数据) content_path = os.path.join(variant_dir, "article.json") try: with open(content_path, "w", encoding="utf-8") as f: # 使用标准json json.dump(output_data, f, ensure_ascii=False, indent=4, ignore_nan=True) logging.info(f"Content JSON saved to: {content_path}") except Exception as e: logging.exception(f"Failed to save content JSON to {content_path}: {e}") # 创建一份article.txt文件以便直接查看 txt_path = os.path.join(variant_dir, "article.txt") try: # 使用原始内容,保留所有换行符 with open(txt_path, "w", encoding="utf-8") as f: if original_title: f.write(f"{original_title}\n\n") # 保持原始内容的所有换行符 if original_content: f.write(original_content) if original_tags: f.write(f"\n\n{original_tags}") if original_judge_analysis: f.write(f"\n\n审核分析:\n{original_judge_analysis}") logging.info(f"Article text saved to: {txt_path}") except Exception as e: logging.error(f"Failed to save article.txt: {e}") # 记录调试信息,无论是否成功 (包含原始数据的完整副本以便调试) debug_path = os.path.join(variant_dir, "debug_content.txt") try: with open(debug_path, "w", encoding="utf-8") as f: f.write(f"原始标题: {original_title}\n\n") f.write(f"原始内容: {original_content}\n\n") if original_tags: f.write(f"原始标签: {original_tags}\n\n") if original_judge_analysis: f.write(f"审核分析: {original_judge_analysis}\n\n") f.write("---处理后---\n\n") for key, value in output_data.items(): if isinstance(value, str): f.write(f"{key}: (length: {len(value)})\n") f.write(f"{repr(value[:200])}...\n\n") else: f.write(f"{key}: {type(value)}\n") logging.info(f"调试内容已保存到: {debug_path}") except Exception as debug_err: logging.error(f"保存调试内容失败: {debug_err}") # Save content prompt prompt_path = os.path.join(variant_dir, "tweet_prompt.txt") try: with open(prompt_path, "w", encoding="utf-8") as f: f.write(prompt_data) logging.info(f"Content prompt saved to: {prompt_path}") except Exception as e: logging.error(f"Failed to save content prompt to {prompt_path}: {e}") def _ultra_safe_clean(self, text): """执行最严格的字符清理,确保100%可序列化""" if not isinstance(text, str): return "" return ''.join(c for c in text if 32 <= ord(c) <= 126) def _preprocess_for_json(self, text): """预处理文本,将换行符转换为\\n形式,保证JSON安全""" if not isinstance(text, str): return text # 将所有实际换行符替换为\\n字符串 return text.replace('\n', '\\n').replace('\r', '\\r') def handle_poster_configs(self, run_id: str, topic_index: int, config_data: list | dict): """处理海报配置数据""" # 处理海报配置数据 try: # 创建目标目录 variant_dir = os.path.join(self._get_run_dir(run_id), f"{topic_index}_1") os.makedirs(variant_dir, exist_ok=True) # 确保配置数据是可序列化的 processed_configs = [] if isinstance(config_data, list): for config in config_data: processed_config = {} # 处理索引字段 processed_config["index"] = config.get("index", 0) # 处理标题字段,应用JSON预处理 main_title = config.get("main_title", "") processed_config["main_title"] = self._preprocess_for_json(main_title) # 处理文本字段列表,对每个文本应用JSON预处理 texts = config.get("texts", []) processed_texts = [] for text in texts: processed_texts.append(self._preprocess_for_json(text)) processed_config["texts"] = processed_texts processed_configs.append(processed_config) else: # 如果不是列表,可能是字典或其他格式,尝试转换 if isinstance(config_data, dict): # 处理单个配置字典 processed_config = {} processed_config["index"] = config_data.get("index", 0) processed_config["main_title"] = self._preprocess_for_json(config_data.get("main_title", "")) texts = config_data.get("texts", []) processed_texts = [] for text in texts: processed_texts.append(self._preprocess_for_json(text)) processed_config["texts"] = processed_texts processed_configs.append(processed_config) # 保存配置到JSON文件 config_file_path = os.path.join(variant_dir, f"topic_{topic_index}_poster_configs.json") with open(config_file_path, 'w', encoding='utf-8') as f: json.dump(processed_configs, f, ensure_ascii=False, indent=4, cls=self.SafeJSONEncoder) logging.info(f"Successfully saved poster configs to {config_file_path}") except Exception as e: logging.error(f"Error saving poster configs: {e}") traceback.print_exc() def handle_generated_image(self, run_id: str, topic_index: int, variant_index: int, image_type: str, image_data, output_filename: str, metadata: dict = None): """处理生成的图像,对于笔记图像和额外配图保存到image目录,其他类型保持原有路径结构""" if not image_data: logging.warning(f"传入的{image_type}图像数据为空 (Topic {topic_index}, Variant {variant_index})。跳过保存。") return # 根据图像类型确定保存路径 if image_type == 'note' or image_type == 'additional': # 笔记图像和额外配图保存到image目录 # 创建run_id/i_j/image目录 run_dir = self._get_run_dir(run_id) variant_dir = os.path.join(run_dir, f"{topic_index}_{variant_index}") image_dir = os.path.join(variant_dir, "image") os.makedirs(image_dir, exist_ok=True) # 在输出文件名前加上图像类型前缀 prefixed_filename = f"{image_type}_{output_filename}" if not output_filename.startswith(image_type) else output_filename save_path = os.path.join(image_dir, prefixed_filename) else: # 其他类型图像使用原有的保存路径逻辑 subdir = None if image_type == 'collage': subdir = 'collage_img' # 可配置的子目录名称 elif image_type == 'poster': subdir = 'poster' else: logging.warning(f"未知图像类型 '{image_type}',保存到变体根目录。") subdir = None # 如果类型未知,直接保存到变体目录 # 确保目标目录存在 variant_dir = os.path.join(self._get_run_dir(run_id), f"{topic_index}_{variant_index}") if subdir: target_dir = os.path.join(variant_dir, subdir) os.makedirs(target_dir, exist_ok=True) else: target_dir = variant_dir os.makedirs(target_dir, exist_ok=True) save_path = os.path.join(target_dir, output_filename) try: # 保存图片 image_data.save(save_path) logging.info(f"保存{image_type}图像到: {save_path}") # 保存元数据(如果有) if metadata: # 确保元数据文件与图像在同一目录 metadata_filename = os.path.splitext(os.path.basename(save_path))[0] + "_metadata.json" metadata_path = os.path.join(os.path.dirname(save_path), metadata_filename) try: with open(metadata_path, 'w', encoding='utf-8') as f: # 不使用自定义编码器,使用标准json json.dump(metadata, f, ensure_ascii=False, indent=4, ignore_nan=True) logging.info(f"保存{image_type}元数据到: {metadata_path}") except Exception as me: logging.error(f"无法保存{image_type}元数据到{metadata_path}: {me}") traceback.print_exc() except Exception as e: logging.exception(f"无法保存{image_type}图像到{save_path}: {e}") traceback.print_exc() return save_path def finalize(self, run_id: str): logging.info(f"FileSystemOutputHandler finalizing run: {run_id}. No specific actions needed.") pass # Nothing specific to do for file system finalize def _sanitize_content_for_json(self, data): """对内容进行深度清理,确保可以安全序列化为JSON Args: data: 要处理的数据(字典、列表或基本类型) Returns: 经过处理的数据,确保可以安全序列化 """ if isinstance(data, dict): # 处理字典类型 sanitized_dict = {} for key, value in data.items(): # 移除error标志,我们会在最终验证后重新设置它 if key == "error": continue if key == "raw_result": continue sanitized_dict[key] = self._sanitize_content_for_json(value) return sanitized_dict elif isinstance(data, list): # 处理列表类型 return [self._sanitize_content_for_json(item) for item in data] elif isinstance(data, str): # 处理字符串类型(重点关注) # 1. 首先,替换所有字面的"\n"为真正的换行符 if r'\n' in data: data = data.replace(r'\n', '\n') # 2. 使用更强的处理方式 - 只保留绝对安全的字符 # - ASCII 32-126 (标准可打印ASCII字符) # - 换行、回车、制表符 # - 去除所有其他控制字符和潜在问题字符 safe_chars = [] for char in data: if char in '\n\r\t' or (32 <= ord(char) <= 126): safe_chars.append(char) elif ord(char) > 127: # 非ASCII字符 (包括emoji) # 转换为Unicode转义序列 safe_chars.append(f"\\u{ord(char):04x}".encode().decode('unicode-escape')) cleaned = ''.join(safe_chars) # 3. 验证字符串可以被安全序列化 try: json.dumps(cleaned, ensure_ascii=False) return cleaned except Exception as e: logging.warning(f"字符串清理后仍无法序列化,使用保守处理: {e}") # 最保守的处理 - 只保留ASCII字符 return ''.join(c for c in cleaned if ord(c) < 128) else: # 其他类型(数字、布尔值等)原样返回 return data