#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Repository 基类 提供通用的 CRUD 操作 """ import logging from typing import Dict, Any, Optional, List, TypeVar, Generic from abc import ABC, abstractmethod from .connection import DatabaseConnection logger = logging.getLogger(__name__) T = TypeVar('T') class BaseRepository(ABC, Generic[T]): """ Repository 基类 提供: - 通用 CRUD 操作 - 批量查询 - 分页查询 """ def __init__(self, db: Optional[DatabaseConnection] = None): """ 初始化 Repository Args: db: 数据库连接,不传则使用单例 """ self._db = db or DatabaseConnection() @property @abstractmethod def table_name(self) -> str: """表名""" pass @property def primary_key(self) -> str: """主键字段名""" return 'id' @property def soft_delete_field(self) -> Optional[str]: """软删除字段名,None 表示不使用软删除""" return 'is_delete' def _build_where_clause(self, include_deleted: bool = False) -> str: """构建 WHERE 子句(处理软删除)""" if self.soft_delete_field and not include_deleted: return f"WHERE {self.soft_delete_field} = 0" return "" def find_by_id(self, id: Any, include_deleted: bool = False) -> Optional[Dict[str, Any]]: """ 根据 ID 查询 Args: id: 主键值 include_deleted: 是否包含已删除记录 Returns: 记录字典或 None """ where = self._build_where_clause(include_deleted) if where: where += f" AND {self.primary_key} = %s" else: where = f"WHERE {self.primary_key} = %s" query = f"SELECT * FROM {self.table_name} {where}" try: return self._db.execute_one(query, (id,)) except Exception as e: logger.error(f"查询 {self.table_name} 失败: {e}") return None def find_by_ids(self, ids: List[Any], include_deleted: bool = False) -> List[Dict[str, Any]]: """ 根据 ID 列表批量查询 Args: ids: 主键值列表 include_deleted: 是否包含已删除记录 Returns: 记录列表 """ if not ids: return [] placeholders = ','.join(['%s'] * len(ids)) where = self._build_where_clause(include_deleted) if where: where += f" AND {self.primary_key} IN ({placeholders})" else: where = f"WHERE {self.primary_key} IN ({placeholders})" query = f"SELECT * FROM {self.table_name} {where}" try: return self._db.execute_query(query, tuple(ids)) except Exception as e: logger.error(f"批量查询 {self.table_name} 失败: {e}") return [] def find_all(self, limit: int = 100, offset: int = 0, include_deleted: bool = False) -> List[Dict[str, Any]]: """ 查询所有记录 Args: limit: 最大返回数量 offset: 偏移量 include_deleted: 是否包含已删除记录 Returns: 记录列表 """ where = self._build_where_clause(include_deleted) query = f"SELECT * FROM {self.table_name} {where} LIMIT %s OFFSET %s" try: return self._db.execute_query(query, (limit, offset)) except Exception as e: logger.error(f"查询 {self.table_name} 列表失败: {e}") return [] def find_by_field(self, field: str, value: Any, include_deleted: bool = False) -> List[Dict[str, Any]]: """ 根据字段查询 Args: field: 字段名 value: 字段值 include_deleted: 是否包含已删除记录 Returns: 记录列表 """ where = self._build_where_clause(include_deleted) if where: where += f" AND {field} = %s" else: where = f"WHERE {field} = %s" query = f"SELECT * FROM {self.table_name} {where}" try: return self._db.execute_query(query, (value,)) except Exception as e: logger.error(f"查询 {self.table_name} 失败: {e}") return [] def find_one_by_field(self, field: str, value: Any, include_deleted: bool = False) -> Optional[Dict[str, Any]]: """ 根据字段查询单条记录 Args: field: 字段名 value: 字段值 include_deleted: 是否包含已删除记录 Returns: 记录字典或 None """ results = self.find_by_field(field, value, include_deleted) return results[0] if results else None def count(self, include_deleted: bool = False) -> int: """ 统计记录数 Args: include_deleted: 是否包含已删除记录 Returns: 记录数 """ where = self._build_where_clause(include_deleted) query = f"SELECT COUNT(*) as count FROM {self.table_name} {where}" try: result = self._db.execute_one(query) return result.get('count', 0) if result else 0 except Exception as e: logger.error(f"统计 {self.table_name} 失败: {e}") return 0 def exists(self, id: Any) -> bool: """ 检查记录是否存在 Args: id: 主键值 Returns: 是否存在 """ return self.find_by_id(id) is not None