392 lines
20 KiB
Python
392 lines
20 KiB
Python
import os
|
|
import time
|
|
from datetime import datetime
|
|
import argparse
|
|
import sys
|
|
import traceback
|
|
import json
|
|
import logging
|
|
|
|
from core.ai_agent import AI_Agent
|
|
# from core.topic_parser import TopicParser # No longer needed directly in main?
|
|
import core.contentGen as contentGen
|
|
import core.posterGen as posterGen
|
|
import core.simple_collage as simple_collage
|
|
from utils.resource_loader import ResourceLoader
|
|
from utils.tweet_generator import ( # Import the moved functions
|
|
run_topic_generation_pipeline,
|
|
generate_content_for_topic,
|
|
generate_posters_for_topic
|
|
)
|
|
from utils.prompt_manager import PromptManager # Import PromptManager
|
|
import random
|
|
# Import Output Handlers
|
|
from utils.output_handler import FileSystemOutputHandler, OutputHandler
|
|
from core.topic_parser import TopicParser
|
|
from utils.tweet_generator import tweetTopicRecord # Needed only if loading old topics files?
|
|
|
|
def load_config(config_path="poster_gen_config.json"):
|
|
"""Loads configuration from a JSON file."""
|
|
if not os.path.exists(config_path):
|
|
print(f"Error: Configuration file '{config_path}' not found.")
|
|
print("Please copy 'example_config.json' to 'poster_gen_config.json' and customize it.")
|
|
sys.exit(1)
|
|
try:
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
config = json.load(f)
|
|
# Basic validation (can be expanded)
|
|
required_keys = ["api_url", "model", "api_key", "resource_dir", "prompts_dir", "output_dir", "num", "variants", "topic_system_prompt", "topic_user_prompt", "content_system_prompt", "image_base_dir"]
|
|
if not all(key in config for key in required_keys):
|
|
missing_keys = [key for key in required_keys if key not in config]
|
|
print(f"Error: Config file '{config_path}' is missing required keys: {missing_keys}")
|
|
sys.exit(1)
|
|
# Resolve relative paths based on config location or a defined base path if necessary
|
|
# For simplicity, assuming paths in config are relative to project root or absolute
|
|
return config
|
|
except json.JSONDecodeError:
|
|
print(f"Error: Could not decode JSON from '{config_path}'. Check the file format.")
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
print(f"Error loading configuration from '{config_path}': {e}")
|
|
sys.exit(1)
|
|
|
|
# Removed generate_topics_step function definition from here
|
|
# Its logic is now in utils.tweet_generator.run_topic_generation_pipeline
|
|
|
|
# --- Main Orchestration Step (Remains in main.py) ---
|
|
|
|
def generate_content_and_posters_step(config, run_id, topics_list, output_handler):
|
|
"""
|
|
Step 2: Generates content and posters for each topic in the record.
|
|
Returns True if successful (at least partially), False otherwise.
|
|
"""
|
|
if not topics_list or not topics_list:
|
|
# print("Skipping content/poster generation: No valid topics found in the record.")
|
|
logging.warning("Skipping content/poster generation: No valid topics found in the record.")
|
|
return False
|
|
|
|
# print(f"\n--- Starting Step 2: Content and Poster Generation for run_id: {run_id} ---")
|
|
logging.info(f"Starting Step 2: Content and Poster Generation for run_id: {run_id}")
|
|
# print(f"Processing {len(topics_list)} topics...")
|
|
logging.info(f"Processing {len(topics_list)} topics...")
|
|
|
|
success_flag = False
|
|
# --- 创建 PromptManager 实例 (传入具体参数) ---
|
|
try:
|
|
prompt_manager = PromptManager(
|
|
topic_system_prompt_path=config.get("topic_system_prompt"),
|
|
topic_user_prompt_path=config.get("topic_user_prompt"),
|
|
content_system_prompt_path=config.get("content_system_prompt"),
|
|
prompts_dir=config.get("prompts_dir"),
|
|
resource_dir_config=config.get("resource_dir", []),
|
|
topic_gen_num=config.get("num", 1), # Topic gen num/date used by topic prompts
|
|
topic_gen_date=config.get("date", "")
|
|
)
|
|
logging.info("PromptManager instance created for Step 2.")
|
|
except KeyError as e:
|
|
logging.error(f"Configuration error creating PromptManager: Missing key '{e}'. Cannot proceed with content generation.")
|
|
return False
|
|
# --- 结束创建 PromptManager ---
|
|
|
|
ai_agent = None
|
|
try:
|
|
# --- Initialize AI Agent for Content Generation ---
|
|
request_timeout = config.get("request_timeout", 30) # Default 30 seconds
|
|
max_retries = config.get("max_retries", 3) # Default 3 retries
|
|
stream_chunk_timeout = config.get("stream_chunk_timeout", 60) # Default 60 seconds for stream chunk
|
|
ai_agent = AI_Agent(
|
|
config["api_url"],
|
|
config["model"],
|
|
config["api_key"],
|
|
timeout=request_timeout,
|
|
max_retries=max_retries,
|
|
stream_chunk_timeout=stream_chunk_timeout
|
|
)
|
|
logging.info("AI Agent for content generation initialized.")
|
|
|
|
# --- Iterate through Topics ---
|
|
for i, topic_item in enumerate(topics_list):
|
|
topic_index = topic_item.get('index', i + 1) # Use parsed index if available
|
|
logging.info(f"--- Processing Topic {topic_index}/{len(topics_list)}: {topic_item.get('object', 'N/A')} ---") # Make it stand out
|
|
|
|
# --- Generate Content Variants ---
|
|
# 读取内容生成需要的参数
|
|
content_variants = config.get("variants", 1)
|
|
content_temp = config.get("content_temperature", 0.3)
|
|
content_top_p = config.get("content_top_p", 0.4)
|
|
content_presence_penalty = config.get("content_presence_penalty", 1.5)
|
|
|
|
# 调用修改后的 generate_content_for_topic
|
|
content_success = generate_content_for_topic(
|
|
ai_agent,
|
|
prompt_manager, # Pass PromptManager instance
|
|
topic_item,
|
|
run_id,
|
|
topic_index,
|
|
output_handler, # Pass OutputHandler instance
|
|
# 传递具体参数
|
|
variants=content_variants,
|
|
temperature=content_temp,
|
|
top_p=content_top_p,
|
|
presence_penalty=content_presence_penalty
|
|
)
|
|
|
|
# if tweet_content_list: # generate_content_for_topic 现在返回 bool
|
|
if content_success:
|
|
logging.info(f"Content generation successful for Topic {topic_index}.")
|
|
# --- Generate Posters ---
|
|
# TODO: 重构 generate_posters_for_topic 以移除 config 依赖
|
|
# TODO: 需要确定如何将 content 数据传递给 poster 生成步骤 (已解决:函数内部读取)
|
|
# 临时方案:可能需要在这里读取由 output_handler 保存的 content 文件
|
|
# 或者修改 generate_content_for_topic 以返回收集到的 content 数据列表 (选项1)
|
|
# 暂时跳过 poster 生成的调用,直到确定方案
|
|
|
|
# --- 重新启用 poster 生成调用 ---
|
|
logging.info(f"Proceeding to poster generation for Topic {topic_index}...")
|
|
|
|
# --- 读取 Poster 生成所需参数 ---
|
|
poster_variants = config.get("variants", 1) # 通常与 content variants 相同
|
|
poster_assets_dir = config.get("poster_assets_base_dir")
|
|
img_base_dir = config.get("image_base_dir")
|
|
mod_img_subdir = config.get("modify_image_subdir", "modify")
|
|
res_dir_config = config.get("resource_dir", [])
|
|
poster_size = tuple(config.get("poster_target_size", [900, 1200]))
|
|
txt_possibility = config.get("text_possibility", 0.3)
|
|
img_frame_possibility = config.get("img_frame_possibility", 0.7)
|
|
text_bg_possibility = config.get("text_bg_possibility", 0)
|
|
collage_subdir = config.get("output_collage_subdir", "collage_img")
|
|
poster_subdir = config.get("output_poster_subdir", "poster")
|
|
poster_filename = config.get("output_poster_filename", "poster.jpg")
|
|
cam_img_subdir = config.get("camera_image_subdir", "相机")
|
|
poster_content_system_prompt = config.get("poster_content_system_prompt", None)
|
|
|
|
# 检查关键路径是否存在
|
|
if not poster_assets_dir or not img_base_dir:
|
|
logging.error(f"Missing critical paths for poster generation (poster_assets_base_dir or image_base_dir) in config. Skipping posters for topic {topic_index}.")
|
|
continue # 跳过此主题的海报生成
|
|
# --- 结束读取参数 ---
|
|
|
|
with open(poster_content_system_prompt, "r", encoding="utf-8") as f:
|
|
poster_content_system_prompt = f.read()
|
|
|
|
posters_attempted = generate_posters_for_topic(
|
|
topic_item=topic_item,
|
|
output_dir=config["output_dir"], # Base output dir is still needed
|
|
run_id=run_id,
|
|
topic_index=topic_index,
|
|
output_handler=output_handler, # <--- 传递 Output Handler
|
|
# 传递具体参数
|
|
variants=poster_variants,
|
|
poster_assets_base_dir=poster_assets_dir,
|
|
image_base_dir=img_base_dir,
|
|
modify_image_subdir=mod_img_subdir,
|
|
resource_dir_config=res_dir_config,
|
|
poster_target_size=poster_size,
|
|
text_possibility=txt_possibility,
|
|
img_frame_possibility=img_frame_possibility,
|
|
text_bg_possibility=text_bg_possibility,
|
|
output_collage_subdir=collage_subdir,
|
|
output_poster_subdir=poster_subdir,
|
|
output_poster_filename=poster_filename,
|
|
camera_image_subdir=cam_img_subdir,
|
|
system_prompt=poster_content_system_prompt
|
|
)
|
|
if posters_attempted:
|
|
logging.info(f"Poster generation process completed for Topic {topic_index}.")
|
|
success_flag = True # Mark overall success if content AND poster attempts were made
|
|
else:
|
|
logging.warning(f"Poster generation skipped or failed early for Topic {topic_index}.")
|
|
# 即使海报失败,只要内容成功了,也算部分成功?根据需求决定 success_flag
|
|
# success_flag = True # 取决于是否认为内容成功就足够
|
|
# logging.warning(f"Skipping poster generation for Topic {topic_index} pending refactor and data passing strategy.")
|
|
# Mark overall success if content generation succeeded
|
|
# success_flag = True
|
|
else:
|
|
logging.warning(f"Content generation failed or yielded no valid results for Topic {topic_index}. Skipping posters.")
|
|
logging.info(f"--- Finished Topic {topic_index} ---")
|
|
|
|
|
|
except KeyError as e:
|
|
# print(f"\nError: Configuration error during content/poster generation: Missing key '{e}'")
|
|
logging.error(f"Configuration error during content/poster generation: Missing key '{e}'")
|
|
traceback.print_exc()
|
|
return False # Indicate failure due to config error
|
|
except Exception as e:
|
|
# print(f"\nAn unexpected error occurred during content and poster generation: {e}")
|
|
# traceback.print_exc()
|
|
logging.exception("An unexpected error occurred during content and poster generation:")
|
|
return False # Indicate general failure
|
|
finally:
|
|
# Ensure the AI agent is closed
|
|
if ai_agent:
|
|
# print("Closing content generation AI Agent...")
|
|
logging.info("Closing content generation AI Agent...")
|
|
ai_agent.close()
|
|
|
|
if success_flag:
|
|
# print("\n--- Step 2: Content and Poster Generation completed (at least partially). ---")
|
|
logging.info("Step 2: Content and Poster Generation completed (at least partially).")
|
|
else:
|
|
# print("\n--- Step 2: Content and Poster Generation completed, but no content/posters were successfully generated or attempted. ---")
|
|
logging.warning("Step 2: Content and Poster Generation completed, but no content/posters were successfully generated or attempted.")
|
|
return success_flag # Return True if at least some steps were attempted
|
|
|
|
def main():
|
|
# --- Basic Logging Setup ---
|
|
logging.basicConfig(
|
|
level=logging.INFO, # Default level
|
|
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
)
|
|
# --- End Logging Setup ---
|
|
|
|
parser = argparse.ArgumentParser(description="Travel Content Creator Pipeline")
|
|
parser.add_argument(
|
|
"--config",
|
|
type=str,
|
|
default="poster_gen_config.json",
|
|
help="Path to the configuration file (e.g., poster_gen_config.json)"
|
|
)
|
|
parser.add_argument(
|
|
"--run_id",
|
|
type=str,
|
|
default=None, # Default to None, let the pipeline generate it
|
|
help="Optional specific run ID (e.g., 'my_test_run_01'). If not provided, a timestamp-based ID will be generated."
|
|
)
|
|
parser.add_argument(
|
|
"--topics_file",
|
|
type=str,
|
|
default=None,
|
|
help="Optional path to a pre-generated topics JSON file. If provided, skips topic generation (Step 1)."
|
|
)
|
|
parser.add_argument(
|
|
"--debug",
|
|
action='store_true', # Add debug flag
|
|
help="Enable debug level logging."
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# --- Adjust Logging Level if Debug ---
|
|
if args.debug:
|
|
logging.getLogger().setLevel(logging.DEBUG)
|
|
logging.info("Debug logging enabled.")
|
|
# --- End Debug Level Adjustment ---
|
|
|
|
logging.info("Starting Travel Content Creator Pipeline...")
|
|
logging.info(f"Using configuration file: {args.config}")
|
|
if args.run_id: logging.info(f"Using specific run_id: {args.run_id}")
|
|
if args.topics_file: logging.info(f"Using existing topics file: {args.topics_file}")
|
|
|
|
config = load_config(args.config)
|
|
if config is None:
|
|
logging.critical("Failed to load configuration. Exiting.")
|
|
sys.exit(1)
|
|
|
|
# --- Initialize Output Handler ---
|
|
# For now, always use FileSystemOutputHandler. Later, this could be configurable.
|
|
output_handler: OutputHandler = FileSystemOutputHandler(config.get("output_dir", "result"))
|
|
logging.info(f"Using Output Handler: {output_handler.__class__.__name__}")
|
|
# --- End Output Handler Init ---
|
|
|
|
run_id = args.run_id
|
|
# tweet_topic_record = None # No longer the primary way to pass data
|
|
topics_list = None
|
|
system_prompt = None
|
|
user_prompt = None
|
|
pipeline_start_time = time.time()
|
|
|
|
# --- Step 1: Topic Generation (or Load Existing) ---
|
|
if args.topics_file:
|
|
logging.info(f"Skipping Topic Generation (Step 1) - Loading topics from: {args.topics_file}")
|
|
topics_list = TopicParser.load_topics_from_json(args.topics_file)
|
|
if topics_list:
|
|
# Try to infer run_id from filename if not provided
|
|
if not run_id:
|
|
try:
|
|
base = os.path.basename(args.topics_file)
|
|
# Assuming format "tweet_topic_{run_id}.json" or "tweet_topic.json"
|
|
if base.startswith("tweet_topic_") and base.endswith(".json"):
|
|
run_id = base[len("tweet_topic_"): -len(".json")]
|
|
# print(f"Inferred run_id from topics filename: {run_id}")
|
|
logging.info(f"Inferred run_id from topics filename: {run_id}")
|
|
elif base == "tweet_topic.json": # Handle the default name?
|
|
# Decide how to handle this - maybe use parent dir name or generate new?
|
|
# For now, let's generate a new one if run_id is still None below
|
|
logging.warning(f"Loaded topics from default filename '{base}'. Run ID not inferred.")
|
|
else:
|
|
# print(f"Warning: Could not infer run_id from topics filename: {base}")
|
|
logging.warning(f"Could not infer run_id from topics filename: {base}")
|
|
except Exception as e:
|
|
# print(f"Warning: Error trying to infer run_id from topics filename: {e}")
|
|
logging.warning(f"Error trying to infer run_id from topics filename: {e}")
|
|
|
|
# If run_id is still None after trying to infer, generate one
|
|
if run_id is None:
|
|
run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S_loaded")
|
|
# print(f"Generated run_id for loaded topics: {run_id}")
|
|
logging.info(f"Generated run_id for loaded topics: {run_id}")
|
|
|
|
# Prompts are missing when loading from file, handle this if needed later
|
|
system_prompt = "" # Placeholder
|
|
user_prompt = "" # Placeholder
|
|
logging.info(f"Successfully loaded {len(topics_list)} topics for run_id: {run_id}. Prompts are not available.")
|
|
# Optionally, save the loaded topics using the handler?
|
|
# output_handler.handle_topic_results(run_id, topics_list, system_prompt, user_prompt)
|
|
else:
|
|
# print(f"Error: Failed to load topics from {args.topics_file}. Cannot proceed.")
|
|
logging.error(f"Failed to load topics from {args.topics_file}. Cannot proceed.")
|
|
sys.exit(1)
|
|
else:
|
|
# print("--- Executing Topic Generation (Step 1) ---")
|
|
logging.info("Executing Topic Generation (Step 1)...")
|
|
step1_start = time.time()
|
|
# Call the updated function, receive raw data
|
|
run_id, topics_list, system_prompt, user_prompt = run_topic_generation_pipeline(config, args.run_id)
|
|
step1_end = time.time()
|
|
if run_id is not None and topics_list is not None: # Check if step succeeded
|
|
# print(f"Step 1 completed successfully in {step1_end - step1_start:.2f} seconds. Run ID: {run_id}")
|
|
logging.info(f"Step 1 completed successfully in {step1_end - step1_start:.2f} seconds. Run ID: {run_id}")
|
|
# --- Use Output Handler to save results ---
|
|
output_handler.handle_topic_results(run_id, topics_list, system_prompt, user_prompt)
|
|
else:
|
|
# print("Critical: Topic Generation (Step 1) failed. Exiting.")
|
|
logging.critical("Topic Generation (Step 1) failed. Exiting.")
|
|
sys.exit(1)
|
|
|
|
# --- Step 2: Content & Poster Generation ---
|
|
if run_id is not None and topics_list is not None:
|
|
# print("\n--- Executing Content and Poster Generation (Step 2) ---")
|
|
logging.info("Executing Content and Poster Generation (Step 2)...")
|
|
step2_start = time.time()
|
|
# TODO: Refactor generate_content_and_posters_step to accept topics_list
|
|
# and use the output_handler instead of saving files directly.
|
|
# For now, we might need to pass topics_list and handler, or adapt it.
|
|
# Let's tentatively adapt the call signature, assuming the function will be refactored.
|
|
success = generate_content_and_posters_step(config, run_id, topics_list, output_handler)
|
|
step2_end = time.time()
|
|
if success:
|
|
# print(f"Step 2 completed in {step2_end - step2_start:.2f} seconds.")
|
|
logging.info(f"Step 2 completed in {step2_end - step2_start:.2f} seconds.")
|
|
else:
|
|
# print("Warning: Step 2 finished, but may have encountered errors or generated no output.")
|
|
logging.warning("Step 2 finished, but may have encountered errors or generated no output.")
|
|
else:
|
|
# print("Error: Cannot proceed to Step 2: Invalid run_id or topics_list from Step 1.")
|
|
logging.error("Cannot proceed to Step 2: Invalid run_id or topics_list from Step 1.")
|
|
|
|
# --- Finalize Output ---
|
|
if run_id:
|
|
output_handler.finalize(run_id)
|
|
# --- End Finalize ---
|
|
|
|
pipeline_end_time = time.time()
|
|
# print(f"\nPipeline finished. Total execution time: {pipeline_end_time - pipeline_start_time:.2f} seconds.")
|
|
logging.info(f"Pipeline finished. Total execution time: {pipeline_end_time - pipeline_start_time:.2f} seconds.")
|
|
# print(f"Results for run_id '{run_id}' are in: {os.path.join(config.get('output_dir', 'result'), run_id)}")
|
|
logging.info(f"Results for run_id '{run_id}' are in: {os.path.join(config.get('output_dir', 'result'), run_id)}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
|
|
|