208 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Repository 基类
提供通用的 CRUD 操作
"""
import logging
from typing import Dict, Any, Optional, List, TypeVar, Generic
from abc import ABC, abstractmethod
from .connection import DatabaseConnection
logger = logging.getLogger(__name__)
T = TypeVar('T')
class BaseRepository(ABC, Generic[T]):
"""
Repository 基类
提供:
- 通用 CRUD 操作
- 批量查询
- 分页查询
"""
def __init__(self, db: Optional[DatabaseConnection] = None):
"""
初始化 Repository
Args:
db: 数据库连接,不传则使用单例
"""
self._db = db or DatabaseConnection()
@property
@abstractmethod
def table_name(self) -> str:
"""表名"""
pass
@property
def primary_key(self) -> str:
"""主键字段名"""
return 'id'
@property
def soft_delete_field(self) -> Optional[str]:
"""软删除字段名None 表示不使用软删除"""
return 'is_delete'
def _build_where_clause(self, include_deleted: bool = False) -> str:
"""构建 WHERE 子句(处理软删除)"""
if self.soft_delete_field and not include_deleted:
return f"WHERE {self.soft_delete_field} = 0"
return ""
def find_by_id(self, id: Any, include_deleted: bool = False) -> Optional[Dict[str, Any]]:
"""
根据 ID 查询
Args:
id: 主键值
include_deleted: 是否包含已删除记录
Returns:
记录字典或 None
"""
where = self._build_where_clause(include_deleted)
if where:
where += f" AND {self.primary_key} = %s"
else:
where = f"WHERE {self.primary_key} = %s"
query = f"SELECT * FROM {self.table_name} {where}"
try:
return self._db.execute_one(query, (id,))
except Exception as e:
logger.error(f"查询 {self.table_name} 失败: {e}")
return None
def find_by_ids(self, ids: List[Any], include_deleted: bool = False) -> List[Dict[str, Any]]:
"""
根据 ID 列表批量查询
Args:
ids: 主键值列表
include_deleted: 是否包含已删除记录
Returns:
记录列表
"""
if not ids:
return []
placeholders = ','.join(['%s'] * len(ids))
where = self._build_where_clause(include_deleted)
if where:
where += f" AND {self.primary_key} IN ({placeholders})"
else:
where = f"WHERE {self.primary_key} IN ({placeholders})"
query = f"SELECT * FROM {self.table_name} {where}"
try:
return self._db.execute_query(query, tuple(ids))
except Exception as e:
logger.error(f"批量查询 {self.table_name} 失败: {e}")
return []
def find_all(self, limit: int = 100, offset: int = 0, include_deleted: bool = False) -> List[Dict[str, Any]]:
"""
查询所有记录
Args:
limit: 最大返回数量
offset: 偏移量
include_deleted: 是否包含已删除记录
Returns:
记录列表
"""
where = self._build_where_clause(include_deleted)
query = f"SELECT * FROM {self.table_name} {where} LIMIT %s OFFSET %s"
try:
return self._db.execute_query(query, (limit, offset))
except Exception as e:
logger.error(f"查询 {self.table_name} 列表失败: {e}")
return []
def find_by_field(self, field: str, value: Any, include_deleted: bool = False) -> List[Dict[str, Any]]:
"""
根据字段查询
Args:
field: 字段名
value: 字段值
include_deleted: 是否包含已删除记录
Returns:
记录列表
"""
where = self._build_where_clause(include_deleted)
if where:
where += f" AND {field} = %s"
else:
where = f"WHERE {field} = %s"
query = f"SELECT * FROM {self.table_name} {where}"
try:
return self._db.execute_query(query, (value,))
except Exception as e:
logger.error(f"查询 {self.table_name} 失败: {e}")
return []
def find_one_by_field(self, field: str, value: Any, include_deleted: bool = False) -> Optional[Dict[str, Any]]:
"""
根据字段查询单条记录
Args:
field: 字段名
value: 字段值
include_deleted: 是否包含已删除记录
Returns:
记录字典或 None
"""
results = self.find_by_field(field, value, include_deleted)
return results[0] if results else None
def count(self, include_deleted: bool = False) -> int:
"""
统计记录数
Args:
include_deleted: 是否包含已删除记录
Returns:
记录数
"""
where = self._build_where_clause(include_deleted)
query = f"SELECT COUNT(*) as count FROM {self.table_name} {where}"
try:
result = self._db.execute_one(query)
return result.get('count', 0) if result else 0
except Exception as e:
logger.error(f"统计 {self.table_name} 失败: {e}")
return 0
def exists(self, id: Any) -> bool:
"""
检查记录是否存在
Args:
id: 主键值
Returns:
是否存在
"""
return self.find_by_id(id) is not None