From 0c9e7f90aef062e56a1edae6956ce21f75fc21d5 Mon Sep 17 00:00:00 2001 From: jinye_huang Date: Tue, 22 Apr 2025 14:16:29 +0800 Subject: [PATCH] =?UTF-8?q?=E8=BF=81=E7=A7=BB=E4=BA=86=E9=80=89=E9=A2=98?= =?UTF-8?q?=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/run_step1_topics.py | 7 +- examples/test_workflow.py | 7 +- main.py | 142 +++++++++++++++++------------------ utils/tweet_generator.py | 62 +++++++++++++++ 4 files changed, 139 insertions(+), 79 deletions(-) diff --git a/examples/run_step1_topics.py b/examples/run_step1_topics.py index 2f0cf6b..699217d 100644 --- a/examples/run_step1_topics.py +++ b/examples/run_step1_topics.py @@ -5,12 +5,14 @@ """ import os import sys +import traceback # 添加项目根目录到Python路径 sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # 导入所需模块 -from main import load_config, generate_topics_step +from main import load_config +from utils.tweet_generator import run_topic_generation_pipeline if __name__ == "__main__": print("==== 阶段 1: 仅生成选题 ====") @@ -24,7 +26,7 @@ if __name__ == "__main__": # 2. 执行选题生成 print("\n执行选题生成...") - run_id, tweet_topic_record = generate_topics_step(config) + run_id, tweet_topic_record = run_topic_generation_pipeline(config) if run_id and tweet_topic_record: output_dir = config.get("output_dir", "./result") @@ -40,6 +42,5 @@ if __name__ == "__main__": except Exception as e: print(f"\n处理过程中出错: {e}") - import traceback traceback.print_exc() sys.exit(1) \ No newline at end of file diff --git a/examples/test_workflow.py b/examples/test_workflow.py index e061308..20ba441 100644 --- a/examples/test_workflow.py +++ b/examples/test_workflow.py @@ -13,6 +13,8 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # 导入所需模块 from main import load_config, generate_topics_step, generate_content_and_posters_step +from utils.tweet_generator import run_topic_generation_pipeline +from core.topic_parser import TopicParser def test_full_workflow(): """测试完整的工作流程,从选题生成到海报制作""" @@ -26,7 +28,7 @@ def test_full_workflow(): # 2. 执行选题生成 print("\n步骤 2: 生成选题...") - run_id, tweet_topic_record = generate_topics_step(config) + run_id, tweet_topic_record = run_topic_generation_pipeline(config) if not run_id or not tweet_topic_record: print("选题生成失败,测试终止。") @@ -84,7 +86,7 @@ def test_steps_separately(): # 2. 仅执行选题生成 print("\n步骤 2: 仅测试选题生成...") - run_id, tweet_topic_record = generate_topics_step(test_config) + run_id, tweet_topic_record = run_topic_generation_pipeline(test_config) if not run_id or not tweet_topic_record: print("选题生成失败,测试终止。") @@ -104,7 +106,6 @@ def test_steps_separately(): # 这部分通常是由main函数中的流程自动处理的 # 这里为了演示分段流程,模拟手动加载数据并处理 - from core.topic_parser import TopicParser if os.path.exists(topics_file): with open(topics_file, 'r', encoding='utf-8') as f: diff --git a/main.py b/main.py index c1ad52c..f097398 100644 --- a/main.py +++ b/main.py @@ -7,15 +7,15 @@ import traceback import json from core.ai_agent import AI_Agent -from core.topic_parser import TopicParser +# from core.topic_parser import TopicParser # No longer needed directly in main? import core.contentGen as contentGen import core.posterGen as posterGen import core.simple_collage as simple_collage from utils.resource_loader import ResourceLoader -from utils.tweet_generator import prepare_topic_generation, generate_topics, generate_single_content +from utils.tweet_generator import generate_single_content, run_topic_generation_pipeline # Import the new pipeline function import random -TEXT_POSBILITY = 0.3 +TEXT_POSBILITY = 0.3 # Consider moving this to config if it varies def load_config(config_path="poster_gen_config.json"): """Loads configuration from a JSON file.""" @@ -29,8 +29,8 @@ def load_config(config_path="poster_gen_config.json"): # Basic validation (can be expanded) required_keys = ["api_url", "model", "api_key", "resource_dir", "prompts_dir", "output_dir", "num", "variants", "topic_system_prompt", "topic_user_prompt", "content_system_prompt", "image_base_dir"] if not all(key in config for key in required_keys): - print(f"Error: Config file '{config_path}' is missing one or more required keys.") - print(f"Required keys are: {required_keys}") + missing_keys = [key for key in required_keys if key not in config] + print(f"Error: Config file '{config_path}' is missing required keys: {missing_keys}") sys.exit(1) # Resolve relative paths based on config location or a defined base path if necessary # For simplicity, assuming paths in config are relative to project root or absolute @@ -42,35 +42,8 @@ def load_config(config_path="poster_gen_config.json"): print(f"Error loading configuration from '{config_path}': {e}") sys.exit(1) - -def generate_topics_step(config): - """Generates topics based on the configuration.""" - print("Step 1: Generating Topics...") - ai_agent, system_prompt, user_prompt, base_output_dir = prepare_topic_generation( - config.get("date", datetime.now().strftime("%Y-%m-%d")), # Use current date if not specified - config["num"], config["topic_system_prompt"], config["topic_user_prompt"], - config["api_url"], config["model"], config["api_key"], config["prompts_dir"], - config["resource_dir"], config["output_dir"] - ) - - run_id, tweet_topic_record = generate_topics( - ai_agent, system_prompt, user_prompt, config["output_dir"], - config.get("topic_temperature", 0.2), config.get("topic_top_p", 0.5), config.get("topic_max_tokens", 1.5) # Added defaults for safety - ) - - if not run_id or not tweet_topic_record: - print("Topic generation failed. Exiting.") - ai_agent.close() - return None, None - - output_dir = os.path.join(config["output_dir"], run_id) - os.makedirs(output_dir, exist_ok=True) - tweet_topic_record.save_topics(os.path.join(output_dir, "tweet_topic.json")) - tweet_topic_record.save_prompt(os.path.join(output_dir, "tweet_prompt.txt")) - ai_agent.close() - print(f"Topics generated successfully. Run ID: {run_id}") - return run_id, tweet_topic_record - +# Removed generate_topics_step function definition from here +# Its logic is now in utils.tweet_generator.run_topic_generation_pipeline def generate_content_and_posters_step(config, run_id, tweet_topic_record): """Generates content and posters based on generated topics.""" @@ -78,55 +51,75 @@ def generate_content_and_posters_step(config, run_id, tweet_topic_record): print("Missing run_id or topics data. Skipping content and poster generation.") return - print("Step 2: Generating Content and Posters...") + print("\nStep 2: Generating Content and Posters...") base_output_dir = config["output_dir"] output_dir = os.path.join(base_output_dir, run_id) # Directory for this specific run - # Load content generation system prompt + # --- Pre-load resources and initialize shared objects --- + # Load content generation system prompt once content_system_prompt = ResourceLoader.load_system_prompt(config["content_system_prompt"]) if not content_system_prompt: print("Warning: Content generation system prompt is empty. Using default logic if available or might fail.") - # Potentially load topic system prompt as fallback if needed, or handle error - # content_system_prompt = ResourceLoader.load_system_prompt(config["topic_system_prompt"]) + # Initialize AI Agent once for the entire content generation phase + ai_agent = None + try: + print(f"Initializing AI Agent ({config['model']})...") + ai_agent = AI_Agent(config["api_url"], config["model"], config["api_key"]) + except Exception as e: + print(f"Error initializing AI Agent: {e}. Cannot proceed with content generation.") + traceback.print_exc() + return # Cannot continue without AI agent + # Check image base directory image_base_dir = config.get("image_base_dir", None) - if not image_base_dir: - print("Error: 'image_base_dir' not specified in config. Cannot locate images.") + if not image_base_dir or not os.path.isdir(image_base_dir): + print(f"Error: 'image_base_dir' ({image_base_dir}) not specified or not a valid directory in config. Cannot locate images.") + if ai_agent: ai_agent.close() # Close agent if initialized return camera_image_subdir = config.get("camera_image_subdir", "相机") # Default '相机' modify_image_subdir = config.get("modify_image_subdir", "modify") # Default 'modify' + # Initialize ContentGenerator and PosterGenerator once if they are stateless + # Assuming they are stateless for now + content_gen = contentGen.ContentGenerator() + poster_gen_instance = posterGen.PosterGenerator() + # --- Process each topic --- for i, topic in enumerate(tweet_topic_record.topics_list): topic_index = i + 1 - print(f"Processing Topic {topic_index}/{len(tweet_topic_record.topics_list)}: {topic.get('title', 'N/A')}") + print(f"\nProcessing Topic {topic_index}/{len(tweet_topic_record.topics_list)}: {topic.get('title', 'N/A')}") tweet_content_list = [] - # --- Content Generation Loop --- + # --- Content Generation Loop (using the single AI Agent) --- for j in range(config["variants"]): variant_index = j + 1 print(f" Generating Variant {variant_index}/{config['variants']}...") - time.sleep(random.random()) # Keep the random delay? Okay for now. - ai_agent = AI_Agent(config["api_url"], config["model"], config["api_key"]) + time.sleep(random.random() * 0.5) # Slightly reduced delay try: + # Use the pre-initialized AI Agent tweet_content, gen_result = generate_single_content( ai_agent, content_system_prompt, topic, config["prompts_dir"], config["resource_dir"], - output_dir, run_id, topic_index, variant_index, config.get("content_temperature", 0.3) # Added default + output_dir, run_id, topic_index, variant_index, config.get("content_temperature", 0.3) ) if tweet_content: - tweet_content_list.append(tweet_content.get_json_file()) # Assuming this returns the structured data needed later + # Assuming get_json_file() returns a dictionary or similar structure + tweet_content_data = tweet_content.get_json_file() + if tweet_content_data: + tweet_content_list.append(tweet_content_data) + else: + print(f" Warning: generate_single_content for Topic {topic_index}, Variant {variant_index} returned empty data.") else: print(f" Failed to generate content for Topic {topic_index}, Variant {variant_index}. Skipping.") except Exception as e: print(f" Error during content generation for Topic {topic_index}, Variant {variant_index}: {e}") - traceback.print_exc() - finally: - ai_agent.close() # Ensure agent is closed + # Decide if traceback is needed here, might be too verbose for loop errors + # traceback.print_exc() + # Do NOT close the agent here if not tweet_content_list: - print(f" No content generated for Topic {topic_index}. Skipping poster generation.") + print(f" No valid content generated for Topic {topic_index}. Skipping poster generation.") continue # --- Poster Generation Setup --- @@ -137,21 +130,22 @@ def generate_content_and_posters_step(config, run_id, tweet_topic_record): # Clean object name (consider making this a utility function) try: - object_name = object_name.split(".")[0] - if "景点信息-" in object_name: - object_name = object_name.split("景点信息-")[1] - # Handle cases like "景点A+景点B"? Needs clearer logic if required. + # More robust cleaning might be needed depending on actual object name formats + object_name_cleaned = object_name.split(".")[0].replace("景点信息-", "").strip() + if not object_name_cleaned: + print(f" Warning: Object name '{object_name}' resulted in empty string after cleaning.") + continue + object_name = object_name_cleaned except Exception as e: print(f" Warning: Could not fully clean object name '{object_name}': {e}") + # Continue with potentially unclean name? Or skip? + # Let's continue for now, path checks below might catch issues. # Construct and check image paths using config base dir - # Path for collage/poster input images (e.g., from 'modify' dir) input_img_dir_path = os.path.join(image_base_dir, modify_image_subdir, object_name) - # Path for potential description file (e.g., from '相机' dir) camera_img_dir_path = os.path.join(image_base_dir, camera_image_subdir, object_name) description_file_path = os.path.join(camera_img_dir_path, "description.txt") - if not os.path.exists(input_img_dir_path) or not os.path.isdir(input_img_dir_path): print(f" Image directory not found or not a directory: '{input_img_dir_path}'. Skipping poster generation for this topic.") continue @@ -164,21 +158,22 @@ def generate_content_and_posters_step(config, run_id, tweet_topic_record): print(f" Description file not found: '{description_file_path}'. Using generated content for poster text.") # --- Generate Text Configurations for Posters --- - content_gen = contentGen.ContentGenerator() try: - # Assuming tweet_content_list contains the JSON data needed by content_gen.run + # Pass the list of content data directly poster_text_configs_raw = content_gen.run(info_directory, config["variants"], tweet_content_list) - print(f" Raw poster text configs: {poster_text_configs_raw}") # For debugging + # print(f" Raw poster text configs: {poster_text_configs_raw}") # For debugging + if not poster_text_configs_raw: + print(" Warning: ContentGenerator returned empty configuration data.") + continue # Skip if no text configs generated poster_config_summary = posterGen.PosterConfig(poster_text_configs_raw) except Exception as e: print(f" Error running ContentGenerator or parsing poster configs: {e}") traceback.print_exc() continue # Skip poster generation for this topic - # --- Poster Generation Loop --- poster_num = config["variants"] # Same as content variants - target_size = tuple(config.get("poster_target_size", [900, 1200])) # Add default size + target_size = tuple(config.get("poster_target_size", [900, 1200])) for j_index in range(poster_num): variant_index = j_index + 1 @@ -202,7 +197,6 @@ def generate_content_and_posters_step(config, run_id, tweet_topic_record): output_count=1, # Assuming 1 collage image per poster variant output_dir=collage_output_dir ) - # print(f" Collage image list: {img_list}") # Debugging if not img_list or len(img_list) == 0 or not img_list[0].get('path'): print(f" Failed to generate collage image for Variant {variant_index}. Skipping poster.") @@ -210,24 +204,18 @@ def generate_content_and_posters_step(config, run_id, tweet_topic_record): collage_img_path = img_list[0]['path'] print(f" Using collage image: {collage_img_path}") - # --- Create Poster --- - poster_gen_instance = posterGen.PosterGenerator() # Renamed to avoid conflict - - # Prepare text data (Simplified logic, adjust TEXT_POSBILITY if needed) - # Consider moving text data preparation into posterGen or a dedicated function + # --- Create Poster (using the single poster_gen_instance) --- text_data = { "title": poster_config.get('main_title', 'Default Title'), - "subtitle": "", # Subtitle seems unused? + "subtitle": "", "additional_texts": [] } texts = poster_config.get('texts', []) if texts: text_data["additional_texts"].append({"text": texts[0], "position": "bottom", "size_factor": 0.5}) - if len(texts) > 1 and random.random() < TEXT_POSBILITY: # Apply possibility check + if len(texts) > 1 and random.random() < TEXT_POSBILITY: text_data["additional_texts"].append({"text": texts[1], "position": "bottom", "size_factor": 0.5}) - # print(f" Text data for poster: {text_data}") # Debugging - final_poster_path = os.path.join(poster_output_dir, "poster.jpg") result_path = poster_gen_instance.create_poster(collage_img_path, text_data, final_poster_path) if result_path: @@ -240,13 +228,21 @@ def generate_content_and_posters_step(config, run_id, tweet_topic_record): traceback.print_exc() continue # Continue to next variant + # --- Cleanup --- + # Close the AI Agent after processing all topics + if ai_agent: + print("\nClosing AI Agent...") + ai_agent.close() + def main(): # No argparse for now, directly load default config config = load_config() # Load from poster_gen_config.json # Execute steps sequentially - run_id, tweet_topic_record = generate_topics_step(config) + # Step 1: Generate Topics (using the function from utils.tweet_generator) + run_id, tweet_topic_record = run_topic_generation_pipeline(config) + # Step 2: Generate Content and Posters (if Step 1 was successful) if run_id and tweet_topic_record: generate_content_and_posters_step(config, run_id, tweet_topic_record) else: diff --git a/utils/tweet_generator.py b/utils/tweet_generator.py index 4719fc8..9263da7 100644 --- a/utils/tweet_generator.py +++ b/utils/tweet_generator.py @@ -8,6 +8,7 @@ import argparse import json from datetime import datetime import sys +import traceback sys.path.append('/root/autodl-tmp') # 从本地模块导入 from TravelContentCreator.core.ai_agent import AI_Agent @@ -259,6 +260,67 @@ def prepare_topic_generation( return ai_agent, system_prompt, user_prompt, output_dir +def run_topic_generation_pipeline(config): + """Runs the complete topic generation pipeline based on the configuration.""" + print("Step 1: Generating Topics...") + + # Prepare necessary inputs and the AI agent for topic generation + # Note: prepare_topic_generation already initializes an AI_Agent + try: + ai_agent, system_prompt, user_prompt, base_output_dir = prepare_topic_generation( + config.get("date", datetime.now().strftime("%Y-%m-%d")), # Use current date if not specified + config["num"], config["topic_system_prompt"], config["topic_user_prompt"], + config["api_url"], config["model"], config["api_key"], config["prompts_dir"], + config["resource_dir"], config["output_dir"] + ) + except Exception as e: + print(f"Error during topic generation preparation: {e}") + traceback.print_exc() + return None, None + + # Generate topics using the prepared agent and prompts + try: + run_id, tweet_topic_record = generate_topics( + ai_agent, system_prompt, user_prompt, config["output_dir"], + config.get("topic_temperature", 0.2), + config.get("topic_top_p", 0.5), + config.get("topic_max_tokens", 1.5) # Consider if max_tokens name is accurate here (was presence_penalty?) + ) + except Exception as e: + print(f"Error during topic generation API call: {e}") + traceback.print_exc() + if ai_agent: ai_agent.close() # Ensure agent is closed on error + return None, None + + # Ensure the AI agent is closed after generation + if ai_agent: + ai_agent.close() + + # Process results + if not run_id or not tweet_topic_record: + print("Topic generation failed (no run_id or topics returned).") + return None, None + + output_dir = os.path.join(config["output_dir"], run_id) + try: + os.makedirs(output_dir, exist_ok=True) + # Save topics and prompt details + save_topics_success = tweet_topic_record.save_topics(os.path.join(output_dir, "tweet_topic.json")) + save_prompt_success = tweet_topic_record.save_prompt(os.path.join(output_dir, "tweet_prompt.txt")) + if not save_topics_success or not save_prompt_success: + print("Warning: Failed to save topic generation results or prompts.") + # Continue but warn user + + except Exception as e: + print(f"Error saving topic generation results: {e}") + traceback.print_exc() + # Return the generated data even if saving fails, but maybe warn more strongly? + # return run_id, tweet_topic_record # Decide if partial success is okay + return None, None # Or consider failure if saving is critical + + print(f"Topics generated successfully. Run ID: {run_id}") + return run_id, tweet_topic_record + def main(): """主函数入口""" config_file = {