TravelContentCreator/utils/tweet_generator.py

1002 lines
48 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import time
import random
import argparse
import json
from datetime import datetime
import sys
import traceback
import logging # Add logging
import re
# sys.path.append('/root/autodl-tmp') # No longer needed if running as a module or if path is set correctly
# 从本地模块导入
# from TravelContentCreator.core.ai_agent import AI_Agent # Remove project name prefix
# from TravelContentCreator.core.topic_parser import TopicParser # Remove project name prefix
# ResourceLoader is now used implicitly via PromptManager
# from TravelContentCreator.utils.resource_loader import ResourceLoader
# from TravelContentCreator.utils.prompt_manager import PromptManager # Remove project name prefix
# from ..core import contentGen as core_contentGen # Change to absolute import
# from ..core import posterGen as core_posterGen # Change to absolute import
# from ..core import simple_collage as core_simple_collage # Change to absolute import
from core.ai_agent import AI_Agent
from core.topic_parser import TopicParser
from utils.prompt_manager import PromptManager # Keep this as it's importing from the same level package 'utils'
from utils import content_generator as core_contentGen
from core import poster_gen as core_posterGen
from core import simple_collage as core_simple_collage
from .output_handler import OutputHandler # <-- 添加导入
from utils.content_judger import ContentJudger # <-- 添加ContentJudger导入
class tweetTopic:
def __init__(self, index, date, logic, object, product, product_logic, style, style_logic, target_audience, target_audience_logic):
self.index = index
self.date = date
self.logic = logic
self.object = object
self.product = product
self.product_logic = product_logic
self.style = style
self.style_logic = style_logic
self.target_audience = target_audience
self.target_audience_logic = target_audience_logic
class tweetTopicRecord:
def __init__(self, topics_list, system_prompt, user_prompt, run_id):
self.topics_list = topics_list
self.system_prompt = system_prompt
self.user_prompt = user_prompt
self.run_id = run_id
class tweetContent:
def __init__(self, result, prompt, run_id, article_index, variant_index):
self.result = result
self.prompt = prompt
self.run_id = run_id
self.article_index = article_index
self.variant_index = variant_index
try:
self.json_data = self.split_content(result)
except Exception as e:
logging.error(f"Failed to parse AI result for {article_index}_{variant_index}: {e}")
logging.debug(f"Raw result: {result[:500]}...") # Log partial raw result
self.json_data = {"title": "", "content": "", "tag": "", "error": True, "raw_result": e} # 不再包含raw_result
def split_content(self, result):
try:
processed_result = result
if "</think>" in result:
processed_result = result.split("</think>")[1] # Take part after </think>
# 以json 格式输出
json_data = json.loads(processed_result)
json_data["error"] = False
json_data["raw_result"] = None
json_data["judge_success"] = None
return json_data
except Exception as e:
logging.warning(f"解析内容时出错: {e}, 使用默认空内容")
# 创建一个新的json_data而不是使用未定义的变量
return {
"title": "",
"content": "",
"error": True,
"raw_result": str(e),
"judge_success": False
}
def get_json_data(self):
"""Returns the generated JSON data dictionary."""
return self.json_data
def get_prompt(self):
"""Returns the user prompt used to generate this content."""
return self.prompt
def get_content(self):
return self.content
def get_title(self):
return self.title
def generate_topics(ai_agent, system_prompt, user_prompt, run_id, temperature=0.2, top_p=0.5, presence_penalty=1.5):
"""生成选题列表 (run_id is now passed in, output_dir removed as argument)"""
logging.info("Starting topic generation...")
time_start = time.time()
# Call AI agent work method (updated return values)
result, tokens, time_cost = ai_agent.work(
system_prompt, user_prompt, "", temperature, top_p, presence_penalty
)
logging.info(f"Topic generation API call completed in {time_cost:.2f}s. Estimated tokens: {tokens}")
# Parse topics
result_list = TopicParser.parse_topics(result, run_id)
if not result_list:
logging.warning("Topic parsing resulted in an empty list.")
# Optionally handle raw result logging here if needed, but saving is responsibility of OutputHandler
# error_log_path = os.path.join(output_dir, run_id, f"topic_parsing_error_{run_id}.txt") # output_dir is not available here
# try:
# # ... (save raw output logic) ...
# except Exception as log_err:
# logging.error(f"Failed to save raw AI output on parsing failure: {log_err}")
# 直接返回解析后的列表
return result_list
def generate_single_content(ai_agent, system_prompt, user_prompt, item, run_id,
article_index, variant_index, temperature=0.3, top_p=0.4, presence_penalty=1.5,
max_retries=3):
"""Generates single content variant data. Returns (content_json, user_prompt) or (None, None)."""
logging.info(f"Generating content for topic {article_index}, variant {variant_index}")
if not system_prompt or not user_prompt:
logging.error("System or User prompt is empty. Cannot generate content.")
return None, None
logging.debug(f"Using pre-constructed prompts. User prompt length: {len(user_prompt)}")
# 实现重试逻辑
retry_count = 0
last_result = None
last_tokens = None
last_time_cost = None
while retry_count <= max_retries:
try:
# 只有重试时增加延迟和调整参数
if retry_count > 0:
# 添加随机延迟避免频繁请求
delay = 1 + random.random() * 2 # 1-3秒随机延迟
logging.info(f"内容生成重试 ({retry_count}/{max_retries}),等待{delay:.1f}秒后尝试...")
time.sleep(delay)
# 调整温度参数,增加多样性
adjusted_temperature = min(temperature + (retry_count * 0.1), 0.9)
logging.info(f"调整温度参数为: {adjusted_temperature}")
else:
adjusted_temperature = temperature
# Generate content (non-streaming work returns result, tokens, time_cost)
result, tokens, time_cost = ai_agent.work(
system_prompt, user_prompt, "", adjusted_temperature, top_p, presence_penalty
)
last_result = result
last_tokens = tokens
last_time_cost = time_cost
if result is None: # Check if AI call failed completely
logging.error(f"AI agent work failed for {article_index}_{variant_index}. No result returned.")
retry_count += 1
continue
logging.info(f"Content generation for {article_index}_{variant_index} completed in {time_cost:.2f}s. Estimated tokens: {tokens}")
# --- Create tweetContent object (handles parsing) ---
tweet_content = tweetContent(result, user_prompt, run_id, article_index, variant_index)
content_json = tweet_content.get_json_data()
# 检查是否成功解析到有效内容
if not content_json.get("error", False) and content_json.get("title") and content_json.get("content"):
# 成功获取有效内容
if retry_count > 0:
logging.info(f"在第{retry_count}次重试后成功获取有效内容")
# 返回成功结果
return content_json, user_prompt
else:
logging.warning(f"内容解析失败或内容不完整,结果: {content_json.get('error')}, 标题长度: {len(content_json.get('title', ''))}, 内容长度: {len(content_json.get('content', ''))}")
# 如果到这里,说明内容生成或解析有问题,需要重试
retry_count += 1
except Exception as e:
logging.exception(f"Error during content generation attempt {retry_count+1} for {article_index}_{variant_index}:")
retry_count += 1
if retry_count <= max_retries:
logging.info(f"将尝试第{retry_count}次重试...")
else:
logging.error(f"达到最大重试次数({max_retries}),无法生成有效内容")
# 所有重试都失败,返回最后一次的结果(即使不完整)
logging.warning(f"{max_retries}次尝试后仍未生成有效内容,返回最后一次结果")
# 如果有最后一次结果,尝试使用它
if last_result:
try:
tweet_content = tweetContent(last_result, user_prompt, run_id, article_index, variant_index)
content_json = tweet_content.get_json_data()
return content_json, user_prompt
except Exception as e:
logging.exception(f"Error processing last result: {e}")
# 完全失败的情况,返回空内容
return {"title": "", "content": "", "error": True, "judge_success": False}, user_prompt
def generate_content(ai_agent, system_prompt, topics, output_dir, run_id, prompts_dir, resource_dir,
variants=2, temperature=0.3, start_index=0, end_index=None):
"""根据选题生成内容"""
if not topics:
logging.warning("没有选题,无法生成内容")
return
# 确定处理范围
if end_index is None or end_index > len(topics):
end_index = len(topics)
topics_to_process = topics[start_index:end_index]
logging.info(f"准备处理{len(topics_to_process)}个选题...")
# 创建汇总文件
# summary_file = ResourceLoader.create_summary_file(output_dir, run_id, len(topics_to_process))
# 处理每个选题
processed_results = []
for i, item in enumerate(topics_to_process):
logging.info(f"处理第 {i+1}/{len(topics_to_process)} 篇文章")
# 为每个选题生成多个变体
for j in range(variants):
logging.info(f"正在生成变体 {j+1}/{variants}")
# 调用单篇文章生成函数
tweet_content, result = generate_single_content(
ai_agent, system_prompt, item, run_id, i+1, j+1, temperature
)
if tweet_content:
processed_results.append(tweet_content)
# # 更新汇总文件 (仅保存第一个变体到汇总文件)
# if j == 0:
# ResourceLoader.update_summary(summary_file, i+1, user_prompt, result)
logging.info(f"完成{len(processed_results)}篇文章生成")
return processed_results
def prepare_topic_generation(prompt_manager: PromptManager,
api_url: str,
model_name: str,
api_key: str,
timeout: int,
max_retries: int):
"""准备选题生成的环境和参数. Returns agent, system_prompt, user_prompt.
Args:
prompt_manager: An initialized PromptManager instance.
api_url, model_name, api_key, timeout, max_retries: Parameters for AI_Agent.
"""
logging.info("Preparing for topic generation (using provided PromptManager)...")
# 从传入的 PromptManager 获取 prompts
system_prompt, user_prompt = prompt_manager.get_topic_prompts()
if not system_prompt or not user_prompt:
logging.error("Failed to get topic generation prompts from PromptManager.")
return None, None, None # Return three Nones
# 使用传入的参数初始化 AI Agent
try:
logging.info("Initializing AI Agent for topic generation...")
ai_agent = AI_Agent(
api_url, # Use passed arg
model_name, # Use passed arg
api_key, # Use passed arg
timeout=timeout, # Use passed arg
max_retries=max_retries # Use passed arg
)
except Exception as e:
logging.exception("Error initializing AI Agent for topic generation:")
return None, None, None # Return three Nones
# 返回 agent 和 prompts
return ai_agent, system_prompt, user_prompt
def run_topic_generation_pipeline(config, run_id=None):
"""
Runs the complete topic generation pipeline based on the configuration.
Returns: (run_id, topics_list, system_prompt, user_prompt) or (None, None, None, None) on failure.
"""
logging.info("Starting Step 1: Topic Generation Pipeline...")
if run_id is None:
logging.info("No run_id provided, generating one based on timestamp.")
run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
else:
logging.info(f"Using provided run_id: {run_id}")
ai_agent, system_prompt, user_prompt = None, None, None # Initialize
topics_list = None
prompt_manager = None # Initialize prompt_manager
try:
# --- 读取 PromptManager 所需参数 ---
topic_sys_prompt_path = config.get("topic_system_prompt")
topic_user_prompt_path = config.get("topic_user_prompt")
content_sys_prompt_path = config.get("content_system_prompt") # 虽然这里不用,但 PromptManager 可能需要
prompts_dir_path = config.get("prompts_dir")
prompts_config = config.get("prompts_config") # 新增获取prompts_config配置
resource_config = config.get("resource_dir", [])
topic_num = config.get("num", 1)
topic_date = config.get("date", "")
# --- 创建 PromptManager 实例 ---
prompt_manager = PromptManager(
topic_system_prompt_path=topic_sys_prompt_path,
topic_user_prompt_path=topic_user_prompt_path,
content_system_prompt_path=content_sys_prompt_path,
prompts_dir=prompts_dir_path,
prompts_config=prompts_config, # 新增传入prompts_config配置
resource_dir_config=resource_config,
topic_gen_num=topic_num,
topic_gen_date=topic_date
)
logging.info("PromptManager instance created.")
# --- 读取 AI Agent 所需参数 ---
ai_api_url = config.get("api_url")
ai_model = config.get("model")
ai_api_key = config.get("api_key")
ai_timeout = config.get("request_timeout", 30)
ai_max_retries = config.get("max_retries", 3)
# 检查必需的 AI 参数是否存在
if not all([ai_api_url, ai_model, ai_api_key]):
raise ValueError("Missing required AI configuration (api_url, model, api_key) in config.")
# --- 调用修改后的 prepare_topic_generation ---
ai_agent, system_prompt, user_prompt = prepare_topic_generation(
prompt_manager, # Pass instance
ai_api_url,
ai_model,
ai_api_key,
ai_timeout,
ai_max_retries
)
if not ai_agent or not system_prompt or not user_prompt:
raise ValueError("Failed to prepare topic generation (agent or prompts missing).")
# --- Generate topics (保持不变) ---
topics_list = generate_topics(
ai_agent, system_prompt, user_prompt,
run_id, # Pass run_id
config.get("topic_temperature", 0.2),
config.get("topic_top_p", 0.5),
config.get("topic_presence_penalty", 1.5)
)
except Exception as e:
logging.exception("Error during topic generation pipeline execution:")
# Ensure agent is closed even if generation fails mid-way
if ai_agent: ai_agent.close()
return None, None, None, None # Signal failure
finally:
# Ensure the AI agent is closed after generation attempt (if initialized)
if ai_agent:
logging.info("Closing topic generation AI Agent...")
ai_agent.close()
if topics_list is None: # Check if generate_topics returned None (though it currently returns list)
logging.error("Topic generation failed (generate_topics returned None or an error occurred).")
return None, None, None, None
elif not topics_list: # Check if list is empty
logging.warning(f"Topic generation completed for run {run_id}, but the resulting topic list is empty.")
# Return empty list and prompts anyway, let caller decide
# --- Saving logic removed previously ---
logging.info(f"Topic generation pipeline completed successfully. Run ID: {run_id}")
# Return the raw data needed by the OutputHandler
return run_id, topics_list, system_prompt, user_prompt
# --- Decoupled Functional Units (Moved from main.py) ---
def generate_content_for_topic(ai_agent: AI_Agent,
prompt_manager: PromptManager,
topic_item: dict,
run_id: str,
topic_index: int,
output_handler: OutputHandler, # Changed name to match convention
# 添加具体参数,移除 config 和 output_dir
variants: int,
temperature: float,
top_p: float,
presence_penalty: float,
enable_content_judge: bool):
"""Generates all content variants for a single topic item and uses OutputHandler.
Args:
ai_agent: Initialized AI_Agent instance.
prompt_manager: Initialized PromptManager instance.
topic_item: Dictionary representing the topic.
run_id: Current run ID.
topic_index: 1-based index of the topic.
output_handler: An instance of OutputHandler to process results.
variants: Number of variants to generate.
temperature, top_p, presence_penalty: AI generation parameters.
enable_content_judge: Whether to enable content judge.
Returns:
bool: True if at least one variant was successfully generated and handled, False otherwise.
"""
logging.info(f"Generating content for Topic {topic_index} (Object: {topic_item.get('object', 'N/A')})...")
success_flag = False # Track if any variant succeeded
# 使用传入的 variants 参数
# variants = config.get("variants", 1)
# 如果启用了内容审核,获取产品资料
product_info = None
content_judger = None
if enable_content_judge:
logging.info(f"内容审核功能已启用,准备获取产品资料...")
# 从topic_item中获取产品名称和对象名称
product_name = topic_item.get("product", "")
object_name = topic_item.get("object", "")
# 组合获取产品资料
product_info = ""
# 获取对象信息
if object_name:
# 通过PromptManager获取对象和产品资料
# 这一部分逻辑来自PromptManager.get_content_prompts中对object_info的构建
found_object_info = False
all_description_files = []
# 从resource_dir_config搜集所有可能的资源文件
for dir_info in prompt_manager.resource_dir_config:
if dir_info.get("type") in ["Object", "Description"]:
all_description_files.extend(dir_info.get("file_path", []))
# 尝试精确匹配对象资料
for file_path in all_description_files:
if object_name in os.path.basename(file_path):
from utils.resource_loader import ResourceLoader
info = ResourceLoader.load_file_content(file_path)
if info:
product_info += f"Object: {object_name}\n{info}\n\n"
logging.info(f"为内容审核找到对象'{object_name}'的资源文件: {file_path}")
found_object_info = True
break
# 如果未找到对象资料,记录警告但继续处理
if not found_object_info:
logging.warning(f"未能为内容审核找到对象'{object_name}'的资源文件")
# 获取产品信息
if product_name:
found_product_info = False
all_product_files = []
# 搜集所有可能的产品资源文件
for dir_info in prompt_manager.resource_dir_config:
if dir_info.get("type") == "Product":
all_product_files.extend(dir_info.get("file_path", []))
# 尝试精确匹配产品资料
for file_path in all_product_files:
if product_name in os.path.basename(file_path):
from utils.resource_loader import ResourceLoader
info = ResourceLoader.load_file_content(file_path)
if info:
product_info += f"Product: {product_name}\n{info}\n\n"
logging.info(f"为内容审核找到产品'{product_name}'的资源文件: {file_path}")
found_product_info = True
break
# 如果未找到产品资料,记录警告但继续处理
if not found_product_info:
logging.warning(f"未能为内容审核找到产品'{product_name}'的资源文件")
# 如果成功获取产品资料初始化ContentJudger
if product_info:
logging.info("成功获取产品资料初始化ContentJudger...")
# 从配置中读取系统提示词路径(脚本级别无法直接获取,需要传递)
# 使用ai_agent的model_name或api_url判断是否使用主AI模型避免额外资源占用
content_judger_system_prompt = prompt_manager._system_prompt_cache.get("judger_system_prompt")
content_judger = ContentJudger(ai_agent, system_prompt=content_judger_system_prompt)
else:
logging.warning("未能获取产品资料,内容审核功能将被跳过")
enable_content_judge = False
for j in range(variants):
variant_index = j + 1
logging.info(f" Generating Variant {variant_index}/{variants}...")
# PromptManager 实例已传入,直接调用
content_system_prompt, content_user_prompt = prompt_manager.get_content_prompts(topic_item)
if not content_system_prompt or not content_user_prompt:
logging.warning(f" Skipping Variant {variant_index} due to missing content prompts.")
continue
time.sleep(random.random() * 0.5)
try:
# Call generate_single_content with passed-in parameters
content_json, prompt_data = generate_single_content(
ai_agent, content_system_prompt, content_user_prompt, topic_item,
run_id, topic_index, variant_index,
temperature, # 使用传入的参数
top_p, # 使用传入的参数
presence_penalty # 使用传入的参数
)
# 简化检查只要content_json不是None就处理它
# 即使是空标题和内容也是有效的结果
if content_json is not None:
# 进行内容审核如果启用且ContentJudger已初始化
if enable_content_judge and content_judger and product_info:
logging.info(f" 对Topic {topic_index}, Variant {variant_index}进行内容审核...")
# 准备审核内容
content_to_judge = f"""title: {content_json.get('title', '')}
content: {content_json.get('content', '')}
"""
# 调用ContentJudger进行审核
try:
judged_result = content_judger.judge_content(product_info, content_to_judge)
if judged_result and isinstance(judged_result, dict):
if "title" in judged_result and "content" in judged_result:
# 保存原始标题和内容
content_json["original_title"] = content_json.get("title", "")
content_json["original_content"] = content_json.get("content", "")
# 保存原始标签优先使用tags如果没有则使用tag
original_tags = content_json.get("tags", content_json.get("tag", ""))
content_json["original_tags"] = original_tags
# 更新为审核后的内容
content_json["title"] = judged_result["title"]
content_json["content"] = judged_result["content"]
# 保留原始标签,避免重复
content_json["tags"] = original_tags
# 删除可能存在的重复tag字段
if "tag" in content_json:
del content_json["tag"]
# 添加审核标记
content_json["judged"] = True
# 添加judge_success状态
content_json["judge_success"] = judged_result.get("judge_success", False)
# 处理分析结果,优先使用"analysis"字段,兼容"不良内容分析"字段
if "analysis" in judged_result:
content_json["judge_analysis"] = judged_result["analysis"]
else:
logging.warning(f" 审核结果缺少title或content字段保留原内容")
content_json["judge_success"] = False
else:
logging.warning(f" 内容审核返回无效结果,保留原内容")
content_json["judge_success"] = False
except Exception as judge_err:
logging.exception(f" 内容审核过程出错: {judge_err},保留原内容")
content_json["judge_success"] = False
else:
# 未启用内容审核时,添加相应标记
content_json["judged"] = False
content_json["judge_success"] = None
# Use the output handler to process/save the result
output_handler.handle_content_variant(
run_id, topic_index, variant_index, content_json, prompt_data or ""
)
success_flag = True # Mark success for this topic
# Check specifically if the AI result itself indicated a parsing error internally
if content_json.get("error"):
logging.warning(f" Content generation for Topic {topic_index}, Variant {variant_index} succeeded but response parsing had issues. Using empty content values.")
else:
logging.info(f" Successfully generated and handled content for Topic {topic_index}, Variant {variant_index}.")
else:
logging.error(f" Content generation failed for Topic {topic_index}, Variant {variant_index}. Skipping handling.")
except Exception as e:
logging.exception(f" Error during content generation call or handling for Topic {topic_index}, Variant {variant_index}:")
# Return the success flag for this topic
return success_flag
def generate_posters_for_topic(topic_item: dict,
output_dir: str,
run_id: str,
topic_index: int,
output_handler: OutputHandler, # 添加 handler
variants: int,
model_name: str,
base_url: str,
api_key: str,
poster_assets_base_dir: str,
image_base_dir: str,
img_frame_possibility: float,
text_bg_possibility: float,
title_possibility: float,
text_possibility: float,
resource_dir_config: list,
poster_target_size: tuple,
output_collage_subdir: str,
output_poster_subdir: str,
output_poster_filename: str,
system_prompt: str,
collage_style: str,
timeout: int
):
"""Generates all posters for a single topic item, handling image data via OutputHandler.
Args:
topic_item: The dictionary representing a single topic.
output_dir: The base output directory for the entire run (e.g., ./result).
run_id: The ID for the current run.
topic_index: The 1-based index of the current topic.
variants: Number of variants.
poster_assets_base_dir: Path to poster assets (fonts, frames etc.).
image_base_dir: Base path for source images.
img_frame_possibility: Probability of adding image frame.
text_bg_possibility: Probability of adding text background.
resource_dir_config: Configuration for resource directories (used for Description).
poster_target_size: Target size tuple (width, height) for the poster.
text_possibility: Probability of adding secondary text.
output_collage_subdir: Subdirectory name for saving collages.
output_poster_subdir: Subdirectory name for saving posters.
output_poster_filename: Filename for the final poster.
system_prompt: System prompt for content generation.
output_handler: An instance of OutputHandler to process results.
timeout: Timeout for content generation.
Returns:
True if poster generation was attempted (regardless of individual variant success),
False if setup failed before attempting variants.
"""
logging.info(f"Generating posters for Topic {topic_index} (Object: {topic_item.get('object', 'N/A')})...")
# --- Load content data from files ---
loaded_content_list = []
logging.info(f"Attempting to load content data for {variants} variants for topic {topic_index}...")
for j in range(variants):
variant_index = j + 1
variant_dir = os.path.join(output_dir, run_id, f"{topic_index}_{variant_index}")
content_path = os.path.join(variant_dir, "article.json")
try:
if os.path.exists(content_path):
with open(content_path, 'r', encoding='utf-8') as f_content:
content_data = json.load(f_content)
# 支持Base64编码格式的文件
if 'title_base64' in content_data and 'content_base64' in content_data:
import base64
logging.info(f"检测到Base64编码的内容文件: {content_path}")
# 解码Base64内容
try:
title = base64.b64decode(content_data.get('title_base64', '')).decode('utf-8')
content = base64.b64decode(content_data.get('content_base64', '')).decode('utf-8')
# 创建包含解码内容的新数据对象
decoded_data = {
'title': title,
'content': content,
'judge_success': content_data.get('judge_success', True),
'judged': content_data.get('judged', True)
}
# 如果有标签,也解码
if 'tags_base64' in content_data:
tags = base64.b64decode(content_data.get('tags_base64', '')).decode('utf-8')
decoded_data['tags'] = tags
loaded_content_list.append(decoded_data)
logging.debug(f" 已成功解码并加载Base64内容: {content_path}")
continue
except Exception as decode_error:
logging.error(f" 解码Base64内容时出错: {decode_error},跳过此文件")
continue
# 常规JSON格式检查
if isinstance(content_data, dict) and 'title' in content_data and 'content' in content_data:
loaded_content_list.append(content_data)
logging.debug(f" Successfully loaded content from: {content_path}")
else:
logging.warning(f" Content file {content_path} has invalid format. Skipping.")
else:
logging.warning(f" Content file not found for variant {variant_index}: {content_path}. Skipping.")
except json.JSONDecodeError:
logging.error(f" Error decoding JSON from content file: {content_path}. Skipping.")
except Exception as e:
logging.exception(f" Error loading content file {content_path}: {e}")
if not loaded_content_list:
logging.error(f"No valid content data loaded for topic {topic_index}. Cannot generate posters.")
return False
logging.info(f"Successfully loaded content data for {len(loaded_content_list)} variants.")
# --- End Load content data ---
# Initialize generators using parameters
try:
content_gen_instance = core_contentGen.ContentGenerator(model_name=model_name, base_url=base_url, api_key=api_key)
if not poster_assets_base_dir:
logging.error("Error: 'poster_assets_base_dir' not provided. Cannot generate posters.")
return False
poster_gen_instance = core_posterGen.PosterGenerator(base_dir=poster_assets_base_dir)
poster_gen_instance.set_img_frame_possibility(img_frame_possibility)
poster_gen_instance.set_text_bg_possibility(text_bg_possibility)
except Exception as e:
logging.exception("Error initializing generators for poster creation:")
return False
# --- Setup: Paths and Object Name ---
object_name = topic_item.get("object", "")
if not object_name:
logging.warning("Warning: Topic object name is missing. Cannot generate posters.")
return False
# Clean object name
try:
object_name_cleaned = object_name.split(".")[0].replace("景点信息-", "").strip()
if not object_name_cleaned:
logging.warning(f"Warning: Object name '{object_name}' resulted in empty string after cleaning. Skipping posters.")
return False
object_name = object_name_cleaned
except Exception as e:
logging.warning(f"Warning: Could not fully clean object name '{object_name}': {e}. Skipping posters.")
return False
# Construct and check INPUT image paths
input_img_dir_path = os.path.join(image_base_dir, object_name)
if not os.path.exists(input_img_dir_path) or not os.path.isdir(input_img_dir_path):
# 模糊匹配:如果找不到完全匹配的目录,尝试查找包含关键词的目录
logging.info(f"尝试对图片目录进行模糊匹配: {object_name}")
found_dir = None
# 1. 尝试获取image_base_dir下的所有目录
try:
all_dirs = [d for d in os.listdir(image_base_dir)
if os.path.isdir(os.path.join(image_base_dir, d))]
logging.info(f"找到 {len(all_dirs)} 个图片目录可用于模糊匹配")
# 2. 提取对象名称中的关键词
# 例如:"美的鹭湖鹭栖台酒店+盈香心动乐园" -> ["美的", "鹭湖", "酒店", "乐园"]
# 首先通过常见分隔符分割(+、空格、_、-等)
parts = re.split(r'[+\s_\-]', object_name)
keywords = []
for part in parts:
# 只保留长度大于1的有意义关键词
if len(part) > 1:
keywords.append(part)
# 尝试匹配更短的语义单元例如中文的2-3个字的词语
# 对于中文名称可以尝试提取2-3个字的短语
for i in range(len(object_name) - 1):
keyword = object_name[i:i+2] # 提取2个字符
if len(keyword) == 2 and all('\u4e00' <= c <= '\u9fff' for c in keyword):
keywords.append(keyword)
# 3. 对每个目录进行评分
dir_scores = {}
for directory in all_dirs:
score = 0
dir_lower = directory.lower()
# 为每个匹配的关键词增加分数
for keyword in keywords:
if keyword.lower() in dir_lower:
score += 1
# 如果得分大于0至少匹配一个关键词记录该目录
if score > 0:
dir_scores[directory] = score
# 4. 选择得分最高的目录
if dir_scores:
best_match = max(dir_scores.items(), key=lambda x: x[1])
found_dir = best_match[0]
logging.info(f"模糊匹配成功!匹配目录: {found_dir},匹配分数: {best_match[1]}")
# 更新图片目录路径
input_img_dir_path = os.path.join(image_base_dir, found_dir)
logging.info(f"使用模糊匹配的图片目录: {input_img_dir_path}")
else:
logging.warning(f"模糊匹配未找到任何包含关键词的目录")
except Exception as e:
logging.warning(f"模糊匹配过程中出错: {e}")
# 如果仍然无法找到有效目录,则返回错误
if not found_dir or not os.path.exists(input_img_dir_path) or not os.path.isdir(input_img_dir_path):
logging.warning(f"Warning: 即使通过模糊匹配也无法找到图片目录: '{input_img_dir_path}'. Skipping posters for this topic.")
return False
# Locate Description File using resource_dir_config parameter
info_directory = []
description_file_path = None
found_description = False
# 准备关键词列表用于模糊匹配
# 与上面图片目录匹配类似,提取对象名称的关键词
parts = re.split(r'[+\s_\-]', object_name)
keywords = []
for part in parts:
if len(part) > 1:
keywords.append(part)
# 尝试提取中文短语作为关键词
for i in range(len(object_name) - 1):
keyword = object_name[i:i+2]
if len(keyword) == 2 and all('\u4e00' <= c <= '\u9fff' for c in keyword):
keywords.append(keyword)
logging.info(f"用于描述文件模糊匹配的关键词: {keywords}")
# 尝试精确匹配
for dir_info in resource_dir_config:
if dir_info.get("type") == "Description":
for file_path in dir_info.get("file_path", []):
if object_name in os.path.basename(file_path):
description_file_path = file_path
if os.path.exists(description_file_path):
info_directory = [description_file_path]
logging.info(f"找到并使用精确匹配的描述文件: {description_file_path}")
found_description = True
else:
logging.warning(f"Warning: 配置中指定的描述文件未找到: {description_file_path}")
break
if found_description:
break
# 如果精确匹配失败,尝试模糊匹配
if not found_description:
logging.info(f"未找到'{object_name}'的精确匹配描述文件,尝试模糊匹配...")
best_score = 0
best_file = None
for dir_info in resource_dir_config:
if dir_info.get("type") == "Description":
for file_path in dir_info.get("file_path", []):
file_name = os.path.basename(file_path)
score = 0
# 计算关键词匹配分数
for keyword in keywords:
if keyword.lower() in file_name.lower():
score += 1
# 如果当前文件得分更高,更新最佳匹配
if score > best_score and os.path.exists(file_path):
best_score = score
best_file = file_path
# 如果找到了最佳匹配文件
if best_file:
description_file_path = best_file
info_directory = [description_file_path]
logging.info(f"模糊匹配找到描述文件: {description_file_path},匹配分数: {best_score}")
found_description = True
if not found_description:
logging.warning(f"未找到对象'{object_name}'的匹配描述文件。")
# Generate Text Configurations for All Variants
try:
poster_text_configs_raw = content_gen_instance.run(info_directory, variants, loaded_content_list, system_prompt, timeout=timeout)
if not poster_text_configs_raw:
logging.warning("Warning: ContentGenerator returned empty configuration data. Skipping posters.")
return False
# --- 使用 OutputHandler 保存 Poster Config ---
output_handler.handle_poster_configs(run_id, topic_index, poster_text_configs_raw)
# --- 结束使用 Handler 保存 ---
# 打印原始配置数据以进行调试
logging.info(f"生成的海报配置数据: {poster_text_configs_raw}")
# 直接使用配置数据,避免通过文件读取
if isinstance(poster_text_configs_raw, list):
poster_configs = poster_text_configs_raw
logging.info(f"直接使用生成的配置列表,包含 {len(poster_configs)} 个配置项")
else:
# 如果不是列表尝试转换或使用PosterConfig类解析
logging.info("生成的配置数据不是列表使用PosterConfig类进行处理")
poster_config_summary = core_posterGen.PosterConfig(poster_text_configs_raw)
poster_configs = poster_config_summary.get_config()
except Exception as e:
logging.exception("Error running ContentGenerator or parsing poster configs:")
traceback.print_exc()
return False
# Poster Generation Loop for each variant
poster_num = min(variants, len(poster_configs)) if isinstance(poster_configs, list) else variants
logging.info(f"计划生成 {poster_num} 个海报变体")
any_poster_attempted = False
for j_index in range(poster_num):
variant_index = j_index + 1
logging.info(f"Generating Poster {variant_index}/{poster_num}...")
any_poster_attempted = True
collage_img = None # To store the generated collage PIL Image
poster_img = None # To store the final poster PIL Image
try:
# 获取当前变体的配置
if isinstance(poster_configs, list) and j_index < len(poster_configs):
poster_config = poster_configs[j_index]
logging.info(f"使用配置数据项 {j_index+1}: {poster_config}")
else:
# 回退方案使用PosterConfig类
poster_config = poster_config_summary.get_config_by_index(j_index)
logging.info(f"使用PosterConfig类获取配置项 {j_index+1}")
if not poster_config:
logging.warning(f"Warning: Could not get poster config for index {j_index}. Skipping.")
continue
# --- Image Collage ---
logging.info(f"Generating collage from: {input_img_dir_path}")
collage_images, used_image_filenames = core_simple_collage.process_directory(
input_img_dir_path,
style=collage_style,
target_size=poster_target_size,
output_count=1
)
if not collage_images: # 检查列表是否为空
logging.warning(f"Warning: Failed to generate collage image for Variant {variant_index}. Skipping poster.")
continue
collage_img = collage_images[0] # 获取第一个 PIL Image
used_image_files = used_image_filenames[0] if used_image_filenames else [] # 获取使用的图片文件名
logging.info(f"Collage image generated successfully (in memory). Used images: {used_image_files}")
logging.info(f"拼贴图使用的图片文件: {used_image_files}")
# --- 使用 Handler 保存 Collage 图片和使用的图片文件信息 ---
output_handler.handle_generated_image(
run_id, topic_index, variant_index,
image_type='collage',
image_data=collage_img,
output_filename='collage.png', # 或者其他期望的文件名
metadata={'used_images': used_image_files} # 添加图片文件信息到元数据
)
# --- 结束保存 Collage ---
# --- Create Poster ---
if random.random() > title_possibility:
text_data = {
"title": poster_config.get('main_title', ''),
"subtitle": "",
"additional_texts": []
}
texts = poster_config.get('texts', [])
if texts:
# 确保文本不为空
if random.random() > text_possibility:
text_data["additional_texts"].append({"text": texts[0], "position": "middle", "size_factor": 0.8})
# for text in texts:
# if random.random() < text_possibility:
# text_data["additional_texts"].append({"text": text, "position": "middle", "size_factor": 0.8})
else:
text_data = {
"title": "",
"subtitle": "",
"additional_texts": []
}
# 打印要发送的文本数据
logging.info(f"文本数据: {text_data}")
# 调用修改后的 create_poster, 接收 PIL Image
poster_img = poster_gen_instance.create_poster(collage_img, text_data)
if poster_img:
logging.info(f"Poster image generated successfully (in memory).")
# --- 使用 Handler 保存 Poster 图片 ---
output_handler.handle_generated_image(
run_id, topic_index, variant_index,
image_type='poster',
image_data=poster_img,
output_filename=output_poster_filename, # 使用参数中的文件名
metadata={'used_collage': True, 'collage_images': used_image_files}
)
# --- 结束保存 Poster ---
else:
logging.warning(f"Warning: Poster generation function returned None for variant {variant_index}.")
except Exception as e:
logging.exception(f"Error during poster generation for Variant {variant_index}:")
traceback.print_exc()
continue
return any_poster_attempted