TravelContentCreator/domain/aigc/engine_executor.py

306 lines
8.7 KiB
Python
Raw Normal View History

2025-12-08 14:58:35 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
引擎执行器
负责引擎的调用和任务管理
"""
import logging
import uuid
import time
import asyncio
from typing import Dict, Any, Optional
from datetime import datetime
from dataclasses import dataclass, field
from .engine_registry import EngineRegistry
from .engines.base import EngineResult
logger = logging.getLogger(__name__)
@dataclass
class TaskInfo:
"""任务信息"""
task_id: str
engine_id: str
params: Dict[str, Any]
status: str = "pending" # pending, running, completed, failed, cancelled
progress: int = 0
result: Optional[EngineResult] = None
error: Optional[str] = None
created_at: datetime = field(default_factory=datetime.now)
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
def to_dict(self) -> Dict[str, Any]:
return {
"task_id": self.task_id,
"engine_id": self.engine_id,
"status": self.status,
"progress": self.progress,
"result": self.result.to_dict() if self.result else None,
"error": self.error,
"created_at": self.created_at.isoformat() if self.created_at else None,
"started_at": self.started_at.isoformat() if self.started_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
}
class EngineExecutor:
"""
引擎执行器
支持
- 同步执行
- 异步执行返回 task_id
- 任务状态查询
- 任务取消
"""
_instance: Optional['EngineExecutor'] = None
def __new__(cls):
"""单例模式"""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._registry = EngineRegistry()
self._tasks: Dict[str, TaskInfo] = {}
self._running_tasks: Dict[str, asyncio.Task] = {}
self._initialized = True
logger.info("引擎执行器初始化")
async def execute(
self,
engine_id: str,
params: Dict[str, Any],
async_mode: bool = False,
task_id: Optional[str] = None
) -> Dict[str, Any]:
"""
执行引擎
Args:
engine_id: 引擎 ID
params: 引擎参数
async_mode: 是否异步执行
task_id: 任务 ID可选不传则自动生成
Returns:
同步模式执行结果
异步模式{"task_id": "...", "status": "pending"}
"""
# 生成任务 ID
if not task_id:
task_id = str(uuid.uuid4()).replace('-', '')[:16]
# 获取引擎
engine = self._registry.get(engine_id)
if not engine:
return {
"success": False,
"error": f"引擎 {engine_id} 不存在",
"error_code": "ENGINE_NOT_FOUND"
}
# 验证参数
valid, error_msg = engine.validate_params(params)
if not valid:
return {
"success": False,
"error": error_msg,
"error_code": "INVALID_PARAMS"
}
# 创建任务信息
task_info = TaskInfo(
task_id=task_id,
engine_id=engine_id,
params=params
)
self._tasks[task_id] = task_info
if async_mode:
# 异步执行
asyncio_task = asyncio.create_task(
self._execute_task(task_id, engine, params)
)
self._running_tasks[task_id] = asyncio_task
return {
"success": True,
"task_id": task_id,
"status": "pending",
"estimated_duration": engine.estimate_duration(params)
}
else:
# 同步执行
result = await self._execute_task(task_id, engine, params)
return result.to_dict()
async def _execute_task(
self,
task_id: str,
engine,
params: Dict[str, Any]
) -> EngineResult:
"""
执行任务
Args:
task_id: 任务 ID
engine: 引擎实例
params: 引擎参数
Returns:
执行结果
"""
task_info = self._tasks.get(task_id)
if not task_info:
return EngineResult(success=False, error="任务不存在")
# 更新状态
task_info.status = "running"
task_info.started_at = datetime.now()
start_time = time.time()
try:
# 执行引擎
result = await engine.execute(params)
# 更新任务信息
task_info.status = "completed" if result.success else "failed"
task_info.result = result
task_info.progress = 100
except asyncio.CancelledError:
task_info.status = "cancelled"
result = EngineResult(success=False, error="任务已取消", error_code="CANCELLED")
task_info.result = result
except Exception as e:
logger.error(f"任务 {task_id} 执行失败: {e}", exc_info=True)
task_info.status = "failed"
task_info.error = str(e)
result = EngineResult(success=False, error=str(e), error_code="EXECUTION_ERROR")
task_info.result = result
finally:
task_info.completed_at = datetime.now()
result.execution_time = time.time() - start_time
# 清理运行中的任务
if task_id in self._running_tasks:
del self._running_tasks[task_id]
return result
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
"""
获取任务状态
Args:
task_id: 任务 ID
Returns:
任务状态信息
"""
task_info = self._tasks.get(task_id)
if not task_info:
return None
# 获取引擎进度
engine = self._registry.get(task_info.engine_id)
if engine and task_info.status == "running":
task_info.progress = engine.get_progress(task_id)
return task_info.to_dict()
async def cancel_task(self, task_id: str) -> bool:
"""
取消任务
Args:
task_id: 任务 ID
Returns:
是否成功取消
"""
task_info = self._tasks.get(task_id)
if not task_info:
return False
if task_info.status not in ("pending", "running"):
return False
# 取消异步任务
if task_id in self._running_tasks:
self._running_tasks[task_id].cancel()
task_info.status = "cancelled"
task_info.completed_at = datetime.now()
return True
def list_tasks(
self,
status: Optional[str] = None,
engine_id: Optional[str] = None,
limit: int = 100
) -> list[Dict[str, Any]]:
"""
列出任务
Args:
status: 状态过滤
engine_id: 引擎 ID 过滤
limit: 最大返回数量
Returns:
任务列表
"""
tasks = []
for task_info in self._tasks.values():
if status and task_info.status != status:
continue
if engine_id and task_info.engine_id != engine_id:
continue
tasks.append(task_info.to_dict())
if len(tasks) >= limit:
break
return tasks
def clear_completed_tasks(self, max_age_seconds: int = 3600):
"""
清理已完成的任务
Args:
max_age_seconds: 最大保留时间
"""
now = datetime.now()
to_delete = []
for task_id, task_info in self._tasks.items():
if task_info.status in ("completed", "failed", "cancelled"):
if task_info.completed_at:
age = (now - task_info.completed_at).total_seconds()
if age > max_age_seconds:
to_delete.append(task_id)
for task_id in to_delete:
del self._tasks[task_id]
if to_delete:
logger.info(f"清理了 {len(to_delete)} 个已完成任务")