643 lines
22 KiB
Python
643 lines
22 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
Workflow Router
|
||
工作流路由 - API v2
|
||
包含完整流水线、文件上传、Cookie管理、任务管理等功能
|
||
"""
|
||
|
||
import logging
|
||
import uuid
|
||
from typing import Dict, Any, List, Optional
|
||
from pathlib import Path
|
||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
|
||
from fastapi.responses import JSONResponse, StreamingResponse
|
||
|
||
from ..models import (
|
||
PipelineRequest,
|
||
FileUploadRequest,
|
||
CookieManagementRequest,
|
||
PromptBuildRequest,
|
||
PipelineResponse,
|
||
FileUploadResponse,
|
||
CookieManagementResponse,
|
||
PromptBuildResponse,
|
||
TaskManagementResponse,
|
||
ApiResponse
|
||
)
|
||
from ..services import DatabaseService, FileService
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter()
|
||
|
||
# 全局存储
|
||
task_storage = {}
|
||
cookie_storage = {}
|
||
|
||
# 服务实例
|
||
_file_service = None
|
||
_db_service = None
|
||
|
||
|
||
def get_file_service() -> FileService:
|
||
"""获取文件服务实例"""
|
||
global _file_service
|
||
if _file_service is None:
|
||
upload_dir = "uploads"
|
||
Path(upload_dir).mkdir(exist_ok=True)
|
||
_file_service = FileService(upload_dir=upload_dir)
|
||
return _file_service
|
||
|
||
|
||
def get_database_service() -> DatabaseService:
|
||
"""获取数据库服务实例"""
|
||
global _db_service
|
||
if _db_service is None:
|
||
_db_service = DatabaseService()
|
||
return _db_service
|
||
|
||
|
||
@router.post("/pipeline", response_model=PipelineResponse, summary="执行完整流水线")
|
||
async def execute_pipeline(
|
||
request: PipelineRequest,
|
||
content_pipeline: Dict[str, Any] = Depends(__import__('api_v2.main', fromlist=['get_content_pipeline']).get_content_pipeline),
|
||
poster_pipeline: Dict[str, Any] = Depends(__import__('api_v2.main', fromlist=['get_poster_pipeline']).get_poster_pipeline),
|
||
db_service: DatabaseService = Depends(get_database_service)
|
||
):
|
||
"""
|
||
执行完整的内容创作流水线
|
||
|
||
1. 查询相关数据
|
||
2. 生成主题
|
||
3. 生成内容
|
||
4. 评判内容
|
||
5. 生成海报(可选)
|
||
"""
|
||
task_id = str(uuid.uuid4())
|
||
|
||
try:
|
||
logger.info(f"开始执行流水线,任务ID: {task_id}")
|
||
|
||
pipeline_result = {
|
||
"task_id": task_id,
|
||
"steps_completed": [],
|
||
"results": {},
|
||
"execution_details": {
|
||
"start_time": str(uuid.uuid4()), # 模拟时间戳
|
||
"database_used": db_service.is_available()
|
||
}
|
||
}
|
||
|
||
# 0. 数据查询(如果需要)
|
||
if hasattr(request, 'use_database_data') and request.use_database_data:
|
||
logger.info("步骤0: 查询数据库数据")
|
||
try:
|
||
# 查询相关景区和产品数据
|
||
scenic_spots, _ = db_service.get_scenic_spots(limit=5, search=request.scenic_info)
|
||
products, _ = db_service.get_products(limit=5, search=request.product_info)
|
||
styles = db_service.get_styles()
|
||
audiences = db_service.get_audiences()
|
||
|
||
pipeline_result["results"]["database_data"] = {
|
||
"scenic_spots": scenic_spots,
|
||
"products": products,
|
||
"styles": styles,
|
||
"audiences": audiences
|
||
}
|
||
pipeline_result["steps_completed"].append("database_query")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"数据库查询失败,使用原始参数: {e}")
|
||
|
||
# 1. 生成主题
|
||
logger.info("步骤1: 生成主题")
|
||
topic_generator = content_pipeline["topic_generator"]
|
||
|
||
topic_request_id, topics_data = await topic_generator.generate_topics(
|
||
creative_materials=request.creative_materials,
|
||
num_topics=request.num_topics,
|
||
style=request.style
|
||
)
|
||
|
||
pipeline_result["steps_completed"].append("topic_generation")
|
||
pipeline_result["results"]["topics"] = topics_data
|
||
|
||
# 选择第一个主题进行内容生成
|
||
if topics_data.get("topics"):
|
||
selected_topic = topics_data["topics"][0]
|
||
|
||
# 2. 生成内容
|
||
logger.info("步骤2: 生成内容")
|
||
content_generator = content_pipeline["content_generator"]
|
||
|
||
content_request_id, content_data = await content_generator.generate_content(
|
||
topic=selected_topic,
|
||
scenic_info=request.scenic_info,
|
||
product_info=request.product_info,
|
||
additional_requirements=f"风格: {request.style}, 目标受众: {request.target_audience}"
|
||
)
|
||
|
||
pipeline_result["steps_completed"].append("content_generation")
|
||
pipeline_result["results"]["content"] = content_data
|
||
|
||
# 3. 评判内容
|
||
logger.info("步骤3: 评判内容")
|
||
content_judger = content_pipeline["content_judger"]
|
||
|
||
judge_request_id, judge_data = await content_judger.judge_content(
|
||
product_info=request.product_info,
|
||
content_to_judge=content_data.get("content", {}).get("content", "")
|
||
)
|
||
|
||
pipeline_result["steps_completed"].append("content_judging")
|
||
pipeline_result["results"]["judge_result"] = judge_data
|
||
|
||
# 4. 生成海报(可选)
|
||
if request.generate_poster:
|
||
logger.info("步骤4: 生成海报")
|
||
poster_generator = poster_pipeline["poster_generator"]
|
||
|
||
poster_request_id, poster_data = await poster_generator.generate_poster(
|
||
template_name=request.poster_template,
|
||
scenic_info=request.scenic_info,
|
||
product_info=request.product_info,
|
||
style_options={"transparent_background": True}
|
||
)
|
||
|
||
pipeline_result["steps_completed"].append("poster_generation")
|
||
pipeline_result["results"]["poster"] = poster_data
|
||
|
||
# 保存任务结果
|
||
task_storage[task_id] = pipeline_result
|
||
|
||
logger.info(f"流水线执行完成,任务ID: {task_id}")
|
||
|
||
return PipelineResponse(
|
||
success=True,
|
||
message="流水线执行完成",
|
||
data=pipeline_result,
|
||
request_id=task_id
|
||
)
|
||
|
||
except Exception as e:
|
||
error_msg = f"流水线执行失败: {str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content=PipelineResponse(
|
||
success=False,
|
||
message="流水线执行失败",
|
||
error=error_msg,
|
||
request_id=task_id
|
||
).dict()
|
||
)
|
||
|
||
|
||
@router.post("/upload", response_model=FileUploadResponse, summary="上传文件")
|
||
async def upload_file(
|
||
file: UploadFile = File(...),
|
||
description: str = Form(""),
|
||
process_immediately: bool = Form(True),
|
||
document_pipeline: Dict[str, Any] = Depends(__import__('api_v2.main', fromlist=['get_document_pipeline']).get_document_pipeline),
|
||
file_service: FileService = Depends(get_file_service)
|
||
):
|
||
"""
|
||
上传文件并可选择立即处理(支持大文件流式处理)
|
||
"""
|
||
try:
|
||
logger.info(f"开始上传文件: {file.filename}")
|
||
|
||
# 使用文件服务保存文件(流式处理)
|
||
file_info = await file_service.save_upload_file(
|
||
file=file,
|
||
validate_size=True
|
||
)
|
||
|
||
# 添加描述信息
|
||
file_info["description"] = description
|
||
file_info["upload_id"] = str(uuid.uuid4())
|
||
|
||
# 可选立即处理
|
||
if process_immediately:
|
||
try:
|
||
logger.info("开始处理上传的文件")
|
||
text_extractor = document_pipeline["text_extractor"]
|
||
extracted_doc = text_extractor.extract_text(file_info["file_path"])
|
||
|
||
if extracted_doc:
|
||
file_info["processing_status"] = "completed"
|
||
file_info["extracted_text_length"] = len(extracted_doc.content)
|
||
file_info["document_type"] = extracted_doc.document_type
|
||
file_info["extracted_metadata"] = extracted_doc.metadata
|
||
else:
|
||
file_info["processing_status"] = "failed"
|
||
file_info["processing_error"] = "文档内容为空"
|
||
|
||
except Exception as e:
|
||
logger.warning(f"文件处理失败: {e}")
|
||
file_info["processing_status"] = "failed"
|
||
file_info["processing_error"] = str(e)
|
||
else:
|
||
file_info["processing_status"] = "pending"
|
||
|
||
logger.info(f"文件上传成功: {file.filename} (大小: {file_info['file_size']} 字节)")
|
||
|
||
return FileUploadResponse(
|
||
success=True,
|
||
message="文件上传成功",
|
||
data=file_info
|
||
)
|
||
|
||
except ValueError as e:
|
||
# 文件验证错误
|
||
logger.error(f"文件验证失败: {e}")
|
||
return JSONResponse(
|
||
status_code=400,
|
||
content=FileUploadResponse(
|
||
success=False,
|
||
message="文件验证失败",
|
||
error=str(e)
|
||
).dict()
|
||
)
|
||
except Exception as e:
|
||
error_msg = f"文件上传失败: {str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content=FileUploadResponse(
|
||
success=False,
|
||
message="文件上传失败",
|
||
error=error_msg
|
||
).dict()
|
||
)
|
||
|
||
|
||
@router.get("/files", summary="获取文件列表")
|
||
async def list_files(
|
||
file_service: FileService = Depends(get_file_service)
|
||
):
|
||
"""获取已上传的文件列表"""
|
||
try:
|
||
files = file_service.list_files()
|
||
|
||
return ApiResponse(
|
||
success=True,
|
||
message="获取文件列表成功",
|
||
data={
|
||
"files": files,
|
||
"total_files": len(files),
|
||
"upload_directory": str(file_service.upload_dir)
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
error_msg = f"获取文件列表失败: {str(e)}"
|
||
logger.error(error_msg)
|
||
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content=ApiResponse(
|
||
success=False,
|
||
message="获取文件列表失败",
|
||
error=error_msg
|
||
).dict()
|
||
)
|
||
|
||
|
||
@router.get("/files/{filename}/download", summary="下载文件")
|
||
async def download_file(
|
||
filename: str,
|
||
file_service: FileService = Depends(get_file_service)
|
||
):
|
||
"""下载文件(支持大文件流式下载)"""
|
||
try:
|
||
file_path = file_service.upload_dir / filename
|
||
|
||
if not file_path.exists():
|
||
raise HTTPException(status_code=404, detail="文件不存在")
|
||
|
||
# 创建流式响应
|
||
return file_service.create_streaming_response(
|
||
file_path=str(file_path),
|
||
filename=filename
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
error_msg = f"文件下载失败: {str(e)}"
|
||
logger.error(error_msg)
|
||
raise HTTPException(status_code=500, detail=error_msg)
|
||
|
||
|
||
@router.delete("/files/{filename}", summary="删除文件")
|
||
async def delete_file(
|
||
filename: str,
|
||
file_service: FileService = Depends(get_file_service)
|
||
):
|
||
"""删除文件"""
|
||
try:
|
||
file_path = file_service.upload_dir / filename
|
||
|
||
if not file_path.exists():
|
||
raise HTTPException(status_code=404, detail="文件不存在")
|
||
|
||
success = await file_service.delete_file(str(file_path))
|
||
|
||
if success:
|
||
return ApiResponse(
|
||
success=True,
|
||
message="文件删除成功",
|
||
data={"filename": filename}
|
||
)
|
||
else:
|
||
raise HTTPException(status_code=500, detail="文件删除失败")
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
error_msg = f"文件删除失败: {str(e)}"
|
||
logger.error(error_msg)
|
||
raise HTTPException(status_code=500, detail=error_msg)
|
||
|
||
|
||
@router.post("/cookies", response_model=CookieManagementResponse, summary="添加Cookie")
|
||
async def add_cookie(request: CookieManagementRequest):
|
||
"""添加Cookie"""
|
||
try:
|
||
cookie_storage[request.name] = {
|
||
"name": request.name,
|
||
"cookie_string": request.cookie_string,
|
||
"description": request.description,
|
||
"created_at": str(uuid.uuid4()), # 模拟时间戳
|
||
"is_active": True,
|
||
"use_count": 0,
|
||
"last_used": None
|
||
}
|
||
|
||
result_data = {
|
||
"operation": "add",
|
||
"cookie_name": request.name,
|
||
"total_cookies": len(cookie_storage),
|
||
"valid_cookies": len([c for c in cookie_storage.values() if c["is_active"]])
|
||
}
|
||
|
||
logger.info(f"Cookie添加成功: {request.name}")
|
||
|
||
return CookieManagementResponse(
|
||
success=True,
|
||
message="Cookie添加成功",
|
||
data=result_data
|
||
)
|
||
|
||
except Exception as e:
|
||
error_msg = f"Cookie添加失败: {str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content=CookieManagementResponse(
|
||
success=False,
|
||
message="Cookie添加失败",
|
||
error=error_msg
|
||
).dict()
|
||
)
|
||
|
||
|
||
@router.delete("/cookies/{cookie_name}", response_model=CookieManagementResponse, summary="删除Cookie")
|
||
async def delete_cookie(cookie_name: str):
|
||
"""删除Cookie"""
|
||
try:
|
||
if cookie_name not in cookie_storage:
|
||
raise HTTPException(status_code=404, detail="Cookie不存在")
|
||
|
||
del cookie_storage[cookie_name]
|
||
|
||
result_data = {
|
||
"operation": "delete",
|
||
"cookie_name": cookie_name,
|
||
"total_cookies": len(cookie_storage),
|
||
"valid_cookies": len([c for c in cookie_storage.values() if c["is_active"]])
|
||
}
|
||
|
||
logger.info(f"Cookie删除成功: {cookie_name}")
|
||
|
||
return CookieManagementResponse(
|
||
success=True,
|
||
message="Cookie删除成功",
|
||
data=result_data
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
error_msg = f"Cookie删除失败: {str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content=CookieManagementResponse(
|
||
success=False,
|
||
message="Cookie删除失败",
|
||
error=error_msg
|
||
).dict()
|
||
)
|
||
|
||
|
||
@router.get("/cookies", response_model=CookieManagementResponse, summary="获取Cookie统计")
|
||
async def get_cookie_stats():
|
||
"""获取Cookie统计信息"""
|
||
try:
|
||
# 隐藏敏感的Cookie字符串
|
||
safe_cookies = []
|
||
for cookie in cookie_storage.values():
|
||
safe_cookie = {**cookie}
|
||
safe_cookie["cookie_string"] = "***已隐藏***"
|
||
safe_cookies.append(safe_cookie)
|
||
|
||
result_data = {
|
||
"total_cookies": len(cookie_storage),
|
||
"valid_cookies": len([c for c in cookie_storage.values() if c["is_active"]]),
|
||
"cookies": safe_cookies
|
||
}
|
||
|
||
return CookieManagementResponse(
|
||
success=True,
|
||
message="获取Cookie统计成功",
|
||
data=result_data
|
||
)
|
||
|
||
except Exception as e:
|
||
error_msg = f"获取Cookie统计失败: {str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content=CookieManagementResponse(
|
||
success=False,
|
||
message="获取Cookie统计失败",
|
||
error=error_msg
|
||
).dict()
|
||
)
|
||
|
||
|
||
@router.post("/prompts/build", response_model=PromptBuildResponse, summary="构建提示词")
|
||
async def build_prompt(
|
||
request: PromptBuildRequest,
|
||
content_pipeline: Dict[str, Any] = Depends(__import__('api_v2.main', fromlist=['get_content_pipeline']).get_content_pipeline),
|
||
db_service: DatabaseService = Depends(get_database_service)
|
||
):
|
||
"""构建自定义提示词(支持数据库数据增强)"""
|
||
try:
|
||
prompt_manager = content_pipeline["prompt_manager"]
|
||
|
||
# 如果启用数据库增强,获取相关数据
|
||
enhanced_data = {}
|
||
if db_service.is_available():
|
||
try:
|
||
# 查询相关数据
|
||
if request.scenic_info:
|
||
scenic_spots, _ = db_service.get_scenic_spots(limit=3, search=request.scenic_info)
|
||
enhanced_data["scenic_spots"] = scenic_spots
|
||
|
||
if request.product_info:
|
||
products, _ = db_service.get_products(limit=3, search=request.product_info)
|
||
enhanced_data["products"] = products
|
||
|
||
styles = db_service.get_styles()
|
||
audiences = db_service.get_audiences()
|
||
enhanced_data["available_styles"] = [s["styleName"] for s in styles]
|
||
enhanced_data["available_audiences"] = [a["audienceName"] for a in audiences]
|
||
|
||
except Exception as e:
|
||
logger.warning(f"数据库增强失败: {e}")
|
||
|
||
# 获取基础提示词模板
|
||
try:
|
||
system_prompt = prompt_manager.get_prompt(request.task_type, "system")
|
||
user_prompt_template = prompt_manager.get_prompt(request.task_type, "user")
|
||
except:
|
||
# 如果没有找到模板,提供默认模板
|
||
system_prompt = f"你是一个专业的{request.task_type}专家。"
|
||
user_prompt_template = "请根据以下信息完成任务:\n景区信息:{scenic_info}\n产品信息:{product_info}\n风格要求:{style}\n目标受众:{target_audience}\n自定义要求:{custom_requirements}"
|
||
|
||
# 格式化用户提示词
|
||
user_prompt = user_prompt_template.format(
|
||
scenic_info=request.scenic_info or "未提供",
|
||
product_info=request.product_info or "未提供",
|
||
style=request.style or "默认",
|
||
target_audience=request.target_audience or "通用受众",
|
||
custom_requirements=request.custom_requirements or "无特殊要求"
|
||
)
|
||
|
||
result_data = {
|
||
"system_prompt": system_prompt,
|
||
"user_prompt": user_prompt,
|
||
"task_type": request.task_type,
|
||
"variables": {
|
||
"scenic_info": request.scenic_info,
|
||
"product_info": request.product_info,
|
||
"style": request.style,
|
||
"target_audience": request.target_audience,
|
||
"custom_requirements": request.custom_requirements
|
||
},
|
||
"enhanced_data": enhanced_data,
|
||
"database_enhanced": bool(enhanced_data)
|
||
}
|
||
|
||
logger.info(f"提示词构建成功: {request.task_type}")
|
||
|
||
return PromptBuildResponse(
|
||
success=True,
|
||
message="提示词构建成功",
|
||
data=result_data
|
||
)
|
||
|
||
except Exception as e:
|
||
error_msg = f"提示词构建失败: {str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content=PromptBuildResponse(
|
||
success=False,
|
||
message="提示词构建失败",
|
||
error=error_msg
|
||
).dict()
|
||
)
|
||
|
||
|
||
@router.get("/tasks", response_model=TaskManagementResponse, summary="获取任务列表")
|
||
async def get_tasks():
|
||
"""获取所有任务列表"""
|
||
try:
|
||
tasks = []
|
||
for task_id, task_data in task_storage.items():
|
||
tasks.append({
|
||
"task_id": task_id,
|
||
"steps_completed": task_data.get("steps_completed", []),
|
||
"status": "completed" if task_data.get("results") else "processing",
|
||
"created_time": task_data.get("execution_details", {}).get("start_time", "unknown"),
|
||
"total_steps": len(task_data.get("results", {}))
|
||
})
|
||
|
||
result_data = {
|
||
"tasks": tasks,
|
||
"total_tasks": len(tasks),
|
||
"completed_tasks": len([t for t in tasks if t["status"] == "completed"]),
|
||
"processing_tasks": len([t for t in tasks if t["status"] == "processing"])
|
||
}
|
||
|
||
return TaskManagementResponse(
|
||
success=True,
|
||
message="获取任务列表成功",
|
||
data=result_data
|
||
)
|
||
|
||
except Exception as e:
|
||
error_msg = f"获取任务列表失败: {str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content=TaskManagementResponse(
|
||
success=False,
|
||
message="获取任务列表失败",
|
||
error=error_msg
|
||
).dict()
|
||
)
|
||
|
||
|
||
@router.get("/tasks/{task_id}", response_model=TaskManagementResponse, summary="获取任务详情")
|
||
async def get_task(task_id: str):
|
||
"""获取指定任务的详情"""
|
||
try:
|
||
if task_id not in task_storage:
|
||
raise HTTPException(status_code=404, detail="任务不存在")
|
||
|
||
task_data = task_storage[task_id]
|
||
|
||
return TaskManagementResponse(
|
||
success=True,
|
||
message="获取任务详情成功",
|
||
data=task_data
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
error_msg = f"获取任务详情失败: {str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content=TaskManagementResponse(
|
||
success=False,
|
||
message="获取任务详情失败",
|
||
error=error_msg
|
||
).dict()
|
||
) |