修复了poster接口提示字构筑的问题
This commit is contained in:
parent
c988ea2911
commit
4f5d8cfbfe
Binary file not shown.
@ -7,6 +7,7 @@
|
|||||||
|
|
||||||
from re import T
|
from re import T
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
|
from datetime import datetime
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
@ -47,9 +48,17 @@ class TemplateInfo(BaseModel):
|
|||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
handlerPath: str
|
handler_path: str = Field(alias="handlerPath")
|
||||||
className: str
|
class_name: str = Field(alias="className")
|
||||||
isActive: bool
|
system_prompt: Optional[str] = None
|
||||||
|
user_prompt_template: Optional[str] = None
|
||||||
|
required_fields: Optional[List[str]] = None
|
||||||
|
optional_fields: Optional[List[str]] = None
|
||||||
|
size: Optional[List[int]] = None
|
||||||
|
is_active: bool = Field(alias="isActive")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
class TemplateListResponse(BaseModel):
|
class TemplateListResponse(BaseModel):
|
||||||
"""模板列表响应模型"""
|
"""模板列表响应模型"""
|
||||||
@ -71,6 +80,11 @@ class PosterGenerateResponse(BaseModel):
|
|||||||
resultImagesBase64: List[Dict[str, Any]] = Field(description="生成的海报图像(base64编码)列表")
|
resultImagesBase64: List[Dict[str, Any]] = Field(description="生成的海报图像(base64编码)列表")
|
||||||
psdFiles: Optional[List[Dict[str, Any]]] = Field(None, description="生成的PSD文件信息列表")
|
psdFiles: Optional[List[Dict[str, Any]]] = Field(None, description="生成的PSD文件信息列表")
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_encoders = {
|
||||||
|
datetime: lambda v: v.isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
class ImageUsageRequest(BaseModel):
|
class ImageUsageRequest(BaseModel):
|
||||||
"""图像使用查询请求模型"""
|
"""图像使用查询请求模型"""
|
||||||
|
|||||||
Binary file not shown.
@ -37,18 +37,52 @@ def get_poster_service(
|
|||||||
return PosterService(ai_agent, config_manager, output_manager)
|
return PosterService(ai_agent, config_manager, output_manager)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/templates", response_model=TemplateListResponse, summary="获取可用模板列表")
|
@router.get("/test/templates", summary="测试模板配置")
|
||||||
|
async def test_templates(
|
||||||
|
poster_service: PosterService = Depends(get_poster_service)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
测试模板配置是否正确加载
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 获取所有模板信息
|
||||||
|
templates = poster_service._templates
|
||||||
|
|
||||||
|
# 检查每个模板的配置
|
||||||
|
result = {}
|
||||||
|
for template_id, template_info in templates.items():
|
||||||
|
result[template_id] = {
|
||||||
|
"basic_info": template_info,
|
||||||
|
"has_system_prompt": bool(template_info.get('system_prompt')),
|
||||||
|
"has_user_prompt_template": bool(template_info.get('user_prompt_template')),
|
||||||
|
"prompt_lengths": {
|
||||||
|
"system_prompt": len(template_info.get('system_prompt', '')),
|
||||||
|
"user_prompt_template": len(template_info.get('user_prompt_template', ''))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": "模板配置检查完成",
|
||||||
|
"template_count": len(templates),
|
||||||
|
"templates": result
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"测试模板配置失败: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"测试失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/templates", response_model=TemplateListResponse, summary="获取海报模板列表")
|
||||||
async def get_templates(
|
async def get_templates(
|
||||||
poster_service: PosterService = Depends(get_poster_service)
|
poster_service: PosterService = Depends(get_poster_service)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取可用的海报模板列表
|
获取所有可用的海报模板列表
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
templates = await poster_service.get_available_templates()
|
templates = poster_service.get_available_templates()
|
||||||
return TemplateListResponse(
|
return TemplateListResponse(
|
||||||
templates=templates,
|
templates=templates,
|
||||||
total_count=len(templates)
|
totalCount=len(templates)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取模板列表失败: {e}", exc_info=True)
|
logger.error(f"获取模板列表失败: {e}", exc_info=True)
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@ -821,7 +821,7 @@ class DatabaseService:
|
|||||||
with self.db_pool.get_connection() as conn:
|
with self.db_pool.get_connection() as conn:
|
||||||
with conn.cursor(dictionary=True) as cursor:
|
with conn.cursor(dictionary=True) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT * FROM posterTemplates ORDER BY createdAt"
|
"SELECT * FROM poster_templates ORDER BY created_at"
|
||||||
)
|
)
|
||||||
results = cursor.fetchall()
|
results = cursor.fetchall()
|
||||||
logger.info(f"获取海报模板列表: 找到{len(results)}个模板")
|
logger.info(f"获取海报模板列表: 找到{len(results)}个模板")
|
||||||
@ -849,7 +849,7 @@ class DatabaseService:
|
|||||||
with self.db_pool.get_connection() as conn:
|
with self.db_pool.get_connection() as conn:
|
||||||
with conn.cursor(dictionary=True) as cursor:
|
with conn.cursor(dictionary=True) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT * FROM posterTemplates WHERE id = %s",
|
"SELECT * FROM poster_templates WHERE id = %s",
|
||||||
(template_id,)
|
(template_id,)
|
||||||
)
|
)
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
@ -878,7 +878,7 @@ class DatabaseService:
|
|||||||
with self.db_pool.get_connection() as conn:
|
with self.db_pool.get_connection() as conn:
|
||||||
with conn.cursor(dictionary=True) as cursor:
|
with conn.cursor(dictionary=True) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT * FROM posterTemplates WHERE isActive = 1 ORDER BY createdAt"
|
"SELECT * FROM poster_templates WHERE is_active = 1 ORDER BY created_at"
|
||||||
)
|
)
|
||||||
results = cursor.fetchall()
|
results = cursor.fetchall()
|
||||||
logger.info(f"获取激活模板列表: 找到{len(results)}个模板")
|
logger.info(f"获取激活模板列表: 找到{len(results)}个模板")
|
||||||
|
|||||||
@ -63,18 +63,18 @@ class PosterService:
|
|||||||
'vibrant': {
|
'vibrant': {
|
||||||
'id': 'vibrant',
|
'id': 'vibrant',
|
||||||
'name': '活力风格',
|
'name': '活力风格',
|
||||||
'handlerPath': 'poster.templates.vibrant_template',
|
'handler_path': 'poster.templates.vibrant_template',
|
||||||
'className': 'VibrantTemplate',
|
'class_name': 'VibrantTemplate',
|
||||||
'description': '适合景点、活动等充满活力的场景',
|
'description': '适合景点、活动等充满活力的场景',
|
||||||
'isActive': True
|
'is_active': True
|
||||||
},
|
},
|
||||||
'business': {
|
'business': {
|
||||||
'id': 'business',
|
'id': 'business',
|
||||||
'name': '商务风格',
|
'name': '商务风格',
|
||||||
'handlerPath': 'poster.templates.business_template',
|
'handler_path': 'poster.templates.business_template',
|
||||||
'className': 'BusinessTemplate',
|
'class_name': 'BusinessTemplate',
|
||||||
'description': '适合酒店、房地产等商务场景',
|
'description': '适合酒店、房地产等商务场景',
|
||||||
'isActive': True
|
'is_active': True
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,11 +89,11 @@ class PosterService:
|
|||||||
return self._template_instances[template_id]
|
return self._template_instances[template_id]
|
||||||
|
|
||||||
template_info = self._templates[template_id]
|
template_info = self._templates[template_id]
|
||||||
handler_path = template_info.get('handlerPath')
|
handler_path = template_info.get('handler_path')
|
||||||
class_name = template_info.get('className')
|
class_name = template_info.get('class_name')
|
||||||
|
|
||||||
if not handler_path or not class_name:
|
if not handler_path or not class_name:
|
||||||
logger.error(f"模板 {template_id} 缺少 handlerPath 或 className")
|
logger.error(f"模板 {template_id} 缺少 handler_path 或 class_name")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -132,10 +132,17 @@ class PosterService:
|
|||||||
"""获取所有可用的模板信息"""
|
"""获取所有可用的模板信息"""
|
||||||
result = []
|
result = []
|
||||||
for tid in self._templates:
|
for tid in self._templates:
|
||||||
if self._templates[tid].get('is_active'):
|
template = self._templates[tid]
|
||||||
template_info = self.get_template_info(tid)
|
if template.get('is_active', True): # 默认为激活状态
|
||||||
if template_info:
|
template_info = {
|
||||||
result.append(template_info)
|
"id": template["id"],
|
||||||
|
"name": template["name"],
|
||||||
|
"description": template["description"],
|
||||||
|
"handlerPath": template.get("handler_path", ""),
|
||||||
|
"className": template.get("class_name", ""),
|
||||||
|
"isActive": template.get("is_active", True)
|
||||||
|
}
|
||||||
|
result.append(template_info)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_template_info(self, template_id: str) -> Optional[Dict[str, Any]]:
|
def get_template_info(self, template_id: str) -> Optional[Dict[str, Any]]:
|
||||||
@ -316,13 +323,17 @@ class PosterService:
|
|||||||
async def _generate_content_with_llm(self, template_id: str, content_id: Optional[int],
|
async def _generate_content_with_llm(self, template_id: str, content_id: Optional[int],
|
||||||
product_id: Optional[int], scenic_spot_id: Optional[int]) -> Optional[Dict[str, Any]]:
|
product_id: Optional[int], scenic_spot_id: Optional[int]) -> Optional[Dict[str, Any]]:
|
||||||
"""使用LLM生成海报内容"""
|
"""使用LLM生成海报内容"""
|
||||||
# 获取提示词
|
# 获取提示词 - 直接从数据库模板信息中获取
|
||||||
template_info = self._templates.get(template_id, {})
|
template_info = self._templates.get(template_id, {})
|
||||||
system_prompt = template_info.get('systemPrompt', "")
|
system_prompt = template_info.get('system_prompt', "")
|
||||||
user_prompt_template = template_info.get('userPromptTemplate', "")
|
user_prompt_template = template_info.get('user_prompt_template', "")
|
||||||
|
|
||||||
if not system_prompt or not user_prompt_template:
|
if not system_prompt or not user_prompt_template:
|
||||||
logger.error(f"模板 {template_id} 缺少提示词配置")
|
logger.error(f"模板 {template_id} 缺少提示词配置")
|
||||||
|
logger.debug(f"模板信息: {template_info}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
logger.info(f"成功加载模板 {template_id} 的提示词配置")
|
||||||
|
|
||||||
# 获取相关数据
|
# 获取相关数据
|
||||||
data = {}
|
data = {}
|
||||||
@ -332,23 +343,91 @@ class PosterService:
|
|||||||
data['product'] = self.db_service.get_product_by_id(product_id)
|
data['product'] = self.db_service.get_product_by_id(product_id)
|
||||||
if scenic_spot_id:
|
if scenic_spot_id:
|
||||||
data['scenic_spot'] = self.db_service.get_scenic_spot_by_id(scenic_spot_id)
|
data['scenic_spot'] = self.db_service.get_scenic_spot_by_id(scenic_spot_id)
|
||||||
logger.info(f"data: {data}")
|
|
||||||
|
logger.info(f"获取到的数据: content={data.get('content') is not None}, product={data.get('product') is not None}, scenic_spot={data.get('scenic_spot') is not None}")
|
||||||
|
|
||||||
# 格式化提示词
|
# 格式化数据为简洁的文本格式,参考其他模块的做法
|
||||||
try:
|
try:
|
||||||
user_prompt = user_prompt_template.format(**data)
|
logger.info("开始格式化数据...")
|
||||||
logger.info(f"user_prompt: {user_prompt}")
|
|
||||||
except KeyError as e:
|
# 景区信息格式化
|
||||||
logger.warning(f"格式化提示词时缺少键: {e}")
|
scenic_info = "无相关景区信息"
|
||||||
user_prompt = user_prompt_template + f"\n可用数据: {json.dumps(data, ensure_ascii=False)}"
|
if data.get('scenic_spot'):
|
||||||
|
logger.info("正在格式化景区信息...")
|
||||||
|
spot = data['scenic_spot']
|
||||||
|
scenic_info = f"""景区名称: {spot.get('name', '')}
|
||||||
|
地址: {spot.get('address', '')}
|
||||||
|
描述: {spot.get('description', '')}
|
||||||
|
优势: {spot.get('advantage', '')}
|
||||||
|
亮点: {spot.get('highlight', '')}
|
||||||
|
交通信息: {spot.get('trafficInfo', '')}"""
|
||||||
|
logger.info("景区信息格式化完成")
|
||||||
|
|
||||||
|
# 产品信息格式化
|
||||||
|
product_info = "无相关产品信息"
|
||||||
|
if data.get('product'):
|
||||||
|
logger.info("正在格式化产品信息...")
|
||||||
|
product = data['product']
|
||||||
|
product_info = f"""产品名称: {product.get('productName', '')}
|
||||||
|
原价: {product.get('originPrice', '')}
|
||||||
|
实际价格: {product.get('realPrice', '')}
|
||||||
|
套餐信息: {product.get('packageInfo', '')}
|
||||||
|
核心优势: {product.get('keyAdvantages', '')}
|
||||||
|
亮点: {product.get('highlights', '')}
|
||||||
|
详细描述: {product.get('detailedDescription', '')}"""
|
||||||
|
logger.info("产品信息格式化完成")
|
||||||
|
|
||||||
|
# 内容信息格式化
|
||||||
|
tweet_info = "无相关内容信息"
|
||||||
|
if data.get('content'):
|
||||||
|
logger.info("正在格式化内容信息...")
|
||||||
|
content = data['content']
|
||||||
|
tweet_info = f"""标题: {content.get('title', '')}
|
||||||
|
内容: {content.get('content', '')}"""
|
||||||
|
logger.info("内容信息格式化完成")
|
||||||
|
|
||||||
|
logger.info("开始构建用户提示词...")
|
||||||
|
# 构建用户提示词
|
||||||
|
user_prompt = user_prompt_template.format(
|
||||||
|
scenic_info=scenic_info,
|
||||||
|
product_info=product_info,
|
||||||
|
tweet_info=tweet_info
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"用户提示词构建成功,长度: {len(user_prompt)}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"格式化提示词时发生错误: {e}", exc_info=True)
|
||||||
|
# 提供兜底方案
|
||||||
|
user_prompt = f"""{user_prompt_template}
|
||||||
|
|
||||||
|
当前可用数据:
|
||||||
|
- 景区信息: {'有' if data.get('scenic_spot') else '无'}
|
||||||
|
- 产品信息: {'有' if data.get('product') else '无'}
|
||||||
|
- 内容信息: {'有' if data.get('content') else '无'}
|
||||||
|
|
||||||
|
请根据可用信息生成海报内容。"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response, _, _, _ = await self.ai_agent.generate_text(system_prompt=system_prompt, user_prompt=user_prompt,use_stream=True)
|
response, _, _, _ = await self.ai_agent.generate_text(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
user_prompt=user_prompt,
|
||||||
|
use_stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提取JSON响应
|
||||||
json_start = response.find('{')
|
json_start = response.find('{')
|
||||||
json_end = response.rfind('}') + 1
|
json_end = response.rfind('}') + 1
|
||||||
if json_start != -1 and json_end != -1:
|
if json_start != -1 and json_end != -1:
|
||||||
return json.loads(response[json_start:json_end])
|
result = json.loads(response[json_start:json_end])
|
||||||
logger.error(f"LLM响应中未找到JSON: {response}")
|
logger.info(f"LLM生成内容成功: {list(result.keys())}")
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
logger.error(f"LLM响应中未找到JSON格式内容: {response[:200]}...")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"解析LLM响应JSON失败: {e}")
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"生成内容时发生错误: {e}", exc_info=True)
|
logger.error(f"生成内容时发生错误: {e}", exc_info=True)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user