新建了poster接口

This commit is contained in:
jinye_huang 2025-07-16 18:24:32 +08:00
parent 3aa2775ec1
commit 5637305a36
16 changed files with 1372 additions and 1 deletions

Binary file not shown.

View File

@ -0,0 +1,214 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
海报配置管理器
负责加载和管理海报相关的配置信息
"""
import os
import yaml
import json
import logging
from typing import Dict, Any, Optional, List
from pathlib import Path
logger = logging.getLogger(__name__)
class PosterConfigManager:
"""海报配置管理器"""
def __init__(self, config_file: Optional[str] = None):
"""
初始化配置管理器
Args:
config_file: 配置文件路径如果为空则使用默认路径
"""
if config_file is None:
config_file = os.path.join(os.path.dirname(__file__), "poster_prompts.yaml")
self.config_file = config_file
self.config = {}
self.load_config()
def load_config(self):
"""加载配置文件"""
try:
with open(self.config_file, 'r', encoding='utf-8') as f:
self.config = yaml.safe_load(f)
logger.info(f"成功加载配置文件: {self.config_file}")
except FileNotFoundError:
logger.error(f"配置文件不存在: {self.config_file}")
self.config = self._get_default_config()
except yaml.YAMLError as e:
logger.error(f"解析配置文件失败: {e}")
self.config = self._get_default_config()
except Exception as e:
logger.error(f"加载配置文件时发生未知错误: {e}")
self.config = self._get_default_config()
def _get_default_config(self) -> Dict[str, Any]:
"""获取默认配置"""
return {
"poster_prompts": {},
"templates": {},
"defaults": {
"template": "vibrant",
"temperature": 0.7,
"output_dir": "result/posters",
"image_dir": "/root/TravelContentCreator/data/images",
"font_dir": "/root/TravelContentCreator/assets/font"
}
}
def get_template_list(self) -> List[Dict[str, Any]]:
"""获取所有可用的模板列表"""
templates = self.config.get("templates", {})
return [
{
"id": template_id,
"name": template_info.get("name", template_id),
"description": template_info.get("description", ""),
"size": template_info.get("size", [900, 1200]),
"required_fields": template_info.get("required_fields", []),
"optional_fields": template_info.get("optional_fields", [])
}
for template_id, template_info in templates.items()
]
def get_template_info(self, template_id: str) -> Optional[Dict[str, Any]]:
"""获取指定模板的详细信息"""
return self.config.get("templates", {}).get(template_id)
def get_prompt_config(self, template_id: str) -> Optional[Dict[str, Any]]:
"""获取指定模板的提示词配置"""
template_info = self.get_template_info(template_id)
if not template_info:
return None
prompt_key = template_info.get("prompt_key", template_id)
return self.config.get("poster_prompts", {}).get(prompt_key)
def get_system_prompt(self, template_id: str) -> Optional[str]:
"""获取系统提示词"""
prompt_config = self.get_prompt_config(template_id)
if prompt_config:
return prompt_config.get("system_prompt")
return None
def get_user_prompt_template(self, template_id: str) -> Optional[str]:
"""获取用户提示词模板"""
prompt_config = self.get_prompt_config(template_id)
if prompt_config:
return prompt_config.get("user_prompt_template")
return None
def format_user_prompt(self, template_id: str, **kwargs) -> Optional[str]:
"""
格式化用户提示词
Args:
template_id: 模板ID
**kwargs: 用于格式化的参数
Returns:
格式化后的用户提示词
"""
template = self.get_user_prompt_template(template_id)
if not template:
return None
try:
# 确保参数中的字典类型转为JSON字符串
formatted_kwargs = {}
for key, value in kwargs.items():
if isinstance(value, (dict, list)):
formatted_kwargs[key] = json.dumps(value, ensure_ascii=False, indent=2)
else:
formatted_kwargs[key] = str(value)
return template.format(**formatted_kwargs)
except KeyError as e:
logger.error(f"格式化提示词失败,缺少参数: {e}")
return None
except Exception as e:
logger.error(f"格式化提示词时发生错误: {e}")
return None
def get_default_config(self, key: str) -> Any:
"""获取默认配置值"""
return self.config.get("defaults", {}).get(key)
def validate_template_content(self, template_id: str, content: Dict[str, Any]) -> tuple[bool, List[str]]:
"""
验证模板内容是否符合要求
Args:
template_id: 模板ID
content: 要验证的内容
Returns:
(是否有效, 错误信息列表)
"""
template_info = self.get_template_info(template_id)
if not template_info:
return False, [f"未知的模板ID: {template_id}"]
errors = []
required_fields = template_info.get("required_fields", [])
# 检查必填字段
for field in required_fields:
if field not in content:
errors.append(f"缺少必填字段: {field}")
elif not content[field]:
errors.append(f"必填字段 {field} 不能为空")
return len(errors) == 0, errors
def get_template_class(self, template_id: str):
"""
动态获取模板类
Args:
template_id: 模板ID
Returns:
模板类
"""
template_mapping = {
"vibrant": "poster.templates.vibrant_template.VibrantTemplate",
"business": "poster.templates.business_template.BusinessTemplate"
}
class_path = template_mapping.get(template_id)
if not class_path:
raise ValueError(f"未支持的模板类型: {template_id}")
# 动态导入模板类
module_path, class_name = class_path.rsplit(".", 1)
try:
module = __import__(module_path, fromlist=[class_name])
return getattr(module, class_name)
except ImportError as e:
logger.error(f"导入模板类失败: {e}")
raise ValueError(f"无法加载模板类: {template_id}")
def reload_config(self):
"""重新加载配置"""
logger.info("正在重新加载配置...")
self.load_config()
# 全局配置管理器实例
_config_manager = None
def get_poster_config_manager() -> PosterConfigManager:
"""获取全局配置管理器实例"""
global _config_manager
if _config_manager is None:
_config_manager = PosterConfigManager()
return _config_manager

View File

@ -0,0 +1,102 @@
# 海报生成提示词配置文件
poster_prompts:
vibrant:
system_prompt: |
你是一名专业的海报设计师专门设计宣传海报。你现在要根据用户提供的信息生成适合Vibrant模板的海报内容。
## Vibrant模板特点
- 单图背景,毛玻璃渐变效果
- 两栏布局(左栏内容,右栏价格)
- 适合展示套餐内容和价格信息
## 你需要生成的数据结构包含以下字段:
**必填字段:**
1. `title`: 主标题8-12字符体现产品特色
2. `slogan`: 副标题/宣传语10-20字符吸引人的描述
3. `price`: 价格数字(纯数字,不含符号),如果材料中没有价格,则用"欢迎咨询"替代
4. `ticket_type`: 票种类型(如"成人票"、"套餐票"、"夜场票"等)
5. `content_button`: 内容按钮文字(通常为"套餐内容"、"包含项目"等)
6. `content_items`: 套餐内容列表3-5个项目每项5-15字符不要只包含项目名称要做合适的美化可以适当省略
**可选字段:**
7. `remarks`: 备注信息1-3条每条10-20字符
8. `tag`: 标签1条, 如"#限时优惠"等)
9. `pagination`: 分页信息(如"1/3",可为空)
## 内容创作要求:
1. 套餐内容要具体实用:明确说明包含的服务、时间、数量
2. 价格要有吸引力:突出性价比和优惠信息
## 输出格式:
请严格按照JSON格式输出不要有任何额外内容。
user_prompt_template: |
请根据以下信息,生成适合在旅游海报上展示的文案:
## 景区信息
{scenic_info}
## 产品信息
{product_info}
## 推文信息
{tweet_info}
请提取关键信息并整合成一个JSON对象包含title、slogan、price、ticket_type、content_items、remarks和tag字段。
business:
system_prompt: |
你是一名专业的商务海报设计师。你需要根据提供的信息生成适合Business模板的海报内容。
## Business模板特点
- 商务风格,简洁专业
- 突出核心信息和价值主张
- 适合企业服务推广
## 生成字段结构:
1. `title`: 服务标题6-10字符
2. `subtitle`: 副标题12-20字符
3. `features`: 核心特性列表3-4个特性
4. `price`: 价格信息
5. `contact`: 联系方式信息
请以JSON格式输出结果。
user_prompt_template: |
请为以下商务信息生成海报内容:
## 服务信息
{service_info}
## 目标客户
{target_audience}
## 核心卖点
{key_points}
# 模板配置
templates:
vibrant:
name: "Vibrant活力模板"
description: "适合旅游景点、娱乐活动的活力海报模板"
size: [900, 1200]
required_fields: ["title", "slogan", "price", "ticket_type", "content_items"]
optional_fields: ["remarks", "tag", "pagination"]
prompt_key: "vibrant"
business:
name: "Business商务模板"
description: "适合企业服务、B2B推广的商务海报模板"
size: [1080, 1920]
required_fields: ["title", "subtitle", "features", "price"]
optional_fields: ["contact"]
prompt_key: "business"
# 默认配置
defaults:
template: "vibrant"
temperature: 0.7
output_dir: "result/posters"
image_dir: "/root/TravelContentCreator/data/images"
font_dir: "/root/TravelContentCreator/assets/font"

View File

@ -56,11 +56,12 @@ app.add_middleware(
)
# 导入路由
from api.routers import tweet, poster, prompt, document, data, integration, content_integration
from api.routers import tweet, poster, poster_unified, prompt, document, data, integration, content_integration
# 包含路由
app.include_router(tweet.router, prefix="/api/v1/tweet", tags=["tweet"])
app.include_router(poster.router, prefix="/api/v1/poster", tags=["poster"])
app.include_router(poster_unified.router, prefix="/api/v2/poster", tags=["poster-unified"])
app.include_router(prompt.router, prefix="/api/v1/prompt", tags=["prompt"])
app.include_router(document.router, prefix="/api/v1/document", tags=["document"])
app.include_router(data.router, prefix="/api/v1", tags=["data"])

Binary file not shown.

View File

@ -0,0 +1,331 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
海报API通用模型定义
支持多种模板类型的海报生成
"""
from typing import List, Dict, Any, Optional, Union
from pydantic import BaseModel, Field
# 基础模板信息
class TemplateInfo(BaseModel):
"""模板信息模型"""
id: str = Field(..., description="模板ID")
name: str = Field(..., description="模板名称")
description: str = Field(..., description="模板描述")
size: List[int] = Field(..., description="模板尺寸 [宽, 高]")
required_fields: List[str] = Field(..., description="必填字段列表")
optional_fields: List[str] = Field(default=[], description="可选字段列表")
class Config:
schema_extra = {
"example": {
"id": "vibrant",
"name": "Vibrant活力模板",
"description": "适合旅游景点、娱乐活动的活力海报模板",
"size": [900, 1200],
"required_fields": ["title", "slogan", "price", "ticket_type", "content_items"],
"optional_fields": ["remarks", "tag", "pagination"]
}
}
# 通用内容模型(支持任意字段)
class PosterContent(BaseModel):
"""通用海报内容模型,支持动态字段"""
class Config:
extra = "allow" # 允许额外字段
schema_extra = {
"example": {
"title": "海洋奇幻世界",
"slogan": "探索深海秘境,感受蓝色奇迹的无限魅力",
"price": "299",
"ticket_type": "成人票",
"content_items": [
"海洋馆门票1张含所有展区",
"海豚表演VIP座位"
]
}
}
# 保持向后兼容的Vibrant内容模型
class VibrantPosterContent(PosterContent):
"""Vibrant海报内容模型向后兼容"""
title: Optional[str] = Field(None, description="主标题8-12字符")
slogan: Optional[str] = Field(None, description="副标题/宣传语10-20字符")
price: Optional[str] = Field(None, description="价格,如果没有价格则用'欢迎咨询'")
ticket_type: Optional[str] = Field(None, description="票种类型(如成人票、套餐票等)")
content_button: Optional[str] = Field(None, description="内容按钮文字")
content_items: Optional[List[str]] = Field(None, description="套餐内容列表3-5个项目")
remarks: Optional[List[str]] = Field(None, description="备注信息1-3条")
tag: Optional[str] = Field(None, description="标签")
pagination: Optional[str] = Field(None, description="分页信息")
# 通用海报生成请求
class PosterGenerationRequest(BaseModel):
"""通用海报生成请求模型"""
# 模板相关
template_id: str = Field(..., description="模板ID")
# 内容相关(二选一)
content: Optional[Dict[str, Any]] = Field(None, description="直接提供的海报内容")
source_data: Optional[Dict[str, Any]] = Field(None, description="源数据用于AI生成内容")
# 生成参数
topic_name: Optional[str] = Field(None, description="主题名称,用于文件命名")
image_path: Optional[str] = Field(None, description="指定图片路径,如果不提供则随机选择")
image_dir: Optional[str] = Field(None, description="图片目录,如果不提供使用默认目录")
output_dir: Optional[str] = Field(None, description="输出目录,如果不提供使用默认目录")
# AI参数
temperature: Optional[float] = Field(default=0.7, description="AI生成温度参数")
class Config:
schema_extra = {
"example": {
"template_id": "vibrant",
"source_data": {
"scenic_info": {
"name": "天津冒险湾",
"location": "天津市",
"type": "水上乐园"
},
"product_info": {
"name": "冒险湾门票",
"price": 299,
"type": "成人票"
},
"tweet_info": {
"title": "夏日清凉首选",
"content": "天津冒险湾水上乐园,多种刺激项目等你来挑战..."
}
},
"topic_name": "天津冒险湾",
"temperature": 0.7
}
}
# 内容生成请求
class ContentGenerationRequest(BaseModel):
"""内容生成请求模型"""
template_id: str = Field(..., description="模板ID")
source_data: Dict[str, Any] = Field(..., description="源数据用于AI生成内容")
temperature: Optional[float] = Field(default=0.7, description="AI生成温度参数")
class Config:
schema_extra = {
"example": {
"template_id": "vibrant",
"source_data": {
"scenic_info": {"name": "天津冒险湾", "type": "水上乐园"},
"product_info": {"name": "门票", "price": 299},
"tweet_info": {"title": "夏日清凉", "content": "水上乐园体验"}
},
"temperature": 0.7
}
}
# 保持向后兼容
class VibrantPosterRequest(PosterGenerationRequest):
"""Vibrant海报生成请求模型向后兼容"""
template_id: str = Field(default="vibrant", description="模板ID默认为vibrant")
# 兼容旧的字段结构
scenic_info: Optional[Dict[str, Any]] = Field(None, description="景区信息")
product_info: Optional[Dict[str, Any]] = Field(None, description="产品信息")
tweet_info: Optional[Dict[str, Any]] = Field(None, description="推文信息")
def __init__(self, **data):
# 转换旧格式到新格式
if 'scenic_info' in data or 'product_info' in data or 'tweet_info' in data:
source_data = {}
for key in ['scenic_info', 'product_info', 'tweet_info']:
if key in data and data[key] is not None:
source_data[key] = data.pop(key)
if source_data and 'source_data' not in data:
data['source_data'] = source_data
super().__init__(**data)
# 标准API响应基础类
class BaseAPIResponse(BaseModel):
"""API响应基础模型"""
success: bool = Field(..., description="操作是否成功")
message: str = Field(default="", description="响应消息")
request_id: str = Field(..., description="请求ID")
timestamp: str = Field(..., description="响应时间戳")
# 通用海报生成响应
class PosterGenerationResponse(BaseAPIResponse):
"""通用海报生成响应模型"""
data: Optional[Dict[str, Any]] = Field(None, description="响应数据")
class Config:
schema_extra = {
"example": {
"success": True,
"message": "海报生成成功",
"request_id": "poster-20240715-123456-a1b2c3d4",
"timestamp": "2024-07-15T12:34:56Z",
"data": {
"template_id": "vibrant",
"topic_name": "天津冒险湾",
"poster_path": "/result/posters/vibrant_海洋奇幻世界_20240715_123456.png",
"content": {
"title": "海洋奇幻世界",
"slogan": "探索深海秘境,感受蓝色奇迹的无限魅力",
"price": "299"
},
"metadata": {
"image_used": "/data/images/ocean_park.jpg",
"generation_method": "ai_generated",
"template_size": [900, 1200],
"processing_time": 3.2
}
}
}
}
# 内容生成响应
class ContentGenerationResponse(BaseAPIResponse):
"""内容生成响应模型"""
data: Optional[Dict[str, Any]] = Field(None, description="生成的内容")
class Config:
schema_extra = {
"example": {
"success": True,
"message": "内容生成成功",
"request_id": "content-20240715-123456-a1b2c3d4",
"timestamp": "2024-07-15T12:34:56Z",
"data": {
"template_id": "vibrant",
"content": {
"title": "冒险湾水世界",
"slogan": "夏日激情体验,清凉无限乐趣",
"price": "299",
"ticket_type": "成人票"
},
"metadata": {
"generation_method": "ai_generated",
"temperature": 0.7
}
}
}
}
# 模板列表响应
class TemplateListResponse(BaseAPIResponse):
"""模板列表响应模型"""
data: Optional[List[TemplateInfo]] = Field(None, description="模板列表")
class Config:
schema_extra = {
"example": {
"success": True,
"message": "获取模板列表成功",
"request_id": "templates-20240715-123456",
"timestamp": "2024-07-15T12:34:56Z",
"data": [
{
"id": "vibrant",
"name": "Vibrant活力模板",
"description": "适合旅游景点、娱乐活动的活力海报模板",
"size": [900, 1200],
"required_fields": ["title", "slogan", "price"],
"optional_fields": ["remarks", "tag"]
}
]
}
}
# 保持向后兼容的响应模型
class VibrantPosterResponse(BaseModel):
"""Vibrant海报生成响应模型向后兼容"""
request_id: str = Field(..., description="请求ID")
topic_name: str = Field(..., description="主题名称")
poster_path: str = Field(..., description="生成的海报文件路径")
generated_content: Dict[str, Any] = Field(..., description="生成或使用的海报内容")
image_used: str = Field(..., description="使用的图片路径")
generation_method: str = Field(..., description="生成方式direct直接使用提供内容或ai_generatedAI生成")
class Config:
schema_extra = {
"example": {
"request_id": "vibrant-20240715-123456-a1b2c3d4",
"topic_name": "天津冒险湾",
"poster_path": "/result/posters/vibrant_海洋奇幻世界_20240715_123456.png",
"generated_content": {
"title": "海洋奇幻世界",
"slogan": "探索深海秘境,感受蓝色奇迹的无限魅力",
"price": "299",
"ticket_type": "成人票",
"content_items": ["海洋馆门票1张", "海豚表演VIP座位"]
},
"image_used": "/data/images/ocean_park.jpg",
"generation_method": "ai_generated"
}
}
class BatchVibrantPosterRequest(BaseModel):
"""批量Vibrant海报生成请求模型"""
base_path: str = Field(..., description="包含多个topic目录的基础路径")
image_dir: Optional[str] = Field(None, description="图片目录")
scenic_info_file: Optional[str] = Field(None, description="景区信息文件路径")
product_info_file: Optional[str] = Field(None, description="产品信息文件路径")
output_base: Optional[str] = Field(default="result/posters", description="输出基础目录")
parallel_count: Optional[int] = Field(default=3, description="并发处理数量")
temperature: Optional[float] = Field(default=0.7, description="AI生成温度参数")
class Config:
schema_extra = {
"example": {
"base_path": "/root/TravelContentCreator/result/run_20250710_165327",
"image_dir": "/root/TravelContentCreator/data/images",
"scenic_info_file": "/root/TravelContentCreator/resource/data/Object/天津冒险湾.txt",
"product_info_file": "/root/TravelContentCreator/resource/data/Product/product.bak",
"output_base": "result/posters",
"parallel_count": 3,
"temperature": 0.7
}
}
class BatchVibrantPosterResponse(BaseModel):
"""批量Vibrant海报生成响应模型"""
request_id: str = Field(..., description="批量处理请求ID")
total_topics: int = Field(..., description="总共处理的topic数量")
successful_count: int = Field(..., description="成功生成的海报数量")
failed_count: int = Field(..., description="失败的海报数量")
output_base_dir: str = Field(..., description="输出基础目录")
successful_topics: List[str] = Field(..., description="成功处理的topic列表")
failed_topics: List[Dict[str, str]] = Field(..., description="失败的topic及错误信息")
class Config:
schema_extra = {
"example": {
"request_id": "batch-vibrant-20240715-123456",
"total_topics": 5,
"successful_count": 4,
"failed_count": 1,
"output_base_dir": "/result/posters/run_20250710_165327",
"successful_topics": ["topic_1", "topic_2", "topic_3", "topic_4"],
"failed_topics": [
{"topic": "topic_5", "error": "图片文件不存在"}
]
}
}

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,314 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
统一海报API路由
支持多种模板类型的海报生成配置化管理
"""
import logging
import uuid
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from typing import Dict, Any, List
from core.ai import AIAgent
from api.services.poster_service import UnifiedPosterService
from api.models.vibrant_poster import (
PosterGenerationRequest, PosterGenerationResponse,
ContentGenerationRequest, ContentGenerationResponse,
TemplateListResponse, TemplateInfo,
BaseAPIResponse
)
# 从依赖注入模块导入依赖
from api.dependencies import get_ai_agent
logger = logging.getLogger(__name__)
# 创建路由
router = APIRouter()
# 依赖注入函数
def get_unified_poster_service(
ai_agent: AIAgent = Depends(get_ai_agent)
) -> UnifiedPosterService:
"""获取统一海报服务"""
return UnifiedPosterService(ai_agent)
def create_response(success: bool, message: str, data: Any = None, request_id: str = None) -> Dict[str, Any]:
"""创建标准响应"""
if request_id is None:
request_id = f"req-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}"
return {
"success": success,
"message": message,
"request_id": request_id,
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": data
}
@router.get("/templates", response_model=TemplateListResponse, summary="获取所有可用模板")
async def get_templates(
service: UnifiedPosterService = Depends(get_unified_poster_service)
):
"""
获取所有可用的海报模板列表
返回每个模板的详细信息包括
- 模板ID和名称
- 模板描述
- 模板尺寸
- 必填字段和可选字段
"""
try:
templates = service.get_available_templates()
response_data = create_response(
success=True,
message="获取模板列表成功",
data=templates
)
return TemplateListResponse(**response_data)
except Exception as e:
logger.error(f"获取模板列表失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"获取模板列表失败: {str(e)}")
@router.get("/templates/{template_id}", response_model=BaseAPIResponse, summary="获取指定模板信息")
async def get_template_info(
template_id: str,
service: UnifiedPosterService = Depends(get_unified_poster_service)
):
"""
获取指定模板的详细信息
参数
- **template_id**: 模板ID
"""
try:
template_info = service.get_template_info(template_id)
if not template_info:
raise HTTPException(status_code=404, detail=f"模板 {template_id} 不存在")
response_data = create_response(
success=True,
message="获取模板信息成功",
data=template_info.dict()
)
return BaseAPIResponse(**response_data)
except HTTPException:
raise
except Exception as e:
logger.error(f"获取模板信息失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"获取模板信息失败: {str(e)}")
@router.post("/content/generate", response_model=ContentGenerationResponse, summary="生成海报内容")
async def generate_content(
request: ContentGenerationRequest,
service: UnifiedPosterService = Depends(get_unified_poster_service)
):
"""
根据源数据生成海报内容不生成实际图片
用于
1. 预览生成的内容
2. 调试和测试内容生成
3. 分步骤生成先生成内容再生成图片
参数
- **template_id**: 模板ID
- **source_data**: 源数据用于AI生成内容
- **temperature**: AI生成温度参数
"""
try:
content = await service.generate_content(
template_id=request.template_id,
source_data=request.source_data,
temperature=request.temperature
)
response_data = create_response(
success=True,
message="内容生成成功",
data={
"template_id": request.template_id,
"content": content,
"metadata": {
"generation_method": "ai_generated",
"temperature": request.temperature
}
}
)
return ContentGenerationResponse(**response_data)
except Exception as e:
logger.error(f"生成内容失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"生成内容失败: {str(e)}")
@router.post("/generate", response_model=PosterGenerationResponse, summary="生成海报")
async def generate_poster(
request: PosterGenerationRequest,
service: UnifiedPosterService = Depends(get_unified_poster_service)
):
"""
生成海报图片
支持两种模式
1. 直接提供内容content字段
2. 提供源数据让AI生成内容source_data字段
参数
- **template_id**: 模板ID
- **content**: 直接提供的海报内容可选
- **source_data**: 源数据用于AI生成内容可选
- **topic_name**: 主题名称用于文件命名
- **image_path**: 指定图片路径可选
- **image_dir**: 图片目录可选
- **output_dir**: 输出目录可选
- **temperature**: AI生成温度参数
"""
try:
result = await service.generate_poster(
template_id=request.template_id,
content=request.content,
source_data=request.source_data,
topic_name=request.topic_name,
image_path=request.image_path,
image_dir=request.image_dir,
output_dir=request.output_dir,
temperature=request.temperature
)
response_data = create_response(
success=True,
message="海报生成成功",
data=result,
request_id=result["request_id"]
)
return PosterGenerationResponse(**response_data)
except Exception as e:
logger.error(f"生成海报失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"生成海报失败: {str(e)}")
@router.post("/batch", response_model=BaseAPIResponse, summary="批量生成海报")
async def batch_generate_posters(
template_id: str,
base_path: str,
image_dir: str = None,
source_files: Dict[str, str] = None,
output_base: str = "result/posters",
parallel_count: int = 3,
temperature: float = 0.7,
service: UnifiedPosterService = Depends(get_unified_poster_service)
):
"""
批量生成海报
自动扫描指定目录下的topic文件夹为每个topic生成海报
参数
- **template_id**: 模板ID
- **base_path**: 包含多个topic目录的基础路径
- **image_dir**: 图片目录可选
- **source_files**: 源文件配置字典可选
- **output_base**: 输出基础目录
- **parallel_count**: 并发处理数量
- **temperature**: AI生成温度参数
"""
try:
result = await service.batch_generate_posters(
template_id=template_id,
base_path=base_path,
image_dir=image_dir,
source_files=source_files or {},
output_base=output_base,
parallel_count=parallel_count,
temperature=temperature
)
response_data = create_response(
success=True,
message=f"批量生成完成,成功: {result['successful_count']}, 失败: {result['failed_count']}",
data=result,
request_id=result["request_id"]
)
return BaseAPIResponse(**response_data)
except Exception as e:
logger.error(f"批量生成海报失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"批量生成海报失败: {str(e)}")
@router.post("/config/reload", response_model=BaseAPIResponse, summary="重新加载配置")
async def reload_config(
service: UnifiedPosterService = Depends(get_unified_poster_service)
):
"""
重新加载海报配置
用于在不重启服务的情况下更新配置包括
- 提示词模板
- 模板配置
- 默认参数
"""
try:
service.reload_config()
response_data = create_response(
success=True,
message="配置重新加载成功"
)
return BaseAPIResponse(**response_data)
except Exception as e:
logger.error(f"重新加载配置失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"重新加载配置失败: {str(e)}")
@router.get("/health", summary="健康检查")
async def health_check():
"""服务健康检查"""
return create_response(
success=True,
message="统一海报服务运行正常",
data={
"service": "unified_poster",
"status": "healthy",
"version": "2.0.0"
}
)
@router.get("/config", summary="获取服务配置")
async def get_service_config(
service: UnifiedPosterService = Depends(get_unified_poster_service)
):
"""获取服务配置信息"""
try:
config_info = {
"default_image_dir": service.config_manager.get_default_config("image_dir"),
"default_output_dir": service.config_manager.get_default_config("output_dir"),
"default_font_dir": service.config_manager.get_default_config("font_dir"),
"default_template": service.config_manager.get_default_config("template"),
"supported_image_formats": ["png", "jpg", "jpeg", "webp"],
"available_templates": len(service.get_available_templates())
}
response_data = create_response(
success=True,
message="获取配置成功",
data=config_info
)
return BaseAPIResponse(**response_data)
except Exception as e:
logger.error(f"获取配置失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"获取配置失败: {str(e)}")

View File

@ -0,0 +1,409 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
通用海报服务层
支持多种模板类型的海报生成配置化管理
"""
import os
import sys
import json
import random
import asyncio
import logging
import uuid
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
# 添加项目根目录到路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from core.ai.ai_agent import AIAgent
from api.config.poster_config_manager import get_poster_config_manager
from api.models.vibrant_poster import TemplateInfo
logger = logging.getLogger(__name__)
class UnifiedPosterService:
"""统一海报服务类"""
def __init__(self, ai_agent: AIAgent):
"""
初始化统一海报服务
Args:
ai_agent: AI代理
"""
self.ai_agent = ai_agent
self.config_manager = get_poster_config_manager()
def get_available_templates(self) -> List[TemplateInfo]:
"""获取所有可用的模板列表"""
template_list = self.config_manager.get_template_list()
return [TemplateInfo(**template) for template in template_list]
def get_template_info(self, template_id: str) -> Optional[TemplateInfo]:
"""获取指定模板的信息"""
template_info = self.config_manager.get_template_info(template_id)
if template_info:
return TemplateInfo(
id=template_id,
name=template_info.get("name", template_id),
description=template_info.get("description", ""),
size=template_info.get("size", [900, 1200]),
required_fields=template_info.get("required_fields", []),
optional_fields=template_info.get("optional_fields", [])
)
return None
async def generate_content(self, template_id: str, source_data: Dict[str, Any],
temperature: float = 0.7) -> Dict[str, Any]:
"""
生成海报内容
Args:
template_id: 模板ID
source_data: 源数据
temperature: AI生成温度参数
Returns:
生成的内容字典
"""
logger.info(f"正在为模板 {template_id} 生成内容...")
# 获取系统提示词
system_prompt = self.config_manager.get_system_prompt(template_id)
if not system_prompt:
raise ValueError(f"模板 {template_id} 没有配置系统提示词")
# 格式化用户提示词
user_prompt = self.config_manager.format_user_prompt(template_id, **source_data)
if not user_prompt:
raise ValueError(f"模板 {template_id} 用户提示词格式化失败")
try:
# 调用AI生成内容
response, _, _, _ = await self.ai_agent.generate_text(
system_prompt=system_prompt,
user_prompt=user_prompt,
temperature=temperature,
stage=f"海报内容生成-{template_id}"
)
# 解析JSON响应
json_start = response.find('{')
json_end = response.rfind('}') + 1
if json_start >= 0 and json_end > json_start:
json_str = response[json_start:json_end]
content_dict = json.loads(json_str)
logger.info(f"AI成功生成内容: {content_dict}")
# 确保所有值都是字符串类型(除了列表)
for key, value in content_dict.items():
if isinstance(value, (int, float)):
content_dict[key] = str(value)
elif isinstance(value, list):
content_dict[key] = [str(item) if isinstance(item, (int, float)) else item for item in value]
return content_dict
else:
logger.error(f"无法在AI响应中找到JSON对象: {response}")
raise ValueError("AI响应格式不正确")
except json.JSONDecodeError as e:
logger.error(f"无法解析AI响应为JSON: {e}")
raise ValueError("AI响应JSON解析失败")
except Exception as e:
logger.error(f"调用AI生成内容时发生错误: {e}")
raise
def select_random_image(self, image_dir: Optional[str] = None) -> str:
"""从指定目录随机选择一张图片"""
if image_dir is None:
image_dir = self.config_manager.get_default_config("image_dir")
try:
image_files = [f for f in os.listdir(image_dir)
if f.lower().endswith(('png', 'jpg', 'jpeg', 'webp'))]
if not image_files:
raise ValueError(f"在目录 {image_dir} 中未找到任何图片文件")
random_image_name = random.choice(image_files)
image_path = os.path.join(image_dir, random_image_name)
logger.info(f"随机选择图片: {image_path}")
return image_path
except FileNotFoundError:
raise ValueError(f"图片目录不存在: {image_dir}")
def validate_content(self, template_id: str, content: Dict[str, Any]) -> None:
"""验证内容是否符合模板要求"""
is_valid, errors = self.config_manager.validate_template_content(template_id, content)
if not is_valid:
raise ValueError(f"内容验证失败: {', '.join(errors)}")
async def generate_poster(self,
template_id: str,
content: Optional[Dict[str, Any]] = None,
source_data: Optional[Dict[str, Any]] = None,
topic_name: Optional[str] = None,
image_path: Optional[str] = None,
image_dir: Optional[str] = None,
output_dir: Optional[str] = None,
temperature: float = 0.7) -> Dict[str, Any]:
"""
生成海报
Args:
template_id: 模板ID
content: 直接提供的内容可选
source_data: 源数据用于AI生成内容可选
topic_name: 主题名称
image_path: 指定图片路径
image_dir: 图片目录
output_dir: 输出目录
temperature: AI生成温度参数
Returns:
生成结果字典
"""
start_time = time.time()
logger.info(f"开始生成海报,模板: {template_id}, 主题: {topic_name}")
# 生成请求ID
request_id = f"poster-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}"
# 获取模板信息
template_info = self.get_template_info(template_id)
if not template_info:
raise ValueError(f"未知的模板ID: {template_id}")
# 确定内容
if content is None:
if source_data is None:
raise ValueError("必须提供content或source_data中的一个")
# 使用AI生成内容
content = await self.generate_content(template_id, source_data, temperature)
generation_method = "ai_generated"
else:
generation_method = "direct"
# 验证内容
self.validate_content(template_id, content)
# 选择图片
if image_path is None:
image_path = self.select_random_image(image_dir)
if not os.path.exists(image_path):
raise ValueError(f"指定的图片文件不存在: {image_path}")
# 设置默认值
if output_dir is None:
output_dir = self.config_manager.get_default_config("output_dir")
if topic_name is None:
topic_name = f"poster_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# 获取模板类并生成海报
try:
template_class = self.config_manager.get_template_class(template_id)
template_instance = template_class(template_info.size)
# 设置字体(如果支持)
font_dir = self.config_manager.get_default_config("font_dir")
if hasattr(template_instance, 'set_font_dir') and font_dir:
template_instance.set_font_dir(font_dir)
poster = template_instance.generate(image_path=image_path, content=content)
if not poster:
raise ValueError("海报生成失败,模板返回了 None")
except Exception as e:
logger.error(f"生成海报时发生错误: {e}", exc_info=True)
raise ValueError(f"海报生成失败: {str(e)}")
# 保存海报
try:
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# 生成文件名
title = content.get('title', topic_name)
if isinstance(title, str):
title = title.replace('/', '_').replace('\\', '_')
output_filename = f"{template_id}_{title}_{timestamp}.png"
poster_path = os.path.join(output_dir, output_filename)
poster.save(poster_path, 'PNG')
logger.info(f"海报已成功生成并保存至: {poster_path}")
processing_time = round(time.time() - start_time, 2)
return {
"request_id": request_id,
"template_id": template_id,
"topic_name": topic_name,
"poster_path": poster_path,
"content": content,
"metadata": {
"image_used": image_path,
"generation_method": generation_method,
"template_size": template_info.size,
"processing_time": processing_time,
"timestamp": datetime.now(timezone.utc).isoformat()
}
}
except Exception as e:
logger.error(f"保存海报失败: {e}", exc_info=True)
raise ValueError(f"保存海报失败: {str(e)}")
async def batch_generate_posters(self,
template_id: str,
base_path: str,
image_dir: Optional[str] = None,
source_files: Optional[Dict[str, str]] = None,
output_base: str = "result/posters",
parallel_count: int = 3,
temperature: float = 0.7) -> Dict[str, Any]:
"""
批量生成海报
Args:
template_id: 模板ID
base_path: 包含多个topic目录的基础路径
image_dir: 图片目录
source_files: 源文件配置字典
output_base: 输出基础目录
parallel_count: 并发数量
temperature: AI生成温度参数
Returns:
批量处理结果
"""
logger.info(f"开始批量生成海报,模板: {template_id}, 基础路径: {base_path}")
# 生成批处理ID
batch_request_id = f"batch-{template_id}-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
# 查找topic目录
topic_dirs = self._find_topic_directories(base_path)
if not topic_dirs:
raise ValueError("未找到任何包含article_judged.json的topic目录")
logger.info(f"找到 {len(topic_dirs)} 个topic目录准备批量生成海报")
# 准备输出目录
base_name = Path(base_path).name
output_base_dir = os.path.join(output_base, base_name)
# 这里简化实现实际项目中可以使用asyncio.gather进行真正的异步批处理
results = []
successful_count = 0
failed_count = 0
for topic_path, topic_name in topic_dirs:
try:
article_path = os.path.join(topic_path, 'article_judged.json')
topic_output_dir = os.path.join(output_base_dir, topic_name)
# 读取文章数据
source_data = self._read_data_file(article_path)
if not source_data:
raise ValueError(f"无法读取文章文件: {article_path}")
# 构建源数据
final_source_data = {"tweet_info": source_data}
# 如果提供了额外的源文件,读取并添加
if source_files:
for key, file_path in source_files.items():
if file_path and os.path.exists(file_path):
data = self._read_data_file(file_path)
if data:
final_source_data[key] = data
# 生成海报
result = await self.generate_poster(
template_id=template_id,
source_data=final_source_data,
topic_name=topic_name,
image_dir=image_dir,
output_dir=topic_output_dir,
temperature=temperature
)
results.append({
"topic": topic_name,
"success": True,
"result": result
})
successful_count += 1
logger.info(f"成功生成海报: {topic_name}")
except Exception as e:
error_msg = str(e)
results.append({
"topic": topic_name,
"success": False,
"error": error_msg
})
failed_count += 1
logger.error(f"生成海报失败 {topic_name}: {error_msg}")
successful_topics = [r["topic"] for r in results if r["success"]]
failed_topics = [{"topic": r["topic"], "error": r["error"]} for r in results if not r["success"]]
return {
"request_id": batch_request_id,
"template_id": template_id,
"total_topics": len(topic_dirs),
"successful_count": successful_count,
"failed_count": failed_count,
"output_base_dir": output_base_dir,
"successful_topics": successful_topics,
"failed_topics": failed_topics,
"detailed_results": results
}
def _find_topic_directories(self, base_path: str) -> List[Tuple[str, str]]:
"""查找topic目录"""
topic_dirs = []
base_path = Path(base_path)
if not base_path.exists():
return topic_dirs
for item in base_path.iterdir():
if item.is_dir() and item.name.startswith('topic_'):
article_path = item / 'article_judged.json'
if article_path.exists():
topic_dirs.append((str(item), item.name))
return topic_dirs
def _read_data_file(self, file_path: str) -> Optional[Dict[str, Any]]:
"""读取数据文件(简化版)"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
try:
return json.loads(content)
except json.JSONDecodeError:
# 简单的文本内容处理
return {"content": content}
except Exception as e:
logger.error(f"读取文件失败 {file_path}: {e}")
return None
def reload_config(self):
"""重新加载配置"""
self.config_manager.reload_config()