208 lines
5.8 KiB
Python
Raw Normal View History

2025-12-08 14:58:35 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Repository 基类
提供通用的 CRUD 操作
"""
import logging
from typing import Dict, Any, Optional, List, TypeVar, Generic
from abc import ABC, abstractmethod
from .connection import DatabaseConnection
logger = logging.getLogger(__name__)
T = TypeVar('T')
class BaseRepository(ABC, Generic[T]):
"""
Repository 基类
提供
- 通用 CRUD 操作
- 批量查询
- 分页查询
"""
def __init__(self, db: Optional[DatabaseConnection] = None):
"""
初始化 Repository
Args:
db: 数据库连接不传则使用单例
"""
self._db = db or DatabaseConnection()
@property
@abstractmethod
def table_name(self) -> str:
"""表名"""
pass
@property
def primary_key(self) -> str:
"""主键字段名"""
return 'id'
@property
def soft_delete_field(self) -> Optional[str]:
"""软删除字段名None 表示不使用软删除"""
return 'is_delete'
def _build_where_clause(self, include_deleted: bool = False) -> str:
"""构建 WHERE 子句(处理软删除)"""
if self.soft_delete_field and not include_deleted:
return f"WHERE {self.soft_delete_field} = 0"
return ""
def find_by_id(self, id: Any, include_deleted: bool = False) -> Optional[Dict[str, Any]]:
"""
根据 ID 查询
Args:
id: 主键值
include_deleted: 是否包含已删除记录
Returns:
记录字典或 None
"""
where = self._build_where_clause(include_deleted)
if where:
where += f" AND {self.primary_key} = %s"
else:
where = f"WHERE {self.primary_key} = %s"
query = f"SELECT * FROM {self.table_name} {where}"
try:
return self._db.execute_one(query, (id,))
except Exception as e:
logger.error(f"查询 {self.table_name} 失败: {e}")
return None
def find_by_ids(self, ids: List[Any], include_deleted: bool = False) -> List[Dict[str, Any]]:
"""
根据 ID 列表批量查询
Args:
ids: 主键值列表
include_deleted: 是否包含已删除记录
Returns:
记录列表
"""
if not ids:
return []
placeholders = ','.join(['%s'] * len(ids))
where = self._build_where_clause(include_deleted)
if where:
where += f" AND {self.primary_key} IN ({placeholders})"
else:
where = f"WHERE {self.primary_key} IN ({placeholders})"
query = f"SELECT * FROM {self.table_name} {where}"
try:
return self._db.execute_query(query, tuple(ids))
except Exception as e:
logger.error(f"批量查询 {self.table_name} 失败: {e}")
return []
def find_all(self, limit: int = 100, offset: int = 0, include_deleted: bool = False) -> List[Dict[str, Any]]:
"""
查询所有记录
Args:
limit: 最大返回数量
offset: 偏移量
include_deleted: 是否包含已删除记录
Returns:
记录列表
"""
where = self._build_where_clause(include_deleted)
query = f"SELECT * FROM {self.table_name} {where} LIMIT %s OFFSET %s"
try:
return self._db.execute_query(query, (limit, offset))
except Exception as e:
logger.error(f"查询 {self.table_name} 列表失败: {e}")
return []
def find_by_field(self, field: str, value: Any, include_deleted: bool = False) -> List[Dict[str, Any]]:
"""
根据字段查询
Args:
field: 字段名
value: 字段值
include_deleted: 是否包含已删除记录
Returns:
记录列表
"""
where = self._build_where_clause(include_deleted)
if where:
where += f" AND {field} = %s"
else:
where = f"WHERE {field} = %s"
query = f"SELECT * FROM {self.table_name} {where}"
try:
return self._db.execute_query(query, (value,))
except Exception as e:
logger.error(f"查询 {self.table_name} 失败: {e}")
return []
def find_one_by_field(self, field: str, value: Any, include_deleted: bool = False) -> Optional[Dict[str, Any]]:
"""
根据字段查询单条记录
Args:
field: 字段名
value: 字段值
include_deleted: 是否包含已删除记录
Returns:
记录字典或 None
"""
results = self.find_by_field(field, value, include_deleted)
return results[0] if results else None
def count(self, include_deleted: bool = False) -> int:
"""
统计记录数
Args:
include_deleted: 是否包含已删除记录
Returns:
记录数
"""
where = self._build_where_clause(include_deleted)
query = f"SELECT COUNT(*) as count FROM {self.table_name} {where}"
try:
result = self._db.execute_one(query)
return result.get('count', 0) if result else 0
except Exception as e:
logger.error(f"统计 {self.table_name} 失败: {e}")
return 0
def exists(self, id: Any) -> bool:
"""
检查记录是否存在
Args:
id: 主键值
Returns:
是否存在
"""
return self.find_by_id(id) is not None