#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 数据库连接管理 提供连接池和连接获取功能 """ import logging import os import json from typing import Dict, Any, Optional from pathlib import Path from contextlib import contextmanager logger = logging.getLogger(__name__) try: import mysql.connector from mysql.connector import pooling MYSQL_AVAILABLE = True except ImportError: MYSQL_AVAILABLE = False logger.warning("mysql-connector-python 未安装") class DatabaseConnection: """ 数据库连接管理器 提供: - 连接池管理 - 连接获取和释放 - 配置加载 """ _instance: Optional['DatabaseConnection'] = None def __new__(cls, config_path: Optional[str] = None): """单例模式""" if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self, config_path: Optional[str] = None): if self._initialized: return self._config_path = config_path self._pool = None self._config: Dict[str, Any] = {} self._load_config() self._init_pool() self._initialized = True def _load_config(self): """加载数据库配置""" # 尝试从配置文件加载 if self._config_path: config_file = Path(self._config_path) else: # 默认配置路径 config_file = Path(__file__).parent.parent.parent / "config" / "database.json" if config_file.exists(): try: with open(config_file, 'r', encoding='utf-8') as f: raw_config = json.load(f) self._config = self._process_env_vars(raw_config) logger.info(f"从配置文件加载数据库配置: {config_file}") except Exception as e: logger.warning(f"加载配置文件失败: {e}") # 使用环境变量作为兜底 if not self._config: self._config = { 'host': os.getenv('DB_HOST', 'localhost'), 'user': os.getenv('DB_USER', 'root'), 'password': os.getenv('DB_PASSWORD', ''), 'database': os.getenv('DB_NAME', 'travel_content'), 'port': int(os.getenv('DB_PORT', '3306')), 'charset': 'utf8mb4', 'pool_size': int(os.getenv('DB_POOL_SIZE', '10')) } logger.info("使用环境变量构建数据库配置") def _process_env_vars(self, config: Dict[str, Any]) -> Dict[str, Any]: """处理配置中的环境变量引用""" processed = {} for key, value in config.items(): if isinstance(value, str) and value.startswith('${') and value.endswith('}'): env_key = value[2:-1] processed[key] = os.getenv(env_key, '') else: processed[key] = value return processed def _init_pool(self): """初始化连接池""" if not MYSQL_AVAILABLE: logger.error("mysql-connector-python 未安装,无法创建连接池") return try: pool_config = { 'pool_name': 'travel_pool', 'pool_size': self._config.get('pool_size', 10), 'host': self._config.get('host', 'localhost'), 'port': self._config.get('port', 3306), 'user': self._config.get('user', 'root'), 'password': self._config.get('password', ''), 'database': self._config.get('database', 'travel_content'), 'charset': self._config.get('charset', 'utf8mb4'), 'autocommit': True, } self._pool = pooling.MySQLConnectionPool(**pool_config) logger.info(f"数据库连接池初始化成功: {pool_config['host']}:{pool_config['port']}/{pool_config['database']}") except Exception as e: logger.error(f"初始化数据库连接池失败: {e}") self._pool = None @contextmanager def get_connection(self): """ 获取数据库连接 (上下文管理器) Usage: with db.get_connection() as conn: cursor = conn.cursor(dictionary=True) cursor.execute("SELECT * FROM ...") """ if not self._pool: raise RuntimeError("数据库连接池未初始化") conn = None try: conn = self._pool.get_connection() yield conn finally: if conn: conn.close() @contextmanager def get_cursor(self, dictionary: bool = True): """ 获取数据库游标 (上下文管理器) Usage: with db.get_cursor() as cursor: cursor.execute("SELECT * FROM ...") results = cursor.fetchall() """ with self.get_connection() as conn: cursor = conn.cursor(dictionary=dictionary) try: yield cursor finally: cursor.close() def execute_query(self, query: str, params: tuple = None) -> list: """ 执行查询并返回结果 Args: query: SQL 查询语句 params: 查询参数 Returns: 查询结果列表 """ with self.get_cursor() as cursor: cursor.execute(query, params or ()) return cursor.fetchall() def execute_one(self, query: str, params: tuple = None) -> Optional[Dict[str, Any]]: """ 执行查询并返回单条结果 Args: query: SQL 查询语句 params: 查询参数 Returns: 单条结果或 None """ with self.get_cursor() as cursor: cursor.execute(query, params or ()) return cursor.fetchone() def execute_update(self, query: str, params: tuple = None) -> int: """ 执行更新语句 Args: query: SQL 更新语句 params: 更新参数 Returns: 影响的行数 """ with self.get_connection() as conn: cursor = conn.cursor() try: cursor.execute(query, params or ()) conn.commit() return cursor.rowcount finally: cursor.close() @property def is_connected(self) -> bool: """检查连接池是否可用""" return self._pool is not None def get_info(self) -> Dict[str, Any]: """获取连接信息""" return { 'host': self._config.get('host'), 'port': self._config.get('port'), 'database': self._config.get('database'), 'pool_size': self._config.get('pool_size'), 'is_connected': self.is_connected, }