TravelContentCreator/utils/tweet_generator.py

820 lines
39 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import time
import random
import argparse
import json
from datetime import datetime
import sys
import traceback
import logging # Add logging
import re
# sys.path.append('/root/autodl-tmp') # No longer needed if running as a module or if path is set correctly
# 从本地模块导入
# from TravelContentCreator.core.ai_agent import AI_Agent # Remove project name prefix
# from TravelContentCreator.core.topic_parser import TopicParser # Remove project name prefix
# ResourceLoader is now used implicitly via PromptManager
# from TravelContentCreator.utils.resource_loader import ResourceLoader
# from TravelContentCreator.utils.prompt_manager import PromptManager # Remove project name prefix
# from ..core import contentGen as core_contentGen # Change to absolute import
# from ..core import posterGen as core_posterGen # Change to absolute import
# from ..core import simple_collage as core_simple_collage # Change to absolute import
from core.ai_agent import AI_Agent
from core.topic_parser import TopicParser
from utils.prompt_manager import PromptManager # Keep this as it's importing from the same level package 'utils'
from utils import content_generator as core_contentGen
from core import poster_gen as core_posterGen
from core import simple_collage as core_simple_collage
from .output_handler import OutputHandler # <-- 添加导入
class tweetTopic:
def __init__(self, index, date, logic, object, product, product_logic, style, style_logic, target_audience, target_audience_logic):
self.index = index
self.date = date
self.logic = logic
self.object = object
self.product = product
self.product_logic = product_logic
self.style = style
self.style_logic = style_logic
self.target_audience = target_audience
self.target_audience_logic = target_audience_logic
class tweetTopicRecord:
def __init__(self, topics_list, system_prompt, user_prompt, run_id):
self.topics_list = topics_list
self.system_prompt = system_prompt
self.user_prompt = user_prompt
self.run_id = run_id
class tweetContent:
def __init__(self, result, prompt, run_id, article_index, variant_index):
self.result = result
self.prompt = prompt
self.run_id = run_id
self.article_index = article_index
self.variant_index = variant_index
try:
self.title, self.content = self.split_content(result)
self.json_data = self.gen_result_json()
except Exception as e:
logging.error(f"Failed to parse AI result for {article_index}_{variant_index}: {e}")
logging.debug(f"Raw result: {result[:500]}...") # Log partial raw result
self.title = "" # 改为空字符串,而不是 "[Parsing Error]"
self.content = "" # 改为空字符串,而不是 "[Failed to parse AI content]"
self.json_data = {"title": self.title, "content": self.content, "error": True} # 不再包含raw_result
def split_content(self, result):
# Assuming split logic might still fail, keep it simple or improve with regex/json
# We should ideally switch content generation to JSON output as well.
# For now, keep existing logic but handle errors in __init__.
# Optional: Add basic check before splitting
if not result or "</think>" not in result or "title>" not in result or "content>" not in result:
logging.warning(f"AI result format unexpected: {result[:200]}...")
# 返回空字符串而不是抛出异常,这样可以在主函数继续处理
return "", ""
# --- Existing Logic (prone to errors) ---
try:
processed_result = result
if "</think>" in result:
processed_result = result.split("</think>", 1)[1] # Take part after </think>
title = processed_result.split("title>", 1)[1].split("</title>", 1)[0]
content = processed_result.split("content>", 1)[1].split("</content>", 1)[0]
# --- End Existing Logic ---
return title.strip(), content.strip()
except Exception as e:
logging.warning(f"解析内容时出错: {e}, 返回空字符串")
return "", ""
def gen_result_json(self):
json_file = {
"title": self.title,
"content": self.content
}
# Add error flag if it exists
if hasattr(self, 'json_data') and self.json_data.get('error'):
json_file['error'] = True
json_file['raw_result'] = self.json_data.get('raw_result')
return json_file
def get_json_data(self):
"""Returns the generated JSON data dictionary."""
return self.json_data
def get_prompt(self):
"""Returns the user prompt used to generate this content."""
return self.prompt
def get_content(self):
return self.content
def get_title(self):
return self.title
def generate_topics(ai_agent, system_prompt, user_prompt, run_id, temperature=0.2, top_p=0.5, presence_penalty=1.5):
"""生成选题列表 (run_id is now passed in, output_dir removed as argument)"""
logging.info("Starting topic generation...")
time_start = time.time()
# Call AI agent work method (updated return values)
result, tokens, time_cost = ai_agent.work(
system_prompt, user_prompt, "", temperature, top_p, presence_penalty
)
logging.info(f"Topic generation API call completed in {time_cost:.2f}s. Estimated tokens: {tokens}")
# Parse topics
result_list = TopicParser.parse_topics(result, run_id)
if not result_list:
logging.warning("Topic parsing resulted in an empty list.")
# Optionally handle raw result logging here if needed, but saving is responsibility of OutputHandler
# error_log_path = os.path.join(output_dir, run_id, f"topic_parsing_error_{run_id}.txt") # output_dir is not available here
# try:
# # ... (save raw output logic) ...
# except Exception as log_err:
# logging.error(f"Failed to save raw AI output on parsing failure: {log_err}")
# 直接返回解析后的列表
return result_list
def generate_single_content(ai_agent, system_prompt, user_prompt, item, run_id,
article_index, variant_index, temperature=0.3, top_p=0.4, presence_penalty=1.5):
"""Generates single content variant data. Returns (content_json, user_prompt) or (None, None)."""
logging.info(f"Generating content for topic {article_index}, variant {variant_index}")
try:
if not system_prompt or not user_prompt:
logging.error("System or User prompt is empty. Cannot generate content.")
return None, None
logging.debug(f"Using pre-constructed prompts. User prompt length: {len(user_prompt)}")
time.sleep(random.random() * 0.5)
# Generate content (non-streaming work returns result, tokens, time_cost)
result, tokens, time_cost = ai_agent.work(
system_prompt, user_prompt, "", temperature, top_p, presence_penalty
)
if result is None: # Check if AI call failed
logging.error(f"AI agent work failed for {article_index}_{variant_index}. No result returned.")
return {"title": "", "content": "", "error": True}, user_prompt # 返回空字段而不是None
logging.info(f"Content generation for {article_index}_{variant_index} completed in {time_cost:.2f}s. Estimated tokens: {tokens}")
# --- Create tweetContent object (handles parsing) ---
# Pass user_prompt instead of full prompt? Yes, user_prompt is what we need later.
tweet_content = tweetContent(result, user_prompt, run_id, article_index, variant_index)
# --- Remove Saving Logic ---
# run_specific_output_dir = os.path.join(output_dir, run_id) # output_dir no longer available
# variant_result_dir = os.path.join(run_specific_output_dir, f"{article_index}_{variant_index}")
# os.makedirs(variant_result_dir, exist_ok=True)
# content_save_path = os.path.join(variant_result_dir, "article.json")
# prompt_save_path = os.path.join(variant_result_dir, "tweet_prompt.txt")
# tweet_content.save_content(content_save_path) # Method removed
# tweet_content.save_prompt(prompt_save_path) # Method removed
# --- End Remove Saving Logic ---
# Return the data needed by the output handler
content_json = tweet_content.get_json_data()
prompt_data = tweet_content.get_prompt() # Get the stored user prompt
return content_json, prompt_data # Return data pair
except Exception as e:
logging.exception(f"Error generating single content for {article_index}_{variant_index}:")
return {"title": "", "content": "", "error": True}, user_prompt # 返回空字段而不是None
def generate_content(ai_agent, system_prompt, topics, output_dir, run_id, prompts_dir, resource_dir,
variants=2, temperature=0.3, start_index=0, end_index=None):
"""根据选题生成内容"""
if not topics:
print("没有选题,无法生成内容")
return
# 确定处理范围
if end_index is None or end_index > len(topics):
end_index = len(topics)
topics_to_process = topics[start_index:end_index]
print(f"准备处理{len(topics_to_process)}个选题...")
# 创建汇总文件
# summary_file = ResourceLoader.create_summary_file(output_dir, run_id, len(topics_to_process))
# 处理每个选题
processed_results = []
for i, item in enumerate(topics_to_process):
print(f"处理第 {i+1}/{len(topics_to_process)} 篇文章")
# 为每个选题生成多个变体
for j in range(variants):
print(f"正在生成变体 {j+1}/{variants}")
# 调用单篇文章生成函数
tweet_content, result = generate_single_content(
ai_agent, system_prompt, item, run_id, i+1, j+1, temperature
)
if tweet_content:
processed_results.append(tweet_content)
# # 更新汇总文件 (仅保存第一个变体到汇总文件)
# if j == 0:
# ResourceLoader.update_summary(summary_file, i+1, user_prompt, result)
print(f"完成{len(processed_results)}篇文章生成")
return processed_results
def prepare_topic_generation(prompt_manager: PromptManager,
api_url: str,
model_name: str,
api_key: str,
timeout: int,
max_retries: int):
"""准备选题生成的环境和参数. Returns agent, system_prompt, user_prompt.
Args:
prompt_manager: An initialized PromptManager instance.
api_url, model_name, api_key, timeout, max_retries: Parameters for AI_Agent.
"""
logging.info("Preparing for topic generation (using provided PromptManager)...")
# 从传入的 PromptManager 获取 prompts
system_prompt, user_prompt = prompt_manager.get_topic_prompts()
if not system_prompt or not user_prompt:
logging.error("Failed to get topic generation prompts from PromptManager.")
return None, None, None # Return three Nones
# 使用传入的参数初始化 AI Agent
try:
logging.info("Initializing AI Agent for topic generation...")
ai_agent = AI_Agent(
api_url, # Use passed arg
model_name, # Use passed arg
api_key, # Use passed arg
timeout=timeout, # Use passed arg
max_retries=max_retries # Use passed arg
)
except Exception as e:
logging.exception("Error initializing AI Agent for topic generation:")
return None, None, None # Return three Nones
# 返回 agent 和 prompts
return ai_agent, system_prompt, user_prompt
def run_topic_generation_pipeline(config, run_id=None):
"""
Runs the complete topic generation pipeline based on the configuration.
Returns: (run_id, topics_list, system_prompt, user_prompt) or (None, None, None, None) on failure.
"""
logging.info("Starting Step 1: Topic Generation Pipeline...")
if run_id is None:
logging.info("No run_id provided, generating one based on timestamp.")
run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
else:
logging.info(f"Using provided run_id: {run_id}")
ai_agent, system_prompt, user_prompt = None, None, None # Initialize
topics_list = None
prompt_manager = None # Initialize prompt_manager
try:
# --- 读取 PromptManager 所需参数 ---
topic_sys_prompt_path = config.get("topic_system_prompt")
topic_user_prompt_path = config.get("topic_user_prompt")
content_sys_prompt_path = config.get("content_system_prompt") # 虽然这里不用,但 PromptManager 可能需要
prompts_dir_path = config.get("prompts_dir")
prompts_config = config.get("prompts_config") # 新增获取prompts_config配置
resource_config = config.get("resource_dir", [])
topic_num = config.get("num", 1)
topic_date = config.get("date", "")
# --- 创建 PromptManager 实例 ---
prompt_manager = PromptManager(
topic_system_prompt_path=topic_sys_prompt_path,
topic_user_prompt_path=topic_user_prompt_path,
content_system_prompt_path=content_sys_prompt_path,
prompts_dir=prompts_dir_path,
prompts_config=prompts_config, # 新增传入prompts_config配置
resource_dir_config=resource_config,
topic_gen_num=topic_num,
topic_gen_date=topic_date
)
logging.info("PromptManager instance created.")
# --- 读取 AI Agent 所需参数 ---
ai_api_url = config.get("api_url")
ai_model = config.get("model")
ai_api_key = config.get("api_key")
ai_timeout = config.get("request_timeout", 30)
ai_max_retries = config.get("max_retries", 3)
# 检查必需的 AI 参数是否存在
if not all([ai_api_url, ai_model, ai_api_key]):
raise ValueError("Missing required AI configuration (api_url, model, api_key) in config.")
# --- 调用修改后的 prepare_topic_generation ---
ai_agent, system_prompt, user_prompt = prepare_topic_generation(
prompt_manager, # Pass instance
ai_api_url,
ai_model,
ai_api_key,
ai_timeout,
ai_max_retries
)
if not ai_agent or not system_prompt or not user_prompt:
raise ValueError("Failed to prepare topic generation (agent or prompts missing).")
# --- Generate topics (保持不变) ---
topics_list = generate_topics(
ai_agent, system_prompt, user_prompt,
run_id, # Pass run_id
config.get("topic_temperature", 0.2),
config.get("topic_top_p", 0.5),
config.get("topic_presence_penalty", 1.5)
)
except Exception as e:
logging.exception("Error during topic generation pipeline execution:")
# Ensure agent is closed even if generation fails mid-way
if ai_agent: ai_agent.close()
return None, None, None, None # Signal failure
finally:
# Ensure the AI agent is closed after generation attempt (if initialized)
if ai_agent:
logging.info("Closing topic generation AI Agent...")
ai_agent.close()
if topics_list is None: # Check if generate_topics returned None (though it currently returns list)
logging.error("Topic generation failed (generate_topics returned None or an error occurred).")
return None, None, None, None
elif not topics_list: # Check if list is empty
logging.warning(f"Topic generation completed for run {run_id}, but the resulting topic list is empty.")
# Return empty list and prompts anyway, let caller decide
# --- Saving logic removed previously ---
logging.info(f"Topic generation pipeline completed successfully. Run ID: {run_id}")
# Return the raw data needed by the OutputHandler
return run_id, topics_list, system_prompt, user_prompt
# --- Decoupled Functional Units (Moved from main.py) ---
def generate_content_for_topic(ai_agent: AI_Agent,
prompt_manager: PromptManager,
topic_item: dict,
run_id: str,
topic_index: int,
output_handler: OutputHandler, # Changed name to match convention
# 添加具体参数,移除 config 和 output_dir
variants: int,
temperature: float,
top_p: float,
presence_penalty: float):
"""Generates all content variants for a single topic item and uses OutputHandler.
Args:
ai_agent: Initialized AI_Agent instance.
prompt_manager: Initialized PromptManager instance.
topic_item: Dictionary representing the topic.
run_id: Current run ID.
topic_index: 1-based index of the topic.
output_handler: An instance of OutputHandler to process results.
variants: Number of variants to generate.
temperature, top_p, presence_penalty: AI generation parameters.
Returns:
bool: True if at least one variant was successfully generated and handled, False otherwise.
"""
logging.info(f"Generating content for Topic {topic_index} (Object: {topic_item.get('object', 'N/A')})...")
success_flag = False # Track if any variant succeeded
# 使用传入的 variants 参数
# variants = config.get("variants", 1)
for j in range(variants):
variant_index = j + 1
logging.info(f" Generating Variant {variant_index}/{variants}...")
# PromptManager 实例已传入,直接调用
content_system_prompt, content_user_prompt = prompt_manager.get_content_prompts(topic_item)
if not content_system_prompt or not content_user_prompt:
logging.warning(f" Skipping Variant {variant_index} due to missing content prompts.")
continue
time.sleep(random.random() * 0.5)
try:
# Call generate_single_content with passed-in parameters
content_json, prompt_data = generate_single_content(
ai_agent, content_system_prompt, content_user_prompt, topic_item,
run_id, topic_index, variant_index,
temperature, # 使用传入的参数
top_p, # 使用传入的参数
presence_penalty # 使用传入的参数
)
# 简化检查只要content_json不是None就处理它
# 即使是空标题和内容也是有效的结果
if content_json is not None:
# Use the output handler to process/save the result
output_handler.handle_content_variant(
run_id, topic_index, variant_index, content_json, prompt_data or ""
)
success_flag = True # Mark success for this topic
# Check specifically if the AI result itself indicated a parsing error internally
if content_json.get("error"):
logging.warning(f" Content generation for Topic {topic_index}, Variant {variant_index} succeeded but response parsing had issues. Using empty content values.")
else:
logging.info(f" Successfully generated and handled content for Topic {topic_index}, Variant {variant_index}.")
else:
logging.error(f" Content generation failed for Topic {topic_index}, Variant {variant_index}. Skipping handling.")
except Exception as e:
logging.exception(f" Error during content generation call or handling for Topic {topic_index}, Variant {variant_index}:")
# Return the success flag for this topic
return success_flag
def generate_posters_for_topic(topic_item: dict,
output_dir: str,
run_id: str,
topic_index: int,
output_handler: OutputHandler, # 添加 handler
variants: int,
poster_assets_base_dir: str,
image_base_dir: str,
img_frame_possibility: float,
text_bg_possibility: float,
title_possibility: float,
text_possibility: float,
resource_dir_config: list,
poster_target_size: tuple,
output_collage_subdir: str,
output_poster_subdir: str,
output_poster_filename: str,
system_prompt: str,
collage_style: str,
timeout: int
):
"""Generates all posters for a single topic item, handling image data via OutputHandler.
Args:
topic_item: The dictionary representing a single topic.
output_dir: The base output directory for the entire run (e.g., ./result).
run_id: The ID for the current run.
topic_index: The 1-based index of the current topic.
variants: Number of variants.
poster_assets_base_dir: Path to poster assets (fonts, frames etc.).
image_base_dir: Base path for source images.
img_frame_possibility: Probability of adding image frame.
text_bg_possibility: Probability of adding text background.
resource_dir_config: Configuration for resource directories (used for Description).
poster_target_size: Target size tuple (width, height) for the poster.
text_possibility: Probability of adding secondary text.
output_collage_subdir: Subdirectory name for saving collages.
output_poster_subdir: Subdirectory name for saving posters.
output_poster_filename: Filename for the final poster.
system_prompt: System prompt for content generation.
output_handler: An instance of OutputHandler to process results.
timeout: Timeout for content generation.
Returns:
True if poster generation was attempted (regardless of individual variant success),
False if setup failed before attempting variants.
"""
logging.info(f"Generating posters for Topic {topic_index} (Object: {topic_item.get('object', 'N/A')})...")
# --- Load content data from files ---
loaded_content_list = []
logging.info(f"Attempting to load content data for {variants} variants for topic {topic_index}...")
for j in range(variants):
variant_index = j + 1
variant_dir = os.path.join(output_dir, run_id, f"{topic_index}_{variant_index}")
content_path = os.path.join(variant_dir, "article.json")
try:
if os.path.exists(content_path):
with open(content_path, 'r', encoding='utf-8') as f_content:
content_data = json.load(f_content)
if isinstance(content_data, dict) and 'title' in content_data and 'content' in content_data:
loaded_content_list.append(content_data)
logging.debug(f" Successfully loaded content from: {content_path}")
else:
logging.warning(f" Content file {content_path} has invalid format. Skipping.")
else:
logging.warning(f" Content file not found for variant {variant_index}: {content_path}. Skipping.")
except json.JSONDecodeError:
logging.error(f" Error decoding JSON from content file: {content_path}. Skipping.")
except Exception as e:
logging.exception(f" Error loading content file {content_path}: {e}")
if not loaded_content_list:
logging.error(f"No valid content data loaded for topic {topic_index}. Cannot generate posters.")
return False
logging.info(f"Successfully loaded content data for {len(loaded_content_list)} variants.")
# --- End Load content data ---
# Initialize generators using parameters
try:
content_gen_instance = core_contentGen.ContentGenerator()
if not poster_assets_base_dir:
logging.error("Error: 'poster_assets_base_dir' not provided. Cannot generate posters.")
return False
poster_gen_instance = core_posterGen.PosterGenerator(base_dir=poster_assets_base_dir)
poster_gen_instance.set_img_frame_possibility(img_frame_possibility)
poster_gen_instance.set_text_bg_possibility(text_bg_possibility)
except Exception as e:
logging.exception("Error initializing generators for poster creation:")
return False
# --- Setup: Paths and Object Name ---
object_name = topic_item.get("object", "")
if not object_name:
logging.warning("Warning: Topic object name is missing. Cannot generate posters.")
return False
# Clean object name
try:
object_name_cleaned = object_name.split(".")[0].replace("景点信息-", "").strip()
if not object_name_cleaned:
logging.warning(f"Warning: Object name '{object_name}' resulted in empty string after cleaning. Skipping posters.")
return False
object_name = object_name_cleaned
except Exception as e:
logging.warning(f"Warning: Could not fully clean object name '{object_name}': {e}. Skipping posters.")
return False
# Construct and check INPUT image paths
input_img_dir_path = os.path.join(image_base_dir, object_name)
if not os.path.exists(input_img_dir_path) or not os.path.isdir(input_img_dir_path):
# 模糊匹配:如果找不到完全匹配的目录,尝试查找包含关键词的目录
logging.info(f"尝试对图片目录进行模糊匹配: {object_name}")
found_dir = None
# 1. 尝试获取image_base_dir下的所有目录
try:
all_dirs = [d for d in os.listdir(image_base_dir)
if os.path.isdir(os.path.join(image_base_dir, d))]
logging.info(f"找到 {len(all_dirs)} 个图片目录可用于模糊匹配")
# 2. 提取对象名称中的关键词
# 例如:"美的鹭湖鹭栖台酒店+盈香心动乐园" -> ["美的", "鹭湖", "酒店", "乐园"]
# 首先通过常见分隔符分割(+、空格、_、-等)
parts = re.split(r'[+\s_\-]', object_name)
keywords = []
for part in parts:
# 只保留长度大于1的有意义关键词
if len(part) > 1:
keywords.append(part)
# 尝试匹配更短的语义单元例如中文的2-3个字的词语
# 对于中文名称可以尝试提取2-3个字的短语
for i in range(len(object_name) - 1):
keyword = object_name[i:i+2] # 提取2个字符
if len(keyword) == 2 and all('\u4e00' <= c <= '\u9fff' for c in keyword):
keywords.append(keyword)
# 3. 对每个目录进行评分
dir_scores = {}
for directory in all_dirs:
score = 0
dir_lower = directory.lower()
# 为每个匹配的关键词增加分数
for keyword in keywords:
if keyword.lower() in dir_lower:
score += 1
# 如果得分大于0至少匹配一个关键词记录该目录
if score > 0:
dir_scores[directory] = score
# 4. 选择得分最高的目录
if dir_scores:
best_match = max(dir_scores.items(), key=lambda x: x[1])
found_dir = best_match[0]
logging.info(f"模糊匹配成功!匹配目录: {found_dir},匹配分数: {best_match[1]}")
# 更新图片目录路径
input_img_dir_path = os.path.join(image_base_dir, found_dir)
logging.info(f"使用模糊匹配的图片目录: {input_img_dir_path}")
else:
logging.warning(f"模糊匹配未找到任何包含关键词的目录")
except Exception as e:
logging.warning(f"模糊匹配过程中出错: {e}")
# 如果仍然无法找到有效目录,则返回错误
if not found_dir or not os.path.exists(input_img_dir_path) or not os.path.isdir(input_img_dir_path):
logging.warning(f"Warning: 即使通过模糊匹配也无法找到图片目录: '{input_img_dir_path}'. Skipping posters for this topic.")
return False
# Locate Description File using resource_dir_config parameter
info_directory = []
description_file_path = None
found_description = False
# 准备关键词列表用于模糊匹配
# 与上面图片目录匹配类似,提取对象名称的关键词
parts = re.split(r'[+\s_\-]', object_name)
keywords = []
for part in parts:
if len(part) > 1:
keywords.append(part)
# 尝试提取中文短语作为关键词
for i in range(len(object_name) - 1):
keyword = object_name[i:i+2]
if len(keyword) == 2 and all('\u4e00' <= c <= '\u9fff' for c in keyword):
keywords.append(keyword)
logging.info(f"用于描述文件模糊匹配的关键词: {keywords}")
# 尝试精确匹配
for dir_info in resource_dir_config:
if dir_info.get("type") == "Description":
for file_path in dir_info.get("file_path", []):
if object_name in os.path.basename(file_path):
description_file_path = file_path
if os.path.exists(description_file_path):
info_directory = [description_file_path]
logging.info(f"找到并使用精确匹配的描述文件: {description_file_path}")
found_description = True
else:
logging.warning(f"Warning: 配置中指定的描述文件未找到: {description_file_path}")
break
if found_description:
break
# 如果精确匹配失败,尝试模糊匹配
if not found_description:
logging.info(f"未找到'{object_name}'的精确匹配描述文件,尝试模糊匹配...")
best_score = 0
best_file = None
for dir_info in resource_dir_config:
if dir_info.get("type") == "Description":
for file_path in dir_info.get("file_path", []):
file_name = os.path.basename(file_path)
score = 0
# 计算关键词匹配分数
for keyword in keywords:
if keyword.lower() in file_name.lower():
score += 1
# 如果当前文件得分更高,更新最佳匹配
if score > best_score and os.path.exists(file_path):
best_score = score
best_file = file_path
# 如果找到了最佳匹配文件
if best_file:
description_file_path = best_file
info_directory = [description_file_path]
logging.info(f"模糊匹配找到描述文件: {description_file_path},匹配分数: {best_score}")
found_description = True
if not found_description:
logging.warning(f"未找到对象'{object_name}'的匹配描述文件。")
# Generate Text Configurations for All Variants
try:
poster_text_configs_raw = content_gen_instance.run(info_directory, variants, loaded_content_list, system_prompt, timeout=timeout)
if not poster_text_configs_raw:
logging.warning("Warning: ContentGenerator returned empty configuration data. Skipping posters.")
return False
# --- 使用 OutputHandler 保存 Poster Config ---
output_handler.handle_poster_configs(run_id, topic_index, poster_text_configs_raw)
# --- 结束使用 Handler 保存 ---
# 打印原始配置数据以进行调试
logging.info(f"生成的海报配置数据: {poster_text_configs_raw}")
# 直接使用配置数据,避免通过文件读取
if isinstance(poster_text_configs_raw, list):
poster_configs = poster_text_configs_raw
logging.info(f"直接使用生成的配置列表,包含 {len(poster_configs)} 个配置项")
else:
# 如果不是列表尝试转换或使用PosterConfig类解析
logging.info("生成的配置数据不是列表使用PosterConfig类进行处理")
poster_config_summary = core_posterGen.PosterConfig(poster_text_configs_raw)
poster_configs = poster_config_summary.get_config()
except Exception as e:
logging.exception("Error running ContentGenerator or parsing poster configs:")
traceback.print_exc()
return False
# Poster Generation Loop for each variant
poster_num = min(variants, len(poster_configs)) if isinstance(poster_configs, list) else variants
logging.info(f"计划生成 {poster_num} 个海报变体")
any_poster_attempted = False
for j_index in range(poster_num):
variant_index = j_index + 1
logging.info(f"Generating Poster {variant_index}/{poster_num}...")
any_poster_attempted = True
collage_img = None # To store the generated collage PIL Image
poster_img = None # To store the final poster PIL Image
try:
# 获取当前变体的配置
if isinstance(poster_configs, list) and j_index < len(poster_configs):
poster_config = poster_configs[j_index]
logging.info(f"使用配置数据项 {j_index+1}: {poster_config}")
else:
# 回退方案使用PosterConfig类
poster_config = poster_config_summary.get_config_by_index(j_index)
logging.info(f"使用PosterConfig类获取配置项 {j_index+1}")
if not poster_config:
logging.warning(f"Warning: Could not get poster config for index {j_index}. Skipping.")
continue
# --- Image Collage ---
logging.info(f"Generating collage from: {input_img_dir_path}")
collage_images, used_image_filenames = core_simple_collage.process_directory(
input_img_dir_path,
style=collage_style,
target_size=poster_target_size,
output_count=1
)
if not collage_images: # 检查列表是否为空
logging.warning(f"Warning: Failed to generate collage image for Variant {variant_index}. Skipping poster.")
continue
collage_img = collage_images[0] # 获取第一个 PIL Image
used_image_files = used_image_filenames[0] if used_image_filenames else [] # 获取使用的图片文件名
logging.info(f"Collage image generated successfully (in memory). Used images: {used_image_files}")
print(f"拼贴图使用的图片文件: {used_image_files}")
# --- 使用 Handler 保存 Collage 图片和使用的图片文件信息 ---
output_handler.handle_generated_image(
run_id, topic_index, variant_index,
image_type='collage',
image_data=collage_img,
output_filename='collage.png', # 或者其他期望的文件名
metadata={'used_images': used_image_files} # 添加图片文件信息到元数据
)
# --- 结束保存 Collage ---
# --- Create Poster ---
if random.random() > title_possibility:
text_data = {
"title": poster_config.get('main_title', ''),
"subtitle": "",
"additional_texts": []
}
texts = poster_config.get('texts', [])
if texts:
# 确保文本不为空
if random.random() > text_possibility:
text_data["additional_texts"].append({"text": texts[0], "position": "middle", "size_factor": 0.8})
# for text in texts:
# if random.random() < text_possibility:
# text_data["additional_texts"].append({"text": text, "position": "middle", "size_factor": 0.8})
else:
text_data = {
"title": "",
"subtitle": "",
"additional_texts": []
}
# 打印要发送的文本数据
logging.info(f"文本数据: {text_data}")
# 调用修改后的 create_poster, 接收 PIL Image
poster_img = poster_gen_instance.create_poster(collage_img, text_data)
if poster_img:
logging.info(f"Poster image generated successfully (in memory).")
# --- 使用 Handler 保存 Poster 图片 ---
output_handler.handle_generated_image(
run_id, topic_index, variant_index,
image_type='poster',
image_data=poster_img,
output_filename=output_poster_filename, # 使用参数中的文件名
metadata={'used_collage': True, 'collage_images': used_image_files}
)
# --- 结束保存 Poster ---
else:
logging.warning(f"Warning: Poster generation function returned None for variant {variant_index}.")
except Exception as e:
logging.exception(f"Error during poster generation for Variant {variant_index}:")
traceback.print_exc()
continue
return any_poster_attempted
# main 函数不再使用,注释掉或移除
# def main():
# """主函数入口"""
# # ... (旧的 main 函数逻辑)
#
# if __name__ == "__main__":
# main()