2025-07-31 15:35:23 +08:00

643 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Workflow Router
工作流路由 - API v2
包含完整流水线、文件上传、Cookie管理、任务管理等功能
"""
import logging
import uuid
from typing import Dict, Any, List, Optional
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
from fastapi.responses import JSONResponse, StreamingResponse
from ..models import (
PipelineRequest,
FileUploadRequest,
CookieManagementRequest,
PromptBuildRequest,
PipelineResponse,
FileUploadResponse,
CookieManagementResponse,
PromptBuildResponse,
TaskManagementResponse,
ApiResponse
)
from ..services import DatabaseService, FileService
logger = logging.getLogger(__name__)
router = APIRouter()
# 全局存储
task_storage = {}
cookie_storage = {}
# 服务实例
_file_service = None
_db_service = None
def get_file_service() -> FileService:
"""获取文件服务实例"""
global _file_service
if _file_service is None:
upload_dir = "uploads"
Path(upload_dir).mkdir(exist_ok=True)
_file_service = FileService(upload_dir=upload_dir)
return _file_service
def get_database_service() -> DatabaseService:
"""获取数据库服务实例"""
global _db_service
if _db_service is None:
_db_service = DatabaseService()
return _db_service
@router.post("/pipeline", response_model=PipelineResponse, summary="执行完整流水线")
async def execute_pipeline(
request: PipelineRequest,
content_pipeline: Dict[str, Any] = Depends(__import__('api_v2.main', fromlist=['get_content_pipeline']).get_content_pipeline),
poster_pipeline: Dict[str, Any] = Depends(__import__('api_v2.main', fromlist=['get_poster_pipeline']).get_poster_pipeline),
db_service: DatabaseService = Depends(get_database_service)
):
"""
执行完整的内容创作流水线
1. 查询相关数据
2. 生成主题
3. 生成内容
4. 评判内容
5. 生成海报(可选)
"""
task_id = str(uuid.uuid4())
try:
logger.info(f"开始执行流水线任务ID: {task_id}")
pipeline_result = {
"task_id": task_id,
"steps_completed": [],
"results": {},
"execution_details": {
"start_time": str(uuid.uuid4()), # 模拟时间戳
"database_used": db_service.is_available()
}
}
# 0. 数据查询(如果需要)
if hasattr(request, 'use_database_data') and request.use_database_data:
logger.info("步骤0: 查询数据库数据")
try:
# 查询相关景区和产品数据
scenic_spots, _ = db_service.get_scenic_spots(limit=5, search=request.scenic_info)
products, _ = db_service.get_products(limit=5, search=request.product_info)
styles = db_service.get_styles()
audiences = db_service.get_audiences()
pipeline_result["results"]["database_data"] = {
"scenic_spots": scenic_spots,
"products": products,
"styles": styles,
"audiences": audiences
}
pipeline_result["steps_completed"].append("database_query")
except Exception as e:
logger.warning(f"数据库查询失败,使用原始参数: {e}")
# 1. 生成主题
logger.info("步骤1: 生成主题")
topic_generator = content_pipeline["topic_generator"]
topic_request_id, topics_data = await topic_generator.generate_topics(
creative_materials=request.creative_materials,
num_topics=request.num_topics,
style=request.style
)
pipeline_result["steps_completed"].append("topic_generation")
pipeline_result["results"]["topics"] = topics_data
# 选择第一个主题进行内容生成
if topics_data.get("topics"):
selected_topic = topics_data["topics"][0]
# 2. 生成内容
logger.info("步骤2: 生成内容")
content_generator = content_pipeline["content_generator"]
content_request_id, content_data = await content_generator.generate_content(
topic=selected_topic,
scenic_info=request.scenic_info,
product_info=request.product_info,
additional_requirements=f"风格: {request.style}, 目标受众: {request.target_audience}"
)
pipeline_result["steps_completed"].append("content_generation")
pipeline_result["results"]["content"] = content_data
# 3. 评判内容
logger.info("步骤3: 评判内容")
content_judger = content_pipeline["content_judger"]
judge_request_id, judge_data = await content_judger.judge_content(
product_info=request.product_info,
content_to_judge=content_data.get("content", {}).get("content", "")
)
pipeline_result["steps_completed"].append("content_judging")
pipeline_result["results"]["judge_result"] = judge_data
# 4. 生成海报(可选)
if request.generate_poster:
logger.info("步骤4: 生成海报")
poster_generator = poster_pipeline["poster_generator"]
poster_request_id, poster_data = await poster_generator.generate_poster(
template_name=request.poster_template,
scenic_info=request.scenic_info,
product_info=request.product_info,
style_options={"transparent_background": True}
)
pipeline_result["steps_completed"].append("poster_generation")
pipeline_result["results"]["poster"] = poster_data
# 保存任务结果
task_storage[task_id] = pipeline_result
logger.info(f"流水线执行完成任务ID: {task_id}")
return PipelineResponse(
success=True,
message="流水线执行完成",
data=pipeline_result,
request_id=task_id
)
except Exception as e:
error_msg = f"流水线执行失败: {str(e)}"
logger.error(error_msg, exc_info=True)
return JSONResponse(
status_code=500,
content=PipelineResponse(
success=False,
message="流水线执行失败",
error=error_msg,
request_id=task_id
).dict()
)
@router.post("/upload", response_model=FileUploadResponse, summary="上传文件")
async def upload_file(
file: UploadFile = File(...),
description: str = Form(""),
process_immediately: bool = Form(True),
document_pipeline: Dict[str, Any] = Depends(__import__('api_v2.main', fromlist=['get_document_pipeline']).get_document_pipeline),
file_service: FileService = Depends(get_file_service)
):
"""
上传文件并可选择立即处理(支持大文件流式处理)
"""
try:
logger.info(f"开始上传文件: {file.filename}")
# 使用文件服务保存文件(流式处理)
file_info = await file_service.save_upload_file(
file=file,
validate_size=True
)
# 添加描述信息
file_info["description"] = description
file_info["upload_id"] = str(uuid.uuid4())
# 可选立即处理
if process_immediately:
try:
logger.info("开始处理上传的文件")
text_extractor = document_pipeline["text_extractor"]
extracted_doc = text_extractor.extract_text(file_info["file_path"])
if extracted_doc:
file_info["processing_status"] = "completed"
file_info["extracted_text_length"] = len(extracted_doc.content)
file_info["document_type"] = extracted_doc.document_type
file_info["extracted_metadata"] = extracted_doc.metadata
else:
file_info["processing_status"] = "failed"
file_info["processing_error"] = "文档内容为空"
except Exception as e:
logger.warning(f"文件处理失败: {e}")
file_info["processing_status"] = "failed"
file_info["processing_error"] = str(e)
else:
file_info["processing_status"] = "pending"
logger.info(f"文件上传成功: {file.filename} (大小: {file_info['file_size']} 字节)")
return FileUploadResponse(
success=True,
message="文件上传成功",
data=file_info
)
except ValueError as e:
# 文件验证错误
logger.error(f"文件验证失败: {e}")
return JSONResponse(
status_code=400,
content=FileUploadResponse(
success=False,
message="文件验证失败",
error=str(e)
).dict()
)
except Exception as e:
error_msg = f"文件上传失败: {str(e)}"
logger.error(error_msg, exc_info=True)
return JSONResponse(
status_code=500,
content=FileUploadResponse(
success=False,
message="文件上传失败",
error=error_msg
).dict()
)
@router.get("/files", summary="获取文件列表")
async def list_files(
file_service: FileService = Depends(get_file_service)
):
"""获取已上传的文件列表"""
try:
files = file_service.list_files()
return ApiResponse(
success=True,
message="获取文件列表成功",
data={
"files": files,
"total_files": len(files),
"upload_directory": str(file_service.upload_dir)
}
)
except Exception as e:
error_msg = f"获取文件列表失败: {str(e)}"
logger.error(error_msg)
return JSONResponse(
status_code=500,
content=ApiResponse(
success=False,
message="获取文件列表失败",
error=error_msg
).dict()
)
@router.get("/files/{filename}/download", summary="下载文件")
async def download_file(
filename: str,
file_service: FileService = Depends(get_file_service)
):
"""下载文件(支持大文件流式下载)"""
try:
file_path = file_service.upload_dir / filename
if not file_path.exists():
raise HTTPException(status_code=404, detail="文件不存在")
# 创建流式响应
return file_service.create_streaming_response(
file_path=str(file_path),
filename=filename
)
except HTTPException:
raise
except Exception as e:
error_msg = f"文件下载失败: {str(e)}"
logger.error(error_msg)
raise HTTPException(status_code=500, detail=error_msg)
@router.delete("/files/{filename}", summary="删除文件")
async def delete_file(
filename: str,
file_service: FileService = Depends(get_file_service)
):
"""删除文件"""
try:
file_path = file_service.upload_dir / filename
if not file_path.exists():
raise HTTPException(status_code=404, detail="文件不存在")
success = await file_service.delete_file(str(file_path))
if success:
return ApiResponse(
success=True,
message="文件删除成功",
data={"filename": filename}
)
else:
raise HTTPException(status_code=500, detail="文件删除失败")
except HTTPException:
raise
except Exception as e:
error_msg = f"文件删除失败: {str(e)}"
logger.error(error_msg)
raise HTTPException(status_code=500, detail=error_msg)
@router.post("/cookies", response_model=CookieManagementResponse, summary="添加Cookie")
async def add_cookie(request: CookieManagementRequest):
"""添加Cookie"""
try:
cookie_storage[request.name] = {
"name": request.name,
"cookie_string": request.cookie_string,
"description": request.description,
"created_at": str(uuid.uuid4()), # 模拟时间戳
"is_active": True,
"use_count": 0,
"last_used": None
}
result_data = {
"operation": "add",
"cookie_name": request.name,
"total_cookies": len(cookie_storage),
"valid_cookies": len([c for c in cookie_storage.values() if c["is_active"]])
}
logger.info(f"Cookie添加成功: {request.name}")
return CookieManagementResponse(
success=True,
message="Cookie添加成功",
data=result_data
)
except Exception as e:
error_msg = f"Cookie添加失败: {str(e)}"
logger.error(error_msg, exc_info=True)
return JSONResponse(
status_code=500,
content=CookieManagementResponse(
success=False,
message="Cookie添加失败",
error=error_msg
).dict()
)
@router.delete("/cookies/{cookie_name}", response_model=CookieManagementResponse, summary="删除Cookie")
async def delete_cookie(cookie_name: str):
"""删除Cookie"""
try:
if cookie_name not in cookie_storage:
raise HTTPException(status_code=404, detail="Cookie不存在")
del cookie_storage[cookie_name]
result_data = {
"operation": "delete",
"cookie_name": cookie_name,
"total_cookies": len(cookie_storage),
"valid_cookies": len([c for c in cookie_storage.values() if c["is_active"]])
}
logger.info(f"Cookie删除成功: {cookie_name}")
return CookieManagementResponse(
success=True,
message="Cookie删除成功",
data=result_data
)
except HTTPException:
raise
except Exception as e:
error_msg = f"Cookie删除失败: {str(e)}"
logger.error(error_msg, exc_info=True)
return JSONResponse(
status_code=500,
content=CookieManagementResponse(
success=False,
message="Cookie删除失败",
error=error_msg
).dict()
)
@router.get("/cookies", response_model=CookieManagementResponse, summary="获取Cookie统计")
async def get_cookie_stats():
"""获取Cookie统计信息"""
try:
# 隐藏敏感的Cookie字符串
safe_cookies = []
for cookie in cookie_storage.values():
safe_cookie = {**cookie}
safe_cookie["cookie_string"] = "***已隐藏***"
safe_cookies.append(safe_cookie)
result_data = {
"total_cookies": len(cookie_storage),
"valid_cookies": len([c for c in cookie_storage.values() if c["is_active"]]),
"cookies": safe_cookies
}
return CookieManagementResponse(
success=True,
message="获取Cookie统计成功",
data=result_data
)
except Exception as e:
error_msg = f"获取Cookie统计失败: {str(e)}"
logger.error(error_msg, exc_info=True)
return JSONResponse(
status_code=500,
content=CookieManagementResponse(
success=False,
message="获取Cookie统计失败",
error=error_msg
).dict()
)
@router.post("/prompts/build", response_model=PromptBuildResponse, summary="构建提示词")
async def build_prompt(
request: PromptBuildRequest,
content_pipeline: Dict[str, Any] = Depends(__import__('api_v2.main', fromlist=['get_content_pipeline']).get_content_pipeline),
db_service: DatabaseService = Depends(get_database_service)
):
"""构建自定义提示词(支持数据库数据增强)"""
try:
prompt_manager = content_pipeline["prompt_manager"]
# 如果启用数据库增强,获取相关数据
enhanced_data = {}
if db_service.is_available():
try:
# 查询相关数据
if request.scenic_info:
scenic_spots, _ = db_service.get_scenic_spots(limit=3, search=request.scenic_info)
enhanced_data["scenic_spots"] = scenic_spots
if request.product_info:
products, _ = db_service.get_products(limit=3, search=request.product_info)
enhanced_data["products"] = products
styles = db_service.get_styles()
audiences = db_service.get_audiences()
enhanced_data["available_styles"] = [s["styleName"] for s in styles]
enhanced_data["available_audiences"] = [a["audienceName"] for a in audiences]
except Exception as e:
logger.warning(f"数据库增强失败: {e}")
# 获取基础提示词模板
try:
system_prompt = prompt_manager.get_prompt(request.task_type, "system")
user_prompt_template = prompt_manager.get_prompt(request.task_type, "user")
except:
# 如果没有找到模板,提供默认模板
system_prompt = f"你是一个专业的{request.task_type}专家。"
user_prompt_template = "请根据以下信息完成任务:\n景区信息:{scenic_info}\n产品信息:{product_info}\n风格要求:{style}\n目标受众:{target_audience}\n自定义要求:{custom_requirements}"
# 格式化用户提示词
user_prompt = user_prompt_template.format(
scenic_info=request.scenic_info or "未提供",
product_info=request.product_info or "未提供",
style=request.style or "默认",
target_audience=request.target_audience or "通用受众",
custom_requirements=request.custom_requirements or "无特殊要求"
)
result_data = {
"system_prompt": system_prompt,
"user_prompt": user_prompt,
"task_type": request.task_type,
"variables": {
"scenic_info": request.scenic_info,
"product_info": request.product_info,
"style": request.style,
"target_audience": request.target_audience,
"custom_requirements": request.custom_requirements
},
"enhanced_data": enhanced_data,
"database_enhanced": bool(enhanced_data)
}
logger.info(f"提示词构建成功: {request.task_type}")
return PromptBuildResponse(
success=True,
message="提示词构建成功",
data=result_data
)
except Exception as e:
error_msg = f"提示词构建失败: {str(e)}"
logger.error(error_msg, exc_info=True)
return JSONResponse(
status_code=500,
content=PromptBuildResponse(
success=False,
message="提示词构建失败",
error=error_msg
).dict()
)
@router.get("/tasks", response_model=TaskManagementResponse, summary="获取任务列表")
async def get_tasks():
"""获取所有任务列表"""
try:
tasks = []
for task_id, task_data in task_storage.items():
tasks.append({
"task_id": task_id,
"steps_completed": task_data.get("steps_completed", []),
"status": "completed" if task_data.get("results") else "processing",
"created_time": task_data.get("execution_details", {}).get("start_time", "unknown"),
"total_steps": len(task_data.get("results", {}))
})
result_data = {
"tasks": tasks,
"total_tasks": len(tasks),
"completed_tasks": len([t for t in tasks if t["status"] == "completed"]),
"processing_tasks": len([t for t in tasks if t["status"] == "processing"])
}
return TaskManagementResponse(
success=True,
message="获取任务列表成功",
data=result_data
)
except Exception as e:
error_msg = f"获取任务列表失败: {str(e)}"
logger.error(error_msg, exc_info=True)
return JSONResponse(
status_code=500,
content=TaskManagementResponse(
success=False,
message="获取任务列表失败",
error=error_msg
).dict()
)
@router.get("/tasks/{task_id}", response_model=TaskManagementResponse, summary="获取任务详情")
async def get_task(task_id: str):
"""获取指定任务的详情"""
try:
if task_id not in task_storage:
raise HTTPException(status_code=404, detail="任务不存在")
task_data = task_storage[task_id]
return TaskManagementResponse(
success=True,
message="获取任务详情成功",
data=task_data
)
except HTTPException:
raise
except Exception as e:
error_msg = f"获取任务详情失败: {str(e)}"
logger.error(error_msg, exc_info=True)
return JSONResponse(
status_code=500,
content=TaskManagementResponse(
success=False,
message="获取任务详情失败",
error=error_msg
).dict()
)