TravelContentCreator/api/services/prompt_service.py

651 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
提示词服务层
负责提示词的存储、检索和构建
"""
import logging
import json
import os
import re
import sys
import traceback
import time
from typing import Dict, Any, Optional, List, cast
from pathlib import Path
import mysql.connector
from mysql.connector import pooling
import random
from core.config import ConfigManager, ResourceConfig
from utils.prompts import BasePromptBuilder, PromptTemplate
from utils.file_io import ResourceLoader
logger = logging.getLogger(__name__)
class PromptService:
"""提示词服务类"""
def __init__(self, config_manager: ConfigManager):
"""
初始化提示词服务
Args:
config_manager: 配置管理器
"""
self.config_manager = config_manager
# 按需加载resource配置
if 'resource' not in self.config_manager._loaded_configs:
self.config_manager.load_task_config('resource')
self.resource_config = self.config_manager.get_config('resource', ResourceConfig)
# ResourceLoader是静态类不需要实例化
self.resource_dirs = self.resource_config.resource_dirs
# 初始化数据库连接池
self.db_pool = self._init_db_pool()
# 创建必要的目录结构
# self._create_resource_directories()
def _create_resource_directories(self):
pass
"""创建必要的资源目录"""
try:
dirs_to_create = ['styles', 'attractions', 'products', 'audiences']
for dir_name in dirs_to_create:
dir_path = Path(dir_name)
if not dir_path.exists():
dir_path.mkdir(parents=True, exist_ok=True)
logger.info(f"创建目录: {dir_path}")
except Exception as e:
logger.error(f"创建资源目录失败: {e}")
def _process_env_vars(self, config: Dict[str, Any]) -> Dict[str, Any]:
"""处理配置中的环境变量"""
processed = {}
for key, value in config.items():
if isinstance(value, str) and "${" in value:
# 匹配 ${ENV_VAR:-default} 格式
pattern = r'\${([^:-]+)(?::-([^}]+))?}'
match = re.match(pattern, value)
if match:
env_var, default = match.groups()
processed_value = os.environ.get(env_var, default or "")
# 尝试转换为数字
if key == "port":
try:
processed_value = int(processed_value)
except (ValueError, TypeError):
processed_value = 3306
processed[key] = processed_value
else:
processed[key] = value
else:
processed[key] = value
return processed
def _init_db_pool(self):
"""初始化数据库连接池,尝试多种连接方式"""
# 获取数据库配置
raw_db_config = self.config_manager.get_raw_config('database')
# 处理环境变量
db_config = self._process_env_vars(raw_db_config)
# 连接尝试配置
connection_attempts = [
{"desc": "使用配置文件中的设置", "config": db_config},
{"desc": "使用明确的密码", "config": {**db_config, "password": "password"}},
{"desc": "使用空密码", "config": {**db_config, "password": ""}},
{"desc": "使用auth_plugin", "config": {**db_config, "auth_plugin": "mysql_native_password"}}
]
# 尝试不同的连接方式
for attempt in connection_attempts:
try:
# 打印连接信息(不包含密码)
connection_info = {k: v for k, v in attempt["config"].items() if k != 'password'}
logger.info(f"尝试连接数据库 ({attempt['desc']}): {connection_info}")
# 创建连接池
# 从配置中分离MySQL连接池支持的参数和不支持的参数
config = attempt["config"].copy()
# MySQL连接池不支持的参数需要移除
unsupported_params = [
'max_retry_attempts', 'query_timeout', 'soft_delete_field', 'active_record_value'
]
for param in unsupported_params:
config.pop(param, None)
# 设置连接池参数,使用配置文件中的值或默认值
pool_size = config.pop('pool_size', 5)
pool = pooling.MySQLConnectionPool(
pool_name=f"prompt_service_pool_{int(time.time())}",
pool_size=pool_size,
**config
)
# 测试连接
with pool.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT 1")
cursor.fetchall()
logger.info(f"数据库连接池初始化成功 ({attempt['desc']})")
return pool
except Exception as e:
error_details = traceback.format_exc()
logger.error(f"数据库连接尝试 ({attempt['desc']}) 失败: {e}\n{error_details}")
logger.warning("所有数据库连接尝试均失败,将使用文件系统作为数据源")
return None
def get_style_content(self, styleName: str) -> str:
"""获取风格提示词内容"""
if self.db_pool:
try:
with self.db_pool.get_connection() as conn:
with conn.cursor(dictionary=True) as cursor:
cursor.execute(
"SELECT description FROM contentStyle WHERE styleName = %s",
(styleName,)
)
result = cursor.fetchone()
if result:
logger.info(f"从数据库获取风格提示词: {styleName}")
return result['description']
except Exception as e:
logger.error(f"从数据库获取风格提示词失败: {e}")
# 从文件系统获取
logger.info(f"从文件系统获取风格提示词: {styleName}")
for dir_path in self.resource_dirs:
style_file = ResourceLoader.find_file(os.path.join(dir_path, "styles"), styleName)
if style_file:
content = ResourceLoader.load_text_file(style_file)
if content:
return content
return f"未找到风格 '{styleName}' 的提示词"
def get_audience_content(self, audience_name: str) -> str:
"""
获取目标受众提示词
Args:
audience_name: 受众名称
Returns:
受众提示词内容
"""
# 优先从数据库获取
if self.db_pool:
try:
with self.db_pool.get_connection() as conn:
with conn.cursor(dictionary=True) as cursor:
cursor.execute(
"SELECT description FROM targetAudience WHERE audienceName = %s",
(audience_name,)
)
result = cursor.fetchone()
if result:
logger.info(f"从数据库获取受众提示词: {audience_name}")
return result["description"]
except Exception as e:
logger.error(f"从数据库获取受众提示词失败: {e}")
# 回退到文件系统
try:
demand_paths = self.resource_config.demand.paths
for path_str in demand_paths:
try:
if audience_name in path_str:
full_path = self._get_full_path(path_str)
if full_path.exists():
logger.info(f"从文件系统获取受众提示词: {audience_name}")
return full_path.read_text('utf-8')
except Exception as e:
logger.error(f"读取受众文件失败 {path_str}: {e}")
# 如果没有精确匹配,尝试模糊匹配
for path_str in demand_paths:
try:
full_path = self._get_full_path(path_str)
if full_path.exists() and full_path.is_file():
content = full_path.read_text('utf-8')
if audience_name.lower() in full_path.stem.lower():
logger.info(f"通过模糊匹配找到受众提示词: {audience_name} -> {full_path.name}")
return content
except Exception as e:
logger.error(f"读取受众文件失败 {path_str}: {e}")
except Exception as e:
logger.error(f"获取受众提示词失败: {e}")
logger.warning(f"未找到受众提示词: {audience_name},将使用默认值")
return "通用用户画像"
def get_scenic_spot_info(self, spot_name: str) -> str:
"""
获取景区信息
Args:
spot_name: 景区名称
Returns:
景区信息内容
"""
if self.db_pool:
try:
with self.db_pool.get_connection() as conn:
with conn.cursor(dictionary=True) as cursor:
cursor.execute(
"SELECT description FROM scenicSpot WHERE name = %s",
(spot_name,)
)
result = cursor.fetchone()
if result:
logger.info(f"从数据库获取景区信息: {spot_name}")
return result["description"]
except Exception as e:
logger.error(f"从数据库获取景区信息失败: {e}")
# 回退到文件系统
try:
object_paths = self.resource_config.object.paths
for path_str in object_paths:
try:
if spot_name in path_str:
full_path = self._get_full_path(path_str)
if full_path.exists():
logger.info(f"从文件系统获取景区信息: {spot_name}")
return full_path.read_text('utf-8')
except Exception as e:
logger.error(f"读取景区文件失败 {path_str}: {e}")
except Exception as e:
logger.error(f"获取景区信息失败: {e}")
logger.warning(f"未找到景区信息: {spot_name}")
return ""
def get_product_info(self, product_name: str) -> str:
"""
获取产品信息
Args:
product_name: 产品名称
Returns:
产品信息内容
"""
if self.db_pool:
try:
with self.db_pool.get_connection() as conn:
with conn.cursor(dictionary=True) as cursor:
cursor.execute(
"SELECT detailedDescription FROM product WHERE productName = %s",
(product_name,)
)
result = cursor.fetchone()
if result:
logger.info(f"从数据库获取产品信息: {product_name}")
return result["detailedDescription"]
except Exception as e:
logger.error(f"从数据库获取产品信息失败: {e}")
# 回退到文件系统
try:
product_paths = self.resource_config.product.paths
for path_str in product_paths:
try:
if product_name in path_str:
full_path = self._get_full_path(path_str)
if full_path.exists():
logger.info(f"从文件系统获取产品信息: {product_name}")
return full_path.read_text('utf-8')
except Exception as e:
logger.error(f"读取产品文件失败 {path_str}: {e}")
except Exception as e:
logger.error(f"获取产品信息失败: {e}")
logger.warning(f"未找到产品信息: {product_name}")
return ""
def get_refer_content(self, step: str = "") -> str:
"""
获取参考内容
Args:
step: 当前步骤,用于过滤参考内容
Returns:
参考内容
"""
refer_content = "参考内容:\n"
# 从文件系统获取参考内容
try:
refer_list = self.resource_config.refer.refer_list
filtered_configs = [
item for item in refer_list
if not item.step or item.step == step
]
for ref_item in filtered_configs:
try:
path_str = ref_item.path
full_path = self._get_full_path(path_str)
if full_path.exists() and full_path.is_file():
if full_path.suffix.lower() == '.json':
# 处理JSON文件
with open(full_path, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, dict) and 'examples' in data:
examples = data['examples']
if isinstance(examples, list):
sample_size = max(1, int(len(examples) * ref_item.sampling_rate))
sampled_examples = random.sample(examples, sample_size)
sampled_content = json.dumps({'examples': sampled_examples}, ensure_ascii=False, indent=4)
elif isinstance(data, list):
sample_size = max(1, int(len(data) * ref_item.sampling_rate))
sampled_examples = random.sample(data, sample_size)
sampled_content = json.dumps(sampled_examples, ensure_ascii=False, indent=4)
else:
# 如果不是预期结构,按原方式处理
with open(full_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
sample_size = max(1, int(len(lines) * ref_item.sampling_rate))
sampled_lines = random.sample(lines, sample_size)
sampled_content = ''.join(sampled_lines)
else:
# 非JSON文件按原方式处理
with open(full_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
sample_size = max(1, int(len(lines) * ref_item.sampling_rate))
sampled_lines = random.sample(lines, sample_size)
sampled_content = ''.join(sampled_lines)
refer_content += f"--- {full_path.name} (sampled {ref_item.sampling_rate * 100}%) ---\n{sampled_content}\n\n"
except Exception as e:
logger.error(f"读取或采样参考文件失败 {ref_item.path}: {e}")
except Exception as e:
logger.error(f"获取参考内容失败: {e}")
return refer_content
def _get_full_path(self, path_str: str) -> Path:
"""根据基准目录解析相对或绝对路径"""
if not self.resource_config.resource_dirs:
raise ValueError("Resource directories list is empty in config.")
base_path = Path(self.resource_config.resource_dirs[0])
file_path = Path(path_str)
return file_path if file_path.is_absolute() else (base_path / file_path).resolve()
def get_all_styles(self) -> List[Dict[str, str]]:
"""
获取所有内容风格
Returns:
风格列表每个元素包含name和description
"""
styles = []
# 优先从数据库获取
if self.db_pool:
try:
with self.db_pool.get_connection() as conn:
with conn.cursor(dictionary=True) as cursor:
cursor.execute("SELECT styleName as name, description FROM contentStyle")
results = cursor.fetchall()
if results:
logger.info(f"从数据库获取所有风格: {len(results)}")
return results
except Exception as e:
logger.error(f"从数据库获取所有风格失败: {e}")
# 回退到文件系统
try:
style_paths = self.resource_config.style.paths
for path_str in style_paths:
try:
full_path = self._get_full_path(path_str)
if full_path.exists() and full_path.is_file():
content = full_path.read_text('utf-8')
name = full_path.stem
if "文案提示词" in name:
name = name.replace("文案提示词", "")
styles.append({
"name": name,
"description": content[:100] + "..." if len(content) > 100 else content
})
except Exception as e:
logger.error(f"读取风格文件失败 {path_str}: {e}")
except Exception as e:
logger.error(f"获取所有风格失败: {e}")
return styles
def get_all_audiences(self) -> List[Dict[str, str]]:
"""
获取所有目标受众
Returns:
受众列表每个元素包含name和description
"""
audiences = []
# 优先从数据库获取
if self.db_pool:
try:
with self.db_pool.get_connection() as conn:
with conn.cursor(dictionary=True) as cursor:
cursor.execute("SELECT audienceName as name, description FROM targetAudience")
results = cursor.fetchall()
if results:
logger.info(f"从数据库获取所有受众: {len(results)}")
return results
except Exception as e:
logger.error(f"从数据库获取所有受众失败: {e}")
# 回退到文件系统
try:
demand_paths = self.resource_config.demand.paths
for path_str in demand_paths:
try:
full_path = self._get_full_path(path_str)
if full_path.exists() and full_path.is_file():
content = full_path.read_text('utf-8')
name = full_path.stem
if "文旅需求" in name:
name = name.replace("文旅需求", "")
audiences.append({
"name": name,
"description": content[:100] + "..." if len(content) > 100 else content
})
except Exception as e:
logger.error(f"读取受众文件失败 {path_str}: {e}")
except Exception as e:
logger.error(f"获取所有受众失败: {e}")
return audiences
def get_all_scenic_spots(self) -> List[Dict[str, str]]:
"""
获取所有景区
Returns:
景区列表每个元素包含name和description
"""
spots = []
# 优先从数据库获取
if self.db_pool:
try:
with self.db_pool.get_connection() as conn:
with conn.cursor(dictionary=True) as cursor:
cursor.execute("SELECT name as name, description FROM scenicSpot")
results = cursor.fetchall()
if results:
logger.info(f"从数据库获取所有景区: {len(results)}")
return [{
"name": item["name"],
"description": item["description"][:100] + "..." if len(item["description"]) > 100 else item["description"]
} for item in results]
except Exception as e:
logger.error(f"从数据库获取所有景区失败: {e}")
# 回退到文件系统
try:
object_paths = self.resource_config.object.paths
for path_str in object_paths:
try:
full_path = self._get_full_path(path_str)
if full_path.exists() and full_path.is_file():
content = full_path.read_text('utf-8')
spots.append({
"name": full_path.stem,
"description": content[:100] + "..." if len(content) > 100 else content
})
except Exception as e:
logger.error(f"读取景区文件失败 {path_str}: {e}")
except Exception as e:
logger.error(f"获取所有景区失败: {e}")
return spots
def save_style(self, name: str, description: str) -> bool:
"""
保存内容风格
Args:
name: 风格名称
description: 风格描述
Returns:
是否保存成功
"""
# 优先保存到数据库
if self.db_pool:
try:
with self.db_pool.get_connection() as conn:
with conn.cursor() as cursor:
# 检查是否存在
cursor.execute(
"SELECT COUNT(*) FROM contentStyle WHERE styleName = %s",
(name,)
)
count = cursor.fetchone()[0]
if count > 0:
# 更新
cursor.execute(
"UPDATE contentStyle SET description = %s WHERE styleName = %s",
(description, name)
)
else:
# 插入
cursor.execute(
"INSERT INTO contentStyle (styleName, description) VALUES (%s, %s)",
(name, description)
)
conn.commit()
logger.info(f"风格保存到数据库成功: {name}")
return True
except Exception as e:
logger.error(f"风格保存到数据库失败: {e}")
# 回退到文件系统
try:
# 确保风格目录存在
style_dir = Path(self.resource_config.resource_dirs[0]) / "resource" / "prompt" / "Style"
style_dir.mkdir(parents=True, exist_ok=True)
# 保存文件
file_path = style_dir / f"{name}文案提示词.md"
with open(file_path, 'w', encoding='utf-8') as f:
f.write(description)
# 更新配置
if str(file_path) not in self.resource_config.style.paths:
self.resource_config.style.paths.append(str(file_path))
logger.info(f"风格保存到文件系统成功: {file_path}")
return True
except Exception as e:
logger.error(f"风格保存到文件系统失败: {e}")
return False
def save_audience(self, name: str, description: str) -> bool:
"""
保存目标受众
Args:
name: 受众名称
description: 受众描述
Returns:
是否保存成功
"""
# 优先保存到数据库
if self.db_pool:
try:
with self.db_pool.get_connection() as conn:
with conn.cursor() as cursor:
# 检查是否存在
cursor.execute(
"SELECT COUNT(*) FROM targetAudience WHERE audienceName = %s",
(name,)
)
count = cursor.fetchone()[0]
if count > 0:
# 更新
cursor.execute(
"UPDATE targetAudience SET description = %s WHERE audienceName = %s",
(description, name)
)
else:
# 插入
cursor.execute(
"INSERT INTO targetAudience (audienceName, description) VALUES (%s, %s)",
(name, description)
)
conn.commit()
logger.info(f"受众保存到数据库成功: {name}")
return True
except Exception as e:
logger.error(f"受众保存到数据库失败: {e}")
# 回退到文件系统
try:
# 确保受众目录存在
demand_dir = Path(self.resource_config.resource_dirs[0]) / "resource" / "prompt" / "Demand"
demand_dir.mkdir(parents=True, exist_ok=True)
# 保存文件
file_path = demand_dir / f"{name}文旅需求.md"
with open(file_path, 'w', encoding='utf-8') as f:
f.write(description)
# 更新配置
if str(file_path) not in self.resource_config.demand.paths:
self.resource_config.demand.paths.append(str(file_path))
logger.info(f"受众保存到文件系统成功: {file_path}")
return True
except Exception as e:
logger.error(f"受众保存到文件系统失败: {e}")
return False