TravelContentCreator/utils/tweet_generator.py

943 lines
46 KiB
Python
Raw Normal View History

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import time
import random
import argparse
import json
from datetime import datetime
import sys
2025-04-22 14:16:29 +08:00
import traceback
2025-04-22 17:36:29 +08:00
import logging # Add logging
2025-04-27 10:38:30 +08:00
import re
# sys.path.append('/root/autodl-tmp') # No longer needed if running as a module or if path is set correctly
# 从本地模块导入
# from TravelContentCreator.core.ai_agent import AI_Agent # Remove project name prefix
# from TravelContentCreator.core.topic_parser import TopicParser # Remove project name prefix
2025-04-22 14:19:21 +08:00
# ResourceLoader is now used implicitly via PromptManager
# from TravelContentCreator.utils.resource_loader import ResourceLoader
# from TravelContentCreator.utils.prompt_manager import PromptManager # Remove project name prefix
# from ..core import contentGen as core_contentGen # Change to absolute import
# from ..core import posterGen as core_posterGen # Change to absolute import
# from ..core import simple_collage as core_simple_collage # Change to absolute import
from core.ai_agent import AI_Agent
from core.topic_parser import TopicParser
from utils.prompt_manager import PromptManager # Keep this as it's importing from the same level package 'utils'
from utils import content_generator as core_contentGen
2025-04-25 10:11:45 +08:00
from core import poster_gen as core_posterGen
from core import simple_collage as core_simple_collage
from .output_handler import OutputHandler # <-- 添加导入
2025-05-10 20:12:49 +08:00
from utils.content_judger import ContentJudger # <-- 添加ContentJudger导入
class tweetTopic:
def __init__(self, index, date, logic, object, product, product_logic, style, style_logic, target_audience, target_audience_logic):
self.index = index
self.date = date
self.logic = logic
self.object = object
self.product = product
self.product_logic = product_logic
self.style = style
self.style_logic = style_logic
self.target_audience = target_audience
self.target_audience_logic = target_audience_logic
class tweetTopicRecord:
def __init__(self, topics_list, system_prompt, user_prompt, run_id):
self.topics_list = topics_list
self.system_prompt = system_prompt
self.user_prompt = user_prompt
self.run_id = run_id
class tweetContent:
def __init__(self, result, prompt, run_id, article_index, variant_index):
self.result = result
self.prompt = prompt
self.run_id = run_id
self.article_index = article_index
self.variant_index = variant_index
2025-04-22 17:36:29 +08:00
try:
2025-05-08 23:35:03 +08:00
self.json_data = self.split_content(result)
2025-04-22 17:36:29 +08:00
except Exception as e:
logging.error(f"Failed to parse AI result for {article_index}_{variant_index}: {e}")
logging.debug(f"Raw result: {result[:500]}...") # Log partial raw result
2025-05-08 23:35:03 +08:00
self.json_data = {"title": "", "content": "", "tag": "", "error": True, "raw_result": e} # 不再包含raw_result
2025-04-22 17:36:29 +08:00
def split_content(self, result):
2025-04-22 17:36:29 +08:00
# Assuming split logic might still fail, keep it simple or improve with regex/json
# We should ideally switch content generation to JSON output as well.
# For now, keep existing logic but handle errors in __init__.
# Optional: Add basic check before splitting
# if not result or "</think>" not in result or "title>" not in result or "content>" not in result:
# logging.warning(f"AI result format unexpected: {result[:200]}...")
# # 返回空字符串而不是抛出异常,这样可以在主函数继续处理
# return "", ""
2025-04-22 17:36:29 +08:00
# --- Existing Logic (prone to errors) ---
try:
processed_result = result
if "</think>" in result:
2025-05-08 23:35:03 +08:00
processed_result = result.split("</think>")[1] # Take part after </think>
2025-05-08 22:39:49 +08:00
# 以json 格式输出
json_data = json.loads(processed_result)
2025-05-08 23:35:03 +08:00
json_data["error"] = False
json_data["raw_result"] = None
# 确保judge_success字段存在
if "judge_success" not in json_data:
json_data["judge_success"] = None
2025-05-08 23:35:03 +08:00
return json_data
# --- End Existing Logic ---
2025-05-08 23:35:03 +08:00
except Exception as e:
logging.warning(f"解析内容时出错: {e}, 使用默认空内容")
# 创建一个新的json_data而不是使用未定义的变量
return {
"title": "",
"content": "",
"error": True,
"raw_result": str(e),
"judge_success": False
}
def get_json_data(self):
"""Returns the generated JSON data dictionary."""
return self.json_data
def get_prompt(self):
"""Returns the user prompt used to generate this content."""
return self.prompt
def get_content(self):
return self.content
def get_title(self):
return self.title
def generate_topics(ai_agent, system_prompt, user_prompt, run_id, temperature=0.2, top_p=0.5, presence_penalty=1.5):
"""生成选题列表 (run_id is now passed in, output_dir removed as argument)"""
2025-04-22 17:36:29 +08:00
logging.info("Starting topic generation...")
time_start = time.time()
2025-04-22 17:36:29 +08:00
# Call AI agent work method (updated return values)
result, tokens, time_cost = ai_agent.work(
system_prompt, user_prompt, "", temperature, top_p, presence_penalty
)
2025-04-22 17:36:29 +08:00
logging.info(f"Topic generation API call completed in {time_cost:.2f}s. Estimated tokens: {tokens}")
2025-04-22 17:36:29 +08:00
# Parse topics
2025-04-26 12:08:40 +08:00
result_list = TopicParser.parse_topics(result, run_id)
2025-04-22 17:36:29 +08:00
if not result_list:
logging.warning("Topic parsing resulted in an empty list.")
# Optionally handle raw result logging here if needed, but saving is responsibility of OutputHandler
# error_log_path = os.path.join(output_dir, run_id, f"topic_parsing_error_{run_id}.txt") # output_dir is not available here
# try:
# # ... (save raw output logic) ...
# except Exception as log_err:
# logging.error(f"Failed to save raw AI output on parsing failure: {log_err}")
2025-04-22 17:36:29 +08:00
# 直接返回解析后的列表
return result_list
def generate_single_content(ai_agent, system_prompt, user_prompt, item, run_id,
2025-04-17 18:39:49 +08:00
article_index, variant_index, temperature=0.3, top_p=0.4, presence_penalty=1.5):
"""Generates single content variant data. Returns (content_json, user_prompt) or (None, None)."""
2025-04-22 17:36:29 +08:00
logging.info(f"Generating content for topic {article_index}, variant {variant_index}")
2025-04-17 16:14:41 +08:00
try:
2025-04-22 14:19:21 +08:00
if not system_prompt or not user_prompt:
2025-04-22 17:36:29 +08:00
logging.error("System or User prompt is empty. Cannot generate content.")
2025-04-22 14:19:21 +08:00
return None, None
2025-04-22 17:36:29 +08:00
logging.debug(f"Using pre-constructed prompts. User prompt length: {len(user_prompt)}")
2025-04-17 16:14:41 +08:00
2025-04-22 17:36:29 +08:00
time.sleep(random.random() * 0.5)
2025-04-17 16:14:41 +08:00
# Generate content (non-streaming work returns result, tokens, time_cost)
2025-04-22 17:36:29 +08:00
result, tokens, time_cost = ai_agent.work(
2025-04-17 18:39:49 +08:00
system_prompt, user_prompt, "", temperature, top_p, presence_penalty
2025-04-17 16:14:41 +08:00
)
if result is None: # Check if AI call failed
logging.error(f"AI agent work failed for {article_index}_{variant_index}. No result returned.")
return {"title": "", "content": "", "error": True, "judge_success": False}, user_prompt # 添加judge_success字段
2025-04-22 17:36:29 +08:00
logging.info(f"Content generation for {article_index}_{variant_index} completed in {time_cost:.2f}s. Estimated tokens: {tokens}")
2025-04-17 16:14:41 +08:00
# --- Create tweetContent object (handles parsing) ---
# Pass user_prompt instead of full prompt? Yes, user_prompt is what we need later.
tweet_content = tweetContent(result, user_prompt, run_id, article_index, variant_index)
# --- Remove Saving Logic ---
# run_specific_output_dir = os.path.join(output_dir, run_id) # output_dir no longer available
# variant_result_dir = os.path.join(run_specific_output_dir, f"{article_index}_{variant_index}")
# os.makedirs(variant_result_dir, exist_ok=True)
# content_save_path = os.path.join(variant_result_dir, "article.json")
# prompt_save_path = os.path.join(variant_result_dir, "tweet_prompt.txt")
# tweet_content.save_content(content_save_path) # Method removed
# tweet_content.save_prompt(prompt_save_path) # Method removed
# --- End Remove Saving Logic ---
# Return the data needed by the output handler
content_json = tweet_content.get_json_data()
prompt_data = tweet_content.get_prompt() # Get the stored user prompt
return content_json, prompt_data # Return data pair
2025-04-17 16:14:41 +08:00
except Exception as e:
2025-04-22 17:36:29 +08:00
logging.exception(f"Error generating single content for {article_index}_{variant_index}:")
return {"title": "", "content": "", "error": True, "judge_success": False}, user_prompt # 添加judge_success字段
2025-04-17 16:14:41 +08:00
def generate_content(ai_agent, system_prompt, topics, output_dir, run_id, prompts_dir, resource_dir,
variants=2, temperature=0.3, start_index=0, end_index=None):
"""根据选题生成内容"""
if not topics:
logging.warning("没有选题,无法生成内容")
return
# 确定处理范围
if end_index is None or end_index > len(topics):
end_index = len(topics)
topics_to_process = topics[start_index:end_index]
logging.info(f"准备处理{len(topics_to_process)}个选题...")
# 创建汇总文件
# summary_file = ResourceLoader.create_summary_file(output_dir, run_id, len(topics_to_process))
# 处理每个选题
processed_results = []
for i, item in enumerate(topics_to_process):
logging.info(f"处理第 {i+1}/{len(topics_to_process)} 篇文章")
2025-04-17 16:14:41 +08:00
# 为每个选题生成多个变体
for j in range(variants):
logging.info(f"正在生成变体 {j+1}/{variants}")
2025-04-17 16:14:41 +08:00
# 调用单篇文章生成函数
tweet_content, result = generate_single_content(
ai_agent, system_prompt, item, run_id, i+1, j+1, temperature
2025-04-17 16:14:41 +08:00
)
if tweet_content:
processed_results.append(tweet_content)
# # 更新汇总文件 (仅保存第一个变体到汇总文件)
# if j == 0:
# ResourceLoader.update_summary(summary_file, i+1, user_prompt, result)
logging.info(f"完成{len(processed_results)}篇文章生成")
return processed_results
def prepare_topic_generation(prompt_manager: PromptManager,
api_url: str,
model_name: str,
api_key: str,
timeout: int,
max_retries: int):
"""准备选题生成的环境和参数. Returns agent, system_prompt, user_prompt.
Args:
prompt_manager: An initialized PromptManager instance.
api_url, model_name, api_key, timeout, max_retries: Parameters for AI_Agent.
"""
logging.info("Preparing for topic generation (using provided PromptManager)...")
# 从传入的 PromptManager 获取 prompts
2025-04-22 14:19:21 +08:00
system_prompt, user_prompt = prompt_manager.get_topic_prompts()
2025-04-22 14:19:21 +08:00
if not system_prompt or not user_prompt:
logging.error("Failed to get topic generation prompts from PromptManager.")
return None, None, None # Return three Nones
# 使用传入的参数初始化 AI Agent
2025-04-22 14:19:21 +08:00
try:
2025-04-22 17:36:29 +08:00
logging.info("Initializing AI Agent for topic generation...")
2025-04-22 16:30:48 +08:00
ai_agent = AI_Agent(
api_url, # Use passed arg
model_name, # Use passed arg
api_key, # Use passed arg
timeout=timeout, # Use passed arg
max_retries=max_retries # Use passed arg
2025-04-22 16:30:48 +08:00
)
2025-04-22 14:19:21 +08:00
except Exception as e:
2025-04-22 17:36:29 +08:00
logging.exception("Error initializing AI Agent for topic generation:")
return None, None, None # Return three Nones
2025-04-22 14:19:21 +08:00
# 返回 agent 和 prompts
return ai_agent, system_prompt, user_prompt
2025-04-22 15:39:35 +08:00
def run_topic_generation_pipeline(config, run_id=None):
"""
Runs the complete topic generation pipeline based on the configuration.
Returns: (run_id, topics_list, system_prompt, user_prompt) or (None, None, None, None) on failure.
"""
2025-04-22 17:36:29 +08:00
logging.info("Starting Step 1: Topic Generation Pipeline...")
2025-04-22 15:39:35 +08:00
if run_id is None:
2025-04-22 17:36:29 +08:00
logging.info("No run_id provided, generating one based on timestamp.")
2025-04-22 15:39:35 +08:00
run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
else:
2025-04-22 17:36:29 +08:00
logging.info(f"Using provided run_id: {run_id}")
2025-04-22 15:39:35 +08:00
ai_agent, system_prompt, user_prompt = None, None, None # Initialize
topics_list = None
prompt_manager = None # Initialize prompt_manager
2025-04-22 14:16:29 +08:00
try:
# --- 读取 PromptManager 所需参数 ---
topic_sys_prompt_path = config.get("topic_system_prompt")
topic_user_prompt_path = config.get("topic_user_prompt")
content_sys_prompt_path = config.get("content_system_prompt") # 虽然这里不用,但 PromptManager 可能需要
prompts_dir_path = config.get("prompts_dir")
2025-04-24 21:44:05 +08:00
prompts_config = config.get("prompts_config") # 新增获取prompts_config配置
resource_config = config.get("resource_dir", [])
topic_num = config.get("num", 1)
topic_date = config.get("date", "")
# --- 创建 PromptManager 实例 ---
prompt_manager = PromptManager(
topic_system_prompt_path=topic_sys_prompt_path,
topic_user_prompt_path=topic_user_prompt_path,
content_system_prompt_path=content_sys_prompt_path,
prompts_dir=prompts_dir_path,
2025-04-24 21:44:05 +08:00
prompts_config=prompts_config, # 新增传入prompts_config配置
resource_dir_config=resource_config,
topic_gen_num=topic_num,
topic_gen_date=topic_date
)
logging.info("PromptManager instance created.")
# --- 读取 AI Agent 所需参数 ---
ai_api_url = config.get("api_url")
ai_model = config.get("model")
ai_api_key = config.get("api_key")
ai_timeout = config.get("request_timeout", 30)
ai_max_retries = config.get("max_retries", 3)
2025-04-22 14:19:21 +08:00
# 检查必需的 AI 参数是否存在
if not all([ai_api_url, ai_model, ai_api_key]):
raise ValueError("Missing required AI configuration (api_url, model, api_key) in config.")
# --- 调用修改后的 prepare_topic_generation ---
ai_agent, system_prompt, user_prompt = prepare_topic_generation(
prompt_manager, # Pass instance
ai_api_url,
ai_model,
ai_api_key,
ai_timeout,
ai_max_retries
)
2025-04-22 14:19:21 +08:00
if not ai_agent or not system_prompt or not user_prompt:
raise ValueError("Failed to prepare topic generation (agent or prompts missing).")
# --- Generate topics (保持不变) ---
topics_list = generate_topics(
ai_agent, system_prompt, user_prompt,
run_id, # Pass run_id
2025-04-22 14:16:29 +08:00
config.get("topic_temperature", 0.2),
config.get("topic_top_p", 0.5),
config.get("topic_presence_penalty", 1.5)
2025-04-22 14:16:29 +08:00
)
except Exception as e:
logging.exception("Error during topic generation pipeline execution:")
# Ensure agent is closed even if generation fails mid-way
if ai_agent: ai_agent.close()
return None, None, None, None # Signal failure
finally:
# Ensure the AI agent is closed after generation attempt (if initialized)
if ai_agent:
logging.info("Closing topic generation AI Agent...")
ai_agent.close()
if topics_list is None: # Check if generate_topics returned None (though it currently returns list)
logging.error("Topic generation failed (generate_topics returned None or an error occurred).")
return None, None, None, None
elif not topics_list: # Check if list is empty
logging.warning(f"Topic generation completed for run {run_id}, but the resulting topic list is empty.")
# Return empty list and prompts anyway, let caller decide
2025-04-22 14:16:29 +08:00
# --- Saving logic removed previously ---
2025-04-22 14:16:29 +08:00
logging.info(f"Topic generation pipeline completed successfully. Run ID: {run_id}")
# Return the raw data needed by the OutputHandler
return run_id, topics_list, system_prompt, user_prompt
2025-04-22 14:16:29 +08:00
# --- Decoupled Functional Units (Moved from main.py) ---
def generate_content_for_topic(ai_agent: AI_Agent,
prompt_manager: PromptManager,
topic_item: dict,
run_id: str,
topic_index: int,
output_handler: OutputHandler, # Changed name to match convention
# 添加具体参数,移除 config 和 output_dir
variants: int,
temperature: float,
top_p: float,
2025-05-10 20:12:49 +08:00
presence_penalty: float,
enable_content_judge: bool):
"""Generates all content variants for a single topic item and uses OutputHandler.
Args:
ai_agent: Initialized AI_Agent instance.
prompt_manager: Initialized PromptManager instance.
topic_item: Dictionary representing the topic.
run_id: Current run ID.
topic_index: 1-based index of the topic.
output_handler: An instance of OutputHandler to process results.
variants: Number of variants to generate.
temperature, top_p, presence_penalty: AI generation parameters.
2025-05-10 20:12:49 +08:00
enable_content_judge: Whether to enable content judge.
Returns:
bool: True if at least one variant was successfully generated and handled, False otherwise.
"""
2025-04-22 17:36:29 +08:00
logging.info(f"Generating content for Topic {topic_index} (Object: {topic_item.get('object', 'N/A')})...")
success_flag = False # Track if any variant succeeded
# 使用传入的 variants 参数
# variants = config.get("variants", 1)
2025-05-10 20:12:49 +08:00
# 如果启用了内容审核,获取产品资料
product_info = None
content_judger = None
if enable_content_judge:
logging.info(f"内容审核功能已启用,准备获取产品资料...")
# 从topic_item中获取产品名称和对象名称
product_name = topic_item.get("product", "")
object_name = topic_item.get("object", "")
# 组合获取产品资料
product_info = ""
# 获取对象信息
if object_name:
# 通过PromptManager获取对象和产品资料
# 这一部分逻辑来自PromptManager.get_content_prompts中对object_info的构建
found_object_info = False
all_description_files = []
# 从resource_dir_config搜集所有可能的资源文件
for dir_info in prompt_manager.resource_dir_config:
if dir_info.get("type") in ["Object", "Description"]:
all_description_files.extend(dir_info.get("file_path", []))
# 尝试精确匹配对象资料
for file_path in all_description_files:
if object_name in os.path.basename(file_path):
from utils.resource_loader import ResourceLoader
info = ResourceLoader.load_file_content(file_path)
if info:
product_info += f"Object: {object_name}\n{info}\n\n"
logging.info(f"为内容审核找到对象'{object_name}'的资源文件: {file_path}")
found_object_info = True
break
# 如果未找到对象资料,记录警告但继续处理
if not found_object_info:
logging.warning(f"未能为内容审核找到对象'{object_name}'的资源文件")
# 获取产品信息
if product_name:
found_product_info = False
all_product_files = []
# 搜集所有可能的产品资源文件
for dir_info in prompt_manager.resource_dir_config:
if dir_info.get("type") == "Product":
all_product_files.extend(dir_info.get("file_path", []))
# 尝试精确匹配产品资料
for file_path in all_product_files:
if product_name in os.path.basename(file_path):
from utils.resource_loader import ResourceLoader
info = ResourceLoader.load_file_content(file_path)
if info:
product_info += f"Product: {product_name}\n{info}\n\n"
logging.info(f"为内容审核找到产品'{product_name}'的资源文件: {file_path}")
found_product_info = True
break
# 如果未找到产品资料,记录警告但继续处理
if not found_product_info:
logging.warning(f"未能为内容审核找到产品'{product_name}'的资源文件")
# 如果成功获取产品资料初始化ContentJudger
if product_info:
logging.info("成功获取产品资料初始化ContentJudger...")
# 从配置中读取系统提示词路径(脚本级别无法直接获取,需要传递)
# 使用ai_agent的model_name或api_url判断是否使用主AI模型避免额外资源占用
content_judger_system_prompt_path = prompt_manager._system_prompt_cache.get("judger_system_prompt")
content_judger = ContentJudger(ai_agent, system_prompt_path=content_judger_system_prompt_path)
else:
logging.warning("未能获取产品资料,内容审核功能将被跳过")
enable_content_judge = False
for j in range(variants):
variant_index = j + 1
logging.info(f" Generating Variant {variant_index}/{variants}...")
# PromptManager 实例已传入,直接调用
content_system_prompt, content_user_prompt = prompt_manager.get_content_prompts(topic_item)
if not content_system_prompt or not content_user_prompt:
logging.warning(f" Skipping Variant {variant_index} due to missing content prompts.")
continue
time.sleep(random.random() * 0.5)
try:
# Call generate_single_content with passed-in parameters
content_json, prompt_data = generate_single_content(
ai_agent, content_system_prompt, content_user_prompt, topic_item,
run_id, topic_index, variant_index,
temperature, # 使用传入的参数
top_p, # 使用传入的参数
presence_penalty # 使用传入的参数
)
# 简化检查只要content_json不是None就处理它
# 即使是空标题和内容也是有效的结果
if content_json is not None:
2025-05-10 20:12:49 +08:00
# 进行内容审核如果启用且ContentJudger已初始化
if enable_content_judge and content_judger and product_info:
logging.info(f" 对Topic {topic_index}, Variant {variant_index}进行内容审核...")
# 准备审核内容
content_to_judge = f"""title: {content_json.get('title', '')}
content: {content_json.get('content', '')}
"""
# 调用ContentJudger进行审核
try:
judged_result = content_judger.judge_content(product_info, content_to_judge)
if judged_result and isinstance(judged_result, dict):
if "title" in judged_result and "content" in judged_result:
2025-05-10 20:53:31 +08:00
# 保存原始标题和内容
content_json["original_title"] = content_json.get("title", "")
content_json["original_content"] = content_json.get("content", "")
2025-05-12 15:44:54 +08:00
# 保存原始标签优先使用tags如果没有则使用tag
original_tags = content_json.get("tags", content_json.get("tag", ""))
content_json["original_tags"] = original_tags
2025-05-10 20:53:31 +08:00
# 更新为审核后的内容
2025-05-10 20:12:49 +08:00
content_json["title"] = judged_result["title"]
content_json["content"] = judged_result["content"]
2025-05-12 15:44:54 +08:00
# 保留原始标签,避免重复
content_json["tags"] = original_tags
# 删除可能存在的重复tag字段
if "tag" in content_json:
del content_json["tag"]
2025-05-10 20:12:49 +08:00
# 添加审核标记
content_json["judged"] = True
# 添加judge_success状态
content_json["judge_success"] = judged_result.get("judge_success", False)
2025-05-10 20:12:49 +08:00
# 可选:保存审核分析结果
if "不良内容分析" in judged_result:
content_json["judge_analysis"] = judged_result["不良内容分析"]
else:
logging.warning(f" 审核结果缺少title或content字段保留原内容")
content_json["judge_success"] = False
2025-05-10 20:12:49 +08:00
else:
logging.warning(f" 内容审核返回无效结果,保留原内容")
content_json["judge_success"] = False
2025-05-10 20:12:49 +08:00
except Exception as judge_err:
logging.exception(f" 内容审核过程出错: {judge_err},保留原内容")
content_json["judge_success"] = False
else:
# 未启用内容审核时,添加相应标记
content_json["judged"] = False
content_json["judge_success"] = None
2025-05-10 20:12:49 +08:00
# Use the output handler to process/save the result
output_handler.handle_content_variant(
run_id, topic_index, variant_index, content_json, prompt_data or ""
)
success_flag = True # Mark success for this topic
# Check specifically if the AI result itself indicated a parsing error internally
if content_json.get("error"):
logging.warning(f" Content generation for Topic {topic_index}, Variant {variant_index} succeeded but response parsing had issues. Using empty content values.")
else:
logging.info(f" Successfully generated and handled content for Topic {topic_index}, Variant {variant_index}.")
else:
logging.error(f" Content generation failed for Topic {topic_index}, Variant {variant_index}. Skipping handling.")
except Exception as e:
logging.exception(f" Error during content generation call or handling for Topic {topic_index}, Variant {variant_index}:")
# Return the success flag for this topic
return success_flag
def generate_posters_for_topic(topic_item: dict,
output_dir: str,
run_id: str,
topic_index: int,
output_handler: OutputHandler, # 添加 handler
variants: int,
model_name: str,
base_url: str,
api_key: str,
poster_assets_base_dir: str,
image_base_dir: str,
2025-04-24 19:09:50 +08:00
img_frame_possibility: float,
text_bg_possibility: float,
2025-04-25 17:35:43 +08:00
title_possibility: float,
text_possibility: float,
resource_dir_config: list,
poster_target_size: tuple,
output_collage_subdir: str,
output_poster_subdir: str,
output_poster_filename: str,
system_prompt: str,
collage_style: str,
timeout: int
):
"""Generates all posters for a single topic item, handling image data via OutputHandler.
Args:
topic_item: The dictionary representing a single topic.
output_dir: The base output directory for the entire run (e.g., ./result).
run_id: The ID for the current run.
topic_index: The 1-based index of the current topic.
variants: Number of variants.
poster_assets_base_dir: Path to poster assets (fonts, frames etc.).
image_base_dir: Base path for source images.
img_frame_possibility: Probability of adding image frame.
text_bg_possibility: Probability of adding text background.
resource_dir_config: Configuration for resource directories (used for Description).
poster_target_size: Target size tuple (width, height) for the poster.
text_possibility: Probability of adding secondary text.
output_collage_subdir: Subdirectory name for saving collages.
output_poster_subdir: Subdirectory name for saving posters.
output_poster_filename: Filename for the final poster.
2025-04-24 18:57:05 +08:00
system_prompt: System prompt for content generation.
output_handler: An instance of OutputHandler to process results.
timeout: Timeout for content generation.
Returns:
True if poster generation was attempted (regardless of individual variant success),
False if setup failed before attempting variants.
"""
2025-04-22 17:36:29 +08:00
logging.info(f"Generating posters for Topic {topic_index} (Object: {topic_item.get('object', 'N/A')})...")
# --- Load content data from files ---
loaded_content_list = []
logging.info(f"Attempting to load content data for {variants} variants for topic {topic_index}...")
for j in range(variants):
variant_index = j + 1
variant_dir = os.path.join(output_dir, run_id, f"{topic_index}_{variant_index}")
content_path = os.path.join(variant_dir, "article.json")
try:
if os.path.exists(content_path):
with open(content_path, 'r', encoding='utf-8') as f_content:
content_data = json.load(f_content)
if isinstance(content_data, dict) and 'title' in content_data and 'content' in content_data:
loaded_content_list.append(content_data)
logging.debug(f" Successfully loaded content from: {content_path}")
else:
logging.warning(f" Content file {content_path} has invalid format. Skipping.")
else:
logging.warning(f" Content file not found for variant {variant_index}: {content_path}. Skipping.")
except json.JSONDecodeError:
logging.error(f" Error decoding JSON from content file: {content_path}. Skipping.")
except Exception as e:
logging.exception(f" Error loading content file {content_path}: {e}")
if not loaded_content_list:
logging.error(f"No valid content data loaded for topic {topic_index}. Cannot generate posters.")
return False
logging.info(f"Successfully loaded content data for {len(loaded_content_list)} variants.")
# --- End Load content data ---
# Initialize generators using parameters
try:
content_gen_instance = core_contentGen.ContentGenerator(model_name=model_name, base_url=base_url, api_key=api_key)
if not poster_assets_base_dir:
logging.error("Error: 'poster_assets_base_dir' not provided. Cannot generate posters.")
return False
poster_gen_instance = core_posterGen.PosterGenerator(base_dir=poster_assets_base_dir)
2025-04-24 19:09:50 +08:00
poster_gen_instance.set_img_frame_possibility(img_frame_possibility)
poster_gen_instance.set_text_bg_possibility(text_bg_possibility)
except Exception as e:
2025-04-22 17:36:29 +08:00
logging.exception("Error initializing generators for poster creation:")
2025-05-10 20:12:49 +08:00
return False
# --- Setup: Paths and Object Name ---
object_name = topic_item.get("object", "")
if not object_name:
2025-04-22 17:36:29 +08:00
logging.warning("Warning: Topic object name is missing. Cannot generate posters.")
return False
# Clean object name
try:
object_name_cleaned = object_name.split(".")[0].replace("景点信息-", "").strip()
if not object_name_cleaned:
2025-04-22 17:36:29 +08:00
logging.warning(f"Warning: Object name '{object_name}' resulted in empty string after cleaning. Skipping posters.")
return False
object_name = object_name_cleaned
except Exception as e:
2025-04-22 17:36:29 +08:00
logging.warning(f"Warning: Could not fully clean object name '{object_name}': {e}. Skipping posters.")
return False
# Construct and check INPUT image paths
input_img_dir_path = os.path.join(image_base_dir, object_name)
if not os.path.exists(input_img_dir_path) or not os.path.isdir(input_img_dir_path):
2025-04-27 10:38:30 +08:00
# 模糊匹配:如果找不到完全匹配的目录,尝试查找包含关键词的目录
logging.info(f"尝试对图片目录进行模糊匹配: {object_name}")
found_dir = None
# 1. 尝试获取image_base_dir下的所有目录
try:
all_dirs = [d for d in os.listdir(image_base_dir)
if os.path.isdir(os.path.join(image_base_dir, d))]
logging.info(f"找到 {len(all_dirs)} 个图片目录可用于模糊匹配")
# 2. 提取对象名称中的关键词
# 例如:"美的鹭湖鹭栖台酒店+盈香心动乐园" -> ["美的", "鹭湖", "酒店", "乐园"]
# 首先通过常见分隔符分割(+、空格、_、-等)
parts = re.split(r'[+\s_\-]', object_name)
keywords = []
for part in parts:
# 只保留长度大于1的有意义关键词
if len(part) > 1:
keywords.append(part)
# 尝试匹配更短的语义单元例如中文的2-3个字的词语
# 对于中文名称可以尝试提取2-3个字的短语
for i in range(len(object_name) - 1):
keyword = object_name[i:i+2] # 提取2个字符
if len(keyword) == 2 and all('\u4e00' <= c <= '\u9fff' for c in keyword):
keywords.append(keyword)
# 3. 对每个目录进行评分
dir_scores = {}
for directory in all_dirs:
score = 0
dir_lower = directory.lower()
# 为每个匹配的关键词增加分数
for keyword in keywords:
if keyword.lower() in dir_lower:
score += 1
# 如果得分大于0至少匹配一个关键词记录该目录
if score > 0:
dir_scores[directory] = score
# 4. 选择得分最高的目录
if dir_scores:
best_match = max(dir_scores.items(), key=lambda x: x[1])
found_dir = best_match[0]
logging.info(f"模糊匹配成功!匹配目录: {found_dir},匹配分数: {best_match[1]}")
# 更新图片目录路径
input_img_dir_path = os.path.join(image_base_dir, found_dir)
logging.info(f"使用模糊匹配的图片目录: {input_img_dir_path}")
else:
logging.warning(f"模糊匹配未找到任何包含关键词的目录")
except Exception as e:
logging.warning(f"模糊匹配过程中出错: {e}")
# 如果仍然无法找到有效目录,则返回错误
if not found_dir or not os.path.exists(input_img_dir_path) or not os.path.isdir(input_img_dir_path):
logging.warning(f"Warning: 即使通过模糊匹配也无法找到图片目录: '{input_img_dir_path}'. Skipping posters for this topic.")
return False
# Locate Description File using resource_dir_config parameter
info_directory = []
description_file_path = None
found_description = False
2025-04-27 10:38:30 +08:00
# 准备关键词列表用于模糊匹配
# 与上面图片目录匹配类似,提取对象名称的关键词
parts = re.split(r'[+\s_\-]', object_name)
keywords = []
for part in parts:
if len(part) > 1:
keywords.append(part)
# 尝试提取中文短语作为关键词
for i in range(len(object_name) - 1):
keyword = object_name[i:i+2]
if len(keyword) == 2 and all('\u4e00' <= c <= '\u9fff' for c in keyword):
keywords.append(keyword)
logging.info(f"用于描述文件模糊匹配的关键词: {keywords}")
# 尝试精确匹配
for dir_info in resource_dir_config:
if dir_info.get("type") == "Description":
for file_path in dir_info.get("file_path", []):
if object_name in os.path.basename(file_path):
description_file_path = file_path
if os.path.exists(description_file_path):
info_directory = [description_file_path]
2025-04-27 10:38:30 +08:00
logging.info(f"找到并使用精确匹配的描述文件: {description_file_path}")
found_description = True
else:
2025-04-27 10:38:30 +08:00
logging.warning(f"Warning: 配置中指定的描述文件未找到: {description_file_path}")
break
if found_description:
2025-04-27 10:38:30 +08:00
break
# 如果精确匹配失败,尝试模糊匹配
if not found_description:
logging.info(f"未找到'{object_name}'的精确匹配描述文件,尝试模糊匹配...")
best_score = 0
best_file = None
for dir_info in resource_dir_config:
if dir_info.get("type") == "Description":
for file_path in dir_info.get("file_path", []):
file_name = os.path.basename(file_path)
score = 0
# 计算关键词匹配分数
for keyword in keywords:
if keyword.lower() in file_name.lower():
score += 1
# 如果当前文件得分更高,更新最佳匹配
if score > best_score and os.path.exists(file_path):
best_score = score
best_file = file_path
# 如果找到了最佳匹配文件
if best_file:
description_file_path = best_file
info_directory = [description_file_path]
logging.info(f"模糊匹配找到描述文件: {description_file_path},匹配分数: {best_score}")
found_description = True
if not found_description:
2025-04-27 10:38:30 +08:00
logging.warning(f"未找到对象'{object_name}'的匹配描述文件。")
# Generate Text Configurations for All Variants
try:
poster_text_configs_raw = content_gen_instance.run(info_directory, variants, loaded_content_list, system_prompt, timeout=timeout)
if not poster_text_configs_raw:
2025-04-22 17:36:29 +08:00
logging.warning("Warning: ContentGenerator returned empty configuration data. Skipping posters.")
return False
2025-04-22 17:46:21 +08:00
# --- 使用 OutputHandler 保存 Poster Config ---
output_handler.handle_poster_configs(run_id, topic_index, poster_text_configs_raw)
2025-04-22 21:26:56 +08:00
# --- 结束使用 Handler 保存 ---
# 打印原始配置数据以进行调试
logging.info(f"生成的海报配置数据: {poster_text_configs_raw}")
2025-04-22 17:46:21 +08:00
2025-04-22 21:26:56 +08:00
# 直接使用配置数据,避免通过文件读取
if isinstance(poster_text_configs_raw, list):
poster_configs = poster_text_configs_raw
logging.info(f"直接使用生成的配置列表,包含 {len(poster_configs)} 个配置项")
else:
# 如果不是列表尝试转换或使用PosterConfig类解析
logging.info("生成的配置数据不是列表使用PosterConfig类进行处理")
poster_config_summary = core_posterGen.PosterConfig(poster_text_configs_raw)
poster_configs = poster_config_summary.get_config()
except Exception as e:
2025-04-22 17:36:29 +08:00
logging.exception("Error running ContentGenerator or parsing poster configs:")
traceback.print_exc()
return False
# Poster Generation Loop for each variant
2025-04-22 21:26:56 +08:00
poster_num = min(variants, len(poster_configs)) if isinstance(poster_configs, list) else variants
logging.info(f"计划生成 {poster_num} 个海报变体")
any_poster_attempted = False
for j_index in range(poster_num):
variant_index = j_index + 1
2025-04-22 17:36:29 +08:00
logging.info(f"Generating Poster {variant_index}/{poster_num}...")
any_poster_attempted = True
collage_img = None # To store the generated collage PIL Image
poster_img = None # To store the final poster PIL Image
try:
2025-04-22 21:26:56 +08:00
# 获取当前变体的配置
if isinstance(poster_configs, list) and j_index < len(poster_configs):
poster_config = poster_configs[j_index]
logging.info(f"使用配置数据项 {j_index+1}: {poster_config}")
else:
# 回退方案使用PosterConfig类
poster_config = poster_config_summary.get_config_by_index(j_index)
logging.info(f"使用PosterConfig类获取配置项 {j_index+1}")
if not poster_config:
2025-04-22 17:36:29 +08:00
logging.warning(f"Warning: Could not get poster config for index {j_index}. Skipping.")
continue
# --- Image Collage ---
2025-04-22 17:36:29 +08:00
logging.info(f"Generating collage from: {input_img_dir_path}")
2025-04-26 13:45:47 +08:00
collage_images, used_image_filenames = core_simple_collage.process_directory(
input_img_dir_path,
style=collage_style,
target_size=poster_target_size,
output_count=1
)
if not collage_images: # 检查列表是否为空
2025-04-22 17:36:29 +08:00
logging.warning(f"Warning: Failed to generate collage image for Variant {variant_index}. Skipping poster.")
continue
collage_img = collage_images[0] # 获取第一个 PIL Image
2025-04-26 13:45:47 +08:00
used_image_files = used_image_filenames[0] if used_image_filenames else [] # 获取使用的图片文件名
logging.info(f"Collage image generated successfully (in memory). Used images: {used_image_files}")
logging.info(f"拼贴图使用的图片文件: {used_image_files}")
2025-04-26 13:45:47 +08:00
# --- 使用 Handler 保存 Collage 图片和使用的图片文件信息 ---
output_handler.handle_generated_image(
run_id, topic_index, variant_index,
image_type='collage',
image_data=collage_img,
2025-04-26 13:45:47 +08:00
output_filename='collage.png', # 或者其他期望的文件名
metadata={'used_images': used_image_files} # 添加图片文件信息到元数据
)
# --- 结束保存 Collage ---
# --- Create Poster ---
2025-04-27 16:24:10 +08:00
if random.random() > title_possibility:
2025-04-25 17:35:43 +08:00
text_data = {
"title": poster_config.get('main_title', ''),
"subtitle": "",
"additional_texts": []
2025-04-25 17:35:43 +08:00
}
texts = poster_config.get('texts', [])
if texts:
# 确保文本不为空
2025-04-27 16:24:10 +08:00
if random.random() > text_possibility:
text_data["additional_texts"].append({"text": texts[0], "position": "middle", "size_factor": 0.8})
# for text in texts:
# if random.random() < text_possibility:
# text_data["additional_texts"].append({"text": text, "position": "middle", "size_factor": 0.8})
2025-04-25 17:35:43 +08:00
else:
2025-04-27 16:24:10 +08:00
text_data = {
"title": "",
"subtitle": "",
"additional_texts": []
}
2025-04-22 21:26:56 +08:00
# 打印要发送的文本数据
logging.info(f"文本数据: {text_data}")
# 调用修改后的 create_poster, 接收 PIL Image
poster_img = poster_gen_instance.create_poster(collage_img, text_data)
if poster_img:
logging.info(f"Poster image generated successfully (in memory).")
# --- 使用 Handler 保存 Poster 图片 ---
output_handler.handle_generated_image(
run_id, topic_index, variant_index,
image_type='poster',
image_data=poster_img,
2025-04-26 13:45:47 +08:00
output_filename=output_poster_filename, # 使用参数中的文件名
metadata={'used_collage': True, 'collage_images': used_image_files}
)
# --- 结束保存 Poster ---
else:
logging.warning(f"Warning: Poster generation function returned None for variant {variant_index}.")
except Exception as e:
2025-04-22 17:36:29 +08:00
logging.exception(f"Error during poster generation for Variant {variant_index}:")
traceback.print_exc()
continue
2025-05-10 20:12:49 +08:00
return any_poster_attempted