新建了poster接口
This commit is contained in:
parent
3aa2775ec1
commit
5637305a36
Binary file not shown.
BIN
api/config/__pycache__/poster_config_manager.cpython-312.pyc
Normal file
BIN
api/config/__pycache__/poster_config_manager.cpython-312.pyc
Normal file
Binary file not shown.
214
api/config/poster_config_manager.py
Normal file
214
api/config/poster_config_manager.py
Normal file
@ -0,0 +1,214 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
海报配置管理器
|
||||||
|
负责加载和管理海报相关的配置信息
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PosterConfigManager:
|
||||||
|
"""海报配置管理器"""
|
||||||
|
|
||||||
|
def __init__(self, config_file: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
初始化配置管理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_file: 配置文件路径,如果为空则使用默认路径
|
||||||
|
"""
|
||||||
|
if config_file is None:
|
||||||
|
config_file = os.path.join(os.path.dirname(__file__), "poster_prompts.yaml")
|
||||||
|
|
||||||
|
self.config_file = config_file
|
||||||
|
self.config = {}
|
||||||
|
self.load_config()
|
||||||
|
|
||||||
|
def load_config(self):
|
||||||
|
"""加载配置文件"""
|
||||||
|
try:
|
||||||
|
with open(self.config_file, 'r', encoding='utf-8') as f:
|
||||||
|
self.config = yaml.safe_load(f)
|
||||||
|
logger.info(f"成功加载配置文件: {self.config_file}")
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(f"配置文件不存在: {self.config_file}")
|
||||||
|
self.config = self._get_default_config()
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
logger.error(f"解析配置文件失败: {e}")
|
||||||
|
self.config = self._get_default_config()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载配置文件时发生未知错误: {e}")
|
||||||
|
self.config = self._get_default_config()
|
||||||
|
|
||||||
|
def _get_default_config(self) -> Dict[str, Any]:
|
||||||
|
"""获取默认配置"""
|
||||||
|
return {
|
||||||
|
"poster_prompts": {},
|
||||||
|
"templates": {},
|
||||||
|
"defaults": {
|
||||||
|
"template": "vibrant",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"output_dir": "result/posters",
|
||||||
|
"image_dir": "/root/TravelContentCreator/data/images",
|
||||||
|
"font_dir": "/root/TravelContentCreator/assets/font"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_template_list(self) -> List[Dict[str, Any]]:
|
||||||
|
"""获取所有可用的模板列表"""
|
||||||
|
templates = self.config.get("templates", {})
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": template_id,
|
||||||
|
"name": template_info.get("name", template_id),
|
||||||
|
"description": template_info.get("description", ""),
|
||||||
|
"size": template_info.get("size", [900, 1200]),
|
||||||
|
"required_fields": template_info.get("required_fields", []),
|
||||||
|
"optional_fields": template_info.get("optional_fields", [])
|
||||||
|
}
|
||||||
|
for template_id, template_info in templates.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_template_info(self, template_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""获取指定模板的详细信息"""
|
||||||
|
return self.config.get("templates", {}).get(template_id)
|
||||||
|
|
||||||
|
def get_prompt_config(self, template_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""获取指定模板的提示词配置"""
|
||||||
|
template_info = self.get_template_info(template_id)
|
||||||
|
if not template_info:
|
||||||
|
return None
|
||||||
|
|
||||||
|
prompt_key = template_info.get("prompt_key", template_id)
|
||||||
|
return self.config.get("poster_prompts", {}).get(prompt_key)
|
||||||
|
|
||||||
|
def get_system_prompt(self, template_id: str) -> Optional[str]:
|
||||||
|
"""获取系统提示词"""
|
||||||
|
prompt_config = self.get_prompt_config(template_id)
|
||||||
|
if prompt_config:
|
||||||
|
return prompt_config.get("system_prompt")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_user_prompt_template(self, template_id: str) -> Optional[str]:
|
||||||
|
"""获取用户提示词模板"""
|
||||||
|
prompt_config = self.get_prompt_config(template_id)
|
||||||
|
if prompt_config:
|
||||||
|
return prompt_config.get("user_prompt_template")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def format_user_prompt(self, template_id: str, **kwargs) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
格式化用户提示词
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_id: 模板ID
|
||||||
|
**kwargs: 用于格式化的参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
格式化后的用户提示词
|
||||||
|
"""
|
||||||
|
template = self.get_user_prompt_template(template_id)
|
||||||
|
if not template:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 确保参数中的字典类型转为JSON字符串
|
||||||
|
formatted_kwargs = {}
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if isinstance(value, (dict, list)):
|
||||||
|
formatted_kwargs[key] = json.dumps(value, ensure_ascii=False, indent=2)
|
||||||
|
else:
|
||||||
|
formatted_kwargs[key] = str(value)
|
||||||
|
|
||||||
|
return template.format(**formatted_kwargs)
|
||||||
|
except KeyError as e:
|
||||||
|
logger.error(f"格式化提示词失败,缺少参数: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"格式化提示词时发生错误: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_default_config(self, key: str) -> Any:
|
||||||
|
"""获取默认配置值"""
|
||||||
|
return self.config.get("defaults", {}).get(key)
|
||||||
|
|
||||||
|
def validate_template_content(self, template_id: str, content: Dict[str, Any]) -> tuple[bool, List[str]]:
|
||||||
|
"""
|
||||||
|
验证模板内容是否符合要求
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_id: 模板ID
|
||||||
|
content: 要验证的内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(是否有效, 错误信息列表)
|
||||||
|
"""
|
||||||
|
template_info = self.get_template_info(template_id)
|
||||||
|
if not template_info:
|
||||||
|
return False, [f"未知的模板ID: {template_id}"]
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
required_fields = template_info.get("required_fields", [])
|
||||||
|
|
||||||
|
# 检查必填字段
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in content:
|
||||||
|
errors.append(f"缺少必填字段: {field}")
|
||||||
|
elif not content[field]:
|
||||||
|
errors.append(f"必填字段 {field} 不能为空")
|
||||||
|
|
||||||
|
return len(errors) == 0, errors
|
||||||
|
|
||||||
|
def get_template_class(self, template_id: str):
|
||||||
|
"""
|
||||||
|
动态获取模板类
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_id: 模板ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
模板类
|
||||||
|
"""
|
||||||
|
template_mapping = {
|
||||||
|
"vibrant": "poster.templates.vibrant_template.VibrantTemplate",
|
||||||
|
"business": "poster.templates.business_template.BusinessTemplate"
|
||||||
|
}
|
||||||
|
|
||||||
|
class_path = template_mapping.get(template_id)
|
||||||
|
if not class_path:
|
||||||
|
raise ValueError(f"未支持的模板类型: {template_id}")
|
||||||
|
|
||||||
|
# 动态导入模板类
|
||||||
|
module_path, class_name = class_path.rsplit(".", 1)
|
||||||
|
try:
|
||||||
|
module = __import__(module_path, fromlist=[class_name])
|
||||||
|
return getattr(module, class_name)
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"导入模板类失败: {e}")
|
||||||
|
raise ValueError(f"无法加载模板类: {template_id}")
|
||||||
|
|
||||||
|
def reload_config(self):
|
||||||
|
"""重新加载配置"""
|
||||||
|
logger.info("正在重新加载配置...")
|
||||||
|
self.load_config()
|
||||||
|
|
||||||
|
|
||||||
|
# 全局配置管理器实例
|
||||||
|
_config_manager = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_poster_config_manager() -> PosterConfigManager:
|
||||||
|
"""获取全局配置管理器实例"""
|
||||||
|
global _config_manager
|
||||||
|
if _config_manager is None:
|
||||||
|
_config_manager = PosterConfigManager()
|
||||||
|
return _config_manager
|
||||||
102
api/config/poster_prompts.yaml
Normal file
102
api/config/poster_prompts.yaml
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
# 海报生成提示词配置文件
|
||||||
|
poster_prompts:
|
||||||
|
vibrant:
|
||||||
|
system_prompt: |
|
||||||
|
你是一名专业的海报设计师,专门设计宣传海报。你现在要根据用户提供的信息,生成适合Vibrant模板的海报内容。
|
||||||
|
|
||||||
|
## Vibrant模板特点:
|
||||||
|
- 单图背景,毛玻璃渐变效果
|
||||||
|
- 两栏布局(左栏内容,右栏价格)
|
||||||
|
- 适合展示套餐内容和价格信息
|
||||||
|
|
||||||
|
## 你需要生成的数据结构包含以下字段:
|
||||||
|
|
||||||
|
**必填字段:**
|
||||||
|
1. `title`: 主标题(8-12字符,体现产品特色)
|
||||||
|
2. `slogan`: 副标题/宣传语(10-20字符,吸引人的描述)
|
||||||
|
3. `price`: 价格数字(纯数字,不含符号),如果材料中没有价格,则用"欢迎咨询"替代
|
||||||
|
4. `ticket_type`: 票种类型(如"成人票"、"套餐票"、"夜场票"等)
|
||||||
|
5. `content_button`: 内容按钮文字(通常为"套餐内容"、"包含项目"等)
|
||||||
|
6. `content_items`: 套餐内容列表(3-5个项目,每项5-15字符,不要只包含项目名称,要做合适的美化,可以适当省略)
|
||||||
|
|
||||||
|
**可选字段:**
|
||||||
|
7. `remarks`: 备注信息(1-3条,每条10-20字符)
|
||||||
|
8. `tag`: 标签(1条, 如"#限时优惠"等)
|
||||||
|
9. `pagination`: 分页信息(如"1/3",可为空)
|
||||||
|
|
||||||
|
## 内容创作要求:
|
||||||
|
1. 套餐内容要具体实用:明确说明包含的服务、时间、数量
|
||||||
|
2. 价格要有吸引力:突出性价比和优惠信息
|
||||||
|
|
||||||
|
## 输出格式:
|
||||||
|
请严格按照JSON格式输出,不要有任何额外内容。
|
||||||
|
|
||||||
|
user_prompt_template: |
|
||||||
|
请根据以下信息,生成适合在旅游海报上展示的文案:
|
||||||
|
|
||||||
|
## 景区信息
|
||||||
|
{scenic_info}
|
||||||
|
|
||||||
|
## 产品信息
|
||||||
|
{product_info}
|
||||||
|
|
||||||
|
## 推文信息
|
||||||
|
{tweet_info}
|
||||||
|
|
||||||
|
请提取关键信息并整合成一个JSON对象,包含title、slogan、price、ticket_type、content_items、remarks和tag字段。
|
||||||
|
|
||||||
|
business:
|
||||||
|
system_prompt: |
|
||||||
|
你是一名专业的商务海报设计师。你需要根据提供的信息生成适合Business模板的海报内容。
|
||||||
|
|
||||||
|
## Business模板特点:
|
||||||
|
- 商务风格,简洁专业
|
||||||
|
- 突出核心信息和价值主张
|
||||||
|
- 适合企业服务推广
|
||||||
|
|
||||||
|
## 生成字段结构:
|
||||||
|
1. `title`: 服务标题(6-10字符)
|
||||||
|
2. `subtitle`: 副标题(12-20字符)
|
||||||
|
3. `features`: 核心特性列表(3-4个特性)
|
||||||
|
4. `price`: 价格信息
|
||||||
|
5. `contact`: 联系方式信息
|
||||||
|
|
||||||
|
请以JSON格式输出结果。
|
||||||
|
|
||||||
|
user_prompt_template: |
|
||||||
|
请为以下商务信息生成海报内容:
|
||||||
|
|
||||||
|
## 服务信息
|
||||||
|
{service_info}
|
||||||
|
|
||||||
|
## 目标客户
|
||||||
|
{target_audience}
|
||||||
|
|
||||||
|
## 核心卖点
|
||||||
|
{key_points}
|
||||||
|
|
||||||
|
# 模板配置
|
||||||
|
templates:
|
||||||
|
vibrant:
|
||||||
|
name: "Vibrant活力模板"
|
||||||
|
description: "适合旅游景点、娱乐活动的活力海报模板"
|
||||||
|
size: [900, 1200]
|
||||||
|
required_fields: ["title", "slogan", "price", "ticket_type", "content_items"]
|
||||||
|
optional_fields: ["remarks", "tag", "pagination"]
|
||||||
|
prompt_key: "vibrant"
|
||||||
|
|
||||||
|
business:
|
||||||
|
name: "Business商务模板"
|
||||||
|
description: "适合企业服务、B2B推广的商务海报模板"
|
||||||
|
size: [1080, 1920]
|
||||||
|
required_fields: ["title", "subtitle", "features", "price"]
|
||||||
|
optional_fields: ["contact"]
|
||||||
|
prompt_key: "business"
|
||||||
|
|
||||||
|
# 默认配置
|
||||||
|
defaults:
|
||||||
|
template: "vibrant"
|
||||||
|
temperature: 0.7
|
||||||
|
output_dir: "result/posters"
|
||||||
|
image_dir: "/root/TravelContentCreator/data/images"
|
||||||
|
font_dir: "/root/TravelContentCreator/assets/font"
|
||||||
@ -56,11 +56,12 @@ app.add_middleware(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 导入路由
|
# 导入路由
|
||||||
from api.routers import tweet, poster, prompt, document, data, integration, content_integration
|
from api.routers import tweet, poster, poster_unified, prompt, document, data, integration, content_integration
|
||||||
|
|
||||||
# 包含路由
|
# 包含路由
|
||||||
app.include_router(tweet.router, prefix="/api/v1/tweet", tags=["tweet"])
|
app.include_router(tweet.router, prefix="/api/v1/tweet", tags=["tweet"])
|
||||||
app.include_router(poster.router, prefix="/api/v1/poster", tags=["poster"])
|
app.include_router(poster.router, prefix="/api/v1/poster", tags=["poster"])
|
||||||
|
app.include_router(poster_unified.router, prefix="/api/v2/poster", tags=["poster-unified"])
|
||||||
app.include_router(prompt.router, prefix="/api/v1/prompt", tags=["prompt"])
|
app.include_router(prompt.router, prefix="/api/v1/prompt", tags=["prompt"])
|
||||||
app.include_router(document.router, prefix="/api/v1/document", tags=["document"])
|
app.include_router(document.router, prefix="/api/v1/document", tags=["document"])
|
||||||
app.include_router(data.router, prefix="/api/v1", tags=["data"])
|
app.include_router(data.router, prefix="/api/v1", tags=["data"])
|
||||||
|
|||||||
Binary file not shown.
BIN
api/models/__pycache__/vibrant_poster.cpython-312.pyc
Normal file
BIN
api/models/__pycache__/vibrant_poster.cpython-312.pyc
Normal file
Binary file not shown.
331
api/models/vibrant_poster.py
Normal file
331
api/models/vibrant_poster.py
Normal file
@ -0,0 +1,331 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
海报API通用模型定义
|
||||||
|
支持多种模板类型的海报生成
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Dict, Any, Optional, Union
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
# 基础模板信息
|
||||||
|
class TemplateInfo(BaseModel):
|
||||||
|
"""模板信息模型"""
|
||||||
|
id: str = Field(..., description="模板ID")
|
||||||
|
name: str = Field(..., description="模板名称")
|
||||||
|
description: str = Field(..., description="模板描述")
|
||||||
|
size: List[int] = Field(..., description="模板尺寸 [宽, 高]")
|
||||||
|
required_fields: List[str] = Field(..., description="必填字段列表")
|
||||||
|
optional_fields: List[str] = Field(default=[], description="可选字段列表")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"id": "vibrant",
|
||||||
|
"name": "Vibrant活力模板",
|
||||||
|
"description": "适合旅游景点、娱乐活动的活力海报模板",
|
||||||
|
"size": [900, 1200],
|
||||||
|
"required_fields": ["title", "slogan", "price", "ticket_type", "content_items"],
|
||||||
|
"optional_fields": ["remarks", "tag", "pagination"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 通用内容模型(支持任意字段)
|
||||||
|
class PosterContent(BaseModel):
|
||||||
|
"""通用海报内容模型,支持动态字段"""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow" # 允许额外字段
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"title": "海洋奇幻世界",
|
||||||
|
"slogan": "探索深海秘境,感受蓝色奇迹的无限魅力",
|
||||||
|
"price": "299",
|
||||||
|
"ticket_type": "成人票",
|
||||||
|
"content_items": [
|
||||||
|
"海洋馆门票1张(含所有展区)",
|
||||||
|
"海豚表演VIP座位"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 保持向后兼容的Vibrant内容模型
|
||||||
|
class VibrantPosterContent(PosterContent):
|
||||||
|
"""Vibrant海报内容模型(向后兼容)"""
|
||||||
|
title: Optional[str] = Field(None, description="主标题(8-12字符)")
|
||||||
|
slogan: Optional[str] = Field(None, description="副标题/宣传语(10-20字符)")
|
||||||
|
price: Optional[str] = Field(None, description="价格,如果没有价格则用'欢迎咨询'")
|
||||||
|
ticket_type: Optional[str] = Field(None, description="票种类型(如成人票、套餐票等)")
|
||||||
|
content_button: Optional[str] = Field(None, description="内容按钮文字")
|
||||||
|
content_items: Optional[List[str]] = Field(None, description="套餐内容列表(3-5个项目)")
|
||||||
|
remarks: Optional[List[str]] = Field(None, description="备注信息(1-3条)")
|
||||||
|
tag: Optional[str] = Field(None, description="标签")
|
||||||
|
pagination: Optional[str] = Field(None, description="分页信息")
|
||||||
|
|
||||||
|
|
||||||
|
# 通用海报生成请求
|
||||||
|
class PosterGenerationRequest(BaseModel):
|
||||||
|
"""通用海报生成请求模型"""
|
||||||
|
# 模板相关
|
||||||
|
template_id: str = Field(..., description="模板ID")
|
||||||
|
|
||||||
|
# 内容相关(二选一)
|
||||||
|
content: Optional[Dict[str, Any]] = Field(None, description="直接提供的海报内容")
|
||||||
|
source_data: Optional[Dict[str, Any]] = Field(None, description="源数据,用于AI生成内容")
|
||||||
|
|
||||||
|
# 生成参数
|
||||||
|
topic_name: Optional[str] = Field(None, description="主题名称,用于文件命名")
|
||||||
|
image_path: Optional[str] = Field(None, description="指定图片路径,如果不提供则随机选择")
|
||||||
|
image_dir: Optional[str] = Field(None, description="图片目录,如果不提供使用默认目录")
|
||||||
|
output_dir: Optional[str] = Field(None, description="输出目录,如果不提供使用默认目录")
|
||||||
|
|
||||||
|
# AI参数
|
||||||
|
temperature: Optional[float] = Field(default=0.7, description="AI生成温度参数")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"template_id": "vibrant",
|
||||||
|
"source_data": {
|
||||||
|
"scenic_info": {
|
||||||
|
"name": "天津冒险湾",
|
||||||
|
"location": "天津市",
|
||||||
|
"type": "水上乐园"
|
||||||
|
},
|
||||||
|
"product_info": {
|
||||||
|
"name": "冒险湾门票",
|
||||||
|
"price": 299,
|
||||||
|
"type": "成人票"
|
||||||
|
},
|
||||||
|
"tweet_info": {
|
||||||
|
"title": "夏日清凉首选",
|
||||||
|
"content": "天津冒险湾水上乐园,多种刺激项目等你来挑战..."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"topic_name": "天津冒险湾",
|
||||||
|
"temperature": 0.7
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 内容生成请求
|
||||||
|
class ContentGenerationRequest(BaseModel):
|
||||||
|
"""内容生成请求模型"""
|
||||||
|
template_id: str = Field(..., description="模板ID")
|
||||||
|
source_data: Dict[str, Any] = Field(..., description="源数据,用于AI生成内容")
|
||||||
|
temperature: Optional[float] = Field(default=0.7, description="AI生成温度参数")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"template_id": "vibrant",
|
||||||
|
"source_data": {
|
||||||
|
"scenic_info": {"name": "天津冒险湾", "type": "水上乐园"},
|
||||||
|
"product_info": {"name": "门票", "price": 299},
|
||||||
|
"tweet_info": {"title": "夏日清凉", "content": "水上乐园体验"}
|
||||||
|
},
|
||||||
|
"temperature": 0.7
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 保持向后兼容
|
||||||
|
class VibrantPosterRequest(PosterGenerationRequest):
|
||||||
|
"""Vibrant海报生成请求模型(向后兼容)"""
|
||||||
|
template_id: str = Field(default="vibrant", description="模板ID,默认为vibrant")
|
||||||
|
|
||||||
|
# 兼容旧的字段结构
|
||||||
|
scenic_info: Optional[Dict[str, Any]] = Field(None, description="景区信息")
|
||||||
|
product_info: Optional[Dict[str, Any]] = Field(None, description="产品信息")
|
||||||
|
tweet_info: Optional[Dict[str, Any]] = Field(None, description="推文信息")
|
||||||
|
|
||||||
|
def __init__(self, **data):
|
||||||
|
# 转换旧格式到新格式
|
||||||
|
if 'scenic_info' in data or 'product_info' in data or 'tweet_info' in data:
|
||||||
|
source_data = {}
|
||||||
|
for key in ['scenic_info', 'product_info', 'tweet_info']:
|
||||||
|
if key in data and data[key] is not None:
|
||||||
|
source_data[key] = data.pop(key)
|
||||||
|
if source_data and 'source_data' not in data:
|
||||||
|
data['source_data'] = source_data
|
||||||
|
|
||||||
|
super().__init__(**data)
|
||||||
|
|
||||||
|
|
||||||
|
# 标准API响应基础类
|
||||||
|
class BaseAPIResponse(BaseModel):
|
||||||
|
"""API响应基础模型"""
|
||||||
|
success: bool = Field(..., description="操作是否成功")
|
||||||
|
message: str = Field(default="", description="响应消息")
|
||||||
|
request_id: str = Field(..., description="请求ID")
|
||||||
|
timestamp: str = Field(..., description="响应时间戳")
|
||||||
|
|
||||||
|
|
||||||
|
# 通用海报生成响应
|
||||||
|
class PosterGenerationResponse(BaseAPIResponse):
|
||||||
|
"""通用海报生成响应模型"""
|
||||||
|
data: Optional[Dict[str, Any]] = Field(None, description="响应数据")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"success": True,
|
||||||
|
"message": "海报生成成功",
|
||||||
|
"request_id": "poster-20240715-123456-a1b2c3d4",
|
||||||
|
"timestamp": "2024-07-15T12:34:56Z",
|
||||||
|
"data": {
|
||||||
|
"template_id": "vibrant",
|
||||||
|
"topic_name": "天津冒险湾",
|
||||||
|
"poster_path": "/result/posters/vibrant_海洋奇幻世界_20240715_123456.png",
|
||||||
|
"content": {
|
||||||
|
"title": "海洋奇幻世界",
|
||||||
|
"slogan": "探索深海秘境,感受蓝色奇迹的无限魅力",
|
||||||
|
"price": "299"
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"image_used": "/data/images/ocean_park.jpg",
|
||||||
|
"generation_method": "ai_generated",
|
||||||
|
"template_size": [900, 1200],
|
||||||
|
"processing_time": 3.2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 内容生成响应
|
||||||
|
class ContentGenerationResponse(BaseAPIResponse):
|
||||||
|
"""内容生成响应模型"""
|
||||||
|
data: Optional[Dict[str, Any]] = Field(None, description="生成的内容")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"success": True,
|
||||||
|
"message": "内容生成成功",
|
||||||
|
"request_id": "content-20240715-123456-a1b2c3d4",
|
||||||
|
"timestamp": "2024-07-15T12:34:56Z",
|
||||||
|
"data": {
|
||||||
|
"template_id": "vibrant",
|
||||||
|
"content": {
|
||||||
|
"title": "冒险湾水世界",
|
||||||
|
"slogan": "夏日激情体验,清凉无限乐趣",
|
||||||
|
"price": "299",
|
||||||
|
"ticket_type": "成人票"
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"generation_method": "ai_generated",
|
||||||
|
"temperature": 0.7
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 模板列表响应
|
||||||
|
class TemplateListResponse(BaseAPIResponse):
|
||||||
|
"""模板列表响应模型"""
|
||||||
|
data: Optional[List[TemplateInfo]] = Field(None, description="模板列表")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"success": True,
|
||||||
|
"message": "获取模板列表成功",
|
||||||
|
"request_id": "templates-20240715-123456",
|
||||||
|
"timestamp": "2024-07-15T12:34:56Z",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": "vibrant",
|
||||||
|
"name": "Vibrant活力模板",
|
||||||
|
"description": "适合旅游景点、娱乐活动的活力海报模板",
|
||||||
|
"size": [900, 1200],
|
||||||
|
"required_fields": ["title", "slogan", "price"],
|
||||||
|
"optional_fields": ["remarks", "tag"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 保持向后兼容的响应模型
|
||||||
|
class VibrantPosterResponse(BaseModel):
|
||||||
|
"""Vibrant海报生成响应模型(向后兼容)"""
|
||||||
|
request_id: str = Field(..., description="请求ID")
|
||||||
|
topic_name: str = Field(..., description="主题名称")
|
||||||
|
poster_path: str = Field(..., description="生成的海报文件路径")
|
||||||
|
generated_content: Dict[str, Any] = Field(..., description="生成或使用的海报内容")
|
||||||
|
image_used: str = Field(..., description="使用的图片路径")
|
||||||
|
generation_method: str = Field(..., description="生成方式:direct(直接使用提供内容)或ai_generated(AI生成)")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"request_id": "vibrant-20240715-123456-a1b2c3d4",
|
||||||
|
"topic_name": "天津冒险湾",
|
||||||
|
"poster_path": "/result/posters/vibrant_海洋奇幻世界_20240715_123456.png",
|
||||||
|
"generated_content": {
|
||||||
|
"title": "海洋奇幻世界",
|
||||||
|
"slogan": "探索深海秘境,感受蓝色奇迹的无限魅力",
|
||||||
|
"price": "299",
|
||||||
|
"ticket_type": "成人票",
|
||||||
|
"content_items": ["海洋馆门票1张", "海豚表演VIP座位"]
|
||||||
|
},
|
||||||
|
"image_used": "/data/images/ocean_park.jpg",
|
||||||
|
"generation_method": "ai_generated"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BatchVibrantPosterRequest(BaseModel):
|
||||||
|
"""批量Vibrant海报生成请求模型"""
|
||||||
|
base_path: str = Field(..., description="包含多个topic目录的基础路径")
|
||||||
|
image_dir: Optional[str] = Field(None, description="图片目录")
|
||||||
|
scenic_info_file: Optional[str] = Field(None, description="景区信息文件路径")
|
||||||
|
product_info_file: Optional[str] = Field(None, description="产品信息文件路径")
|
||||||
|
output_base: Optional[str] = Field(default="result/posters", description="输出基础目录")
|
||||||
|
parallel_count: Optional[int] = Field(default=3, description="并发处理数量")
|
||||||
|
temperature: Optional[float] = Field(default=0.7, description="AI生成温度参数")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"base_path": "/root/TravelContentCreator/result/run_20250710_165327",
|
||||||
|
"image_dir": "/root/TravelContentCreator/data/images",
|
||||||
|
"scenic_info_file": "/root/TravelContentCreator/resource/data/Object/天津冒险湾.txt",
|
||||||
|
"product_info_file": "/root/TravelContentCreator/resource/data/Product/product.bak",
|
||||||
|
"output_base": "result/posters",
|
||||||
|
"parallel_count": 3,
|
||||||
|
"temperature": 0.7
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BatchVibrantPosterResponse(BaseModel):
|
||||||
|
"""批量Vibrant海报生成响应模型"""
|
||||||
|
request_id: str = Field(..., description="批量处理请求ID")
|
||||||
|
total_topics: int = Field(..., description="总共处理的topic数量")
|
||||||
|
successful_count: int = Field(..., description="成功生成的海报数量")
|
||||||
|
failed_count: int = Field(..., description="失败的海报数量")
|
||||||
|
output_base_dir: str = Field(..., description="输出基础目录")
|
||||||
|
successful_topics: List[str] = Field(..., description="成功处理的topic列表")
|
||||||
|
failed_topics: List[Dict[str, str]] = Field(..., description="失败的topic及错误信息")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"request_id": "batch-vibrant-20240715-123456",
|
||||||
|
"total_topics": 5,
|
||||||
|
"successful_count": 4,
|
||||||
|
"failed_count": 1,
|
||||||
|
"output_base_dir": "/result/posters/run_20250710_165327",
|
||||||
|
"successful_topics": ["topic_1", "topic_2", "topic_3", "topic_4"],
|
||||||
|
"failed_topics": [
|
||||||
|
{"topic": "topic_5", "error": "图片文件不存在"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
BIN
api/routers/__pycache__/poster_unified.cpython-312.pyc
Normal file
BIN
api/routers/__pycache__/poster_unified.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
api/routers/__pycache__/vibrant_poster.cpython-312.pyc
Normal file
BIN
api/routers/__pycache__/vibrant_poster.cpython-312.pyc
Normal file
Binary file not shown.
314
api/routers/poster_unified.py
Normal file
314
api/routers/poster_unified.py
Normal file
@ -0,0 +1,314 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
统一海报API路由
|
||||||
|
支持多种模板类型的海报生成,配置化管理
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
|
||||||
|
from core.ai import AIAgent
|
||||||
|
from api.services.poster_service import UnifiedPosterService
|
||||||
|
from api.models.vibrant_poster import (
|
||||||
|
PosterGenerationRequest, PosterGenerationResponse,
|
||||||
|
ContentGenerationRequest, ContentGenerationResponse,
|
||||||
|
TemplateListResponse, TemplateInfo,
|
||||||
|
BaseAPIResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
# 从依赖注入模块导入依赖
|
||||||
|
from api.dependencies import get_ai_agent
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 创建路由
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# 依赖注入函数
|
||||||
|
def get_unified_poster_service(
|
||||||
|
ai_agent: AIAgent = Depends(get_ai_agent)
|
||||||
|
) -> UnifiedPosterService:
|
||||||
|
"""获取统一海报服务"""
|
||||||
|
return UnifiedPosterService(ai_agent)
|
||||||
|
|
||||||
|
|
||||||
|
def create_response(success: bool, message: str, data: Any = None, request_id: str = None) -> Dict[str, Any]:
|
||||||
|
"""创建标准响应"""
|
||||||
|
if request_id is None:
|
||||||
|
request_id = f"req-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": success,
|
||||||
|
"message": message,
|
||||||
|
"request_id": request_id,
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"data": data
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/templates", response_model=TemplateListResponse, summary="获取所有可用模板")
|
||||||
|
async def get_templates(
|
||||||
|
service: UnifiedPosterService = Depends(get_unified_poster_service)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取所有可用的海报模板列表
|
||||||
|
|
||||||
|
返回每个模板的详细信息,包括:
|
||||||
|
- 模板ID和名称
|
||||||
|
- 模板描述
|
||||||
|
- 模板尺寸
|
||||||
|
- 必填字段和可选字段
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
templates = service.get_available_templates()
|
||||||
|
response_data = create_response(
|
||||||
|
success=True,
|
||||||
|
message="获取模板列表成功",
|
||||||
|
data=templates
|
||||||
|
)
|
||||||
|
return TemplateListResponse(**response_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取模板列表失败: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"获取模板列表失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/templates/{template_id}", response_model=BaseAPIResponse, summary="获取指定模板信息")
|
||||||
|
async def get_template_info(
|
||||||
|
template_id: str,
|
||||||
|
service: UnifiedPosterService = Depends(get_unified_poster_service)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取指定模板的详细信息
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **template_id**: 模板ID
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
template_info = service.get_template_info(template_id)
|
||||||
|
if not template_info:
|
||||||
|
raise HTTPException(status_code=404, detail=f"模板 {template_id} 不存在")
|
||||||
|
|
||||||
|
response_data = create_response(
|
||||||
|
success=True,
|
||||||
|
message="获取模板信息成功",
|
||||||
|
data=template_info.dict()
|
||||||
|
)
|
||||||
|
return BaseAPIResponse(**response_data)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取模板信息失败: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"获取模板信息失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/content/generate", response_model=ContentGenerationResponse, summary="生成海报内容")
|
||||||
|
async def generate_content(
|
||||||
|
request: ContentGenerationRequest,
|
||||||
|
service: UnifiedPosterService = Depends(get_unified_poster_service)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
根据源数据生成海报内容,不生成实际图片
|
||||||
|
|
||||||
|
用于:
|
||||||
|
1. 预览生成的内容
|
||||||
|
2. 调试和测试内容生成
|
||||||
|
3. 分步骤生成(先生成内容,再生成图片)
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **template_id**: 模板ID
|
||||||
|
- **source_data**: 源数据,用于AI生成内容
|
||||||
|
- **temperature**: AI生成温度参数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
content = await service.generate_content(
|
||||||
|
template_id=request.template_id,
|
||||||
|
source_data=request.source_data,
|
||||||
|
temperature=request.temperature
|
||||||
|
)
|
||||||
|
|
||||||
|
response_data = create_response(
|
||||||
|
success=True,
|
||||||
|
message="内容生成成功",
|
||||||
|
data={
|
||||||
|
"template_id": request.template_id,
|
||||||
|
"content": content,
|
||||||
|
"metadata": {
|
||||||
|
"generation_method": "ai_generated",
|
||||||
|
"temperature": request.temperature
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return ContentGenerationResponse(**response_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"生成内容失败: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"生成内容失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/generate", response_model=PosterGenerationResponse, summary="生成海报")
|
||||||
|
async def generate_poster(
|
||||||
|
request: PosterGenerationRequest,
|
||||||
|
service: UnifiedPosterService = Depends(get_unified_poster_service)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
生成海报图片
|
||||||
|
|
||||||
|
支持两种模式:
|
||||||
|
1. 直接提供内容(content字段)
|
||||||
|
2. 提供源数据让AI生成内容(source_data字段)
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **template_id**: 模板ID
|
||||||
|
- **content**: 直接提供的海报内容(可选)
|
||||||
|
- **source_data**: 源数据,用于AI生成内容(可选)
|
||||||
|
- **topic_name**: 主题名称,用于文件命名
|
||||||
|
- **image_path**: 指定图片路径(可选)
|
||||||
|
- **image_dir**: 图片目录(可选)
|
||||||
|
- **output_dir**: 输出目录(可选)
|
||||||
|
- **temperature**: AI生成温度参数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await service.generate_poster(
|
||||||
|
template_id=request.template_id,
|
||||||
|
content=request.content,
|
||||||
|
source_data=request.source_data,
|
||||||
|
topic_name=request.topic_name,
|
||||||
|
image_path=request.image_path,
|
||||||
|
image_dir=request.image_dir,
|
||||||
|
output_dir=request.output_dir,
|
||||||
|
temperature=request.temperature
|
||||||
|
)
|
||||||
|
|
||||||
|
response_data = create_response(
|
||||||
|
success=True,
|
||||||
|
message="海报生成成功",
|
||||||
|
data=result,
|
||||||
|
request_id=result["request_id"]
|
||||||
|
)
|
||||||
|
return PosterGenerationResponse(**response_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"生成海报失败: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"生成海报失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/batch", response_model=BaseAPIResponse, summary="批量生成海报")
|
||||||
|
async def batch_generate_posters(
|
||||||
|
template_id: str,
|
||||||
|
base_path: str,
|
||||||
|
image_dir: str = None,
|
||||||
|
source_files: Dict[str, str] = None,
|
||||||
|
output_base: str = "result/posters",
|
||||||
|
parallel_count: int = 3,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
service: UnifiedPosterService = Depends(get_unified_poster_service)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
批量生成海报
|
||||||
|
|
||||||
|
自动扫描指定目录下的topic文件夹,为每个topic生成海报。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **template_id**: 模板ID
|
||||||
|
- **base_path**: 包含多个topic目录的基础路径
|
||||||
|
- **image_dir**: 图片目录(可选)
|
||||||
|
- **source_files**: 源文件配置字典(可选)
|
||||||
|
- **output_base**: 输出基础目录
|
||||||
|
- **parallel_count**: 并发处理数量
|
||||||
|
- **temperature**: AI生成温度参数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await service.batch_generate_posters(
|
||||||
|
template_id=template_id,
|
||||||
|
base_path=base_path,
|
||||||
|
image_dir=image_dir,
|
||||||
|
source_files=source_files or {},
|
||||||
|
output_base=output_base,
|
||||||
|
parallel_count=parallel_count,
|
||||||
|
temperature=temperature
|
||||||
|
)
|
||||||
|
|
||||||
|
response_data = create_response(
|
||||||
|
success=True,
|
||||||
|
message=f"批量生成完成,成功: {result['successful_count']}, 失败: {result['failed_count']}",
|
||||||
|
data=result,
|
||||||
|
request_id=result["request_id"]
|
||||||
|
)
|
||||||
|
return BaseAPIResponse(**response_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"批量生成海报失败: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"批量生成海报失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/config/reload", response_model=BaseAPIResponse, summary="重新加载配置")
|
||||||
|
async def reload_config(
|
||||||
|
service: UnifiedPosterService = Depends(get_unified_poster_service)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
重新加载海报配置
|
||||||
|
|
||||||
|
用于在不重启服务的情况下更新配置,包括:
|
||||||
|
- 提示词模板
|
||||||
|
- 模板配置
|
||||||
|
- 默认参数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service.reload_config()
|
||||||
|
response_data = create_response(
|
||||||
|
success=True,
|
||||||
|
message="配置重新加载成功"
|
||||||
|
)
|
||||||
|
return BaseAPIResponse(**response_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"重新加载配置失败: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"重新加载配置失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health", summary="健康检查")
|
||||||
|
async def health_check():
|
||||||
|
"""服务健康检查"""
|
||||||
|
return create_response(
|
||||||
|
success=True,
|
||||||
|
message="统一海报服务运行正常",
|
||||||
|
data={
|
||||||
|
"service": "unified_poster",
|
||||||
|
"status": "healthy",
|
||||||
|
"version": "2.0.0"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config", summary="获取服务配置")
|
||||||
|
async def get_service_config(
|
||||||
|
service: UnifiedPosterService = Depends(get_unified_poster_service)
|
||||||
|
):
|
||||||
|
"""获取服务配置信息"""
|
||||||
|
try:
|
||||||
|
config_info = {
|
||||||
|
"default_image_dir": service.config_manager.get_default_config("image_dir"),
|
||||||
|
"default_output_dir": service.config_manager.get_default_config("output_dir"),
|
||||||
|
"default_font_dir": service.config_manager.get_default_config("font_dir"),
|
||||||
|
"default_template": service.config_manager.get_default_config("template"),
|
||||||
|
"supported_image_formats": ["png", "jpg", "jpeg", "webp"],
|
||||||
|
"available_templates": len(service.get_available_templates())
|
||||||
|
}
|
||||||
|
|
||||||
|
response_data = create_response(
|
||||||
|
success=True,
|
||||||
|
message="获取配置成功",
|
||||||
|
data=config_info
|
||||||
|
)
|
||||||
|
return BaseAPIResponse(**response_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取配置失败: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"获取配置失败: {str(e)}")
|
||||||
BIN
api/services/__pycache__/poster_service.cpython-312.pyc
Normal file
BIN
api/services/__pycache__/poster_service.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
api/services/__pycache__/vibrant_poster.cpython-312.pyc
Normal file
BIN
api/services/__pycache__/vibrant_poster.cpython-312.pyc
Normal file
Binary file not shown.
409
api/services/poster_service.py
Normal file
409
api/services/poster_service.py
Normal file
@ -0,0 +1,409 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
通用海报服务层
|
||||||
|
支持多种模板类型的海报生成,配置化管理
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
|
|
||||||
|
# 添加项目根目录到路径
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||||
|
|
||||||
|
from core.ai.ai_agent import AIAgent
|
||||||
|
from api.config.poster_config_manager import get_poster_config_manager
|
||||||
|
from api.models.vibrant_poster import TemplateInfo
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UnifiedPosterService:
|
||||||
|
"""统一海报服务类"""
|
||||||
|
|
||||||
|
def __init__(self, ai_agent: AIAgent):
|
||||||
|
"""
|
||||||
|
初始化统一海报服务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ai_agent: AI代理
|
||||||
|
"""
|
||||||
|
self.ai_agent = ai_agent
|
||||||
|
self.config_manager = get_poster_config_manager()
|
||||||
|
|
||||||
|
def get_available_templates(self) -> List[TemplateInfo]:
|
||||||
|
"""获取所有可用的模板列表"""
|
||||||
|
template_list = self.config_manager.get_template_list()
|
||||||
|
return [TemplateInfo(**template) for template in template_list]
|
||||||
|
|
||||||
|
def get_template_info(self, template_id: str) -> Optional[TemplateInfo]:
|
||||||
|
"""获取指定模板的信息"""
|
||||||
|
template_info = self.config_manager.get_template_info(template_id)
|
||||||
|
if template_info:
|
||||||
|
return TemplateInfo(
|
||||||
|
id=template_id,
|
||||||
|
name=template_info.get("name", template_id),
|
||||||
|
description=template_info.get("description", ""),
|
||||||
|
size=template_info.get("size", [900, 1200]),
|
||||||
|
required_fields=template_info.get("required_fields", []),
|
||||||
|
optional_fields=template_info.get("optional_fields", [])
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def generate_content(self, template_id: str, source_data: Dict[str, Any],
|
||||||
|
temperature: float = 0.7) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
生成海报内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_id: 模板ID
|
||||||
|
source_data: 源数据
|
||||||
|
temperature: AI生成温度参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
生成的内容字典
|
||||||
|
"""
|
||||||
|
logger.info(f"正在为模板 {template_id} 生成内容...")
|
||||||
|
|
||||||
|
# 获取系统提示词
|
||||||
|
system_prompt = self.config_manager.get_system_prompt(template_id)
|
||||||
|
if not system_prompt:
|
||||||
|
raise ValueError(f"模板 {template_id} 没有配置系统提示词")
|
||||||
|
|
||||||
|
# 格式化用户提示词
|
||||||
|
user_prompt = self.config_manager.format_user_prompt(template_id, **source_data)
|
||||||
|
if not user_prompt:
|
||||||
|
raise ValueError(f"模板 {template_id} 用户提示词格式化失败")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 调用AI生成内容
|
||||||
|
response, _, _, _ = await self.ai_agent.generate_text(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
user_prompt=user_prompt,
|
||||||
|
temperature=temperature,
|
||||||
|
stage=f"海报内容生成-{template_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析JSON响应
|
||||||
|
json_start = response.find('{')
|
||||||
|
json_end = response.rfind('}') + 1
|
||||||
|
|
||||||
|
if json_start >= 0 and json_end > json_start:
|
||||||
|
json_str = response[json_start:json_end]
|
||||||
|
content_dict = json.loads(json_str)
|
||||||
|
logger.info(f"AI成功生成内容: {content_dict}")
|
||||||
|
|
||||||
|
# 确保所有值都是字符串类型(除了列表)
|
||||||
|
for key, value in content_dict.items():
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
content_dict[key] = str(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
content_dict[key] = [str(item) if isinstance(item, (int, float)) else item for item in value]
|
||||||
|
|
||||||
|
return content_dict
|
||||||
|
else:
|
||||||
|
logger.error(f"无法在AI响应中找到JSON对象: {response}")
|
||||||
|
raise ValueError("AI响应格式不正确")
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"无法解析AI响应为JSON: {e}")
|
||||||
|
raise ValueError("AI响应JSON解析失败")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"调用AI生成内容时发生错误: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def select_random_image(self, image_dir: Optional[str] = None) -> str:
|
||||||
|
"""从指定目录随机选择一张图片"""
|
||||||
|
if image_dir is None:
|
||||||
|
image_dir = self.config_manager.get_default_config("image_dir")
|
||||||
|
|
||||||
|
try:
|
||||||
|
image_files = [f for f in os.listdir(image_dir)
|
||||||
|
if f.lower().endswith(('png', 'jpg', 'jpeg', 'webp'))]
|
||||||
|
if not image_files:
|
||||||
|
raise ValueError(f"在目录 {image_dir} 中未找到任何图片文件")
|
||||||
|
|
||||||
|
random_image_name = random.choice(image_files)
|
||||||
|
image_path = os.path.join(image_dir, random_image_name)
|
||||||
|
logger.info(f"随机选择图片: {image_path}")
|
||||||
|
return image_path
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise ValueError(f"图片目录不存在: {image_dir}")
|
||||||
|
|
||||||
|
def validate_content(self, template_id: str, content: Dict[str, Any]) -> None:
|
||||||
|
"""验证内容是否符合模板要求"""
|
||||||
|
is_valid, errors = self.config_manager.validate_template_content(template_id, content)
|
||||||
|
if not is_valid:
|
||||||
|
raise ValueError(f"内容验证失败: {', '.join(errors)}")
|
||||||
|
|
||||||
|
async def generate_poster(self,
|
||||||
|
template_id: str,
|
||||||
|
content: Optional[Dict[str, Any]] = None,
|
||||||
|
source_data: Optional[Dict[str, Any]] = None,
|
||||||
|
topic_name: Optional[str] = None,
|
||||||
|
image_path: Optional[str] = None,
|
||||||
|
image_dir: Optional[str] = None,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
temperature: float = 0.7) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
生成海报
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_id: 模板ID
|
||||||
|
content: 直接提供的内容(可选)
|
||||||
|
source_data: 源数据,用于AI生成内容(可选)
|
||||||
|
topic_name: 主题名称
|
||||||
|
image_path: 指定图片路径
|
||||||
|
image_dir: 图片目录
|
||||||
|
output_dir: 输出目录
|
||||||
|
temperature: AI生成温度参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
生成结果字典
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
logger.info(f"开始生成海报,模板: {template_id}, 主题: {topic_name}")
|
||||||
|
|
||||||
|
# 生成请求ID
|
||||||
|
request_id = f"poster-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}"
|
||||||
|
|
||||||
|
# 获取模板信息
|
||||||
|
template_info = self.get_template_info(template_id)
|
||||||
|
if not template_info:
|
||||||
|
raise ValueError(f"未知的模板ID: {template_id}")
|
||||||
|
|
||||||
|
# 确定内容
|
||||||
|
if content is None:
|
||||||
|
if source_data is None:
|
||||||
|
raise ValueError("必须提供content或source_data中的一个")
|
||||||
|
|
||||||
|
# 使用AI生成内容
|
||||||
|
content = await self.generate_content(template_id, source_data, temperature)
|
||||||
|
generation_method = "ai_generated"
|
||||||
|
else:
|
||||||
|
generation_method = "direct"
|
||||||
|
|
||||||
|
# 验证内容
|
||||||
|
self.validate_content(template_id, content)
|
||||||
|
|
||||||
|
# 选择图片
|
||||||
|
if image_path is None:
|
||||||
|
image_path = self.select_random_image(image_dir)
|
||||||
|
|
||||||
|
if not os.path.exists(image_path):
|
||||||
|
raise ValueError(f"指定的图片文件不存在: {image_path}")
|
||||||
|
|
||||||
|
# 设置默认值
|
||||||
|
if output_dir is None:
|
||||||
|
output_dir = self.config_manager.get_default_config("output_dir")
|
||||||
|
if topic_name is None:
|
||||||
|
topic_name = f"poster_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||||
|
|
||||||
|
# 获取模板类并生成海报
|
||||||
|
try:
|
||||||
|
template_class = self.config_manager.get_template_class(template_id)
|
||||||
|
template_instance = template_class(template_info.size)
|
||||||
|
|
||||||
|
# 设置字体(如果支持)
|
||||||
|
font_dir = self.config_manager.get_default_config("font_dir")
|
||||||
|
if hasattr(template_instance, 'set_font_dir') and font_dir:
|
||||||
|
template_instance.set_font_dir(font_dir)
|
||||||
|
|
||||||
|
poster = template_instance.generate(image_path=image_path, content=content)
|
||||||
|
|
||||||
|
if not poster:
|
||||||
|
raise ValueError("海报生成失败,模板返回了 None")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"生成海报时发生错误: {e}", exc_info=True)
|
||||||
|
raise ValueError(f"海报生成失败: {str(e)}")
|
||||||
|
|
||||||
|
# 保存海报
|
||||||
|
try:
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
|
||||||
|
# 生成文件名
|
||||||
|
title = content.get('title', topic_name)
|
||||||
|
if isinstance(title, str):
|
||||||
|
title = title.replace('/', '_').replace('\\', '_')
|
||||||
|
output_filename = f"{template_id}_{title}_{timestamp}.png"
|
||||||
|
poster_path = os.path.join(output_dir, output_filename)
|
||||||
|
|
||||||
|
poster.save(poster_path, 'PNG')
|
||||||
|
logger.info(f"海报已成功生成并保存至: {poster_path}")
|
||||||
|
|
||||||
|
processing_time = round(time.time() - start_time, 2)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"request_id": request_id,
|
||||||
|
"template_id": template_id,
|
||||||
|
"topic_name": topic_name,
|
||||||
|
"poster_path": poster_path,
|
||||||
|
"content": content,
|
||||||
|
"metadata": {
|
||||||
|
"image_used": image_path,
|
||||||
|
"generation_method": generation_method,
|
||||||
|
"template_size": template_info.size,
|
||||||
|
"processing_time": processing_time,
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存海报失败: {e}", exc_info=True)
|
||||||
|
raise ValueError(f"保存海报失败: {str(e)}")
|
||||||
|
|
||||||
|
async def batch_generate_posters(self,
|
||||||
|
template_id: str,
|
||||||
|
base_path: str,
|
||||||
|
image_dir: Optional[str] = None,
|
||||||
|
source_files: Optional[Dict[str, str]] = None,
|
||||||
|
output_base: str = "result/posters",
|
||||||
|
parallel_count: int = 3,
|
||||||
|
temperature: float = 0.7) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
批量生成海报
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_id: 模板ID
|
||||||
|
base_path: 包含多个topic目录的基础路径
|
||||||
|
image_dir: 图片目录
|
||||||
|
source_files: 源文件配置字典
|
||||||
|
output_base: 输出基础目录
|
||||||
|
parallel_count: 并发数量
|
||||||
|
temperature: AI生成温度参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
批量处理结果
|
||||||
|
"""
|
||||||
|
logger.info(f"开始批量生成海报,模板: {template_id}, 基础路径: {base_path}")
|
||||||
|
|
||||||
|
# 生成批处理ID
|
||||||
|
batch_request_id = f"batch-{template_id}-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
|
||||||
|
|
||||||
|
# 查找topic目录
|
||||||
|
topic_dirs = self._find_topic_directories(base_path)
|
||||||
|
|
||||||
|
if not topic_dirs:
|
||||||
|
raise ValueError("未找到任何包含article_judged.json的topic目录")
|
||||||
|
|
||||||
|
logger.info(f"找到 {len(topic_dirs)} 个topic目录,准备批量生成海报")
|
||||||
|
|
||||||
|
# 准备输出目录
|
||||||
|
base_name = Path(base_path).name
|
||||||
|
output_base_dir = os.path.join(output_base, base_name)
|
||||||
|
|
||||||
|
# 这里简化实现,实际项目中可以使用asyncio.gather进行真正的异步批处理
|
||||||
|
results = []
|
||||||
|
successful_count = 0
|
||||||
|
failed_count = 0
|
||||||
|
|
||||||
|
for topic_path, topic_name in topic_dirs:
|
||||||
|
try:
|
||||||
|
article_path = os.path.join(topic_path, 'article_judged.json')
|
||||||
|
topic_output_dir = os.path.join(output_base_dir, topic_name)
|
||||||
|
|
||||||
|
# 读取文章数据
|
||||||
|
source_data = self._read_data_file(article_path)
|
||||||
|
if not source_data:
|
||||||
|
raise ValueError(f"无法读取文章文件: {article_path}")
|
||||||
|
|
||||||
|
# 构建源数据
|
||||||
|
final_source_data = {"tweet_info": source_data}
|
||||||
|
|
||||||
|
# 如果提供了额外的源文件,读取并添加
|
||||||
|
if source_files:
|
||||||
|
for key, file_path in source_files.items():
|
||||||
|
if file_path and os.path.exists(file_path):
|
||||||
|
data = self._read_data_file(file_path)
|
||||||
|
if data:
|
||||||
|
final_source_data[key] = data
|
||||||
|
|
||||||
|
# 生成海报
|
||||||
|
result = await self.generate_poster(
|
||||||
|
template_id=template_id,
|
||||||
|
source_data=final_source_data,
|
||||||
|
topic_name=topic_name,
|
||||||
|
image_dir=image_dir,
|
||||||
|
output_dir=topic_output_dir,
|
||||||
|
temperature=temperature
|
||||||
|
)
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"topic": topic_name,
|
||||||
|
"success": True,
|
||||||
|
"result": result
|
||||||
|
})
|
||||||
|
successful_count += 1
|
||||||
|
logger.info(f"成功生成海报: {topic_name}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
results.append({
|
||||||
|
"topic": topic_name,
|
||||||
|
"success": False,
|
||||||
|
"error": error_msg
|
||||||
|
})
|
||||||
|
failed_count += 1
|
||||||
|
logger.error(f"生成海报失败 {topic_name}: {error_msg}")
|
||||||
|
|
||||||
|
successful_topics = [r["topic"] for r in results if r["success"]]
|
||||||
|
failed_topics = [{"topic": r["topic"], "error": r["error"]} for r in results if not r["success"]]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"request_id": batch_request_id,
|
||||||
|
"template_id": template_id,
|
||||||
|
"total_topics": len(topic_dirs),
|
||||||
|
"successful_count": successful_count,
|
||||||
|
"failed_count": failed_count,
|
||||||
|
"output_base_dir": output_base_dir,
|
||||||
|
"successful_topics": successful_topics,
|
||||||
|
"failed_topics": failed_topics,
|
||||||
|
"detailed_results": results
|
||||||
|
}
|
||||||
|
|
||||||
|
def _find_topic_directories(self, base_path: str) -> List[Tuple[str, str]]:
|
||||||
|
"""查找topic目录"""
|
||||||
|
topic_dirs = []
|
||||||
|
base_path = Path(base_path)
|
||||||
|
|
||||||
|
if not base_path.exists():
|
||||||
|
return topic_dirs
|
||||||
|
|
||||||
|
for item in base_path.iterdir():
|
||||||
|
if item.is_dir() and item.name.startswith('topic_'):
|
||||||
|
article_path = item / 'article_judged.json'
|
||||||
|
if article_path.exists():
|
||||||
|
topic_dirs.append((str(item), item.name))
|
||||||
|
|
||||||
|
return topic_dirs
|
||||||
|
|
||||||
|
def _read_data_file(self, file_path: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""读取数据文件(简化版)"""
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
|
try:
|
||||||
|
return json.loads(content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# 简单的文本内容处理
|
||||||
|
return {"content": content}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取文件失败 {file_path}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def reload_config(self):
|
||||||
|
"""重新加载配置"""
|
||||||
|
self.config_manager.reload_config()
|
||||||
Loading…
x
Reference in New Issue
Block a user