228 lines
7.1 KiB
Python
228 lines
7.1 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
数据库连接管理
|
|||
|
|
提供连接池和连接获取功能
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import logging
|
|||
|
|
import os
|
|||
|
|
import json
|
|||
|
|
from typing import Dict, Any, Optional
|
|||
|
|
from pathlib import Path
|
|||
|
|
from contextlib import contextmanager
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
import mysql.connector
|
|||
|
|
from mysql.connector import pooling
|
|||
|
|
MYSQL_AVAILABLE = True
|
|||
|
|
except ImportError:
|
|||
|
|
MYSQL_AVAILABLE = False
|
|||
|
|
logger.warning("mysql-connector-python 未安装")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DatabaseConnection:
|
|||
|
|
"""
|
|||
|
|
数据库连接管理器
|
|||
|
|
|
|||
|
|
提供:
|
|||
|
|
- 连接池管理
|
|||
|
|
- 连接获取和释放
|
|||
|
|
- 配置加载
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
_instance: Optional['DatabaseConnection'] = None
|
|||
|
|
|
|||
|
|
def __new__(cls, config_path: Optional[str] = None):
|
|||
|
|
"""单例模式"""
|
|||
|
|
if cls._instance is None:
|
|||
|
|
cls._instance = super().__new__(cls)
|
|||
|
|
cls._instance._initialized = False
|
|||
|
|
return cls._instance
|
|||
|
|
|
|||
|
|
def __init__(self, config_path: Optional[str] = None):
|
|||
|
|
if self._initialized:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
self._config_path = config_path
|
|||
|
|
self._pool = None
|
|||
|
|
self._config: Dict[str, Any] = {}
|
|||
|
|
|
|||
|
|
self._load_config()
|
|||
|
|
self._init_pool()
|
|||
|
|
self._initialized = True
|
|||
|
|
|
|||
|
|
def _load_config(self):
|
|||
|
|
"""加载数据库配置"""
|
|||
|
|
# 尝试从配置文件加载
|
|||
|
|
if self._config_path:
|
|||
|
|
config_file = Path(self._config_path)
|
|||
|
|
else:
|
|||
|
|
# 默认配置路径
|
|||
|
|
config_file = Path(__file__).parent.parent.parent / "config" / "database.json"
|
|||
|
|
|
|||
|
|
if config_file.exists():
|
|||
|
|
try:
|
|||
|
|
with open(config_file, 'r', encoding='utf-8') as f:
|
|||
|
|
raw_config = json.load(f)
|
|||
|
|
self._config = self._process_env_vars(raw_config)
|
|||
|
|
logger.info(f"从配置文件加载数据库配置: {config_file}")
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"加载配置文件失败: {e}")
|
|||
|
|
|
|||
|
|
# 使用环境变量作为兜底
|
|||
|
|
if not self._config:
|
|||
|
|
self._config = {
|
|||
|
|
'host': os.getenv('DB_HOST', 'localhost'),
|
|||
|
|
'user': os.getenv('DB_USER', 'root'),
|
|||
|
|
'password': os.getenv('DB_PASSWORD', ''),
|
|||
|
|
'database': os.getenv('DB_NAME', 'travel_content'),
|
|||
|
|
'port': int(os.getenv('DB_PORT', '3306')),
|
|||
|
|
'charset': 'utf8mb4',
|
|||
|
|
'pool_size': int(os.getenv('DB_POOL_SIZE', '10'))
|
|||
|
|
}
|
|||
|
|
logger.info("使用环境变量构建数据库配置")
|
|||
|
|
|
|||
|
|
def _process_env_vars(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
|||
|
|
"""处理配置中的环境变量引用"""
|
|||
|
|
processed = {}
|
|||
|
|
for key, value in config.items():
|
|||
|
|
if isinstance(value, str) and value.startswith('${') and value.endswith('}'):
|
|||
|
|
env_key = value[2:-1]
|
|||
|
|
processed[key] = os.getenv(env_key, '')
|
|||
|
|
else:
|
|||
|
|
processed[key] = value
|
|||
|
|
return processed
|
|||
|
|
|
|||
|
|
def _init_pool(self):
|
|||
|
|
"""初始化连接池"""
|
|||
|
|
if not MYSQL_AVAILABLE:
|
|||
|
|
logger.error("mysql-connector-python 未安装,无法创建连接池")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
pool_config = {
|
|||
|
|
'pool_name': 'travel_pool',
|
|||
|
|
'pool_size': self._config.get('pool_size', 10),
|
|||
|
|
'host': self._config.get('host', 'localhost'),
|
|||
|
|
'port': self._config.get('port', 3306),
|
|||
|
|
'user': self._config.get('user', 'root'),
|
|||
|
|
'password': self._config.get('password', ''),
|
|||
|
|
'database': self._config.get('database', 'travel_content'),
|
|||
|
|
'charset': self._config.get('charset', 'utf8mb4'),
|
|||
|
|
'autocommit': True,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
self._pool = pooling.MySQLConnectionPool(**pool_config)
|
|||
|
|
logger.info(f"数据库连接池初始化成功: {pool_config['host']}:{pool_config['port']}/{pool_config['database']}")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"初始化数据库连接池失败: {e}")
|
|||
|
|
self._pool = None
|
|||
|
|
|
|||
|
|
@contextmanager
|
|||
|
|
def get_connection(self):
|
|||
|
|
"""
|
|||
|
|
获取数据库连接 (上下文管理器)
|
|||
|
|
|
|||
|
|
Usage:
|
|||
|
|
with db.get_connection() as conn:
|
|||
|
|
cursor = conn.cursor(dictionary=True)
|
|||
|
|
cursor.execute("SELECT * FROM ...")
|
|||
|
|
"""
|
|||
|
|
if not self._pool:
|
|||
|
|
raise RuntimeError("数据库连接池未初始化")
|
|||
|
|
|
|||
|
|
conn = None
|
|||
|
|
try:
|
|||
|
|
conn = self._pool.get_connection()
|
|||
|
|
yield conn
|
|||
|
|
finally:
|
|||
|
|
if conn:
|
|||
|
|
conn.close()
|
|||
|
|
|
|||
|
|
@contextmanager
|
|||
|
|
def get_cursor(self, dictionary: bool = True):
|
|||
|
|
"""
|
|||
|
|
获取数据库游标 (上下文管理器)
|
|||
|
|
|
|||
|
|
Usage:
|
|||
|
|
with db.get_cursor() as cursor:
|
|||
|
|
cursor.execute("SELECT * FROM ...")
|
|||
|
|
results = cursor.fetchall()
|
|||
|
|
"""
|
|||
|
|
with self.get_connection() as conn:
|
|||
|
|
cursor = conn.cursor(dictionary=dictionary)
|
|||
|
|
try:
|
|||
|
|
yield cursor
|
|||
|
|
finally:
|
|||
|
|
cursor.close()
|
|||
|
|
|
|||
|
|
def execute_query(self, query: str, params: tuple = None) -> list:
|
|||
|
|
"""
|
|||
|
|
执行查询并返回结果
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
query: SQL 查询语句
|
|||
|
|
params: 查询参数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
查询结果列表
|
|||
|
|
"""
|
|||
|
|
with self.get_cursor() as cursor:
|
|||
|
|
cursor.execute(query, params or ())
|
|||
|
|
return cursor.fetchall()
|
|||
|
|
|
|||
|
|
def execute_one(self, query: str, params: tuple = None) -> Optional[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
执行查询并返回单条结果
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
query: SQL 查询语句
|
|||
|
|
params: 查询参数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
单条结果或 None
|
|||
|
|
"""
|
|||
|
|
with self.get_cursor() as cursor:
|
|||
|
|
cursor.execute(query, params or ())
|
|||
|
|
return cursor.fetchone()
|
|||
|
|
|
|||
|
|
def execute_update(self, query: str, params: tuple = None) -> int:
|
|||
|
|
"""
|
|||
|
|
执行更新语句
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
query: SQL 更新语句
|
|||
|
|
params: 更新参数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
影响的行数
|
|||
|
|
"""
|
|||
|
|
with self.get_connection() as conn:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
try:
|
|||
|
|
cursor.execute(query, params or ())
|
|||
|
|
conn.commit()
|
|||
|
|
return cursor.rowcount
|
|||
|
|
finally:
|
|||
|
|
cursor.close()
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def is_connected(self) -> bool:
|
|||
|
|
"""检查连接池是否可用"""
|
|||
|
|
return self._pool is not None
|
|||
|
|
|
|||
|
|
def get_info(self) -> Dict[str, Any]:
|
|||
|
|
"""获取连接信息"""
|
|||
|
|
return {
|
|||
|
|
'host': self._config.get('host'),
|
|||
|
|
'port': self._config.get('port'),
|
|||
|
|
'database': self._config.get('database'),
|
|||
|
|
'pool_size': self._config.get('pool_size'),
|
|||
|
|
'is_connected': self.is_connected,
|
|||
|
|
}
|