306 lines
8.7 KiB
Python
306 lines
8.7 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
引擎执行器
|
||
负责引擎的调用和任务管理
|
||
"""
|
||
|
||
import logging
|
||
import uuid
|
||
import time
|
||
import asyncio
|
||
from typing import Dict, Any, Optional
|
||
from datetime import datetime
|
||
from dataclasses import dataclass, field
|
||
|
||
from .engine_registry import EngineRegistry
|
||
from .engines.base import EngineResult
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class TaskInfo:
|
||
"""任务信息"""
|
||
task_id: str
|
||
engine_id: str
|
||
params: Dict[str, Any]
|
||
status: str = "pending" # pending, running, completed, failed, cancelled
|
||
progress: int = 0
|
||
result: Optional[EngineResult] = None
|
||
error: Optional[str] = None
|
||
created_at: datetime = field(default_factory=datetime.now)
|
||
started_at: Optional[datetime] = None
|
||
completed_at: Optional[datetime] = None
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"task_id": self.task_id,
|
||
"engine_id": self.engine_id,
|
||
"status": self.status,
|
||
"progress": self.progress,
|
||
"result": self.result.to_dict() if self.result else None,
|
||
"error": self.error,
|
||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||
}
|
||
|
||
|
||
class EngineExecutor:
|
||
"""
|
||
引擎执行器
|
||
|
||
支持:
|
||
- 同步执行
|
||
- 异步执行(返回 task_id)
|
||
- 任务状态查询
|
||
- 任务取消
|
||
"""
|
||
|
||
_instance: Optional['EngineExecutor'] = None
|
||
|
||
def __new__(cls):
|
||
"""单例模式"""
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
cls._instance._initialized = False
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
if self._initialized:
|
||
return
|
||
|
||
self._registry = EngineRegistry()
|
||
self._tasks: Dict[str, TaskInfo] = {}
|
||
self._running_tasks: Dict[str, asyncio.Task] = {}
|
||
self._initialized = True
|
||
|
||
logger.info("引擎执行器初始化")
|
||
|
||
async def execute(
|
||
self,
|
||
engine_id: str,
|
||
params: Dict[str, Any],
|
||
async_mode: bool = False,
|
||
task_id: Optional[str] = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
执行引擎
|
||
|
||
Args:
|
||
engine_id: 引擎 ID
|
||
params: 引擎参数
|
||
async_mode: 是否异步执行
|
||
task_id: 任务 ID(可选,不传则自动生成)
|
||
|
||
Returns:
|
||
同步模式:执行结果
|
||
异步模式:{"task_id": "...", "status": "pending"}
|
||
"""
|
||
# 生成任务 ID
|
||
if not task_id:
|
||
task_id = str(uuid.uuid4()).replace('-', '')[:16]
|
||
|
||
# 获取引擎
|
||
engine = self._registry.get(engine_id)
|
||
if not engine:
|
||
return {
|
||
"success": False,
|
||
"error": f"引擎 {engine_id} 不存在",
|
||
"error_code": "ENGINE_NOT_FOUND"
|
||
}
|
||
|
||
# 验证参数
|
||
valid, error_msg = engine.validate_params(params)
|
||
if not valid:
|
||
return {
|
||
"success": False,
|
||
"error": error_msg,
|
||
"error_code": "INVALID_PARAMS"
|
||
}
|
||
|
||
# 创建任务信息
|
||
task_info = TaskInfo(
|
||
task_id=task_id,
|
||
engine_id=engine_id,
|
||
params=params
|
||
)
|
||
self._tasks[task_id] = task_info
|
||
|
||
if async_mode:
|
||
# 异步执行
|
||
asyncio_task = asyncio.create_task(
|
||
self._execute_task(task_id, engine, params)
|
||
)
|
||
self._running_tasks[task_id] = asyncio_task
|
||
|
||
return {
|
||
"success": True,
|
||
"task_id": task_id,
|
||
"status": "pending",
|
||
"estimated_duration": engine.estimate_duration(params)
|
||
}
|
||
else:
|
||
# 同步执行
|
||
result = await self._execute_task(task_id, engine, params)
|
||
return result.to_dict()
|
||
|
||
async def _execute_task(
|
||
self,
|
||
task_id: str,
|
||
engine,
|
||
params: Dict[str, Any]
|
||
) -> EngineResult:
|
||
"""
|
||
执行任务
|
||
|
||
Args:
|
||
task_id: 任务 ID
|
||
engine: 引擎实例
|
||
params: 引擎参数
|
||
|
||
Returns:
|
||
执行结果
|
||
"""
|
||
task_info = self._tasks.get(task_id)
|
||
if not task_info:
|
||
return EngineResult(success=False, error="任务不存在")
|
||
|
||
# 更新状态
|
||
task_info.status = "running"
|
||
task_info.started_at = datetime.now()
|
||
|
||
start_time = time.time()
|
||
|
||
try:
|
||
# 执行引擎
|
||
result = await engine.execute(params)
|
||
|
||
# 更新任务信息
|
||
task_info.status = "completed" if result.success else "failed"
|
||
task_info.result = result
|
||
task_info.progress = 100
|
||
|
||
except asyncio.CancelledError:
|
||
task_info.status = "cancelled"
|
||
result = EngineResult(success=False, error="任务已取消", error_code="CANCELLED")
|
||
task_info.result = result
|
||
|
||
except Exception as e:
|
||
logger.error(f"任务 {task_id} 执行失败: {e}", exc_info=True)
|
||
task_info.status = "failed"
|
||
task_info.error = str(e)
|
||
result = EngineResult(success=False, error=str(e), error_code="EXECUTION_ERROR")
|
||
task_info.result = result
|
||
|
||
finally:
|
||
task_info.completed_at = datetime.now()
|
||
result.execution_time = time.time() - start_time
|
||
|
||
# 清理运行中的任务
|
||
if task_id in self._running_tasks:
|
||
del self._running_tasks[task_id]
|
||
|
||
return result
|
||
|
||
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
获取任务状态
|
||
|
||
Args:
|
||
task_id: 任务 ID
|
||
|
||
Returns:
|
||
任务状态信息
|
||
"""
|
||
task_info = self._tasks.get(task_id)
|
||
if not task_info:
|
||
return None
|
||
|
||
# 获取引擎进度
|
||
engine = self._registry.get(task_info.engine_id)
|
||
if engine and task_info.status == "running":
|
||
task_info.progress = engine.get_progress(task_id)
|
||
|
||
return task_info.to_dict()
|
||
|
||
async def cancel_task(self, task_id: str) -> bool:
|
||
"""
|
||
取消任务
|
||
|
||
Args:
|
||
task_id: 任务 ID
|
||
|
||
Returns:
|
||
是否成功取消
|
||
"""
|
||
task_info = self._tasks.get(task_id)
|
||
if not task_info:
|
||
return False
|
||
|
||
if task_info.status not in ("pending", "running"):
|
||
return False
|
||
|
||
# 取消异步任务
|
||
if task_id in self._running_tasks:
|
||
self._running_tasks[task_id].cancel()
|
||
|
||
task_info.status = "cancelled"
|
||
task_info.completed_at = datetime.now()
|
||
|
||
return True
|
||
|
||
def list_tasks(
|
||
self,
|
||
status: Optional[str] = None,
|
||
engine_id: Optional[str] = None,
|
||
limit: int = 100
|
||
) -> list[Dict[str, Any]]:
|
||
"""
|
||
列出任务
|
||
|
||
Args:
|
||
status: 状态过滤
|
||
engine_id: 引擎 ID 过滤
|
||
limit: 最大返回数量
|
||
|
||
Returns:
|
||
任务列表
|
||
"""
|
||
tasks = []
|
||
for task_info in self._tasks.values():
|
||
if status and task_info.status != status:
|
||
continue
|
||
if engine_id and task_info.engine_id != engine_id:
|
||
continue
|
||
tasks.append(task_info.to_dict())
|
||
if len(tasks) >= limit:
|
||
break
|
||
|
||
return tasks
|
||
|
||
def clear_completed_tasks(self, max_age_seconds: int = 3600):
|
||
"""
|
||
清理已完成的任务
|
||
|
||
Args:
|
||
max_age_seconds: 最大保留时间(秒)
|
||
"""
|
||
now = datetime.now()
|
||
to_delete = []
|
||
|
||
for task_id, task_info in self._tasks.items():
|
||
if task_info.status in ("completed", "failed", "cancelled"):
|
||
if task_info.completed_at:
|
||
age = (now - task_info.completed_at).total_seconds()
|
||
if age > max_age_seconds:
|
||
to_delete.append(task_id)
|
||
|
||
for task_id in to_delete:
|
||
del self._tasks[task_id]
|
||
|
||
if to_delete:
|
||
logger.info(f"清理了 {len(to_delete)} 个已完成任务")
|