TravelContentCreator/api/services/prompt_service.py

619 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
提示词服务层
负责提示词的存储、检索和构建
"""
import logging
import json
import os
import re
from typing import Dict, Any, Optional, List, cast
from pathlib import Path
import mysql.connector
from mysql.connector import pooling
from core.config import ConfigManager, ResourceConfig
from utils.prompts import BasePromptBuilder, PromptTemplate
from utils.file_io import ResourceLoader
logger = logging.getLogger(__name__)
class PromptService:
"""提示词服务类"""
def __init__(self, config_manager: ConfigManager):
"""
初始化提示词服务
Args:
config_manager: 配置管理器
"""
self.config_manager = config_manager
self.resource_config: ResourceConfig = config_manager.get_config('resource', ResourceConfig)
# 初始化数据库连接池
self._init_db_pool()
def _init_db_pool(self):
"""初始化数据库连接池"""
try:
# 尝试直接从配置文件加载数据库配置
config_dir = Path("config")
db_config_path = config_dir / "database.json"
if not db_config_path.exists():
logger.warning(f"数据库配置文件不存在: {db_config_path}")
self.db_pool = None
return
# 加载配置文件
with open(db_config_path, 'r', encoding='utf-8') as f:
db_config = json.load(f)
# 处理环境变量
processed_config = {}
for key, value in db_config.items():
if isinstance(value, str) and "${" in value:
# 匹配 ${ENV_VAR:-default} 格式
pattern = r'\${([^:-]+)(?::-([^}]+))?}'
match = re.match(pattern, value)
if match:
env_var, default = match.groups()
processed_value = os.environ.get(env_var, default)
# 尝试转换为数字
if key == "port":
try:
processed_value = int(processed_value)
except (ValueError, TypeError):
processed_value = 3306
processed_config[key] = processed_value
else:
processed_config[key] = value
# 创建连接池
self.db_pool = pooling.MySQLConnectionPool(
pool_name="prompt_pool",
pool_size=5,
host=processed_config.get("host", "localhost"),
user=processed_config.get("user", "root"),
password=processed_config.get("password", ""),
database=processed_config.get("database", "travel_content"),
port=processed_config.get("port", 3306),
charset=processed_config.get("charset", "utf8mb4")
)
logger.info(f"数据库连接池初始化成功,连接到 {processed_config.get('host')}:{processed_config.get('port')}")
except Exception as e:
logger.error(f"初始化数据库连接池失败: {e}")
self.db_pool = None
def get_style_content(self, style_name: str) -> str:
"""
获取内容风格提示词
Args:
style_name: 风格名称
Returns:
风格提示词内容
"""
# 优先从数据库获取
if self.db_pool:
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT description FROM contentStyle WHERE styleName = %s",
(style_name,)
)
result = cursor.fetchone()
cursor.close()
conn.close()
if result:
logger.info(f"从数据库获取风格提示词: {style_name}")
return result["description"]
except Exception as e:
logger.error(f"从数据库获取风格提示词失败: {e}")
# 回退到文件系统
try:
style_paths = self.resource_config.style.paths
for path_str in style_paths:
try:
if style_name in path_str:
full_path = self._get_full_path(path_str)
if full_path.exists():
logger.info(f"从文件系统获取风格提示词: {style_name}")
return full_path.read_text('utf-8')
except Exception as e:
logger.error(f"读取风格文件失败 {path_str}: {e}")
# 如果没有精确匹配,尝试模糊匹配
for path_str in style_paths:
try:
full_path = self._get_full_path(path_str)
if full_path.exists() and full_path.is_file():
content = full_path.read_text('utf-8')
if style_name.lower() in full_path.stem.lower():
logger.info(f"通过模糊匹配找到风格提示词: {style_name} -> {full_path.name}")
return content
except Exception as e:
logger.error(f"读取风格文件失败 {path_str}: {e}")
except Exception as e:
logger.error(f"获取风格提示词失败: {e}")
logger.warning(f"未找到风格提示词: {style_name},将使用默认值")
return "通用风格"
def get_audience_content(self, audience_name: str) -> str:
"""
获取目标受众提示词
Args:
audience_name: 受众名称
Returns:
受众提示词内容
"""
# 优先从数据库获取
if self.db_pool:
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT description FROM targetAudience WHERE audienceName = %s",
(audience_name,)
)
result = cursor.fetchone()
cursor.close()
conn.close()
if result:
logger.info(f"从数据库获取受众提示词: {audience_name}")
return result["description"]
except Exception as e:
logger.error(f"从数据库获取受众提示词失败: {e}")
# 回退到文件系统
try:
demand_paths = self.resource_config.demand.paths
for path_str in demand_paths:
try:
if audience_name in path_str:
full_path = self._get_full_path(path_str)
if full_path.exists():
logger.info(f"从文件系统获取受众提示词: {audience_name}")
return full_path.read_text('utf-8')
except Exception as e:
logger.error(f"读取受众文件失败 {path_str}: {e}")
# 如果没有精确匹配,尝试模糊匹配
for path_str in demand_paths:
try:
full_path = self._get_full_path(path_str)
if full_path.exists() and full_path.is_file():
content = full_path.read_text('utf-8')
if audience_name.lower() in full_path.stem.lower():
logger.info(f"通过模糊匹配找到受众提示词: {audience_name} -> {full_path.name}")
return content
except Exception as e:
logger.error(f"读取受众文件失败 {path_str}: {e}")
except Exception as e:
logger.error(f"获取受众提示词失败: {e}")
logger.warning(f"未找到受众提示词: {audience_name},将使用默认值")
return "通用用户画像"
def get_scenic_spot_info(self, spot_name: str) -> str:
"""
获取景区信息
Args:
spot_name: 景区名称
Returns:
景区信息内容
"""
# 优先从数据库获取
if self.db_pool:
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT description FROM scenicSpot WHERE spotName = %s",
(spot_name,)
)
result = cursor.fetchone()
cursor.close()
conn.close()
if result:
logger.info(f"从数据库获取景区信息: {spot_name}")
return result["description"]
except Exception as e:
logger.error(f"从数据库获取景区信息失败: {e}")
# 回退到文件系统
try:
object_paths = self.resource_config.object.paths
for path_str in object_paths:
try:
if spot_name in path_str:
full_path = self._get_full_path(path_str)
if full_path.exists():
logger.info(f"从文件系统获取景区信息: {spot_name}")
return full_path.read_text('utf-8')
except Exception as e:
logger.error(f"读取景区文件失败 {path_str}: {e}")
except Exception as e:
logger.error(f"获取景区信息失败: {e}")
logger.warning(f"未找到景区信息: {spot_name}")
return ""
def get_product_info(self, product_name: str) -> str:
"""
获取产品信息
Args:
product_name: 产品名称
Returns:
产品信息内容
"""
# 优先从数据库获取
if self.db_pool:
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT description FROM product WHERE productName = %s",
(product_name,)
)
result = cursor.fetchone()
cursor.close()
conn.close()
if result:
logger.info(f"从数据库获取产品信息: {product_name}")
return result["description"]
except Exception as e:
logger.error(f"从数据库获取产品信息失败: {e}")
# 回退到文件系统
try:
product_paths = self.resource_config.product.paths
for path_str in product_paths:
try:
if product_name in path_str:
full_path = self._get_full_path(path_str)
if full_path.exists():
logger.info(f"从文件系统获取产品信息: {product_name}")
return full_path.read_text('utf-8')
except Exception as e:
logger.error(f"读取产品文件失败 {path_str}: {e}")
except Exception as e:
logger.error(f"获取产品信息失败: {e}")
logger.warning(f"未找到产品信息: {product_name}")
return ""
def get_refer_content(self, step: str = "") -> str:
"""
获取参考内容
Args:
step: 当前步骤,用于过滤参考内容
Returns:
参考内容
"""
refer_content = "参考内容:\n"
# 从文件系统获取参考内容
try:
refer_list = self.resource_config.refer.refer_list
filtered_configs = [
item for item in refer_list
if not item.step or item.step == step
]
for ref_item in filtered_configs:
try:
path_str = ref_item.path
full_path = self._get_full_path(path_str)
if full_path.exists() and full_path.is_file():
content = full_path.read_text('utf-8')
refer_content += f"--- {full_path.name} ---\n{content}\n\n"
except Exception as e:
logger.error(f"读取参考文件失败 {ref_item.path}: {e}")
except Exception as e:
logger.error(f"获取参考内容失败: {e}")
return refer_content
def _get_full_path(self, path_str: str) -> Path:
"""根据基准目录解析相对或绝对路径"""
if not self.resource_config.resource_dirs:
raise ValueError("Resource directories list is empty in config.")
base_path = Path(self.resource_config.resource_dirs[0])
file_path = Path(path_str)
return file_path if file_path.is_absolute() else (base_path / file_path).resolve()
def get_all_styles(self) -> List[Dict[str, str]]:
"""
获取所有内容风格
Returns:
风格列表每个元素包含name和description
"""
styles = []
# 优先从数据库获取
if self.db_pool:
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT styleName as name, description FROM contentStyle")
results = cursor.fetchall()
cursor.close()
conn.close()
if results:
logger.info(f"从数据库获取所有风格: {len(results)}")
return results
except Exception as e:
logger.error(f"从数据库获取所有风格失败: {e}")
# 回退到文件系统
try:
style_paths = self.resource_config.style.paths
for path_str in style_paths:
try:
full_path = self._get_full_path(path_str)
if full_path.exists() and full_path.is_file():
content = full_path.read_text('utf-8')
name = full_path.stem
if "文案提示词" in name:
name = name.replace("文案提示词", "")
styles.append({
"name": name,
"description": content[:100] + "..." if len(content) > 100 else content
})
except Exception as e:
logger.error(f"读取风格文件失败 {path_str}: {e}")
except Exception as e:
logger.error(f"获取所有风格失败: {e}")
return styles
def get_all_audiences(self) -> List[Dict[str, str]]:
"""
获取所有目标受众
Returns:
受众列表每个元素包含name和description
"""
audiences = []
# 优先从数据库获取
if self.db_pool:
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT audienceName as name, description FROM targetAudience")
results = cursor.fetchall()
cursor.close()
conn.close()
if results:
logger.info(f"从数据库获取所有受众: {len(results)}")
return results
except Exception as e:
logger.error(f"从数据库获取所有受众失败: {e}")
# 回退到文件系统
try:
demand_paths = self.resource_config.demand.paths
for path_str in demand_paths:
try:
full_path = self._get_full_path(path_str)
if full_path.exists() and full_path.is_file():
content = full_path.read_text('utf-8')
name = full_path.stem
if "文旅需求" in name:
name = name.replace("文旅需求", "")
audiences.append({
"name": name,
"description": content[:100] + "..." if len(content) > 100 else content
})
except Exception as e:
logger.error(f"读取受众文件失败 {path_str}: {e}")
except Exception as e:
logger.error(f"获取所有受众失败: {e}")
return audiences
def get_all_scenic_spots(self) -> List[Dict[str, str]]:
"""
获取所有景区
Returns:
景区列表每个元素包含name和description
"""
spots = []
# 优先从数据库获取
if self.db_pool:
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT spotName as name, description FROM scenicSpot")
results = cursor.fetchall()
cursor.close()
conn.close()
if results:
logger.info(f"从数据库获取所有景区: {len(results)}")
return [{
"name": item["name"],
"description": item["description"][:100] + "..." if len(item["description"]) > 100 else item["description"]
} for item in results]
except Exception as e:
logger.error(f"从数据库获取所有景区失败: {e}")
# 回退到文件系统
try:
object_paths = self.resource_config.object.paths
for path_str in object_paths:
try:
full_path = self._get_full_path(path_str)
if full_path.exists() and full_path.is_file():
content = full_path.read_text('utf-8')
spots.append({
"name": full_path.stem,
"description": content[:100] + "..." if len(content) > 100 else content
})
except Exception as e:
logger.error(f"读取景区文件失败 {path_str}: {e}")
except Exception as e:
logger.error(f"获取所有景区失败: {e}")
return spots
def save_style(self, name: str, description: str) -> bool:
"""
保存内容风格
Args:
name: 风格名称
description: 风格描述
Returns:
是否保存成功
"""
# 优先保存到数据库
if self.db_pool:
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor()
# 检查是否存在
cursor.execute(
"SELECT COUNT(*) FROM contentStyle WHERE styleName = %s",
(name,)
)
count = cursor.fetchone()[0]
if count > 0:
# 更新
cursor.execute(
"UPDATE contentStyle SET description = %s WHERE styleName = %s",
(description, name)
)
else:
# 插入
cursor.execute(
"INSERT INTO contentStyle (styleName, description) VALUES (%s, %s)",
(name, description)
)
conn.commit()
cursor.close()
conn.close()
logger.info(f"风格保存到数据库成功: {name}")
return True
except Exception as e:
logger.error(f"风格保存到数据库失败: {e}")
# 回退到文件系统
try:
# 确保风格目录存在
style_dir = Path(self.resource_config.resource_dirs[0]) / "resource" / "prompt" / "Style"
style_dir.mkdir(parents=True, exist_ok=True)
# 保存文件
file_path = style_dir / f"{name}文案提示词.md"
with open(file_path, 'w', encoding='utf-8') as f:
f.write(description)
# 更新配置
if str(file_path) not in self.resource_config.style.paths:
self.resource_config.style.paths.append(str(file_path))
logger.info(f"风格保存到文件系统成功: {file_path}")
return True
except Exception as e:
logger.error(f"风格保存到文件系统失败: {e}")
return False
def save_audience(self, name: str, description: str) -> bool:
"""
保存目标受众
Args:
name: 受众名称
description: 受众描述
Returns:
是否保存成功
"""
# 优先保存到数据库
if self.db_pool:
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor()
# 检查是否存在
cursor.execute(
"SELECT COUNT(*) FROM targetAudience WHERE audienceName = %s",
(name,)
)
count = cursor.fetchone()[0]
if count > 0:
# 更新
cursor.execute(
"UPDATE targetAudience SET description = %s WHERE audienceName = %s",
(description, name)
)
else:
# 插入
cursor.execute(
"INSERT INTO targetAudience (audienceName, description) VALUES (%s, %s)",
(name, description)
)
conn.commit()
cursor.close()
conn.close()
logger.info(f"受众保存到数据库成功: {name}")
return True
except Exception as e:
logger.error(f"受众保存到数据库失败: {e}")
# 回退到文件系统
try:
# 确保受众目录存在
demand_dir = Path(self.resource_config.resource_dirs[0]) / "resource" / "prompt" / "Demand"
demand_dir.mkdir(parents=True, exist_ok=True)
# 保存文件
file_path = demand_dir / f"{name}文旅需求.md"
with open(file_path, 'w', encoding='utf-8') as f:
f.write(description)
# 更新配置
if str(file_path) not in self.resource_config.demand.paths:
self.resource_config.demand.paths.append(str(file_path))
logger.info(f"受众保存到文件系统成功: {file_path}")
return True
except Exception as e:
logger.error(f"受众保存到文件系统失败: {e}")
return False