TravelContentCreator/domain/prompt/reference_manager.py

414 lines
13 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
参考文献库管理器
功能:
1. 加载和缓存参考文献
2. 按人群/风格智能匹配
3. 随机抽样
4. 支持增删改查 (通过文件操作)
使用方式:
from domain.prompt.reference_manager import ReferenceManager
manager = ReferenceManager()
# 获取标题参考 (自动匹配)
titles = manager.get_titles(audience_id='qinzi', count=20)
# 获取正文范文 (自动匹配)
contents = manager.get_contents(style_id='gonglue', count=3)
# 添加新的标题参考
manager.add_title("新的爆款标题模板")
# 列出所有参考文献
manager.list_all()
"""
import json
import os
import random
import logging
from typing import List, Dict, Any, Optional
from pathlib import Path
import yaml
logger = logging.getLogger(__name__)
class ReferenceManager:
"""参考文献库管理器"""
def __init__(self, base_path: str = None):
"""
初始化管理器
Args:
base_path: 参考文献库根目录默认为 prompts/reference
"""
if base_path is None:
base_path = os.path.join(
os.path.dirname(__file__),
'../../prompts/reference'
)
self.base_path = os.path.abspath(base_path)
self._config = None
self._cache = {} # 文件缓存
logger.info(f"ReferenceManager 初始化: {self.base_path}")
# ==================== 读取接口 ====================
def get_titles(self,
audience_id: str = None,
style_id: str = None,
count: int = None) -> List[str]:
"""
获取标题参考
Args:
audience_id: 人群 ID (用于智能匹配)
style_id: 风格 ID (用于智能匹配)
count: 抽取数量None 表示使用配置默认值
Returns:
标题列表
"""
config = self._get_config()
titles_config = config.get('titles', {}).get('default', {})
# TODO: 实现按 audience_id/style_id 匹配
# 目前使用默认配置
file_name = titles_config.get('file', '标题参考格式.json')
default_count = titles_config.get('sample_count', 20)
if count is None:
count = default_count
# 加载文件
data = self._load_file(file_name)
examples = self._extract_examples(data)
# 随机抽样
if len(examples) <= count:
return examples
return random.sample(examples, count)
def get_contents(self,
audience_id: str = None,
style_id: str = None,
count: int = None) -> List[str]:
"""
获取正文范文
Args:
audience_id: 人群 ID
style_id: 风格 ID
count: 抽取数量
Returns:
正文范文列表
"""
config = self._get_config()
contents_config = config.get('contents', {}).get('default', {})
file_name = contents_config.get('file', '正文范文参考.json')
default_count = contents_config.get('sample_count', 3)
if count is None:
count = default_count
data = self._load_file(file_name)
examples = self._extract_examples(data)
if len(examples) <= count:
return examples
return random.sample(examples, count)
def list_all(self) -> Dict[str, Any]:
"""
列出所有参考文献统计
Returns:
{
'titles': {'count': 100, 'file': '...'},
'contents': {'count': 10, 'file': '...'}
}
"""
config = self._get_config()
result = {}
# 标题统计
titles_file = config.get('titles', {}).get('default', {}).get('file')
if titles_file:
data = self._load_file(titles_file)
examples = self._extract_examples(data)
result['titles'] = {
'file': titles_file,
'count': len(examples),
'sample_count': config.get('titles', {}).get('default', {}).get('sample_count', 20)
}
# 正文统计
contents_file = config.get('contents', {}).get('default', {}).get('file')
if contents_file:
data = self._load_file(contents_file)
examples = self._extract_examples(data)
result['contents'] = {
'file': contents_file,
'count': len(examples),
'sample_count': config.get('contents', {}).get('default', {}).get('sample_count', 3)
}
return result
def get_all_titles(self) -> List[str]:
"""获取所有标题参考"""
config = self._get_config()
file_name = config.get('titles', {}).get('default', {}).get('file', 'titles.yaml')
data = self._load_file(file_name)
return self._extract_examples(data)
def get_all_contents(self) -> List[str]:
"""获取所有正文范文"""
config = self._get_config()
file_name = config.get('contents', {}).get('default', {}).get('file', 'contents.yaml')
data = self._load_file(file_name)
return self._extract_examples(data)
# ==================== 写入接口 ====================
def add_title(self, content: str) -> bool:
"""
添加新的标题参考
Args:
content: 标题内容
Returns:
是否成功
"""
config = self._get_config()
file_name = config.get('titles', {}).get('default', {}).get('file', '标题参考格式.json')
return self._add_example(file_name, content)
def add_content(self, content: str) -> bool:
"""
添加新的正文范文
Args:
content: 正文内容
Returns:
是否成功
"""
config = self._get_config()
file_name = config.get('contents', {}).get('default', {}).get('file', '正文范文参考.json')
return self._add_example(file_name, content)
def remove_title(self, index: int) -> bool:
"""
删除标题参考
Args:
index: 索引位置
Returns:
是否成功
"""
config = self._get_config()
file_name = config.get('titles', {}).get('default', {}).get('file', '标题参考格式.json')
return self._remove_example(file_name, index)
def remove_content(self, index: int) -> bool:
"""
删除正文范文
Args:
index: 索引位置
Returns:
是否成功
"""
config = self._get_config()
file_name = config.get('contents', {}).get('default', {}).get('file', 'contents.yaml')
return self._remove_example(file_name, index)
def update_title(self, index: int, content: str) -> bool:
"""更新标题参考"""
config = self._get_config()
file_name = config.get('titles', {}).get('default', {}).get('file', 'titles.yaml')
return self._update_example(file_name, index, content)
def update_content(self, index: int, content: str) -> bool:
"""更新正文范文"""
config = self._get_config()
file_name = config.get('contents', {}).get('default', {}).get('file', 'contents.yaml')
return self._update_example(file_name, index, content)
def clear_cache(self):
"""清除缓存"""
self._cache = {}
self._config = None
logger.info("参考文献缓存已清除")
# ==================== 内部方法 ====================
def _get_config(self) -> Dict:
"""获取配置"""
if self._config is not None:
return self._config
config_path = os.path.join(self.base_path, 'index.yaml')
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
self._config = yaml.safe_load(f) or {}
else:
# 默认配置
self._config = {
'titles': {
'default': {
'file': '标题参考格式.json',
'sample_count': 20
}
},
'contents': {
'default': {
'file': '正文范文参考.json',
'sample_count': 3
}
}
}
return self._config
def _extract_examples(self, data: Dict) -> List[str]:
"""
从数据中提取示例列表
支持两种格式:
1. YAML 格式: examples 是字符串列表
2. JSON 格式: examples [{"content": "..."}, ...] 列表
"""
examples = data.get('examples', [])
if not examples:
return []
# 判断格式
if isinstance(examples[0], str):
# YAML 格式: 直接是字符串列表
return examples
elif isinstance(examples[0], dict):
# JSON 格式: 需要提取 content 字段
return [ex.get('content', '') for ex in examples]
else:
return []
def _load_file(self, file_name: str) -> Dict:
"""加载参考文献文件 (支持 YAML 和 JSON带缓存)"""
if file_name in self._cache:
return self._cache[file_name]
file_path = os.path.join(self.base_path, file_name)
try:
with open(file_path, 'r', encoding='utf-8') as f:
# 根据扩展名选择解析器
if file_name.endswith('.yaml') or file_name.endswith('.yml'):
data = yaml.safe_load(f) or {}
else:
data = json.load(f)
self._cache[file_name] = data
return data
except Exception as e:
logger.error(f"加载参考文献失败: {file_path}, {e}")
return {'examples': []}
def _save_file(self, file_name: str, data: Dict) -> bool:
"""保存参考文献文件 (支持 YAML 和 JSON)"""
file_path = os.path.join(self.base_path, file_name)
try:
with open(file_path, 'w', encoding='utf-8') as f:
if file_name.endswith('.yaml') or file_name.endswith('.yml'):
yaml.dump(data, f, allow_unicode=True, default_flow_style=False, sort_keys=False)
else:
json.dump(data, f, ensure_ascii=False, indent=4)
# 清除缓存
if file_name in self._cache:
del self._cache[file_name]
return True
except Exception as e:
logger.error(f"保存参考文献失败: {file_path}, {e}")
return False
def _add_example(self, file_name: str, content: str) -> bool:
"""添加示例"""
data = self._load_file(file_name)
examples = self._extract_examples(data)
# 检查重复
if content in examples:
logger.warning(f"参考文献已存在: {content[:50]}...")
return False
data.setdefault('examples', []).append(content)
return self._save_file(file_name, data)
def _remove_example(self, file_name: str, index: int) -> bool:
"""删除示例"""
data = self._load_file(file_name)
examples = data.get('examples', [])
if index < 0 or index >= len(examples):
logger.error(f"索引超出范围: {index}, 总数: {len(examples)}")
return False
removed = examples.pop(index)
removed_str = removed if isinstance(removed, str) else removed.get('content', '')
logger.info(f"删除参考文献: {removed_str[:50]}...")
return self._save_file(file_name, data)
def _update_example(self, file_name: str, index: int, content: str) -> bool:
"""更新示例"""
data = self._load_file(file_name)
examples = data.get('examples', [])
if index < 0 or index >= len(examples):
logger.error(f"索引超出范围: {index}, 总数: {len(examples)}")
return False
examples[index] = content
logger.info(f"更新参考文献[{index}]: {content[:50]}...")
return self._save_file(file_name, data)
# 全局单例
_manager_instance = None
def get_reference_manager() -> ReferenceManager:
"""获取全局 ReferenceManager 实例"""
global _manager_instance
if _manager_instance is None:
_manager_instance = ReferenceManager()
return _manager_instance