优化了poster模块的接口

This commit is contained in:
jinye_huang 2025-07-17 16:15:02 +08:00
parent fe5fe5e5e2
commit b8140ed820
13 changed files with 649 additions and 510 deletions

Binary file not shown.

View File

@ -56,12 +56,11 @@ app.add_middleware(
) )
# 导入路由 # 导入路由
from api.routers import tweet, poster, poster_unified, prompt, document, data, integration, content_integration from api.routers import tweet, poster, prompt, document, data, integration, content_integration
# 包含路由 # 包含路由
app.include_router(tweet.router, prefix="/api/v1/tweet", tags=["tweet"]) app.include_router(tweet.router, prefix="/api/v1/tweet", tags=["tweet"])
app.include_router(poster.router, prefix="/api/v1/poster", tags=["poster"]) app.include_router(poster.router, prefix="/api/v1/poster", tags=["poster"])
app.include_router(poster_unified.router, prefix="/api/v2/poster", tags=["poster-unified"])
app.include_router(prompt.router, prefix="/api/v1/prompt", tags=["prompt"]) app.include_router(prompt.router, prefix="/api/v1/prompt", tags=["prompt"])
app.include_router(document.router, prefix="/api/v1/document", tags=["document"]) app.include_router(document.router, prefix="/api/v1/document", tags=["document"])
app.include_router(data.router, prefix="/api/v1", tags=["data"]) app.include_router(data.router, prefix="/api/v1", tags=["data"])

View File

@ -2,101 +2,122 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
海报API模型定义 海报API模型定义 - 简化版本
只保留核心功能重点优化图片ID使用追踪
""" """
from typing import List, Dict, Any, Optional, Tuple from typing import List, Dict, Any, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class PosterRequest(BaseModel): class PosterGenerateRequest(BaseModel):
"""海报生成请求模型""" """海报生成请求模型"""
content: Dict[str, Any] = Field(..., description="内容数据,包含标题、正文等") content_id: Optional[int] = Field(None, description="内容ID")
topic_index: str = Field(..., description="主题索引,用于文件命名") product_id: Optional[int] = Field(None, description="产品ID")
template_name: Optional[str] = Field(None, description="模板名称如果为None则根据配置选择") scenic_spot_id: Optional[int] = Field(None, description="景区ID")
image_ids: Optional[List[int]] = Field(None, description="图像ID列表")
generate_collage: bool = Field(False, description="是否生成拼图")
class Config: class Config:
schema_extra = { schema_extra = {
"example": { "example": {
"content": { "content_id": 1,
"title": "【北京故宫】避开人潮的秘密路线90%的人都不知道!", "product_id": 2,
"content": "故宫,作为中国最著名的文化遗产之一...", "scenic_spot_id": 3,
"tag": ["北京旅游", "故宫", "旅游攻略", "避暑胜地"] "image_ids": [1, 2, 3],
}, "generate_collage": True
"topic_index": "1",
"template_name": "vibrant"
} }
} }
class PosterResponse(BaseModel): class ImageUsageInfo(BaseModel):
"""图像使用信息模型 - 重点追踪图片使用情况"""
image_id: int = Field(..., description="图像ID")
usage_count: int = Field(..., description="使用次数")
first_used_at: str = Field(..., description="首次使用时间")
last_used_at: str = Field(..., description="最后使用时间")
usage_context: List[str] = Field(default_factory=list, description="使用场景列表")
class PosterGenerateResponse(BaseModel):
"""海报生成响应模型""" """海报生成响应模型"""
request_id: str = Field(..., description="请求ID") request_id: str = Field(..., description="请求ID")
topic_index: str = Field(..., description="主题索引") poster_base64: str = Field(..., description="海报图像的base64编码")
poster_path: str = Field(..., description="生成的海报文件路径") content_info: Optional[Dict[str, Any]] = Field(None, description="内容信息")
template_name: str = Field(..., description="使用的模板名称") product_info: Optional[Dict[str, Any]] = Field(None, description="产品信息")
scenic_spot_info: Optional[Dict[str, Any]] = Field(None, description="景区信息")
used_image_ids: List[int] = Field(default_factory=list, description="使用的图像ID列表")
image_usage_info: List[ImageUsageInfo] = Field(default_factory=list, description="图像使用详情")
collage_base64: Optional[str] = Field(None, description="拼图的base64编码")
metadata: Dict[str, Any] = Field(default_factory=dict, description="处理元数据")
class Config: class Config:
schema_extra = { schema_extra = {
"example": { "example": {
"request_id": "poster-20240715-123456-a1b2c3d4", "request_id": "poster-20240715-123456-a1b2c3d4",
"topic_index": "1", "poster_base64": "...",
"poster_path": "/result/run_20230715_123456/topic_1/poster_vibrant.png", "content_info": {
"template_name": "vibrant" "id": 1,
} "title": "【北京故宫】避开人潮的秘密路线",
} "content": "故宫,作为中国最著名的文化遗产之一...",
"tag": "北京旅游,故宫,旅游攻略"
class TemplateListResponse(BaseModel):
"""模板列表响应模型"""
templates: List[str] = Field(..., description="可用的模板列表")
default_template: str = Field(..., description="默认模板")
class Config:
schema_extra = {
"example": {
"templates": ["vibrant", "business", "collage"],
"default_template": "vibrant"
}
}
class PosterTextRequest(BaseModel):
"""海报文案生成请求模型"""
system_prompt: str = Field(..., description="系统提示词")
user_prompt: str = Field(..., description="用户提示词")
context_data: Optional[Dict[str, Any]] = Field(None, description="上下文数据,用于填充提示词中的占位符")
temperature: Optional[float] = Field(0.3, description="生成温度参数")
top_p: Optional[float] = Field(0.4, description="top_p参数")
class Config:
schema_extra = {
"example": {
"system_prompt": "你是一位专业的旅游海报文案撰写专家...",
"user_prompt": "请为{location}{attraction}创作一段简短有力的海报文案...",
"context_data": {
"location": "北京",
"attraction": "故宫"
}, },
"temperature": 0.3, "used_image_ids": [1, 2, 3],
"top_p": 0.4 "image_usage_info": [
{
"image_id": 1,
"usage_count": 5,
"first_used_at": "2024-07-15 10:30:00",
"last_used_at": "2024-07-15 10:30:00",
"usage_context": ["poster_generation", "collage_creation"]
}
],
"collage_base64": "...",
"metadata": {
"total_images_used": 3,
"has_collage": True,
"processing_time": "2.5s"
}
} }
} }
class PosterTextResponse(BaseModel): class ImageUsageRequest(BaseModel):
"""海报文案生成响应模型""" """图像使用查询请求模型"""
request_id: str = Field(..., description="请求ID") image_ids: List[int] = Field(..., description="要查询的图像ID列表")
text_content: Dict[str, Any] = Field(..., description="生成的文案内容")
class Config: class Config:
schema_extra = { schema_extra = {
"example": { "example": {
"request_id": "text-20240715-123456-a1b2c3d4", "image_ids": [1, 2, 3, 4, 5]
"text_content": { }
"title": "紫禁城的秘密", }
"subtitle": "600年历史等你探索",
"description": "穿越时光,触摸历史,感受帝王的荣耀与辉煌"
class ImageUsageResponse(BaseModel):
"""图像使用情况响应模型"""
request_id: str = Field(..., description="请求ID")
image_usage_info: List[ImageUsageInfo] = Field(..., description="图像使用详情")
summary: Dict[str, Any] = Field(..., description="使用情况汇总")
class Config:
schema_extra = {
"example": {
"request_id": "usage-20240715-123456-a1b2c3d4",
"image_usage_info": [
{
"image_id": 1,
"usage_count": 5,
"first_used_at": "2024-07-15 10:30:00",
"last_used_at": "2024-07-15 10:30:00",
"usage_context": ["poster_generation", "collage_creation"]
}
],
"summary": {
"total_images": 5,
"total_usage_count": 15,
"most_used_image_id": 1,
"least_used_image_id": 5
} }
} }
} }

View File

@ -2,7 +2,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
海报API路由 海报API路由 - 简化版本
只保留核心功能重点优化图片使用追踪
""" """
import logging import logging
@ -14,9 +15,8 @@ from core.ai import AIAgent
from utils.file_io import OutputManager from utils.file_io import OutputManager
from api.services.poster import PosterService from api.services.poster import PosterService
from api.models.poster import ( from api.models.poster import (
PosterRequest, PosterResponse, PosterGenerateRequest, PosterGenerateResponse,
TemplateListResponse, ImageUsageRequest, ImageUsageResponse
PosterTextRequest, PosterTextResponse
) )
# 从依赖注入模块导入依赖 # 从依赖注入模块导入依赖
@ -37,80 +37,51 @@ def get_poster_service(
return PosterService(ai_agent, config_manager, output_manager) return PosterService(ai_agent, config_manager, output_manager)
@router.post("/generate", response_model=PosterResponse, summary="生成海报") @router.post("/generate", response_model=PosterGenerateResponse, summary="生成海报")
async def generate_poster( async def generate_poster(
request: PosterRequest, request: PosterGenerateRequest,
poster_service: PosterService = Depends(get_poster_service) poster_service: PosterService = Depends(get_poster_service)
): ):
""" """
生成海报 生成海报
- **content**: 内容数据包含标题正文等 - **content_id**: 内容ID可选
- **topic_index**: 主题索引用于文件命名 - **product_id**: 产品ID可选
- **template_name**: 模板名称如果为None则根据配置选择 - **scenic_spot_id**: 景区ID可选
- **image_ids**: 图像ID列表可选
- **generate_collage**: 是否生成拼图
""" """
try: try:
request_id, topic_index, poster_path, template_name = poster_service.generate_poster( result = poster_service.generate_poster_simplified(
content=request.content, content_id=request.content_id,
topic_index=request.topic_index, product_id=request.product_id,
template_name=request.template_name scenic_spot_id=request.scenic_spot_id,
image_ids=request.image_ids,
generate_collage=request.generate_collage
) )
return PosterResponse( return PosterGenerateResponse(**result)
request_id=request_id, except ValueError as e:
topic_index=topic_index, logger.error(f"参数错误: {e}")
poster_path=poster_path, raise HTTPException(status_code=400, detail=str(e))
template_name=template_name
)
except Exception as e: except Exception as e:
logger.error(f"生成海报失败: {e}", exc_info=True) logger.error(f"生成海报失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"生成海报失败: {str(e)}") raise HTTPException(status_code=500, detail=f"生成海报失败: {str(e)}")
@router.get("/templates", response_model=TemplateListResponse, summary="获取可用模板列表") @router.post("/image-usage", response_model=ImageUsageResponse, summary="查询图像使用情况")
async def get_templates( async def get_image_usage(
poster_service: PosterService = Depends(get_poster_service) request: ImageUsageRequest,
):
"""获取可用的海报模板列表"""
try:
templates, default_template = poster_service.get_available_templates()
return TemplateListResponse(
templates=templates,
default_template=default_template
)
except Exception as e:
logger.error(f"获取模板列表失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"获取模板列表失败: {str(e)}")
@router.post("/text", response_model=PosterTextResponse, summary="生成海报文案")
async def generate_poster_text(
request: PosterTextRequest,
poster_service: PosterService = Depends(get_poster_service) poster_service: PosterService = Depends(get_poster_service)
): ):
""" """
生成海报文案 查询图像使用情况
- **system_prompt**: 系统提示词 - **image_ids**: 要查询的图像ID列表
- **user_prompt**: 用户提示词
- **context_data**: 上下文数据用于填充提示词中的占位符
- **temperature**: 生成温度参数
- **top_p**: top_p参数
""" """
try: try:
request_id, text_content = await poster_service.generate_poster_text( result = poster_service.get_image_usage_info(request.image_ids)
system_prompt=request.system_prompt, return ImageUsageResponse(**result)
user_prompt=request.user_prompt,
context_data=request.context_data,
temperature=request.temperature,
top_p=request.top_p
)
return PosterTextResponse(
request_id=request_id,
text_content=text_content
)
except Exception as e: except Exception as e:
logger.error(f"生成海报文案失败: {e}", exc_info=True) logger.error(f"查询图像使用情况失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"生成海报文案失败: {str(e)}") raise HTTPException(status_code=500, detail=f"查询图像使用情况失败: {str(e)}")

View File

@ -1,314 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
统一海报API路由
支持多种模板类型的海报生成配置化管理
"""
import logging
import uuid
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from typing import Dict, Any, List
from core.ai import AIAgent
from api.services.poster_service import UnifiedPosterService
from api.models.vibrant_poster import (
PosterGenerationRequest, PosterGenerationResponse,
ContentGenerationRequest, ContentGenerationResponse,
TemplateListResponse, TemplateInfo,
BaseAPIResponse
)
# 从依赖注入模块导入依赖
from api.dependencies import get_ai_agent
logger = logging.getLogger(__name__)
# 创建路由
router = APIRouter()
# 依赖注入函数
def get_unified_poster_service(
ai_agent: AIAgent = Depends(get_ai_agent)
) -> UnifiedPosterService:
"""获取统一海报服务"""
return UnifiedPosterService(ai_agent)
def create_response(success: bool, message: str, data: Any = None, request_id: str = None) -> Dict[str, Any]:
"""创建标准响应"""
if request_id is None:
request_id = f"req-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}"
return {
"success": success,
"message": message,
"request_id": request_id,
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": data
}
@router.get("/templates", response_model=TemplateListResponse, summary="获取所有可用模板")
async def get_templates(
service: UnifiedPosterService = Depends(get_unified_poster_service)
):
"""
获取所有可用的海报模板列表
返回每个模板的详细信息包括
- 模板ID和名称
- 模板描述
- 模板尺寸
- 必填字段和可选字段
"""
try:
templates = service.get_available_templates()
response_data = create_response(
success=True,
message="获取模板列表成功",
data=templates
)
return TemplateListResponse(**response_data)
except Exception as e:
logger.error(f"获取模板列表失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"获取模板列表失败: {str(e)}")
@router.get("/templates/{template_id}", response_model=BaseAPIResponse, summary="获取指定模板信息")
async def get_template_info(
template_id: str,
service: UnifiedPosterService = Depends(get_unified_poster_service)
):
"""
获取指定模板的详细信息
参数
- **template_id**: 模板ID
"""
try:
template_info = service.get_template_info(template_id)
if not template_info:
raise HTTPException(status_code=404, detail=f"模板 {template_id} 不存在")
response_data = create_response(
success=True,
message="获取模板信息成功",
data=template_info.dict()
)
return BaseAPIResponse(**response_data)
except HTTPException:
raise
except Exception as e:
logger.error(f"获取模板信息失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"获取模板信息失败: {str(e)}")
@router.post("/content/generate", response_model=ContentGenerationResponse, summary="生成海报内容")
async def generate_content(
request: ContentGenerationRequest,
service: UnifiedPosterService = Depends(get_unified_poster_service)
):
"""
根据源数据生成海报内容不生成实际图片
用于
1. 预览生成的内容
2. 调试和测试内容生成
3. 分步骤生成先生成内容再生成图片
参数
- **template_id**: 模板ID
- **source_data**: 源数据用于AI生成内容
- **temperature**: AI生成温度参数
"""
try:
content = await service.generate_content(
template_id=request.template_id,
source_data=request.source_data,
temperature=request.temperature
)
response_data = create_response(
success=True,
message="内容生成成功",
data={
"template_id": request.template_id,
"content": content,
"metadata": {
"generation_method": "ai_generated",
"temperature": request.temperature
}
}
)
return ContentGenerationResponse(**response_data)
except Exception as e:
logger.error(f"生成内容失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"生成内容失败: {str(e)}")
@router.post("/generate", response_model=PosterGenerationResponse, summary="生成海报")
async def generate_poster(
request: PosterGenerationRequest,
service: UnifiedPosterService = Depends(get_unified_poster_service)
):
"""
生成海报图片
支持两种模式
1. 直接提供内容content字段
2. 提供源数据让AI生成内容source_data字段
参数
- **template_id**: 模板ID
- **content**: 直接提供的海报内容可选
- **source_data**: 源数据用于AI生成内容可选
- **topic_name**: 主题名称用于文件命名
- **image_path**: 指定图片路径可选
- **image_dir**: 图片目录可选
- **output_dir**: 输出目录可选
- **temperature**: AI生成温度参数
"""
try:
result = await service.generate_poster(
template_id=request.template_id,
content=request.content,
source_data=request.source_data,
topic_name=request.topic_name,
image_path=request.image_path,
image_dir=request.image_dir,
output_dir=request.output_dir,
temperature=request.temperature
)
response_data = create_response(
success=True,
message="海报生成成功",
data=result,
request_id=result["request_id"]
)
return PosterGenerationResponse(**response_data)
except Exception as e:
logger.error(f"生成海报失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"生成海报失败: {str(e)}")
@router.post("/batch", response_model=BaseAPIResponse, summary="批量生成海报")
async def batch_generate_posters(
template_id: str,
base_path: str,
image_dir: str = None,
source_files: Dict[str, str] = None,
output_base: str = "result/posters",
parallel_count: int = 3,
temperature: float = 0.7,
service: UnifiedPosterService = Depends(get_unified_poster_service)
):
"""
批量生成海报
自动扫描指定目录下的topic文件夹为每个topic生成海报
参数
- **template_id**: 模板ID
- **base_path**: 包含多个topic目录的基础路径
- **image_dir**: 图片目录可选
- **source_files**: 源文件配置字典可选
- **output_base**: 输出基础目录
- **parallel_count**: 并发处理数量
- **temperature**: AI生成温度参数
"""
try:
result = await service.batch_generate_posters(
template_id=template_id,
base_path=base_path,
image_dir=image_dir,
source_files=source_files or {},
output_base=output_base,
parallel_count=parallel_count,
temperature=temperature
)
response_data = create_response(
success=True,
message=f"批量生成完成,成功: {result['successful_count']}, 失败: {result['failed_count']}",
data=result,
request_id=result["request_id"]
)
return BaseAPIResponse(**response_data)
except Exception as e:
logger.error(f"批量生成海报失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"批量生成海报失败: {str(e)}")
@router.post("/config/reload", response_model=BaseAPIResponse, summary="重新加载配置")
async def reload_config(
service: UnifiedPosterService = Depends(get_unified_poster_service)
):
"""
重新加载海报配置
用于在不重启服务的情况下更新配置包括
- 提示词模板
- 模板配置
- 默认参数
"""
try:
service.reload_config()
response_data = create_response(
success=True,
message="配置重新加载成功"
)
return BaseAPIResponse(**response_data)
except Exception as e:
logger.error(f"重新加载配置失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"重新加载配置失败: {str(e)}")
@router.get("/health", summary="健康检查")
async def health_check():
"""服务健康检查"""
return create_response(
success=True,
message="统一海报服务运行正常",
data={
"service": "unified_poster",
"status": "healthy",
"version": "2.0.0"
}
)
@router.get("/config", summary="获取服务配置")
async def get_service_config(
service: UnifiedPosterService = Depends(get_unified_poster_service)
):
"""获取服务配置信息"""
try:
config_info = {
"default_image_dir": service.config_manager.get_default_config("image_dir"),
"default_output_dir": service.config_manager.get_default_config("output_dir"),
"default_font_dir": service.config_manager.get_default_config("font_dir"),
"default_template": service.config_manager.get_default_config("template"),
"supported_image_formats": ["png", "jpg", "jpeg", "webp"],
"available_templates": len(service.get_available_templates())
}
response_data = create_response(
success=True,
message="获取配置成功",
data=config_info
)
return BaseAPIResponse(**response_data)
except Exception as e:
logger.error(f"获取配置失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"获取配置失败: {str(e)}")

View File

@ -632,7 +632,7 @@ class DatabaseService:
conn = self.db_pool.get_connection() conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
cursor.execute( cursor.execute(
"SELECT id FROM product WHERE productName = %s AND isDelete = 0", "SELECT id FROM product WHERE name = %s AND isDelete = 0",
(product_name,) (product_name,)
) )
result = cursor.fetchone() result = cursor.fetchone()
@ -647,4 +647,259 @@ class DatabaseService:
except Exception as e: except Exception as e:
logger.error(f"查询产品ID失败: {e}") logger.error(f"查询产品ID失败: {e}")
return None return None
def get_image_by_id(self, image_id: int) -> Optional[Dict[str, Any]]:
"""
根据ID获取图像信息
Args:
image_id: 图像ID
Returns:
图像信息字典如果未找到则返回None
"""
if not self.db_pool:
logger.error("数据库连接池未初始化")
return None
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT * FROM material WHERE id = %s AND materialType = 'image' AND isDelete = 0",
(image_id,)
)
result = cursor.fetchone()
cursor.close()
conn.close()
if result:
logger.info(f"找到图像信息: ID={image_id}, 名称={result['materialName']}")
return result
else:
logger.warning(f"未找到图像信息: ID={image_id}")
return None
except Exception as e:
logger.error(f"查询图像信息失败: {e}")
return None
def get_images_by_ids(self, image_ids: List[int]) -> List[Dict[str, Any]]:
"""
根据ID列表批量获取图像信息
Args:
image_ids: 图像ID列表
Returns:
图像信息列表
"""
if not self.db_pool or not image_ids:
return []
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True)
# 构建IN查询
placeholders = ','.join(['%s'] * len(image_ids))
query = f"SELECT * FROM material WHERE id IN ({placeholders}) AND materialType = 'image' AND isDelete = 0"
cursor.execute(query, image_ids)
results = cursor.fetchall()
cursor.close()
conn.close()
logger.info(f"批量查询图像信息: 请求{len(image_ids)}个,找到{len(results)}")
return results
except Exception as e:
logger.error(f"批量查询图像信息失败: {e}")
return []
def get_content_by_id(self, content_id: int) -> Optional[Dict[str, Any]]:
"""
根据ID获取内容信息
Args:
content_id: 内容ID
Returns:
内容信息字典如果未找到则返回None
"""
if not self.db_pool:
logger.error("数据库连接池未初始化")
return None
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT * FROM content WHERE id = %s AND isDelete = 0",
(content_id,)
)
result = cursor.fetchone()
cursor.close()
conn.close()
if result:
logger.info(f"找到内容信息: ID={content_id}, 标题={result['title']}")
return result
else:
logger.warning(f"未找到内容信息: ID={content_id}")
return None
except Exception as e:
logger.error(f"查询内容信息失败: {e}")
return None
def get_content_by_topic_index(self, topic_index: str) -> Optional[Dict[str, Any]]:
"""
根据主题索引获取内容信息
Args:
topic_index: 主题索引
Returns:
内容信息字典如果未找到则返回None
"""
if not self.db_pool:
logger.error("数据库连接池未初始化")
return None
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT * FROM content WHERE topicIndex = %s AND isDelete = 0 ORDER BY createTime DESC LIMIT 1",
(topic_index,)
)
result = cursor.fetchone()
cursor.close()
conn.close()
if result:
logger.info(f"找到内容信息: topicIndex={topic_index}, 标题={result['title']}")
return result
else:
logger.warning(f"未找到内容信息: topicIndex={topic_index}")
return None
except Exception as e:
logger.error(f"查询内容信息失败: {e}")
return None
def get_images_by_folder_id(self, folder_id: int) -> List[Dict[str, Any]]:
"""
根据文件夹ID获取图像列表
Args:
folder_id: 文件夹ID
Returns:
图像信息列表
"""
if not self.db_pool:
return []
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT * FROM material WHERE folderId = %s AND materialType = 'image' AND isDelete = 0 ORDER BY createTime DESC",
(folder_id,)
)
results = cursor.fetchall()
cursor.close()
conn.close()
logger.info(f"根据文件夹ID获取图像: folderId={folder_id}, 找到{len(results)}个图像")
return results
except Exception as e:
logger.error(f"根据文件夹ID获取图像失败: {e}")
return []
def get_folder_by_id(self, folder_id: int) -> Optional[Dict[str, Any]]:
"""
根据ID获取文件夹信息
Args:
folder_id: 文件夹ID
Returns:
文件夹信息字典如果未找到则返回None
"""
if not self.db_pool:
logger.error("数据库连接池未初始化")
return None
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT * FROM material_folder WHERE id = %s AND isDelete = 0",
(folder_id,)
)
result = cursor.fetchone()
cursor.close()
conn.close()
if result:
logger.info(f"找到文件夹信息: ID={folder_id}, 名称={result['folderName']}")
return result
else:
logger.warning(f"未找到文件夹信息: ID={folder_id}")
return None
except Exception as e:
logger.error(f"查询文件夹信息失败: {e}")
return None
def get_related_images_for_content(self, content_id: int, limit: int = 10) -> List[Dict[str, Any]]:
"""
获取与内容相关的图像列表
Args:
content_id: 内容ID
limit: 限制数量
Returns:
相关图像列表
"""
if not self.db_pool:
return []
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True)
# 获取内容信息
cursor.execute(
"SELECT * FROM content WHERE id = %s AND isDelete = 0",
(content_id,)
)
content = cursor.fetchone()
if not content:
cursor.close()
conn.close()
return []
# 获取相关的图像(这里可以根据业务逻辑调整查询条件)
cursor.execute(
"SELECT * FROM material WHERE materialType = 'image' AND isDelete = 0 ORDER BY RAND() LIMIT %s",
(limit,)
)
results = cursor.fetchall()
cursor.close()
conn.close()
logger.info(f"获取相关内容图像: contentId={content_id}, 找到{len(results)}个图像")
return results
except Exception as e:
logger.error(f"获取相关内容图像失败: {e}")
return []

View File

@ -2,20 +2,22 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
海报服务层 海报服务层 - 简化版本
封装现有功能提供API调用 封装核心功能重点优化图片使用追踪
""" """
import logging import logging
import uuid import uuid
from typing import List, Dict, Any, Optional, Tuple import time
from typing import List, Dict, Any, Optional
from datetime import datetime from datetime import datetime
from core.config import ConfigManager, PosterConfig from core.config import ConfigManager, PosterConfig
from core.ai import AIAgent from core.ai import AIAgent
from utils.file_io import OutputManager from utils.file_io import OutputManager
from utils.image_processor import ImageProcessor
from poster.poster_generator import PosterGenerator from poster.poster_generator import PosterGenerator
from poster.text_generator import PosterContentGenerator from api.services.database_service import DatabaseService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -38,86 +40,291 @@ class PosterService:
# 初始化各个组件 # 初始化各个组件
self.poster_generator = PosterGenerator(config_manager, output_manager) self.poster_generator = PosterGenerator(config_manager, output_manager)
self.text_generator = PosterContentGenerator(ai_agent)
def generate_poster(self, content: Dict[str, Any], topic_index: str, # 初始化数据库服务
template_name: Optional[str] = None) -> Tuple[str, str, str, str]: self.db_service = DatabaseService(config_manager)
# 图片使用追踪存储(实际应用中应该使用数据库)
self._image_usage_tracker = {}
def generate_poster_simplified(self, content_id: Optional[int] = None,
product_id: Optional[int] = None,
scenic_spot_id: Optional[int] = None,
image_ids: Optional[List[int]] = None,
generate_collage: bool = False) -> Dict[str, Any]:
""" """
生成海报 简化的海报生成方法
Args: Args:
content: 内容数据包含标题正文等 content_id: 内容ID
topic_index: 主题索引用于文件命名 product_id: 产品ID
template_name: 模板名称如果为None则根据配置选择 scenic_spot_id: 景区ID
image_ids: 图像ID列表
generate_collage: 是否生成拼图
Returns: Returns:
请求ID主题索引生成的海报文件路径和使用的模板名称 包含base64图像数据和图片使用信息的字典
""" """
logger.info(f"开始为主题 {topic_index} 生成海报") start_time = time.time()
logger.info(f"开始生成海报: content_id={content_id}, product_id={product_id}, scenic_spot_id={scenic_spot_id}")
# 生成海报 result = {
poster_path = self.poster_generator.generate_poster(content, topic_index, template_name) "request_id": f"poster-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}",
"poster_base64": "",
"content_info": None,
"product_info": None,
"scenic_spot_info": None,
"used_image_ids": [],
"image_usage_info": [],
"collage_base64": None,
"metadata": {}
}
# 获取使用的模板名称 try:
if template_name is None: # 1. 获取内容信息
template_name = self.poster_generator._select_template() if content_id:
content_info = self.db_service.get_content_by_id(content_id)
# 生成请求ID if content_info:
request_id = f"poster-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}" result["content_info"] = self._build_content_info(content_info)
logger.info(f"海报生成完成请求ID: {request_id}, 主题索引: {topic_index}, 模板: {template_name}")
return request_id, topic_index, poster_path, template_name
def get_available_templates(self) -> Tuple[List[str], str]:
"""
获取可用的模板列表
Returns:
可用的模板列表和默认模板
"""
# 获取配置
poster_config = self.config_manager.get_config('poster', PosterConfig)
# 获取可用模板
available_templates = poster_config.available_templates
# 获取默认模板
default_template = poster_config.template_selection
if default_template == "random" and available_templates:
default_template = available_templates[0]
return available_templates, default_template # 2. 获取产品信息
if product_id:
product_info = self.db_service.get_product_by_id(product_id)
if product_info:
result["product_info"] = self._build_product_info(product_info)
# 3. 获取景区信息
if scenic_spot_id:
scenic_spot_info = self.db_service.get_scenic_spot_by_id(scenic_spot_id)
if scenic_spot_info:
result["scenic_spot_info"] = self._build_scenic_spot_info(scenic_spot_info)
# 4. 处理图像信息并追踪使用情况
image_paths = []
if image_ids:
images = self.db_service.get_images_by_ids(image_ids)
for img in images:
image_paths.append(img['filePath'])
result["used_image_ids"].append(img['id'])
# 更新图片使用追踪
self._update_image_usage(img['id'], "poster_generation")
# 5. 构建海报内容
poster_content = self._build_poster_content_from_info(
result["content_info"], result["product_info"], result["scenic_spot_info"]
)
# 6. 生成海报只使用vibrant模板
poster_path = self.poster_generator.generate_poster(
poster_content,
str(content_id or product_id or scenic_spot_id or "unknown"),
"vibrant"
)
if poster_path:
# 转换为base64
result["poster_base64"] = ImageProcessor.image_to_base64(poster_path)
# 7. 处理拼图
if generate_collage and len(image_paths) > 1:
collage_result = ImageProcessor.process_images_for_poster(
image_paths,
target_size=(900, 1200),
create_collage=True
)
if collage_result.get("collage_image"):
result["collage_base64"] = collage_result["collage_image"]["base64"]
# 更新拼图中使用的图片使用追踪
for img_id in result["used_image_ids"]:
self._update_image_usage(img_id, "collage_creation")
# 8. 构建图片使用信息
result["image_usage_info"] = self._get_image_usage_info(result["used_image_ids"])
# 9. 添加元数据
processing_time = time.time() - start_time
result["metadata"] = {
"total_images_used": len(result["used_image_ids"]),
"has_collage": result["collage_base64"] is not None,
"processing_time": f"{processing_time:.2f}s",
"template_used": "vibrant"
}
logger.info(f"海报生成完成,处理时间: {processing_time:.2f}s")
return result
except Exception as e:
logger.error(f"生成海报失败: {e}", exc_info=True)
result["metadata"]["error"] = str(e)
return result
async def generate_poster_text(self, system_prompt: str, user_prompt: str, def get_image_usage_info(self, image_ids: List[int]) -> Dict[str, Any]:
context_data: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None) -> Tuple[str, Dict[str, Any]]:
""" """
生成海报文案 获取图像使用情况信息
Args: Args:
system_prompt: 系统提示词 image_ids: 图像ID列表
user_prompt: 用户提示词
context_data: 上下文数据用于填充提示词中的占位符
temperature: 生成温度参数
top_p: top_p参数
Returns: Returns:
请求ID和生成的文案内容 图像使用情况信息
""" """
logger.info("开始生成海报文案") request_id = f"usage-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}"
# 生成文案 image_usage_info = []
text_content = await self.text_generator.generate_text_for_poster( total_usage_count = 0
system_prompt=system_prompt, usage_counts = []
user_prompt=user_prompt,
context_data=context_data,
temperature=temperature,
top_p=top_p
)
# 生成请求ID for img_id in image_ids:
request_id = f"text-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}" usage_info = self._get_single_image_usage_info(img_id)
if usage_info:
image_usage_info.append(usage_info)
total_usage_count += usage_info["usage_count"]
usage_counts.append(usage_info["usage_count"])
logger.info(f"海报文案生成完成请求ID: {request_id}") # 计算汇总信息
return request_id, text_content summary = {
"total_images": len(image_ids),
"total_usage_count": total_usage_count,
"most_used_image_id": None,
"least_used_image_id": None
}
if usage_counts:
max_usage = max(usage_counts)
min_usage = min(usage_counts)
summary["most_used_image_id"] = next(
(info["image_id"] for info in image_usage_info if info["usage_count"] == max_usage),
None
)
summary["least_used_image_id"] = next(
(info["image_id"] for info in image_usage_info if info["usage_count"] == min_usage),
None
)
return {
"request_id": request_id,
"image_usage_info": image_usage_info,
"summary": summary
}
def _update_image_usage(self, image_id: int, context: str):
"""更新图片使用追踪"""
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
if image_id not in self._image_usage_tracker:
self._image_usage_tracker[image_id] = {
"usage_count": 0,
"first_used_at": current_time,
"last_used_at": current_time,
"usage_context": []
}
tracker = self._image_usage_tracker[image_id]
tracker["usage_count"] += 1
tracker["last_used_at"] = current_time
if context not in tracker["usage_context"]:
tracker["usage_context"].append(context)
def _get_single_image_usage_info(self, image_id: int) -> Optional[Dict[str, Any]]:
"""获取单个图片的使用信息"""
if image_id in self._image_usage_tracker:
tracker = self._image_usage_tracker[image_id]
return {
"image_id": image_id,
"usage_count": tracker["usage_count"],
"first_used_at": tracker["first_used_at"],
"last_used_at": tracker["last_used_at"],
"usage_context": tracker["usage_context"]
}
else:
# 如果图片从未被使用过,返回默认信息
return {
"image_id": image_id,
"usage_count": 0,
"first_used_at": "",
"last_used_at": "",
"usage_context": []
}
def _get_image_usage_info(self, image_ids: List[int]) -> List[Dict[str, Any]]:
"""获取图片使用信息列表"""
usage_info = []
for img_id in image_ids:
info = self._get_single_image_usage_info(img_id)
if info:
usage_info.append(info)
return usage_info
def _build_content_info(self, content_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建内容信息"""
return {
"id": content_data["id"],
"title": content_data["title"],
"content": content_data["content"],
"tag": content_data["tag"]
}
def _build_product_info(self, product_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建产品信息"""
return {
"id": product_data["id"],
"name": product_data["name"],
"description": product_data.get("description"),
"real_price": float(product_data["realPrice"]) if product_data.get("realPrice") else None,
"origin_price": float(product_data["originPrice"]) if product_data.get("originPrice") else None
}
def _build_scenic_spot_info(self, scenic_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建景区信息"""
return {
"id": scenic_data["id"],
"name": scenic_data["name"],
"description": scenic_data.get("description"),
"address": scenic_data.get("address")
}
def _build_poster_content_from_info(self, content_info: Optional[Dict[str, Any]],
product_info: Optional[Dict[str, Any]],
scenic_spot_info: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""从信息构建海报内容"""
title = ""
content_parts = []
tags = []
# 构建标题
if content_info:
title = content_info["title"]
if content_info.get("content"):
content_parts.append(content_info["content"])
if content_info.get("tag"):
tags.extend(content_info["tag"].split(","))
else:
# 如果没有内容信息,使用景区和产品信息构建标题
if scenic_spot_info and product_info:
title = f"{scenic_spot_info['name']} - {product_info['name']}"
elif scenic_spot_info:
title = scenic_spot_info['name']
elif product_info:
title = product_info['name']
# 添加景区信息
if scenic_spot_info and scenic_spot_info.get("description"):
content_parts.append(f"景区介绍: {scenic_spot_info['description']}")
tags.append(scenic_spot_info["name"])
# 添加产品信息
if product_info:
if product_info.get("description"):
content_parts.append(f"产品介绍: {product_info['description']}")
if product_info.get("real_price"):
content_parts.append(f"价格: ¥{product_info['real_price']}")
tags.append(product_info["name"])
content = "\n\n".join(content_parts) if content_parts else "暂无详细内容"
return {
"title": title,
"content": content,
"tag": tags
}

View File

@ -51,7 +51,7 @@ class PromptService:
self.db_pool = self._init_db_pool() self.db_pool = self._init_db_pool()
# 创建必要的目录结构 # 创建必要的目录结构
self._create_resource_directories() # self._create_resource_directories()
def _create_resource_directories(self): def _create_resource_directories(self):
pass pass