TravelContentCreator/domain/aigc/engine_executor.py

306 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
引擎执行器
负责引擎的调用和任务管理
"""
import logging
import uuid
import time
import asyncio
from typing import Dict, Any, Optional
from datetime import datetime
from dataclasses import dataclass, field
from .engine_registry import EngineRegistry
from .engines.base import EngineResult
logger = logging.getLogger(__name__)
@dataclass
class TaskInfo:
"""任务信息"""
task_id: str
engine_id: str
params: Dict[str, Any]
status: str = "pending" # pending, running, completed, failed, cancelled
progress: int = 0
result: Optional[EngineResult] = None
error: Optional[str] = None
created_at: datetime = field(default_factory=datetime.now)
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
def to_dict(self) -> Dict[str, Any]:
return {
"task_id": self.task_id,
"engine_id": self.engine_id,
"status": self.status,
"progress": self.progress,
"result": self.result.to_dict() if self.result else None,
"error": self.error,
"created_at": self.created_at.isoformat() if self.created_at else None,
"started_at": self.started_at.isoformat() if self.started_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
}
class EngineExecutor:
"""
引擎执行器
支持:
- 同步执行
- 异步执行(返回 task_id
- 任务状态查询
- 任务取消
"""
_instance: Optional['EngineExecutor'] = None
def __new__(cls):
"""单例模式"""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._registry = EngineRegistry()
self._tasks: Dict[str, TaskInfo] = {}
self._running_tasks: Dict[str, asyncio.Task] = {}
self._initialized = True
logger.info("引擎执行器初始化")
async def execute(
self,
engine_id: str,
params: Dict[str, Any],
async_mode: bool = False,
task_id: Optional[str] = None
) -> Dict[str, Any]:
"""
执行引擎
Args:
engine_id: 引擎 ID
params: 引擎参数
async_mode: 是否异步执行
task_id: 任务 ID可选不传则自动生成
Returns:
同步模式:执行结果
异步模式:{"task_id": "...", "status": "pending"}
"""
# 生成任务 ID
if not task_id:
task_id = str(uuid.uuid4()).replace('-', '')[:16]
# 获取引擎
engine = self._registry.get(engine_id)
if not engine:
return {
"success": False,
"error": f"引擎 {engine_id} 不存在",
"error_code": "ENGINE_NOT_FOUND"
}
# 验证参数
valid, error_msg = engine.validate_params(params)
if not valid:
return {
"success": False,
"error": error_msg,
"error_code": "INVALID_PARAMS"
}
# 创建任务信息
task_info = TaskInfo(
task_id=task_id,
engine_id=engine_id,
params=params
)
self._tasks[task_id] = task_info
if async_mode:
# 异步执行
asyncio_task = asyncio.create_task(
self._execute_task(task_id, engine, params)
)
self._running_tasks[task_id] = asyncio_task
return {
"success": True,
"task_id": task_id,
"status": "pending",
"estimated_duration": engine.estimate_duration(params)
}
else:
# 同步执行
result = await self._execute_task(task_id, engine, params)
return result.to_dict()
async def _execute_task(
self,
task_id: str,
engine,
params: Dict[str, Any]
) -> EngineResult:
"""
执行任务
Args:
task_id: 任务 ID
engine: 引擎实例
params: 引擎参数
Returns:
执行结果
"""
task_info = self._tasks.get(task_id)
if not task_info:
return EngineResult(success=False, error="任务不存在")
# 更新状态
task_info.status = "running"
task_info.started_at = datetime.now()
start_time = time.time()
try:
# 执行引擎
result = await engine.execute(params)
# 更新任务信息
task_info.status = "completed" if result.success else "failed"
task_info.result = result
task_info.progress = 100
except asyncio.CancelledError:
task_info.status = "cancelled"
result = EngineResult(success=False, error="任务已取消", error_code="CANCELLED")
task_info.result = result
except Exception as e:
logger.error(f"任务 {task_id} 执行失败: {e}", exc_info=True)
task_info.status = "failed"
task_info.error = str(e)
result = EngineResult(success=False, error=str(e), error_code="EXECUTION_ERROR")
task_info.result = result
finally:
task_info.completed_at = datetime.now()
result.execution_time = time.time() - start_time
# 清理运行中的任务
if task_id in self._running_tasks:
del self._running_tasks[task_id]
return result
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
"""
获取任务状态
Args:
task_id: 任务 ID
Returns:
任务状态信息
"""
task_info = self._tasks.get(task_id)
if not task_info:
return None
# 获取引擎进度
engine = self._registry.get(task_info.engine_id)
if engine and task_info.status == "running":
task_info.progress = engine.get_progress(task_id)
return task_info.to_dict()
async def cancel_task(self, task_id: str) -> bool:
"""
取消任务
Args:
task_id: 任务 ID
Returns:
是否成功取消
"""
task_info = self._tasks.get(task_id)
if not task_info:
return False
if task_info.status not in ("pending", "running"):
return False
# 取消异步任务
if task_id in self._running_tasks:
self._running_tasks[task_id].cancel()
task_info.status = "cancelled"
task_info.completed_at = datetime.now()
return True
def list_tasks(
self,
status: Optional[str] = None,
engine_id: Optional[str] = None,
limit: int = 100
) -> list[Dict[str, Any]]:
"""
列出任务
Args:
status: 状态过滤
engine_id: 引擎 ID 过滤
limit: 最大返回数量
Returns:
任务列表
"""
tasks = []
for task_info in self._tasks.values():
if status and task_info.status != status:
continue
if engine_id and task_info.engine_id != engine_id:
continue
tasks.append(task_info.to_dict())
if len(tasks) >= limit:
break
return tasks
def clear_completed_tasks(self, max_age_seconds: int = 3600):
"""
清理已完成的任务
Args:
max_age_seconds: 最大保留时间(秒)
"""
now = datetime.now()
to_delete = []
for task_id, task_info in self._tasks.items():
if task_info.status in ("completed", "failed", "cancelled"):
if task_info.completed_at:
age = (now - task_info.completed_at).total_seconds()
if age > max_age_seconds:
to_delete.append(task_id)
for task_id in to_delete:
del self._tasks[task_id]
if to_delete:
logger.info(f"清理了 {len(to_delete)} 个已完成任务")