diff --git a/examples/test_pipeline_steps.py b/examples/test_pipeline_steps.py index 3311315..0256a8f 100644 --- a/examples/test_pipeline_steps.py +++ b/examples/test_pipeline_steps.py @@ -8,128 +8,229 @@ import time from datetime import datetime import logging -# Add project root to the Python path -project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.insert(0, project_root) +# --- Setup Project Root Path --- +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if PROJECT_ROOT not in sys.path: + sys.path.append(PROJECT_ROOT) +# --- End Setup --- -from core.ai_agent import AI_Agent -from utils.prompt_manager import PromptManager -from utils.tweet_generator import ( - run_topic_generation_pipeline, - generate_content_for_topic, - generate_posters_for_topic -) +# --- Imports from the project --- +try: + from utils.tweet_generator import run_topic_generation_pipeline, generate_content_for_topic, generate_posters_for_topic + from core.topic_parser import TopicParser + from utils.prompt_manager import PromptManager # Needed for content generation + from core.ai_agent import AI_Agent # Needed for content generation + # from utils.tweet_generator import tweetTopicRecord # No longer needed directly + from utils.output_handler import FileSystemOutputHandler, OutputHandler # Import handlers +except ImportError as e: + logging.critical(f"ImportError: {e}. Ensure all core/utils modules are available and '{PROJECT_ROOT}' is in sys.path.") + sys.exit(1) +# --- End Imports --- -def load_config(config_path="/root/autodl-tmp/TravelContentCreator/poster_gen_config.json"): - """Loads configuration relative to the script.""" - if not os.path.exists(config_path): - logging.error(f"Error: Config file '{config_path}' not found.") - sys.exit(1) +def load_config(config_path): + """Loads configuration from a JSON file.""" try: with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) - logging.info("Configuration loaded successfully.") + logging.info(f"Config loaded successfully from {config_path}") return config + except FileNotFoundError: + logging.error(f"Error: Configuration file not found at {config_path}") + return None + except json.JSONDecodeError: + logging.error(f"Error: Could not decode JSON from {config_path}") + return None except Exception as e: - logging.error(f"Error loading configuration: {e}") - sys.exit(1) + logging.exception(f"An unexpected error occurred loading config {config_path}:") + return None def main_test(): + # --- Basic Logging Setup --- logging.basicConfig( - level=logging.INFO, + level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) + # --- End Logging Setup --- - logging.info("--- Starting Pipeline Step Test ---") - config = load_config() + logging.info("--- Starting Pipeline Steps Test ---") - # --- Override config for faster testing --- - config['num'] = 1 # Generate only 1 topic - config['variants'] = 1 # Generate only 1 content/poster variant - logging.info(f"Config overridden for testing: num={config['num']}, variants={config['variants']}") + # 1. Load Configuration + config_path = os.path.join(PROJECT_ROOT, "poster_gen_config.json") # Use project root path + config = load_config(config_path) + if config is None: + logging.critical("Failed to load configuration. Exiting test.") + sys.exit(1) - run_id = None - tweet_topic_record = None - ai_agent_content = None # Separate agent instance for content/poster + # --- Initialize Output Handler --- + output_handler: OutputHandler = FileSystemOutputHandler(config.get("output_dir", "result")) + logging.info(f"Using Output Handler: {output_handler.__class__.__name__}") + # --- End Output Handler Init --- + # 2. Define a unique Run ID for this test run + test_run_id = f"test_pipeline_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + logging.info(f"Using Run ID: {test_run_id}") + + # --- Step 1: Topic Generation --- + logging.info("\n--- Testing Step 1: Topic Generation ---") + step1_start = time.time() + # run_topic_generation_pipeline still takes config internally + run_id_step1, topics_list, system_prompt, user_prompt = run_topic_generation_pipeline(config, run_id=test_run_id) + step1_end = time.time() + + if run_id_step1 is None or topics_list is None: + logging.error("Topic generation (Step 1) failed. Cannot proceed.") + sys.exit(1) + if run_id_step1 != test_run_id: + logging.warning(f"run_id returned from step 1 ({run_id_step1}) differs from the one provided ({test_run_id}). Using the returned one.") + test_run_id = run_id_step1 # Use the ID actually used by the pipeline + + logging.info(f"Step 1 finished in {step1_end - step1_start:.2f} seconds.") + logging.info(f"Generated {len(topics_list)} topics for Run ID: {test_run_id}") + # Use output handler to save topic results (mimics main.py) + output_handler.handle_topic_results(test_run_id, topics_list, system_prompt, user_prompt) + + # --- Step 2: Content and Poster Generation --- + logging.info("\n--- Testing Step 2: Content and Poster Generation ---") + step2_start = time.time() + + # Initialize resources needed for step 2, extracting params from config + prompt_manager = None + ai_agent = None + step2_success_flag = False try: - # --- Step 1: Test Topic Generation --- - logging.info("\n--- Testing Topic Generation ---") - run_id, tweet_topic_record = run_topic_generation_pipeline(config) # run_id generated inside if not passed - - if not run_id or not tweet_topic_record or not tweet_topic_record.topics_list: - logging.info("Topic generation failed or produced no topics. Exiting test.") - return - - logging.info(f"Topic generation successful. Run ID: {run_id}") - logging.info(f"Generated {len(tweet_topic_record.topics_list)} topic(s).") - test_topic = tweet_topic_record.topics_list[0] # Get the first topic for testing - logging.info("Test Topic Data:", json.dumps(test_topic, ensure_ascii=False, indent=2)) - - # --- Step 2: Test Content Generation (for the first topic) --- - logging.info("\n--- Testing Content Generation ---") - - # Initialize resources needed for content generation - prompt_manager = PromptManager(config) - logging.info("Initializing AI Agent for content...") + # --- Create PromptManager --- + prompt_manager = PromptManager( + topic_system_prompt_path=config.get("topic_system_prompt"), + topic_user_prompt_path=config.get("topic_user_prompt"), + content_system_prompt_path=config.get("content_system_prompt"), + prompts_dir=config.get("prompts_dir"), + resource_dir_config=config.get("resource_dir", []), + topic_gen_num=config.get("num", 1), + topic_gen_date=config.get("date", "") + ) + logging.info("PromptManager instance created for Step 2 test.") + + # --- Create AI Agent --- + ai_api_url = config.get("api_url") + ai_model = config.get("model") + ai_api_key = config.get("api_key") request_timeout = config.get("request_timeout", 30) max_retries = config.get("max_retries", 3) - ai_agent_content = AI_Agent( - config["api_url"], - config["model"], - config["api_key"], + if not all([ai_api_url, ai_model, ai_api_key]): + raise ValueError("Missing required AI configuration (api_url, model, api_key)") + logging.info("Initializing AI Agent for content generation test...") + ai_agent = AI_Agent( + base_url=ai_api_url, + model_name=ai_model, + api=ai_api_key, timeout=request_timeout, max_retries=max_retries ) - base_output_dir = config["output_dir"] - topic_index = 1 # Testing the first topic (1-based index) + total_topics = len(topics_list) + logging.info(f"Processing {total_topics} topics for content/posters...") + for i, topic_item in enumerate(topics_list): + topic_index = topic_item.get('index', i + 1) + logging.info(f"--- Processing Topic {topic_index}/{total_topics} ---") - tweet_content_list = generate_content_for_topic( - ai_agent_content, prompt_manager, config, test_topic, - base_output_dir, run_id, topic_index - ) - - if not tweet_content_list: - logging.info("Content generation failed or produced no content. Exiting test.") - return - - logging.info(f"Content generation successful. Generated {len(tweet_content_list)} variant(s).") - logging.info("Generated Content Data (first variant):", json.dumps(tweet_content_list[0], ensure_ascii=False, indent=2)) - - - # --- Step 3: Test Poster Generation (for the first topic/content) --- - logging.info("\n--- Testing Poster Generation ---") - - # Poster generation uses its own internal ContentGenerator and PosterGenerator instances - # We just need to call the function - success = generate_posters_for_topic( - config, - test_topic, - tweet_content_list, # Pass the list generated above - base_output_dir, - run_id, - topic_index - ) - - if success: - logging.info("Poster generation function executed (check output directory for results).") - else: - logging.info("Poster generation function reported failure or skipped execution.") + # --- Generate Content --- + content_variants = config.get("variants", 1) + content_temp = config.get("content_temperature", 0.3) + content_top_p = config.get("content_top_p", 0.4) + content_presence_penalty = config.get("content_presence_penalty", 1.5) + + content_success = generate_content_for_topic( + ai_agent=ai_agent, + prompt_manager=prompt_manager, + topic_item=topic_item, + run_id=test_run_id, + topic_index=topic_index, + output_handler=output_handler, + variants=content_variants, + temperature=content_temp, + top_p=content_top_p, + presence_penalty=content_presence_penalty + ) + if content_success: + logging.info(f"Content generation successful for Topic {topic_index}.") + # --- Generate Posters --- + poster_variants = config.get("variants", 1) + poster_assets_dir = config.get("poster_assets_base_dir") + img_base_dir = config.get("image_base_dir") + mod_img_subdir = config.get("modify_image_subdir", "modify") + res_dir_config = config.get("resource_dir", []) + poster_size = tuple(config.get("poster_target_size", [900, 1200])) + txt_possibility = config.get("text_possibility", 0.3) + collage_subdir = config.get("output_collage_subdir", "collage_img") + poster_subdir = config.get("output_poster_subdir", "poster") + poster_filename = config.get("output_poster_filename", "poster.jpg") + cam_img_subdir = config.get("camera_image_subdir", "相机") + + if not poster_assets_dir or not img_base_dir: + logging.error(f"Missing critical paths for poster generation. Skipping posters for topic {topic_index}.") + continue + + posters_attempted = generate_posters_for_topic( + topic_item=topic_item, + output_dir=config.get("output_dir", "result"), + run_id=test_run_id, + topic_index=topic_index, + output_handler=output_handler, + variants=poster_variants, + poster_assets_base_dir=poster_assets_dir, + image_base_dir=img_base_dir, + modify_image_subdir=mod_img_subdir, + resource_dir_config=res_dir_config, + poster_target_size=poster_size, + text_possibility=txt_possibility, + output_collage_subdir=collage_subdir, + output_poster_subdir=poster_subdir, + output_poster_filename=poster_filename, + camera_image_subdir=cam_img_subdir + ) + if posters_attempted: + logging.info(f"Poster generation process completed for Topic {topic_index}.") + step2_success_flag = True + else: + logging.warning(f"Poster generation skipped/failed for Topic {topic_index}.") + else: + logging.warning(f"Content generation failed for Topic {topic_index}. Skipping posters.") + logging.info(f"--- Finished Topic {topic_index} ---") + except ValueError as e: + logging.error(f"Configuration error during Step 2 setup: {e}") + step2_success_flag = False except Exception as e: - logging.info(f"\n--- An error occurred during testing ---") - logging.error(f"Error: {e}") - + logging.exception("An error occurred during Step 2 processing:") + step2_success_flag = False # Ensure flag is false on error finally: - # Clean up the content generation AI agent if it was created - if ai_agent_content: - logging.info("\nClosing content generation AI Agent...") - ai_agent_content.close() - logging.info("\n--- Test Finished ---") + if ai_agent: + logging.info("Closing AI Agent for content generation test...") + ai_agent.close() + # --- End Simulated Step 2 Logic --- + step2_end = time.time() + if step2_success_flag: + logging.info(f"Step 2 finished in {step2_end - step2_start:.2f} seconds.") + else: + logging.warning(f"Step 2 finished in {step2_end - step2_start:.2f} seconds, but encountered errors or generated no output.") + + # --- Finalize Output --- + if test_run_id: + output_handler.finalize(test_run_id) + # --- End Finalize --- + + # --- Test Summary --- + logging.info("\n--- Pipeline Steps Test Summary ---") + logging.info(f"Run ID: {test_run_id}") + output_location = os.path.join(config["output_dir"], test_run_id) + logging.info(f"Check output files in: {output_location}") + if os.path.exists(output_location): + logging.info("Output directory exists.") + else: + logging.warning("Output directory NOT found.") if __name__ == "__main__": - main_test() \ No newline at end of file + main_test() \ No newline at end of file diff --git a/examples/test_stream.py b/examples/test_stream.py index eeb093c..1f231ca 100644 --- a/examples/test_stream.py +++ b/examples/test_stream.py @@ -7,101 +7,119 @@ import json import time import logging -# Add project root to the Python path to allow importing modules from core and utils -project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.insert(0, project_root) +# Determine the project root directory (assuming examples/ is one level down) +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if PROJECT_ROOT not in sys.path: + sys.path.append(PROJECT_ROOT) -from core.ai_agent import AI_Agent +# Now import from core +try: + from core.ai_agent import AI_Agent +except ImportError as e: + logging.critical(f"Failed to import AI_Agent. Ensure '{PROJECT_ROOT}' is in sys.path and core/ai_agent.py exists. Error: {e}") + sys.exit(1) -def load_config(config_path="../poster_gen_config.json"): - """Loads configuration from a JSON file relative to this script.""" - if not os.path.exists(config_path): - logging.error(f"Error: Configuration file '{config_path}' not found.") - logging.error("Make sure you have copied 'example_config.json' to 'poster_gen_config.json' in the project root.") - sys.exit(1) +def load_config(config_path): + """Loads configuration from a JSON file.""" try: with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) - # Basic validation can be added here if needed - logging.info(f"Configuration loaded successfully from {config_path}") + logging.info(f"Config loaded successfully from {config_path}") return config + except FileNotFoundError: + logging.error(f"Error: Configuration file not found at {config_path}") + return None + except json.JSONDecodeError: + logging.error(f"Error: Could not decode JSON from {config_path}") + return None except Exception as e: - logging.error(f"Error loading configuration from '{config_path}': {e}") - sys.exit(1) + logging.exception(f"An unexpected error occurred loading config {config_path}:") + return None def main(): + # --- Basic Logging Setup --- logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) + # --- End Logging Setup --- - logging.info("--- Testing AI Agent Streaming ---") - - # 1. Load configuration - config = load_config() + logging.info("Starting AI Agent Stream Test...") - # 2. Define example prompts (replace with your desired test prompts) - test_system_prompt = "You are a helpful assistant. Respond concisely." - test_user_prompt = "Tell me a short story about a traveling robot." - - # You can optionally specify a folder with reference files - test_file_folder = None # Or e.g., "../resource/Object" - - # Get generation parameters from config or use defaults - temperature = config.get("content_temperature", 0.7) # Using content params as example - top_p = config.get("content_top_p", 0.9) - presence_penalty = config.get("content_presence_penalty", 1.0) + # Load configuration (adjust path relative to this script) + config_path = os.path.join(PROJECT_ROOT, "poster_gen_config.json") + config = load_config(config_path) + if config is None: + logging.critical("Failed to load configuration. Exiting test.") + sys.exit(1) + + # Example Prompts + system_prompt = "你是一个乐于助人的AI助手,擅长写短篇故事。" + user_prompt = "请写一个关于旅行机器人的短篇故事,它在一个充满异国情调的星球上发现了新的生命形式。" - # 3. Initialize AI Agent ai_agent = None try: + # --- Extract AI Agent parameters from config --- + ai_api_url = config.get("api_url") + ai_model = config.get("model") + ai_api_key = config.get("api_key") request_timeout = config.get("request_timeout", 30) max_retries = config.get("max_retries", 3) + # Check for required AI params + if not all([ai_api_url, ai_model, ai_api_key]): + logging.critical("Missing required AI configuration (api_url, model, api_key) in config. Exiting test.") + sys.exit(1) + # --- End Extract AI Agent params --- + + logging.info("Initializing AI Agent for stream test...") + # Initialize AI_Agent using extracted parameters ai_agent = AI_Agent( - config["api_url"], - config["model"], - config["api_key"], + api_url=ai_api_url, # Use extracted var + model=ai_model, # Use extracted var + api_key=ai_api_key, # Use extracted var timeout=request_timeout, max_retries=max_retries ) - logging.info("AI Agent initialized.") - - # 4. Call work_stream and process the generator - logging.info("\n--- Starting stream generation ---") + + # Example call to work_stream + logging.info("Calling ai_agent.work_stream...") + # Extract generation parameters from config + temperature = config.get("content_temperature", 0.7) # Use a relevant temperature setting + top_p = config.get("content_top_p", 0.9) + presence_penalty = config.get("content_presence_penalty", 0.0) + start_time = time.time() - stream_generator = ai_agent.work_stream( - test_system_prompt, - test_user_prompt, - test_file_folder, - temperature, - top_p, - presence_penalty + system_prompt=system_prompt, + user_prompt=user_prompt, + info_directory=None, # No extra context folder for this test + temperature=temperature, + top_p=top_p, + presence_penalty=presence_penalty ) - - full_response_streamed = "" - try: - for chunk in stream_generator: - print(chunk, end="", flush=True) # Print each chunk as it arrives - full_response_streamed += chunk - except Exception as e: - logging.error(f"\nError while iterating through stream generator: {e}") - + + # Process the stream + logging.info("Processing stream response:") + full_response = "" + for chunk in stream_generator: + print(chunk, end="", flush=True) # Keep print for stream output + full_response += chunk + end_time = time.time() - logging.info(f"\n--- Stream finished in {end_time - start_time:.2f} seconds ---") - # print(f"Full response received via stream:\n{full_response_streamed}") # Optionally print the assembled response - + logging.info(f"\n--- Stream Finished ---") + logging.info(f"Total time: {end_time - start_time:.2f} seconds") + logging.info(f"Total characters received: {len(full_response)}") + + except KeyError as e: + logging.error(f"Configuration error: Missing key '{e}'. Please check '{config_path}'.") except Exception as e: - logging.error(f"\nAn error occurred: {e}") - import traceback - traceback.print_exc() - + logging.exception("An error occurred during the stream test:") finally: - # 5. Close the agent + # Ensure the agent is closed if ai_agent: - logging.info("\nClosing AI Agent...") + logging.info("Closing AI Agent...") ai_agent.close() logging.info("AI Agent closed.") diff --git a/main.py b/main.py index dc30c2e..ed7fbbc 100644 --- a/main.py +++ b/main.py @@ -71,7 +71,23 @@ def generate_content_and_posters_step(config, run_id, topics_list, output_handle logging.info(f"Processing {len(topics_list)} topics...") success_flag = False - prompt_manager = PromptManager(config) + # --- 创建 PromptManager 实例 (传入具体参数) --- + try: + prompt_manager = PromptManager( + topic_system_prompt_path=config.get("topic_system_prompt"), + topic_user_prompt_path=config.get("topic_user_prompt"), + content_system_prompt_path=config.get("content_system_prompt"), + prompts_dir=config.get("prompts_dir"), + resource_dir_config=config.get("resource_dir", []), + topic_gen_num=config.get("num", 1), # Topic gen num/date used by topic prompts + topic_gen_date=config.get("date", "") + ) + logging.info("PromptManager instance created for Step 2.") + except KeyError as e: + logging.error(f"Configuration error creating PromptManager: Missing key '{e}'. Cannot proceed with content generation.") + return False + # --- 结束创建 PromptManager --- + ai_agent = None try: # --- Initialize AI Agent for Content Generation --- diff --git a/utils/__pycache__/prompt_manager.cpython-312.pyc b/utils/__pycache__/prompt_manager.cpython-312.pyc index 1039bfb..d602808 100644 Binary files a/utils/__pycache__/prompt_manager.cpython-312.pyc and b/utils/__pycache__/prompt_manager.cpython-312.pyc differ diff --git a/utils/prompt_manager.py b/utils/prompt_manager.py index 383fb69..bcbf72f 100644 --- a/utils/prompt_manager.py +++ b/utils/prompt_manager.py @@ -137,36 +137,60 @@ class PromptManager: except Exception as e: logging.exception("Error processing Demand description:") - # 2. 添加Object信息 (based on topic_item['object']) + # 2. Object Info - 先列出所有可用文件,再注入匹配文件的内容 try: - object_name_base = topic_item['object'] # This might be '景点信息-XXX.txt' - object_file_path = None - # Find the full path for the object file from config + object_name_from_topic = topic_item.get('object') # e.g., "尚书第建筑群" + object_file_basenames = [] + matched_object_file_path = None + matched_object_basename = None + + # 遍历查找 Object 文件 for dir_info in resource_dir_config: if dir_info.get("type") == "Object": for file_path in dir_info.get("file_path", []): - # Match basename, assuming topic_item['object'] is the basename - # if os.path.basename(file_path) == object_name_base: - # Use containment check instead of exact match - if object_name_base in os.path.basename(file_path): - object_file_path = file_path - break - if object_file_path: break - - if object_file_path: - object_content = ResourceLoader.load_file_content(object_file_path) - if object_content: - user_prompt += f"Object Info:\n{object_content}\n" - else: - logging.warning(f"Object file could not be loaded: {object_file_path}") + basename = os.path.basename(file_path) + object_file_basenames.append(basename) + + # 尝试匹配当前 topic 的 object (仅当尚未找到匹配时) + if object_name_from_topic and not matched_object_file_path: + cleaned_resource_name = basename + if cleaned_resource_name.startswith("景点信息-"): + cleaned_resource_name = cleaned_resource_name[len("景点信息-"):] + if cleaned_resource_name.endswith(".txt"): + cleaned_resource_name = cleaned_resource_name[:-len(".txt")] + + if cleaned_resource_name and cleaned_resource_name in object_name_from_topic: + matched_object_file_path = file_path + matched_object_basename = basename + # 注意:这里不 break,继续收集所有文件名 + + # 构建提示词 - Part 1: 文件列表 + if object_file_basenames: + user_prompt += "Object信息:\n" + # user_prompt += f"{object_file_basenames}\n\n" # 直接打印列表可能不够清晰 + for fname in object_file_basenames: + user_prompt += f"- {fname}\n" + user_prompt += "\n" # 加一个空行 + logging.info(f"Listed {len(object_file_basenames)} available object files.") else: - # If basename match fails, maybe topic_item['object'] is just 'XXX'? - # Try finding based on substring? This might be ambiguous. - logging.warning(f"Object file path not found in config matching object: {topic_item.get('object')}") - except KeyError: - logging.warning("Warning: 'object' key missing in topic_item for Object prompt.") + logging.warning("No resource directory entry found with type 'Object', or it has no file paths.") + + # 构建提示词 - Part 2: 注入匹配文件内容 + if matched_object_file_path: + logging.info(f"Attempting to load content for matched object file: {matched_object_basename}") + matched_object_content = ResourceLoader.load_file_content(matched_object_file_path) + if matched_object_content: + user_prompt += f"{matched_object_basename}:\n{matched_object_content}\n\n" + logging.info(f"Successfully loaded and injected content for: {matched_object_basename}") + else: + logging.warning(f"Object file matched ({matched_object_basename}) but could not be loaded or is empty.") + elif object_name_from_topic: # 只有当 topic 中指定了 object 但没找到匹配文件时才警告 + logging.warning(f"Could not find a matching Object resource file to inject content for '{object_name_from_topic}'. Only the list of files was provided.") + + except KeyError: + logging.warning("Warning: 'object' key potentially missing in topic_item.") except Exception as e: - logging.exception("Error processing Object prompt:") + logging.exception("Error processing Object prompt section:") # 3. 添加Product信息 (if applicable) try: