332 lines
9.3 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
AIGC 引擎基类
提供所有引擎的公共能力
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List
from pathlib import Path
from datetime import datetime
logger = logging.getLogger(__name__)
@dataclass
class EngineResult:
"""引擎执行结果"""
success: bool
data: Optional[Dict[str, Any]] = None
error: Optional[str] = None
error_code: Optional[str] = None
execution_time: float = 0.0
metadata: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"success": self.success,
"data": self.data,
"error": self.error,
"error_code": self.error_code,
"execution_time": self.execution_time,
"metadata": self.metadata
}
class BaseAIGCEngine(ABC):
"""
AIGC 引擎基类
所有 AIGC 功能引擎都应继承此类。
基类提供以下开箱即用的能力:
- self.db: 数据库访问
- self.llm: LLM 客户端
- self.prompt: 提示词构建器
- self.image: 图片处理器
- self.storage: 文件存储
- self.paths: 路径配置
"""
# ========== 子类必须定义的属性 ==========
engine_id: str = "" # 引擎唯一标识
engine_name: str = "" # 引擎显示名称
version: str = "1.0.0" # 版本号
description: str = "" # 引擎描述
# ========== 内部状态 ==========
_initialized: bool = False
_shared_components: Optional[Dict[str, Any]] = None
def __init__(self):
"""初始化引擎"""
self.logger = logging.getLogger(f"engine.{self.engine_id}")
self._task_progress: Dict[str, int] = {}
def initialize(self, shared_components: Dict[str, Any]):
"""
注入共享组件
Args:
shared_components: 包含 db, llm, prompt, image, storage, paths 等
"""
self._shared_components = shared_components
self._initialized = True
self.logger.info(f"引擎 {self.engine_id} 初始化完成")
# ========== 共享组件访问器 ==========
@property
def db(self):
"""数据库访问器"""
self._check_initialized()
return self._shared_components.get('db')
@property
def llm(self):
"""LLM 客户端"""
self._check_initialized()
return self._shared_components.get('llm')
@property
def prompt(self):
"""提示词构建器"""
self._check_initialized()
return self._shared_components.get('prompt')
@property
def image(self):
"""图片处理器"""
self._check_initialized()
return self._shared_components.get('image')
@property
def storage(self):
"""文件存储"""
self._check_initialized()
return self._shared_components.get('storage')
@property
def paths(self) -> Dict[str, Any]:
"""路径配置"""
self._check_initialized()
return self._shared_components.get('paths', {})
@property
def config(self) -> Dict[str, Any]:
"""全局配置"""
self._check_initialized()
return self._shared_components.get('config', {})
def _check_initialized(self):
"""检查是否已初始化"""
if not self._initialized:
raise RuntimeError(f"引擎 {self.engine_id} 未初始化,请先调用 initialize()")
# ========== 子类必须实现的方法 ==========
@abstractmethod
async def execute(self, params: Dict[str, Any]) -> EngineResult:
"""
执行引擎逻辑
Args:
params: 引擎参数
Returns:
EngineResult: 执行结果
"""
pass
@abstractmethod
def get_param_schema(self) -> Dict[str, Any]:
"""
获取参数 Schema
Returns:
参数定义字典,格式:
{
"param_name": {
"type": "str|int|float|bool|list|dict",
"required": True|False,
"default": ...,
"desc": "参数描述",
"enum": [...], # 可选,枚举值
}
}
"""
pass
# ========== 可选覆盖的方法 ==========
def validate_params(self, params: Dict[str, Any]) -> tuple[bool, Optional[str]]:
"""
验证参数
Args:
params: 待验证的参数
Returns:
(是否有效, 错误信息)
"""
schema = self.get_param_schema()
for param_name, param_def in schema.items():
required = param_def.get('required', False)
param_type = param_def.get('type', 'str')
enum_values = param_def.get('enum')
# 检查必填参数
if required and param_name not in params:
return False, f"缺少必填参数: {param_name}"
# 检查参数值
if param_name in params:
value = params[param_name]
# 类型检查
type_map = {
'str': str,
'int': int,
'float': (int, float),
'bool': bool,
'list': list,
'dict': dict,
}
expected_type = type_map.get(param_type)
if expected_type and not isinstance(value, expected_type):
return False, f"参数 {param_name} 类型错误,期望 {param_type}"
# 枚举检查
if enum_values and value not in enum_values:
return False, f"参数 {param_name} 值无效,允许值: {enum_values}"
return True, None
def estimate_duration(self, params: Dict[str, Any]) -> int:
"""
预估执行时间(秒)
Args:
params: 引擎参数
Returns:
预估秒数
"""
return 30 # 默认 30 秒
def get_progress(self, task_id: str) -> int:
"""
获取任务进度
Args:
task_id: 任务 ID
Returns:
进度百分比 (0-100)
"""
return self._task_progress.get(task_id, 0)
def set_progress(self, task_id: str, progress: int):
"""
设置任务进度
Args:
task_id: 任务 ID
progress: 进度百分比 (0-100)
"""
self._task_progress[task_id] = min(100, max(0, progress))
# ========== 工具方法 ==========
def parse_json(self, text: str) -> Optional[Dict[str, Any]]:
"""
从文本中提取 JSON
Args:
text: 可能包含 JSON 的文本
Returns:
解析后的字典,失败返回 None
"""
import json
import re
# 尝试直接解析
try:
return json.loads(text)
except:
pass
# 尝试提取 JSON 块
patterns = [
r'```json\s*([\s\S]*?)\s*```',
r'```\s*([\s\S]*?)\s*```',
r'\{[\s\S]*\}',
]
for pattern in patterns:
match = re.search(pattern, text)
if match:
try:
json_str = match.group(1) if '```' in pattern else match.group(0)
return json.loads(json_str)
except:
continue
return None
def get_path(self, *parts: str) -> Path:
"""
获取项目路径
Args:
*parts: 路径部分
Returns:
完整路径
"""
project_root = self.paths.get('project_root', '')
return Path(project_root) / Path(*parts)
def get_resource_path(self, resource_type: str, *parts: str) -> Path:
"""
获取资源路径
Args:
resource_type: 资源类型 (prompt, font, template)
*parts: 额外路径部分
Returns:
完整路径
"""
resource_paths = self.paths.get('resource', {})
base_path = resource_paths.get(resource_type, f'resource/{resource_type}')
return self.get_path(base_path, *parts)
def log(self, message: str, level: str = 'info'):
"""
记录日志
Args:
message: 日志消息
level: 日志级别 (debug, info, warning, error)
"""
log_func = getattr(self.logger, level, self.logger.info)
log_func(message)
# ========== 引擎信息 ==========
def get_info(self) -> Dict[str, Any]:
"""获取引擎信息"""
return {
"id": self.engine_id,
"name": self.engine_name,
"version": self.version,
"description": self.description,
"param_schema": self.get_param_schema(),
}