228 lines
7.1 KiB
Python
Raw Normal View History

2025-12-08 14:58:35 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
数据库连接管理
提供连接池和连接获取功能
"""
import logging
import os
import json
from typing import Dict, Any, Optional
from pathlib import Path
from contextlib import contextmanager
logger = logging.getLogger(__name__)
try:
import mysql.connector
from mysql.connector import pooling
MYSQL_AVAILABLE = True
except ImportError:
MYSQL_AVAILABLE = False
logger.warning("mysql-connector-python 未安装")
class DatabaseConnection:
"""
数据库连接管理器
提供
- 连接池管理
- 连接获取和释放
- 配置加载
"""
_instance: Optional['DatabaseConnection'] = None
def __new__(cls, config_path: Optional[str] = None):
"""单例模式"""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self, config_path: Optional[str] = None):
if self._initialized:
return
self._config_path = config_path
self._pool = None
self._config: Dict[str, Any] = {}
self._load_config()
self._init_pool()
self._initialized = True
def _load_config(self):
"""加载数据库配置"""
# 尝试从配置文件加载
if self._config_path:
config_file = Path(self._config_path)
else:
# 默认配置路径
config_file = Path(__file__).parent.parent.parent / "config" / "database.json"
if config_file.exists():
try:
with open(config_file, 'r', encoding='utf-8') as f:
raw_config = json.load(f)
self._config = self._process_env_vars(raw_config)
logger.info(f"从配置文件加载数据库配置: {config_file}")
except Exception as e:
logger.warning(f"加载配置文件失败: {e}")
# 使用环境变量作为兜底
if not self._config:
self._config = {
'host': os.getenv('DB_HOST', 'localhost'),
'user': os.getenv('DB_USER', 'root'),
'password': os.getenv('DB_PASSWORD', ''),
'database': os.getenv('DB_NAME', 'travel_content'),
'port': int(os.getenv('DB_PORT', '3306')),
'charset': 'utf8mb4',
'pool_size': int(os.getenv('DB_POOL_SIZE', '10'))
}
logger.info("使用环境变量构建数据库配置")
def _process_env_vars(self, config: Dict[str, Any]) -> Dict[str, Any]:
"""处理配置中的环境变量引用"""
processed = {}
for key, value in config.items():
if isinstance(value, str) and value.startswith('${') and value.endswith('}'):
env_key = value[2:-1]
processed[key] = os.getenv(env_key, '')
else:
processed[key] = value
return processed
def _init_pool(self):
"""初始化连接池"""
if not MYSQL_AVAILABLE:
logger.error("mysql-connector-python 未安装,无法创建连接池")
return
try:
pool_config = {
'pool_name': 'travel_pool',
'pool_size': self._config.get('pool_size', 10),
'host': self._config.get('host', 'localhost'),
'port': self._config.get('port', 3306),
'user': self._config.get('user', 'root'),
'password': self._config.get('password', ''),
'database': self._config.get('database', 'travel_content'),
'charset': self._config.get('charset', 'utf8mb4'),
'autocommit': True,
}
self._pool = pooling.MySQLConnectionPool(**pool_config)
logger.info(f"数据库连接池初始化成功: {pool_config['host']}:{pool_config['port']}/{pool_config['database']}")
except Exception as e:
logger.error(f"初始化数据库连接池失败: {e}")
self._pool = None
@contextmanager
def get_connection(self):
"""
获取数据库连接 (上下文管理器)
Usage:
with db.get_connection() as conn:
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM ...")
"""
if not self._pool:
raise RuntimeError("数据库连接池未初始化")
conn = None
try:
conn = self._pool.get_connection()
yield conn
finally:
if conn:
conn.close()
@contextmanager
def get_cursor(self, dictionary: bool = True):
"""
获取数据库游标 (上下文管理器)
Usage:
with db.get_cursor() as cursor:
cursor.execute("SELECT * FROM ...")
results = cursor.fetchall()
"""
with self.get_connection() as conn:
cursor = conn.cursor(dictionary=dictionary)
try:
yield cursor
finally:
cursor.close()
def execute_query(self, query: str, params: tuple = None) -> list:
"""
执行查询并返回结果
Args:
query: SQL 查询语句
params: 查询参数
Returns:
查询结果列表
"""
with self.get_cursor() as cursor:
cursor.execute(query, params or ())
return cursor.fetchall()
def execute_one(self, query: str, params: tuple = None) -> Optional[Dict[str, Any]]:
"""
执行查询并返回单条结果
Args:
query: SQL 查询语句
params: 查询参数
Returns:
单条结果或 None
"""
with self.get_cursor() as cursor:
cursor.execute(query, params or ())
return cursor.fetchone()
def execute_update(self, query: str, params: tuple = None) -> int:
"""
执行更新语句
Args:
query: SQL 更新语句
params: 更新参数
Returns:
影响的行数
"""
with self.get_connection() as conn:
cursor = conn.cursor()
try:
cursor.execute(query, params or ())
conn.commit()
return cursor.rowcount
finally:
cursor.close()
@property
def is_connected(self) -> bool:
"""检查连接池是否可用"""
return self._pool is not None
def get_info(self) -> Dict[str, Any]:
"""获取连接信息"""
return {
'host': self._config.get('host'),
'port': self._config.get('port'),
'database': self._config.get('database'),
'pool_size': self._config.get('pool_size'),
'is_connected': self.is_connected,
}