TravelContentCreator/api/services/poster_service.py

409 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
通用海报服务层
支持多种模板类型的海报生成,配置化管理
"""
import os
import sys
import json
import random
import asyncio
import logging
import uuid
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
# 添加项目根目录到路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from core.ai.ai_agent import AIAgent
from api.config.poster_config_manager import get_poster_config_manager
from api.models.vibrant_poster import TemplateInfo
logger = logging.getLogger(__name__)
class UnifiedPosterService:
"""统一海报服务类"""
def __init__(self, ai_agent: AIAgent):
"""
初始化统一海报服务
Args:
ai_agent: AI代理
"""
self.ai_agent = ai_agent
self.config_manager = get_poster_config_manager()
def get_available_templates(self) -> List[TemplateInfo]:
"""获取所有可用的模板列表"""
template_list = self.config_manager.get_template_list()
return [TemplateInfo(**template) for template in template_list]
def get_template_info(self, template_id: str) -> Optional[TemplateInfo]:
"""获取指定模板的信息"""
template_info = self.config_manager.get_template_info(template_id)
if template_info:
return TemplateInfo(
id=template_id,
name=template_info.get("name", template_id),
description=template_info.get("description", ""),
size=template_info.get("size", [900, 1200]),
required_fields=template_info.get("required_fields", []),
optional_fields=template_info.get("optional_fields", [])
)
return None
async def generate_content(self, template_id: str, source_data: Dict[str, Any],
temperature: float = 0.7) -> Dict[str, Any]:
"""
生成海报内容
Args:
template_id: 模板ID
source_data: 源数据
temperature: AI生成温度参数
Returns:
生成的内容字典
"""
logger.info(f"正在为模板 {template_id} 生成内容...")
# 获取系统提示词
system_prompt = self.config_manager.get_system_prompt(template_id)
if not system_prompt:
raise ValueError(f"模板 {template_id} 没有配置系统提示词")
# 格式化用户提示词
user_prompt = self.config_manager.format_user_prompt(template_id, **source_data)
if not user_prompt:
raise ValueError(f"模板 {template_id} 用户提示词格式化失败")
try:
# 调用AI生成内容
response, _, _, _ = await self.ai_agent.generate_text(
system_prompt=system_prompt,
user_prompt=user_prompt,
temperature=temperature,
stage=f"海报内容生成-{template_id}"
)
# 解析JSON响应
json_start = response.find('{')
json_end = response.rfind('}') + 1
if json_start >= 0 and json_end > json_start:
json_str = response[json_start:json_end]
content_dict = json.loads(json_str)
logger.info(f"AI成功生成内容: {content_dict}")
# 确保所有值都是字符串类型(除了列表)
for key, value in content_dict.items():
if isinstance(value, (int, float)):
content_dict[key] = str(value)
elif isinstance(value, list):
content_dict[key] = [str(item) if isinstance(item, (int, float)) else item for item in value]
return content_dict
else:
logger.error(f"无法在AI响应中找到JSON对象: {response}")
raise ValueError("AI响应格式不正确")
except json.JSONDecodeError as e:
logger.error(f"无法解析AI响应为JSON: {e}")
raise ValueError("AI响应JSON解析失败")
except Exception as e:
logger.error(f"调用AI生成内容时发生错误: {e}")
raise
def select_random_image(self, image_dir: Optional[str] = None) -> str:
"""从指定目录随机选择一张图片"""
if image_dir is None:
image_dir = self.config_manager.get_default_config("image_dir")
try:
image_files = [f for f in os.listdir(image_dir)
if f.lower().endswith(('png', 'jpg', 'jpeg', 'webp'))]
if not image_files:
raise ValueError(f"在目录 {image_dir} 中未找到任何图片文件")
random_image_name = random.choice(image_files)
image_path = os.path.join(image_dir, random_image_name)
logger.info(f"随机选择图片: {image_path}")
return image_path
except FileNotFoundError:
raise ValueError(f"图片目录不存在: {image_dir}")
def validate_content(self, template_id: str, content: Dict[str, Any]) -> None:
"""验证内容是否符合模板要求"""
is_valid, errors = self.config_manager.validate_template_content(template_id, content)
if not is_valid:
raise ValueError(f"内容验证失败: {', '.join(errors)}")
async def generate_poster(self,
template_id: str,
content: Optional[Dict[str, Any]] = None,
source_data: Optional[Dict[str, Any]] = None,
topic_name: Optional[str] = None,
image_path: Optional[str] = None,
image_dir: Optional[str] = None,
output_dir: Optional[str] = None,
temperature: float = 0.7) -> Dict[str, Any]:
"""
生成海报
Args:
template_id: 模板ID
content: 直接提供的内容(可选)
source_data: 源数据用于AI生成内容可选
topic_name: 主题名称
image_path: 指定图片路径
image_dir: 图片目录
output_dir: 输出目录
temperature: AI生成温度参数
Returns:
生成结果字典
"""
start_time = time.time()
logger.info(f"开始生成海报,模板: {template_id}, 主题: {topic_name}")
# 生成请求ID
request_id = f"poster-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}"
# 获取模板信息
template_info = self.get_template_info(template_id)
if not template_info:
raise ValueError(f"未知的模板ID: {template_id}")
# 确定内容
if content is None:
if source_data is None:
raise ValueError("必须提供content或source_data中的一个")
# 使用AI生成内容
content = await self.generate_content(template_id, source_data, temperature)
generation_method = "ai_generated"
else:
generation_method = "direct"
# 验证内容
self.validate_content(template_id, content)
# 选择图片
if image_path is None:
image_path = self.select_random_image(image_dir)
if not os.path.exists(image_path):
raise ValueError(f"指定的图片文件不存在: {image_path}")
# 设置默认值
if output_dir is None:
output_dir = self.config_manager.get_default_config("output_dir")
if topic_name is None:
topic_name = f"poster_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# 获取模板类并生成海报
try:
template_class = self.config_manager.get_template_class(template_id)
template_instance = template_class(template_info.size)
# 设置字体(如果支持)
font_dir = self.config_manager.get_default_config("font_dir")
if hasattr(template_instance, 'set_font_dir') and font_dir:
template_instance.set_font_dir(font_dir)
poster = template_instance.generate(image_path=image_path, content=content)
if not poster:
raise ValueError("海报生成失败,模板返回了 None")
except Exception as e:
logger.error(f"生成海报时发生错误: {e}", exc_info=True)
raise ValueError(f"海报生成失败: {str(e)}")
# 保存海报
try:
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# 生成文件名
title = content.get('title', topic_name)
if isinstance(title, str):
title = title.replace('/', '_').replace('\\', '_')
output_filename = f"{template_id}_{title}_{timestamp}.png"
poster_path = os.path.join(output_dir, output_filename)
poster.save(poster_path, 'PNG')
logger.info(f"海报已成功生成并保存至: {poster_path}")
processing_time = round(time.time() - start_time, 2)
return {
"request_id": request_id,
"template_id": template_id,
"topic_name": topic_name,
"poster_path": poster_path,
"content": content,
"metadata": {
"image_used": image_path,
"generation_method": generation_method,
"template_size": template_info.size,
"processing_time": processing_time,
"timestamp": datetime.now(timezone.utc).isoformat()
}
}
except Exception as e:
logger.error(f"保存海报失败: {e}", exc_info=True)
raise ValueError(f"保存海报失败: {str(e)}")
async def batch_generate_posters(self,
template_id: str,
base_path: str,
image_dir: Optional[str] = None,
source_files: Optional[Dict[str, str]] = None,
output_base: str = "result/posters",
parallel_count: int = 3,
temperature: float = 0.7) -> Dict[str, Any]:
"""
批量生成海报
Args:
template_id: 模板ID
base_path: 包含多个topic目录的基础路径
image_dir: 图片目录
source_files: 源文件配置字典
output_base: 输出基础目录
parallel_count: 并发数量
temperature: AI生成温度参数
Returns:
批量处理结果
"""
logger.info(f"开始批量生成海报,模板: {template_id}, 基础路径: {base_path}")
# 生成批处理ID
batch_request_id = f"batch-{template_id}-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
# 查找topic目录
topic_dirs = self._find_topic_directories(base_path)
if not topic_dirs:
raise ValueError("未找到任何包含article_judged.json的topic目录")
logger.info(f"找到 {len(topic_dirs)} 个topic目录准备批量生成海报")
# 准备输出目录
base_name = Path(base_path).name
output_base_dir = os.path.join(output_base, base_name)
# 这里简化实现实际项目中可以使用asyncio.gather进行真正的异步批处理
results = []
successful_count = 0
failed_count = 0
for topic_path, topic_name in topic_dirs:
try:
article_path = os.path.join(topic_path, 'article_judged.json')
topic_output_dir = os.path.join(output_base_dir, topic_name)
# 读取文章数据
source_data = self._read_data_file(article_path)
if not source_data:
raise ValueError(f"无法读取文章文件: {article_path}")
# 构建源数据
final_source_data = {"tweet_info": source_data}
# 如果提供了额外的源文件,读取并添加
if source_files:
for key, file_path in source_files.items():
if file_path and os.path.exists(file_path):
data = self._read_data_file(file_path)
if data:
final_source_data[key] = data
# 生成海报
result = await self.generate_poster(
template_id=template_id,
source_data=final_source_data,
topic_name=topic_name,
image_dir=image_dir,
output_dir=topic_output_dir,
temperature=temperature
)
results.append({
"topic": topic_name,
"success": True,
"result": result
})
successful_count += 1
logger.info(f"成功生成海报: {topic_name}")
except Exception as e:
error_msg = str(e)
results.append({
"topic": topic_name,
"success": False,
"error": error_msg
})
failed_count += 1
logger.error(f"生成海报失败 {topic_name}: {error_msg}")
successful_topics = [r["topic"] for r in results if r["success"]]
failed_topics = [{"topic": r["topic"], "error": r["error"]} for r in results if not r["success"]]
return {
"request_id": batch_request_id,
"template_id": template_id,
"total_topics": len(topic_dirs),
"successful_count": successful_count,
"failed_count": failed_count,
"output_base_dir": output_base_dir,
"successful_topics": successful_topics,
"failed_topics": failed_topics,
"detailed_results": results
}
def _find_topic_directories(self, base_path: str) -> List[Tuple[str, str]]:
"""查找topic目录"""
topic_dirs = []
base_path = Path(base_path)
if not base_path.exists():
return topic_dirs
for item in base_path.iterdir():
if item.is_dir() and item.name.startswith('topic_'):
article_path = item / 'article_judged.json'
if article_path.exists():
topic_dirs.append((str(item), item.name))
return topic_dirs
def _read_data_file(self, file_path: str) -> Optional[Dict[str, Any]]:
"""读取数据文件(简化版)"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
try:
return json.loads(content)
except json.JSONDecodeError:
# 简单的文本内容处理
return {"content": content}
except Exception as e:
logger.error(f"读取文件失败 {file_path}: {e}")
return None
def reload_config(self):
"""重新加载配置"""
self.config_manager.reload_config()