332 lines
9.3 KiB
Python
332 lines
9.3 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
AIGC 引擎基类
|
|
提供所有引擎的公共能力
|
|
"""
|
|
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, Any, Optional, List
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class EngineResult:
|
|
"""引擎执行结果"""
|
|
success: bool
|
|
data: Optional[Dict[str, Any]] = None
|
|
error: Optional[str] = None
|
|
error_code: Optional[str] = None
|
|
execution_time: float = 0.0
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"success": self.success,
|
|
"data": self.data,
|
|
"error": self.error,
|
|
"error_code": self.error_code,
|
|
"execution_time": self.execution_time,
|
|
"metadata": self.metadata
|
|
}
|
|
|
|
|
|
class BaseAIGCEngine(ABC):
|
|
"""
|
|
AIGC 引擎基类
|
|
|
|
所有 AIGC 功能引擎都应继承此类。
|
|
基类提供以下开箱即用的能力:
|
|
- self.db: 数据库访问
|
|
- self.llm: LLM 客户端
|
|
- self.prompt: 提示词构建器
|
|
- self.image: 图片处理器
|
|
- self.storage: 文件存储
|
|
- self.paths: 路径配置
|
|
"""
|
|
|
|
# ========== 子类必须定义的属性 ==========
|
|
engine_id: str = "" # 引擎唯一标识
|
|
engine_name: str = "" # 引擎显示名称
|
|
version: str = "1.0.0" # 版本号
|
|
description: str = "" # 引擎描述
|
|
|
|
# ========== 内部状态 ==========
|
|
_initialized: bool = False
|
|
_shared_components: Optional[Dict[str, Any]] = None
|
|
|
|
def __init__(self):
|
|
"""初始化引擎"""
|
|
self.logger = logging.getLogger(f"engine.{self.engine_id}")
|
|
self._task_progress: Dict[str, int] = {}
|
|
|
|
def initialize(self, shared_components: Dict[str, Any]):
|
|
"""
|
|
注入共享组件
|
|
|
|
Args:
|
|
shared_components: 包含 db, llm, prompt, image, storage, paths 等
|
|
"""
|
|
self._shared_components = shared_components
|
|
self._initialized = True
|
|
self.logger.info(f"引擎 {self.engine_id} 初始化完成")
|
|
|
|
# ========== 共享组件访问器 ==========
|
|
|
|
@property
|
|
def db(self):
|
|
"""数据库访问器"""
|
|
self._check_initialized()
|
|
return self._shared_components.get('db')
|
|
|
|
@property
|
|
def llm(self):
|
|
"""LLM 客户端"""
|
|
self._check_initialized()
|
|
return self._shared_components.get('llm')
|
|
|
|
@property
|
|
def prompt(self):
|
|
"""提示词构建器"""
|
|
self._check_initialized()
|
|
return self._shared_components.get('prompt')
|
|
|
|
@property
|
|
def image(self):
|
|
"""图片处理器"""
|
|
self._check_initialized()
|
|
return self._shared_components.get('image')
|
|
|
|
@property
|
|
def storage(self):
|
|
"""文件存储"""
|
|
self._check_initialized()
|
|
return self._shared_components.get('storage')
|
|
|
|
@property
|
|
def paths(self) -> Dict[str, Any]:
|
|
"""路径配置"""
|
|
self._check_initialized()
|
|
return self._shared_components.get('paths', {})
|
|
|
|
@property
|
|
def config(self) -> Dict[str, Any]:
|
|
"""全局配置"""
|
|
self._check_initialized()
|
|
return self._shared_components.get('config', {})
|
|
|
|
def _check_initialized(self):
|
|
"""检查是否已初始化"""
|
|
if not self._initialized:
|
|
raise RuntimeError(f"引擎 {self.engine_id} 未初始化,请先调用 initialize()")
|
|
|
|
# ========== 子类必须实现的方法 ==========
|
|
|
|
@abstractmethod
|
|
async def execute(self, params: Dict[str, Any]) -> EngineResult:
|
|
"""
|
|
执行引擎逻辑
|
|
|
|
Args:
|
|
params: 引擎参数
|
|
|
|
Returns:
|
|
EngineResult: 执行结果
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_param_schema(self) -> Dict[str, Any]:
|
|
"""
|
|
获取参数 Schema
|
|
|
|
Returns:
|
|
参数定义字典,格式:
|
|
{
|
|
"param_name": {
|
|
"type": "str|int|float|bool|list|dict",
|
|
"required": True|False,
|
|
"default": ...,
|
|
"desc": "参数描述",
|
|
"enum": [...], # 可选,枚举值
|
|
}
|
|
}
|
|
"""
|
|
pass
|
|
|
|
# ========== 可选覆盖的方法 ==========
|
|
|
|
def validate_params(self, params: Dict[str, Any]) -> tuple[bool, Optional[str]]:
|
|
"""
|
|
验证参数
|
|
|
|
Args:
|
|
params: 待验证的参数
|
|
|
|
Returns:
|
|
(是否有效, 错误信息)
|
|
"""
|
|
schema = self.get_param_schema()
|
|
|
|
for param_name, param_def in schema.items():
|
|
required = param_def.get('required', False)
|
|
param_type = param_def.get('type', 'str')
|
|
enum_values = param_def.get('enum')
|
|
|
|
# 检查必填参数
|
|
if required and param_name not in params:
|
|
return False, f"缺少必填参数: {param_name}"
|
|
|
|
# 检查参数值
|
|
if param_name in params:
|
|
value = params[param_name]
|
|
|
|
# 类型检查
|
|
type_map = {
|
|
'str': str,
|
|
'int': int,
|
|
'float': (int, float),
|
|
'bool': bool,
|
|
'list': list,
|
|
'dict': dict,
|
|
}
|
|
expected_type = type_map.get(param_type)
|
|
if expected_type and not isinstance(value, expected_type):
|
|
return False, f"参数 {param_name} 类型错误,期望 {param_type}"
|
|
|
|
# 枚举检查
|
|
if enum_values and value not in enum_values:
|
|
return False, f"参数 {param_name} 值无效,允许值: {enum_values}"
|
|
|
|
return True, None
|
|
|
|
def estimate_duration(self, params: Dict[str, Any]) -> int:
|
|
"""
|
|
预估执行时间(秒)
|
|
|
|
Args:
|
|
params: 引擎参数
|
|
|
|
Returns:
|
|
预估秒数
|
|
"""
|
|
return 30 # 默认 30 秒
|
|
|
|
def get_progress(self, task_id: str) -> int:
|
|
"""
|
|
获取任务进度
|
|
|
|
Args:
|
|
task_id: 任务 ID
|
|
|
|
Returns:
|
|
进度百分比 (0-100)
|
|
"""
|
|
return self._task_progress.get(task_id, 0)
|
|
|
|
def set_progress(self, task_id: str, progress: int):
|
|
"""
|
|
设置任务进度
|
|
|
|
Args:
|
|
task_id: 任务 ID
|
|
progress: 进度百分比 (0-100)
|
|
"""
|
|
self._task_progress[task_id] = min(100, max(0, progress))
|
|
|
|
# ========== 工具方法 ==========
|
|
|
|
def parse_json(self, text: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
从文本中提取 JSON
|
|
|
|
Args:
|
|
text: 可能包含 JSON 的文本
|
|
|
|
Returns:
|
|
解析后的字典,失败返回 None
|
|
"""
|
|
import json
|
|
import re
|
|
|
|
# 尝试直接解析
|
|
try:
|
|
return json.loads(text)
|
|
except:
|
|
pass
|
|
|
|
# 尝试提取 JSON 块
|
|
patterns = [
|
|
r'```json\s*([\s\S]*?)\s*```',
|
|
r'```\s*([\s\S]*?)\s*```',
|
|
r'\{[\s\S]*\}',
|
|
]
|
|
|
|
for pattern in patterns:
|
|
match = re.search(pattern, text)
|
|
if match:
|
|
try:
|
|
json_str = match.group(1) if '```' in pattern else match.group(0)
|
|
return json.loads(json_str)
|
|
except:
|
|
continue
|
|
|
|
return None
|
|
|
|
def get_path(self, *parts: str) -> Path:
|
|
"""
|
|
获取项目路径
|
|
|
|
Args:
|
|
*parts: 路径部分
|
|
|
|
Returns:
|
|
完整路径
|
|
"""
|
|
project_root = self.paths.get('project_root', '')
|
|
return Path(project_root) / Path(*parts)
|
|
|
|
def get_resource_path(self, resource_type: str, *parts: str) -> Path:
|
|
"""
|
|
获取资源路径
|
|
|
|
Args:
|
|
resource_type: 资源类型 (prompt, font, template)
|
|
*parts: 额外路径部分
|
|
|
|
Returns:
|
|
完整路径
|
|
"""
|
|
resource_paths = self.paths.get('resource', {})
|
|
base_path = resource_paths.get(resource_type, f'resource/{resource_type}')
|
|
return self.get_path(base_path, *parts)
|
|
|
|
def log(self, message: str, level: str = 'info'):
|
|
"""
|
|
记录日志
|
|
|
|
Args:
|
|
message: 日志消息
|
|
level: 日志级别 (debug, info, warning, error)
|
|
"""
|
|
log_func = getattr(self.logger, level, self.logger.info)
|
|
log_func(message)
|
|
|
|
# ========== 引擎信息 ==========
|
|
|
|
def get_info(self) -> Dict[str, Any]:
|
|
"""获取引擎信息"""
|
|
return {
|
|
"id": self.engine_id,
|
|
"name": self.engine_name,
|
|
"version": self.version,
|
|
"description": self.description,
|
|
"param_schema": self.get_param_schema(),
|
|
}
|