新建了poster接口
This commit is contained in:
parent
3aa2775ec1
commit
5637305a36
Binary file not shown.
BIN
api/config/__pycache__/poster_config_manager.cpython-312.pyc
Normal file
BIN
api/config/__pycache__/poster_config_manager.cpython-312.pyc
Normal file
Binary file not shown.
214
api/config/poster_config_manager.py
Normal file
214
api/config/poster_config_manager.py
Normal file
@ -0,0 +1,214 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
海报配置管理器
|
||||
负责加载和管理海报相关的配置信息
|
||||
"""
|
||||
|
||||
import os
|
||||
import yaml
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PosterConfigManager:
|
||||
"""海报配置管理器"""
|
||||
|
||||
def __init__(self, config_file: Optional[str] = None):
|
||||
"""
|
||||
初始化配置管理器
|
||||
|
||||
Args:
|
||||
config_file: 配置文件路径,如果为空则使用默认路径
|
||||
"""
|
||||
if config_file is None:
|
||||
config_file = os.path.join(os.path.dirname(__file__), "poster_prompts.yaml")
|
||||
|
||||
self.config_file = config_file
|
||||
self.config = {}
|
||||
self.load_config()
|
||||
|
||||
def load_config(self):
|
||||
"""加载配置文件"""
|
||||
try:
|
||||
with open(self.config_file, 'r', encoding='utf-8') as f:
|
||||
self.config = yaml.safe_load(f)
|
||||
logger.info(f"成功加载配置文件: {self.config_file}")
|
||||
except FileNotFoundError:
|
||||
logger.error(f"配置文件不存在: {self.config_file}")
|
||||
self.config = self._get_default_config()
|
||||
except yaml.YAMLError as e:
|
||||
logger.error(f"解析配置文件失败: {e}")
|
||||
self.config = self._get_default_config()
|
||||
except Exception as e:
|
||||
logger.error(f"加载配置文件时发生未知错误: {e}")
|
||||
self.config = self._get_default_config()
|
||||
|
||||
def _get_default_config(self) -> Dict[str, Any]:
|
||||
"""获取默认配置"""
|
||||
return {
|
||||
"poster_prompts": {},
|
||||
"templates": {},
|
||||
"defaults": {
|
||||
"template": "vibrant",
|
||||
"temperature": 0.7,
|
||||
"output_dir": "result/posters",
|
||||
"image_dir": "/root/TravelContentCreator/data/images",
|
||||
"font_dir": "/root/TravelContentCreator/assets/font"
|
||||
}
|
||||
}
|
||||
|
||||
def get_template_list(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有可用的模板列表"""
|
||||
templates = self.config.get("templates", {})
|
||||
return [
|
||||
{
|
||||
"id": template_id,
|
||||
"name": template_info.get("name", template_id),
|
||||
"description": template_info.get("description", ""),
|
||||
"size": template_info.get("size", [900, 1200]),
|
||||
"required_fields": template_info.get("required_fields", []),
|
||||
"optional_fields": template_info.get("optional_fields", [])
|
||||
}
|
||||
for template_id, template_info in templates.items()
|
||||
]
|
||||
|
||||
def get_template_info(self, template_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取指定模板的详细信息"""
|
||||
return self.config.get("templates", {}).get(template_id)
|
||||
|
||||
def get_prompt_config(self, template_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取指定模板的提示词配置"""
|
||||
template_info = self.get_template_info(template_id)
|
||||
if not template_info:
|
||||
return None
|
||||
|
||||
prompt_key = template_info.get("prompt_key", template_id)
|
||||
return self.config.get("poster_prompts", {}).get(prompt_key)
|
||||
|
||||
def get_system_prompt(self, template_id: str) -> Optional[str]:
|
||||
"""获取系统提示词"""
|
||||
prompt_config = self.get_prompt_config(template_id)
|
||||
if prompt_config:
|
||||
return prompt_config.get("system_prompt")
|
||||
return None
|
||||
|
||||
def get_user_prompt_template(self, template_id: str) -> Optional[str]:
|
||||
"""获取用户提示词模板"""
|
||||
prompt_config = self.get_prompt_config(template_id)
|
||||
if prompt_config:
|
||||
return prompt_config.get("user_prompt_template")
|
||||
return None
|
||||
|
||||
def format_user_prompt(self, template_id: str, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
格式化用户提示词
|
||||
|
||||
Args:
|
||||
template_id: 模板ID
|
||||
**kwargs: 用于格式化的参数
|
||||
|
||||
Returns:
|
||||
格式化后的用户提示词
|
||||
"""
|
||||
template = self.get_user_prompt_template(template_id)
|
||||
if not template:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 确保参数中的字典类型转为JSON字符串
|
||||
formatted_kwargs = {}
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
formatted_kwargs[key] = json.dumps(value, ensure_ascii=False, indent=2)
|
||||
else:
|
||||
formatted_kwargs[key] = str(value)
|
||||
|
||||
return template.format(**formatted_kwargs)
|
||||
except KeyError as e:
|
||||
logger.error(f"格式化提示词失败,缺少参数: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"格式化提示词时发生错误: {e}")
|
||||
return None
|
||||
|
||||
def get_default_config(self, key: str) -> Any:
|
||||
"""获取默认配置值"""
|
||||
return self.config.get("defaults", {}).get(key)
|
||||
|
||||
def validate_template_content(self, template_id: str, content: Dict[str, Any]) -> tuple[bool, List[str]]:
|
||||
"""
|
||||
验证模板内容是否符合要求
|
||||
|
||||
Args:
|
||||
template_id: 模板ID
|
||||
content: 要验证的内容
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息列表)
|
||||
"""
|
||||
template_info = self.get_template_info(template_id)
|
||||
if not template_info:
|
||||
return False, [f"未知的模板ID: {template_id}"]
|
||||
|
||||
errors = []
|
||||
required_fields = template_info.get("required_fields", [])
|
||||
|
||||
# 检查必填字段
|
||||
for field in required_fields:
|
||||
if field not in content:
|
||||
errors.append(f"缺少必填字段: {field}")
|
||||
elif not content[field]:
|
||||
errors.append(f"必填字段 {field} 不能为空")
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
def get_template_class(self, template_id: str):
|
||||
"""
|
||||
动态获取模板类
|
||||
|
||||
Args:
|
||||
template_id: 模板ID
|
||||
|
||||
Returns:
|
||||
模板类
|
||||
"""
|
||||
template_mapping = {
|
||||
"vibrant": "poster.templates.vibrant_template.VibrantTemplate",
|
||||
"business": "poster.templates.business_template.BusinessTemplate"
|
||||
}
|
||||
|
||||
class_path = template_mapping.get(template_id)
|
||||
if not class_path:
|
||||
raise ValueError(f"未支持的模板类型: {template_id}")
|
||||
|
||||
# 动态导入模板类
|
||||
module_path, class_name = class_path.rsplit(".", 1)
|
||||
try:
|
||||
module = __import__(module_path, fromlist=[class_name])
|
||||
return getattr(module, class_name)
|
||||
except ImportError as e:
|
||||
logger.error(f"导入模板类失败: {e}")
|
||||
raise ValueError(f"无法加载模板类: {template_id}")
|
||||
|
||||
def reload_config(self):
|
||||
"""重新加载配置"""
|
||||
logger.info("正在重新加载配置...")
|
||||
self.load_config()
|
||||
|
||||
|
||||
# 全局配置管理器实例
|
||||
_config_manager = None
|
||||
|
||||
|
||||
def get_poster_config_manager() -> PosterConfigManager:
|
||||
"""获取全局配置管理器实例"""
|
||||
global _config_manager
|
||||
if _config_manager is None:
|
||||
_config_manager = PosterConfigManager()
|
||||
return _config_manager
|
||||
102
api/config/poster_prompts.yaml
Normal file
102
api/config/poster_prompts.yaml
Normal file
@ -0,0 +1,102 @@
|
||||
# 海报生成提示词配置文件
|
||||
poster_prompts:
|
||||
vibrant:
|
||||
system_prompt: |
|
||||
你是一名专业的海报设计师,专门设计宣传海报。你现在要根据用户提供的信息,生成适合Vibrant模板的海报内容。
|
||||
|
||||
## Vibrant模板特点:
|
||||
- 单图背景,毛玻璃渐变效果
|
||||
- 两栏布局(左栏内容,右栏价格)
|
||||
- 适合展示套餐内容和价格信息
|
||||
|
||||
## 你需要生成的数据结构包含以下字段:
|
||||
|
||||
**必填字段:**
|
||||
1. `title`: 主标题(8-12字符,体现产品特色)
|
||||
2. `slogan`: 副标题/宣传语(10-20字符,吸引人的描述)
|
||||
3. `price`: 价格数字(纯数字,不含符号),如果材料中没有价格,则用"欢迎咨询"替代
|
||||
4. `ticket_type`: 票种类型(如"成人票"、"套餐票"、"夜场票"等)
|
||||
5. `content_button`: 内容按钮文字(通常为"套餐内容"、"包含项目"等)
|
||||
6. `content_items`: 套餐内容列表(3-5个项目,每项5-15字符,不要只包含项目名称,要做合适的美化,可以适当省略)
|
||||
|
||||
**可选字段:**
|
||||
7. `remarks`: 备注信息(1-3条,每条10-20字符)
|
||||
8. `tag`: 标签(1条, 如"#限时优惠"等)
|
||||
9. `pagination`: 分页信息(如"1/3",可为空)
|
||||
|
||||
## 内容创作要求:
|
||||
1. 套餐内容要具体实用:明确说明包含的服务、时间、数量
|
||||
2. 价格要有吸引力:突出性价比和优惠信息
|
||||
|
||||
## 输出格式:
|
||||
请严格按照JSON格式输出,不要有任何额外内容。
|
||||
|
||||
user_prompt_template: |
|
||||
请根据以下信息,生成适合在旅游海报上展示的文案:
|
||||
|
||||
## 景区信息
|
||||
{scenic_info}
|
||||
|
||||
## 产品信息
|
||||
{product_info}
|
||||
|
||||
## 推文信息
|
||||
{tweet_info}
|
||||
|
||||
请提取关键信息并整合成一个JSON对象,包含title、slogan、price、ticket_type、content_items、remarks和tag字段。
|
||||
|
||||
business:
|
||||
system_prompt: |
|
||||
你是一名专业的商务海报设计师。你需要根据提供的信息生成适合Business模板的海报内容。
|
||||
|
||||
## Business模板特点:
|
||||
- 商务风格,简洁专业
|
||||
- 突出核心信息和价值主张
|
||||
- 适合企业服务推广
|
||||
|
||||
## 生成字段结构:
|
||||
1. `title`: 服务标题(6-10字符)
|
||||
2. `subtitle`: 副标题(12-20字符)
|
||||
3. `features`: 核心特性列表(3-4个特性)
|
||||
4. `price`: 价格信息
|
||||
5. `contact`: 联系方式信息
|
||||
|
||||
请以JSON格式输出结果。
|
||||
|
||||
user_prompt_template: |
|
||||
请为以下商务信息生成海报内容:
|
||||
|
||||
## 服务信息
|
||||
{service_info}
|
||||
|
||||
## 目标客户
|
||||
{target_audience}
|
||||
|
||||
## 核心卖点
|
||||
{key_points}
|
||||
|
||||
# 模板配置
|
||||
templates:
|
||||
vibrant:
|
||||
name: "Vibrant活力模板"
|
||||
description: "适合旅游景点、娱乐活动的活力海报模板"
|
||||
size: [900, 1200]
|
||||
required_fields: ["title", "slogan", "price", "ticket_type", "content_items"]
|
||||
optional_fields: ["remarks", "tag", "pagination"]
|
||||
prompt_key: "vibrant"
|
||||
|
||||
business:
|
||||
name: "Business商务模板"
|
||||
description: "适合企业服务、B2B推广的商务海报模板"
|
||||
size: [1080, 1920]
|
||||
required_fields: ["title", "subtitle", "features", "price"]
|
||||
optional_fields: ["contact"]
|
||||
prompt_key: "business"
|
||||
|
||||
# 默认配置
|
||||
defaults:
|
||||
template: "vibrant"
|
||||
temperature: 0.7
|
||||
output_dir: "result/posters"
|
||||
image_dir: "/root/TravelContentCreator/data/images"
|
||||
font_dir: "/root/TravelContentCreator/assets/font"
|
||||
@ -56,11 +56,12 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
# 导入路由
|
||||
from api.routers import tweet, poster, prompt, document, data, integration, content_integration
|
||||
from api.routers import tweet, poster, poster_unified, prompt, document, data, integration, content_integration
|
||||
|
||||
# 包含路由
|
||||
app.include_router(tweet.router, prefix="/api/v1/tweet", tags=["tweet"])
|
||||
app.include_router(poster.router, prefix="/api/v1/poster", tags=["poster"])
|
||||
app.include_router(poster_unified.router, prefix="/api/v2/poster", tags=["poster-unified"])
|
||||
app.include_router(prompt.router, prefix="/api/v1/prompt", tags=["prompt"])
|
||||
app.include_router(document.router, prefix="/api/v1/document", tags=["document"])
|
||||
app.include_router(data.router, prefix="/api/v1", tags=["data"])
|
||||
|
||||
Binary file not shown.
BIN
api/models/__pycache__/vibrant_poster.cpython-312.pyc
Normal file
BIN
api/models/__pycache__/vibrant_poster.cpython-312.pyc
Normal file
Binary file not shown.
331
api/models/vibrant_poster.py
Normal file
331
api/models/vibrant_poster.py
Normal file
@ -0,0 +1,331 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
海报API通用模型定义
|
||||
支持多种模板类型的海报生成
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# 基础模板信息
|
||||
class TemplateInfo(BaseModel):
|
||||
"""模板信息模型"""
|
||||
id: str = Field(..., description="模板ID")
|
||||
name: str = Field(..., description="模板名称")
|
||||
description: str = Field(..., description="模板描述")
|
||||
size: List[int] = Field(..., description="模板尺寸 [宽, 高]")
|
||||
required_fields: List[str] = Field(..., description="必填字段列表")
|
||||
optional_fields: List[str] = Field(default=[], description="可选字段列表")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"id": "vibrant",
|
||||
"name": "Vibrant活力模板",
|
||||
"description": "适合旅游景点、娱乐活动的活力海报模板",
|
||||
"size": [900, 1200],
|
||||
"required_fields": ["title", "slogan", "price", "ticket_type", "content_items"],
|
||||
"optional_fields": ["remarks", "tag", "pagination"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# 通用内容模型(支持任意字段)
|
||||
class PosterContent(BaseModel):
|
||||
"""通用海报内容模型,支持动态字段"""
|
||||
|
||||
class Config:
|
||||
extra = "allow" # 允许额外字段
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"title": "海洋奇幻世界",
|
||||
"slogan": "探索深海秘境,感受蓝色奇迹的无限魅力",
|
||||
"price": "299",
|
||||
"ticket_type": "成人票",
|
||||
"content_items": [
|
||||
"海洋馆门票1张(含所有展区)",
|
||||
"海豚表演VIP座位"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# 保持向后兼容的Vibrant内容模型
|
||||
class VibrantPosterContent(PosterContent):
|
||||
"""Vibrant海报内容模型(向后兼容)"""
|
||||
title: Optional[str] = Field(None, description="主标题(8-12字符)")
|
||||
slogan: Optional[str] = Field(None, description="副标题/宣传语(10-20字符)")
|
||||
price: Optional[str] = Field(None, description="价格,如果没有价格则用'欢迎咨询'")
|
||||
ticket_type: Optional[str] = Field(None, description="票种类型(如成人票、套餐票等)")
|
||||
content_button: Optional[str] = Field(None, description="内容按钮文字")
|
||||
content_items: Optional[List[str]] = Field(None, description="套餐内容列表(3-5个项目)")
|
||||
remarks: Optional[List[str]] = Field(None, description="备注信息(1-3条)")
|
||||
tag: Optional[str] = Field(None, description="标签")
|
||||
pagination: Optional[str] = Field(None, description="分页信息")
|
||||
|
||||
|
||||
# 通用海报生成请求
|
||||
class PosterGenerationRequest(BaseModel):
|
||||
"""通用海报生成请求模型"""
|
||||
# 模板相关
|
||||
template_id: str = Field(..., description="模板ID")
|
||||
|
||||
# 内容相关(二选一)
|
||||
content: Optional[Dict[str, Any]] = Field(None, description="直接提供的海报内容")
|
||||
source_data: Optional[Dict[str, Any]] = Field(None, description="源数据,用于AI生成内容")
|
||||
|
||||
# 生成参数
|
||||
topic_name: Optional[str] = Field(None, description="主题名称,用于文件命名")
|
||||
image_path: Optional[str] = Field(None, description="指定图片路径,如果不提供则随机选择")
|
||||
image_dir: Optional[str] = Field(None, description="图片目录,如果不提供使用默认目录")
|
||||
output_dir: Optional[str] = Field(None, description="输出目录,如果不提供使用默认目录")
|
||||
|
||||
# AI参数
|
||||
temperature: Optional[float] = Field(default=0.7, description="AI生成温度参数")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"template_id": "vibrant",
|
||||
"source_data": {
|
||||
"scenic_info": {
|
||||
"name": "天津冒险湾",
|
||||
"location": "天津市",
|
||||
"type": "水上乐园"
|
||||
},
|
||||
"product_info": {
|
||||
"name": "冒险湾门票",
|
||||
"price": 299,
|
||||
"type": "成人票"
|
||||
},
|
||||
"tweet_info": {
|
||||
"title": "夏日清凉首选",
|
||||
"content": "天津冒险湾水上乐园,多种刺激项目等你来挑战..."
|
||||
}
|
||||
},
|
||||
"topic_name": "天津冒险湾",
|
||||
"temperature": 0.7
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# 内容生成请求
|
||||
class ContentGenerationRequest(BaseModel):
|
||||
"""内容生成请求模型"""
|
||||
template_id: str = Field(..., description="模板ID")
|
||||
source_data: Dict[str, Any] = Field(..., description="源数据,用于AI生成内容")
|
||||
temperature: Optional[float] = Field(default=0.7, description="AI生成温度参数")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"template_id": "vibrant",
|
||||
"source_data": {
|
||||
"scenic_info": {"name": "天津冒险湾", "type": "水上乐园"},
|
||||
"product_info": {"name": "门票", "price": 299},
|
||||
"tweet_info": {"title": "夏日清凉", "content": "水上乐园体验"}
|
||||
},
|
||||
"temperature": 0.7
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# 保持向后兼容
|
||||
class VibrantPosterRequest(PosterGenerationRequest):
|
||||
"""Vibrant海报生成请求模型(向后兼容)"""
|
||||
template_id: str = Field(default="vibrant", description="模板ID,默认为vibrant")
|
||||
|
||||
# 兼容旧的字段结构
|
||||
scenic_info: Optional[Dict[str, Any]] = Field(None, description="景区信息")
|
||||
product_info: Optional[Dict[str, Any]] = Field(None, description="产品信息")
|
||||
tweet_info: Optional[Dict[str, Any]] = Field(None, description="推文信息")
|
||||
|
||||
def __init__(self, **data):
|
||||
# 转换旧格式到新格式
|
||||
if 'scenic_info' in data or 'product_info' in data or 'tweet_info' in data:
|
||||
source_data = {}
|
||||
for key in ['scenic_info', 'product_info', 'tweet_info']:
|
||||
if key in data and data[key] is not None:
|
||||
source_data[key] = data.pop(key)
|
||||
if source_data and 'source_data' not in data:
|
||||
data['source_data'] = source_data
|
||||
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
# 标准API响应基础类
|
||||
class BaseAPIResponse(BaseModel):
|
||||
"""API响应基础模型"""
|
||||
success: bool = Field(..., description="操作是否成功")
|
||||
message: str = Field(default="", description="响应消息")
|
||||
request_id: str = Field(..., description="请求ID")
|
||||
timestamp: str = Field(..., description="响应时间戳")
|
||||
|
||||
|
||||
# 通用海报生成响应
|
||||
class PosterGenerationResponse(BaseAPIResponse):
|
||||
"""通用海报生成响应模型"""
|
||||
data: Optional[Dict[str, Any]] = Field(None, description="响应数据")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"success": True,
|
||||
"message": "海报生成成功",
|
||||
"request_id": "poster-20240715-123456-a1b2c3d4",
|
||||
"timestamp": "2024-07-15T12:34:56Z",
|
||||
"data": {
|
||||
"template_id": "vibrant",
|
||||
"topic_name": "天津冒险湾",
|
||||
"poster_path": "/result/posters/vibrant_海洋奇幻世界_20240715_123456.png",
|
||||
"content": {
|
||||
"title": "海洋奇幻世界",
|
||||
"slogan": "探索深海秘境,感受蓝色奇迹的无限魅力",
|
||||
"price": "299"
|
||||
},
|
||||
"metadata": {
|
||||
"image_used": "/data/images/ocean_park.jpg",
|
||||
"generation_method": "ai_generated",
|
||||
"template_size": [900, 1200],
|
||||
"processing_time": 3.2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# 内容生成响应
|
||||
class ContentGenerationResponse(BaseAPIResponse):
|
||||
"""内容生成响应模型"""
|
||||
data: Optional[Dict[str, Any]] = Field(None, description="生成的内容")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"success": True,
|
||||
"message": "内容生成成功",
|
||||
"request_id": "content-20240715-123456-a1b2c3d4",
|
||||
"timestamp": "2024-07-15T12:34:56Z",
|
||||
"data": {
|
||||
"template_id": "vibrant",
|
||||
"content": {
|
||||
"title": "冒险湾水世界",
|
||||
"slogan": "夏日激情体验,清凉无限乐趣",
|
||||
"price": "299",
|
||||
"ticket_type": "成人票"
|
||||
},
|
||||
"metadata": {
|
||||
"generation_method": "ai_generated",
|
||||
"temperature": 0.7
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# 模板列表响应
|
||||
class TemplateListResponse(BaseAPIResponse):
|
||||
"""模板列表响应模型"""
|
||||
data: Optional[List[TemplateInfo]] = Field(None, description="模板列表")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"success": True,
|
||||
"message": "获取模板列表成功",
|
||||
"request_id": "templates-20240715-123456",
|
||||
"timestamp": "2024-07-15T12:34:56Z",
|
||||
"data": [
|
||||
{
|
||||
"id": "vibrant",
|
||||
"name": "Vibrant活力模板",
|
||||
"description": "适合旅游景点、娱乐活动的活力海报模板",
|
||||
"size": [900, 1200],
|
||||
"required_fields": ["title", "slogan", "price"],
|
||||
"optional_fields": ["remarks", "tag"]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# 保持向后兼容的响应模型
|
||||
class VibrantPosterResponse(BaseModel):
|
||||
"""Vibrant海报生成响应模型(向后兼容)"""
|
||||
request_id: str = Field(..., description="请求ID")
|
||||
topic_name: str = Field(..., description="主题名称")
|
||||
poster_path: str = Field(..., description="生成的海报文件路径")
|
||||
generated_content: Dict[str, Any] = Field(..., description="生成或使用的海报内容")
|
||||
image_used: str = Field(..., description="使用的图片路径")
|
||||
generation_method: str = Field(..., description="生成方式:direct(直接使用提供内容)或ai_generated(AI生成)")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"request_id": "vibrant-20240715-123456-a1b2c3d4",
|
||||
"topic_name": "天津冒险湾",
|
||||
"poster_path": "/result/posters/vibrant_海洋奇幻世界_20240715_123456.png",
|
||||
"generated_content": {
|
||||
"title": "海洋奇幻世界",
|
||||
"slogan": "探索深海秘境,感受蓝色奇迹的无限魅力",
|
||||
"price": "299",
|
||||
"ticket_type": "成人票",
|
||||
"content_items": ["海洋馆门票1张", "海豚表演VIP座位"]
|
||||
},
|
||||
"image_used": "/data/images/ocean_park.jpg",
|
||||
"generation_method": "ai_generated"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class BatchVibrantPosterRequest(BaseModel):
|
||||
"""批量Vibrant海报生成请求模型"""
|
||||
base_path: str = Field(..., description="包含多个topic目录的基础路径")
|
||||
image_dir: Optional[str] = Field(None, description="图片目录")
|
||||
scenic_info_file: Optional[str] = Field(None, description="景区信息文件路径")
|
||||
product_info_file: Optional[str] = Field(None, description="产品信息文件路径")
|
||||
output_base: Optional[str] = Field(default="result/posters", description="输出基础目录")
|
||||
parallel_count: Optional[int] = Field(default=3, description="并发处理数量")
|
||||
temperature: Optional[float] = Field(default=0.7, description="AI生成温度参数")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"base_path": "/root/TravelContentCreator/result/run_20250710_165327",
|
||||
"image_dir": "/root/TravelContentCreator/data/images",
|
||||
"scenic_info_file": "/root/TravelContentCreator/resource/data/Object/天津冒险湾.txt",
|
||||
"product_info_file": "/root/TravelContentCreator/resource/data/Product/product.bak",
|
||||
"output_base": "result/posters",
|
||||
"parallel_count": 3,
|
||||
"temperature": 0.7
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class BatchVibrantPosterResponse(BaseModel):
|
||||
"""批量Vibrant海报生成响应模型"""
|
||||
request_id: str = Field(..., description="批量处理请求ID")
|
||||
total_topics: int = Field(..., description="总共处理的topic数量")
|
||||
successful_count: int = Field(..., description="成功生成的海报数量")
|
||||
failed_count: int = Field(..., description="失败的海报数量")
|
||||
output_base_dir: str = Field(..., description="输出基础目录")
|
||||
successful_topics: List[str] = Field(..., description="成功处理的topic列表")
|
||||
failed_topics: List[Dict[str, str]] = Field(..., description="失败的topic及错误信息")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"request_id": "batch-vibrant-20240715-123456",
|
||||
"total_topics": 5,
|
||||
"successful_count": 4,
|
||||
"failed_count": 1,
|
||||
"output_base_dir": "/result/posters/run_20250710_165327",
|
||||
"successful_topics": ["topic_1", "topic_2", "topic_3", "topic_4"],
|
||||
"failed_topics": [
|
||||
{"topic": "topic_5", "error": "图片文件不存在"}
|
||||
]
|
||||
}
|
||||
}
|
||||
BIN
api/routers/__pycache__/poster_unified.cpython-312.pyc
Normal file
BIN
api/routers/__pycache__/poster_unified.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
api/routers/__pycache__/vibrant_poster.cpython-312.pyc
Normal file
BIN
api/routers/__pycache__/vibrant_poster.cpython-312.pyc
Normal file
Binary file not shown.
314
api/routers/poster_unified.py
Normal file
314
api/routers/poster_unified.py
Normal file
@ -0,0 +1,314 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
统一海报API路由
|
||||
支持多种模板类型的海报生成,配置化管理
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from core.ai import AIAgent
|
||||
from api.services.poster_service import UnifiedPosterService
|
||||
from api.models.vibrant_poster import (
|
||||
PosterGenerationRequest, PosterGenerationResponse,
|
||||
ContentGenerationRequest, ContentGenerationResponse,
|
||||
TemplateListResponse, TemplateInfo,
|
||||
BaseAPIResponse
|
||||
)
|
||||
|
||||
# 从依赖注入模块导入依赖
|
||||
from api.dependencies import get_ai_agent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 创建路由
|
||||
router = APIRouter()
|
||||
|
||||
# 依赖注入函数
|
||||
def get_unified_poster_service(
|
||||
ai_agent: AIAgent = Depends(get_ai_agent)
|
||||
) -> UnifiedPosterService:
|
||||
"""获取统一海报服务"""
|
||||
return UnifiedPosterService(ai_agent)
|
||||
|
||||
|
||||
def create_response(success: bool, message: str, data: Any = None, request_id: str = None) -> Dict[str, Any]:
|
||||
"""创建标准响应"""
|
||||
if request_id is None:
|
||||
request_id = f"req-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}"
|
||||
|
||||
return {
|
||||
"success": success,
|
||||
"message": message,
|
||||
"request_id": request_id,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"data": data
|
||||
}
|
||||
|
||||
|
||||
@router.get("/templates", response_model=TemplateListResponse, summary="获取所有可用模板")
|
||||
async def get_templates(
|
||||
service: UnifiedPosterService = Depends(get_unified_poster_service)
|
||||
):
|
||||
"""
|
||||
获取所有可用的海报模板列表
|
||||
|
||||
返回每个模板的详细信息,包括:
|
||||
- 模板ID和名称
|
||||
- 模板描述
|
||||
- 模板尺寸
|
||||
- 必填字段和可选字段
|
||||
"""
|
||||
try:
|
||||
templates = service.get_available_templates()
|
||||
response_data = create_response(
|
||||
success=True,
|
||||
message="获取模板列表成功",
|
||||
data=templates
|
||||
)
|
||||
return TemplateListResponse(**response_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取模板列表失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"获取模板列表失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/templates/{template_id}", response_model=BaseAPIResponse, summary="获取指定模板信息")
|
||||
async def get_template_info(
|
||||
template_id: str,
|
||||
service: UnifiedPosterService = Depends(get_unified_poster_service)
|
||||
):
|
||||
"""
|
||||
获取指定模板的详细信息
|
||||
|
||||
参数:
|
||||
- **template_id**: 模板ID
|
||||
"""
|
||||
try:
|
||||
template_info = service.get_template_info(template_id)
|
||||
if not template_info:
|
||||
raise HTTPException(status_code=404, detail=f"模板 {template_id} 不存在")
|
||||
|
||||
response_data = create_response(
|
||||
success=True,
|
||||
message="获取模板信息成功",
|
||||
data=template_info.dict()
|
||||
)
|
||||
return BaseAPIResponse(**response_data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取模板信息失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"获取模板信息失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/content/generate", response_model=ContentGenerationResponse, summary="生成海报内容")
|
||||
async def generate_content(
|
||||
request: ContentGenerationRequest,
|
||||
service: UnifiedPosterService = Depends(get_unified_poster_service)
|
||||
):
|
||||
"""
|
||||
根据源数据生成海报内容,不生成实际图片
|
||||
|
||||
用于:
|
||||
1. 预览生成的内容
|
||||
2. 调试和测试内容生成
|
||||
3. 分步骤生成(先生成内容,再生成图片)
|
||||
|
||||
参数:
|
||||
- **template_id**: 模板ID
|
||||
- **source_data**: 源数据,用于AI生成内容
|
||||
- **temperature**: AI生成温度参数
|
||||
"""
|
||||
try:
|
||||
content = await service.generate_content(
|
||||
template_id=request.template_id,
|
||||
source_data=request.source_data,
|
||||
temperature=request.temperature
|
||||
)
|
||||
|
||||
response_data = create_response(
|
||||
success=True,
|
||||
message="内容生成成功",
|
||||
data={
|
||||
"template_id": request.template_id,
|
||||
"content": content,
|
||||
"metadata": {
|
||||
"generation_method": "ai_generated",
|
||||
"temperature": request.temperature
|
||||
}
|
||||
}
|
||||
)
|
||||
return ContentGenerationResponse(**response_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成内容失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"生成内容失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/generate", response_model=PosterGenerationResponse, summary="生成海报")
|
||||
async def generate_poster(
|
||||
request: PosterGenerationRequest,
|
||||
service: UnifiedPosterService = Depends(get_unified_poster_service)
|
||||
):
|
||||
"""
|
||||
生成海报图片
|
||||
|
||||
支持两种模式:
|
||||
1. 直接提供内容(content字段)
|
||||
2. 提供源数据让AI生成内容(source_data字段)
|
||||
|
||||
参数:
|
||||
- **template_id**: 模板ID
|
||||
- **content**: 直接提供的海报内容(可选)
|
||||
- **source_data**: 源数据,用于AI生成内容(可选)
|
||||
- **topic_name**: 主题名称,用于文件命名
|
||||
- **image_path**: 指定图片路径(可选)
|
||||
- **image_dir**: 图片目录(可选)
|
||||
- **output_dir**: 输出目录(可选)
|
||||
- **temperature**: AI生成温度参数
|
||||
"""
|
||||
try:
|
||||
result = await service.generate_poster(
|
||||
template_id=request.template_id,
|
||||
content=request.content,
|
||||
source_data=request.source_data,
|
||||
topic_name=request.topic_name,
|
||||
image_path=request.image_path,
|
||||
image_dir=request.image_dir,
|
||||
output_dir=request.output_dir,
|
||||
temperature=request.temperature
|
||||
)
|
||||
|
||||
response_data = create_response(
|
||||
success=True,
|
||||
message="海报生成成功",
|
||||
data=result,
|
||||
request_id=result["request_id"]
|
||||
)
|
||||
return PosterGenerationResponse(**response_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成海报失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"生成海报失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/batch", response_model=BaseAPIResponse, summary="批量生成海报")
|
||||
async def batch_generate_posters(
|
||||
template_id: str,
|
||||
base_path: str,
|
||||
image_dir: str = None,
|
||||
source_files: Dict[str, str] = None,
|
||||
output_base: str = "result/posters",
|
||||
parallel_count: int = 3,
|
||||
temperature: float = 0.7,
|
||||
service: UnifiedPosterService = Depends(get_unified_poster_service)
|
||||
):
|
||||
"""
|
||||
批量生成海报
|
||||
|
||||
自动扫描指定目录下的topic文件夹,为每个topic生成海报。
|
||||
|
||||
参数:
|
||||
- **template_id**: 模板ID
|
||||
- **base_path**: 包含多个topic目录的基础路径
|
||||
- **image_dir**: 图片目录(可选)
|
||||
- **source_files**: 源文件配置字典(可选)
|
||||
- **output_base**: 输出基础目录
|
||||
- **parallel_count**: 并发处理数量
|
||||
- **temperature**: AI生成温度参数
|
||||
"""
|
||||
try:
|
||||
result = await service.batch_generate_posters(
|
||||
template_id=template_id,
|
||||
base_path=base_path,
|
||||
image_dir=image_dir,
|
||||
source_files=source_files or {},
|
||||
output_base=output_base,
|
||||
parallel_count=parallel_count,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
response_data = create_response(
|
||||
success=True,
|
||||
message=f"批量生成完成,成功: {result['successful_count']}, 失败: {result['failed_count']}",
|
||||
data=result,
|
||||
request_id=result["request_id"]
|
||||
)
|
||||
return BaseAPIResponse(**response_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量生成海报失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"批量生成海报失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/config/reload", response_model=BaseAPIResponse, summary="重新加载配置")
|
||||
async def reload_config(
|
||||
service: UnifiedPosterService = Depends(get_unified_poster_service)
|
||||
):
|
||||
"""
|
||||
重新加载海报配置
|
||||
|
||||
用于在不重启服务的情况下更新配置,包括:
|
||||
- 提示词模板
|
||||
- 模板配置
|
||||
- 默认参数
|
||||
"""
|
||||
try:
|
||||
service.reload_config()
|
||||
response_data = create_response(
|
||||
success=True,
|
||||
message="配置重新加载成功"
|
||||
)
|
||||
return BaseAPIResponse(**response_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"重新加载配置失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"重新加载配置失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/health", summary="健康检查")
|
||||
async def health_check():
|
||||
"""服务健康检查"""
|
||||
return create_response(
|
||||
success=True,
|
||||
message="统一海报服务运行正常",
|
||||
data={
|
||||
"service": "unified_poster",
|
||||
"status": "healthy",
|
||||
"version": "2.0.0"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config", summary="获取服务配置")
|
||||
async def get_service_config(
|
||||
service: UnifiedPosterService = Depends(get_unified_poster_service)
|
||||
):
|
||||
"""获取服务配置信息"""
|
||||
try:
|
||||
config_info = {
|
||||
"default_image_dir": service.config_manager.get_default_config("image_dir"),
|
||||
"default_output_dir": service.config_manager.get_default_config("output_dir"),
|
||||
"default_font_dir": service.config_manager.get_default_config("font_dir"),
|
||||
"default_template": service.config_manager.get_default_config("template"),
|
||||
"supported_image_formats": ["png", "jpg", "jpeg", "webp"],
|
||||
"available_templates": len(service.get_available_templates())
|
||||
}
|
||||
|
||||
response_data = create_response(
|
||||
success=True,
|
||||
message="获取配置成功",
|
||||
data=config_info
|
||||
)
|
||||
return BaseAPIResponse(**response_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取配置失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"获取配置失败: {str(e)}")
|
||||
BIN
api/services/__pycache__/poster_service.cpython-312.pyc
Normal file
BIN
api/services/__pycache__/poster_service.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
api/services/__pycache__/vibrant_poster.cpython-312.pyc
Normal file
BIN
api/services/__pycache__/vibrant_poster.cpython-312.pyc
Normal file
Binary file not shown.
409
api/services/poster_service.py
Normal file
409
api/services/poster_service.py
Normal file
@ -0,0 +1,409 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
通用海报服务层
|
||||
支持多种模板类型的海报生成,配置化管理
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import random
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
from core.ai.ai_agent import AIAgent
|
||||
from api.config.poster_config_manager import get_poster_config_manager
|
||||
from api.models.vibrant_poster import TemplateInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnifiedPosterService:
|
||||
"""统一海报服务类"""
|
||||
|
||||
def __init__(self, ai_agent: AIAgent):
|
||||
"""
|
||||
初始化统一海报服务
|
||||
|
||||
Args:
|
||||
ai_agent: AI代理
|
||||
"""
|
||||
self.ai_agent = ai_agent
|
||||
self.config_manager = get_poster_config_manager()
|
||||
|
||||
def get_available_templates(self) -> List[TemplateInfo]:
|
||||
"""获取所有可用的模板列表"""
|
||||
template_list = self.config_manager.get_template_list()
|
||||
return [TemplateInfo(**template) for template in template_list]
|
||||
|
||||
def get_template_info(self, template_id: str) -> Optional[TemplateInfo]:
|
||||
"""获取指定模板的信息"""
|
||||
template_info = self.config_manager.get_template_info(template_id)
|
||||
if template_info:
|
||||
return TemplateInfo(
|
||||
id=template_id,
|
||||
name=template_info.get("name", template_id),
|
||||
description=template_info.get("description", ""),
|
||||
size=template_info.get("size", [900, 1200]),
|
||||
required_fields=template_info.get("required_fields", []),
|
||||
optional_fields=template_info.get("optional_fields", [])
|
||||
)
|
||||
return None
|
||||
|
||||
async def generate_content(self, template_id: str, source_data: Dict[str, Any],
|
||||
temperature: float = 0.7) -> Dict[str, Any]:
|
||||
"""
|
||||
生成海报内容
|
||||
|
||||
Args:
|
||||
template_id: 模板ID
|
||||
source_data: 源数据
|
||||
temperature: AI生成温度参数
|
||||
|
||||
Returns:
|
||||
生成的内容字典
|
||||
"""
|
||||
logger.info(f"正在为模板 {template_id} 生成内容...")
|
||||
|
||||
# 获取系统提示词
|
||||
system_prompt = self.config_manager.get_system_prompt(template_id)
|
||||
if not system_prompt:
|
||||
raise ValueError(f"模板 {template_id} 没有配置系统提示词")
|
||||
|
||||
# 格式化用户提示词
|
||||
user_prompt = self.config_manager.format_user_prompt(template_id, **source_data)
|
||||
if not user_prompt:
|
||||
raise ValueError(f"模板 {template_id} 用户提示词格式化失败")
|
||||
|
||||
try:
|
||||
# 调用AI生成内容
|
||||
response, _, _, _ = await self.ai_agent.generate_text(
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
temperature=temperature,
|
||||
stage=f"海报内容生成-{template_id}"
|
||||
)
|
||||
|
||||
# 解析JSON响应
|
||||
json_start = response.find('{')
|
||||
json_end = response.rfind('}') + 1
|
||||
|
||||
if json_start >= 0 and json_end > json_start:
|
||||
json_str = response[json_start:json_end]
|
||||
content_dict = json.loads(json_str)
|
||||
logger.info(f"AI成功生成内容: {content_dict}")
|
||||
|
||||
# 确保所有值都是字符串类型(除了列表)
|
||||
for key, value in content_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
content_dict[key] = str(value)
|
||||
elif isinstance(value, list):
|
||||
content_dict[key] = [str(item) if isinstance(item, (int, float)) else item for item in value]
|
||||
|
||||
return content_dict
|
||||
else:
|
||||
logger.error(f"无法在AI响应中找到JSON对象: {response}")
|
||||
raise ValueError("AI响应格式不正确")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"无法解析AI响应为JSON: {e}")
|
||||
raise ValueError("AI响应JSON解析失败")
|
||||
except Exception as e:
|
||||
logger.error(f"调用AI生成内容时发生错误: {e}")
|
||||
raise
|
||||
|
||||
def select_random_image(self, image_dir: Optional[str] = None) -> str:
|
||||
"""从指定目录随机选择一张图片"""
|
||||
if image_dir is None:
|
||||
image_dir = self.config_manager.get_default_config("image_dir")
|
||||
|
||||
try:
|
||||
image_files = [f for f in os.listdir(image_dir)
|
||||
if f.lower().endswith(('png', 'jpg', 'jpeg', 'webp'))]
|
||||
if not image_files:
|
||||
raise ValueError(f"在目录 {image_dir} 中未找到任何图片文件")
|
||||
|
||||
random_image_name = random.choice(image_files)
|
||||
image_path = os.path.join(image_dir, random_image_name)
|
||||
logger.info(f"随机选择图片: {image_path}")
|
||||
return image_path
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"图片目录不存在: {image_dir}")
|
||||
|
||||
def validate_content(self, template_id: str, content: Dict[str, Any]) -> None:
|
||||
"""验证内容是否符合模板要求"""
|
||||
is_valid, errors = self.config_manager.validate_template_content(template_id, content)
|
||||
if not is_valid:
|
||||
raise ValueError(f"内容验证失败: {', '.join(errors)}")
|
||||
|
||||
async def generate_poster(self,
|
||||
template_id: str,
|
||||
content: Optional[Dict[str, Any]] = None,
|
||||
source_data: Optional[Dict[str, Any]] = None,
|
||||
topic_name: Optional[str] = None,
|
||||
image_path: Optional[str] = None,
|
||||
image_dir: Optional[str] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
temperature: float = 0.7) -> Dict[str, Any]:
|
||||
"""
|
||||
生成海报
|
||||
|
||||
Args:
|
||||
template_id: 模板ID
|
||||
content: 直接提供的内容(可选)
|
||||
source_data: 源数据,用于AI生成内容(可选)
|
||||
topic_name: 主题名称
|
||||
image_path: 指定图片路径
|
||||
image_dir: 图片目录
|
||||
output_dir: 输出目录
|
||||
temperature: AI生成温度参数
|
||||
|
||||
Returns:
|
||||
生成结果字典
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
logger.info(f"开始生成海报,模板: {template_id}, 主题: {topic_name}")
|
||||
|
||||
# 生成请求ID
|
||||
request_id = f"poster-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}"
|
||||
|
||||
# 获取模板信息
|
||||
template_info = self.get_template_info(template_id)
|
||||
if not template_info:
|
||||
raise ValueError(f"未知的模板ID: {template_id}")
|
||||
|
||||
# 确定内容
|
||||
if content is None:
|
||||
if source_data is None:
|
||||
raise ValueError("必须提供content或source_data中的一个")
|
||||
|
||||
# 使用AI生成内容
|
||||
content = await self.generate_content(template_id, source_data, temperature)
|
||||
generation_method = "ai_generated"
|
||||
else:
|
||||
generation_method = "direct"
|
||||
|
||||
# 验证内容
|
||||
self.validate_content(template_id, content)
|
||||
|
||||
# 选择图片
|
||||
if image_path is None:
|
||||
image_path = self.select_random_image(image_dir)
|
||||
|
||||
if not os.path.exists(image_path):
|
||||
raise ValueError(f"指定的图片文件不存在: {image_path}")
|
||||
|
||||
# 设置默认值
|
||||
if output_dir is None:
|
||||
output_dir = self.config_manager.get_default_config("output_dir")
|
||||
if topic_name is None:
|
||||
topic_name = f"poster_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
|
||||
# 获取模板类并生成海报
|
||||
try:
|
||||
template_class = self.config_manager.get_template_class(template_id)
|
||||
template_instance = template_class(template_info.size)
|
||||
|
||||
# 设置字体(如果支持)
|
||||
font_dir = self.config_manager.get_default_config("font_dir")
|
||||
if hasattr(template_instance, 'set_font_dir') and font_dir:
|
||||
template_instance.set_font_dir(font_dir)
|
||||
|
||||
poster = template_instance.generate(image_path=image_path, content=content)
|
||||
|
||||
if not poster:
|
||||
raise ValueError("海报生成失败,模板返回了 None")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成海报时发生错误: {e}", exc_info=True)
|
||||
raise ValueError(f"海报生成失败: {str(e)}")
|
||||
|
||||
# 保存海报
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# 生成文件名
|
||||
title = content.get('title', topic_name)
|
||||
if isinstance(title, str):
|
||||
title = title.replace('/', '_').replace('\\', '_')
|
||||
output_filename = f"{template_id}_{title}_{timestamp}.png"
|
||||
poster_path = os.path.join(output_dir, output_filename)
|
||||
|
||||
poster.save(poster_path, 'PNG')
|
||||
logger.info(f"海报已成功生成并保存至: {poster_path}")
|
||||
|
||||
processing_time = round(time.time() - start_time, 2)
|
||||
|
||||
return {
|
||||
"request_id": request_id,
|
||||
"template_id": template_id,
|
||||
"topic_name": topic_name,
|
||||
"poster_path": poster_path,
|
||||
"content": content,
|
||||
"metadata": {
|
||||
"image_used": image_path,
|
||||
"generation_method": generation_method,
|
||||
"template_size": template_info.size,
|
||||
"processing_time": processing_time,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存海报失败: {e}", exc_info=True)
|
||||
raise ValueError(f"保存海报失败: {str(e)}")
|
||||
|
||||
async def batch_generate_posters(self,
|
||||
template_id: str,
|
||||
base_path: str,
|
||||
image_dir: Optional[str] = None,
|
||||
source_files: Optional[Dict[str, str]] = None,
|
||||
output_base: str = "result/posters",
|
||||
parallel_count: int = 3,
|
||||
temperature: float = 0.7) -> Dict[str, Any]:
|
||||
"""
|
||||
批量生成海报
|
||||
|
||||
Args:
|
||||
template_id: 模板ID
|
||||
base_path: 包含多个topic目录的基础路径
|
||||
image_dir: 图片目录
|
||||
source_files: 源文件配置字典
|
||||
output_base: 输出基础目录
|
||||
parallel_count: 并发数量
|
||||
temperature: AI生成温度参数
|
||||
|
||||
Returns:
|
||||
批量处理结果
|
||||
"""
|
||||
logger.info(f"开始批量生成海报,模板: {template_id}, 基础路径: {base_path}")
|
||||
|
||||
# 生成批处理ID
|
||||
batch_request_id = f"batch-{template_id}-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
|
||||
|
||||
# 查找topic目录
|
||||
topic_dirs = self._find_topic_directories(base_path)
|
||||
|
||||
if not topic_dirs:
|
||||
raise ValueError("未找到任何包含article_judged.json的topic目录")
|
||||
|
||||
logger.info(f"找到 {len(topic_dirs)} 个topic目录,准备批量生成海报")
|
||||
|
||||
# 准备输出目录
|
||||
base_name = Path(base_path).name
|
||||
output_base_dir = os.path.join(output_base, base_name)
|
||||
|
||||
# 这里简化实现,实际项目中可以使用asyncio.gather进行真正的异步批处理
|
||||
results = []
|
||||
successful_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for topic_path, topic_name in topic_dirs:
|
||||
try:
|
||||
article_path = os.path.join(topic_path, 'article_judged.json')
|
||||
topic_output_dir = os.path.join(output_base_dir, topic_name)
|
||||
|
||||
# 读取文章数据
|
||||
source_data = self._read_data_file(article_path)
|
||||
if not source_data:
|
||||
raise ValueError(f"无法读取文章文件: {article_path}")
|
||||
|
||||
# 构建源数据
|
||||
final_source_data = {"tweet_info": source_data}
|
||||
|
||||
# 如果提供了额外的源文件,读取并添加
|
||||
if source_files:
|
||||
for key, file_path in source_files.items():
|
||||
if file_path and os.path.exists(file_path):
|
||||
data = self._read_data_file(file_path)
|
||||
if data:
|
||||
final_source_data[key] = data
|
||||
|
||||
# 生成海报
|
||||
result = await self.generate_poster(
|
||||
template_id=template_id,
|
||||
source_data=final_source_data,
|
||||
topic_name=topic_name,
|
||||
image_dir=image_dir,
|
||||
output_dir=topic_output_dir,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
results.append({
|
||||
"topic": topic_name,
|
||||
"success": True,
|
||||
"result": result
|
||||
})
|
||||
successful_count += 1
|
||||
logger.info(f"成功生成海报: {topic_name}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
results.append({
|
||||
"topic": topic_name,
|
||||
"success": False,
|
||||
"error": error_msg
|
||||
})
|
||||
failed_count += 1
|
||||
logger.error(f"生成海报失败 {topic_name}: {error_msg}")
|
||||
|
||||
successful_topics = [r["topic"] for r in results if r["success"]]
|
||||
failed_topics = [{"topic": r["topic"], "error": r["error"]} for r in results if not r["success"]]
|
||||
|
||||
return {
|
||||
"request_id": batch_request_id,
|
||||
"template_id": template_id,
|
||||
"total_topics": len(topic_dirs),
|
||||
"successful_count": successful_count,
|
||||
"failed_count": failed_count,
|
||||
"output_base_dir": output_base_dir,
|
||||
"successful_topics": successful_topics,
|
||||
"failed_topics": failed_topics,
|
||||
"detailed_results": results
|
||||
}
|
||||
|
||||
def _find_topic_directories(self, base_path: str) -> List[Tuple[str, str]]:
|
||||
"""查找topic目录"""
|
||||
topic_dirs = []
|
||||
base_path = Path(base_path)
|
||||
|
||||
if not base_path.exists():
|
||||
return topic_dirs
|
||||
|
||||
for item in base_path.iterdir():
|
||||
if item.is_dir() and item.name.startswith('topic_'):
|
||||
article_path = item / 'article_judged.json'
|
||||
if article_path.exists():
|
||||
topic_dirs.append((str(item), item.name))
|
||||
|
||||
return topic_dirs
|
||||
|
||||
def _read_data_file(self, file_path: str) -> Optional[Dict[str, Any]]:
|
||||
"""读取数据文件(简化版)"""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
try:
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
# 简单的文本内容处理
|
||||
return {"content": content}
|
||||
except Exception as e:
|
||||
logger.error(f"读取文件失败 {file_path}: {e}")
|
||||
return None
|
||||
|
||||
def reload_config(self):
|
||||
"""重新加载配置"""
|
||||
self.config_manager.reload_config()
|
||||
Loading…
x
Reference in New Issue
Block a user