132 lines
5.0 KiB
Python
132 lines
5.0 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
海报生成流程总协调器
|
|||
|
|
"""
|
|||
|
|
import logging
|
|||
|
|
import os
|
|||
|
|
from typing import Dict, Any, Optional
|
|||
|
|
import sys
|
|||
|
|
from PIL import Image
|
|||
|
|
|
|||
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|||
|
|
# 核心模块依赖
|
|||
|
|
from core.ai import AIAgent
|
|||
|
|
from core.config import PosterConfig
|
|||
|
|
from utils.file_io import OutputManager, ResourceLoader
|
|||
|
|
|
|||
|
|
# 本地模块依赖
|
|||
|
|
from .job import PosterJob
|
|||
|
|
from .templates.vibrant_template import VibrantTemplate
|
|||
|
|
from .templates.business_template import BusinessTemplate
|
|||
|
|
from .templates.collage_template import CollageTemplate
|
|||
|
|
from .text_generator import PosterContentGenerator
|
|||
|
|
from .utils import ImageProcessor
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
class PosterGenerator:
|
|||
|
|
"""
|
|||
|
|
负责根据配置选择和调用不同的海报模板,并协调整个生成流程。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, config: PosterConfig, ai_agent: AIAgent, output_manager: OutputManager):
|
|||
|
|
"""
|
|||
|
|
初始化海报生成器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
config (PosterConfig): 海报生成相关的配置模型
|
|||
|
|
ai_agent (AIAgent): AI代理实例
|
|||
|
|
output_manager (OutputManager): 文件输出处理器
|
|||
|
|
"""
|
|||
|
|
self.config = config
|
|||
|
|
self.ai_agent = ai_agent
|
|||
|
|
self.output_manager = output_manager
|
|||
|
|
self.content_generator = PosterContentGenerator(ai_agent)
|
|||
|
|
self.logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
# 模板注册表
|
|||
|
|
self.templates = {
|
|||
|
|
"vibrant": VibrantTemplate,
|
|||
|
|
"business": BusinessTemplate,
|
|||
|
|
"collage": CollageTemplate,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async def generate(self, run_id: str, topic_index: int, variant_index: int, job: PosterJob) -> Optional[str]:
|
|||
|
|
"""
|
|||
|
|
执行单个海报生成任务
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
run_id (str): 当前运行的ID
|
|||
|
|
topic_index (int): 主题索引
|
|||
|
|
variant_index (int): 变体索引
|
|||
|
|
job (PosterJob): 包含任务详情的配置对象
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Optional[str]: 成功则返回生成海报的路径,否则返回None
|
|||
|
|
"""
|
|||
|
|
self.logger.info(f"开始处理海报任务: template={job.template}, topic={topic_index}, variant={variant_index}")
|
|||
|
|
|
|||
|
|
# 1. 选择模板
|
|||
|
|
template_class = self.templates.get(job.template)
|
|||
|
|
if not template_class:
|
|||
|
|
self.logger.error(f"未知的海报模板: {job.template}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
template_instance = template_class(size=job.size)
|
|||
|
|
|
|||
|
|
# 2. 准备生成参数
|
|||
|
|
generation_params = job.params.copy()
|
|||
|
|
|
|||
|
|
# 3. (可选) AI生成文本
|
|||
|
|
if job.generate_text:
|
|||
|
|
text_params = job.text_generation_params
|
|||
|
|
system_prompt = ResourceLoader.load_text_file(text_params.system_prompt_path or self.config.poster_system_prompt)
|
|||
|
|
user_prompt = ResourceLoader.load_text_file(text_params.user_prompt_path or self.config.poster_user_prompt)
|
|||
|
|
|
|||
|
|
if not system_prompt or not user_prompt:
|
|||
|
|
self.logger.error("AI文本生成失败:无法加载系统或用户提示词。")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
generated_content = await self.content_generator.generate_text_for_poster(
|
|||
|
|
system_prompt=system_prompt,
|
|||
|
|
user_prompt=user_prompt,
|
|||
|
|
context_data={"content_data": text_params.content_data},
|
|||
|
|
temperature=self.config.model.temperature,
|
|||
|
|
top_p=self.config.model.top_p
|
|||
|
|
)
|
|||
|
|
if not generated_content:
|
|||
|
|
self.logger.error("AI未能生成有效的文本内容。")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 将AI生成的内容合并到参数中
|
|||
|
|
generation_params.update(generated_content)
|
|||
|
|
|
|||
|
|
# 4. 执行模板生成
|
|||
|
|
try:
|
|||
|
|
self.logger.info(f"使用模板 '{job.template}' 生成图像...")
|
|||
|
|
poster_image = template_instance.generate(**generation_params)
|
|||
|
|
|
|||
|
|
if poster_image:
|
|||
|
|
# 5. (可选) 应用哈希干扰
|
|||
|
|
if self.config.anti_duplicate_hash:
|
|||
|
|
self.logger.info(f"为海报应用哈希干扰...")
|
|||
|
|
poster_image = ImageProcessor.apply_strategic_hash_disruption(poster_image)
|
|||
|
|
|
|||
|
|
# 6. 保存最终图像
|
|||
|
|
output_filename = f"poster_{topic_index}_{variant_index}_{job.template}.png"
|
|||
|
|
# 使用OutputManager来获取正确的保存路径
|
|||
|
|
variant_dir = self.output_manager.get_variant_dir(topic_index, variant_index)
|
|||
|
|
output_path = variant_dir / output_filename
|
|||
|
|
|
|||
|
|
ImageProcessor.save_image(poster_image, str(output_path))
|
|||
|
|
self.logger.info(f"成功生成并保存海报: {output_path}")
|
|||
|
|
return str(output_path)
|
|||
|
|
else:
|
|||
|
|
self.logger.error(f"模板 '{job.template}' 未能生成图像。")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
self.logger.critical(f"海报生成过程中发生意外错误: {e}", exc_info=True)
|
|||
|
|
return None
|