TravelContentCreator/utils/poster/poster_generator.py

132 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
海报生成流程总协调器
"""
import logging
import os
from typing import Dict, Any, Optional
import sys
from PIL import Image
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 核心模块依赖
from core.ai import AIAgent
from core.config import PosterConfig
from utils.file_io import OutputManager, ResourceLoader
# 本地模块依赖
from .job import PosterJob
from .templates.vibrant_template import VibrantTemplate
from .templates.business_template import BusinessTemplate
from .templates.collage_template import CollageTemplate
from .text_generator import PosterContentGenerator
from .utils import ImageProcessor
logger = logging.getLogger(__name__)
class PosterGenerator:
"""
负责根据配置选择和调用不同的海报模板,并协调整个生成流程。
"""
def __init__(self, config: PosterConfig, ai_agent: AIAgent, output_manager: OutputManager):
"""
初始化海报生成器
Args:
config (PosterConfig): 海报生成相关的配置模型
ai_agent (AIAgent): AI代理实例
output_manager (OutputManager): 文件输出处理器
"""
self.config = config
self.ai_agent = ai_agent
self.output_manager = output_manager
self.content_generator = PosterContentGenerator(ai_agent)
self.logger = logging.getLogger(__name__)
# 模板注册表
self.templates = {
"vibrant": VibrantTemplate,
"business": BusinessTemplate,
"collage": CollageTemplate,
}
async def generate(self, run_id: str, topic_index: int, variant_index: int, job: PosterJob) -> Optional[str]:
"""
执行单个海报生成任务
Args:
run_id (str): 当前运行的ID
topic_index (int): 主题索引
variant_index (int): 变体索引
job (PosterJob): 包含任务详情的配置对象
Returns:
Optional[str]: 成功则返回生成海报的路径否则返回None
"""
self.logger.info(f"开始处理海报任务: template={job.template}, topic={topic_index}, variant={variant_index}")
# 1. 选择模板
template_class = self.templates.get(job.template)
if not template_class:
self.logger.error(f"未知的海报模板: {job.template}")
return None
template_instance = template_class(size=job.size)
# 2. 准备生成参数
generation_params = job.params.copy()
# 3. (可选) AI生成文本
if job.generate_text:
text_params = job.text_generation_params
system_prompt = ResourceLoader.load_text_file(text_params.system_prompt_path or self.config.poster_system_prompt)
user_prompt = ResourceLoader.load_text_file(text_params.user_prompt_path or self.config.poster_user_prompt)
if not system_prompt or not user_prompt:
self.logger.error("AI文本生成失败无法加载系统或用户提示词。")
return None
generated_content = await self.content_generator.generate_text_for_poster(
system_prompt=system_prompt,
user_prompt=user_prompt,
context_data={"content_data": text_params.content_data},
temperature=self.config.model.temperature,
top_p=self.config.model.top_p
)
if not generated_content:
self.logger.error("AI未能生成有效的文本内容。")
return None
# 将AI生成的内容合并到参数中
generation_params.update(generated_content)
# 4. 执行模板生成
try:
self.logger.info(f"使用模板 '{job.template}' 生成图像...")
poster_image = template_instance.generate(**generation_params)
if poster_image:
# 5. (可选) 应用哈希干扰
if self.config.anti_duplicate_hash:
self.logger.info(f"为海报应用哈希干扰...")
poster_image = ImageProcessor.apply_strategic_hash_disruption(poster_image)
# 6. 保存最终图像
output_filename = f"poster_{topic_index}_{variant_index}_{job.template}.png"
# 使用OutputManager来获取正确的保存路径
variant_dir = self.output_manager.get_variant_dir(topic_index, variant_index)
output_path = variant_dir / output_filename
ImageProcessor.save_image(poster_image, str(output_path))
self.logger.info(f"成功生成并保存海报: {output_path}")
return str(output_path)
else:
self.logger.error(f"模板 '{job.template}' 未能生成图像。")
return None
except Exception as e:
self.logger.critical(f"海报生成过程中发生意外错误: {e}", exc_info=True)
return None