321 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
海报服务层 - 重构版本
封装核心功能,支持基于模板的动态内容生成和海报创建
"""
import logging
import uuid
import time
import json
import importlib
import base64
from io import BytesIO
from typing import List, Dict, Any, Optional, Type, Union, cast
from datetime import datetime
from pathlib import Path
from PIL import Image
from core.config import ConfigManager, PosterConfig
from core.ai import AIAgent
from utils.file_io import OutputManager
from utils.image_processor import ImageProcessor
from poster.templates.base_template import BaseTemplate
from api.services.database_service import DatabaseService
logger = logging.getLogger(__name__)
class PosterService:
"""海报服务类"""
def __init__(self, ai_agent: AIAgent, config_manager: ConfigManager, output_manager: OutputManager):
"""初始化海报服务"""
self.ai_agent = ai_agent
self.config_manager = config_manager
self.output_manager = output_manager
self.db_service = DatabaseService(config_manager)
self._templates = {}
self._template_instances = {}
self._image_usage_tracker = {}
self._init_templates()
def _init_templates(self):
"""从数据库加载模板配置"""
try:
db_templates = self.db_service.get_active_poster_templates()
if db_templates:
self._templates = {t['id']: t for t in db_templates}
logger.info(f"从数据库加载了 {len(self._templates)} 个模板")
else:
self._load_default_templates()
logger.info("数据库无模板,使用默认模板配置")
except Exception as e:
logger.error(f"从数据库加载模板失败: {e}", exc_info=True)
self._load_default_templates()
def _load_default_templates(self):
"""加载默认模板配置"""
self._templates = {
'vibrant': {
'id': 'vibrant',
'name': '活力风格',
'handler_path': 'poster.templates.vibrant_template',
'class_name': 'VibrantTemplate',
'description': '适合景点、活动等充满活力的场景',
'is_active': True
},
'business': {
'id': 'business',
'name': '商务风格',
'handler_path': 'poster.templates.business_template',
'class_name': 'BusinessTemplate',
'description': '适合酒店、房地产等商务场景',
'is_active': True
}
}
def _load_template_handler(self, template_id: str) -> Optional[BaseTemplate]:
"""动态加载模板处理器"""
if template_id not in self._templates:
logger.error(f"未找到模板: {template_id}")
return None
# 如果已经实例化过,直接返回缓存的实例
if template_id in self._template_instances:
return self._template_instances[template_id]
template_info = self._templates[template_id]
handler_path = template_info.get('handler_path')
class_name = template_info.get('class_name')
if not handler_path or not class_name:
logger.error(f"模板 {template_id} 缺少 handler_path 或 class_name")
return None
try:
# 动态导入模块和类
module = importlib.import_module(handler_path)
template_class = getattr(module, class_name)
# 实例化模板
template_instance = template_class()
# 设置字体目录(如果配置了)
from core.config import PosterConfig
poster_config = self.config_manager.get_config('poster', PosterConfig)
if poster_config:
font_dir = poster_config.font_dir
if font_dir and hasattr(template_instance, 'set_font_dir'):
template_instance.set_font_dir(font_dir)
# 缓存实例以便重用
self._template_instances[template_id] = template_instance
logger.info(f"成功加载模板处理器: {template_id} ({handler_path}.{class_name})")
return template_instance
except (ImportError, AttributeError) as e:
logger.error(f"加载模板处理器失败: {e}", exc_info=True)
return None
def reload_templates(self):
"""重新加载模板信息"""
logger.info("重新加载模板信息...")
self._init_templates()
# 清除缓存的模板实例,以便重新加载
self._template_instances = {}
def get_available_templates(self) -> List[Dict[str, Any]]:
"""获取所有可用的模板信息"""
result = []
for tid in self._templates:
if self._templates[tid].get('is_active'):
template_info = self.get_template_info(tid)
if template_info:
result.append(template_info)
return result
def get_template_info(self, template_id: str) -> Optional[Dict[str, Any]]:
"""获取指定模板的简化信息"""
template = self._templates.get(template_id)
if not template:
return None
return {
"id": template["id"],
"name": template["name"],
"description": template["description"],
"has_prompts": bool(template.get("system_prompt") and template.get("user_prompt_template")),
"input_format": template.get("input_format", {}),
"output_format": template.get("output_format", {}),
"is_active": template.get("is_active", False)
}
async def generate_poster(self,
template_id: str,
poster_content: Optional[Dict[str, Any]],
content_id: Optional[int],
product_id: Optional[int],
scenic_spot_id: Optional[int],
image_ids: Optional[List[int]],
num_variations: int = 1,
force_llm_generation: bool = False) -> Dict[str, Any]:
"""
统一的海报生成入口
Args:
template_id: 模板ID
poster_content: 用户提供的海报内容(可选)
content_id: 内容ID用于从数据库获取内容可选
product_id: 产品ID用于从数据库获取产品信息可选
scenic_spot_id: 景点ID用于从数据库获取景点信息可选
image_ids: 图片ID列表用于从数据库获取图片可选
num_variations: 需要生成的变体数量
force_llm_generation: 是否强制使用LLM生成内容
Returns:
生成结果字典
"""
start_time = time.time()
# 1. 动态加载模板处理器
template_handler = self._load_template_handler(template_id)
if not template_handler:
raise ValueError(f"无法为模板ID '{template_id}' 加载处理器。")
# 2. 准备内容 (LLM或用户提供)
final_content = poster_content
if force_llm_generation or not final_content:
logger.info(f"为模板 {template_id} 按需生成内容...")
final_content = await self._generate_content_with_llm(template_id, content_id, product_id, scenic_spot_id)
if not final_content:
raise ValueError("无法获取用于生成海报的内容")
# 3. 准备图片
images = []
if image_ids:
images = self.db_service.get_images_by_ids(image_ids)
if not images:
raise ValueError("无法获取指定的图片")
# 4. 调用模板生成海报
try:
posters = template_handler.generate(
content=final_content,
images=images,
num_variations=num_variations
)
if not posters:
raise ValueError("模板未能生成有效的海报")
# 5. 保存海报并返回结果
variations = []
for i, poster in enumerate(posters):
output_path = self._save_poster(poster, template_id, i)
if output_path:
variations.append({
"variation_id": i,
"poster_path": str(output_path),
"base64": self._image_to_base64(poster)
})
# 记录模板使用情况
self._update_template_stats(template_id, bool(variations), time.time() - start_time)
return {
"request_id": f"poster-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}",
"template_id": template_id,
"variations": variations,
"metadata": {
"generation_time": f"{time.time() - start_time:.2f}s",
"model_used": self.ai_agent.config.model if force_llm_generation or not poster_content else None,
"num_variations": len(variations)
}
}
except Exception as e:
logger.error(f"生成海报时发生错误: {e}", exc_info=True)
self._update_template_stats(template_id, False, time.time() - start_time)
raise ValueError(f"生成海报失败: {str(e)}")
def _save_poster(self, poster: Image.Image, template_id: str, variation_id: int) -> Optional[Path]:
"""保存海报到文件系统"""
try:
# 创建唯一的主题ID用于保存
topic_id = f"poster_{template_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# 获取输出目录
output_dir = self.output_manager.get_topic_dir(topic_id)
# 生成文件名
file_name = f"{template_id}_v{variation_id}.png"
file_path = output_dir / file_name
# 保存图像
poster.save(file_path, format="PNG")
logger.info(f"海报已保存: {file_path}")
return file_path
except Exception as e:
logger.error(f"保存海报失败: {e}", exc_info=True)
return None
def _image_to_base64(self, image: Image.Image) -> str:
"""将PIL图像转换为base64字符串"""
buffer = BytesIO()
image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode('utf-8')
def _update_template_stats(self, template_id: str, success: bool, duration: float):
"""更新模板使用统计"""
try:
# 调用数据库服务的方法更新统计
self.db_service.update_template_usage_stats(
template_id=template_id,
success=success,
processing_time=duration
)
except Exception as e:
logger.warning(f"更新模板统计失败: {e}")
async def _generate_content_with_llm(self, template_id: str, content_id: Optional[int],
product_id: Optional[int], scenic_spot_id: Optional[int]) -> Optional[Dict[str, Any]]:
"""使用LLM生成海报内容"""
# 获取提示词
template_info = self._templates.get(template_id, {})
system_prompt = template_info.get('system_prompt', "")
user_prompt_template = template_info.get('user_prompt_template', "")
if not system_prompt or not user_prompt_template:
logger.error(f"模板 {template_id} 缺少提示词配置")
return None
# 获取相关数据
data = {}
if content_id:
data['content'] = self.db_service.get_content_by_id(content_id)
if product_id:
data['product'] = self.db_service.get_product_by_id(product_id)
if scenic_spot_id:
data['scenic_spot'] = self.db_service.get_scenic_spot_by_id(scenic_spot_id)
# 格式化提示词
try:
user_prompt = user_prompt_template.format(**data)
except KeyError as e:
logger.warning(f"格式化提示词时缺少键: {e}")
user_prompt = user_prompt_template + f"\n可用数据: {json.dumps(data, ensure_ascii=False)}"
try:
response, _, _, _ = await self.ai_agent.generate_text(system_prompt=system_prompt, user_prompt=user_prompt)
json_start = response.find('{')
json_end = response.rfind('}') + 1
if json_start != -1 and json_end != -1:
return json.loads(response[json_start:json_end])
logger.error(f"LLM响应中未找到JSON: {response}")
return None
except Exception as e:
logger.error(f"生成内容时发生错误: {e}", exc_info=True)
return None