迁移了选题模块

This commit is contained in:
jinye_huang 2025-04-22 14:16:29 +08:00
parent 00eebd6270
commit 0c9e7f90ae
4 changed files with 139 additions and 79 deletions

View File

@ -5,12 +5,14 @@
""" """
import os import os
import sys import sys
import traceback
# 添加项目根目录到Python路径 # 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 导入所需模块 # 导入所需模块
from main import load_config, generate_topics_step from main import load_config
from utils.tweet_generator import run_topic_generation_pipeline
if __name__ == "__main__": if __name__ == "__main__":
print("==== 阶段 1: 仅生成选题 ====") print("==== 阶段 1: 仅生成选题 ====")
@ -24,7 +26,7 @@ if __name__ == "__main__":
# 2. 执行选题生成 # 2. 执行选题生成
print("\n执行选题生成...") print("\n执行选题生成...")
run_id, tweet_topic_record = generate_topics_step(config) run_id, tweet_topic_record = run_topic_generation_pipeline(config)
if run_id and tweet_topic_record: if run_id and tweet_topic_record:
output_dir = config.get("output_dir", "./result") output_dir = config.get("output_dir", "./result")
@ -40,6 +42,5 @@ if __name__ == "__main__":
except Exception as e: except Exception as e:
print(f"\n处理过程中出错: {e}") print(f"\n处理过程中出错: {e}")
import traceback
traceback.print_exc() traceback.print_exc()
sys.exit(1) sys.exit(1)

View File

@ -13,6 +13,8 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 导入所需模块 # 导入所需模块
from main import load_config, generate_topics_step, generate_content_and_posters_step from main import load_config, generate_topics_step, generate_content_and_posters_step
from utils.tweet_generator import run_topic_generation_pipeline
from core.topic_parser import TopicParser
def test_full_workflow(): def test_full_workflow():
"""测试完整的工作流程,从选题生成到海报制作""" """测试完整的工作流程,从选题生成到海报制作"""
@ -26,7 +28,7 @@ def test_full_workflow():
# 2. 执行选题生成 # 2. 执行选题生成
print("\n步骤 2: 生成选题...") print("\n步骤 2: 生成选题...")
run_id, tweet_topic_record = generate_topics_step(config) run_id, tweet_topic_record = run_topic_generation_pipeline(config)
if not run_id or not tweet_topic_record: if not run_id or not tweet_topic_record:
print("选题生成失败,测试终止。") print("选题生成失败,测试终止。")
@ -84,7 +86,7 @@ def test_steps_separately():
# 2. 仅执行选题生成 # 2. 仅执行选题生成
print("\n步骤 2: 仅测试选题生成...") print("\n步骤 2: 仅测试选题生成...")
run_id, tweet_topic_record = generate_topics_step(test_config) run_id, tweet_topic_record = run_topic_generation_pipeline(test_config)
if not run_id or not tweet_topic_record: if not run_id or not tweet_topic_record:
print("选题生成失败,测试终止。") print("选题生成失败,测试终止。")
@ -104,7 +106,6 @@ def test_steps_separately():
# 这部分通常是由main函数中的流程自动处理的 # 这部分通常是由main函数中的流程自动处理的
# 这里为了演示分段流程,模拟手动加载数据并处理 # 这里为了演示分段流程,模拟手动加载数据并处理
from core.topic_parser import TopicParser
if os.path.exists(topics_file): if os.path.exists(topics_file):
with open(topics_file, 'r', encoding='utf-8') as f: with open(topics_file, 'r', encoding='utf-8') as f:

142
main.py
View File

@ -7,15 +7,15 @@ import traceback
import json import json
from core.ai_agent import AI_Agent from core.ai_agent import AI_Agent
from core.topic_parser import TopicParser # from core.topic_parser import TopicParser # No longer needed directly in main?
import core.contentGen as contentGen import core.contentGen as contentGen
import core.posterGen as posterGen import core.posterGen as posterGen
import core.simple_collage as simple_collage import core.simple_collage as simple_collage
from utils.resource_loader import ResourceLoader from utils.resource_loader import ResourceLoader
from utils.tweet_generator import prepare_topic_generation, generate_topics, generate_single_content from utils.tweet_generator import generate_single_content, run_topic_generation_pipeline # Import the new pipeline function
import random import random
TEXT_POSBILITY = 0.3 TEXT_POSBILITY = 0.3 # Consider moving this to config if it varies
def load_config(config_path="poster_gen_config.json"): def load_config(config_path="poster_gen_config.json"):
"""Loads configuration from a JSON file.""" """Loads configuration from a JSON file."""
@ -29,8 +29,8 @@ def load_config(config_path="poster_gen_config.json"):
# Basic validation (can be expanded) # Basic validation (can be expanded)
required_keys = ["api_url", "model", "api_key", "resource_dir", "prompts_dir", "output_dir", "num", "variants", "topic_system_prompt", "topic_user_prompt", "content_system_prompt", "image_base_dir"] required_keys = ["api_url", "model", "api_key", "resource_dir", "prompts_dir", "output_dir", "num", "variants", "topic_system_prompt", "topic_user_prompt", "content_system_prompt", "image_base_dir"]
if not all(key in config for key in required_keys): if not all(key in config for key in required_keys):
print(f"Error: Config file '{config_path}' is missing one or more required keys.") missing_keys = [key for key in required_keys if key not in config]
print(f"Required keys are: {required_keys}") print(f"Error: Config file '{config_path}' is missing required keys: {missing_keys}")
sys.exit(1) sys.exit(1)
# Resolve relative paths based on config location or a defined base path if necessary # Resolve relative paths based on config location or a defined base path if necessary
# For simplicity, assuming paths in config are relative to project root or absolute # For simplicity, assuming paths in config are relative to project root or absolute
@ -42,35 +42,8 @@ def load_config(config_path="poster_gen_config.json"):
print(f"Error loading configuration from '{config_path}': {e}") print(f"Error loading configuration from '{config_path}': {e}")
sys.exit(1) sys.exit(1)
# Removed generate_topics_step function definition from here
def generate_topics_step(config): # Its logic is now in utils.tweet_generator.run_topic_generation_pipeline
"""Generates topics based on the configuration."""
print("Step 1: Generating Topics...")
ai_agent, system_prompt, user_prompt, base_output_dir = prepare_topic_generation(
config.get("date", datetime.now().strftime("%Y-%m-%d")), # Use current date if not specified
config["num"], config["topic_system_prompt"], config["topic_user_prompt"],
config["api_url"], config["model"], config["api_key"], config["prompts_dir"],
config["resource_dir"], config["output_dir"]
)
run_id, tweet_topic_record = generate_topics(
ai_agent, system_prompt, user_prompt, config["output_dir"],
config.get("topic_temperature", 0.2), config.get("topic_top_p", 0.5), config.get("topic_max_tokens", 1.5) # Added defaults for safety
)
if not run_id or not tweet_topic_record:
print("Topic generation failed. Exiting.")
ai_agent.close()
return None, None
output_dir = os.path.join(config["output_dir"], run_id)
os.makedirs(output_dir, exist_ok=True)
tweet_topic_record.save_topics(os.path.join(output_dir, "tweet_topic.json"))
tweet_topic_record.save_prompt(os.path.join(output_dir, "tweet_prompt.txt"))
ai_agent.close()
print(f"Topics generated successfully. Run ID: {run_id}")
return run_id, tweet_topic_record
def generate_content_and_posters_step(config, run_id, tweet_topic_record): def generate_content_and_posters_step(config, run_id, tweet_topic_record):
"""Generates content and posters based on generated topics.""" """Generates content and posters based on generated topics."""
@ -78,55 +51,75 @@ def generate_content_and_posters_step(config, run_id, tweet_topic_record):
print("Missing run_id or topics data. Skipping content and poster generation.") print("Missing run_id or topics data. Skipping content and poster generation.")
return return
print("Step 2: Generating Content and Posters...") print("\nStep 2: Generating Content and Posters...")
base_output_dir = config["output_dir"] base_output_dir = config["output_dir"]
output_dir = os.path.join(base_output_dir, run_id) # Directory for this specific run output_dir = os.path.join(base_output_dir, run_id) # Directory for this specific run
# Load content generation system prompt # --- Pre-load resources and initialize shared objects ---
# Load content generation system prompt once
content_system_prompt = ResourceLoader.load_system_prompt(config["content_system_prompt"]) content_system_prompt = ResourceLoader.load_system_prompt(config["content_system_prompt"])
if not content_system_prompt: if not content_system_prompt:
print("Warning: Content generation system prompt is empty. Using default logic if available or might fail.") print("Warning: Content generation system prompt is empty. Using default logic if available or might fail.")
# Potentially load topic system prompt as fallback if needed, or handle error
# content_system_prompt = ResourceLoader.load_system_prompt(config["topic_system_prompt"])
# Initialize AI Agent once for the entire content generation phase
ai_agent = None
try:
print(f"Initializing AI Agent ({config['model']})...")
ai_agent = AI_Agent(config["api_url"], config["model"], config["api_key"])
except Exception as e:
print(f"Error initializing AI Agent: {e}. Cannot proceed with content generation.")
traceback.print_exc()
return # Cannot continue without AI agent
# Check image base directory
image_base_dir = config.get("image_base_dir", None) image_base_dir = config.get("image_base_dir", None)
if not image_base_dir: if not image_base_dir or not os.path.isdir(image_base_dir):
print("Error: 'image_base_dir' not specified in config. Cannot locate images.") print(f"Error: 'image_base_dir' ({image_base_dir}) not specified or not a valid directory in config. Cannot locate images.")
if ai_agent: ai_agent.close() # Close agent if initialized
return return
camera_image_subdir = config.get("camera_image_subdir", "相机") # Default '相机' camera_image_subdir = config.get("camera_image_subdir", "相机") # Default '相机'
modify_image_subdir = config.get("modify_image_subdir", "modify") # Default 'modify' modify_image_subdir = config.get("modify_image_subdir", "modify") # Default 'modify'
# Initialize ContentGenerator and PosterGenerator once if they are stateless
# Assuming they are stateless for now
content_gen = contentGen.ContentGenerator()
poster_gen_instance = posterGen.PosterGenerator()
# --- Process each topic ---
for i, topic in enumerate(tweet_topic_record.topics_list): for i, topic in enumerate(tweet_topic_record.topics_list):
topic_index = i + 1 topic_index = i + 1
print(f"Processing Topic {topic_index}/{len(tweet_topic_record.topics_list)}: {topic.get('title', 'N/A')}") print(f"\nProcessing Topic {topic_index}/{len(tweet_topic_record.topics_list)}: {topic.get('title', 'N/A')}")
tweet_content_list = [] tweet_content_list = []
# --- Content Generation Loop --- # --- Content Generation Loop (using the single AI Agent) ---
for j in range(config["variants"]): for j in range(config["variants"]):
variant_index = j + 1 variant_index = j + 1
print(f" Generating Variant {variant_index}/{config['variants']}...") print(f" Generating Variant {variant_index}/{config['variants']}...")
time.sleep(random.random()) # Keep the random delay? Okay for now. time.sleep(random.random() * 0.5) # Slightly reduced delay
ai_agent = AI_Agent(config["api_url"], config["model"], config["api_key"])
try: try:
# Use the pre-initialized AI Agent
tweet_content, gen_result = generate_single_content( tweet_content, gen_result = generate_single_content(
ai_agent, content_system_prompt, topic, ai_agent, content_system_prompt, topic,
config["prompts_dir"], config["resource_dir"], config["prompts_dir"], config["resource_dir"],
output_dir, run_id, topic_index, variant_index, config.get("content_temperature", 0.3) # Added default output_dir, run_id, topic_index, variant_index, config.get("content_temperature", 0.3)
) )
if tweet_content: if tweet_content:
tweet_content_list.append(tweet_content.get_json_file()) # Assuming this returns the structured data needed later # Assuming get_json_file() returns a dictionary or similar structure
tweet_content_data = tweet_content.get_json_file()
if tweet_content_data:
tweet_content_list.append(tweet_content_data)
else:
print(f" Warning: generate_single_content for Topic {topic_index}, Variant {variant_index} returned empty data.")
else: else:
print(f" Failed to generate content for Topic {topic_index}, Variant {variant_index}. Skipping.") print(f" Failed to generate content for Topic {topic_index}, Variant {variant_index}. Skipping.")
except Exception as e: except Exception as e:
print(f" Error during content generation for Topic {topic_index}, Variant {variant_index}: {e}") print(f" Error during content generation for Topic {topic_index}, Variant {variant_index}: {e}")
traceback.print_exc() # Decide if traceback is needed here, might be too verbose for loop errors
finally: # traceback.print_exc()
ai_agent.close() # Ensure agent is closed # Do NOT close the agent here
if not tweet_content_list: if not tweet_content_list:
print(f" No content generated for Topic {topic_index}. Skipping poster generation.") print(f" No valid content generated for Topic {topic_index}. Skipping poster generation.")
continue continue
# --- Poster Generation Setup --- # --- Poster Generation Setup ---
@ -137,21 +130,22 @@ def generate_content_and_posters_step(config, run_id, tweet_topic_record):
# Clean object name (consider making this a utility function) # Clean object name (consider making this a utility function)
try: try:
object_name = object_name.split(".")[0] # More robust cleaning might be needed depending on actual object name formats
if "景点信息-" in object_name: object_name_cleaned = object_name.split(".")[0].replace("景点信息-", "").strip()
object_name = object_name.split("景点信息-")[1] if not object_name_cleaned:
# Handle cases like "景点A+景点B"? Needs clearer logic if required. print(f" Warning: Object name '{object_name}' resulted in empty string after cleaning.")
continue
object_name = object_name_cleaned
except Exception as e: except Exception as e:
print(f" Warning: Could not fully clean object name '{object_name}': {e}") print(f" Warning: Could not fully clean object name '{object_name}': {e}")
# Continue with potentially unclean name? Or skip?
# Let's continue for now, path checks below might catch issues.
# Construct and check image paths using config base dir # Construct and check image paths using config base dir
# Path for collage/poster input images (e.g., from 'modify' dir)
input_img_dir_path = os.path.join(image_base_dir, modify_image_subdir, object_name) input_img_dir_path = os.path.join(image_base_dir, modify_image_subdir, object_name)
# Path for potential description file (e.g., from '相机' dir)
camera_img_dir_path = os.path.join(image_base_dir, camera_image_subdir, object_name) camera_img_dir_path = os.path.join(image_base_dir, camera_image_subdir, object_name)
description_file_path = os.path.join(camera_img_dir_path, "description.txt") description_file_path = os.path.join(camera_img_dir_path, "description.txt")
if not os.path.exists(input_img_dir_path) or not os.path.isdir(input_img_dir_path): if not os.path.exists(input_img_dir_path) or not os.path.isdir(input_img_dir_path):
print(f" Image directory not found or not a directory: '{input_img_dir_path}'. Skipping poster generation for this topic.") print(f" Image directory not found or not a directory: '{input_img_dir_path}'. Skipping poster generation for this topic.")
continue continue
@ -164,21 +158,22 @@ def generate_content_and_posters_step(config, run_id, tweet_topic_record):
print(f" Description file not found: '{description_file_path}'. Using generated content for poster text.") print(f" Description file not found: '{description_file_path}'. Using generated content for poster text.")
# --- Generate Text Configurations for Posters --- # --- Generate Text Configurations for Posters ---
content_gen = contentGen.ContentGenerator()
try: try:
# Assuming tweet_content_list contains the JSON data needed by content_gen.run # Pass the list of content data directly
poster_text_configs_raw = content_gen.run(info_directory, config["variants"], tweet_content_list) poster_text_configs_raw = content_gen.run(info_directory, config["variants"], tweet_content_list)
print(f" Raw poster text configs: {poster_text_configs_raw}") # For debugging # print(f" Raw poster text configs: {poster_text_configs_raw}") # For debugging
if not poster_text_configs_raw:
print(" Warning: ContentGenerator returned empty configuration data.")
continue # Skip if no text configs generated
poster_config_summary = posterGen.PosterConfig(poster_text_configs_raw) poster_config_summary = posterGen.PosterConfig(poster_text_configs_raw)
except Exception as e: except Exception as e:
print(f" Error running ContentGenerator or parsing poster configs: {e}") print(f" Error running ContentGenerator or parsing poster configs: {e}")
traceback.print_exc() traceback.print_exc()
continue # Skip poster generation for this topic continue # Skip poster generation for this topic
# --- Poster Generation Loop --- # --- Poster Generation Loop ---
poster_num = config["variants"] # Same as content variants poster_num = config["variants"] # Same as content variants
target_size = tuple(config.get("poster_target_size", [900, 1200])) # Add default size target_size = tuple(config.get("poster_target_size", [900, 1200]))
for j_index in range(poster_num): for j_index in range(poster_num):
variant_index = j_index + 1 variant_index = j_index + 1
@ -202,7 +197,6 @@ def generate_content_and_posters_step(config, run_id, tweet_topic_record):
output_count=1, # Assuming 1 collage image per poster variant output_count=1, # Assuming 1 collage image per poster variant
output_dir=collage_output_dir output_dir=collage_output_dir
) )
# print(f" Collage image list: {img_list}") # Debugging
if not img_list or len(img_list) == 0 or not img_list[0].get('path'): if not img_list or len(img_list) == 0 or not img_list[0].get('path'):
print(f" Failed to generate collage image for Variant {variant_index}. Skipping poster.") print(f" Failed to generate collage image for Variant {variant_index}. Skipping poster.")
@ -210,24 +204,18 @@ def generate_content_and_posters_step(config, run_id, tweet_topic_record):
collage_img_path = img_list[0]['path'] collage_img_path = img_list[0]['path']
print(f" Using collage image: {collage_img_path}") print(f" Using collage image: {collage_img_path}")
# --- Create Poster --- # --- Create Poster (using the single poster_gen_instance) ---
poster_gen_instance = posterGen.PosterGenerator() # Renamed to avoid conflict
# Prepare text data (Simplified logic, adjust TEXT_POSBILITY if needed)
# Consider moving text data preparation into posterGen or a dedicated function
text_data = { text_data = {
"title": poster_config.get('main_title', 'Default Title'), "title": poster_config.get('main_title', 'Default Title'),
"subtitle": "", # Subtitle seems unused? "subtitle": "",
"additional_texts": [] "additional_texts": []
} }
texts = poster_config.get('texts', []) texts = poster_config.get('texts', [])
if texts: if texts:
text_data["additional_texts"].append({"text": texts[0], "position": "bottom", "size_factor": 0.5}) text_data["additional_texts"].append({"text": texts[0], "position": "bottom", "size_factor": 0.5})
if len(texts) > 1 and random.random() < TEXT_POSBILITY: # Apply possibility check if len(texts) > 1 and random.random() < TEXT_POSBILITY:
text_data["additional_texts"].append({"text": texts[1], "position": "bottom", "size_factor": 0.5}) text_data["additional_texts"].append({"text": texts[1], "position": "bottom", "size_factor": 0.5})
# print(f" Text data for poster: {text_data}") # Debugging
final_poster_path = os.path.join(poster_output_dir, "poster.jpg") final_poster_path = os.path.join(poster_output_dir, "poster.jpg")
result_path = poster_gen_instance.create_poster(collage_img_path, text_data, final_poster_path) result_path = poster_gen_instance.create_poster(collage_img_path, text_data, final_poster_path)
if result_path: if result_path:
@ -240,13 +228,21 @@ def generate_content_and_posters_step(config, run_id, tweet_topic_record):
traceback.print_exc() traceback.print_exc()
continue # Continue to next variant continue # Continue to next variant
# --- Cleanup ---
# Close the AI Agent after processing all topics
if ai_agent:
print("\nClosing AI Agent...")
ai_agent.close()
def main(): def main():
# No argparse for now, directly load default config # No argparse for now, directly load default config
config = load_config() # Load from poster_gen_config.json config = load_config() # Load from poster_gen_config.json
# Execute steps sequentially # Execute steps sequentially
run_id, tweet_topic_record = generate_topics_step(config) # Step 1: Generate Topics (using the function from utils.tweet_generator)
run_id, tweet_topic_record = run_topic_generation_pipeline(config)
# Step 2: Generate Content and Posters (if Step 1 was successful)
if run_id and tweet_topic_record: if run_id and tweet_topic_record:
generate_content_and_posters_step(config, run_id, tweet_topic_record) generate_content_and_posters_step(config, run_id, tweet_topic_record)
else: else:

View File

@ -8,6 +8,7 @@ import argparse
import json import json
from datetime import datetime from datetime import datetime
import sys import sys
import traceback
sys.path.append('/root/autodl-tmp') sys.path.append('/root/autodl-tmp')
# 从本地模块导入 # 从本地模块导入
from TravelContentCreator.core.ai_agent import AI_Agent from TravelContentCreator.core.ai_agent import AI_Agent
@ -259,6 +260,67 @@ def prepare_topic_generation(
return ai_agent, system_prompt, user_prompt, output_dir return ai_agent, system_prompt, user_prompt, output_dir
def run_topic_generation_pipeline(config):
"""Runs the complete topic generation pipeline based on the configuration."""
print("Step 1: Generating Topics...")
# Prepare necessary inputs and the AI agent for topic generation
# Note: prepare_topic_generation already initializes an AI_Agent
try:
ai_agent, system_prompt, user_prompt, base_output_dir = prepare_topic_generation(
config.get("date", datetime.now().strftime("%Y-%m-%d")), # Use current date if not specified
config["num"], config["topic_system_prompt"], config["topic_user_prompt"],
config["api_url"], config["model"], config["api_key"], config["prompts_dir"],
config["resource_dir"], config["output_dir"]
)
except Exception as e:
print(f"Error during topic generation preparation: {e}")
traceback.print_exc()
return None, None
# Generate topics using the prepared agent and prompts
try:
run_id, tweet_topic_record = generate_topics(
ai_agent, system_prompt, user_prompt, config["output_dir"],
config.get("topic_temperature", 0.2),
config.get("topic_top_p", 0.5),
config.get("topic_max_tokens", 1.5) # Consider if max_tokens name is accurate here (was presence_penalty?)
)
except Exception as e:
print(f"Error during topic generation API call: {e}")
traceback.print_exc()
if ai_agent: ai_agent.close() # Ensure agent is closed on error
return None, None
# Ensure the AI agent is closed after generation
if ai_agent:
ai_agent.close()
# Process results
if not run_id or not tweet_topic_record:
print("Topic generation failed (no run_id or topics returned).")
return None, None
output_dir = os.path.join(config["output_dir"], run_id)
try:
os.makedirs(output_dir, exist_ok=True)
# Save topics and prompt details
save_topics_success = tweet_topic_record.save_topics(os.path.join(output_dir, "tweet_topic.json"))
save_prompt_success = tweet_topic_record.save_prompt(os.path.join(output_dir, "tweet_prompt.txt"))
if not save_topics_success or not save_prompt_success:
print("Warning: Failed to save topic generation results or prompts.")
# Continue but warn user
except Exception as e:
print(f"Error saving topic generation results: {e}")
traceback.print_exc()
# Return the generated data even if saving fails, but maybe warn more strongly?
# return run_id, tweet_topic_record # Decide if partial success is okay
return None, None # Or consider failure if saving is critical
print(f"Topics generated successfully. Run ID: {run_id}")
return run_id, tweet_topic_record
def main(): def main():
"""主函数入口""" """主函数入口"""
config_file = { config_file = {