339 lines
12 KiB
Python
339 lines
12 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
文档处理API路由
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import logging
|
|||
|
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
|
|||
|
|
from typing import List, Dict, Any
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
from core.config import ConfigManager
|
|||
|
|
from core.ai import AIAgent
|
|||
|
|
from utils.file_io import OutputManager
|
|||
|
|
from api.services.document_service import DocumentService
|
|||
|
|
from api.models.document import (
|
|||
|
|
DocumentProcessRequest, DocumentProcessResponse,
|
|||
|
|
BatchProcessRequest, BatchProcessResponse,
|
|||
|
|
TextExtractionRequest, TextExtractionResponse,
|
|||
|
|
TextParsingRequest, TextParsingResponse,
|
|||
|
|
DocumentTransformRequest, DocumentTransformResponse,
|
|||
|
|
SupportedFormatsResponse, ProcessingStatisticsResponse
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 从依赖注入模块导入依赖
|
|||
|
|
from api.dependencies import get_config, get_ai_agent, get_output_manager
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
# 创建路由
|
|||
|
|
router = APIRouter(
|
|||
|
|
tags=["document"],
|
|||
|
|
responses={404: {"description": "Not found"}},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 依赖注入函数
|
|||
|
|
def get_document_service(
|
|||
|
|
config_manager: ConfigManager = Depends(get_config),
|
|||
|
|
ai_agent: AIAgent = Depends(get_ai_agent),
|
|||
|
|
output_manager: OutputManager = Depends(get_output_manager)
|
|||
|
|
) -> DocumentService:
|
|||
|
|
"""获取文档处理服务"""
|
|||
|
|
return DocumentService(ai_agent, config_manager, output_manager)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/process", response_model=DocumentProcessResponse, summary="处理文档")
|
|||
|
|
async def process_document(
|
|||
|
|
request: DocumentProcessRequest,
|
|||
|
|
document_service: DocumentService = Depends(get_document_service)
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
处理单个文档的完整流程:文本提取 → 文档解析 → 内容转换
|
|||
|
|
|
|||
|
|
- **file_path**: 文档文件路径
|
|||
|
|
- **attraction_format**: 景区转换格式(standard/marketing/travel_guide)
|
|||
|
|
- **product_format**: 产品转换格式(standard/sales/catalog)
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
result = await document_service.process_document(
|
|||
|
|
file_path=request.file_path,
|
|||
|
|
attraction_format=request.attraction_format,
|
|||
|
|
product_format=request.product_format
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return DocumentProcessResponse(**result)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"文档处理失败: {e}", exc_info=True)
|
|||
|
|
raise HTTPException(status_code=500, detail=f"文档处理失败: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/batch-process", response_model=BatchProcessResponse, summary="批量处理文档")
|
|||
|
|
async def batch_process_documents(
|
|||
|
|
request: BatchProcessRequest,
|
|||
|
|
document_service: DocumentService = Depends(get_document_service)
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
批量处理多个文档
|
|||
|
|
|
|||
|
|
- **file_paths**: 文档文件路径列表
|
|||
|
|
- **attraction_format**: 景区转换格式
|
|||
|
|
- **product_format**: 产品转换格式
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
result = await document_service.process_multiple_documents(
|
|||
|
|
file_paths=[Path(fp) for fp in request.file_paths],
|
|||
|
|
attraction_format=request.attraction_format,
|
|||
|
|
product_format=request.product_format
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return BatchProcessResponse(**result)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"批量文档处理失败: {e}", exc_info=True)
|
|||
|
|
raise HTTPException(status_code=500, detail=f"批量文档处理失败: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/extract-text", response_model=TextExtractionResponse, summary="提取文本")
|
|||
|
|
async def extract_text(
|
|||
|
|
request: TextExtractionRequest,
|
|||
|
|
document_service: DocumentService = Depends(get_document_service)
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
从文档中提取文本内容
|
|||
|
|
|
|||
|
|
- **file_path**: 文档文件路径
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
result = await document_service.extract_text_only(request.file_path)
|
|||
|
|
return TextExtractionResponse(**result)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"文本提取失败: {e}", exc_info=True)
|
|||
|
|
raise HTTPException(status_code=500, detail=f"文本提取失败: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/parse-text", response_model=TextParsingResponse, summary="解析文本")
|
|||
|
|
async def parse_text(
|
|||
|
|
request: TextParsingRequest,
|
|||
|
|
document_service: DocumentService = Depends(get_document_service)
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
解析文本内容,识别景区和产品信息
|
|||
|
|
|
|||
|
|
- **text**: 文本内容
|
|||
|
|
- **metadata**: 元数据(可选)
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
result = await document_service.parse_text_only(
|
|||
|
|
text=request.text,
|
|||
|
|
metadata=request.metadata or {}
|
|||
|
|
)
|
|||
|
|
return TextParsingResponse(**result)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"文本解析失败: {e}", exc_info=True)
|
|||
|
|
raise HTTPException(status_code=500, detail=f"文本解析失败: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/transform", response_model=DocumentTransformResponse, summary="转换文档")
|
|||
|
|
async def transform_document(
|
|||
|
|
request: DocumentTransformRequest,
|
|||
|
|
document_service: DocumentService = Depends(get_document_service)
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
将解析后的文档转换为标准格式
|
|||
|
|
|
|||
|
|
- **parsed_document**: 解析后的文档数据
|
|||
|
|
- **attraction_format**: 景区转换格式
|
|||
|
|
- **product_format**: 产品转换格式
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
result = await document_service.transform_parsed_document(
|
|||
|
|
parsed_document=request.parsed_document,
|
|||
|
|
attraction_format=request.attraction_format,
|
|||
|
|
product_format=request.product_format
|
|||
|
|
)
|
|||
|
|
return DocumentTransformResponse(**result)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"文档转换失败: {e}", exc_info=True)
|
|||
|
|
raise HTTPException(status_code=500, detail=f"文档转换失败: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/supported-formats", response_model=SupportedFormatsResponse, summary="获取支持的格式")
|
|||
|
|
async def get_supported_formats(
|
|||
|
|
document_service: DocumentService = Depends(get_document_service)
|
|||
|
|
):
|
|||
|
|
"""获取支持的文件格式和转换格式"""
|
|||
|
|
try:
|
|||
|
|
formats = document_service.get_supported_formats()
|
|||
|
|
return SupportedFormatsResponse(**formats)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"获取支持格式失败: {e}", exc_info=True)
|
|||
|
|
raise HTTPException(status_code=500, detail=f"获取支持格式失败: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/statistics", response_model=ProcessingStatisticsResponse, summary="获取处理统计")
|
|||
|
|
async def get_processing_statistics(
|
|||
|
|
document_service: DocumentService = Depends(get_document_service)
|
|||
|
|
):
|
|||
|
|
"""获取文档处理服务的统计信息"""
|
|||
|
|
try:
|
|||
|
|
stats = document_service.get_processing_statistics()
|
|||
|
|
return ProcessingStatisticsResponse(**stats)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"获取统计信息失败: {e}", exc_info=True)
|
|||
|
|
raise HTTPException(status_code=500, detail=f"获取统计信息失败: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 文件上传接口
|
|||
|
|
@router.post("/upload-and-process", response_model=DocumentProcessResponse, summary="上传并处理文档")
|
|||
|
|
async def upload_and_process_document(
|
|||
|
|
file: UploadFile = File(...),
|
|||
|
|
attraction_format: str = "standard",
|
|||
|
|
product_format: str = "standard",
|
|||
|
|
document_service: DocumentService = Depends(get_document_service)
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
上传文档文件并进行处理
|
|||
|
|
|
|||
|
|
- **file**: 上传的文档文件
|
|||
|
|
- **attraction_format**: 景区转换格式
|
|||
|
|
- **product_format**: 产品转换格式
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
import tempfile
|
|||
|
|
import os
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
# 检查文件类型
|
|||
|
|
if not file.filename:
|
|||
|
|
raise HTTPException(status_code=400, detail="文件名不能为空")
|
|||
|
|
|
|||
|
|
file_extension = Path(file.filename).suffix.lower()
|
|||
|
|
supported_formats = document_service.get_supported_formats()
|
|||
|
|
|
|||
|
|
if file_extension not in supported_formats['supported_file_formats']:
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=400,
|
|||
|
|
detail=f"不支持的文件格式: {file_extension}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 创建临时文件
|
|||
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
|
|||
|
|
# 保存上传的文件
|
|||
|
|
content = await file.read()
|
|||
|
|
tmp_file.write(content)
|
|||
|
|
tmp_file_path = tmp_file.name
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 处理文档
|
|||
|
|
result = await document_service.process_document(
|
|||
|
|
file_path=tmp_file_path,
|
|||
|
|
attraction_format=attraction_format,
|
|||
|
|
product_format=product_format
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 更新源文件信息
|
|||
|
|
if result.get('success') and 'source_file' in result:
|
|||
|
|
result['source_file']['original_name'] = file.filename
|
|||
|
|
result['source_file']['uploaded'] = True
|
|||
|
|
|
|||
|
|
return DocumentProcessResponse(**result)
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
# 清理临时文件
|
|||
|
|
try:
|
|||
|
|
os.unlink(tmp_file_path)
|
|||
|
|
except:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
except HTTPException:
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"文件上传处理失败: {e}", exc_info=True)
|
|||
|
|
raise HTTPException(status_code=500, detail=f"文件上传处理失败: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/upload-batch-process", response_model=BatchProcessResponse, summary="批量上传并处理文档")
|
|||
|
|
async def upload_batch_process_documents(
|
|||
|
|
files: List[UploadFile] = File(...),
|
|||
|
|
attraction_format: str = "standard",
|
|||
|
|
product_format: str = "standard",
|
|||
|
|
document_service: DocumentService = Depends(get_document_service)
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
批量上传文档文件并进行处理
|
|||
|
|
|
|||
|
|
- **files**: 上传的文档文件列表
|
|||
|
|
- **attraction_format**: 景区转换格式
|
|||
|
|
- **product_format**: 产品转换格式
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
import tempfile
|
|||
|
|
import os
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
supported_formats = document_service.get_supported_formats()
|
|||
|
|
temp_files = []
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 保存所有上传的文件
|
|||
|
|
for file in files:
|
|||
|
|
if not file.filename:
|
|||
|
|
logger.warning("跳过没有文件名的文件")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
file_extension = Path(file.filename).suffix.lower()
|
|||
|
|
|
|||
|
|
if file_extension not in supported_formats['supported_file_formats']:
|
|||
|
|
logger.warning(f"跳过不支持的文件: {file.filename}")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 创建临时文件
|
|||
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
|
|||
|
|
content = await file.read()
|
|||
|
|
tmp_file.write(content)
|
|||
|
|
temp_files.append({
|
|||
|
|
'path': tmp_file.name,
|
|||
|
|
'original_name': file.filename
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
if not temp_files:
|
|||
|
|
raise HTTPException(status_code=400, detail="没有有效的文件可以处理")
|
|||
|
|
|
|||
|
|
# 批量处理文档
|
|||
|
|
file_paths = [f['path'] for f in temp_files]
|
|||
|
|
result = await document_service.process_multiple_documents(
|
|||
|
|
file_paths=file_paths,
|
|||
|
|
attraction_format=attraction_format,
|
|||
|
|
product_format=product_format
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 更新源文件信息
|
|||
|
|
if 'results' in result:
|
|||
|
|
for i, file_result in enumerate(result['results']):
|
|||
|
|
if i < len(temp_files) and 'source_file' in file_result:
|
|||
|
|
file_result['source_file']['original_name'] = temp_files[i]['original_name']
|
|||
|
|
file_result['source_file']['uploaded'] = True
|
|||
|
|
|
|||
|
|
return BatchProcessResponse(**result)
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
# 清理临时文件
|
|||
|
|
for temp_file in temp_files:
|
|||
|
|
try:
|
|||
|
|
os.unlink(temp_file['path'])
|
|||
|
|
except:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
except HTTPException:
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"批量文件上传处理失败: {e}", exc_info=True)
|
|||
|
|
raise HTTPException(status_code=500, detail=f"批量文件上传处理失败: {str(e)}")
|