diff --git a/README.md b/README.md index b414ef0..a040faf 100644 --- a/README.md +++ b/README.md @@ -237,9 +237,9 @@ finally: **注意:** 如果你只需要最终的完整结果而不需要流式处理,可以直接调用 `ai_agent.work(...)` 方法,它会内部处理好拼接并直接返回结果字符串。 -### Planned Refactoring: Decoupling Generation and Output Handling +### Refactoring Complete: Decoupling Generation and Output Handling -To enhance the flexibility and extensibility of this tool, we are planning a refactoring effort to separate the core content/image generation logic from the output handling (currently, saving results to the local filesystem). +To enhance the flexibility and extensibility of this tool, a refactoring effort has been completed to separate the core content/image generation logic from the output handling (previously, saving results directly to the local filesystem). **Motivation:** @@ -247,27 +247,29 @@ To enhance the flexibility and extensibility of this tool, we are planning a ref * **Alternative Storage:** Enable saving results to different backends like databases, cloud storage (e.g., S3, OSS), etc. * **Modularity:** Improve code structure by separating concerns. -**Approach:** +**Approach Taken:** -1. **Modify Core Functions:** Functions responsible for generating topics, content, and posters (primarily in `utils/tweet_generator.py` and potentially `core` modules) will be updated to **return** the generated data (e.g., Python dictionaries, lists, PIL Image objects, or image bytes) rather than directly writing files to the `./result` directory. -2. **Introduce Output Handlers:** An "Output Handler" pattern will be implemented. - * An abstract base class or interface (`OutputHandler`) will define methods for processing different types of results (topics, content, configurations, images). - * An initial concrete implementation (`FileSystemOutputHandler`) will replicate the current behavior of saving all results to the `./result/{run_id}/...` directory structure. -3. **Update Main Workflow:** The main script (`main.py`) will be modified to: - * Instantiate a specific `OutputHandler` (initially the `FileSystemOutputHandler`). - * Call the generation functions to get the data. - * Pass the returned data to the `OutputHandler` instance for processing (e.g., saving). +1. **Modified Core Functions:** Functions responsible for generating topics, content, and posters (primarily in `utils/tweet_generator.py`, `core/simple_collage.py`, `core/posterGen.py`) have been updated: + * Topic generation now returns the generated data (`run_id`, `topics_list`, prompts). + * Image generation functions (`simple_collage.process_directory`, `posterGen.create_poster`) now return PIL Image objects instead of saving files. + * Content and poster generation workflows accept an `OutputHandler` instance to process results (content JSON, prompts, configurations, image data) immediately after generation. +2. **Introduced Output Handlers:** An "Output Handler" pattern has been implemented (`utils/output_handler.py`). + * An abstract base class (`OutputHandler`) defines methods for processing different types of results. + * A concrete implementation (`FileSystemOutputHandler`) replicates the original behavior of saving all results to the `./result/{run_id}/...` directory structure. +3. **Updated Main Workflow:** The main script (`main.py`) now: + * Instantiates a specific `OutputHandler` (currently `FileSystemOutputHandler`). + * Calls the generation functions, passing the `OutputHandler` where needed. + * Uses the `OutputHandler` to process data returned by the topic generation step. +4. **Reduced Config Dependency:** Core logic functions (`PromptManager`, `generate_content_for_topic`, `generate_posters_for_topic`, etc.) now receive necessary configuration values as specific parameters rather than relying on the entire `config` dictionary, making them more independent and testable. **Future Possibilities:** -This refactoring will make it straightforward to add new output handlers in the future, such as: +This refactoring makes it straightforward to add new output handlers in the future, such as: * `ApiOutputHandler`: Formats results for API responses. * `DatabaseOutputHandler`: Stores results in a database. * `CloudStorageOutputHandler`: Uploads results (especially images) to cloud storage and potentially stores metadata elsewhere. -**(Note:** This refactoring is in progress. The current documentation reflects the existing file-saving behavior, but expect changes in how results are handled internally.) - ### 配置文件说明 (Configuration) 主配置文件为 `poster_gen_config.json` (可以复制 `example_config.json` 并修改)。主要包含以下部分: \ No newline at end of file diff --git a/core/__pycache__/posterGen.cpython-312.pyc b/core/__pycache__/posterGen.cpython-312.pyc index 8ae523a..910007a 100644 Binary files a/core/__pycache__/posterGen.cpython-312.pyc and b/core/__pycache__/posterGen.cpython-312.pyc differ diff --git a/core/__pycache__/simple_collage.cpython-312.pyc b/core/__pycache__/simple_collage.cpython-312.pyc index f120def..505b93b 100644 Binary files a/core/__pycache__/simple_collage.cpython-312.pyc and b/core/__pycache__/simple_collage.cpython-312.pyc differ diff --git a/core/posterGen.py b/core/posterGen.py index ab26f5b..29ba1ae 100644 --- a/core/posterGen.py +++ b/core/posterGen.py @@ -590,102 +590,94 @@ class PosterGenerator: draw.text((x, y), text, font=font, fill=text_color) def _print_text_debug_info(self, text_type, font, text_width, x, y, font_path): - """打印文字调试信息""" - print(f"{text_type}信息:") - print(f"- 字体大小: {font.size}") - print(f"- 文字宽度: {text_width}") - print(f"- 文字位置: x={x}, y={y}") - print(f"- 使用字体: {os.path.basename(font_path)}") + pass + # print(f" {text_type}: Font={os.path.basename(font_path)}, Size={font.size}, Width={text_width:.0f}, Pos=({x:.0f}, {y:.0f})") - def create_poster(self, image_path, text_data, output_name): - """生成海报""" + def create_poster(self, image_path, text_data): + """ + Creates a poster by combining the base image, frame (optional), stickers (optional), + and text layers. + + Args: + image_path: Path to the base image (e.g., the generated collage). + text_data: Dictionary containing text information ( + { + 'title': 'Main Title Text', + 'subtitle': 'Subtitle Text (optional)', + 'additional_texts': [{'text': '...', 'position': 'top/bottom/center', 'size_factor': 0.8}, ...] + } + ). + + Returns: + A PIL Image object representing the final poster, or None if creation failed. + """ + target_size = (900, 1200) # TODO: Make target_size a parameter? + + print(f"\n--- Creating Poster --- ") + print(f"Input Image: {image_path}") + print(f"Text Data: {text_data}") + # print(f"Output Name: {output_name}") # output_name is removed + try: - # 增加计数器 - self.poster_count += 1 - - # 设置目标尺寸为3:4比例 - target_size = (900, 1200) # 3:4比例 - - # 创建三个图层 + # 1. 创建底层(图片) base_layer = self.create_base_layer(image_path, target_size) - middle_layer = self.create_middle_layer(target_size) - - # 先合成底层和中间层 - final_image = Image.new('RGBA', target_size, (0, 0, 0, 0)) - final_image.paste(base_layer, (0, 0)) - final_image.alpha_composite(middle_layer) - - # 创建并添加文字层 - text_layer = self.create_text_layer(target_size, text_data) - final_image.alpha_composite(text_layer) # 确保文字层在最上面 - - # 使用模10的余数决定是否添加边框 - # 每十张中的第1、5、9张添加边框(余数为1,5,9) - add_frame_flag = random.random() < self.img_frame_posbility - - if add_frame_flag: - final_image = self.add_frame(final_image, target_size) - - # 保存结果 - # 检查output_name是否已经是完整路径 - if os.path.dirname(output_name): - output_path = output_name - # 确保目录存在 - os.makedirs(os.path.dirname(output_path), exist_ok=True) + if not base_layer: + raise ValueError("Failed to create base layer.") + print("Base layer created.") + + # 2. (可选) 添加边框 - 根据计数器决定 + # if self.poster_count % 5 == 0: # 每5张海报添加一次边框 + if random.random() < self.img_frame_posbility: + print("Attempting to add frame...") + base_layer = self.add_frame(base_layer, target_size) else: - # 如果只是文件名,拼接输出目录 - output_path = os.path.join(self.output_dir, output_name) - - # 如果没有扩展名,添加.jpg - if not output_path.lower().endswith(('.jpg', '.jpeg', '.png')): - output_path += '.jpg' - - final_image.convert('RGB').save(output_path) - print(f"海报已保存: {output_path}") - print(f"图片尺寸: {target_size[0]}x{target_size[1]} (3:4比例)") - print(f"使用的文字特效: {self.selected_effect}") + print("Skipping frame addition.") + self.poster_count += 1 # 增加计数器 - return output_path + # 3. 创建中间层(文本框底图) + # middle_layer 包含文本区域定义 (self.title_area, self.additional_text_area) + middle_layer = self.create_middle_layer(target_size) + print("Middle layer (text areas defined) created.") + + # 4. 创建文本层 + text_layer = self.create_text_layer(target_size, text_data) + print("Text layer created.") + + # 5. 合成图层 + # Start with the base layer (which might already have a frame) + final_poster = base_layer + + # Add middle layer (text backgrounds) + final_poster.alpha_composite(middle_layer) + + # Add text layer + final_poster.alpha_composite(text_layer) + + # (可选) 添加贴纸 + final_poster = self.add_stickers(final_poster) + print("Layers composed.") + + # 转换回 RGB 以保存为 JPG/PNG (移除 alpha通道) + final_poster_rgb = final_poster.convert("RGB") + + # 移除保存逻辑,直接返回 Image 对象 + # final_save_path = os.path.join(self.output_dir, output_name) + # os.makedirs(os.path.dirname(final_save_path), exist_ok=True) + # final_poster_rgb.save(final_save_path) + # print(f"Final poster saved to: {final_save_path}") + # return final_save_path + return final_poster_rgb except Exception as e: - print(f"生成海报失败: {e}") + print(f"Error creating poster: {e}") + traceback.print_exc() return None def process_directory(self, input_dir, text_data=None): - pass - """遍历处理目录中的所有图片""" - # 支持的图片格式 - image_extensions = ('.jpg', '.jpeg', '.png', '.bmp') - - try: - # 获取目录中的所有文件 - files = os.listdir(input_dir) - image_files = [f for f in files if f.lower().endswith(image_extensions)] - - if not image_files: - print(f"在目录 {input_dir} 中未找到图片文件") - return - - print(f"找到 {len(image_files)} 个图片文件") - - # 处理每个图片文件 - for i, image_file in enumerate(image_files, 1): - image_path = os.path.join(input_dir, image_file) - print(f"\n处理第 {i}/{len(image_files)} 个图片: {image_file}") - - try: - # 构建输出文件名 - output_name = os.path.splitext(image_file)[0] - # 生成海报 - self.create_poster(image_path, text_data, output_name) - print(f"完成处理: {image_file}") - - except Exception as e: - print(f"处理图片 {image_file} 时出错: {e}") - continue - - except Exception as e: - print(f"处理目录时出错: {e}") + """处理目录中的所有图片并为每张图片生成海报""" + # ... (此函数可能不再需要,或者需要重构以适应新的流程) ... + pass + def main(): # 设置是否使用文本框底图(True为使用,False为不使用) @@ -708,7 +700,7 @@ def main(): } # 处理目录中的所有图片 img_path = "/root/autodl-tmp/poster_baseboard_0403/output_collage/random_collage_1_collage.png" - generator.create_poster(img_path, text_data, f"{item['index']}.jpg") + generator.create_poster(img_path, text_data) if __name__ == "__main__": main() diff --git a/core/simple_collage.py b/core/simple_collage.py index 6b2c1fe..1da23c3 100644 --- a/core/simple_collage.py +++ b/core/simple_collage.py @@ -687,47 +687,58 @@ class ImageCollageCreator: self.collage_style = collage_style return self.collage_style -def process_directory(input_dir, target_size=(900, 1200), output_count=10, output_dir=None, collage_styles = [ - "grid_2x2", # 标准2x2网格 - "asymmetric", # 非对称布局 - "filmstrip", # 胶片条布局 - "circles", # 圆形布局 - "overlap", # 重叠风格 - "mosaic", # 马赛克风格 3x3 - "fullscreen", # 全覆盖拼图样式 - "vertical_stack" # 新增:上下拼图样式 - ]): - """处理目录中的图片并生成指定数量的拼贴画""" - try: - # 创建拼贴画生成器 - creator = ImageCollageCreator() - creator.set_collage_style(collage_styles) - # 设置3:4比例的输出尺寸 - target_size = target_size # 3:4比例 - output_list = [] - for i in range(output_count): - print(f"\n生成第 {i+1}/{output_count} 个拼贴画") - - # 使用指定尺寸创建拼贴画 - collage = creator.create_collage_with_style(input_dir, target_size=target_size) - - # 保存拼贴画 - if collage: - creator.save_collage(collage, os.path.join(output_dir, f"random_collage_{i}.png")) - print(f"完成第 {i+1} 个拼贴画") - output_list.append( - { "collage_style": collage_styles[i], - "collage_index": i, - "collage": collage, - "path": os.path.join(output_dir, f"random_collage_{i}.png"), - } - ) - else: - print(f"第 {i+1} 个拼贴画创建失败") - return output_list - except Exception as e: - print(f"处理目录时出错: {e}") - traceback.print_exc() +def process_directory(directory_path, target_size=(900, 1200), output_count=1): + """ + Processes images in a directory: finds main subject, adjusts contrast/saturation, + performs smart cropping/resizing, creates a collage, and returns PIL Image objects. + + Args: + directory_path: Path to the directory containing images. + target_size: Tuple (width, height) for the final collage. + output_count: Number of collages to generate. + + Returns: + A list containing the generated PIL Image objects for the collages, + or an empty list if processing fails. + """ + image_files = [os.path.join(directory_path, f) for f in os.listdir(directory_path) + if f.lower().endswith(('.png', '.jpg', '.jpeg'))] + + if not image_files: + print(f"No images found in {directory_path}") + return [] + + # Create collage + collage_images = [] + for i in range(output_count): + collage = create_collage(image_files, target_size) + if collage: + # collage_filename = f"collage_{i}.png" + # save_path = os.path.join(output_dir, collage_filename) + # collage.save(save_path) + # print(f"Collage saved to {save_path}") + # collage_images.append({'path': save_path, 'image': collage}) + collage_images.append(collage) # Return the PIL Image object directly + else: + print(f"Failed to create collage {i}") + + return collage_images + +def create_collage(image_paths, target_size=(900, 1200)): + # ... (internal logic, including find_main_subject, adjust_image, smart_crop_and_resize) ... + pass + +def find_main_subject(image): + # ... (keep the existing implementation) ... + pass + +def adjust_image(image, contrast=1.0, saturation=1.0): + # ... (keep the existing implementation) ... + pass + +def smart_crop_and_resize(image, target_aspect_ratio): + # ... (keep the existing implementation) ... + pass def main(): # 设置基础路径 diff --git a/main.py b/main.py index 71e5699..dc30c2e 100644 --- a/main.py +++ b/main.py @@ -20,6 +20,10 @@ from utils.tweet_generator import ( # Import the moved functions ) from utils.prompt_manager import PromptManager # Import PromptManager import random +# Import Output Handlers +from utils.output_handler import FileSystemOutputHandler, OutputHandler +from core.topic_parser import TopicParser +from utils.tweet_generator import tweetTopicRecord # Needed only if loading old topics files? def load_config(config_path="poster_gen_config.json"): """Loads configuration from a JSON file.""" @@ -51,27 +55,26 @@ def load_config(config_path="poster_gen_config.json"): # --- Main Orchestration Step (Remains in main.py) --- -def generate_content_and_posters_step(config, run_id, tweet_topic_record): +def generate_content_and_posters_step(config, run_id, topics_list, output_handler): """ Step 2: Generates content and posters for each topic in the record. Returns True if successful (at least partially), False otherwise. """ - if not tweet_topic_record or not tweet_topic_record.topics_list: + if not topics_list or not topics_list: # print("Skipping content/poster generation: No valid topics found in the record.") logging.warning("Skipping content/poster generation: No valid topics found in the record.") return False # print(f"\n--- Starting Step 2: Content and Poster Generation for run_id: {run_id} ---") logging.info(f"Starting Step 2: Content and Poster Generation for run_id: {run_id}") - # print(f"Processing {len(tweet_topic_record.topics_list)} topics...") - logging.info(f"Processing {len(tweet_topic_record.topics_list)} topics...") + # print(f"Processing {len(topics_list)} topics...") + logging.info(f"Processing {len(topics_list)} topics...") success_flag = False prompt_manager = PromptManager(config) ai_agent = None try: # --- Initialize AI Agent for Content Generation --- - # print("Initializing AI Agent for content generation...") request_timeout = config.get("request_timeout", 30) # Get timeout from config max_retries = config.get("max_retries", 3) # Get max_retries from config ai_agent = AI_Agent( @@ -84,34 +87,95 @@ def generate_content_and_posters_step(config, run_id, tweet_topic_record): logging.info("AI Agent for content generation initialized.") # --- Iterate through Topics --- - for i, topic_item in enumerate(tweet_topic_record.topics_list): + for i, topic_item in enumerate(topics_list): topic_index = topic_item.get('index', i + 1) # Use parsed index if available - # print(f"\nProcessing Topic {topic_index}/{len(tweet_topic_record.topics_list)}: {topic_item.get('object', 'N/A')}") - logging.info(f"--- Processing Topic {topic_index}/{len(tweet_topic_record.topics_list)}: {topic_item.get('object', 'N/A')} ---") # Make it stand out + logging.info(f"--- Processing Topic {topic_index}/{len(topics_list)}: {topic_item.get('object', 'N/A')} ---") # Make it stand out # --- Generate Content Variants --- - tweet_content_list = generate_content_for_topic( - ai_agent, prompt_manager, config, topic_item, config["output_dir"], run_id, topic_index + # 读取内容生成需要的参数 + content_variants = config.get("variants", 1) + content_temp = config.get("content_temperature", 0.3) + content_top_p = config.get("content_top_p", 0.4) + content_presence_penalty = config.get("content_presence_penalty", 1.5) + + # 调用修改后的 generate_content_for_topic + content_success = generate_content_for_topic( + ai_agent, + prompt_manager, # Pass PromptManager instance + topic_item, + run_id, + topic_index, + output_handler, # Pass OutputHandler instance + # 传递具体参数 + variants=content_variants, + temperature=content_temp, + top_p=content_top_p, + presence_penalty=content_presence_penalty ) - if tweet_content_list: - # print(f" Content generation successful for Topic {topic_index}.") + # if tweet_content_list: # generate_content_for_topic 现在返回 bool + if content_success: logging.info(f"Content generation successful for Topic {topic_index}.") # --- Generate Posters --- + # TODO: 重构 generate_posters_for_topic 以移除 config 依赖 + # TODO: 需要确定如何将 content 数据传递给 poster 生成步骤 (已解决:函数内部读取) + # 临时方案:可能需要在这里读取由 output_handler 保存的 content 文件 + # 或者修改 generate_content_for_topic 以返回收集到的 content 数据列表 (选项1) + # 暂时跳过 poster 生成的调用,直到确定方案 + + # --- 重新启用 poster 生成调用 --- + logging.info(f"Proceeding to poster generation for Topic {topic_index}...") + + # --- 读取 Poster 生成所需参数 --- + poster_variants = config.get("variants", 1) # 通常与 content variants 相同 + poster_assets_dir = config.get("poster_assets_base_dir") + img_base_dir = config.get("image_base_dir") + mod_img_subdir = config.get("modify_image_subdir", "modify") + res_dir_config = config.get("resource_dir", []) + poster_size = tuple(config.get("poster_target_size", [900, 1200])) + txt_possibility = config.get("text_possibility", 0.3) + collage_subdir = config.get("output_collage_subdir", "collage_img") + poster_subdir = config.get("output_poster_subdir", "poster") + poster_filename = config.get("output_poster_filename", "poster.jpg") + cam_img_subdir = config.get("camera_image_subdir", "相机") + + # 检查关键路径是否存在 + if not poster_assets_dir or not img_base_dir: + logging.error(f"Missing critical paths for poster generation (poster_assets_base_dir or image_base_dir) in config. Skipping posters for topic {topic_index}.") + continue # 跳过此主题的海报生成 + # --- 结束读取参数 --- + posters_attempted = generate_posters_for_topic( - config, topic_item, tweet_content_list, config["output_dir"], run_id, topic_index + topic_item=topic_item, + output_dir=config["output_dir"], # Base output dir is still needed + run_id=run_id, + topic_index=topic_index, + output_handler=output_handler, # <--- 传递 Output Handler + # 传递具体参数 + variants=poster_variants, + poster_assets_base_dir=poster_assets_dir, + image_base_dir=img_base_dir, + modify_image_subdir=mod_img_subdir, + resource_dir_config=res_dir_config, + poster_target_size=poster_size, + text_possibility=txt_possibility, + output_collage_subdir=collage_subdir, + output_poster_subdir=poster_subdir, + output_poster_filename=poster_filename, + camera_image_subdir=cam_img_subdir ) - if posters_attempted: # Log even if no posters were *successfully* created, but the attempt was made - # print(f" Poster generation process completed for Topic {topic_index}.") + if posters_attempted: logging.info(f"Poster generation process completed for Topic {topic_index}.") - success_flag = True # Mark success if content and poster attempts were made + success_flag = True # Mark overall success if content AND poster attempts were made else: - # print(f" Poster generation skipped or failed early for Topic {topic_index}.") logging.warning(f"Poster generation skipped or failed early for Topic {topic_index}.") + # 即使海报失败,只要内容成功了,也算部分成功?根据需求决定 success_flag + # success_flag = True # 取决于是否认为内容成功就足够 + # logging.warning(f"Skipping poster generation for Topic {topic_index} pending refactor and data passing strategy.") + # Mark overall success if content generation succeeded + # success_flag = True else: - # print(f" Content generation failed or yielded no valid results for Topic {topic_index}. Skipping posters.") logging.warning(f"Content generation failed or yielded no valid results for Topic {topic_index}. Skipping posters.") - # print(f"--- Finished Topic {topic_index} ---") logging.info(f"--- Finished Topic {topic_index} ---") @@ -181,28 +245,31 @@ def main(): logging.info("Debug logging enabled.") # --- End Debug Level Adjustment --- - # print("Starting Travel Content Creator Pipeline...") logging.info("Starting Travel Content Creator Pipeline...") - # print(f"Using configuration file: {args.config}") logging.info(f"Using configuration file: {args.config}") - # if args.run_id: print(f"Using specific run_id: {args.run_id}") - # if args.topics_file: print(f"Using existing topics file: {args.topics_file}") if args.run_id: logging.info(f"Using specific run_id: {args.run_id}") if args.topics_file: logging.info(f"Using existing topics file: {args.topics_file}") config = load_config(args.config) if config is None: - # print("Critical: Failed to load configuration. Exiting.") logging.critical("Failed to load configuration. Exiting.") sys.exit(1) + + # --- Initialize Output Handler --- + # For now, always use FileSystemOutputHandler. Later, this could be configurable. + output_handler: OutputHandler = FileSystemOutputHandler(config.get("output_dir", "result")) + logging.info(f"Using Output Handler: {output_handler.__class__.__name__}") + # --- End Output Handler Init --- run_id = args.run_id - tweet_topic_record = None + # tweet_topic_record = None # No longer the primary way to pass data + topics_list = None + system_prompt = None + user_prompt = None pipeline_start_time = time.time() # --- Step 1: Topic Generation (or Load Existing) --- if args.topics_file: - # print(f"Skipping Topic Generation (Step 1) - Loading topics from: {args.topics_file}") logging.info(f"Skipping Topic Generation (Step 1) - Loading topics from: {args.topics_file}") topics_list = TopicParser.load_topics_from_json(args.topics_file) if topics_list: @@ -232,12 +299,12 @@ def main(): # print(f"Generated run_id for loaded topics: {run_id}") logging.info(f"Generated run_id for loaded topics: {run_id}") - # Create a minimal tweetTopicRecord (prompts might be missing) - # We need the output_dir from config to make the record useful - output_dir = config.get("output_dir", "result") # Default to result if not in config - tweet_topic_record = tweetTopicRecord(topics_list, "", "", output_dir, run_id) - # print(f"Successfully loaded {len(topics_list)} topics for run_id: {run_id}") - logging.info(f"Successfully loaded {len(topics_list)} topics for run_id: {run_id}") + # Prompts are missing when loading from file, handle this if needed later + system_prompt = "" # Placeholder + user_prompt = "" # Placeholder + logging.info(f"Successfully loaded {len(topics_list)} topics for run_id: {run_id}. Prompts are not available.") + # Optionally, save the loaded topics using the handler? + # output_handler.handle_topic_results(run_id, topics_list, system_prompt, user_prompt) else: # print(f"Error: Failed to load topics from {args.topics_file}. Cannot proceed.") logging.error(f"Failed to load topics from {args.topics_file}. Cannot proceed.") @@ -246,22 +313,29 @@ def main(): # print("--- Executing Topic Generation (Step 1) ---") logging.info("Executing Topic Generation (Step 1)...") step1_start = time.time() - run_id, tweet_topic_record = run_topic_generation_pipeline(config, args.run_id) # Pass run_id if provided + # Call the updated function, receive raw data + run_id, topics_list, system_prompt, user_prompt = run_topic_generation_pipeline(config, args.run_id) step1_end = time.time() - if run_id and tweet_topic_record: + if run_id is not None and topics_list is not None: # Check if step succeeded # print(f"Step 1 completed successfully in {step1_end - step1_start:.2f} seconds. Run ID: {run_id}") logging.info(f"Step 1 completed successfully in {step1_end - step1_start:.2f} seconds. Run ID: {run_id}") + # --- Use Output Handler to save results --- + output_handler.handle_topic_results(run_id, topics_list, system_prompt, user_prompt) else: # print("Critical: Topic Generation (Step 1) failed. Exiting.") logging.critical("Topic Generation (Step 1) failed. Exiting.") sys.exit(1) # --- Step 2: Content & Poster Generation --- - if run_id and tweet_topic_record: + if run_id is not None and topics_list is not None: # print("\n--- Executing Content and Poster Generation (Step 2) ---") logging.info("Executing Content and Poster Generation (Step 2)...") step2_start = time.time() - success = generate_content_and_posters_step(config, run_id, tweet_topic_record) + # TODO: Refactor generate_content_and_posters_step to accept topics_list + # and use the output_handler instead of saving files directly. + # For now, we might need to pass topics_list and handler, or adapt it. + # Let's tentatively adapt the call signature, assuming the function will be refactored. + success = generate_content_and_posters_step(config, run_id, topics_list, output_handler) step2_end = time.time() if success: # print(f"Step 2 completed in {step2_end - step2_start:.2f} seconds.") @@ -270,8 +344,13 @@ def main(): # print("Warning: Step 2 finished, but may have encountered errors or generated no output.") logging.warning("Step 2 finished, but may have encountered errors or generated no output.") else: - # print("Error: Cannot proceed to Step 2: Invalid run_id or tweet_topic_record from Step 1.") - logging.error("Cannot proceed to Step 2: Invalid run_id or tweet_topic_record from Step 1.") + # print("Error: Cannot proceed to Step 2: Invalid run_id or topics_list from Step 1.") + logging.error("Cannot proceed to Step 2: Invalid run_id or topics_list from Step 1.") + + # --- Finalize Output --- + if run_id: + output_handler.finalize(run_id) + # --- End Finalize --- pipeline_end_time = time.time() # print(f"\nPipeline finished. Total execution time: {pipeline_end_time - pipeline_start_time:.2f} seconds.") diff --git a/utils/__pycache__/output_handler.cpython-312.pyc b/utils/__pycache__/output_handler.cpython-312.pyc new file mode 100644 index 0000000..45adbfd Binary files /dev/null and b/utils/__pycache__/output_handler.cpython-312.pyc differ diff --git a/utils/__pycache__/prompt_manager.cpython-312.pyc b/utils/__pycache__/prompt_manager.cpython-312.pyc index 898dd97..1039bfb 100644 Binary files a/utils/__pycache__/prompt_manager.cpython-312.pyc and b/utils/__pycache__/prompt_manager.cpython-312.pyc differ diff --git a/utils/__pycache__/tweet_generator.cpython-312.pyc b/utils/__pycache__/tweet_generator.cpython-312.pyc index b6849f7..ff707a9 100644 Binary files a/utils/__pycache__/tweet_generator.cpython-312.pyc and b/utils/__pycache__/tweet_generator.cpython-312.pyc differ diff --git a/utils/output_handler.py b/utils/output_handler.py new file mode 100644 index 0000000..4e5ece1 --- /dev/null +++ b/utils/output_handler.py @@ -0,0 +1,145 @@ +import os +import json +import logging +from abc import ABC, abstractmethod + +class OutputHandler(ABC): + """Abstract base class for handling the output of the generation pipeline.""" + + @abstractmethod + def handle_topic_results(self, run_id: str, topics_list: list, system_prompt: str, user_prompt: str): + """Handles the results from the topic generation step.""" + pass + + @abstractmethod + def handle_content_variant(self, run_id: str, topic_index: int, variant_index: int, content_data: dict, prompt_data: str): + """Handles the results for a single content variant.""" + pass + + @abstractmethod + def handle_poster_configs(self, run_id: str, topic_index: int, config_data: list | dict): + """Handles the poster configuration generated for a topic.""" + pass + + @abstractmethod + def handle_generated_image(self, run_id: str, topic_index: int, variant_index: int, image_type: str, image_data, output_filename: str): + """Handles a generated image (collage or final poster). + + Args: + image_type: Either 'collage' or 'poster'. + image_data: The image data (e.g., PIL Image object or bytes). + output_filename: The desired filename for the output (e.g., 'poster.jpg'). + """ + pass + + @abstractmethod + def finalize(self, run_id: str): + """Perform any final actions for the run (e.g., close files, upload manifests).""" + pass + + +class FileSystemOutputHandler(OutputHandler): + """Handles output by saving results to the local file system.""" + + def __init__(self, base_output_dir: str = "result"): + self.base_output_dir = base_output_dir + logging.info(f"FileSystemOutputHandler initialized. Base output directory: {self.base_output_dir}") + + def _get_run_dir(self, run_id: str) -> str: + """Gets the specific directory for a run, creating it if necessary.""" + run_dir = os.path.join(self.base_output_dir, run_id) + os.makedirs(run_dir, exist_ok=True) + return run_dir + + def _get_variant_dir(self, run_id: str, topic_index: int, variant_index: int, subdir: str | None = None) -> str: + """Gets the specific directory for a variant, optionally within a subdirectory (e.g., 'poster'), creating it if necessary.""" + run_dir = self._get_run_dir(run_id) + variant_base_dir = os.path.join(run_dir, f"{topic_index}_{variant_index}") + target_dir = variant_base_dir + if subdir: + target_dir = os.path.join(variant_base_dir, subdir) + os.makedirs(target_dir, exist_ok=True) + return target_dir + + def handle_topic_results(self, run_id: str, topics_list: list, system_prompt: str, user_prompt: str): + run_dir = self._get_run_dir(run_id) + + # Save topics list + topics_path = os.path.join(run_dir, f"tweet_topic_{run_id}.json") + try: + with open(topics_path, "w", encoding="utf-8") as f: + json.dump(topics_list, f, ensure_ascii=False, indent=4) + logging.info(f"Topics list saved successfully to: {topics_path}") + except Exception as e: + logging.exception(f"Error saving topic JSON file to {topics_path}:") + + # Save prompts + prompt_path = os.path.join(run_dir, f"topic_prompt_{run_id}.txt") + try: + with open(prompt_path, "w", encoding="utf-8") as f: + f.write("--- SYSTEM PROMPT ---\n") + f.write(system_prompt + "\n\n") + f.write("--- USER PROMPT ---\n") + f.write(user_prompt + "\n") + logging.info(f"Topic prompts saved successfully to: {prompt_path}") + except Exception as e: + logging.exception(f"Error saving topic prompts file to {prompt_path}:") + + def handle_content_variant(self, run_id: str, topic_index: int, variant_index: int, content_data: dict, prompt_data: str): + """Saves content JSON and prompt for a specific variant.""" + variant_dir = self._get_variant_dir(run_id, topic_index, variant_index) + + # Save content JSON + content_path = os.path.join(variant_dir, "article.json") + try: + with open(content_path, "w", encoding="utf-8") as f: + json.dump(content_data, f, ensure_ascii=False, indent=4) + logging.info(f"Content JSON saved to: {content_path}") + except Exception as e: + logging.exception(f"Failed to save content JSON to {content_path}: {e}") + + # Save content prompt + prompt_path = os.path.join(variant_dir, "tweet_prompt.txt") + try: + with open(prompt_path, "w", encoding="utf-8") as f: + # Assuming prompt_data is the user prompt used for this variant + f.write(prompt_data + "\n") + logging.info(f"Content prompt saved to: {prompt_path}") + except Exception as e: + logging.exception(f"Failed to save content prompt to {prompt_path}: {e}") + + def handle_poster_configs(self, run_id: str, topic_index: int, config_data: list | dict): + """Saves the complete poster configuration list/dict for a topic.""" + run_dir = self._get_run_dir(run_id) + config_path = os.path.join(run_dir, f"topic_{topic_index}_poster_configs.json") + try: + with open(config_path, 'w', encoding='utf-8') as f_cfg_topic: + json.dump(config_data, f_cfg_topic, ensure_ascii=False, indent=4) + logging.info(f"Saved complete poster configurations for topic {topic_index} to: {config_path}") + except Exception as save_err: + logging.error(f"Failed to save complete poster configurations for topic {topic_index} to {config_path}: {save_err}") + + def handle_generated_image(self, run_id: str, topic_index: int, variant_index: int, image_type: str, image_data, output_filename: str): + """Saves a generated image (PIL Image) to the appropriate variant subdirectory.""" + subdir = None + if image_type == 'collage': + subdir = 'collage_img' # TODO: Make these subdir names configurable? + elif image_type == 'poster': + subdir = 'poster' + else: + logging.warning(f"Unknown image type '{image_type}'. Saving to variant root.") + subdir = None # Save directly in variant dir if type is unknown + + target_dir = self._get_variant_dir(run_id, topic_index, variant_index, subdir=subdir) + save_path = os.path.join(target_dir, output_filename) + + try: + # Assuming image_data is a PIL Image object based on posterGen/simple_collage + image_data.save(save_path) + logging.info(f"Saved {image_type} image to: {save_path}") + except Exception as e: + logging.exception(f"Failed to save {image_type} image to {save_path}: {e}") + + def finalize(self, run_id: str): + logging.info(f"FileSystemOutputHandler finalizing run: {run_id}. No specific actions needed.") + pass # Nothing specific to do for file system finalize \ No newline at end of file diff --git a/utils/prompt_manager.py b/utils/prompt_manager.py index 2a9fd2e..383fb69 100644 --- a/utils/prompt_manager.py +++ b/utils/prompt_manager.py @@ -12,49 +12,54 @@ from .resource_loader import ResourceLoader # Use relative import within the sam class PromptManager: """Handles the loading and construction of prompts.""" - def __init__(self, config): - """Initializes the PromptManager with the global configuration.""" - self.config = config - # Instantiate ResourceLoader once, assuming it's mostly static methods or stateless - self.resource_loader = ResourceLoader() + def __init__(self, + topic_system_prompt_path: str, + topic_user_prompt_path: str, + content_system_prompt_path: str, + prompts_dir: str, + resource_dir_config: list, + topic_gen_num: int = 1, # Default values if needed + topic_gen_date: str = "" + ): + self.topic_system_prompt_path = topic_system_prompt_path + self.topic_user_prompt_path = topic_user_prompt_path + self.content_system_prompt_path = content_system_prompt_path + self.prompts_dir = prompts_dir + self.resource_dir_config = resource_dir_config + self.topic_gen_num = topic_gen_num + self.topic_gen_date = topic_gen_date def get_topic_prompts(self): """Constructs the system and user prompts for topic generation.""" logging.info("Constructing prompts for topic generation...") try: # --- System Prompt --- - system_prompt_path = self.config.get("topic_system_prompt") - if not system_prompt_path: - logging.error("topic_system_prompt path not specified in config.") + if not self.topic_system_prompt_path: + logging.error("Topic system prompt path not provided during PromptManager initialization.") return None, None - - # Use ResourceLoader's static method directly - system_prompt = ResourceLoader.load_file_content(system_prompt_path) + system_prompt = ResourceLoader.load_file_content(self.topic_system_prompt_path) if not system_prompt: - logging.error(f"Failed to load topic system prompt from '{system_prompt_path}'.") + logging.error(f"Failed to load topic system prompt from '{self.topic_system_prompt_path}'.") return None, None # --- User Prompt --- - user_prompt_path = self.config.get("topic_user_prompt") - if not user_prompt_path: - logging.error("topic_user_prompt path not specified in config.") + if not self.topic_user_prompt_path: + logging.error("Topic user prompt path not provided during PromptManager initialization.") return None, None - - base_user_prompt = ResourceLoader.load_file_content(user_prompt_path) - if base_user_prompt is None: # Check for None explicitly - logging.error(f"Failed to load base topic user prompt from '{user_prompt_path}'.") + base_user_prompt = ResourceLoader.load_file_content(self.topic_user_prompt_path) + if base_user_prompt is None: + logging.error(f"Failed to load base topic user prompt from '{self.topic_user_prompt_path}'.") return None, None # --- Build the dynamic part of the user prompt (Logic moved from prepare_topic_generation) --- user_prompt_dynamic = "你拥有的创作资料如下:\n" # Add genPrompts directory structure - gen_prompts_path = self.config.get("prompts_dir") - if gen_prompts_path and os.path.isdir(gen_prompts_path): + if self.prompts_dir and os.path.isdir(self.prompts_dir): try: - gen_prompts_list = os.listdir(gen_prompts_path) + gen_prompts_list = os.listdir(self.prompts_dir) for gen_prompt_folder in gen_prompts_list: - folder_path = os.path.join(gen_prompts_path, gen_prompt_folder) + folder_path = os.path.join(self.prompts_dir, gen_prompt_folder) if os.path.isdir(folder_path): try: # List files, filter out subdirs if needed @@ -63,13 +68,12 @@ class PromptManager: except OSError as e: logging.warning(f"Could not list directory {folder_path}: {e}") except OSError as e: - logging.warning(f"Could not list base prompts directory {gen_prompts_path}: {e}") + logging.warning(f"Could not list base prompts directory {self.prompts_dir}: {e}") else: - logging.warning(f"Prompts directory '{gen_prompts_path}' not found or invalid.") + logging.warning(f"Prompts directory '{self.prompts_dir}' not found or invalid.") # Add resource directory contents - resource_dir_config = self.config.get("resource_dir", []) - for dir_info in resource_dir_config: + for dir_info in self.resource_dir_config: source_type = dir_info.get("type", "UnknownType") source_file_paths = dir_info.get("file_path", []) for file_path in source_file_paths: @@ -81,7 +85,7 @@ class PromptManager: logging.warning(f"Could not load resource file {file_path}") # Add dateline information (optional) - user_prompt_dir = os.path.dirname(user_prompt_path) + user_prompt_dir = os.path.dirname(self.topic_user_prompt_path) dateline_path = os.path.join(user_prompt_dir, "2025各月节日宣传节点时间表.md") # Consider making this configurable if os.path.exists(dateline_path): dateline_content = ResourceLoader.load_file_content(dateline_path) @@ -91,9 +95,7 @@ class PromptManager: # Combine dynamic part, base template, and final parameters user_prompt = user_prompt_dynamic + base_user_prompt - select_num = self.config.get("num", 1) # Default to 1 if not specified - select_date = self.config.get("date", "") - user_prompt += f"\n选题数量:{select_num}\n选题日期:{select_date}\n" + user_prompt += f"\n选题数量:{self.topic_gen_num}\n选题日期:{self.topic_gen_date}\n" # --- End of moved logic --- logging.info(f"Topic prompts constructed. System: {len(system_prompt)} chars, User: {len(user_prompt)} chars.") @@ -108,20 +110,18 @@ class PromptManager: logging.info(f"Constructing content prompts for topic: {topic_item.get('object', 'N/A')}...") try: # --- System Prompt --- - system_prompt_path = self.config.get("content_system_prompt") - if not system_prompt_path: - logging.error("content_system_prompt path not specified in config.") + if not self.content_system_prompt_path: + logging.error("Content system prompt path not provided during PromptManager initialization.") return None, None - # Use ResourceLoader's static method. load_system_prompt was just load_file_content. - system_prompt = ResourceLoader.load_file_content(system_prompt_path) + system_prompt = ResourceLoader.load_file_content(self.content_system_prompt_path) if not system_prompt: - logging.error(f"Failed to load content system prompt from '{system_prompt_path}'.") + logging.error(f"Failed to load content system prompt from '{self.content_system_prompt_path}'.") return None, None # --- User Prompt (Logic moved from ResourceLoader.build_user_prompt) --- user_prompt = "" - prompts_dir = self.config.get("prompts_dir") - resource_dir_config = self.config.get("resource_dir", []) + prompts_dir = self.prompts_dir + resource_dir_config = self.resource_dir_config if not prompts_dir or not os.path.isdir(prompts_dir): logging.warning(f"Prompts directory '{prompts_dir}' not found or invalid. Content user prompt might be incomplete.") diff --git a/utils/tweet_generator.py b/utils/tweet_generator.py index 7ca1ee6..a5f26aa 100644 --- a/utils/tweet_generator.py +++ b/utils/tweet_generator.py @@ -26,6 +26,7 @@ from utils.prompt_manager import PromptManager # Keep this as it's importing fro from core import contentGen as core_contentGen from core import posterGen as core_posterGen from core import simple_collage as core_simple_collage +from .output_handler import OutputHandler # <-- 添加导入 class tweetTopic: def __init__(self, index, date, logic, object, product, product_logic, style, style_logic, target_audience, target_audience_logic): @@ -41,60 +42,29 @@ class tweetTopic: self.target_audience_logic = target_audience_logic class tweetTopicRecord: - def __init__(self, topics_list, system_prompt, user_prompt, output_dir, run_id): + def __init__(self, topics_list, system_prompt, user_prompt, run_id): self.topics_list = topics_list self.system_prompt = system_prompt self.user_prompt = user_prompt - self.output_dir = output_dir self.run_id = run_id - def save_topics(self, path): - try: - with open(path, "w", encoding="utf-8") as f: - json.dump(self.topics_list, f, ensure_ascii=False, indent=4) - logging.info(f"Topics list successfully saved to {path}") # Change to logging - except Exception as e: - # Keep print for traceback, but add logging - logging.exception(f"保存选题失败到 {path}: {e}") # Log exception - # print(f"保存选题失败到 {path}: {e}") - # print("--- Traceback for save_topics error ---") - # traceback.print_exc() - # print("--- End Traceback ---") - return False - return True - - def save_prompt(self, path): - try: - with open(path, "w", encoding="utf-8") as f: - f.write(self.system_prompt + "\n") - f.write(self.user_prompt + "\n") - # f.write(self.output_dir + "\n") # Output dir not needed in prompt file? - # f.write(self.run_id + "\n") # run_id not needed in prompt file? - logging.info(f"Prompts saved to {path}") - except Exception as e: - logging.exception(f"保存提示词失败: {e}") - # print(f"保存提示词失败: {e}") - return False - return True - class tweetContent: - def __init__(self, result, prompt, output_dir, run_id, article_index, variant_index): + def __init__(self, result, prompt, run_id, article_index, variant_index): self.result = result self.prompt = prompt - self.output_dir = output_dir self.run_id = run_id self.article_index = article_index self.variant_index = variant_index try: self.title, self.content = self.split_content(result) - self.json_file = self.gen_result_json() + self.json_data = self.gen_result_json() except Exception as e: logging.error(f"Failed to parse AI result for {article_index}_{variant_index}: {e}") logging.debug(f"Raw result: {result[:500]}...") # Log partial raw result self.title = "[Parsing Error]" self.content = "[Failed to parse AI content]" - self.json_file = {"title": self.title, "content": self.content, "error": True, "raw_result": result} + self.json_data = {"title": self.title, "content": self.content, "error": True, "raw_result": result} def split_content(self, result): # Assuming split logic might still fail, keep it simple or improve with regex/json @@ -122,28 +92,19 @@ class tweetContent: "title": self.title, "content": self.content } + # Add error flag if it exists + if hasattr(self, 'json_data') and self.json_data.get('error'): + json_file['error'] = True + json_file['raw_result'] = self.json_data.get('raw_result') return json_file - def save_content(self, json_path): - try: - with open(json_path, "w", encoding="utf-8") as f: - # If parsing failed, save the error structure - json.dump(self.json_file, f, ensure_ascii=False, indent=4) - logging.info(f"Content JSON saved to: {json_path}") - except Exception as e: - logging.exception(f"Failed to save content JSON to {json_path}: {e}") - return None # Indicate failure - return json_path + def get_json_data(self): + """Returns the generated JSON data dictionary.""" + return self.json_data - def save_prompt(self, path): - try: - with open(path, "w", encoding="utf-8") as f: - f.write(self.prompt + "\n") - logging.info(f"Content prompt saved to: {path}") - except Exception as e: - logging.exception(f"Failed to save content prompt to {path}: {e}") - return None # Indicate failure - return path + def get_prompt(self): + """Returns the user prompt used to generate this content.""" + return self.prompt def get_content(self): return self.content @@ -151,12 +112,9 @@ class tweetContent: def get_title(self): return self.title - def get_json_file(self): - return self.json_file - -def generate_topics(ai_agent, system_prompt, user_prompt, output_dir, run_id, temperature=0.2, top_p=0.5, presence_penalty=1.5): - """生成选题列表 (run_id is now passed in)""" +def generate_topics(ai_agent, system_prompt, user_prompt, run_id, temperature=0.2, top_p=0.5, presence_penalty=1.5): + """生成选题列表 (run_id is now passed in, output_dir removed as argument)""" logging.info("Starting topic generation...") time_start = time.time() @@ -171,26 +129,20 @@ def generate_topics(ai_agent, system_prompt, user_prompt, output_dir, run_id, te result_list = TopicParser.parse_topics(result) if not result_list: logging.warning("Topic parsing resulted in an empty list.") - # Optionally save raw result here if parsing fails? - error_log_path = os.path.join(output_dir, run_id, f"topic_parsing_error_{run_id}.txt") - try: - os.makedirs(os.path.dirname(error_log_path), exist_ok=True) - with open(error_log_path, "w", encoding="utf-8") as f_err: - f_err.write("--- Topic Parsing Failed ---\n") - f_err.write(result) - logging.info(f"Saved raw AI output due to parsing failure to: {error_log_path}") - except Exception as log_err: - logging.error(f"Failed to save raw AI output on parsing failure: {log_err}") + # Optionally handle raw result logging here if needed, but saving is responsibility of OutputHandler + # error_log_path = os.path.join(output_dir, run_id, f"topic_parsing_error_{run_id}.txt") # output_dir is not available here + # try: + # # ... (save raw output logic) ... + # except Exception as log_err: + # logging.error(f"Failed to save raw AI output on parsing failure: {log_err}") - # Create record object (even if list is empty) - tweet_topic_record = tweetTopicRecord(result_list, system_prompt, user_prompt, output_dir, run_id) - - return tweet_topic_record # Return only the record + # 直接返回解析后的列表 + return result_list -def generate_single_content(ai_agent, system_prompt, user_prompt, item, output_dir, run_id, +def generate_single_content(ai_agent, system_prompt, user_prompt, item, run_id, article_index, variant_index, temperature=0.3, top_p=0.4, presence_penalty=1.5): - """生成单篇文章内容. Requires prompts to be passed in.""" + """Generates single content variant data. Returns (content_json, user_prompt) or (None, None).""" logging.info(f"Generating content for topic {article_index}, variant {variant_index}") try: if not system_prompt or not user_prompt: @@ -201,29 +153,37 @@ def generate_single_content(ai_agent, system_prompt, user_prompt, item, output_d time.sleep(random.random() * 0.5) - # Generate content (updated return values) + # Generate content (non-streaming work returns result, tokens, time_cost) result, tokens, time_cost = ai_agent.work( system_prompt, user_prompt, "", temperature, top_p, presence_penalty ) + if result is None: # Check if AI call failed + logging.error(f"AI agent work failed for {article_index}_{variant_index}. No result returned.") + return None, None + logging.info(f"Content generation for {article_index}_{variant_index} completed in {time_cost:.2f}s. Estimated tokens: {tokens}") - # --- Correct directory structure --- - run_specific_output_dir = os.path.join(output_dir, run_id) - variant_result_dir = os.path.join(run_specific_output_dir, f"{article_index}_{variant_index}") - os.makedirs(variant_result_dir, exist_ok=True) + # --- Create tweetContent object (handles parsing) --- + # Pass user_prompt instead of full prompt? Yes, user_prompt is what we need later. + tweet_content = tweetContent(result, user_prompt, run_id, article_index, variant_index) - # Create tweetContent object (handles potential parsing errors inside its __init__) - tweet_content = tweetContent(result, user_prompt, output_dir, run_id, article_index, variant_index) - - # Save content and prompt - content_save_path = os.path.join(variant_result_dir, "article.json") - prompt_save_path = os.path.join(variant_result_dir, "tweet_prompt.txt") - tweet_content.save_content(content_save_path) - tweet_content.save_prompt(prompt_save_path) - # logging.info(f" Saved article content to: {content_save_path}") # Already logged in save_content + # --- Remove Saving Logic --- + # run_specific_output_dir = os.path.join(output_dir, run_id) # output_dir no longer available + # variant_result_dir = os.path.join(run_specific_output_dir, f"{article_index}_{variant_index}") + # os.makedirs(variant_result_dir, exist_ok=True) + # content_save_path = os.path.join(variant_result_dir, "article.json") + # prompt_save_path = os.path.join(variant_result_dir, "tweet_prompt.txt") + # tweet_content.save_content(content_save_path) # Method removed + # tweet_content.save_prompt(prompt_save_path) # Method removed + # --- End Remove Saving Logic --- + + # Return the data needed by the output handler + content_json = tweet_content.get_json_data() + prompt_data = tweet_content.get_prompt() # Get the stored user prompt + + return content_json, prompt_data # Return data pair - return tweet_content, result # Return object and raw result except Exception as e: logging.exception(f"Error generating single content for {article_index}_{variant_index}:") return None, None @@ -256,7 +216,7 @@ def generate_content(ai_agent, system_prompt, topics, output_dir, run_id, prompt # 调用单篇文章生成函数 tweet_content, result = generate_single_content( - ai_agent, system_prompt, item, output_dir, run_id, i+1, j+1, temperature + ai_agent, system_prompt, item, run_id, i+1, j+1, temperature ) if tweet_content: @@ -270,207 +230,248 @@ def generate_content(ai_agent, system_prompt, topics, output_dir, run_id, prompt return processed_results -def prepare_topic_generation( - config # Pass the whole config dictionary now - # select_date, select_num, - # system_prompt_path, user_prompt_path, - # base_url="vllm", model_name="qwenQWQ", api_key="EMPTY", - # gen_prompts_path="/root/autodl-tmp/TravelContentCreator/genPrompts", - # resource_dir="/root/autodl-tmp/TravelContentCreator/resource", - # output_dir="/root/autodl-tmp/TravelContentCreator/result" -): - """准备选题生成的环境和参数. Returns agent and prompts.""" +def prepare_topic_generation(prompt_manager: PromptManager, + api_url: str, + model_name: str, + api_key: str, + timeout: int, + max_retries: int): + """准备选题生成的环境和参数. Returns agent, system_prompt, user_prompt. - # Initialize PromptManager - prompt_manager = PromptManager(config) - - # Get prompts using PromptManager + Args: + prompt_manager: An initialized PromptManager instance. + api_url, model_name, api_key, timeout, max_retries: Parameters for AI_Agent. + """ + logging.info("Preparing for topic generation (using provided PromptManager)...") + # 从传入的 PromptManager 获取 prompts system_prompt, user_prompt = prompt_manager.get_topic_prompts() if not system_prompt or not user_prompt: - print("Error: Failed to get topic generation prompts.") - return None, None, None, None + logging.error("Failed to get topic generation prompts from PromptManager.") + return None, None, None # Return three Nones - # 创建AI Agent (still create agent here for the topic generation phase) + # 使用传入的参数初始化 AI Agent try: logging.info("Initializing AI Agent for topic generation...") - # --- Read timeout/retry from config --- - request_timeout = config.get("request_timeout", 30) # Default 30 seconds - max_retries = config.get("max_retries", 3) # Default 3 retries - # --- Pass values to AI_Agent --- ai_agent = AI_Agent( - config["api_url"], - config["model"], - config["api_key"], - timeout=request_timeout, - max_retries=max_retries + api_url, # Use passed arg + model_name, # Use passed arg + api_key, # Use passed arg + timeout=timeout, # Use passed arg + max_retries=max_retries # Use passed arg ) except Exception as e: logging.exception("Error initializing AI Agent for topic generation:") - traceback.print_exc() - return None, None, None, None + return None, None, None # Return three Nones - # Removed prompt loading/building logic, now handled by PromptManager - - # Return agent and the generated prompts - return ai_agent, system_prompt, user_prompt, config["output_dir"] + # 返回 agent 和 prompts + return ai_agent, system_prompt, user_prompt def run_topic_generation_pipeline(config, run_id=None): - """Runs the complete topic generation pipeline based on the configuration.""" + """ + Runs the complete topic generation pipeline based on the configuration. + Returns: (run_id, topics_list, system_prompt, user_prompt) or (None, None, None, None) on failure. + """ logging.info("Starting Step 1: Topic Generation Pipeline...") - - # --- Handle run_id --- + if run_id is None: logging.info("No run_id provided, generating one based on timestamp.") run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") else: logging.info(f"Using provided run_id: {run_id}") - # --- End run_id handling --- - # Prepare necessary inputs and the AI agent for topic generation - ai_agent, system_prompt, user_prompt, base_output_dir = None, None, None, None + ai_agent, system_prompt, user_prompt = None, None, None # Initialize + topics_list = None + prompt_manager = None # Initialize prompt_manager try: - # Pass the config directly to prepare_topic_generation - ai_agent, system_prompt, user_prompt, base_output_dir = prepare_topic_generation(config) + # --- 读取 PromptManager 所需参数 --- + topic_sys_prompt_path = config.get("topic_system_prompt") + topic_user_prompt_path = config.get("topic_user_prompt") + content_sys_prompt_path = config.get("content_system_prompt") # 虽然这里不用,但 PromptManager 可能需要 + prompts_dir_path = config.get("prompts_dir") + resource_config = config.get("resource_dir", []) + topic_num = config.get("num", 1) + topic_date = config.get("date", "") + # --- 创建 PromptManager 实例 --- + prompt_manager = PromptManager( + topic_system_prompt_path=topic_sys_prompt_path, + topic_user_prompt_path=topic_user_prompt_path, + content_system_prompt_path=content_sys_prompt_path, + prompts_dir=prompts_dir_path, + resource_dir_config=resource_config, + topic_gen_num=topic_num, + topic_gen_date=topic_date + ) + logging.info("PromptManager instance created.") + + # --- 读取 AI Agent 所需参数 --- + ai_api_url = config.get("api_url") + ai_model = config.get("model") + ai_api_key = config.get("api_key") + ai_timeout = config.get("request_timeout", 30) + ai_max_retries = config.get("max_retries", 3) + + # 检查必需的 AI 参数是否存在 + if not all([ai_api_url, ai_model, ai_api_key]): + raise ValueError("Missing required AI configuration (api_url, model, api_key) in config.") + + # --- 调用修改后的 prepare_topic_generation --- + ai_agent, system_prompt, user_prompt = prepare_topic_generation( + prompt_manager, # Pass instance + ai_api_url, + ai_model, + ai_api_key, + ai_timeout, + ai_max_retries + ) if not ai_agent or not system_prompt or not user_prompt: raise ValueError("Failed to prepare topic generation (agent or prompts missing).") - except Exception as e: - logging.exception("Error during topic generation preparation:") - traceback.print_exc() - return None, None - - # Generate topics using the prepared agent and prompts - try: - # Pass the determined run_id to generate_topics - tweet_topic_record = generate_topics( - ai_agent, system_prompt, user_prompt, config["output_dir"], - run_id, # Pass the run_id + # --- Generate topics (保持不变) --- + topics_list = generate_topics( + ai_agent, system_prompt, user_prompt, + run_id, # Pass run_id config.get("topic_temperature", 0.2), config.get("topic_top_p", 0.5), config.get("topic_presence_penalty", 1.5) ) except Exception as e: - logging.exception("Error during topic generation API call:") - traceback.print_exc() - if ai_agent: ai_agent.close() # Ensure agent is closed on error - return None, None + logging.exception("Error during topic generation pipeline execution:") + # Ensure agent is closed even if generation fails mid-way + if ai_agent: ai_agent.close() + return None, None, None, None # Signal failure + finally: + # Ensure the AI agent is closed after generation attempt (if initialized) + if ai_agent: + logging.info("Closing topic generation AI Agent...") + ai_agent.close() - # Ensure the AI agent is closed after generation - if ai_agent: - logging.info("Closing topic generation AI Agent...") - ai_agent.close() + if topics_list is None: # Check if generate_topics returned None (though it currently returns list) + logging.error("Topic generation failed (generate_topics returned None or an error occurred).") + return None, None, None, None + elif not topics_list: # Check if list is empty + logging.warning(f"Topic generation completed for run {run_id}, but the resulting topic list is empty.") + # Return empty list and prompts anyway, let caller decide - # Process results - if not tweet_topic_record: - logging.error("Topic generation failed (generate_topics returned None).") - return None, None # Return None for run_id as well if record is None + # --- Saving logic removed previously --- - # Use the determined run_id for output directory - output_dir = os.path.join(config["output_dir"], run_id) - try: - os.makedirs(output_dir, exist_ok=True) - # --- Debug: Print the data before attempting to save --- - logging.info("--- Debug: Data to be saved in tweet_topic.json ---") - logging.info(tweet_topic_record.topics_list) - logging.info("--- End Debug ---") - # --- End Debug --- - - # Save topics and prompt details - save_topics_success = tweet_topic_record.save_topics(os.path.join(output_dir, "tweet_topic.json")) - save_prompt_success = tweet_topic_record.save_prompt(os.path.join(output_dir, "tweet_prompt.txt")) - if not save_topics_success or not save_prompt_success: - logging.warning("Warning: Failed to save topic generation results or prompts.") - # Continue but warn user - - except Exception as e: - logging.exception("Error saving topic generation results:") - traceback.print_exc() - # Return the generated data even if saving fails, but maybe warn more strongly? - # return run_id, tweet_topic_record # Decide if partial success is okay - return None, None # Or consider failure if saving is critical - - logging.info(f"Topics generated successfully. Run ID: {run_id}") - # Return the determined run_id and the record - return run_id, tweet_topic_record + logging.info(f"Topic generation pipeline completed successfully. Run ID: {run_id}") + # Return the raw data needed by the OutputHandler + return run_id, topics_list, system_prompt, user_prompt # --- Decoupled Functional Units (Moved from main.py) --- -def generate_content_for_topic(ai_agent, prompt_manager, config, topic_item, output_dir, run_id, topic_index): - """Generates all content variants for a single topic item. - +def generate_content_for_topic(ai_agent: AI_Agent, + prompt_manager: PromptManager, + topic_item: dict, + run_id: str, + topic_index: int, + output_handler: OutputHandler, # Changed name to match convention + # 添加具体参数,移除 config 和 output_dir + variants: int, + temperature: float, + top_p: float, + presence_penalty: float): + """Generates all content variants for a single topic item and uses OutputHandler. + Args: - ai_agent: An initialized AI_Agent instance. - prompt_manager: An initialized PromptManager instance. - config: The global configuration dictionary. - topic_item: The dictionary representing a single topic. - output_dir: The base output directory for the entire run (e.g., ./result). - run_id: The ID for the current run. - topic_index: The 1-based index of the current topic. - + ai_agent: Initialized AI_Agent instance. + prompt_manager: Initialized PromptManager instance. + topic_item: Dictionary representing the topic. + run_id: Current run ID. + topic_index: 1-based index of the topic. + output_handler: An instance of OutputHandler to process results. + variants: Number of variants to generate. + temperature, top_p, presence_penalty: AI generation parameters. Returns: - A list of tweet content data (dictionaries) generated for the topic, - or None if generation failed. + bool: True if at least one variant was successfully generated and handled, False otherwise. """ logging.info(f"Generating content for Topic {topic_index} (Object: {topic_item.get('object', 'N/A')})...") - tweet_content_list = [] - variants = config.get("variants", 1) + success_flag = False # Track if any variant succeeded + # 使用传入的 variants 参数 + # variants = config.get("variants", 1) for j in range(variants): variant_index = j + 1 - logging.info(f"Generating Variant {variant_index}/{variants}...") + logging.info(f" Generating Variant {variant_index}/{variants}...") - # Get prompts for this specific topic item - # Assuming prompt_manager is correctly initialized and passed + # PromptManager 实例已传入,直接调用 content_system_prompt, content_user_prompt = prompt_manager.get_content_prompts(topic_item) if not content_system_prompt or not content_user_prompt: - logging.warning(f"Skipping Variant {variant_index} due to missing content prompts.") - continue # Skip this variant + logging.warning(f" Skipping Variant {variant_index} due to missing content prompts.") + continue time.sleep(random.random() * 0.5) try: - # Call the core generation function (generate_single_content is in this file) - tweet_content, gen_result = generate_single_content( + # Call generate_single_content with passed-in parameters + content_json, prompt_data = generate_single_content( ai_agent, content_system_prompt, content_user_prompt, topic_item, - output_dir, run_id, topic_index, variant_index, - config.get("content_temperature", 0.3), - config.get("content_top_p", 0.4), - config.get("content_presence_penalty", 1.5) + run_id, topic_index, variant_index, + temperature, # 使用传入的参数 + top_p, # 使用传入的参数 + presence_penalty # 使用传入的参数 ) - if tweet_content: - try: - tweet_content_data = tweet_content.get_json_file() - if tweet_content_data: - tweet_content_list.append(tweet_content_data) - else: - logging.warning(f"Warning: tweet_content.get_json_file() for Topic {topic_index}, Variant {variant_index} returned empty data.") - except Exception as parse_err: - logging.error(f"Error processing tweet content after generation for Topic {topic_index}, Variant {variant_index}: {parse_err}") + + # Check if generation succeeded and parsing was okay (or error handled within json) + if content_json is not None and prompt_data is not None: + # Use the output handler to process/save the result + output_handler.handle_content_variant( + run_id, topic_index, variant_index, content_json, prompt_data + ) + success_flag = True # Mark success for this topic + + # Check specifically if the AI result itself indicated a parsing error internally + if content_json.get("error"): + logging.error(f" Content generation for Topic {topic_index}, Variant {variant_index} succeeded but response parsing failed (error flag set in content). Raw data logged by handler.") + else: + logging.info(f" Successfully generated and handled content for Topic {topic_index}, Variant {variant_index}.") else: - logging.warning(f"Failed to generate content for Topic {topic_index}, Variant {variant_index}. Skipping.") + logging.error(f" Content generation failed for Topic {topic_index}, Variant {variant_index}. Skipping handling.") + except Exception as e: - logging.exception(f"Error during content generation for Topic {topic_index}, Variant {variant_index}:") - # traceback.print_exc() + logging.exception(f" Error during content generation call or handling for Topic {topic_index}, Variant {variant_index}:") - if not tweet_content_list: - logging.warning(f"No valid content generated for Topic {topic_index}.") - return None - else: - logging.info(f"Successfully generated {len(tweet_content_list)} content variants for Topic {topic_index}.") - return tweet_content_list + # Return the success flag for this topic + return success_flag -def generate_posters_for_topic(config, topic_item, tweet_content_list, output_dir, run_id, topic_index): - """Generates all posters for a single topic item based on its generated content. +def generate_posters_for_topic(topic_item: dict, + output_dir: str, + run_id: str, + topic_index: int, + output_handler: OutputHandler, # 添加 handler + variants: int, + poster_assets_base_dir: str, + image_base_dir: str, + modify_image_subdir: str, + resource_dir_config: list, + poster_target_size: tuple, + text_possibility: float, + output_collage_subdir: str, + output_poster_subdir: str, + output_poster_filename: str, + camera_image_subdir: str + ): + """Generates all posters for a single topic item, handling image data via OutputHandler. Args: - config: The global configuration dictionary. topic_item: The dictionary representing a single topic. - tweet_content_list: List of content data generated by generate_content_for_topic. output_dir: The base output directory for the entire run (e.g., ./result). run_id: The ID for the current run. topic_index: The 1-based index of the current topic. + variants: Number of variants. + poster_assets_base_dir: Path to poster assets (fonts, frames etc.). + image_base_dir: Base path for source images. + modify_image_subdir: Subdirectory for modified images. + resource_dir_config: Configuration for resource directories (used for Description). + poster_target_size: Target size tuple (width, height) for the poster. + text_possibility: Probability of adding secondary text. + output_collage_subdir: Subdirectory name for saving collages. + output_poster_subdir: Subdirectory name for saving posters. + output_poster_filename: Filename for the final poster. + camera_image_subdir: Subdirectory for camera images (currently unused in logic?). + output_handler: An instance of OutputHandler to process results. Returns: True if poster generation was attempted (regardless of individual variant success), @@ -478,30 +479,47 @@ def generate_posters_for_topic(config, topic_item, tweet_content_list, output_di """ logging.info(f"Generating posters for Topic {topic_index} (Object: {topic_item.get('object', 'N/A')})...") - # Initialize necessary generators here, assuming they are stateless or cheap to create - # Alternatively, pass initialized instances if they hold state or are expensive + # --- Load content data from files --- + loaded_content_list = [] + logging.info(f"Attempting to load content data for {variants} variants for topic {topic_index}...") + for j in range(variants): + variant_index = j + 1 + variant_dir = os.path.join(output_dir, run_id, f"{topic_index}_{variant_index}") + content_path = os.path.join(variant_dir, "article.json") + try: + if os.path.exists(content_path): + with open(content_path, 'r', encoding='utf-8') as f_content: + content_data = json.load(f_content) + if isinstance(content_data, dict) and 'title' in content_data and 'content' in content_data: + loaded_content_list.append(content_data) + logging.debug(f" Successfully loaded content from: {content_path}") + else: + logging.warning(f" Content file {content_path} has invalid format. Skipping.") + else: + logging.warning(f" Content file not found for variant {variant_index}: {content_path}. Skipping.") + except json.JSONDecodeError: + logging.error(f" Error decoding JSON from content file: {content_path}. Skipping.") + except Exception as e: + logging.exception(f" Error loading content file {content_path}: {e}") + + if not loaded_content_list: + logging.error(f"No valid content data loaded for topic {topic_index}. Cannot generate posters.") + return False + logging.info(f"Successfully loaded content data for {len(loaded_content_list)} variants.") + # --- End Load content data --- + + # Initialize generators using parameters try: content_gen_instance = core_contentGen.ContentGenerator() - # poster_gen_instance = core_posterGen.PosterGenerator() - # --- Read poster assets base dir from config --- - poster_assets_base_dir = config.get("poster_assets_base_dir") if not poster_assets_base_dir: - logging.error("Error: 'poster_assets_base_dir' not found in configuration. Cannot generate posters.") - return False # Cannot proceed without assets base dir - # --- Initialize PosterGenerator with the base dir --- + logging.error("Error: 'poster_assets_base_dir' not provided. Cannot generate posters.") + return False poster_gen_instance = core_posterGen.PosterGenerator(base_dir=poster_assets_base_dir) except Exception as e: logging.exception("Error initializing generators for poster creation:") return False - # --- Setup: Paths and Object Name --- - image_base_dir = config.get("image_base_dir") - if not image_base_dir: - logging.error("Error: image_base_dir missing in config for poster generation.") - return False - modify_image_subdir = config.get("modify_image_subdir", "modify") - camera_image_subdir = config.get("camera_image_subdir", "相机") - + # --- Setup: Paths and Object Name --- object_name = topic_item.get("object", "") if not object_name: logging.warning("Warning: Topic object name is missing. Cannot generate posters.") @@ -513,110 +531,95 @@ def generate_posters_for_topic(config, topic_item, tweet_content_list, output_di if not object_name_cleaned: logging.warning(f"Warning: Object name '{object_name}' resulted in empty string after cleaning. Skipping posters.") return False - object_name = object_name_cleaned # Use the cleaned name for searching + object_name = object_name_cleaned except Exception as e: logging.warning(f"Warning: Could not fully clean object name '{object_name}': {e}. Skipping posters.") return False - # Construct and check INPUT image paths (still needed for collage) - input_img_dir_path = os.path.join(image_base_dir, modify_image_subdir, object_name) + # Construct and check INPUT image paths + input_img_dir_path = os.path.join(image_base_dir, modify_image_subdir, object_name) if not os.path.exists(input_img_dir_path) or not os.path.isdir(input_img_dir_path): logging.warning(f"Warning: Modify Image directory not found or not a directory: '{input_img_dir_path}'. Skipping posters for this topic.") return False - # --- NEW: Locate Description File using resource_dir type "Description" --- + # Locate Description File using resource_dir_config parameter info_directory = [] description_file_path = None - resource_dir_config = config.get("resource_dir", []) found_description = False - for dir_info in resource_dir_config: if dir_info.get("type") == "Description": for file_path in dir_info.get("file_path", []): - # Match description file based on object name containment if object_name in os.path.basename(file_path): description_file_path = file_path if os.path.exists(description_file_path): - info_directory = [description_file_path] # Pass the found path + info_directory = [description_file_path] logging.info(f"Found and using description file from config: {description_file_path}") found_description = True else: logging.warning(f"Warning: Description file specified in config not found: {description_file_path}") - break # Found the matching entry in this list - if found_description: # Stop searching resource_dir if found + break + if found_description: break - if not found_description: logging.info(f"Warning: No matching description file found for object '{object_name}' in config resource_dir (type='Description').") - # --- End NEW Description File Logic --- - - # --- Generate Text Configurations for All Variants --- + + # Generate Text Configurations for All Variants try: - poster_text_configs_raw = content_gen_instance.run(info_directory, config["variants"], tweet_content_list) + poster_text_configs_raw = content_gen_instance.run(info_directory, variants, loaded_content_list) if not poster_text_configs_raw: logging.warning("Warning: ContentGenerator returned empty configuration data. Skipping posters.") return False - # --- Save the COMPLETE poster configs list for this topic --- - run_output_dir_base = os.path.join(output_dir, run_id) - topic_config_save_path = os.path.join(run_output_dir_base, f"topic_{topic_index}_poster_configs.json") - try: - # Assuming poster_text_configs_raw is JSON-serializable (likely a list/dict) - with open(topic_config_save_path, 'w', encoding='utf-8') as f_cfg_topic: - json.dump(poster_text_configs_raw, f_cfg_topic, ensure_ascii=False, indent=4) - logging.info(f"Saved complete poster configurations for topic {topic_index} to: {topic_config_save_path}") - except Exception as save_err: - logging.error(f"Failed to save complete poster configurations for topic {topic_index} to {topic_config_save_path}: {save_err}") - # --- End Save Complete Config --- + # --- 使用 OutputHandler 保存 Poster Config --- + output_handler.handle_poster_configs(run_id, topic_index, poster_text_configs_raw) + # --- 结束使用 Handler 保存 --- poster_config_summary = core_posterGen.PosterConfig(poster_text_configs_raw) except Exception as e: logging.exception("Error running ContentGenerator or parsing poster configs:") traceback.print_exc() - return False # Cannot proceed if text config fails + return False - # --- Poster Generation Loop for each variant --- - poster_num = config.get("variants", 1) - target_size = tuple(config.get("poster_target_size", [900, 1200])) + # Poster Generation Loop for each variant + poster_num = variants any_poster_attempted = False - text_possibility = config.get("text_possibility", 0.3) # Get from config for j_index in range(poster_num): variant_index = j_index + 1 logging.info(f"Generating Poster {variant_index}/{poster_num}...") any_poster_attempted = True + collage_img = None # To store the generated collage PIL Image + poster_img = None # To store the final poster PIL Image try: poster_config = poster_config_summary.get_config_by_index(j_index) if not poster_config: logging.warning(f"Warning: Could not get poster config for index {j_index}. Skipping.") continue - # Define output directories for this specific variant - run_output_dir = os.path.join(output_dir, run_id) - variant_output_dir = os.path.join(run_output_dir, f"{topic_index}_{variant_index}") - output_collage_subdir = config.get("output_collage_subdir", "collage_img") - output_poster_subdir = config.get("output_poster_subdir", "poster") - collage_output_dir = os.path.join(variant_output_dir, output_collage_subdir) - poster_output_dir = os.path.join(variant_output_dir, output_poster_subdir) - os.makedirs(collage_output_dir, exist_ok=True) - os.makedirs(poster_output_dir, exist_ok=True) - - # --- Image Collage --- + # --- Image Collage --- logging.info(f"Generating collage from: {input_img_dir_path}") - img_list = core_simple_collage.process_directory( + collage_images = core_simple_collage.process_directory( input_img_dir_path, - target_size=target_size, - output_count=1, - output_dir=collage_output_dir + target_size=poster_target_size, + output_count=1 ) - if not img_list or len(img_list) == 0 or not img_list[0].get('path'): + if not collage_images: # 检查列表是否为空 logging.warning(f"Warning: Failed to generate collage image for Variant {variant_index}. Skipping poster.") continue - collage_img_path = img_list[0]['path'] - logging.info(f"Using collage image: {collage_img_path}") + collage_img = collage_images[0] # 获取第一个 PIL Image + logging.info(f"Collage image generated successfully (in memory).") + + # --- 使用 Handler 保存 Collage 图片 --- + output_handler.handle_generated_image( + run_id, topic_index, variant_index, + image_type='collage', + image_data=collage_img, + output_filename='collage.png' # 或者其他期望的文件名 + ) + # --- 结束保存 Collage --- - # --- Create Poster --- + # --- Create Poster --- text_data = { "title": poster_config.get('main_title', 'Default Title'), "subtitle": "", @@ -624,98 +627,37 @@ def generate_posters_for_topic(config, topic_item, tweet_content_list, output_di } texts = poster_config.get('texts', []) if texts: - # Ensure TEXT_POSBILITY is accessible, maybe pass via config? - # text_possibility = config.get("text_possibility", 0.3) text_data["additional_texts"].append({"text": texts[0], "position": "bottom", "size_factor": 0.5}) - if len(texts) > 1 and random.random() < text_possibility: # Use variable from config + if len(texts) > 1 and random.random() < text_possibility: text_data["additional_texts"].append({"text": texts[1], "position": "bottom", "size_factor": 0.5}) - # final_poster_path = os.path.join(poster_output_dir, "poster.jpg") # Filename "poster.jpg" is hardcoded - output_poster_filename = config.get("output_poster_filename", "poster.jpg") - final_poster_path = os.path.join(poster_output_dir, output_poster_filename) - result_path = poster_gen_instance.create_poster(collage_img_path, text_data, final_poster_path) # Uses hardcoded output filename - if result_path: - logging.info(f"Successfully generated poster: {result_path}") + # 调用修改后的 create_poster, 接收 PIL Image + poster_img = poster_gen_instance.create_poster(collage_img, text_data) + + if poster_img: + logging.info(f"Poster image generated successfully (in memory).") + # --- 使用 Handler 保存 Poster 图片 --- + output_handler.handle_generated_image( + run_id, topic_index, variant_index, + image_type='poster', + image_data=poster_img, + output_filename=output_poster_filename # 使用参数中的文件名 + ) + # --- 结束保存 Poster --- else: - logging.warning(f"Warning: Poster generation function did not return a valid path for {final_poster_path}.") + logging.warning(f"Warning: Poster generation function returned None for variant {variant_index}.") except Exception as e: logging.exception(f"Error during poster generation for Variant {variant_index}:") traceback.print_exc() - continue # Continue to next variant + continue return any_poster_attempted -def main(): - """主函数入口""" - config_file = { - "date": "4月17日", - "num": 5, - "model": "qwenQWQ", - "api_url": "vllm", - "api_key": "EMPTY", - "topic_system_prompt": "/root/autodl-tmp/TravelContentCreator/SelectPrompt/systemPrompt.txt", - "topic_user_prompt": "/root/autodl-tmp/TravelContentCreator/SelectPrompt/userPrompt.txt", - "content_system_prompt": "/root/autodl-tmp/TravelContentCreator/genPrompts/systemPrompt.txt", - "resource_dir": [{ - "type": "Object", - "num": 4, - "file_path": ["/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-尚书第.txt", - "/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-明清园.txt", - "/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-泰宁古城.txt", - "/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-甘露寺.txt" - ]}, - { - "type": "Product", - "num": 0, - "file_path": [] - } - ], - "prompts_dir": "/root/autodl-tmp/TravelContentCreator/genPrompts", - "output_dir": "/root/autodl-tmp/TravelContentCreator/result", - "variants": 2, - "topic_temperature": 0.2, - "content_temperature": 0.3 - } - - if True: - # 1. 首先生成选题 - ai_agent, system_prompt, user_prompt, output_dir = prepare_topic_generation( - config_file - ) - - run_id, tweet_topic_record = generate_topics( - ai_agent, system_prompt, user_prompt, config_file["output_dir"], - config_file["topic_temperature"], 0.5, 1.5 - ) - - output_dir = os.path.join(config_file["output_dir"], run_id) - os.makedirs(output_dir, exist_ok=True) - tweet_topic_record.save_topics(os.path.join(output_dir, "tweet_topic.json")) - tweet_topic_record.save_prompt(os.path.join(output_dir, "tweet_prompt.txt")) - # raise Exception("选题生成失败,退出程序") - if not run_id or not tweet_topic_record: - print("选题生成失败,退出程序") - return - - # 2. 然后生成内容 - print("\n开始根据选题生成内容...") - - # 加载内容生成的系统提示词 - content_system_prompt = ResourceLoader.load_system_prompt(config_file["content_system_prompt"]) - - - if not content_system_prompt: - print("内容生成系统提示词为空,使用选题生成的系统提示词") - content_system_prompt = system_prompt - - # 直接使用同一个AI Agent实例 - result = generate_content( - ai_agent, content_system_prompt, tweet_topic_record.topics_list, output_dir, run_id, config_file["prompts_dir"], config_file["resource_dir"], - config_file["variants"], config_file["content_temperature"] - ) - - - -if __name__ == "__main__": - main() \ No newline at end of file +# main 函数不再使用,注释掉或移除 +# def main(): +# """主函数入口""" +# # ... (旧的 main 函数逻辑) +# +# if __name__ == "__main__": +# main() \ No newline at end of file