TravelContentCreator/utils/poster/poster_generator.py

132 lines
5.0 KiB
Python
Raw Normal View History

2025-07-10 10:08:03 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
海报生成流程总协调器
"""
import logging
import os
from typing import Dict, Any, Optional
import sys
from PIL import Image
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 核心模块依赖
from core.ai import AIAgent
from core.config import PosterConfig
from utils.file_io import OutputManager, ResourceLoader
# 本地模块依赖
from .job import PosterJob
from .templates.vibrant_template import VibrantTemplate
from .templates.business_template import BusinessTemplate
from .templates.collage_template import CollageTemplate
from .text_generator import PosterContentGenerator
from .utils import ImageProcessor
logger = logging.getLogger(__name__)
class PosterGenerator:
"""
负责根据配置选择和调用不同的海报模板并协调整个生成流程
"""
def __init__(self, config: PosterConfig, ai_agent: AIAgent, output_manager: OutputManager):
"""
初始化海报生成器
Args:
config (PosterConfig): 海报生成相关的配置模型
ai_agent (AIAgent): AI代理实例
output_manager (OutputManager): 文件输出处理器
"""
self.config = config
self.ai_agent = ai_agent
self.output_manager = output_manager
self.content_generator = PosterContentGenerator(ai_agent)
self.logger = logging.getLogger(__name__)
# 模板注册表
self.templates = {
"vibrant": VibrantTemplate,
"business": BusinessTemplate,
"collage": CollageTemplate,
}
async def generate(self, run_id: str, topic_index: int, variant_index: int, job: PosterJob) -> Optional[str]:
"""
执行单个海报生成任务
Args:
run_id (str): 当前运行的ID
topic_index (int): 主题索引
variant_index (int): 变体索引
job (PosterJob): 包含任务详情的配置对象
Returns:
Optional[str]: 成功则返回生成海报的路径否则返回None
"""
self.logger.info(f"开始处理海报任务: template={job.template}, topic={topic_index}, variant={variant_index}")
# 1. 选择模板
template_class = self.templates.get(job.template)
if not template_class:
self.logger.error(f"未知的海报模板: {job.template}")
return None
template_instance = template_class(size=job.size)
# 2. 准备生成参数
generation_params = job.params.copy()
# 3. (可选) AI生成文本
if job.generate_text:
text_params = job.text_generation_params
system_prompt = ResourceLoader.load_text_file(text_params.system_prompt_path or self.config.poster_system_prompt)
user_prompt = ResourceLoader.load_text_file(text_params.user_prompt_path or self.config.poster_user_prompt)
if not system_prompt or not user_prompt:
self.logger.error("AI文本生成失败无法加载系统或用户提示词。")
return None
generated_content = await self.content_generator.generate_text_for_poster(
system_prompt=system_prompt,
user_prompt=user_prompt,
context_data={"content_data": text_params.content_data},
temperature=self.config.model.temperature,
top_p=self.config.model.top_p
)
if not generated_content:
self.logger.error("AI未能生成有效的文本内容。")
return None
# 将AI生成的内容合并到参数中
generation_params.update(generated_content)
# 4. 执行模板生成
try:
self.logger.info(f"使用模板 '{job.template}' 生成图像...")
poster_image = template_instance.generate(**generation_params)
if poster_image:
# 5. (可选) 应用哈希干扰
if self.config.anti_duplicate_hash:
self.logger.info(f"为海报应用哈希干扰...")
poster_image = ImageProcessor.apply_strategic_hash_disruption(poster_image)
# 6. 保存最终图像
output_filename = f"poster_{topic_index}_{variant_index}_{job.template}.png"
# 使用OutputManager来获取正确的保存路径
variant_dir = self.output_manager.get_variant_dir(topic_index, variant_index)
output_path = variant_dir / output_filename
ImageProcessor.save_image(poster_image, str(output_path))
self.logger.info(f"成功生成并保存海报: {output_path}")
return str(output_path)
else:
self.logger.error(f"模板 '{job.template}' 未能生成图像。")
return None
except Exception as e:
self.logger.critical(f"海报生成过程中发生意外错误: {e}", exc_info=True)
return None