diff --git a/api/models/__pycache__/poster.cpython-312.pyc b/api/models/__pycache__/poster.cpython-312.pyc index 19b994a..95239f1 100644 Binary files a/api/models/__pycache__/poster.cpython-312.pyc and b/api/models/__pycache__/poster.cpython-312.pyc differ diff --git a/api/models/poster.py b/api/models/poster.py index f8c48c9..cc19ceb 100644 --- a/api/models/poster.py +++ b/api/models/poster.py @@ -7,6 +7,7 @@ from re import T from typing import List, Dict, Any, Optional +from datetime import datetime from pydantic import BaseModel, Field @@ -20,6 +21,8 @@ class PosterGenerateRequest(BaseModel): scenicSpotId: Optional[int] = Field(None, description="景区ID,用于AI生成内容") numVariations: int = Field(3, description="要生成的海报变体数量, 默认为3", ge=1, le=5) forceLlmGeneration: bool = Field(False, description="是否强制使用LLM重新生成内容") + generatePsd: bool = Field(False, description="是否生成PSD分层文件") + psdOutputPath: Optional[str] = Field(None, description="PSD文件输出路径(可选,默认自动生成)") class Config: json_schema_extra = { @@ -28,6 +31,8 @@ class PosterGenerateRequest(BaseModel): "imagesBase64": "", "numVariations": 1, "forceLlmGeneration":False, + "generatePsd": True, + "psdOutputPath": "custom_poster.psd", "contentId":1, "productId":1, "scenicSpotId":1, @@ -43,9 +48,17 @@ class TemplateInfo(BaseModel): id: str name: str description: str - handlerPath: str - className: str - isActive: bool + handler_path: str = Field(alias="handlerPath") + class_name: str = Field(alias="className") + system_prompt: Optional[str] = None + user_prompt_template: Optional[str] = None + required_fields: Optional[List[str]] = None + optional_fields: Optional[List[str]] = None + size: Optional[List[int]] = None + is_active: bool = Field(alias="isActive") + + class Config: + allow_population_by_field_name = True class TemplateListResponse(BaseModel): """模板列表响应模型""" @@ -65,7 +78,13 @@ class PosterGenerateResponse(BaseModel): requestId: str templateId: str resultImagesBase64: List[Dict[str, Any]] = Field(description="生成的海报图像(base64编码)列表") + psdFiles: Optional[List[Dict[str, Any]]] = Field(None, description="生成的PSD文件信息列表") metadata: Dict[str, Any] = Field(default_factory=dict) + + class Config: + json_encoders = { + datetime: lambda v: v.isoformat() + } class ImageUsageRequest(BaseModel): """图像使用查询请求模型""" diff --git a/api/routers/__pycache__/poster.cpython-312.pyc b/api/routers/__pycache__/poster.cpython-312.pyc index fa14f45..348b934 100644 Binary files a/api/routers/__pycache__/poster.cpython-312.pyc and b/api/routers/__pycache__/poster.cpython-312.pyc differ diff --git a/api/routers/poster.py b/api/routers/poster.py index 6e650c3..0a67291 100644 --- a/api/routers/poster.py +++ b/api/routers/poster.py @@ -37,18 +37,52 @@ def get_poster_service( return PosterService(ai_agent, config_manager, output_manager) -@router.get("/templates", response_model=TemplateListResponse, summary="获取可用模板列表") +@router.get("/test/templates", summary="测试模板配置") +async def test_templates( + poster_service: PosterService = Depends(get_poster_service) +): + """ + 测试模板配置是否正确加载 + """ + try: + # 获取所有模板信息 + templates = poster_service._templates + + # 检查每个模板的配置 + result = {} + for template_id, template_info in templates.items(): + result[template_id] = { + "basic_info": template_info, + "has_system_prompt": bool(template_info.get('system_prompt')), + "has_user_prompt_template": bool(template_info.get('user_prompt_template')), + "prompt_lengths": { + "system_prompt": len(template_info.get('system_prompt', '')), + "user_prompt_template": len(template_info.get('user_prompt_template', '')) + } + } + + return { + "message": "模板配置检查完成", + "template_count": len(templates), + "templates": result + } + except Exception as e: + logger.error(f"测试模板配置失败: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"测试失败: {str(e)}") + + +@router.get("/templates", response_model=TemplateListResponse, summary="获取海报模板列表") async def get_templates( poster_service: PosterService = Depends(get_poster_service) ): """ - 获取可用的海报模板列表 + 获取所有可用的海报模板列表 """ try: - templates = await poster_service.get_available_templates() + templates = poster_service.get_available_templates() return TemplateListResponse( templates=templates, - total_count=len(templates) + totalCount=len(templates) ) except Exception as e: logger.error(f"获取模板列表失败: {e}", exc_info=True) @@ -87,11 +121,12 @@ async def generate_poster( - **content_id**: 内容ID(可选) - **product_id**: 产品ID(可选) - **scenic_spot_id**: 景区ID(可选) - - **image_ids**: 图像ID列表(可选) + - **images_base64**: 图像base64编码(可选) - **template_id**: 模板ID(默认为vibrant) - - **generate_collage**: 是否生成拼图 - **poster_content**: 用户提供的海报内容(可选) - **force_llm_generation**: 是否强制使用LLM生成内容(可选) + - **generate_psd**: 是否生成PSD分层文件(可选) + - **psd_output_path**: PSD文件输出路径(可选) """ try: result = await poster_service.generate_poster( @@ -102,7 +137,9 @@ async def generate_poster( scenic_spot_id=request.scenicSpotId, images_base64=request.imagesBase64, num_variations=request.numVariations, - force_llm_generation=request.forceLlmGeneration + force_llm_generation=request.forceLlmGeneration, + generate_psd=request.generatePsd, + psd_output_path=request.psdOutputPath ) return PosterGenerateResponse(**result) diff --git a/api/services/__pycache__/database_service.cpython-312.pyc b/api/services/__pycache__/database_service.cpython-312.pyc index 27bc0d2..fa799f5 100644 Binary files a/api/services/__pycache__/database_service.cpython-312.pyc and b/api/services/__pycache__/database_service.cpython-312.pyc differ diff --git a/api/services/__pycache__/poster.cpython-312.pyc b/api/services/__pycache__/poster.cpython-312.pyc index 94ca12f..4171235 100644 Binary files a/api/services/__pycache__/poster.cpython-312.pyc and b/api/services/__pycache__/poster.cpython-312.pyc differ diff --git a/api/services/database_service.py b/api/services/database_service.py index 78fbd47..8d3a855 100644 --- a/api/services/database_service.py +++ b/api/services/database_service.py @@ -821,7 +821,7 @@ class DatabaseService: with self.db_pool.get_connection() as conn: with conn.cursor(dictionary=True) as cursor: cursor.execute( - "SELECT * FROM posterTemplates ORDER BY createdAt" + "SELECT * FROM poster_templates ORDER BY created_at" ) results = cursor.fetchall() logger.info(f"获取海报模板列表: 找到{len(results)}个模板") @@ -849,7 +849,7 @@ class DatabaseService: with self.db_pool.get_connection() as conn: with conn.cursor(dictionary=True) as cursor: cursor.execute( - "SELECT * FROM posterTemplates WHERE id = %s", + "SELECT * FROM poster_templates WHERE id = %s", (template_id,) ) result = cursor.fetchone() @@ -878,7 +878,7 @@ class DatabaseService: with self.db_pool.get_connection() as conn: with conn.cursor(dictionary=True) as cursor: cursor.execute( - "SELECT * FROM posterTemplates WHERE isActive = 1 ORDER BY createdAt" + "SELECT * FROM poster_templates WHERE is_active = 1 ORDER BY created_at" ) results = cursor.fetchall() logger.info(f"获取激活模板列表: 找到{len(results)}个模板") diff --git a/api/services/poster.py b/api/services/poster.py index 473a366..d924d49 100644 --- a/api/services/poster.py +++ b/api/services/poster.py @@ -63,18 +63,18 @@ class PosterService: 'vibrant': { 'id': 'vibrant', 'name': '活力风格', - 'handlerPath': 'poster.templates.vibrant_template', - 'className': 'VibrantTemplate', + 'handler_path': 'poster.templates.vibrant_template', + 'class_name': 'VibrantTemplate', 'description': '适合景点、活动等充满活力的场景', - 'isActive': True + 'is_active': True }, 'business': { 'id': 'business', 'name': '商务风格', - 'handlerPath': 'poster.templates.business_template', - 'className': 'BusinessTemplate', + 'handler_path': 'poster.templates.business_template', + 'class_name': 'BusinessTemplate', 'description': '适合酒店、房地产等商务场景', - 'isActive': True + 'is_active': True } } @@ -89,11 +89,11 @@ class PosterService: return self._template_instances[template_id] template_info = self._templates[template_id] - handler_path = template_info.get('handlerPath') - class_name = template_info.get('className') + handler_path = template_info.get('handler_path') + class_name = template_info.get('class_name') if not handler_path or not class_name: - logger.error(f"模板 {template_id} 缺少 handlerPath 或 className") + logger.error(f"模板 {template_id} 缺少 handler_path 或 class_name") return None try: @@ -132,10 +132,17 @@ class PosterService: """获取所有可用的模板信息""" result = [] for tid in self._templates: - if self._templates[tid].get('is_active'): - template_info = self.get_template_info(tid) - if template_info: - result.append(template_info) + template = self._templates[tid] + if template.get('is_active', True): # 默认为激活状态 + template_info = { + "id": template["id"], + "name": template["name"], + "description": template["description"], + "handlerPath": template.get("handler_path", ""), + "className": template.get("class_name", ""), + "isActive": template.get("is_active", True) + } + result.append(template_info) return result def get_template_info(self, template_id: str) -> Optional[Dict[str, Any]]: @@ -161,7 +168,9 @@ class PosterService: scenic_spot_id: Optional[int], images_base64: Optional[List[str]] , num_variations: int = 1, - force_llm_generation: bool = False) -> Dict[str, Any]: + force_llm_generation: bool = False, + generate_psd: bool = False, + psd_output_path: Optional[str] = None) -> Dict[str, Any]: """ 统一的海报生成入口 @@ -171,12 +180,14 @@ class PosterService: content_id: 内容ID,用于从数据库获取内容(可选) product_id: 产品ID,用于从数据库获取产品信息(可选) scenic_spot_id: 景点ID,用于从数据库获取景点信息(可选) - image_ids: 图片ID列表,用于从数据库获取图片(可选) + images_base64: 图片base64编码,用于生成海报(可选) num_variations: 需要生成的变体数量 force_llm_generation: 是否强制使用LLM生成内容 + generate_psd: 是否生成PSD分层文件 + psd_output_path: PSD文件输出路径(可选,默认自动生成) Returns: - 生成结果字典 + 生成结果字典,包含PNG图像和可选的PSD文件 """ start_time = time.time() @@ -230,6 +241,7 @@ class PosterService: # 5. 保存海报并返回结果 variations = [] + psd_files = [] i=0 ## 用于多个海报时,指定海报的编号,此时只有一个没有用上,但是接口开放着。 output_path = self._save_poster(posters, template_id, i) if output_path: @@ -238,6 +250,15 @@ class PosterService: "poster_path": str(output_path), "base64": self._image_to_base64(posters) }) + + # 6. 如果需要,生成PSD分层文件 + if generate_psd: + psd_result = self._generate_psd_file( + template_handler, images, final_content, + template_id, i, psd_output_path + ) + if psd_result: + psd_files.append(psd_result) # 记录模板使用情况 self._update_template_stats(template_id, bool(variations), time.time() - start_time) @@ -246,10 +267,12 @@ class PosterService: "requestId": f"poster-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}", "templateId": template_id, "resultImagesBase64": variations, + "psdFiles": psd_files if psd_files else None, "metadata": { "generation_time": f"{time.time() - start_time:.2f}s", "model_used": self.ai_agent.config.model if force_llm_generation or not poster_content else None, - "num_variations": len(variations) + "num_variations": len(variations), + "psd_generated": bool(psd_files) } } except Exception as e: @@ -300,13 +323,17 @@ class PosterService: async def _generate_content_with_llm(self, template_id: str, content_id: Optional[int], product_id: Optional[int], scenic_spot_id: Optional[int]) -> Optional[Dict[str, Any]]: """使用LLM生成海报内容""" - # 获取提示词 + # 获取提示词 - 直接从数据库模板信息中获取 template_info = self._templates.get(template_id, {}) - system_prompt = template_info.get('systemPrompt', "") - user_prompt_template = template_info.get('userPromptTemplate', "") + system_prompt = template_info.get('system_prompt', "") + user_prompt_template = template_info.get('user_prompt_template', "") + if not system_prompt or not user_prompt_template: logger.error(f"模板 {template_id} 缺少提示词配置") + logger.debug(f"模板信息: {template_info}") return None + + logger.info(f"成功加载模板 {template_id} 的提示词配置") # 获取相关数据 data = {} @@ -316,24 +343,192 @@ class PosterService: data['product'] = self.db_service.get_product_by_id(product_id) if scenic_spot_id: data['scenic_spot'] = self.db_service.get_scenic_spot_by_id(scenic_spot_id) - logger.info(f"data: {data}") + + logger.info(f"获取到的数据: content={data.get('content') is not None}, product={data.get('product') is not None}, scenic_spot={data.get('scenic_spot') is not None}") - # 格式化提示词 + # 格式化数据为简洁的文本格式,参考其他模块的做法 try: - user_prompt = user_prompt_template.format(**data) - logger.info(f"user_prompt: {user_prompt}") - except KeyError as e: - logger.warning(f"格式化提示词时缺少键: {e}") - user_prompt = user_prompt_template + f"\n可用数据: {json.dumps(data, ensure_ascii=False)}" + logger.info("开始格式化数据...") + + # 景区信息格式化 + scenic_info = "无相关景区信息" + if data.get('scenic_spot'): + logger.info("正在格式化景区信息...") + spot = data['scenic_spot'] + scenic_info = f"""景区名称: {spot.get('name', '')} +地址: {spot.get('address', '')} +描述: {spot.get('description', '')} +优势: {spot.get('advantage', '')} +亮点: {spot.get('highlight', '')} +交通信息: {spot.get('trafficInfo', '')}""" + logger.info("景区信息格式化完成") + + # 产品信息格式化 + product_info = "无相关产品信息" + if data.get('product'): + logger.info("正在格式化产品信息...") + product = data['product'] + product_info = f"""产品名称: {product.get('productName', '')} +原价: {product.get('originPrice', '')} +实际价格: {product.get('realPrice', '')} +套餐信息: {product.get('packageInfo', '')} +核心优势: {product.get('keyAdvantages', '')} +亮点: {product.get('highlights', '')} +详细描述: {product.get('detailedDescription', '')}""" + logger.info("产品信息格式化完成") + + # 内容信息格式化 + tweet_info = "无相关内容信息" + if data.get('content'): + logger.info("正在格式化内容信息...") + content = data['content'] + tweet_info = f"""标题: {content.get('title', '')} +内容: {content.get('content', '')}""" + logger.info("内容信息格式化完成") + + logger.info("开始构建用户提示词...") + # 构建用户提示词 + user_prompt = user_prompt_template.format( + scenic_info=scenic_info, + product_info=product_info, + tweet_info=tweet_info + ) + + logger.info(f"用户提示词构建成功,长度: {len(user_prompt)}") + + except Exception as e: + logger.error(f"格式化提示词时发生错误: {e}", exc_info=True) + # 提供兜底方案 + user_prompt = f"""{user_prompt_template} + +当前可用数据: +- 景区信息: {'有' if data.get('scenic_spot') else '无'} +- 产品信息: {'有' if data.get('product') else '无'} +- 内容信息: {'有' if data.get('content') else '无'} + +请根据可用信息生成海报内容。""" try: - response, _, _, _ = await self.ai_agent.generate_text(system_prompt=system_prompt, user_prompt=user_prompt,use_stream=True) + response, _, _, _ = await self.ai_agent.generate_text( + system_prompt=system_prompt, + user_prompt=user_prompt, + use_stream=True + ) + + # 提取JSON响应 json_start = response.find('{') json_end = response.rfind('}') + 1 if json_start != -1 and json_end != -1: - return json.loads(response[json_start:json_end]) - logger.error(f"LLM响应中未找到JSON: {response}") + result = json.loads(response[json_start:json_end]) + logger.info(f"LLM生成内容成功: {list(result.keys())}") + return result + else: + logger.error(f"LLM响应中未找到JSON格式内容: {response[:200]}...") + return None + + except json.JSONDecodeError as e: + logger.error(f"解析LLM响应JSON失败: {e}") return None except Exception as e: logger.error(f"生成内容时发生错误: {e}", exc_info=True) + return None + + def _generate_psd_file(self, template_handler: BaseTemplate, images: Image.Image, + content: Dict[str, Any], template_id: str, + variation_id: int, custom_output_path: Optional[str] = None) -> Optional[Dict[str, Any]]: + """ + 生成PSD分层文件 + + Args: + template_handler: 模板处理器实例 + images: 图像数据 + content: 海报内容 + template_id: 模板ID + variation_id: 变体ID + custom_output_path: 自定义输出路径 + + Returns: + PSD文件信息字典,包含文件路径、base64编码等 + """ + try: + # 检查模板是否支持PSD生成 + if not hasattr(template_handler, 'generate_layered_psd'): + logger.warning(f"模板 {template_id} 不支持PSD分层输出") + return None + + # 生成PSD文件路径 + if custom_output_path: + psd_filename = custom_output_path + if not psd_filename.endswith('.psd'): + psd_filename += '.psd' + else: + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + psd_filename = f"{template_id}_layered_v{variation_id}_{timestamp}.psd" + + # 获取输出目录 + topic_id = f"poster_{template_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + output_dir = self.output_manager.get_topic_dir(topic_id) + psd_path = output_dir / psd_filename + + # 调用模板的PSD生成方法 + logger.info(f"开始生成PSD分层文件: {psd_path}") + generated_psd_path = template_handler.generate_layered_psd( + images=images, + content=content, + output_path=str(psd_path) + ) + + if not generated_psd_path or not Path(generated_psd_path).exists(): + logger.error("PSD文件生成失败或文件不存在") + return None + + # 获取文件信息 + file_size = Path(generated_psd_path).stat().st_size + + # 可选:生成PSD的base64编码(注意:PSD文件通常较大) + psd_base64 = None + if file_size < 10 * 1024 * 1024: # 如果文件小于10MB,才生成base64 + try: + with open(generated_psd_path, 'rb') as f: + psd_base64 = base64.b64encode(f.read()).decode('utf-8') + except Exception as e: + logger.warning(f"生成PSD文件base64编码失败: {e}") + + # 生成预览图(从PSD合成PNG预览) + preview_base64 = None + try: + from psd_tools import PSDImage + psd = PSDImage.open(generated_psd_path) + preview_image = psd.composite() + if preview_image: + preview_base64 = self._image_to_base64(preview_image) + logger.info("PSD预览图生成成功") + except Exception as e: + logger.warning(f"生成PSD预览图失败: {e}") + + logger.info(f"PSD文件生成成功: {generated_psd_path} ({file_size/1024:.1f}KB)") + + return { + "variation_id": variation_id, + "psd_path": str(generated_psd_path), + "file_name": psd_filename, + "file_size": file_size, + "base64": psd_base64, + "preview_base64": preview_base64, + "layer_count": self._get_psd_layer_count(generated_psd_path), + "generation_time": datetime.now().isoformat() + } + + except Exception as e: + logger.error(f"生成PSD文件时发生错误: {e}", exc_info=True) + return None + + def _get_psd_layer_count(self, psd_path: str) -> Optional[int]: + """获取PSD文件的图层数量""" + try: + from psd_tools import PSDImage + psd = PSDImage.open(psd_path) + return len(list(psd)) + except Exception as e: + logger.warning(f"获取PSD图层数量失败: {e}") return None \ No newline at end of file diff --git a/poster/templates/__pycache__/vibrant_template.cpython-312.pyc b/poster/templates/__pycache__/vibrant_template.cpython-312.pyc index 1e70c5d..e945895 100644 Binary files a/poster/templates/__pycache__/vibrant_template.cpython-312.pyc and b/poster/templates/__pycache__/vibrant_template.cpython-312.pyc differ diff --git a/poster/templates/vibrant_template.py b/poster/templates/vibrant_template.py index 2c41243..0aec2ef 100644 --- a/poster/templates/vibrant_template.py +++ b/poster/templates/vibrant_template.py @@ -464,7 +464,7 @@ class VibrantTemplate(BaseTemplate): def _render_right_column(self, draw: ImageDraw.Draw, content: Dict[str, Any], y: int, x: int, right_margin: int): """渲染右栏内容:价格、票种和备注(增强版本,与demo一致)""" - price_text = content.get('price', '') + price_text = str(content.get('price', '')) # 确保价格是字符串类型 price_target_width = int((right_margin - x) * 0.7) price_size, price_actual_width = self._calculate_optimal_font_size_enhanced( price_text, price_target_width, max_size=120, min_size=40 @@ -1020,7 +1020,7 @@ class VibrantTemplate(BaseTemplate): # 使用中文字体 # 价格 - price_text = content.get('price', '') + price_text = str(content.get('price', '')) # 确保价格是字符串类型 if price_text: price_target_width = int((right_margin - x) * 0.7) price_size, price_actual_width = self._calculate_optimal_font_size_enhanced(