diff --git a/api/__pycache__/main.cpython-312.pyc b/api/__pycache__/main.cpython-312.pyc index 4a026e3..a154140 100644 Binary files a/api/__pycache__/main.cpython-312.pyc and b/api/__pycache__/main.cpython-312.pyc differ diff --git a/api/config/__pycache__/poster_config_manager.cpython-312.pyc b/api/config/__pycache__/poster_config_manager.cpython-312.pyc new file mode 100644 index 0000000..8d51435 Binary files /dev/null and b/api/config/__pycache__/poster_config_manager.cpython-312.pyc differ diff --git a/api/config/poster_config_manager.py b/api/config/poster_config_manager.py new file mode 100644 index 0000000..175221d --- /dev/null +++ b/api/config/poster_config_manager.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +海报配置管理器 +负责加载和管理海报相关的配置信息 +""" + +import os +import yaml +import json +import logging +from typing import Dict, Any, Optional, List +from pathlib import Path + +logger = logging.getLogger(__name__) + + +class PosterConfigManager: + """海报配置管理器""" + + def __init__(self, config_file: Optional[str] = None): + """ + 初始化配置管理器 + + Args: + config_file: 配置文件路径,如果为空则使用默认路径 + """ + if config_file is None: + config_file = os.path.join(os.path.dirname(__file__), "poster_prompts.yaml") + + self.config_file = config_file + self.config = {} + self.load_config() + + def load_config(self): + """加载配置文件""" + try: + with open(self.config_file, 'r', encoding='utf-8') as f: + self.config = yaml.safe_load(f) + logger.info(f"成功加载配置文件: {self.config_file}") + except FileNotFoundError: + logger.error(f"配置文件不存在: {self.config_file}") + self.config = self._get_default_config() + except yaml.YAMLError as e: + logger.error(f"解析配置文件失败: {e}") + self.config = self._get_default_config() + except Exception as e: + logger.error(f"加载配置文件时发生未知错误: {e}") + self.config = self._get_default_config() + + def _get_default_config(self) -> Dict[str, Any]: + """获取默认配置""" + return { + "poster_prompts": {}, + "templates": {}, + "defaults": { + "template": "vibrant", + "temperature": 0.7, + "output_dir": "result/posters", + "image_dir": "/root/TravelContentCreator/data/images", + "font_dir": "/root/TravelContentCreator/assets/font" + } + } + + def get_template_list(self) -> List[Dict[str, Any]]: + """获取所有可用的模板列表""" + templates = self.config.get("templates", {}) + return [ + { + "id": template_id, + "name": template_info.get("name", template_id), + "description": template_info.get("description", ""), + "size": template_info.get("size", [900, 1200]), + "required_fields": template_info.get("required_fields", []), + "optional_fields": template_info.get("optional_fields", []) + } + for template_id, template_info in templates.items() + ] + + def get_template_info(self, template_id: str) -> Optional[Dict[str, Any]]: + """获取指定模板的详细信息""" + return self.config.get("templates", {}).get(template_id) + + def get_prompt_config(self, template_id: str) -> Optional[Dict[str, Any]]: + """获取指定模板的提示词配置""" + template_info = self.get_template_info(template_id) + if not template_info: + return None + + prompt_key = template_info.get("prompt_key", template_id) + return self.config.get("poster_prompts", {}).get(prompt_key) + + def get_system_prompt(self, template_id: str) -> Optional[str]: + """获取系统提示词""" + prompt_config = self.get_prompt_config(template_id) + if prompt_config: + return prompt_config.get("system_prompt") + return None + + def get_user_prompt_template(self, template_id: str) -> Optional[str]: + """获取用户提示词模板""" + prompt_config = self.get_prompt_config(template_id) + if prompt_config: + return prompt_config.get("user_prompt_template") + return None + + def format_user_prompt(self, template_id: str, **kwargs) -> Optional[str]: + """ + 格式化用户提示词 + + Args: + template_id: 模板ID + **kwargs: 用于格式化的参数 + + Returns: + 格式化后的用户提示词 + """ + template = self.get_user_prompt_template(template_id) + if not template: + return None + + try: + # 确保参数中的字典类型转为JSON字符串 + formatted_kwargs = {} + for key, value in kwargs.items(): + if isinstance(value, (dict, list)): + formatted_kwargs[key] = json.dumps(value, ensure_ascii=False, indent=2) + else: + formatted_kwargs[key] = str(value) + + return template.format(**formatted_kwargs) + except KeyError as e: + logger.error(f"格式化提示词失败,缺少参数: {e}") + return None + except Exception as e: + logger.error(f"格式化提示词时发生错误: {e}") + return None + + def get_default_config(self, key: str) -> Any: + """获取默认配置值""" + return self.config.get("defaults", {}).get(key) + + def validate_template_content(self, template_id: str, content: Dict[str, Any]) -> tuple[bool, List[str]]: + """ + 验证模板内容是否符合要求 + + Args: + template_id: 模板ID + content: 要验证的内容 + + Returns: + (是否有效, 错误信息列表) + """ + template_info = self.get_template_info(template_id) + if not template_info: + return False, [f"未知的模板ID: {template_id}"] + + errors = [] + required_fields = template_info.get("required_fields", []) + + # 检查必填字段 + for field in required_fields: + if field not in content: + errors.append(f"缺少必填字段: {field}") + elif not content[field]: + errors.append(f"必填字段 {field} 不能为空") + + return len(errors) == 0, errors + + def get_template_class(self, template_id: str): + """ + 动态获取模板类 + + Args: + template_id: 模板ID + + Returns: + 模板类 + """ + template_mapping = { + "vibrant": "poster.templates.vibrant_template.VibrantTemplate", + "business": "poster.templates.business_template.BusinessTemplate" + } + + class_path = template_mapping.get(template_id) + if not class_path: + raise ValueError(f"未支持的模板类型: {template_id}") + + # 动态导入模板类 + module_path, class_name = class_path.rsplit(".", 1) + try: + module = __import__(module_path, fromlist=[class_name]) + return getattr(module, class_name) + except ImportError as e: + logger.error(f"导入模板类失败: {e}") + raise ValueError(f"无法加载模板类: {template_id}") + + def reload_config(self): + """重新加载配置""" + logger.info("正在重新加载配置...") + self.load_config() + + +# 全局配置管理器实例 +_config_manager = None + + +def get_poster_config_manager() -> PosterConfigManager: + """获取全局配置管理器实例""" + global _config_manager + if _config_manager is None: + _config_manager = PosterConfigManager() + return _config_manager \ No newline at end of file diff --git a/api/config/poster_prompts.yaml b/api/config/poster_prompts.yaml new file mode 100644 index 0000000..9f5bd54 --- /dev/null +++ b/api/config/poster_prompts.yaml @@ -0,0 +1,102 @@ +# 海报生成提示词配置文件 +poster_prompts: + vibrant: + system_prompt: | + 你是一名专业的海报设计师,专门设计宣传海报。你现在要根据用户提供的信息,生成适合Vibrant模板的海报内容。 + + ## Vibrant模板特点: + - 单图背景,毛玻璃渐变效果 + - 两栏布局(左栏内容,右栏价格) + - 适合展示套餐内容和价格信息 + + ## 你需要生成的数据结构包含以下字段: + + **必填字段:** + 1. `title`: 主标题(8-12字符,体现产品特色) + 2. `slogan`: 副标题/宣传语(10-20字符,吸引人的描述) + 3. `price`: 价格数字(纯数字,不含符号),如果材料中没有价格,则用"欢迎咨询"替代 + 4. `ticket_type`: 票种类型(如"成人票"、"套餐票"、"夜场票"等) + 5. `content_button`: 内容按钮文字(通常为"套餐内容"、"包含项目"等) + 6. `content_items`: 套餐内容列表(3-5个项目,每项5-15字符,不要只包含项目名称,要做合适的美化,可以适当省略) + + **可选字段:** + 7. `remarks`: 备注信息(1-3条,每条10-20字符) + 8. `tag`: 标签(1条, 如"#限时优惠"等) + 9. `pagination`: 分页信息(如"1/3",可为空) + + ## 内容创作要求: + 1. 套餐内容要具体实用:明确说明包含的服务、时间、数量 + 2. 价格要有吸引力:突出性价比和优惠信息 + + ## 输出格式: + 请严格按照JSON格式输出,不要有任何额外内容。 + + user_prompt_template: | + 请根据以下信息,生成适合在旅游海报上展示的文案: + + ## 景区信息 + {scenic_info} + + ## 产品信息 + {product_info} + + ## 推文信息 + {tweet_info} + + 请提取关键信息并整合成一个JSON对象,包含title、slogan、price、ticket_type、content_items、remarks和tag字段。 + + business: + system_prompt: | + 你是一名专业的商务海报设计师。你需要根据提供的信息生成适合Business模板的海报内容。 + + ## Business模板特点: + - 商务风格,简洁专业 + - 突出核心信息和价值主张 + - 适合企业服务推广 + + ## 生成字段结构: + 1. `title`: 服务标题(6-10字符) + 2. `subtitle`: 副标题(12-20字符) + 3. `features`: 核心特性列表(3-4个特性) + 4. `price`: 价格信息 + 5. `contact`: 联系方式信息 + + 请以JSON格式输出结果。 + + user_prompt_template: | + 请为以下商务信息生成海报内容: + + ## 服务信息 + {service_info} + + ## 目标客户 + {target_audience} + + ## 核心卖点 + {key_points} + +# 模板配置 +templates: + vibrant: + name: "Vibrant活力模板" + description: "适合旅游景点、娱乐活动的活力海报模板" + size: [900, 1200] + required_fields: ["title", "slogan", "price", "ticket_type", "content_items"] + optional_fields: ["remarks", "tag", "pagination"] + prompt_key: "vibrant" + + business: + name: "Business商务模板" + description: "适合企业服务、B2B推广的商务海报模板" + size: [1080, 1920] + required_fields: ["title", "subtitle", "features", "price"] + optional_fields: ["contact"] + prompt_key: "business" + +# 默认配置 +defaults: + template: "vibrant" + temperature: 0.7 + output_dir: "result/posters" + image_dir: "/root/TravelContentCreator/data/images" + font_dir: "/root/TravelContentCreator/assets/font" \ No newline at end of file diff --git a/api/main.py b/api/main.py index aba7b57..b000300 100644 --- a/api/main.py +++ b/api/main.py @@ -56,11 +56,12 @@ app.add_middleware( ) # 导入路由 -from api.routers import tweet, poster, prompt, document, data, integration, content_integration +from api.routers import tweet, poster, poster_unified, prompt, document, data, integration, content_integration # 包含路由 app.include_router(tweet.router, prefix="/api/v1/tweet", tags=["tweet"]) app.include_router(poster.router, prefix="/api/v1/poster", tags=["poster"]) +app.include_router(poster_unified.router, prefix="/api/v2/poster", tags=["poster-unified"]) app.include_router(prompt.router, prefix="/api/v1/prompt", tags=["prompt"]) app.include_router(document.router, prefix="/api/v1/document", tags=["document"]) app.include_router(data.router, prefix="/api/v1", tags=["data"]) diff --git a/api/models/__pycache__/tweet.cpython-312.pyc b/api/models/__pycache__/tweet.cpython-312.pyc index 45e950e..e19f132 100644 Binary files a/api/models/__pycache__/tweet.cpython-312.pyc and b/api/models/__pycache__/tweet.cpython-312.pyc differ diff --git a/api/models/__pycache__/vibrant_poster.cpython-312.pyc b/api/models/__pycache__/vibrant_poster.cpython-312.pyc new file mode 100644 index 0000000..6a4d3ab Binary files /dev/null and b/api/models/__pycache__/vibrant_poster.cpython-312.pyc differ diff --git a/api/models/vibrant_poster.py b/api/models/vibrant_poster.py new file mode 100644 index 0000000..c1b6218 --- /dev/null +++ b/api/models/vibrant_poster.py @@ -0,0 +1,331 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +海报API通用模型定义 +支持多种模板类型的海报生成 +""" + +from typing import List, Dict, Any, Optional, Union +from pydantic import BaseModel, Field + + +# 基础模板信息 +class TemplateInfo(BaseModel): + """模板信息模型""" + id: str = Field(..., description="模板ID") + name: str = Field(..., description="模板名称") + description: str = Field(..., description="模板描述") + size: List[int] = Field(..., description="模板尺寸 [宽, 高]") + required_fields: List[str] = Field(..., description="必填字段列表") + optional_fields: List[str] = Field(default=[], description="可选字段列表") + + class Config: + schema_extra = { + "example": { + "id": "vibrant", + "name": "Vibrant活力模板", + "description": "适合旅游景点、娱乐活动的活力海报模板", + "size": [900, 1200], + "required_fields": ["title", "slogan", "price", "ticket_type", "content_items"], + "optional_fields": ["remarks", "tag", "pagination"] + } + } + + +# 通用内容模型(支持任意字段) +class PosterContent(BaseModel): + """通用海报内容模型,支持动态字段""" + + class Config: + extra = "allow" # 允许额外字段 + schema_extra = { + "example": { + "title": "海洋奇幻世界", + "slogan": "探索深海秘境,感受蓝色奇迹的无限魅力", + "price": "299", + "ticket_type": "成人票", + "content_items": [ + "海洋馆门票1张(含所有展区)", + "海豚表演VIP座位" + ] + } + } + + +# 保持向后兼容的Vibrant内容模型 +class VibrantPosterContent(PosterContent): + """Vibrant海报内容模型(向后兼容)""" + title: Optional[str] = Field(None, description="主标题(8-12字符)") + slogan: Optional[str] = Field(None, description="副标题/宣传语(10-20字符)") + price: Optional[str] = Field(None, description="价格,如果没有价格则用'欢迎咨询'") + ticket_type: Optional[str] = Field(None, description="票种类型(如成人票、套餐票等)") + content_button: Optional[str] = Field(None, description="内容按钮文字") + content_items: Optional[List[str]] = Field(None, description="套餐内容列表(3-5个项目)") + remarks: Optional[List[str]] = Field(None, description="备注信息(1-3条)") + tag: Optional[str] = Field(None, description="标签") + pagination: Optional[str] = Field(None, description="分页信息") + + +# 通用海报生成请求 +class PosterGenerationRequest(BaseModel): + """通用海报生成请求模型""" + # 模板相关 + template_id: str = Field(..., description="模板ID") + + # 内容相关(二选一) + content: Optional[Dict[str, Any]] = Field(None, description="直接提供的海报内容") + source_data: Optional[Dict[str, Any]] = Field(None, description="源数据,用于AI生成内容") + + # 生成参数 + topic_name: Optional[str] = Field(None, description="主题名称,用于文件命名") + image_path: Optional[str] = Field(None, description="指定图片路径,如果不提供则随机选择") + image_dir: Optional[str] = Field(None, description="图片目录,如果不提供使用默认目录") + output_dir: Optional[str] = Field(None, description="输出目录,如果不提供使用默认目录") + + # AI参数 + temperature: Optional[float] = Field(default=0.7, description="AI生成温度参数") + + class Config: + schema_extra = { + "example": { + "template_id": "vibrant", + "source_data": { + "scenic_info": { + "name": "天津冒险湾", + "location": "天津市", + "type": "水上乐园" + }, + "product_info": { + "name": "冒险湾门票", + "price": 299, + "type": "成人票" + }, + "tweet_info": { + "title": "夏日清凉首选", + "content": "天津冒险湾水上乐园,多种刺激项目等你来挑战..." + } + }, + "topic_name": "天津冒险湾", + "temperature": 0.7 + } + } + + +# 内容生成请求 +class ContentGenerationRequest(BaseModel): + """内容生成请求模型""" + template_id: str = Field(..., description="模板ID") + source_data: Dict[str, Any] = Field(..., description="源数据,用于AI生成内容") + temperature: Optional[float] = Field(default=0.7, description="AI生成温度参数") + + class Config: + schema_extra = { + "example": { + "template_id": "vibrant", + "source_data": { + "scenic_info": {"name": "天津冒险湾", "type": "水上乐园"}, + "product_info": {"name": "门票", "price": 299}, + "tweet_info": {"title": "夏日清凉", "content": "水上乐园体验"} + }, + "temperature": 0.7 + } + } + + +# 保持向后兼容 +class VibrantPosterRequest(PosterGenerationRequest): + """Vibrant海报生成请求模型(向后兼容)""" + template_id: str = Field(default="vibrant", description="模板ID,默认为vibrant") + + # 兼容旧的字段结构 + scenic_info: Optional[Dict[str, Any]] = Field(None, description="景区信息") + product_info: Optional[Dict[str, Any]] = Field(None, description="产品信息") + tweet_info: Optional[Dict[str, Any]] = Field(None, description="推文信息") + + def __init__(self, **data): + # 转换旧格式到新格式 + if 'scenic_info' in data or 'product_info' in data or 'tweet_info' in data: + source_data = {} + for key in ['scenic_info', 'product_info', 'tweet_info']: + if key in data and data[key] is not None: + source_data[key] = data.pop(key) + if source_data and 'source_data' not in data: + data['source_data'] = source_data + + super().__init__(**data) + + +# 标准API响应基础类 +class BaseAPIResponse(BaseModel): + """API响应基础模型""" + success: bool = Field(..., description="操作是否成功") + message: str = Field(default="", description="响应消息") + request_id: str = Field(..., description="请求ID") + timestamp: str = Field(..., description="响应时间戳") + + +# 通用海报生成响应 +class PosterGenerationResponse(BaseAPIResponse): + """通用海报生成响应模型""" + data: Optional[Dict[str, Any]] = Field(None, description="响应数据") + + class Config: + schema_extra = { + "example": { + "success": True, + "message": "海报生成成功", + "request_id": "poster-20240715-123456-a1b2c3d4", + "timestamp": "2024-07-15T12:34:56Z", + "data": { + "template_id": "vibrant", + "topic_name": "天津冒险湾", + "poster_path": "/result/posters/vibrant_海洋奇幻世界_20240715_123456.png", + "content": { + "title": "海洋奇幻世界", + "slogan": "探索深海秘境,感受蓝色奇迹的无限魅力", + "price": "299" + }, + "metadata": { + "image_used": "/data/images/ocean_park.jpg", + "generation_method": "ai_generated", + "template_size": [900, 1200], + "processing_time": 3.2 + } + } + } + } + + +# 内容生成响应 +class ContentGenerationResponse(BaseAPIResponse): + """内容生成响应模型""" + data: Optional[Dict[str, Any]] = Field(None, description="生成的内容") + + class Config: + schema_extra = { + "example": { + "success": True, + "message": "内容生成成功", + "request_id": "content-20240715-123456-a1b2c3d4", + "timestamp": "2024-07-15T12:34:56Z", + "data": { + "template_id": "vibrant", + "content": { + "title": "冒险湾水世界", + "slogan": "夏日激情体验,清凉无限乐趣", + "price": "299", + "ticket_type": "成人票" + }, + "metadata": { + "generation_method": "ai_generated", + "temperature": 0.7 + } + } + } + } + + +# 模板列表响应 +class TemplateListResponse(BaseAPIResponse): + """模板列表响应模型""" + data: Optional[List[TemplateInfo]] = Field(None, description="模板列表") + + class Config: + schema_extra = { + "example": { + "success": True, + "message": "获取模板列表成功", + "request_id": "templates-20240715-123456", + "timestamp": "2024-07-15T12:34:56Z", + "data": [ + { + "id": "vibrant", + "name": "Vibrant活力模板", + "description": "适合旅游景点、娱乐活动的活力海报模板", + "size": [900, 1200], + "required_fields": ["title", "slogan", "price"], + "optional_fields": ["remarks", "tag"] + } + ] + } + } + + +# 保持向后兼容的响应模型 +class VibrantPosterResponse(BaseModel): + """Vibrant海报生成响应模型(向后兼容)""" + request_id: str = Field(..., description="请求ID") + topic_name: str = Field(..., description="主题名称") + poster_path: str = Field(..., description="生成的海报文件路径") + generated_content: Dict[str, Any] = Field(..., description="生成或使用的海报内容") + image_used: str = Field(..., description="使用的图片路径") + generation_method: str = Field(..., description="生成方式:direct(直接使用提供内容)或ai_generated(AI生成)") + + class Config: + schema_extra = { + "example": { + "request_id": "vibrant-20240715-123456-a1b2c3d4", + "topic_name": "天津冒险湾", + "poster_path": "/result/posters/vibrant_海洋奇幻世界_20240715_123456.png", + "generated_content": { + "title": "海洋奇幻世界", + "slogan": "探索深海秘境,感受蓝色奇迹的无限魅力", + "price": "299", + "ticket_type": "成人票", + "content_items": ["海洋馆门票1张", "海豚表演VIP座位"] + }, + "image_used": "/data/images/ocean_park.jpg", + "generation_method": "ai_generated" + } + } + + +class BatchVibrantPosterRequest(BaseModel): + """批量Vibrant海报生成请求模型""" + base_path: str = Field(..., description="包含多个topic目录的基础路径") + image_dir: Optional[str] = Field(None, description="图片目录") + scenic_info_file: Optional[str] = Field(None, description="景区信息文件路径") + product_info_file: Optional[str] = Field(None, description="产品信息文件路径") + output_base: Optional[str] = Field(default="result/posters", description="输出基础目录") + parallel_count: Optional[int] = Field(default=3, description="并发处理数量") + temperature: Optional[float] = Field(default=0.7, description="AI生成温度参数") + + class Config: + schema_extra = { + "example": { + "base_path": "/root/TravelContentCreator/result/run_20250710_165327", + "image_dir": "/root/TravelContentCreator/data/images", + "scenic_info_file": "/root/TravelContentCreator/resource/data/Object/天津冒险湾.txt", + "product_info_file": "/root/TravelContentCreator/resource/data/Product/product.bak", + "output_base": "result/posters", + "parallel_count": 3, + "temperature": 0.7 + } + } + + +class BatchVibrantPosterResponse(BaseModel): + """批量Vibrant海报生成响应模型""" + request_id: str = Field(..., description="批量处理请求ID") + total_topics: int = Field(..., description="总共处理的topic数量") + successful_count: int = Field(..., description="成功生成的海报数量") + failed_count: int = Field(..., description="失败的海报数量") + output_base_dir: str = Field(..., description="输出基础目录") + successful_topics: List[str] = Field(..., description="成功处理的topic列表") + failed_topics: List[Dict[str, str]] = Field(..., description="失败的topic及错误信息") + + class Config: + schema_extra = { + "example": { + "request_id": "batch-vibrant-20240715-123456", + "total_topics": 5, + "successful_count": 4, + "failed_count": 1, + "output_base_dir": "/result/posters/run_20250710_165327", + "successful_topics": ["topic_1", "topic_2", "topic_3", "topic_4"], + "failed_topics": [ + {"topic": "topic_5", "error": "图片文件不存在"} + ] + } + } \ No newline at end of file diff --git a/api/routers/__pycache__/poster_unified.cpython-312.pyc b/api/routers/__pycache__/poster_unified.cpython-312.pyc new file mode 100644 index 0000000..3af06e7 Binary files /dev/null and b/api/routers/__pycache__/poster_unified.cpython-312.pyc differ diff --git a/api/routers/__pycache__/tweet.cpython-312.pyc b/api/routers/__pycache__/tweet.cpython-312.pyc index e3c7635..16f1316 100644 Binary files a/api/routers/__pycache__/tweet.cpython-312.pyc and b/api/routers/__pycache__/tweet.cpython-312.pyc differ diff --git a/api/routers/__pycache__/vibrant_poster.cpython-312.pyc b/api/routers/__pycache__/vibrant_poster.cpython-312.pyc new file mode 100644 index 0000000..bdab10c Binary files /dev/null and b/api/routers/__pycache__/vibrant_poster.cpython-312.pyc differ diff --git a/api/routers/poster_unified.py b/api/routers/poster_unified.py new file mode 100644 index 0000000..86c87c4 --- /dev/null +++ b/api/routers/poster_unified.py @@ -0,0 +1,314 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +统一海报API路由 +支持多种模板类型的海报生成,配置化管理 +""" + +import logging +import uuid +from datetime import datetime, timezone +from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks +from typing import Dict, Any, List + +from core.ai import AIAgent +from api.services.poster_service import UnifiedPosterService +from api.models.vibrant_poster import ( + PosterGenerationRequest, PosterGenerationResponse, + ContentGenerationRequest, ContentGenerationResponse, + TemplateListResponse, TemplateInfo, + BaseAPIResponse +) + +# 从依赖注入模块导入依赖 +from api.dependencies import get_ai_agent + +logger = logging.getLogger(__name__) + +# 创建路由 +router = APIRouter() + +# 依赖注入函数 +def get_unified_poster_service( + ai_agent: AIAgent = Depends(get_ai_agent) +) -> UnifiedPosterService: + """获取统一海报服务""" + return UnifiedPosterService(ai_agent) + + +def create_response(success: bool, message: str, data: Any = None, request_id: str = None) -> Dict[str, Any]: + """创建标准响应""" + if request_id is None: + request_id = f"req-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}" + + return { + "success": success, + "message": message, + "request_id": request_id, + "timestamp": datetime.now(timezone.utc).isoformat(), + "data": data + } + + +@router.get("/templates", response_model=TemplateListResponse, summary="获取所有可用模板") +async def get_templates( + service: UnifiedPosterService = Depends(get_unified_poster_service) +): + """ + 获取所有可用的海报模板列表 + + 返回每个模板的详细信息,包括: + - 模板ID和名称 + - 模板描述 + - 模板尺寸 + - 必填字段和可选字段 + """ + try: + templates = service.get_available_templates() + response_data = create_response( + success=True, + message="获取模板列表成功", + data=templates + ) + return TemplateListResponse(**response_data) + + except Exception as e: + logger.error(f"获取模板列表失败: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"获取模板列表失败: {str(e)}") + + +@router.get("/templates/{template_id}", response_model=BaseAPIResponse, summary="获取指定模板信息") +async def get_template_info( + template_id: str, + service: UnifiedPosterService = Depends(get_unified_poster_service) +): + """ + 获取指定模板的详细信息 + + 参数: + - **template_id**: 模板ID + """ + try: + template_info = service.get_template_info(template_id) + if not template_info: + raise HTTPException(status_code=404, detail=f"模板 {template_id} 不存在") + + response_data = create_response( + success=True, + message="获取模板信息成功", + data=template_info.dict() + ) + return BaseAPIResponse(**response_data) + + except HTTPException: + raise + except Exception as e: + logger.error(f"获取模板信息失败: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"获取模板信息失败: {str(e)}") + + +@router.post("/content/generate", response_model=ContentGenerationResponse, summary="生成海报内容") +async def generate_content( + request: ContentGenerationRequest, + service: UnifiedPosterService = Depends(get_unified_poster_service) +): + """ + 根据源数据生成海报内容,不生成实际图片 + + 用于: + 1. 预览生成的内容 + 2. 调试和测试内容生成 + 3. 分步骤生成(先生成内容,再生成图片) + + 参数: + - **template_id**: 模板ID + - **source_data**: 源数据,用于AI生成内容 + - **temperature**: AI生成温度参数 + """ + try: + content = await service.generate_content( + template_id=request.template_id, + source_data=request.source_data, + temperature=request.temperature + ) + + response_data = create_response( + success=True, + message="内容生成成功", + data={ + "template_id": request.template_id, + "content": content, + "metadata": { + "generation_method": "ai_generated", + "temperature": request.temperature + } + } + ) + return ContentGenerationResponse(**response_data) + + except Exception as e: + logger.error(f"生成内容失败: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"生成内容失败: {str(e)}") + + +@router.post("/generate", response_model=PosterGenerationResponse, summary="生成海报") +async def generate_poster( + request: PosterGenerationRequest, + service: UnifiedPosterService = Depends(get_unified_poster_service) +): + """ + 生成海报图片 + + 支持两种模式: + 1. 直接提供内容(content字段) + 2. 提供源数据让AI生成内容(source_data字段) + + 参数: + - **template_id**: 模板ID + - **content**: 直接提供的海报内容(可选) + - **source_data**: 源数据,用于AI生成内容(可选) + - **topic_name**: 主题名称,用于文件命名 + - **image_path**: 指定图片路径(可选) + - **image_dir**: 图片目录(可选) + - **output_dir**: 输出目录(可选) + - **temperature**: AI生成温度参数 + """ + try: + result = await service.generate_poster( + template_id=request.template_id, + content=request.content, + source_data=request.source_data, + topic_name=request.topic_name, + image_path=request.image_path, + image_dir=request.image_dir, + output_dir=request.output_dir, + temperature=request.temperature + ) + + response_data = create_response( + success=True, + message="海报生成成功", + data=result, + request_id=result["request_id"] + ) + return PosterGenerationResponse(**response_data) + + except Exception as e: + logger.error(f"生成海报失败: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"生成海报失败: {str(e)}") + + +@router.post("/batch", response_model=BaseAPIResponse, summary="批量生成海报") +async def batch_generate_posters( + template_id: str, + base_path: str, + image_dir: str = None, + source_files: Dict[str, str] = None, + output_base: str = "result/posters", + parallel_count: int = 3, + temperature: float = 0.7, + service: UnifiedPosterService = Depends(get_unified_poster_service) +): + """ + 批量生成海报 + + 自动扫描指定目录下的topic文件夹,为每个topic生成海报。 + + 参数: + - **template_id**: 模板ID + - **base_path**: 包含多个topic目录的基础路径 + - **image_dir**: 图片目录(可选) + - **source_files**: 源文件配置字典(可选) + - **output_base**: 输出基础目录 + - **parallel_count**: 并发处理数量 + - **temperature**: AI生成温度参数 + """ + try: + result = await service.batch_generate_posters( + template_id=template_id, + base_path=base_path, + image_dir=image_dir, + source_files=source_files or {}, + output_base=output_base, + parallel_count=parallel_count, + temperature=temperature + ) + + response_data = create_response( + success=True, + message=f"批量生成完成,成功: {result['successful_count']}, 失败: {result['failed_count']}", + data=result, + request_id=result["request_id"] + ) + return BaseAPIResponse(**response_data) + + except Exception as e: + logger.error(f"批量生成海报失败: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"批量生成海报失败: {str(e)}") + + +@router.post("/config/reload", response_model=BaseAPIResponse, summary="重新加载配置") +async def reload_config( + service: UnifiedPosterService = Depends(get_unified_poster_service) +): + """ + 重新加载海报配置 + + 用于在不重启服务的情况下更新配置,包括: + - 提示词模板 + - 模板配置 + - 默认参数 + """ + try: + service.reload_config() + response_data = create_response( + success=True, + message="配置重新加载成功" + ) + return BaseAPIResponse(**response_data) + + except Exception as e: + logger.error(f"重新加载配置失败: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"重新加载配置失败: {str(e)}") + + +@router.get("/health", summary="健康检查") +async def health_check(): + """服务健康检查""" + return create_response( + success=True, + message="统一海报服务运行正常", + data={ + "service": "unified_poster", + "status": "healthy", + "version": "2.0.0" + } + ) + + +@router.get("/config", summary="获取服务配置") +async def get_service_config( + service: UnifiedPosterService = Depends(get_unified_poster_service) +): + """获取服务配置信息""" + try: + config_info = { + "default_image_dir": service.config_manager.get_default_config("image_dir"), + "default_output_dir": service.config_manager.get_default_config("output_dir"), + "default_font_dir": service.config_manager.get_default_config("font_dir"), + "default_template": service.config_manager.get_default_config("template"), + "supported_image_formats": ["png", "jpg", "jpeg", "webp"], + "available_templates": len(service.get_available_templates()) + } + + response_data = create_response( + success=True, + message="获取配置成功", + data=config_info + ) + return BaseAPIResponse(**response_data) + + except Exception as e: + logger.error(f"获取配置失败: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"获取配置失败: {str(e)}") \ No newline at end of file diff --git a/api/services/__pycache__/poster_service.cpython-312.pyc b/api/services/__pycache__/poster_service.cpython-312.pyc new file mode 100644 index 0000000..43ad673 Binary files /dev/null and b/api/services/__pycache__/poster_service.cpython-312.pyc differ diff --git a/api/services/__pycache__/tweet.cpython-312.pyc b/api/services/__pycache__/tweet.cpython-312.pyc index a58d4bd..2ba49f2 100644 Binary files a/api/services/__pycache__/tweet.cpython-312.pyc and b/api/services/__pycache__/tweet.cpython-312.pyc differ diff --git a/api/services/__pycache__/vibrant_poster.cpython-312.pyc b/api/services/__pycache__/vibrant_poster.cpython-312.pyc new file mode 100644 index 0000000..6e191c0 Binary files /dev/null and b/api/services/__pycache__/vibrant_poster.cpython-312.pyc differ diff --git a/api/services/poster_service.py b/api/services/poster_service.py new file mode 100644 index 0000000..026c13c --- /dev/null +++ b/api/services/poster_service.py @@ -0,0 +1,409 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +通用海报服务层 +支持多种模板类型的海报生成,配置化管理 +""" + +import os +import sys +import json +import random +import asyncio +import logging +import uuid +import time +from datetime import datetime, timezone +from pathlib import Path +from typing import List, Dict, Any, Optional, Tuple + +# 添加项目根目录到路径 +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +from core.ai.ai_agent import AIAgent +from api.config.poster_config_manager import get_poster_config_manager +from api.models.vibrant_poster import TemplateInfo + +logger = logging.getLogger(__name__) + + +class UnifiedPosterService: + """统一海报服务类""" + + def __init__(self, ai_agent: AIAgent): + """ + 初始化统一海报服务 + + Args: + ai_agent: AI代理 + """ + self.ai_agent = ai_agent + self.config_manager = get_poster_config_manager() + + def get_available_templates(self) -> List[TemplateInfo]: + """获取所有可用的模板列表""" + template_list = self.config_manager.get_template_list() + return [TemplateInfo(**template) for template in template_list] + + def get_template_info(self, template_id: str) -> Optional[TemplateInfo]: + """获取指定模板的信息""" + template_info = self.config_manager.get_template_info(template_id) + if template_info: + return TemplateInfo( + id=template_id, + name=template_info.get("name", template_id), + description=template_info.get("description", ""), + size=template_info.get("size", [900, 1200]), + required_fields=template_info.get("required_fields", []), + optional_fields=template_info.get("optional_fields", []) + ) + return None + + async def generate_content(self, template_id: str, source_data: Dict[str, Any], + temperature: float = 0.7) -> Dict[str, Any]: + """ + 生成海报内容 + + Args: + template_id: 模板ID + source_data: 源数据 + temperature: AI生成温度参数 + + Returns: + 生成的内容字典 + """ + logger.info(f"正在为模板 {template_id} 生成内容...") + + # 获取系统提示词 + system_prompt = self.config_manager.get_system_prompt(template_id) + if not system_prompt: + raise ValueError(f"模板 {template_id} 没有配置系统提示词") + + # 格式化用户提示词 + user_prompt = self.config_manager.format_user_prompt(template_id, **source_data) + if not user_prompt: + raise ValueError(f"模板 {template_id} 用户提示词格式化失败") + + try: + # 调用AI生成内容 + response, _, _, _ = await self.ai_agent.generate_text( + system_prompt=system_prompt, + user_prompt=user_prompt, + temperature=temperature, + stage=f"海报内容生成-{template_id}" + ) + + # 解析JSON响应 + json_start = response.find('{') + json_end = response.rfind('}') + 1 + + if json_start >= 0 and json_end > json_start: + json_str = response[json_start:json_end] + content_dict = json.loads(json_str) + logger.info(f"AI成功生成内容: {content_dict}") + + # 确保所有值都是字符串类型(除了列表) + for key, value in content_dict.items(): + if isinstance(value, (int, float)): + content_dict[key] = str(value) + elif isinstance(value, list): + content_dict[key] = [str(item) if isinstance(item, (int, float)) else item for item in value] + + return content_dict + else: + logger.error(f"无法在AI响应中找到JSON对象: {response}") + raise ValueError("AI响应格式不正确") + + except json.JSONDecodeError as e: + logger.error(f"无法解析AI响应为JSON: {e}") + raise ValueError("AI响应JSON解析失败") + except Exception as e: + logger.error(f"调用AI生成内容时发生错误: {e}") + raise + + def select_random_image(self, image_dir: Optional[str] = None) -> str: + """从指定目录随机选择一张图片""" + if image_dir is None: + image_dir = self.config_manager.get_default_config("image_dir") + + try: + image_files = [f for f in os.listdir(image_dir) + if f.lower().endswith(('png', 'jpg', 'jpeg', 'webp'))] + if not image_files: + raise ValueError(f"在目录 {image_dir} 中未找到任何图片文件") + + random_image_name = random.choice(image_files) + image_path = os.path.join(image_dir, random_image_name) + logger.info(f"随机选择图片: {image_path}") + return image_path + except FileNotFoundError: + raise ValueError(f"图片目录不存在: {image_dir}") + + def validate_content(self, template_id: str, content: Dict[str, Any]) -> None: + """验证内容是否符合模板要求""" + is_valid, errors = self.config_manager.validate_template_content(template_id, content) + if not is_valid: + raise ValueError(f"内容验证失败: {', '.join(errors)}") + + async def generate_poster(self, + template_id: str, + content: Optional[Dict[str, Any]] = None, + source_data: Optional[Dict[str, Any]] = None, + topic_name: Optional[str] = None, + image_path: Optional[str] = None, + image_dir: Optional[str] = None, + output_dir: Optional[str] = None, + temperature: float = 0.7) -> Dict[str, Any]: + """ + 生成海报 + + Args: + template_id: 模板ID + content: 直接提供的内容(可选) + source_data: 源数据,用于AI生成内容(可选) + topic_name: 主题名称 + image_path: 指定图片路径 + image_dir: 图片目录 + output_dir: 输出目录 + temperature: AI生成温度参数 + + Returns: + 生成结果字典 + """ + start_time = time.time() + + logger.info(f"开始生成海报,模板: {template_id}, 主题: {topic_name}") + + # 生成请求ID + request_id = f"poster-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}" + + # 获取模板信息 + template_info = self.get_template_info(template_id) + if not template_info: + raise ValueError(f"未知的模板ID: {template_id}") + + # 确定内容 + if content is None: + if source_data is None: + raise ValueError("必须提供content或source_data中的一个") + + # 使用AI生成内容 + content = await self.generate_content(template_id, source_data, temperature) + generation_method = "ai_generated" + else: + generation_method = "direct" + + # 验证内容 + self.validate_content(template_id, content) + + # 选择图片 + if image_path is None: + image_path = self.select_random_image(image_dir) + + if not os.path.exists(image_path): + raise ValueError(f"指定的图片文件不存在: {image_path}") + + # 设置默认值 + if output_dir is None: + output_dir = self.config_manager.get_default_config("output_dir") + if topic_name is None: + topic_name = f"poster_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + # 获取模板类并生成海报 + try: + template_class = self.config_manager.get_template_class(template_id) + template_instance = template_class(template_info.size) + + # 设置字体(如果支持) + font_dir = self.config_manager.get_default_config("font_dir") + if hasattr(template_instance, 'set_font_dir') and font_dir: + template_instance.set_font_dir(font_dir) + + poster = template_instance.generate(image_path=image_path, content=content) + + if not poster: + raise ValueError("海报生成失败,模板返回了 None") + + except Exception as e: + logger.error(f"生成海报时发生错误: {e}", exc_info=True) + raise ValueError(f"海报生成失败: {str(e)}") + + # 保存海报 + try: + os.makedirs(output_dir, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + # 生成文件名 + title = content.get('title', topic_name) + if isinstance(title, str): + title = title.replace('/', '_').replace('\\', '_') + output_filename = f"{template_id}_{title}_{timestamp}.png" + poster_path = os.path.join(output_dir, output_filename) + + poster.save(poster_path, 'PNG') + logger.info(f"海报已成功生成并保存至: {poster_path}") + + processing_time = round(time.time() - start_time, 2) + + return { + "request_id": request_id, + "template_id": template_id, + "topic_name": topic_name, + "poster_path": poster_path, + "content": content, + "metadata": { + "image_used": image_path, + "generation_method": generation_method, + "template_size": template_info.size, + "processing_time": processing_time, + "timestamp": datetime.now(timezone.utc).isoformat() + } + } + + except Exception as e: + logger.error(f"保存海报失败: {e}", exc_info=True) + raise ValueError(f"保存海报失败: {str(e)}") + + async def batch_generate_posters(self, + template_id: str, + base_path: str, + image_dir: Optional[str] = None, + source_files: Optional[Dict[str, str]] = None, + output_base: str = "result/posters", + parallel_count: int = 3, + temperature: float = 0.7) -> Dict[str, Any]: + """ + 批量生成海报 + + Args: + template_id: 模板ID + base_path: 包含多个topic目录的基础路径 + image_dir: 图片目录 + source_files: 源文件配置字典 + output_base: 输出基础目录 + parallel_count: 并发数量 + temperature: AI生成温度参数 + + Returns: + 批量处理结果 + """ + logger.info(f"开始批量生成海报,模板: {template_id}, 基础路径: {base_path}") + + # 生成批处理ID + batch_request_id = f"batch-{template_id}-{datetime.now().strftime('%Y%m%d-%H%M%S')}" + + # 查找topic目录 + topic_dirs = self._find_topic_directories(base_path) + + if not topic_dirs: + raise ValueError("未找到任何包含article_judged.json的topic目录") + + logger.info(f"找到 {len(topic_dirs)} 个topic目录,准备批量生成海报") + + # 准备输出目录 + base_name = Path(base_path).name + output_base_dir = os.path.join(output_base, base_name) + + # 这里简化实现,实际项目中可以使用asyncio.gather进行真正的异步批处理 + results = [] + successful_count = 0 + failed_count = 0 + + for topic_path, topic_name in topic_dirs: + try: + article_path = os.path.join(topic_path, 'article_judged.json') + topic_output_dir = os.path.join(output_base_dir, topic_name) + + # 读取文章数据 + source_data = self._read_data_file(article_path) + if not source_data: + raise ValueError(f"无法读取文章文件: {article_path}") + + # 构建源数据 + final_source_data = {"tweet_info": source_data} + + # 如果提供了额外的源文件,读取并添加 + if source_files: + for key, file_path in source_files.items(): + if file_path and os.path.exists(file_path): + data = self._read_data_file(file_path) + if data: + final_source_data[key] = data + + # 生成海报 + result = await self.generate_poster( + template_id=template_id, + source_data=final_source_data, + topic_name=topic_name, + image_dir=image_dir, + output_dir=topic_output_dir, + temperature=temperature + ) + + results.append({ + "topic": topic_name, + "success": True, + "result": result + }) + successful_count += 1 + logger.info(f"成功生成海报: {topic_name}") + + except Exception as e: + error_msg = str(e) + results.append({ + "topic": topic_name, + "success": False, + "error": error_msg + }) + failed_count += 1 + logger.error(f"生成海报失败 {topic_name}: {error_msg}") + + successful_topics = [r["topic"] for r in results if r["success"]] + failed_topics = [{"topic": r["topic"], "error": r["error"]} for r in results if not r["success"]] + + return { + "request_id": batch_request_id, + "template_id": template_id, + "total_topics": len(topic_dirs), + "successful_count": successful_count, + "failed_count": failed_count, + "output_base_dir": output_base_dir, + "successful_topics": successful_topics, + "failed_topics": failed_topics, + "detailed_results": results + } + + def _find_topic_directories(self, base_path: str) -> List[Tuple[str, str]]: + """查找topic目录""" + topic_dirs = [] + base_path = Path(base_path) + + if not base_path.exists(): + return topic_dirs + + for item in base_path.iterdir(): + if item.is_dir() and item.name.startswith('topic_'): + article_path = item / 'article_judged.json' + if article_path.exists(): + topic_dirs.append((str(item), item.name)) + + return topic_dirs + + def _read_data_file(self, file_path: str) -> Optional[Dict[str, Any]]: + """读取数据文件(简化版)""" + try: + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + try: + return json.loads(content) + except json.JSONDecodeError: + # 简单的文本内容处理 + return {"content": content} + except Exception as e: + logger.error(f"读取文件失败 {file_path}: {e}") + return None + + def reload_config(self): + """重新加载配置""" + self.config_manager.reload_config() \ No newline at end of file