409 lines
16 KiB
Python
409 lines
16 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
通用海报服务层
|
||
支持多种模板类型的海报生成,配置化管理
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import random
|
||
import asyncio
|
||
import logging
|
||
import uuid
|
||
import time
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
from typing import List, Dict, Any, Optional, Tuple
|
||
|
||
# 添加项目根目录到路径
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||
|
||
from core.ai.ai_agent import AIAgent
|
||
from api.config.poster_config_manager import get_poster_config_manager
|
||
from api.models.vibrant_poster import TemplateInfo
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class UnifiedPosterService:
|
||
"""统一海报服务类"""
|
||
|
||
def __init__(self, ai_agent: AIAgent):
|
||
"""
|
||
初始化统一海报服务
|
||
|
||
Args:
|
||
ai_agent: AI代理
|
||
"""
|
||
self.ai_agent = ai_agent
|
||
self.config_manager = get_poster_config_manager()
|
||
|
||
def get_available_templates(self) -> List[TemplateInfo]:
|
||
"""获取所有可用的模板列表"""
|
||
template_list = self.config_manager.get_template_list()
|
||
return [TemplateInfo(**template) for template in template_list]
|
||
|
||
def get_template_info(self, template_id: str) -> Optional[TemplateInfo]:
|
||
"""获取指定模板的信息"""
|
||
template_info = self.config_manager.get_template_info(template_id)
|
||
if template_info:
|
||
return TemplateInfo(
|
||
id=template_id,
|
||
name=template_info.get("name", template_id),
|
||
description=template_info.get("description", ""),
|
||
size=template_info.get("size", [900, 1200]),
|
||
required_fields=template_info.get("required_fields", []),
|
||
optional_fields=template_info.get("optional_fields", [])
|
||
)
|
||
return None
|
||
|
||
async def generate_content(self, template_id: str, source_data: Dict[str, Any],
|
||
temperature: float = 0.7) -> Dict[str, Any]:
|
||
"""
|
||
生成海报内容
|
||
|
||
Args:
|
||
template_id: 模板ID
|
||
source_data: 源数据
|
||
temperature: AI生成温度参数
|
||
|
||
Returns:
|
||
生成的内容字典
|
||
"""
|
||
logger.info(f"正在为模板 {template_id} 生成内容...")
|
||
|
||
# 获取系统提示词
|
||
system_prompt = self.config_manager.get_system_prompt(template_id)
|
||
if not system_prompt:
|
||
raise ValueError(f"模板 {template_id} 没有配置系统提示词")
|
||
|
||
# 格式化用户提示词
|
||
user_prompt = self.config_manager.format_user_prompt(template_id, **source_data)
|
||
if not user_prompt:
|
||
raise ValueError(f"模板 {template_id} 用户提示词格式化失败")
|
||
|
||
try:
|
||
# 调用AI生成内容
|
||
response, _, _, _ = await self.ai_agent.generate_text(
|
||
system_prompt=system_prompt,
|
||
user_prompt=user_prompt,
|
||
temperature=temperature,
|
||
stage=f"海报内容生成-{template_id}"
|
||
)
|
||
|
||
# 解析JSON响应
|
||
json_start = response.find('{')
|
||
json_end = response.rfind('}') + 1
|
||
|
||
if json_start >= 0 and json_end > json_start:
|
||
json_str = response[json_start:json_end]
|
||
content_dict = json.loads(json_str)
|
||
logger.info(f"AI成功生成内容: {content_dict}")
|
||
|
||
# 确保所有值都是字符串类型(除了列表)
|
||
for key, value in content_dict.items():
|
||
if isinstance(value, (int, float)):
|
||
content_dict[key] = str(value)
|
||
elif isinstance(value, list):
|
||
content_dict[key] = [str(item) if isinstance(item, (int, float)) else item for item in value]
|
||
|
||
return content_dict
|
||
else:
|
||
logger.error(f"无法在AI响应中找到JSON对象: {response}")
|
||
raise ValueError("AI响应格式不正确")
|
||
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"无法解析AI响应为JSON: {e}")
|
||
raise ValueError("AI响应JSON解析失败")
|
||
except Exception as e:
|
||
logger.error(f"调用AI生成内容时发生错误: {e}")
|
||
raise
|
||
|
||
def select_random_image(self, image_dir: Optional[str] = None) -> str:
|
||
"""从指定目录随机选择一张图片"""
|
||
if image_dir is None:
|
||
image_dir = self.config_manager.get_default_config("image_dir")
|
||
|
||
try:
|
||
image_files = [f for f in os.listdir(image_dir)
|
||
if f.lower().endswith(('png', 'jpg', 'jpeg', 'webp'))]
|
||
if not image_files:
|
||
raise ValueError(f"在目录 {image_dir} 中未找到任何图片文件")
|
||
|
||
random_image_name = random.choice(image_files)
|
||
image_path = os.path.join(image_dir, random_image_name)
|
||
logger.info(f"随机选择图片: {image_path}")
|
||
return image_path
|
||
except FileNotFoundError:
|
||
raise ValueError(f"图片目录不存在: {image_dir}")
|
||
|
||
def validate_content(self, template_id: str, content: Dict[str, Any]) -> None:
|
||
"""验证内容是否符合模板要求"""
|
||
is_valid, errors = self.config_manager.validate_template_content(template_id, content)
|
||
if not is_valid:
|
||
raise ValueError(f"内容验证失败: {', '.join(errors)}")
|
||
|
||
async def generate_poster(self,
|
||
template_id: str,
|
||
content: Optional[Dict[str, Any]] = None,
|
||
source_data: Optional[Dict[str, Any]] = None,
|
||
topic_name: Optional[str] = None,
|
||
image_path: Optional[str] = None,
|
||
image_dir: Optional[str] = None,
|
||
output_dir: Optional[str] = None,
|
||
temperature: float = 0.7) -> Dict[str, Any]:
|
||
"""
|
||
生成海报
|
||
|
||
Args:
|
||
template_id: 模板ID
|
||
content: 直接提供的内容(可选)
|
||
source_data: 源数据,用于AI生成内容(可选)
|
||
topic_name: 主题名称
|
||
image_path: 指定图片路径
|
||
image_dir: 图片目录
|
||
output_dir: 输出目录
|
||
temperature: AI生成温度参数
|
||
|
||
Returns:
|
||
生成结果字典
|
||
"""
|
||
start_time = time.time()
|
||
|
||
logger.info(f"开始生成海报,模板: {template_id}, 主题: {topic_name}")
|
||
|
||
# 生成请求ID
|
||
request_id = f"poster-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}"
|
||
|
||
# 获取模板信息
|
||
template_info = self.get_template_info(template_id)
|
||
if not template_info:
|
||
raise ValueError(f"未知的模板ID: {template_id}")
|
||
|
||
# 确定内容
|
||
if content is None:
|
||
if source_data is None:
|
||
raise ValueError("必须提供content或source_data中的一个")
|
||
|
||
# 使用AI生成内容
|
||
content = await self.generate_content(template_id, source_data, temperature)
|
||
generation_method = "ai_generated"
|
||
else:
|
||
generation_method = "direct"
|
||
|
||
# 验证内容
|
||
self.validate_content(template_id, content)
|
||
|
||
# 选择图片
|
||
if image_path is None:
|
||
image_path = self.select_random_image(image_dir)
|
||
|
||
if not os.path.exists(image_path):
|
||
raise ValueError(f"指定的图片文件不存在: {image_path}")
|
||
|
||
# 设置默认值
|
||
if output_dir is None:
|
||
output_dir = self.config_manager.get_default_config("output_dir")
|
||
if topic_name is None:
|
||
topic_name = f"poster_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||
|
||
# 获取模板类并生成海报
|
||
try:
|
||
template_class = self.config_manager.get_template_class(template_id)
|
||
template_instance = template_class(template_info.size)
|
||
|
||
# 设置字体(如果支持)
|
||
font_dir = self.config_manager.get_default_config("font_dir")
|
||
if hasattr(template_instance, 'set_font_dir') and font_dir:
|
||
template_instance.set_font_dir(font_dir)
|
||
|
||
poster = template_instance.generate(image_path=image_path, content=content)
|
||
|
||
if not poster:
|
||
raise ValueError("海报生成失败,模板返回了 None")
|
||
|
||
except Exception as e:
|
||
logger.error(f"生成海报时发生错误: {e}", exc_info=True)
|
||
raise ValueError(f"海报生成失败: {str(e)}")
|
||
|
||
# 保存海报
|
||
try:
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
|
||
# 生成文件名
|
||
title = content.get('title', topic_name)
|
||
if isinstance(title, str):
|
||
title = title.replace('/', '_').replace('\\', '_')
|
||
output_filename = f"{template_id}_{title}_{timestamp}.png"
|
||
poster_path = os.path.join(output_dir, output_filename)
|
||
|
||
poster.save(poster_path, 'PNG')
|
||
logger.info(f"海报已成功生成并保存至: {poster_path}")
|
||
|
||
processing_time = round(time.time() - start_time, 2)
|
||
|
||
return {
|
||
"request_id": request_id,
|
||
"template_id": template_id,
|
||
"topic_name": topic_name,
|
||
"poster_path": poster_path,
|
||
"content": content,
|
||
"metadata": {
|
||
"image_used": image_path,
|
||
"generation_method": generation_method,
|
||
"template_size": template_info.size,
|
||
"processing_time": processing_time,
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"保存海报失败: {e}", exc_info=True)
|
||
raise ValueError(f"保存海报失败: {str(e)}")
|
||
|
||
async def batch_generate_posters(self,
|
||
template_id: str,
|
||
base_path: str,
|
||
image_dir: Optional[str] = None,
|
||
source_files: Optional[Dict[str, str]] = None,
|
||
output_base: str = "result/posters",
|
||
parallel_count: int = 3,
|
||
temperature: float = 0.7) -> Dict[str, Any]:
|
||
"""
|
||
批量生成海报
|
||
|
||
Args:
|
||
template_id: 模板ID
|
||
base_path: 包含多个topic目录的基础路径
|
||
image_dir: 图片目录
|
||
source_files: 源文件配置字典
|
||
output_base: 输出基础目录
|
||
parallel_count: 并发数量
|
||
temperature: AI生成温度参数
|
||
|
||
Returns:
|
||
批量处理结果
|
||
"""
|
||
logger.info(f"开始批量生成海报,模板: {template_id}, 基础路径: {base_path}")
|
||
|
||
# 生成批处理ID
|
||
batch_request_id = f"batch-{template_id}-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
|
||
|
||
# 查找topic目录
|
||
topic_dirs = self._find_topic_directories(base_path)
|
||
|
||
if not topic_dirs:
|
||
raise ValueError("未找到任何包含article_judged.json的topic目录")
|
||
|
||
logger.info(f"找到 {len(topic_dirs)} 个topic目录,准备批量生成海报")
|
||
|
||
# 准备输出目录
|
||
base_name = Path(base_path).name
|
||
output_base_dir = os.path.join(output_base, base_name)
|
||
|
||
# 这里简化实现,实际项目中可以使用asyncio.gather进行真正的异步批处理
|
||
results = []
|
||
successful_count = 0
|
||
failed_count = 0
|
||
|
||
for topic_path, topic_name in topic_dirs:
|
||
try:
|
||
article_path = os.path.join(topic_path, 'article_judged.json')
|
||
topic_output_dir = os.path.join(output_base_dir, topic_name)
|
||
|
||
# 读取文章数据
|
||
source_data = self._read_data_file(article_path)
|
||
if not source_data:
|
||
raise ValueError(f"无法读取文章文件: {article_path}")
|
||
|
||
# 构建源数据
|
||
final_source_data = {"tweet_info": source_data}
|
||
|
||
# 如果提供了额外的源文件,读取并添加
|
||
if source_files:
|
||
for key, file_path in source_files.items():
|
||
if file_path and os.path.exists(file_path):
|
||
data = self._read_data_file(file_path)
|
||
if data:
|
||
final_source_data[key] = data
|
||
|
||
# 生成海报
|
||
result = await self.generate_poster(
|
||
template_id=template_id,
|
||
source_data=final_source_data,
|
||
topic_name=topic_name,
|
||
image_dir=image_dir,
|
||
output_dir=topic_output_dir,
|
||
temperature=temperature
|
||
)
|
||
|
||
results.append({
|
||
"topic": topic_name,
|
||
"success": True,
|
||
"result": result
|
||
})
|
||
successful_count += 1
|
||
logger.info(f"成功生成海报: {topic_name}")
|
||
|
||
except Exception as e:
|
||
error_msg = str(e)
|
||
results.append({
|
||
"topic": topic_name,
|
||
"success": False,
|
||
"error": error_msg
|
||
})
|
||
failed_count += 1
|
||
logger.error(f"生成海报失败 {topic_name}: {error_msg}")
|
||
|
||
successful_topics = [r["topic"] for r in results if r["success"]]
|
||
failed_topics = [{"topic": r["topic"], "error": r["error"]} for r in results if not r["success"]]
|
||
|
||
return {
|
||
"request_id": batch_request_id,
|
||
"template_id": template_id,
|
||
"total_topics": len(topic_dirs),
|
||
"successful_count": successful_count,
|
||
"failed_count": failed_count,
|
||
"output_base_dir": output_base_dir,
|
||
"successful_topics": successful_topics,
|
||
"failed_topics": failed_topics,
|
||
"detailed_results": results
|
||
}
|
||
|
||
def _find_topic_directories(self, base_path: str) -> List[Tuple[str, str]]:
|
||
"""查找topic目录"""
|
||
topic_dirs = []
|
||
base_path = Path(base_path)
|
||
|
||
if not base_path.exists():
|
||
return topic_dirs
|
||
|
||
for item in base_path.iterdir():
|
||
if item.is_dir() and item.name.startswith('topic_'):
|
||
article_path = item / 'article_judged.json'
|
||
if article_path.exists():
|
||
topic_dirs.append((str(item), item.name))
|
||
|
||
return topic_dirs
|
||
|
||
def _read_data_file(self, file_path: str) -> Optional[Dict[str, Any]]:
|
||
"""读取数据文件(简化版)"""
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
try:
|
||
return json.loads(content)
|
||
except json.JSONDecodeError:
|
||
# 简单的文本内容处理
|
||
return {"content": content}
|
||
except Exception as e:
|
||
logger.error(f"读取文件失败 {file_path}: {e}")
|
||
return None
|
||
|
||
def reload_config(self):
|
||
"""重新加载配置"""
|
||
self.config_manager.reload_config() |