339 lines
12 KiB
Python
339 lines
12 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
文档处理API路由
|
||
"""
|
||
|
||
import logging
|
||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
|
||
from typing import List, Dict, Any
|
||
from pathlib import Path
|
||
|
||
from core.config import ConfigManager
|
||
from core.ai import AIAgent
|
||
from utils.file_io import OutputManager
|
||
from api.services.document_service import DocumentService
|
||
from api.models.document import (
|
||
DocumentProcessRequest, DocumentProcessResponse,
|
||
BatchProcessRequest, BatchProcessResponse,
|
||
TextExtractionRequest, TextExtractionResponse,
|
||
TextParsingRequest, TextParsingResponse,
|
||
DocumentTransformRequest, DocumentTransformResponse,
|
||
SupportedFormatsResponse, ProcessingStatisticsResponse
|
||
)
|
||
|
||
# 从依赖注入模块导入依赖
|
||
from api.dependencies import get_config, get_ai_agent, get_output_manager
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 创建路由
|
||
router = APIRouter(
|
||
tags=["document"],
|
||
responses={404: {"description": "Not found"}},
|
||
)
|
||
|
||
# 依赖注入函数
|
||
def get_document_service(
|
||
config_manager: ConfigManager = Depends(get_config),
|
||
ai_agent: AIAgent = Depends(get_ai_agent),
|
||
output_manager: OutputManager = Depends(get_output_manager)
|
||
) -> DocumentService:
|
||
"""获取文档处理服务"""
|
||
return DocumentService(ai_agent, config_manager, output_manager)
|
||
|
||
|
||
@router.post("/process", response_model=DocumentProcessResponse, summary="处理文档")
|
||
async def process_document(
|
||
request: DocumentProcessRequest,
|
||
document_service: DocumentService = Depends(get_document_service)
|
||
):
|
||
"""
|
||
处理单个文档的完整流程:文本提取 → 文档解析 → 内容转换
|
||
|
||
- **file_path**: 文档文件路径
|
||
- **attraction_format**: 景区转换格式(standard/marketing/travel_guide)
|
||
- **product_format**: 产品转换格式(standard/sales/catalog)
|
||
"""
|
||
try:
|
||
result = await document_service.process_document(
|
||
file_path=request.file_path,
|
||
attraction_format=request.attraction_format,
|
||
product_format=request.product_format
|
||
)
|
||
|
||
return DocumentProcessResponse(**result)
|
||
|
||
except Exception as e:
|
||
logger.error(f"文档处理失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"文档处理失败: {str(e)}")
|
||
|
||
|
||
@router.post("/batch-process", response_model=BatchProcessResponse, summary="批量处理文档")
|
||
async def batch_process_documents(
|
||
request: BatchProcessRequest,
|
||
document_service: DocumentService = Depends(get_document_service)
|
||
):
|
||
"""
|
||
批量处理多个文档
|
||
|
||
- **file_paths**: 文档文件路径列表
|
||
- **attraction_format**: 景区转换格式
|
||
- **product_format**: 产品转换格式
|
||
"""
|
||
try:
|
||
result = await document_service.process_multiple_documents(
|
||
file_paths=[Path(fp) for fp in request.file_paths],
|
||
attraction_format=request.attraction_format,
|
||
product_format=request.product_format
|
||
)
|
||
|
||
return BatchProcessResponse(**result)
|
||
|
||
except Exception as e:
|
||
logger.error(f"批量文档处理失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"批量文档处理失败: {str(e)}")
|
||
|
||
|
||
@router.post("/extract-text", response_model=TextExtractionResponse, summary="提取文本")
|
||
async def extract_text(
|
||
request: TextExtractionRequest,
|
||
document_service: DocumentService = Depends(get_document_service)
|
||
):
|
||
"""
|
||
从文档中提取文本内容
|
||
|
||
- **file_path**: 文档文件路径
|
||
"""
|
||
try:
|
||
result = await document_service.extract_text_only(request.file_path)
|
||
return TextExtractionResponse(**result)
|
||
|
||
except Exception as e:
|
||
logger.error(f"文本提取失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"文本提取失败: {str(e)}")
|
||
|
||
|
||
@router.post("/parse-text", response_model=TextParsingResponse, summary="解析文本")
|
||
async def parse_text(
|
||
request: TextParsingRequest,
|
||
document_service: DocumentService = Depends(get_document_service)
|
||
):
|
||
"""
|
||
解析文本内容,识别景区和产品信息
|
||
|
||
- **text**: 文本内容
|
||
- **metadata**: 元数据(可选)
|
||
"""
|
||
try:
|
||
result = await document_service.parse_text_only(
|
||
text=request.text,
|
||
metadata=request.metadata or {}
|
||
)
|
||
return TextParsingResponse(**result)
|
||
|
||
except Exception as e:
|
||
logger.error(f"文本解析失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"文本解析失败: {str(e)}")
|
||
|
||
|
||
@router.post("/transform", response_model=DocumentTransformResponse, summary="转换文档")
|
||
async def transform_document(
|
||
request: DocumentTransformRequest,
|
||
document_service: DocumentService = Depends(get_document_service)
|
||
):
|
||
"""
|
||
将解析后的文档转换为标准格式
|
||
|
||
- **parsed_document**: 解析后的文档数据
|
||
- **attraction_format**: 景区转换格式
|
||
- **product_format**: 产品转换格式
|
||
"""
|
||
try:
|
||
result = await document_service.transform_parsed_document(
|
||
parsed_document=request.parsed_document,
|
||
attraction_format=request.attraction_format,
|
||
product_format=request.product_format
|
||
)
|
||
return DocumentTransformResponse(**result)
|
||
|
||
except Exception as e:
|
||
logger.error(f"文档转换失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"文档转换失败: {str(e)}")
|
||
|
||
|
||
@router.get("/supported-formats", response_model=SupportedFormatsResponse, summary="获取支持的格式")
|
||
async def get_supported_formats(
|
||
document_service: DocumentService = Depends(get_document_service)
|
||
):
|
||
"""获取支持的文件格式和转换格式"""
|
||
try:
|
||
formats = document_service.get_supported_formats()
|
||
return SupportedFormatsResponse(**formats)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取支持格式失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"获取支持格式失败: {str(e)}")
|
||
|
||
|
||
@router.get("/statistics", response_model=ProcessingStatisticsResponse, summary="获取处理统计")
|
||
async def get_processing_statistics(
|
||
document_service: DocumentService = Depends(get_document_service)
|
||
):
|
||
"""获取文档处理服务的统计信息"""
|
||
try:
|
||
stats = document_service.get_processing_statistics()
|
||
return ProcessingStatisticsResponse(**stats)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取统计信息失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"获取统计信息失败: {str(e)}")
|
||
|
||
|
||
# 文件上传接口
|
||
@router.post("/upload-and-process", response_model=DocumentProcessResponse, summary="上传并处理文档")
|
||
async def upload_and_process_document(
|
||
file: UploadFile = File(...),
|
||
attraction_format: str = "standard",
|
||
product_format: str = "standard",
|
||
document_service: DocumentService = Depends(get_document_service)
|
||
):
|
||
"""
|
||
上传文档文件并进行处理
|
||
|
||
- **file**: 上传的文档文件
|
||
- **attraction_format**: 景区转换格式
|
||
- **product_format**: 产品转换格式
|
||
"""
|
||
try:
|
||
import tempfile
|
||
import os
|
||
from pathlib import Path
|
||
|
||
# 检查文件类型
|
||
if not file.filename:
|
||
raise HTTPException(status_code=400, detail="文件名不能为空")
|
||
|
||
file_extension = Path(file.filename).suffix.lower()
|
||
supported_formats = document_service.get_supported_formats()
|
||
|
||
if file_extension not in supported_formats['supported_file_formats']:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"不支持的文件格式: {file_extension}"
|
||
)
|
||
|
||
# 创建临时文件
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
|
||
# 保存上传的文件
|
||
content = await file.read()
|
||
tmp_file.write(content)
|
||
tmp_file_path = tmp_file.name
|
||
|
||
try:
|
||
# 处理文档
|
||
result = await document_service.process_document(
|
||
file_path=tmp_file_path,
|
||
attraction_format=attraction_format,
|
||
product_format=product_format
|
||
)
|
||
|
||
# 更新源文件信息
|
||
if result.get('success') and 'source_file' in result:
|
||
result['source_file']['original_name'] = file.filename
|
||
result['source_file']['uploaded'] = True
|
||
|
||
return DocumentProcessResponse(**result)
|
||
|
||
finally:
|
||
# 清理临时文件
|
||
try:
|
||
os.unlink(tmp_file_path)
|
||
except:
|
||
pass
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"文件上传处理失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"文件上传处理失败: {str(e)}")
|
||
|
||
|
||
@router.post("/upload-batch-process", response_model=BatchProcessResponse, summary="批量上传并处理文档")
|
||
async def upload_batch_process_documents(
|
||
files: List[UploadFile] = File(...),
|
||
attraction_format: str = "standard",
|
||
product_format: str = "standard",
|
||
document_service: DocumentService = Depends(get_document_service)
|
||
):
|
||
"""
|
||
批量上传文档文件并进行处理
|
||
|
||
- **files**: 上传的文档文件列表
|
||
- **attraction_format**: 景区转换格式
|
||
- **product_format**: 产品转换格式
|
||
"""
|
||
try:
|
||
import tempfile
|
||
import os
|
||
from pathlib import Path
|
||
|
||
supported_formats = document_service.get_supported_formats()
|
||
temp_files = []
|
||
|
||
try:
|
||
# 保存所有上传的文件
|
||
for file in files:
|
||
if not file.filename:
|
||
logger.warning("跳过没有文件名的文件")
|
||
continue
|
||
|
||
file_extension = Path(file.filename).suffix.lower()
|
||
|
||
if file_extension not in supported_formats['supported_file_formats']:
|
||
logger.warning(f"跳过不支持的文件: {file.filename}")
|
||
continue
|
||
|
||
# 创建临时文件
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
|
||
content = await file.read()
|
||
tmp_file.write(content)
|
||
temp_files.append({
|
||
'path': tmp_file.name,
|
||
'original_name': file.filename
|
||
})
|
||
|
||
if not temp_files:
|
||
raise HTTPException(status_code=400, detail="没有有效的文件可以处理")
|
||
|
||
# 批量处理文档
|
||
file_paths = [f['path'] for f in temp_files]
|
||
result = await document_service.process_multiple_documents(
|
||
file_paths=file_paths,
|
||
attraction_format=attraction_format,
|
||
product_format=product_format
|
||
)
|
||
|
||
# 更新源文件信息
|
||
if 'results' in result:
|
||
for i, file_result in enumerate(result['results']):
|
||
if i < len(temp_files) and 'source_file' in file_result:
|
||
file_result['source_file']['original_name'] = temp_files[i]['original_name']
|
||
file_result['source_file']['uploaded'] = True
|
||
|
||
return BatchProcessResponse(**result)
|
||
|
||
finally:
|
||
# 清理临时文件
|
||
for temp_file in temp_files:
|
||
try:
|
||
os.unlink(temp_file['path'])
|
||
except:
|
||
pass
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"批量文件上传处理失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"批量文件上传处理失败: {str(e)}") |