208 lines
5.8 KiB
Python
208 lines
5.8 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
Repository 基类
|
||
提供通用的 CRUD 操作
|
||
"""
|
||
|
||
import logging
|
||
from typing import Dict, Any, Optional, List, TypeVar, Generic
|
||
from abc import ABC, abstractmethod
|
||
|
||
from .connection import DatabaseConnection
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
T = TypeVar('T')
|
||
|
||
|
||
class BaseRepository(ABC, Generic[T]):
|
||
"""
|
||
Repository 基类
|
||
|
||
提供:
|
||
- 通用 CRUD 操作
|
||
- 批量查询
|
||
- 分页查询
|
||
"""
|
||
|
||
def __init__(self, db: Optional[DatabaseConnection] = None):
|
||
"""
|
||
初始化 Repository
|
||
|
||
Args:
|
||
db: 数据库连接,不传则使用单例
|
||
"""
|
||
self._db = db or DatabaseConnection()
|
||
|
||
@property
|
||
@abstractmethod
|
||
def table_name(self) -> str:
|
||
"""表名"""
|
||
pass
|
||
|
||
@property
|
||
def primary_key(self) -> str:
|
||
"""主键字段名"""
|
||
return 'id'
|
||
|
||
@property
|
||
def soft_delete_field(self) -> Optional[str]:
|
||
"""软删除字段名,None 表示不使用软删除"""
|
||
return 'is_delete'
|
||
|
||
def _build_where_clause(self, include_deleted: bool = False) -> str:
|
||
"""构建 WHERE 子句(处理软删除)"""
|
||
if self.soft_delete_field and not include_deleted:
|
||
return f"WHERE {self.soft_delete_field} = 0"
|
||
return ""
|
||
|
||
def find_by_id(self, id: Any, include_deleted: bool = False) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
根据 ID 查询
|
||
|
||
Args:
|
||
id: 主键值
|
||
include_deleted: 是否包含已删除记录
|
||
|
||
Returns:
|
||
记录字典或 None
|
||
"""
|
||
where = self._build_where_clause(include_deleted)
|
||
if where:
|
||
where += f" AND {self.primary_key} = %s"
|
||
else:
|
||
where = f"WHERE {self.primary_key} = %s"
|
||
|
||
query = f"SELECT * FROM {self.table_name} {where}"
|
||
|
||
try:
|
||
return self._db.execute_one(query, (id,))
|
||
except Exception as e:
|
||
logger.error(f"查询 {self.table_name} 失败: {e}")
|
||
return None
|
||
|
||
def find_by_ids(self, ids: List[Any], include_deleted: bool = False) -> List[Dict[str, Any]]:
|
||
"""
|
||
根据 ID 列表批量查询
|
||
|
||
Args:
|
||
ids: 主键值列表
|
||
include_deleted: 是否包含已删除记录
|
||
|
||
Returns:
|
||
记录列表
|
||
"""
|
||
if not ids:
|
||
return []
|
||
|
||
placeholders = ','.join(['%s'] * len(ids))
|
||
where = self._build_where_clause(include_deleted)
|
||
if where:
|
||
where += f" AND {self.primary_key} IN ({placeholders})"
|
||
else:
|
||
where = f"WHERE {self.primary_key} IN ({placeholders})"
|
||
|
||
query = f"SELECT * FROM {self.table_name} {where}"
|
||
|
||
try:
|
||
return self._db.execute_query(query, tuple(ids))
|
||
except Exception as e:
|
||
logger.error(f"批量查询 {self.table_name} 失败: {e}")
|
||
return []
|
||
|
||
def find_all(self, limit: int = 100, offset: int = 0, include_deleted: bool = False) -> List[Dict[str, Any]]:
|
||
"""
|
||
查询所有记录
|
||
|
||
Args:
|
||
limit: 最大返回数量
|
||
offset: 偏移量
|
||
include_deleted: 是否包含已删除记录
|
||
|
||
Returns:
|
||
记录列表
|
||
"""
|
||
where = self._build_where_clause(include_deleted)
|
||
query = f"SELECT * FROM {self.table_name} {where} LIMIT %s OFFSET %s"
|
||
|
||
try:
|
||
return self._db.execute_query(query, (limit, offset))
|
||
except Exception as e:
|
||
logger.error(f"查询 {self.table_name} 列表失败: {e}")
|
||
return []
|
||
|
||
def find_by_field(self, field: str, value: Any, include_deleted: bool = False) -> List[Dict[str, Any]]:
|
||
"""
|
||
根据字段查询
|
||
|
||
Args:
|
||
field: 字段名
|
||
value: 字段值
|
||
include_deleted: 是否包含已删除记录
|
||
|
||
Returns:
|
||
记录列表
|
||
"""
|
||
where = self._build_where_clause(include_deleted)
|
||
if where:
|
||
where += f" AND {field} = %s"
|
||
else:
|
||
where = f"WHERE {field} = %s"
|
||
|
||
query = f"SELECT * FROM {self.table_name} {where}"
|
||
|
||
try:
|
||
return self._db.execute_query(query, (value,))
|
||
except Exception as e:
|
||
logger.error(f"查询 {self.table_name} 失败: {e}")
|
||
return []
|
||
|
||
def find_one_by_field(self, field: str, value: Any, include_deleted: bool = False) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
根据字段查询单条记录
|
||
|
||
Args:
|
||
field: 字段名
|
||
value: 字段值
|
||
include_deleted: 是否包含已删除记录
|
||
|
||
Returns:
|
||
记录字典或 None
|
||
"""
|
||
results = self.find_by_field(field, value, include_deleted)
|
||
return results[0] if results else None
|
||
|
||
def count(self, include_deleted: bool = False) -> int:
|
||
"""
|
||
统计记录数
|
||
|
||
Args:
|
||
include_deleted: 是否包含已删除记录
|
||
|
||
Returns:
|
||
记录数
|
||
"""
|
||
where = self._build_where_clause(include_deleted)
|
||
query = f"SELECT COUNT(*) as count FROM {self.table_name} {where}"
|
||
|
||
try:
|
||
result = self._db.execute_one(query)
|
||
return result.get('count', 0) if result else 0
|
||
except Exception as e:
|
||
logger.error(f"统计 {self.table_name} 失败: {e}")
|
||
return 0
|
||
|
||
def exists(self, id: Any) -> bool:
|
||
"""
|
||
检查记录是否存在
|
||
|
||
Args:
|
||
id: 主键值
|
||
|
||
Returns:
|
||
是否存在
|
||
"""
|
||
return self.find_by_id(id) is not None
|