696 lines
33 KiB
Python
696 lines
33 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
import os
|
||
import time
|
||
import random
|
||
import argparse
|
||
import json
|
||
from datetime import datetime
|
||
import sys
|
||
import traceback
|
||
import logging # Add logging
|
||
# sys.path.append('/root/autodl-tmp') # No longer needed if running as a module or if path is set correctly
|
||
# 从本地模块导入
|
||
# from TravelContentCreator.core.ai_agent import AI_Agent # Remove project name prefix
|
||
# from TravelContentCreator.core.topic_parser import TopicParser # Remove project name prefix
|
||
# ResourceLoader is now used implicitly via PromptManager
|
||
# from TravelContentCreator.utils.resource_loader import ResourceLoader
|
||
# from TravelContentCreator.utils.prompt_manager import PromptManager # Remove project name prefix
|
||
# from ..core import contentGen as core_contentGen # Change to absolute import
|
||
# from ..core import posterGen as core_posterGen # Change to absolute import
|
||
# from ..core import simple_collage as core_simple_collage # Change to absolute import
|
||
from core.ai_agent import AI_Agent
|
||
from core.topic_parser import TopicParser
|
||
from utils.prompt_manager import PromptManager # Keep this as it's importing from the same level package 'utils'
|
||
from core import contentGen as core_contentGen
|
||
from core import poster_gen as core_posterGen
|
||
from core import simple_collage as core_simple_collage
|
||
from .output_handler import OutputHandler # <-- 添加导入
|
||
|
||
class tweetTopic:
|
||
def __init__(self, index, date, logic, object, product, product_logic, style, style_logic, target_audience, target_audience_logic):
|
||
self.index = index
|
||
self.date = date
|
||
self.logic = logic
|
||
self.object = object
|
||
self.product = product
|
||
self.product_logic = product_logic
|
||
self.style = style
|
||
self.style_logic = style_logic
|
||
self.target_audience = target_audience
|
||
self.target_audience_logic = target_audience_logic
|
||
|
||
class tweetTopicRecord:
|
||
def __init__(self, topics_list, system_prompt, user_prompt, run_id):
|
||
self.topics_list = topics_list
|
||
self.system_prompt = system_prompt
|
||
self.user_prompt = user_prompt
|
||
self.run_id = run_id
|
||
|
||
class tweetContent:
|
||
def __init__(self, result, prompt, run_id, article_index, variant_index):
|
||
self.result = result
|
||
self.prompt = prompt
|
||
self.run_id = run_id
|
||
self.article_index = article_index
|
||
self.variant_index = variant_index
|
||
|
||
try:
|
||
self.title, self.content = self.split_content(result)
|
||
self.json_data = self.gen_result_json()
|
||
except Exception as e:
|
||
logging.error(f"Failed to parse AI result for {article_index}_{variant_index}: {e}")
|
||
logging.debug(f"Raw result: {result[:500]}...") # Log partial raw result
|
||
self.title = "[Parsing Error]"
|
||
self.content = "[Failed to parse AI content]"
|
||
self.json_data = {"title": self.title, "content": self.content, "error": True, "raw_result": result}
|
||
|
||
def split_content(self, result):
|
||
# Assuming split logic might still fail, keep it simple or improve with regex/json
|
||
# We should ideally switch content generation to JSON output as well.
|
||
# For now, keep existing logic but handle errors in __init__.
|
||
|
||
# Optional: Add basic check before splitting
|
||
if not result or "</think>" not in result or "title>" not in result or "content>" not in result:
|
||
logging.warning(f"AI result format unexpected: {result[:200]}...")
|
||
# Raise error to be caught in __init__
|
||
raise ValueError("AI result missing expected tags or is empty")
|
||
|
||
# --- Existing Logic (prone to errors) ---
|
||
processed_result = result
|
||
if "</think>" in result:
|
||
processed_result = result.split("</think>", 1)[1] # Take part after </think>
|
||
|
||
title = processed_result.split("title>", 1)[1].split("</title>", 1)[0]
|
||
content = processed_result.split("content>", 1)[1].split("</content>", 1)[0]
|
||
# --- End Existing Logic ---
|
||
return title.strip(), content.strip()
|
||
|
||
def gen_result_json(self):
|
||
json_file = {
|
||
"title": self.title,
|
||
"content": self.content
|
||
}
|
||
# Add error flag if it exists
|
||
if hasattr(self, 'json_data') and self.json_data.get('error'):
|
||
json_file['error'] = True
|
||
json_file['raw_result'] = self.json_data.get('raw_result')
|
||
return json_file
|
||
|
||
def get_json_data(self):
|
||
"""Returns the generated JSON data dictionary."""
|
||
return self.json_data
|
||
|
||
def get_prompt(self):
|
||
"""Returns the user prompt used to generate this content."""
|
||
return self.prompt
|
||
|
||
def get_content(self):
|
||
return self.content
|
||
|
||
def get_title(self):
|
||
return self.title
|
||
|
||
|
||
def generate_topics(ai_agent, system_prompt, user_prompt, run_id, temperature=0.2, top_p=0.5, presence_penalty=1.5):
|
||
"""生成选题列表 (run_id is now passed in, output_dir removed as argument)"""
|
||
logging.info("Starting topic generation...")
|
||
time_start = time.time()
|
||
|
||
# Call AI agent work method (updated return values)
|
||
result, tokens, time_cost = ai_agent.work(
|
||
system_prompt, user_prompt, "", temperature, top_p, presence_penalty
|
||
)
|
||
|
||
logging.info(f"Topic generation API call completed in {time_cost:.2f}s. Estimated tokens: {tokens}")
|
||
|
||
# Parse topics
|
||
result_list = TopicParser.parse_topics(result)
|
||
if not result_list:
|
||
logging.warning("Topic parsing resulted in an empty list.")
|
||
# Optionally handle raw result logging here if needed, but saving is responsibility of OutputHandler
|
||
# error_log_path = os.path.join(output_dir, run_id, f"topic_parsing_error_{run_id}.txt") # output_dir is not available here
|
||
# try:
|
||
# # ... (save raw output logic) ...
|
||
# except Exception as log_err:
|
||
# logging.error(f"Failed to save raw AI output on parsing failure: {log_err}")
|
||
|
||
# 直接返回解析后的列表
|
||
return result_list
|
||
|
||
|
||
def generate_single_content(ai_agent, system_prompt, user_prompt, item, run_id,
|
||
article_index, variant_index, temperature=0.3, top_p=0.4, presence_penalty=1.5):
|
||
"""Generates single content variant data. Returns (content_json, user_prompt) or (None, None)."""
|
||
logging.info(f"Generating content for topic {article_index}, variant {variant_index}")
|
||
try:
|
||
if not system_prompt or not user_prompt:
|
||
logging.error("System or User prompt is empty. Cannot generate content.")
|
||
return None, None
|
||
|
||
logging.debug(f"Using pre-constructed prompts. User prompt length: {len(user_prompt)}")
|
||
|
||
time.sleep(random.random() * 0.5)
|
||
|
||
# Generate content (non-streaming work returns result, tokens, time_cost)
|
||
result, tokens, time_cost = ai_agent.work(
|
||
system_prompt, user_prompt, "", temperature, top_p, presence_penalty
|
||
)
|
||
|
||
if result is None: # Check if AI call failed
|
||
logging.error(f"AI agent work failed for {article_index}_{variant_index}. No result returned.")
|
||
return None, None
|
||
|
||
logging.info(f"Content generation for {article_index}_{variant_index} completed in {time_cost:.2f}s. Estimated tokens: {tokens}")
|
||
|
||
# --- Create tweetContent object (handles parsing) ---
|
||
# Pass user_prompt instead of full prompt? Yes, user_prompt is what we need later.
|
||
tweet_content = tweetContent(result, user_prompt, run_id, article_index, variant_index)
|
||
|
||
# --- Remove Saving Logic ---
|
||
# run_specific_output_dir = os.path.join(output_dir, run_id) # output_dir no longer available
|
||
# variant_result_dir = os.path.join(run_specific_output_dir, f"{article_index}_{variant_index}")
|
||
# os.makedirs(variant_result_dir, exist_ok=True)
|
||
# content_save_path = os.path.join(variant_result_dir, "article.json")
|
||
# prompt_save_path = os.path.join(variant_result_dir, "tweet_prompt.txt")
|
||
# tweet_content.save_content(content_save_path) # Method removed
|
||
# tweet_content.save_prompt(prompt_save_path) # Method removed
|
||
# --- End Remove Saving Logic ---
|
||
|
||
# Return the data needed by the output handler
|
||
content_json = tweet_content.get_json_data()
|
||
prompt_data = tweet_content.get_prompt() # Get the stored user prompt
|
||
|
||
return content_json, prompt_data # Return data pair
|
||
|
||
except Exception as e:
|
||
logging.exception(f"Error generating single content for {article_index}_{variant_index}:")
|
||
return None, None
|
||
|
||
def generate_content(ai_agent, system_prompt, topics, output_dir, run_id, prompts_dir, resource_dir,
|
||
variants=2, temperature=0.3, start_index=0, end_index=None):
|
||
"""根据选题生成内容"""
|
||
if not topics:
|
||
print("没有选题,无法生成内容")
|
||
return
|
||
|
||
# 确定处理范围
|
||
if end_index is None or end_index > len(topics):
|
||
end_index = len(topics)
|
||
|
||
topics_to_process = topics[start_index:end_index]
|
||
print(f"准备处理{len(topics_to_process)}个选题...")
|
||
|
||
# 创建汇总文件
|
||
# summary_file = ResourceLoader.create_summary_file(output_dir, run_id, len(topics_to_process))
|
||
|
||
# 处理每个选题
|
||
processed_results = []
|
||
for i, item in enumerate(topics_to_process):
|
||
print(f"处理第 {i+1}/{len(topics_to_process)} 篇文章")
|
||
|
||
# 为每个选题生成多个变体
|
||
for j in range(variants):
|
||
print(f"正在生成变体 {j+1}/{variants}")
|
||
|
||
# 调用单篇文章生成函数
|
||
tweet_content, result = generate_single_content(
|
||
ai_agent, system_prompt, item, run_id, i+1, j+1, temperature
|
||
)
|
||
|
||
if tweet_content:
|
||
processed_results.append(tweet_content)
|
||
|
||
# # 更新汇总文件 (仅保存第一个变体到汇总文件)
|
||
# if j == 0:
|
||
# ResourceLoader.update_summary(summary_file, i+1, user_prompt, result)
|
||
|
||
print(f"完成{len(processed_results)}篇文章生成")
|
||
return processed_results
|
||
|
||
|
||
def prepare_topic_generation(prompt_manager: PromptManager,
|
||
api_url: str,
|
||
model_name: str,
|
||
api_key: str,
|
||
timeout: int,
|
||
max_retries: int):
|
||
"""准备选题生成的环境和参数. Returns agent, system_prompt, user_prompt.
|
||
|
||
Args:
|
||
prompt_manager: An initialized PromptManager instance.
|
||
api_url, model_name, api_key, timeout, max_retries: Parameters for AI_Agent.
|
||
"""
|
||
logging.info("Preparing for topic generation (using provided PromptManager)...")
|
||
# 从传入的 PromptManager 获取 prompts
|
||
system_prompt, user_prompt = prompt_manager.get_topic_prompts()
|
||
|
||
if not system_prompt or not user_prompt:
|
||
logging.error("Failed to get topic generation prompts from PromptManager.")
|
||
return None, None, None # Return three Nones
|
||
|
||
# 使用传入的参数初始化 AI Agent
|
||
try:
|
||
logging.info("Initializing AI Agent for topic generation...")
|
||
ai_agent = AI_Agent(
|
||
api_url, # Use passed arg
|
||
model_name, # Use passed arg
|
||
api_key, # Use passed arg
|
||
timeout=timeout, # Use passed arg
|
||
max_retries=max_retries # Use passed arg
|
||
)
|
||
except Exception as e:
|
||
logging.exception("Error initializing AI Agent for topic generation:")
|
||
return None, None, None # Return three Nones
|
||
|
||
# 返回 agent 和 prompts
|
||
return ai_agent, system_prompt, user_prompt
|
||
|
||
def run_topic_generation_pipeline(config, run_id=None):
|
||
"""
|
||
Runs the complete topic generation pipeline based on the configuration.
|
||
Returns: (run_id, topics_list, system_prompt, user_prompt) or (None, None, None, None) on failure.
|
||
"""
|
||
logging.info("Starting Step 1: Topic Generation Pipeline...")
|
||
|
||
if run_id is None:
|
||
logging.info("No run_id provided, generating one based on timestamp.")
|
||
run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||
else:
|
||
logging.info(f"Using provided run_id: {run_id}")
|
||
|
||
ai_agent, system_prompt, user_prompt = None, None, None # Initialize
|
||
topics_list = None
|
||
prompt_manager = None # Initialize prompt_manager
|
||
try:
|
||
# --- 读取 PromptManager 所需参数 ---
|
||
topic_sys_prompt_path = config.get("topic_system_prompt")
|
||
topic_user_prompt_path = config.get("topic_user_prompt")
|
||
content_sys_prompt_path = config.get("content_system_prompt") # 虽然这里不用,但 PromptManager 可能需要
|
||
prompts_dir_path = config.get("prompts_dir")
|
||
prompts_config = config.get("prompts_config") # 新增:获取prompts_config配置
|
||
resource_config = config.get("resource_dir", [])
|
||
topic_num = config.get("num", 1)
|
||
topic_date = config.get("date", "")
|
||
|
||
# --- 创建 PromptManager 实例 ---
|
||
prompt_manager = PromptManager(
|
||
topic_system_prompt_path=topic_sys_prompt_path,
|
||
topic_user_prompt_path=topic_user_prompt_path,
|
||
content_system_prompt_path=content_sys_prompt_path,
|
||
prompts_dir=prompts_dir_path,
|
||
prompts_config=prompts_config, # 新增:传入prompts_config配置
|
||
resource_dir_config=resource_config,
|
||
topic_gen_num=topic_num,
|
||
topic_gen_date=topic_date
|
||
)
|
||
logging.info("PromptManager instance created.")
|
||
|
||
# --- 读取 AI Agent 所需参数 ---
|
||
ai_api_url = config.get("api_url")
|
||
ai_model = config.get("model")
|
||
ai_api_key = config.get("api_key")
|
||
ai_timeout = config.get("request_timeout", 30)
|
||
ai_max_retries = config.get("max_retries", 3)
|
||
|
||
# 检查必需的 AI 参数是否存在
|
||
if not all([ai_api_url, ai_model, ai_api_key]):
|
||
raise ValueError("Missing required AI configuration (api_url, model, api_key) in config.")
|
||
|
||
# --- 调用修改后的 prepare_topic_generation ---
|
||
ai_agent, system_prompt, user_prompt = prepare_topic_generation(
|
||
prompt_manager, # Pass instance
|
||
ai_api_url,
|
||
ai_model,
|
||
ai_api_key,
|
||
ai_timeout,
|
||
ai_max_retries
|
||
)
|
||
if not ai_agent or not system_prompt or not user_prompt:
|
||
raise ValueError("Failed to prepare topic generation (agent or prompts missing).")
|
||
|
||
# --- Generate topics (保持不变) ---
|
||
topics_list = generate_topics(
|
||
ai_agent, system_prompt, user_prompt,
|
||
run_id, # Pass run_id
|
||
config.get("topic_temperature", 0.2),
|
||
config.get("topic_top_p", 0.5),
|
||
config.get("topic_presence_penalty", 1.5)
|
||
)
|
||
except Exception as e:
|
||
logging.exception("Error during topic generation pipeline execution:")
|
||
# Ensure agent is closed even if generation fails mid-way
|
||
if ai_agent: ai_agent.close()
|
||
return None, None, None, None # Signal failure
|
||
finally:
|
||
# Ensure the AI agent is closed after generation attempt (if initialized)
|
||
if ai_agent:
|
||
logging.info("Closing topic generation AI Agent...")
|
||
ai_agent.close()
|
||
|
||
if topics_list is None: # Check if generate_topics returned None (though it currently returns list)
|
||
logging.error("Topic generation failed (generate_topics returned None or an error occurred).")
|
||
return None, None, None, None
|
||
elif not topics_list: # Check if list is empty
|
||
logging.warning(f"Topic generation completed for run {run_id}, but the resulting topic list is empty.")
|
||
# Return empty list and prompts anyway, let caller decide
|
||
|
||
# --- Saving logic removed previously ---
|
||
|
||
logging.info(f"Topic generation pipeline completed successfully. Run ID: {run_id}")
|
||
# Return the raw data needed by the OutputHandler
|
||
return run_id, topics_list, system_prompt, user_prompt
|
||
|
||
# --- Decoupled Functional Units (Moved from main.py) ---
|
||
|
||
def generate_content_for_topic(ai_agent: AI_Agent,
|
||
prompt_manager: PromptManager,
|
||
topic_item: dict,
|
||
run_id: str,
|
||
topic_index: int,
|
||
output_handler: OutputHandler, # Changed name to match convention
|
||
# 添加具体参数,移除 config 和 output_dir
|
||
variants: int,
|
||
temperature: float,
|
||
top_p: float,
|
||
presence_penalty: float):
|
||
"""Generates all content variants for a single topic item and uses OutputHandler.
|
||
|
||
Args:
|
||
ai_agent: Initialized AI_Agent instance.
|
||
prompt_manager: Initialized PromptManager instance.
|
||
topic_item: Dictionary representing the topic.
|
||
run_id: Current run ID.
|
||
topic_index: 1-based index of the topic.
|
||
output_handler: An instance of OutputHandler to process results.
|
||
variants: Number of variants to generate.
|
||
temperature, top_p, presence_penalty: AI generation parameters.
|
||
Returns:
|
||
bool: True if at least one variant was successfully generated and handled, False otherwise.
|
||
"""
|
||
logging.info(f"Generating content for Topic {topic_index} (Object: {topic_item.get('object', 'N/A')})...")
|
||
success_flag = False # Track if any variant succeeded
|
||
# 使用传入的 variants 参数
|
||
# variants = config.get("variants", 1)
|
||
|
||
for j in range(variants):
|
||
variant_index = j + 1
|
||
logging.info(f" Generating Variant {variant_index}/{variants}...")
|
||
|
||
# PromptManager 实例已传入,直接调用
|
||
content_system_prompt, content_user_prompt = prompt_manager.get_content_prompts(topic_item)
|
||
|
||
if not content_system_prompt or not content_user_prompt:
|
||
logging.warning(f" Skipping Variant {variant_index} due to missing content prompts.")
|
||
continue
|
||
|
||
time.sleep(random.random() * 0.5)
|
||
try:
|
||
# Call generate_single_content with passed-in parameters
|
||
content_json, prompt_data = generate_single_content(
|
||
ai_agent, content_system_prompt, content_user_prompt, topic_item,
|
||
run_id, topic_index, variant_index,
|
||
temperature, # 使用传入的参数
|
||
top_p, # 使用传入的参数
|
||
presence_penalty # 使用传入的参数
|
||
)
|
||
|
||
# Check if generation succeeded and parsing was okay (or error handled within json)
|
||
if content_json is not None and prompt_data is not None:
|
||
# Use the output handler to process/save the result
|
||
output_handler.handle_content_variant(
|
||
run_id, topic_index, variant_index, content_json, prompt_data
|
||
)
|
||
success_flag = True # Mark success for this topic
|
||
|
||
# Check specifically if the AI result itself indicated a parsing error internally
|
||
if content_json.get("error"):
|
||
logging.error(f" Content generation for Topic {topic_index}, Variant {variant_index} succeeded but response parsing failed (error flag set in content). Raw data logged by handler.")
|
||
else:
|
||
logging.info(f" Successfully generated and handled content for Topic {topic_index}, Variant {variant_index}.")
|
||
else:
|
||
logging.error(f" Content generation failed for Topic {topic_index}, Variant {variant_index}. Skipping handling.")
|
||
|
||
except Exception as e:
|
||
logging.exception(f" Error during content generation call or handling for Topic {topic_index}, Variant {variant_index}:")
|
||
|
||
# Return the success flag for this topic
|
||
return success_flag
|
||
|
||
def generate_posters_for_topic(topic_item: dict,
|
||
output_dir: str,
|
||
run_id: str,
|
||
topic_index: int,
|
||
output_handler: OutputHandler, # 添加 handler
|
||
variants: int,
|
||
poster_assets_base_dir: str,
|
||
image_base_dir: str,
|
||
img_frame_possibility: float,
|
||
text_bg_possibility: float,
|
||
resource_dir_config: list,
|
||
poster_target_size: tuple,
|
||
text_possibility: float,
|
||
output_collage_subdir: str,
|
||
output_poster_subdir: str,
|
||
output_poster_filename: str,
|
||
system_prompt: str
|
||
):
|
||
"""Generates all posters for a single topic item, handling image data via OutputHandler.
|
||
|
||
Args:
|
||
topic_item: The dictionary representing a single topic.
|
||
output_dir: The base output directory for the entire run (e.g., ./result).
|
||
run_id: The ID for the current run.
|
||
topic_index: The 1-based index of the current topic.
|
||
variants: Number of variants.
|
||
poster_assets_base_dir: Path to poster assets (fonts, frames etc.).
|
||
image_base_dir: Base path for source images.
|
||
img_frame_possibility: Probability of adding image frame.
|
||
text_bg_possibility: Probability of adding text background.
|
||
resource_dir_config: Configuration for resource directories (used for Description).
|
||
poster_target_size: Target size tuple (width, height) for the poster.
|
||
text_possibility: Probability of adding secondary text.
|
||
output_collage_subdir: Subdirectory name for saving collages.
|
||
output_poster_subdir: Subdirectory name for saving posters.
|
||
output_poster_filename: Filename for the final poster.
|
||
system_prompt: System prompt for content generation.
|
||
output_handler: An instance of OutputHandler to process results.
|
||
|
||
Returns:
|
||
True if poster generation was attempted (regardless of individual variant success),
|
||
False if setup failed before attempting variants.
|
||
"""
|
||
logging.info(f"Generating posters for Topic {topic_index} (Object: {topic_item.get('object', 'N/A')})...")
|
||
|
||
# --- Load content data from files ---
|
||
loaded_content_list = []
|
||
logging.info(f"Attempting to load content data for {variants} variants for topic {topic_index}...")
|
||
for j in range(variants):
|
||
variant_index = j + 1
|
||
variant_dir = os.path.join(output_dir, run_id, f"{topic_index}_{variant_index}")
|
||
content_path = os.path.join(variant_dir, "article.json")
|
||
try:
|
||
if os.path.exists(content_path):
|
||
with open(content_path, 'r', encoding='utf-8') as f_content:
|
||
content_data = json.load(f_content)
|
||
if isinstance(content_data, dict) and 'title' in content_data and 'content' in content_data:
|
||
loaded_content_list.append(content_data)
|
||
logging.debug(f" Successfully loaded content from: {content_path}")
|
||
else:
|
||
logging.warning(f" Content file {content_path} has invalid format. Skipping.")
|
||
else:
|
||
logging.warning(f" Content file not found for variant {variant_index}: {content_path}. Skipping.")
|
||
except json.JSONDecodeError:
|
||
logging.error(f" Error decoding JSON from content file: {content_path}. Skipping.")
|
||
except Exception as e:
|
||
logging.exception(f" Error loading content file {content_path}: {e}")
|
||
|
||
if not loaded_content_list:
|
||
logging.error(f"No valid content data loaded for topic {topic_index}. Cannot generate posters.")
|
||
return False
|
||
logging.info(f"Successfully loaded content data for {len(loaded_content_list)} variants.")
|
||
# --- End Load content data ---
|
||
|
||
# Initialize generators using parameters
|
||
try:
|
||
content_gen_instance = core_contentGen.ContentGenerator()
|
||
if not poster_assets_base_dir:
|
||
logging.error("Error: 'poster_assets_base_dir' not provided. Cannot generate posters.")
|
||
return False
|
||
poster_gen_instance = core_posterGen.PosterGenerator(base_dir=poster_assets_base_dir)
|
||
poster_gen_instance.set_img_frame_possibility(img_frame_possibility)
|
||
poster_gen_instance.set_text_bg_possibility(text_bg_possibility)
|
||
except Exception as e:
|
||
logging.exception("Error initializing generators for poster creation:")
|
||
return False
|
||
|
||
# --- Setup: Paths and Object Name ---
|
||
object_name = topic_item.get("object", "")
|
||
if not object_name:
|
||
logging.warning("Warning: Topic object name is missing. Cannot generate posters.")
|
||
return False
|
||
|
||
# Clean object name
|
||
try:
|
||
object_name_cleaned = object_name.split(".")[0].replace("景点信息-", "").strip()
|
||
if not object_name_cleaned:
|
||
logging.warning(f"Warning: Object name '{object_name}' resulted in empty string after cleaning. Skipping posters.")
|
||
return False
|
||
object_name = object_name_cleaned
|
||
except Exception as e:
|
||
logging.warning(f"Warning: Could not fully clean object name '{object_name}': {e}. Skipping posters.")
|
||
return False
|
||
|
||
# Construct and check INPUT image paths
|
||
input_img_dir_path = os.path.join(image_base_dir, object_name)
|
||
if not os.path.exists(input_img_dir_path) or not os.path.isdir(input_img_dir_path):
|
||
logging.warning(f"Warning: Modify Image directory not found or not a directory: '{input_img_dir_path}'. Skipping posters for this topic.")
|
||
return False
|
||
|
||
# Locate Description File using resource_dir_config parameter
|
||
info_directory = []
|
||
description_file_path = None
|
||
found_description = False
|
||
for dir_info in resource_dir_config:
|
||
if dir_info.get("type") == "Description":
|
||
for file_path in dir_info.get("file_path", []):
|
||
if object_name in os.path.basename(file_path):
|
||
description_file_path = file_path
|
||
if os.path.exists(description_file_path):
|
||
info_directory = [description_file_path]
|
||
logging.info(f"Found and using description file from config: {description_file_path}")
|
||
found_description = True
|
||
else:
|
||
logging.warning(f"Warning: Description file specified in config not found: {description_file_path}")
|
||
break
|
||
if found_description:
|
||
break
|
||
if not found_description:
|
||
logging.info(f"Warning: No matching description file found for object '{object_name}' in config resource_dir (type='Description').")
|
||
|
||
# Generate Text Configurations for All Variants
|
||
try:
|
||
poster_text_configs_raw = content_gen_instance.run(info_directory, variants, loaded_content_list, system_prompt)
|
||
if not poster_text_configs_raw:
|
||
logging.warning("Warning: ContentGenerator returned empty configuration data. Skipping posters.")
|
||
return False
|
||
|
||
# --- 使用 OutputHandler 保存 Poster Config ---
|
||
output_handler.handle_poster_configs(run_id, topic_index, poster_text_configs_raw)
|
||
# --- 结束使用 Handler 保存 ---
|
||
|
||
# 打印原始配置数据以进行调试
|
||
logging.info(f"生成的海报配置数据: {poster_text_configs_raw}")
|
||
|
||
# 直接使用配置数据,避免通过文件读取
|
||
if isinstance(poster_text_configs_raw, list):
|
||
poster_configs = poster_text_configs_raw
|
||
logging.info(f"直接使用生成的配置列表,包含 {len(poster_configs)} 个配置项")
|
||
else:
|
||
# 如果不是列表,尝试转换或使用PosterConfig类解析
|
||
logging.info("生成的配置数据不是列表,使用PosterConfig类进行处理")
|
||
poster_config_summary = core_posterGen.PosterConfig(poster_text_configs_raw)
|
||
poster_configs = poster_config_summary.get_config()
|
||
except Exception as e:
|
||
logging.exception("Error running ContentGenerator or parsing poster configs:")
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
# Poster Generation Loop for each variant
|
||
poster_num = min(variants, len(poster_configs)) if isinstance(poster_configs, list) else variants
|
||
logging.info(f"计划生成 {poster_num} 个海报变体")
|
||
any_poster_attempted = False
|
||
|
||
for j_index in range(poster_num):
|
||
variant_index = j_index + 1
|
||
logging.info(f"Generating Poster {variant_index}/{poster_num}...")
|
||
any_poster_attempted = True
|
||
collage_img = None # To store the generated collage PIL Image
|
||
poster_img = None # To store the final poster PIL Image
|
||
try:
|
||
# 获取当前变体的配置
|
||
if isinstance(poster_configs, list) and j_index < len(poster_configs):
|
||
poster_config = poster_configs[j_index]
|
||
logging.info(f"使用配置数据项 {j_index+1}: {poster_config}")
|
||
else:
|
||
# 回退方案:使用PosterConfig类
|
||
poster_config = poster_config_summary.get_config_by_index(j_index)
|
||
logging.info(f"使用PosterConfig类获取配置项 {j_index+1}")
|
||
|
||
if not poster_config:
|
||
logging.warning(f"Warning: Could not get poster config for index {j_index}. Skipping.")
|
||
continue
|
||
|
||
# --- Image Collage ---
|
||
logging.info(f"Generating collage from: {input_img_dir_path}")
|
||
collage_images = core_simple_collage.process_directory(
|
||
input_img_dir_path,
|
||
target_size=poster_target_size,
|
||
output_count=1
|
||
)
|
||
|
||
if not collage_images: # 检查列表是否为空
|
||
logging.warning(f"Warning: Failed to generate collage image for Variant {variant_index}. Skipping poster.")
|
||
continue
|
||
collage_img = collage_images[0] # 获取第一个 PIL Image
|
||
logging.info(f"Collage image generated successfully (in memory).")
|
||
|
||
# --- 使用 Handler 保存 Collage 图片 ---
|
||
output_handler.handle_generated_image(
|
||
run_id, topic_index, variant_index,
|
||
image_type='collage',
|
||
image_data=collage_img,
|
||
output_filename='collage.png' # 或者其他期望的文件名
|
||
)
|
||
# --- 结束保存 Collage ---
|
||
|
||
# --- Create Poster ---
|
||
text_data = {
|
||
"title": poster_config.get('main_title', 'Default Title'),
|
||
"subtitle": "",
|
||
"additional_texts": []
|
||
}
|
||
texts = poster_config.get('texts', [])
|
||
if texts:
|
||
# 确保文本不为空
|
||
if texts[0]:
|
||
text_data["additional_texts"].append({"text": texts[0], "position": "middle", "size_factor": 0.8})
|
||
|
||
# 添加第二个文本(如果有并且满足随机条件)
|
||
if len(texts) > 1 and texts[1] and random.random() < text_possibility:
|
||
text_data["additional_texts"].append({"text": texts[1], "position": "middle", "size_factor": 0.8})
|
||
|
||
# 打印要发送的文本数据
|
||
logging.info(f"文本数据: {text_data}")
|
||
|
||
# 调用修改后的 create_poster, 接收 PIL Image
|
||
poster_img = poster_gen_instance.create_poster(collage_img, text_data)
|
||
|
||
if poster_img:
|
||
logging.info(f"Poster image generated successfully (in memory).")
|
||
# --- 使用 Handler 保存 Poster 图片 ---
|
||
output_handler.handle_generated_image(
|
||
run_id, topic_index, variant_index,
|
||
image_type='poster',
|
||
image_data=poster_img,
|
||
output_filename=output_poster_filename # 使用参数中的文件名
|
||
)
|
||
# --- 结束保存 Poster ---
|
||
else:
|
||
logging.warning(f"Warning: Poster generation function returned None for variant {variant_index}.")
|
||
|
||
except Exception as e:
|
||
logging.exception(f"Error during poster generation for Variant {variant_index}:")
|
||
traceback.print_exc()
|
||
continue
|
||
|
||
return any_poster_attempted
|
||
|
||
# main 函数不再使用,注释掉或移除
|
||
# def main():
|
||
# """主函数入口"""
|
||
# # ... (旧的 main 函数逻辑)
|
||
#
|
||
# if __name__ == "__main__":
|
||
# main() |