236 lines
10 KiB
Python
236 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import time
|
|
from datetime import datetime
|
|
import logging
|
|
|
|
# --- Setup Project Root Path ---
|
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
if PROJECT_ROOT not in sys.path:
|
|
sys.path.append(PROJECT_ROOT)
|
|
# --- End Setup ---
|
|
|
|
# --- Imports from the project ---
|
|
try:
|
|
from utils.tweet_generator import run_topic_generation_pipeline, generate_content_for_topic, generate_posters_for_topic
|
|
from core.topic_parser import TopicParser
|
|
from utils.prompt_manager import PromptManager # Needed for content generation
|
|
from core.ai_agent import AI_Agent # Needed for content generation
|
|
# from utils.tweet_generator import tweetTopicRecord # No longer needed directly
|
|
from utils.output_handler import FileSystemOutputHandler, OutputHandler # Import handlers
|
|
except ImportError as e:
|
|
logging.critical(f"ImportError: {e}. Ensure all core/utils modules are available and '{PROJECT_ROOT}' is in sys.path.")
|
|
sys.exit(1)
|
|
# --- End Imports ---
|
|
|
|
def load_config(config_path):
|
|
"""Loads configuration from a JSON file."""
|
|
try:
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
config = json.load(f)
|
|
logging.info(f"Config loaded successfully from {config_path}")
|
|
return config
|
|
except FileNotFoundError:
|
|
logging.error(f"Error: Configuration file not found at {config_path}")
|
|
return None
|
|
except json.JSONDecodeError:
|
|
logging.error(f"Error: Could not decode JSON from {config_path}")
|
|
return None
|
|
except Exception as e:
|
|
logging.exception(f"An unexpected error occurred loading config {config_path}:")
|
|
return None
|
|
|
|
def main_test():
|
|
# --- Basic Logging Setup ---
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
)
|
|
# --- End Logging Setup ---
|
|
|
|
logging.info("--- Starting Pipeline Steps Test ---")
|
|
|
|
# 1. Load Configuration
|
|
config_path = os.path.join(PROJECT_ROOT, "poster_gen_config.json") # Use project root path
|
|
config = load_config(config_path)
|
|
if config is None:
|
|
logging.critical("Failed to load configuration. Exiting test.")
|
|
sys.exit(1)
|
|
|
|
# --- Initialize Output Handler ---
|
|
output_handler: OutputHandler = FileSystemOutputHandler(config.get("output_dir", "result"))
|
|
logging.info(f"Using Output Handler: {output_handler.__class__.__name__}")
|
|
# --- End Output Handler Init ---
|
|
|
|
# 2. Define a unique Run ID for this test run
|
|
test_run_id = f"test_pipeline_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
logging.info(f"Using Run ID: {test_run_id}")
|
|
|
|
# --- Step 1: Topic Generation ---
|
|
logging.info("\n--- Testing Step 1: Topic Generation ---")
|
|
step1_start = time.time()
|
|
# run_topic_generation_pipeline still takes config internally
|
|
run_id_step1, topics_list, system_prompt, user_prompt = run_topic_generation_pipeline(config, run_id=test_run_id)
|
|
step1_end = time.time()
|
|
|
|
if run_id_step1 is None or topics_list is None:
|
|
logging.error("Topic generation (Step 1) failed. Cannot proceed.")
|
|
sys.exit(1)
|
|
if run_id_step1 != test_run_id:
|
|
logging.warning(f"run_id returned from step 1 ({run_id_step1}) differs from the one provided ({test_run_id}). Using the returned one.")
|
|
test_run_id = run_id_step1 # Use the ID actually used by the pipeline
|
|
|
|
logging.info(f"Step 1 finished in {step1_end - step1_start:.2f} seconds.")
|
|
logging.info(f"Generated {len(topics_list)} topics for Run ID: {test_run_id}")
|
|
# Use output handler to save topic results (mimics main.py)
|
|
output_handler.handle_topic_results(test_run_id, topics_list, system_prompt, user_prompt)
|
|
|
|
# --- Step 2: Content and Poster Generation ---
|
|
logging.info("\n--- Testing Step 2: Content and Poster Generation ---")
|
|
step2_start = time.time()
|
|
|
|
# Initialize resources needed for step 2, extracting params from config
|
|
prompt_manager = None
|
|
ai_agent = None
|
|
step2_success_flag = False
|
|
try:
|
|
# --- Create PromptManager ---
|
|
prompt_manager = PromptManager(
|
|
topic_system_prompt_path=config.get("topic_system_prompt"),
|
|
topic_user_prompt_path=config.get("topic_user_prompt"),
|
|
content_system_prompt_path=config.get("content_system_prompt"),
|
|
prompts_dir=config.get("prompts_dir"),
|
|
resource_dir_config=config.get("resource_dir", []),
|
|
topic_gen_num=config.get("num", 1),
|
|
topic_gen_date=config.get("date", "")
|
|
)
|
|
logging.info("PromptManager instance created for Step 2 test.")
|
|
|
|
# --- Create AI Agent ---
|
|
ai_api_url = config.get("api_url")
|
|
ai_model = config.get("model")
|
|
ai_api_key = config.get("api_key")
|
|
request_timeout = config.get("request_timeout", 30)
|
|
max_retries = config.get("max_retries", 3)
|
|
if not all([ai_api_url, ai_model, ai_api_key]):
|
|
raise ValueError("Missing required AI configuration (api_url, model, api_key)")
|
|
logging.info("Initializing AI Agent for content generation test...")
|
|
ai_agent = AI_Agent(
|
|
base_url=ai_api_url,
|
|
model_name=ai_model,
|
|
api=ai_api_key,
|
|
timeout=request_timeout,
|
|
max_retries=max_retries
|
|
)
|
|
|
|
total_topics = len(topics_list)
|
|
logging.info(f"Processing {total_topics} topics for content/posters...")
|
|
for i, topic_item in enumerate(topics_list):
|
|
topic_index = topic_item.get('index', i + 1)
|
|
logging.info(f"--- Processing Topic {topic_index}/{total_topics} ---")
|
|
|
|
# --- Generate Content ---
|
|
content_variants = config.get("variants", 1)
|
|
content_temp = config.get("content_temperature", 0.3)
|
|
content_top_p = config.get("content_top_p", 0.4)
|
|
content_presence_penalty = config.get("content_presence_penalty", 1.5)
|
|
|
|
content_success = generate_content_for_topic(
|
|
ai_agent=ai_agent,
|
|
prompt_manager=prompt_manager,
|
|
topic_item=topic_item,
|
|
run_id=test_run_id,
|
|
topic_index=topic_index,
|
|
output_handler=output_handler,
|
|
variants=content_variants,
|
|
temperature=content_temp,
|
|
top_p=content_top_p,
|
|
presence_penalty=content_presence_penalty
|
|
)
|
|
|
|
if content_success:
|
|
logging.info(f"Content generation successful for Topic {topic_index}.")
|
|
# --- Generate Posters ---
|
|
poster_variants = config.get("variants", 1)
|
|
poster_assets_dir = config.get("poster_assets_base_dir")
|
|
img_base_dir = config.get("image_base_dir")
|
|
mod_img_subdir = config.get("modify_image_subdir", "modify")
|
|
res_dir_config = config.get("resource_dir", [])
|
|
poster_size = tuple(config.get("poster_target_size", [900, 1200]))
|
|
txt_possibility = config.get("text_possibility", 0.3)
|
|
collage_subdir = config.get("output_collage_subdir", "collage_img")
|
|
poster_subdir = config.get("output_poster_subdir", "poster")
|
|
poster_filename = config.get("output_poster_filename", "poster.jpg")
|
|
cam_img_subdir = config.get("camera_image_subdir", "相机")
|
|
|
|
if not poster_assets_dir or not img_base_dir:
|
|
logging.error(f"Missing critical paths for poster generation. Skipping posters for topic {topic_index}.")
|
|
continue
|
|
|
|
posters_attempted = generate_posters_for_topic(
|
|
topic_item=topic_item,
|
|
output_dir=config.get("output_dir", "result"),
|
|
run_id=test_run_id,
|
|
topic_index=topic_index,
|
|
output_handler=output_handler,
|
|
variants=poster_variants,
|
|
poster_assets_base_dir=poster_assets_dir,
|
|
image_base_dir=img_base_dir,
|
|
modify_image_subdir=mod_img_subdir,
|
|
resource_dir_config=res_dir_config,
|
|
poster_target_size=poster_size,
|
|
text_possibility=txt_possibility,
|
|
output_collage_subdir=collage_subdir,
|
|
output_poster_subdir=poster_subdir,
|
|
output_poster_filename=poster_filename,
|
|
camera_image_subdir=cam_img_subdir
|
|
)
|
|
if posters_attempted:
|
|
logging.info(f"Poster generation process completed for Topic {topic_index}.")
|
|
step2_success_flag = True
|
|
else:
|
|
logging.warning(f"Poster generation skipped/failed for Topic {topic_index}.")
|
|
else:
|
|
logging.warning(f"Content generation failed for Topic {topic_index}. Skipping posters.")
|
|
logging.info(f"--- Finished Topic {topic_index} ---")
|
|
|
|
except ValueError as e:
|
|
logging.error(f"Configuration error during Step 2 setup: {e}")
|
|
step2_success_flag = False
|
|
except Exception as e:
|
|
logging.exception("An error occurred during Step 2 processing:")
|
|
step2_success_flag = False # Ensure flag is false on error
|
|
finally:
|
|
if ai_agent:
|
|
logging.info("Closing AI Agent for content generation test...")
|
|
ai_agent.close()
|
|
# --- End Simulated Step 2 Logic ---
|
|
|
|
step2_end = time.time()
|
|
if step2_success_flag:
|
|
logging.info(f"Step 2 finished in {step2_end - step2_start:.2f} seconds.")
|
|
else:
|
|
logging.warning(f"Step 2 finished in {step2_end - step2_start:.2f} seconds, but encountered errors or generated no output.")
|
|
|
|
# --- Finalize Output ---
|
|
if test_run_id:
|
|
output_handler.finalize(test_run_id)
|
|
# --- End Finalize ---
|
|
|
|
# --- Test Summary ---
|
|
logging.info("\n--- Pipeline Steps Test Summary ---")
|
|
logging.info(f"Run ID: {test_run_id}")
|
|
output_location = os.path.join(config["output_dir"], test_run_id)
|
|
logging.info(f"Check output files in: {output_location}")
|
|
if os.path.exists(output_location):
|
|
logging.info("Output directory exists.")
|
|
else:
|
|
logging.warning("Output directory NOT found.")
|
|
|
|
if __name__ == "__main__":
|
|
main_test() |