643 lines
22 KiB
Python
643 lines
22 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
Workflow Router
|
|||
|
|
工作流路由 - API v2
|
|||
|
|
包含完整流水线、文件上传、Cookie管理、任务管理等功能
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import logging
|
|||
|
|
import uuid
|
|||
|
|
from typing import Dict, Any, List, Optional
|
|||
|
|
from pathlib import Path
|
|||
|
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
|
|||
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|||
|
|
|
|||
|
|
from ..models import (
|
|||
|
|
PipelineRequest,
|
|||
|
|
FileUploadRequest,
|
|||
|
|
CookieManagementRequest,
|
|||
|
|
PromptBuildRequest,
|
|||
|
|
PipelineResponse,
|
|||
|
|
FileUploadResponse,
|
|||
|
|
CookieManagementResponse,
|
|||
|
|
PromptBuildResponse,
|
|||
|
|
TaskManagementResponse,
|
|||
|
|
ApiResponse
|
|||
|
|
)
|
|||
|
|
from ..services import DatabaseService, FileService
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
router = APIRouter()
|
|||
|
|
|
|||
|
|
# 全局存储
|
|||
|
|
task_storage = {}
|
|||
|
|
cookie_storage = {}
|
|||
|
|
|
|||
|
|
# 服务实例
|
|||
|
|
_file_service = None
|
|||
|
|
_db_service = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_file_service() -> FileService:
|
|||
|
|
"""获取文件服务实例"""
|
|||
|
|
global _file_service
|
|||
|
|
if _file_service is None:
|
|||
|
|
upload_dir = "uploads"
|
|||
|
|
Path(upload_dir).mkdir(exist_ok=True)
|
|||
|
|
_file_service = FileService(upload_dir=upload_dir)
|
|||
|
|
return _file_service
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_database_service() -> DatabaseService:
|
|||
|
|
"""获取数据库服务实例"""
|
|||
|
|
global _db_service
|
|||
|
|
if _db_service is None:
|
|||
|
|
_db_service = DatabaseService()
|
|||
|
|
return _db_service
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/pipeline", response_model=PipelineResponse, summary="执行完整流水线")
|
|||
|
|
async def execute_pipeline(
|
|||
|
|
request: PipelineRequest,
|
|||
|
|
content_pipeline: Dict[str, Any] = Depends(__import__('api_v2.main', fromlist=['get_content_pipeline']).get_content_pipeline),
|
|||
|
|
poster_pipeline: Dict[str, Any] = Depends(__import__('api_v2.main', fromlist=['get_poster_pipeline']).get_poster_pipeline),
|
|||
|
|
db_service: DatabaseService = Depends(get_database_service)
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
执行完整的内容创作流水线
|
|||
|
|
|
|||
|
|
1. 查询相关数据
|
|||
|
|
2. 生成主题
|
|||
|
|
3. 生成内容
|
|||
|
|
4. 评判内容
|
|||
|
|
5. 生成海报(可选)
|
|||
|
|
"""
|
|||
|
|
task_id = str(uuid.uuid4())
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
logger.info(f"开始执行流水线,任务ID: {task_id}")
|
|||
|
|
|
|||
|
|
pipeline_result = {
|
|||
|
|
"task_id": task_id,
|
|||
|
|
"steps_completed": [],
|
|||
|
|
"results": {},
|
|||
|
|
"execution_details": {
|
|||
|
|
"start_time": str(uuid.uuid4()), # 模拟时间戳
|
|||
|
|
"database_used": db_service.is_available()
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 0. 数据查询(如果需要)
|
|||
|
|
if hasattr(request, 'use_database_data') and request.use_database_data:
|
|||
|
|
logger.info("步骤0: 查询数据库数据")
|
|||
|
|
try:
|
|||
|
|
# 查询相关景区和产品数据
|
|||
|
|
scenic_spots, _ = db_service.get_scenic_spots(limit=5, search=request.scenic_info)
|
|||
|
|
products, _ = db_service.get_products(limit=5, search=request.product_info)
|
|||
|
|
styles = db_service.get_styles()
|
|||
|
|
audiences = db_service.get_audiences()
|
|||
|
|
|
|||
|
|
pipeline_result["results"]["database_data"] = {
|
|||
|
|
"scenic_spots": scenic_spots,
|
|||
|
|
"products": products,
|
|||
|
|
"styles": styles,
|
|||
|
|
"audiences": audiences
|
|||
|
|
}
|
|||
|
|
pipeline_result["steps_completed"].append("database_query")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"数据库查询失败,使用原始参数: {e}")
|
|||
|
|
|
|||
|
|
# 1. 生成主题
|
|||
|
|
logger.info("步骤1: 生成主题")
|
|||
|
|
topic_generator = content_pipeline["topic_generator"]
|
|||
|
|
|
|||
|
|
topic_request_id, topics_data = await topic_generator.generate_topics(
|
|||
|
|
creative_materials=request.creative_materials,
|
|||
|
|
num_topics=request.num_topics,
|
|||
|
|
style=request.style
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
pipeline_result["steps_completed"].append("topic_generation")
|
|||
|
|
pipeline_result["results"]["topics"] = topics_data
|
|||
|
|
|
|||
|
|
# 选择第一个主题进行内容生成
|
|||
|
|
if topics_data.get("topics"):
|
|||
|
|
selected_topic = topics_data["topics"][0]
|
|||
|
|
|
|||
|
|
# 2. 生成内容
|
|||
|
|
logger.info("步骤2: 生成内容")
|
|||
|
|
content_generator = content_pipeline["content_generator"]
|
|||
|
|
|
|||
|
|
content_request_id, content_data = await content_generator.generate_content(
|
|||
|
|
topic=selected_topic,
|
|||
|
|
scenic_info=request.scenic_info,
|
|||
|
|
product_info=request.product_info,
|
|||
|
|
additional_requirements=f"风格: {request.style}, 目标受众: {request.target_audience}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
pipeline_result["steps_completed"].append("content_generation")
|
|||
|
|
pipeline_result["results"]["content"] = content_data
|
|||
|
|
|
|||
|
|
# 3. 评判内容
|
|||
|
|
logger.info("步骤3: 评判内容")
|
|||
|
|
content_judger = content_pipeline["content_judger"]
|
|||
|
|
|
|||
|
|
judge_request_id, judge_data = await content_judger.judge_content(
|
|||
|
|
product_info=request.product_info,
|
|||
|
|
content_to_judge=content_data.get("content", {}).get("content", "")
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
pipeline_result["steps_completed"].append("content_judging")
|
|||
|
|
pipeline_result["results"]["judge_result"] = judge_data
|
|||
|
|
|
|||
|
|
# 4. 生成海报(可选)
|
|||
|
|
if request.generate_poster:
|
|||
|
|
logger.info("步骤4: 生成海报")
|
|||
|
|
poster_generator = poster_pipeline["poster_generator"]
|
|||
|
|
|
|||
|
|
poster_request_id, poster_data = await poster_generator.generate_poster(
|
|||
|
|
template_name=request.poster_template,
|
|||
|
|
scenic_info=request.scenic_info,
|
|||
|
|
product_info=request.product_info,
|
|||
|
|
style_options={"transparent_background": True}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
pipeline_result["steps_completed"].append("poster_generation")
|
|||
|
|
pipeline_result["results"]["poster"] = poster_data
|
|||
|
|
|
|||
|
|
# 保存任务结果
|
|||
|
|
task_storage[task_id] = pipeline_result
|
|||
|
|
|
|||
|
|
logger.info(f"流水线执行完成,任务ID: {task_id}")
|
|||
|
|
|
|||
|
|
return PipelineResponse(
|
|||
|
|
success=True,
|
|||
|
|
message="流水线执行完成",
|
|||
|
|
data=pipeline_result,
|
|||
|
|
request_id=task_id
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
error_msg = f"流水线执行失败: {str(e)}"
|
|||
|
|
logger.error(error_msg, exc_info=True)
|
|||
|
|
|
|||
|
|
return JSONResponse(
|
|||
|
|
status_code=500,
|
|||
|
|
content=PipelineResponse(
|
|||
|
|
success=False,
|
|||
|
|
message="流水线执行失败",
|
|||
|
|
error=error_msg,
|
|||
|
|
request_id=task_id
|
|||
|
|
).dict()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/upload", response_model=FileUploadResponse, summary="上传文件")
|
|||
|
|
async def upload_file(
|
|||
|
|
file: UploadFile = File(...),
|
|||
|
|
description: str = Form(""),
|
|||
|
|
process_immediately: bool = Form(True),
|
|||
|
|
document_pipeline: Dict[str, Any] = Depends(__import__('api_v2.main', fromlist=['get_document_pipeline']).get_document_pipeline),
|
|||
|
|
file_service: FileService = Depends(get_file_service)
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
上传文件并可选择立即处理(支持大文件流式处理)
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
logger.info(f"开始上传文件: {file.filename}")
|
|||
|
|
|
|||
|
|
# 使用文件服务保存文件(流式处理)
|
|||
|
|
file_info = await file_service.save_upload_file(
|
|||
|
|
file=file,
|
|||
|
|
validate_size=True
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 添加描述信息
|
|||
|
|
file_info["description"] = description
|
|||
|
|
file_info["upload_id"] = str(uuid.uuid4())
|
|||
|
|
|
|||
|
|
# 可选立即处理
|
|||
|
|
if process_immediately:
|
|||
|
|
try:
|
|||
|
|
logger.info("开始处理上传的文件")
|
|||
|
|
text_extractor = document_pipeline["text_extractor"]
|
|||
|
|
extracted_doc = text_extractor.extract_text(file_info["file_path"])
|
|||
|
|
|
|||
|
|
if extracted_doc:
|
|||
|
|
file_info["processing_status"] = "completed"
|
|||
|
|
file_info["extracted_text_length"] = len(extracted_doc.content)
|
|||
|
|
file_info["document_type"] = extracted_doc.document_type
|
|||
|
|
file_info["extracted_metadata"] = extracted_doc.metadata
|
|||
|
|
else:
|
|||
|
|
file_info["processing_status"] = "failed"
|
|||
|
|
file_info["processing_error"] = "文档内容为空"
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"文件处理失败: {e}")
|
|||
|
|
file_info["processing_status"] = "failed"
|
|||
|
|
file_info["processing_error"] = str(e)
|
|||
|
|
else:
|
|||
|
|
file_info["processing_status"] = "pending"
|
|||
|
|
|
|||
|
|
logger.info(f"文件上传成功: {file.filename} (大小: {file_info['file_size']} 字节)")
|
|||
|
|
|
|||
|
|
return FileUploadResponse(
|
|||
|
|
success=True,
|
|||
|
|
message="文件上传成功",
|
|||
|
|
data=file_info
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except ValueError as e:
|
|||
|
|
# 文件验证错误
|
|||
|
|
logger.error(f"文件验证失败: {e}")
|
|||
|
|
return JSONResponse(
|
|||
|
|
status_code=400,
|
|||
|
|
content=FileUploadResponse(
|
|||
|
|
success=False,
|
|||
|
|
message="文件验证失败",
|
|||
|
|
error=str(e)
|
|||
|
|
).dict()
|
|||
|
|
)
|
|||
|
|
except Exception as e:
|
|||
|
|
error_msg = f"文件上传失败: {str(e)}"
|
|||
|
|
logger.error(error_msg, exc_info=True)
|
|||
|
|
|
|||
|
|
return JSONResponse(
|
|||
|
|
status_code=500,
|
|||
|
|
content=FileUploadResponse(
|
|||
|
|
success=False,
|
|||
|
|
message="文件上传失败",
|
|||
|
|
error=error_msg
|
|||
|
|
).dict()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/files", summary="获取文件列表")
|
|||
|
|
async def list_files(
|
|||
|
|
file_service: FileService = Depends(get_file_service)
|
|||
|
|
):
|
|||
|
|
"""获取已上传的文件列表"""
|
|||
|
|
try:
|
|||
|
|
files = file_service.list_files()
|
|||
|
|
|
|||
|
|
return ApiResponse(
|
|||
|
|
success=True,
|
|||
|
|
message="获取文件列表成功",
|
|||
|
|
data={
|
|||
|
|
"files": files,
|
|||
|
|
"total_files": len(files),
|
|||
|
|
"upload_directory": str(file_service.upload_dir)
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
error_msg = f"获取文件列表失败: {str(e)}"
|
|||
|
|
logger.error(error_msg)
|
|||
|
|
|
|||
|
|
return JSONResponse(
|
|||
|
|
status_code=500,
|
|||
|
|
content=ApiResponse(
|
|||
|
|
success=False,
|
|||
|
|
message="获取文件列表失败",
|
|||
|
|
error=error_msg
|
|||
|
|
).dict()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/files/{filename}/download", summary="下载文件")
|
|||
|
|
async def download_file(
|
|||
|
|
filename: str,
|
|||
|
|
file_service: FileService = Depends(get_file_service)
|
|||
|
|
):
|
|||
|
|
"""下载文件(支持大文件流式下载)"""
|
|||
|
|
try:
|
|||
|
|
file_path = file_service.upload_dir / filename
|
|||
|
|
|
|||
|
|
if not file_path.exists():
|
|||
|
|
raise HTTPException(status_code=404, detail="文件不存在")
|
|||
|
|
|
|||
|
|
# 创建流式响应
|
|||
|
|
return file_service.create_streaming_response(
|
|||
|
|
file_path=str(file_path),
|
|||
|
|
filename=filename
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except HTTPException:
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
error_msg = f"文件下载失败: {str(e)}"
|
|||
|
|
logger.error(error_msg)
|
|||
|
|
raise HTTPException(status_code=500, detail=error_msg)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.delete("/files/{filename}", summary="删除文件")
|
|||
|
|
async def delete_file(
|
|||
|
|
filename: str,
|
|||
|
|
file_service: FileService = Depends(get_file_service)
|
|||
|
|
):
|
|||
|
|
"""删除文件"""
|
|||
|
|
try:
|
|||
|
|
file_path = file_service.upload_dir / filename
|
|||
|
|
|
|||
|
|
if not file_path.exists():
|
|||
|
|
raise HTTPException(status_code=404, detail="文件不存在")
|
|||
|
|
|
|||
|
|
success = await file_service.delete_file(str(file_path))
|
|||
|
|
|
|||
|
|
if success:
|
|||
|
|
return ApiResponse(
|
|||
|
|
success=True,
|
|||
|
|
message="文件删除成功",
|
|||
|
|
data={"filename": filename}
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
raise HTTPException(status_code=500, detail="文件删除失败")
|
|||
|
|
|
|||
|
|
except HTTPException:
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
error_msg = f"文件删除失败: {str(e)}"
|
|||
|
|
logger.error(error_msg)
|
|||
|
|
raise HTTPException(status_code=500, detail=error_msg)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/cookies", response_model=CookieManagementResponse, summary="添加Cookie")
|
|||
|
|
async def add_cookie(request: CookieManagementRequest):
|
|||
|
|
"""添加Cookie"""
|
|||
|
|
try:
|
|||
|
|
cookie_storage[request.name] = {
|
|||
|
|
"name": request.name,
|
|||
|
|
"cookie_string": request.cookie_string,
|
|||
|
|
"description": request.description,
|
|||
|
|
"created_at": str(uuid.uuid4()), # 模拟时间戳
|
|||
|
|
"is_active": True,
|
|||
|
|
"use_count": 0,
|
|||
|
|
"last_used": None
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
result_data = {
|
|||
|
|
"operation": "add",
|
|||
|
|
"cookie_name": request.name,
|
|||
|
|
"total_cookies": len(cookie_storage),
|
|||
|
|
"valid_cookies": len([c for c in cookie_storage.values() if c["is_active"]])
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
logger.info(f"Cookie添加成功: {request.name}")
|
|||
|
|
|
|||
|
|
return CookieManagementResponse(
|
|||
|
|
success=True,
|
|||
|
|
message="Cookie添加成功",
|
|||
|
|
data=result_data
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
error_msg = f"Cookie添加失败: {str(e)}"
|
|||
|
|
logger.error(error_msg, exc_info=True)
|
|||
|
|
|
|||
|
|
return JSONResponse(
|
|||
|
|
status_code=500,
|
|||
|
|
content=CookieManagementResponse(
|
|||
|
|
success=False,
|
|||
|
|
message="Cookie添加失败",
|
|||
|
|
error=error_msg
|
|||
|
|
).dict()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.delete("/cookies/{cookie_name}", response_model=CookieManagementResponse, summary="删除Cookie")
|
|||
|
|
async def delete_cookie(cookie_name: str):
|
|||
|
|
"""删除Cookie"""
|
|||
|
|
try:
|
|||
|
|
if cookie_name not in cookie_storage:
|
|||
|
|
raise HTTPException(status_code=404, detail="Cookie不存在")
|
|||
|
|
|
|||
|
|
del cookie_storage[cookie_name]
|
|||
|
|
|
|||
|
|
result_data = {
|
|||
|
|
"operation": "delete",
|
|||
|
|
"cookie_name": cookie_name,
|
|||
|
|
"total_cookies": len(cookie_storage),
|
|||
|
|
"valid_cookies": len([c for c in cookie_storage.values() if c["is_active"]])
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
logger.info(f"Cookie删除成功: {cookie_name}")
|
|||
|
|
|
|||
|
|
return CookieManagementResponse(
|
|||
|
|
success=True,
|
|||
|
|
message="Cookie删除成功",
|
|||
|
|
data=result_data
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except HTTPException:
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
error_msg = f"Cookie删除失败: {str(e)}"
|
|||
|
|
logger.error(error_msg, exc_info=True)
|
|||
|
|
|
|||
|
|
return JSONResponse(
|
|||
|
|
status_code=500,
|
|||
|
|
content=CookieManagementResponse(
|
|||
|
|
success=False,
|
|||
|
|
message="Cookie删除失败",
|
|||
|
|
error=error_msg
|
|||
|
|
).dict()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/cookies", response_model=CookieManagementResponse, summary="获取Cookie统计")
|
|||
|
|
async def get_cookie_stats():
|
|||
|
|
"""获取Cookie统计信息"""
|
|||
|
|
try:
|
|||
|
|
# 隐藏敏感的Cookie字符串
|
|||
|
|
safe_cookies = []
|
|||
|
|
for cookie in cookie_storage.values():
|
|||
|
|
safe_cookie = {**cookie}
|
|||
|
|
safe_cookie["cookie_string"] = "***已隐藏***"
|
|||
|
|
safe_cookies.append(safe_cookie)
|
|||
|
|
|
|||
|
|
result_data = {
|
|||
|
|
"total_cookies": len(cookie_storage),
|
|||
|
|
"valid_cookies": len([c for c in cookie_storage.values() if c["is_active"]]),
|
|||
|
|
"cookies": safe_cookies
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return CookieManagementResponse(
|
|||
|
|
success=True,
|
|||
|
|
message="获取Cookie统计成功",
|
|||
|
|
data=result_data
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
error_msg = f"获取Cookie统计失败: {str(e)}"
|
|||
|
|
logger.error(error_msg, exc_info=True)
|
|||
|
|
|
|||
|
|
return JSONResponse(
|
|||
|
|
status_code=500,
|
|||
|
|
content=CookieManagementResponse(
|
|||
|
|
success=False,
|
|||
|
|
message="获取Cookie统计失败",
|
|||
|
|
error=error_msg
|
|||
|
|
).dict()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/prompts/build", response_model=PromptBuildResponse, summary="构建提示词")
|
|||
|
|
async def build_prompt(
|
|||
|
|
request: PromptBuildRequest,
|
|||
|
|
content_pipeline: Dict[str, Any] = Depends(__import__('api_v2.main', fromlist=['get_content_pipeline']).get_content_pipeline),
|
|||
|
|
db_service: DatabaseService = Depends(get_database_service)
|
|||
|
|
):
|
|||
|
|
"""构建自定义提示词(支持数据库数据增强)"""
|
|||
|
|
try:
|
|||
|
|
prompt_manager = content_pipeline["prompt_manager"]
|
|||
|
|
|
|||
|
|
# 如果启用数据库增强,获取相关数据
|
|||
|
|
enhanced_data = {}
|
|||
|
|
if db_service.is_available():
|
|||
|
|
try:
|
|||
|
|
# 查询相关数据
|
|||
|
|
if request.scenic_info:
|
|||
|
|
scenic_spots, _ = db_service.get_scenic_spots(limit=3, search=request.scenic_info)
|
|||
|
|
enhanced_data["scenic_spots"] = scenic_spots
|
|||
|
|
|
|||
|
|
if request.product_info:
|
|||
|
|
products, _ = db_service.get_products(limit=3, search=request.product_info)
|
|||
|
|
enhanced_data["products"] = products
|
|||
|
|
|
|||
|
|
styles = db_service.get_styles()
|
|||
|
|
audiences = db_service.get_audiences()
|
|||
|
|
enhanced_data["available_styles"] = [s["styleName"] for s in styles]
|
|||
|
|
enhanced_data["available_audiences"] = [a["audienceName"] for a in audiences]
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"数据库增强失败: {e}")
|
|||
|
|
|
|||
|
|
# 获取基础提示词模板
|
|||
|
|
try:
|
|||
|
|
system_prompt = prompt_manager.get_prompt(request.task_type, "system")
|
|||
|
|
user_prompt_template = prompt_manager.get_prompt(request.task_type, "user")
|
|||
|
|
except:
|
|||
|
|
# 如果没有找到模板,提供默认模板
|
|||
|
|
system_prompt = f"你是一个专业的{request.task_type}专家。"
|
|||
|
|
user_prompt_template = "请根据以下信息完成任务:\n景区信息:{scenic_info}\n产品信息:{product_info}\n风格要求:{style}\n目标受众:{target_audience}\n自定义要求:{custom_requirements}"
|
|||
|
|
|
|||
|
|
# 格式化用户提示词
|
|||
|
|
user_prompt = user_prompt_template.format(
|
|||
|
|
scenic_info=request.scenic_info or "未提供",
|
|||
|
|
product_info=request.product_info or "未提供",
|
|||
|
|
style=request.style or "默认",
|
|||
|
|
target_audience=request.target_audience or "通用受众",
|
|||
|
|
custom_requirements=request.custom_requirements or "无特殊要求"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
result_data = {
|
|||
|
|
"system_prompt": system_prompt,
|
|||
|
|
"user_prompt": user_prompt,
|
|||
|
|
"task_type": request.task_type,
|
|||
|
|
"variables": {
|
|||
|
|
"scenic_info": request.scenic_info,
|
|||
|
|
"product_info": request.product_info,
|
|||
|
|
"style": request.style,
|
|||
|
|
"target_audience": request.target_audience,
|
|||
|
|
"custom_requirements": request.custom_requirements
|
|||
|
|
},
|
|||
|
|
"enhanced_data": enhanced_data,
|
|||
|
|
"database_enhanced": bool(enhanced_data)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
logger.info(f"提示词构建成功: {request.task_type}")
|
|||
|
|
|
|||
|
|
return PromptBuildResponse(
|
|||
|
|
success=True,
|
|||
|
|
message="提示词构建成功",
|
|||
|
|
data=result_data
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
error_msg = f"提示词构建失败: {str(e)}"
|
|||
|
|
logger.error(error_msg, exc_info=True)
|
|||
|
|
|
|||
|
|
return JSONResponse(
|
|||
|
|
status_code=500,
|
|||
|
|
content=PromptBuildResponse(
|
|||
|
|
success=False,
|
|||
|
|
message="提示词构建失败",
|
|||
|
|
error=error_msg
|
|||
|
|
).dict()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/tasks", response_model=TaskManagementResponse, summary="获取任务列表")
|
|||
|
|
async def get_tasks():
|
|||
|
|
"""获取所有任务列表"""
|
|||
|
|
try:
|
|||
|
|
tasks = []
|
|||
|
|
for task_id, task_data in task_storage.items():
|
|||
|
|
tasks.append({
|
|||
|
|
"task_id": task_id,
|
|||
|
|
"steps_completed": task_data.get("steps_completed", []),
|
|||
|
|
"status": "completed" if task_data.get("results") else "processing",
|
|||
|
|
"created_time": task_data.get("execution_details", {}).get("start_time", "unknown"),
|
|||
|
|
"total_steps": len(task_data.get("results", {}))
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
result_data = {
|
|||
|
|
"tasks": tasks,
|
|||
|
|
"total_tasks": len(tasks),
|
|||
|
|
"completed_tasks": len([t for t in tasks if t["status"] == "completed"]),
|
|||
|
|
"processing_tasks": len([t for t in tasks if t["status"] == "processing"])
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return TaskManagementResponse(
|
|||
|
|
success=True,
|
|||
|
|
message="获取任务列表成功",
|
|||
|
|
data=result_data
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
error_msg = f"获取任务列表失败: {str(e)}"
|
|||
|
|
logger.error(error_msg, exc_info=True)
|
|||
|
|
|
|||
|
|
return JSONResponse(
|
|||
|
|
status_code=500,
|
|||
|
|
content=TaskManagementResponse(
|
|||
|
|
success=False,
|
|||
|
|
message="获取任务列表失败",
|
|||
|
|
error=error_msg
|
|||
|
|
).dict()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/tasks/{task_id}", response_model=TaskManagementResponse, summary="获取任务详情")
|
|||
|
|
async def get_task(task_id: str):
|
|||
|
|
"""获取指定任务的详情"""
|
|||
|
|
try:
|
|||
|
|
if task_id not in task_storage:
|
|||
|
|
raise HTTPException(status_code=404, detail="任务不存在")
|
|||
|
|
|
|||
|
|
task_data = task_storage[task_id]
|
|||
|
|
|
|||
|
|
return TaskManagementResponse(
|
|||
|
|
success=True,
|
|||
|
|
message="获取任务详情成功",
|
|||
|
|
data=task_data
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except HTTPException:
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
error_msg = f"获取任务详情失败: {str(e)}"
|
|||
|
|
logger.error(error_msg, exc_info=True)
|
|||
|
|
|
|||
|
|
return JSONResponse(
|
|||
|
|
status_code=500,
|
|||
|
|
content=TaskManagementResponse(
|
|||
|
|
success=False,
|
|||
|
|
message="获取任务详情失败",
|
|||
|
|
error=error_msg
|
|||
|
|
).dict()
|
|||
|
|
)
|