414 lines
13 KiB
Python
414 lines
13 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
参考文献库管理器
|
||
|
||
功能:
|
||
1. 加载和缓存参考文献
|
||
2. 按人群/风格智能匹配
|
||
3. 随机抽样
|
||
4. 支持增删改查 (通过文件操作)
|
||
|
||
使用方式:
|
||
from domain.prompt.reference_manager import ReferenceManager
|
||
|
||
manager = ReferenceManager()
|
||
|
||
# 获取标题参考 (自动匹配)
|
||
titles = manager.get_titles(audience_id='qinzi', count=20)
|
||
|
||
# 获取正文范文 (自动匹配)
|
||
contents = manager.get_contents(style_id='gonglue', count=3)
|
||
|
||
# 添加新的标题参考
|
||
manager.add_title("新的爆款标题模板")
|
||
|
||
# 列出所有参考文献
|
||
manager.list_all()
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import random
|
||
import logging
|
||
from typing import List, Dict, Any, Optional
|
||
from pathlib import Path
|
||
|
||
import yaml
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class ReferenceManager:
|
||
"""参考文献库管理器"""
|
||
|
||
def __init__(self, base_path: str = None):
|
||
"""
|
||
初始化管理器
|
||
|
||
Args:
|
||
base_path: 参考文献库根目录,默认为 prompts/reference
|
||
"""
|
||
if base_path is None:
|
||
base_path = os.path.join(
|
||
os.path.dirname(__file__),
|
||
'../../prompts/reference'
|
||
)
|
||
|
||
self.base_path = os.path.abspath(base_path)
|
||
self._config = None
|
||
self._cache = {} # 文件缓存
|
||
|
||
logger.info(f"ReferenceManager 初始化: {self.base_path}")
|
||
|
||
# ==================== 读取接口 ====================
|
||
|
||
def get_titles(self,
|
||
audience_id: str = None,
|
||
style_id: str = None,
|
||
count: int = None) -> List[str]:
|
||
"""
|
||
获取标题参考
|
||
|
||
Args:
|
||
audience_id: 人群 ID (用于智能匹配)
|
||
style_id: 风格 ID (用于智能匹配)
|
||
count: 抽取数量,None 表示使用配置默认值
|
||
|
||
Returns:
|
||
标题列表
|
||
"""
|
||
config = self._get_config()
|
||
titles_config = config.get('titles', {}).get('default', {})
|
||
|
||
# TODO: 实现按 audience_id/style_id 匹配
|
||
# 目前使用默认配置
|
||
|
||
file_name = titles_config.get('file', '标题参考格式.json')
|
||
default_count = titles_config.get('sample_count', 20)
|
||
|
||
if count is None:
|
||
count = default_count
|
||
|
||
# 加载文件
|
||
data = self._load_file(file_name)
|
||
examples = self._extract_examples(data)
|
||
|
||
# 随机抽样
|
||
if len(examples) <= count:
|
||
return examples
|
||
return random.sample(examples, count)
|
||
|
||
def get_contents(self,
|
||
audience_id: str = None,
|
||
style_id: str = None,
|
||
count: int = None) -> List[str]:
|
||
"""
|
||
获取正文范文
|
||
|
||
Args:
|
||
audience_id: 人群 ID
|
||
style_id: 风格 ID
|
||
count: 抽取数量
|
||
|
||
Returns:
|
||
正文范文列表
|
||
"""
|
||
config = self._get_config()
|
||
contents_config = config.get('contents', {}).get('default', {})
|
||
|
||
file_name = contents_config.get('file', '正文范文参考.json')
|
||
default_count = contents_config.get('sample_count', 3)
|
||
|
||
if count is None:
|
||
count = default_count
|
||
|
||
data = self._load_file(file_name)
|
||
examples = self._extract_examples(data)
|
||
|
||
if len(examples) <= count:
|
||
return examples
|
||
return random.sample(examples, count)
|
||
|
||
def list_all(self) -> Dict[str, Any]:
|
||
"""
|
||
列出所有参考文献统计
|
||
|
||
Returns:
|
||
{
|
||
'titles': {'count': 100, 'file': '...'},
|
||
'contents': {'count': 10, 'file': '...'}
|
||
}
|
||
"""
|
||
config = self._get_config()
|
||
|
||
result = {}
|
||
|
||
# 标题统计
|
||
titles_file = config.get('titles', {}).get('default', {}).get('file')
|
||
if titles_file:
|
||
data = self._load_file(titles_file)
|
||
examples = self._extract_examples(data)
|
||
result['titles'] = {
|
||
'file': titles_file,
|
||
'count': len(examples),
|
||
'sample_count': config.get('titles', {}).get('default', {}).get('sample_count', 20)
|
||
}
|
||
|
||
# 正文统计
|
||
contents_file = config.get('contents', {}).get('default', {}).get('file')
|
||
if contents_file:
|
||
data = self._load_file(contents_file)
|
||
examples = self._extract_examples(data)
|
||
result['contents'] = {
|
||
'file': contents_file,
|
||
'count': len(examples),
|
||
'sample_count': config.get('contents', {}).get('default', {}).get('sample_count', 3)
|
||
}
|
||
|
||
return result
|
||
|
||
def get_all_titles(self) -> List[str]:
|
||
"""获取所有标题参考"""
|
||
config = self._get_config()
|
||
file_name = config.get('titles', {}).get('default', {}).get('file', 'titles.yaml')
|
||
data = self._load_file(file_name)
|
||
return self._extract_examples(data)
|
||
|
||
def get_all_contents(self) -> List[str]:
|
||
"""获取所有正文范文"""
|
||
config = self._get_config()
|
||
file_name = config.get('contents', {}).get('default', {}).get('file', 'contents.yaml')
|
||
data = self._load_file(file_name)
|
||
return self._extract_examples(data)
|
||
|
||
# ==================== 写入接口 ====================
|
||
|
||
def add_title(self, content: str) -> bool:
|
||
"""
|
||
添加新的标题参考
|
||
|
||
Args:
|
||
content: 标题内容
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
config = self._get_config()
|
||
file_name = config.get('titles', {}).get('default', {}).get('file', '标题参考格式.json')
|
||
|
||
return self._add_example(file_name, content)
|
||
|
||
def add_content(self, content: str) -> bool:
|
||
"""
|
||
添加新的正文范文
|
||
|
||
Args:
|
||
content: 正文内容
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
config = self._get_config()
|
||
file_name = config.get('contents', {}).get('default', {}).get('file', '正文范文参考.json')
|
||
|
||
return self._add_example(file_name, content)
|
||
|
||
def remove_title(self, index: int) -> bool:
|
||
"""
|
||
删除标题参考
|
||
|
||
Args:
|
||
index: 索引位置
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
config = self._get_config()
|
||
file_name = config.get('titles', {}).get('default', {}).get('file', '标题参考格式.json')
|
||
|
||
return self._remove_example(file_name, index)
|
||
|
||
def remove_content(self, index: int) -> bool:
|
||
"""
|
||
删除正文范文
|
||
|
||
Args:
|
||
index: 索引位置
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
config = self._get_config()
|
||
file_name = config.get('contents', {}).get('default', {}).get('file', 'contents.yaml')
|
||
|
||
return self._remove_example(file_name, index)
|
||
|
||
def update_title(self, index: int, content: str) -> bool:
|
||
"""更新标题参考"""
|
||
config = self._get_config()
|
||
file_name = config.get('titles', {}).get('default', {}).get('file', 'titles.yaml')
|
||
return self._update_example(file_name, index, content)
|
||
|
||
def update_content(self, index: int, content: str) -> bool:
|
||
"""更新正文范文"""
|
||
config = self._get_config()
|
||
file_name = config.get('contents', {}).get('default', {}).get('file', 'contents.yaml')
|
||
return self._update_example(file_name, index, content)
|
||
|
||
def clear_cache(self):
|
||
"""清除缓存"""
|
||
self._cache = {}
|
||
self._config = None
|
||
logger.info("参考文献缓存已清除")
|
||
|
||
# ==================== 内部方法 ====================
|
||
|
||
def _get_config(self) -> Dict:
|
||
"""获取配置"""
|
||
if self._config is not None:
|
||
return self._config
|
||
|
||
config_path = os.path.join(self.base_path, 'index.yaml')
|
||
|
||
if os.path.exists(config_path):
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
self._config = yaml.safe_load(f) or {}
|
||
else:
|
||
# 默认配置
|
||
self._config = {
|
||
'titles': {
|
||
'default': {
|
||
'file': '标题参考格式.json',
|
||
'sample_count': 20
|
||
}
|
||
},
|
||
'contents': {
|
||
'default': {
|
||
'file': '正文范文参考.json',
|
||
'sample_count': 3
|
||
}
|
||
}
|
||
}
|
||
|
||
return self._config
|
||
|
||
def _extract_examples(self, data: Dict) -> List[str]:
|
||
"""
|
||
从数据中提取示例列表
|
||
|
||
支持两种格式:
|
||
1. YAML 格式: examples 是字符串列表
|
||
2. JSON 格式: examples 是 [{"content": "..."}, ...] 列表
|
||
"""
|
||
examples = data.get('examples', [])
|
||
|
||
if not examples:
|
||
return []
|
||
|
||
# 判断格式
|
||
if isinstance(examples[0], str):
|
||
# YAML 格式: 直接是字符串列表
|
||
return examples
|
||
elif isinstance(examples[0], dict):
|
||
# JSON 格式: 需要提取 content 字段
|
||
return [ex.get('content', '') for ex in examples]
|
||
else:
|
||
return []
|
||
|
||
def _load_file(self, file_name: str) -> Dict:
|
||
"""加载参考文献文件 (支持 YAML 和 JSON,带缓存)"""
|
||
if file_name in self._cache:
|
||
return self._cache[file_name]
|
||
|
||
file_path = os.path.join(self.base_path, file_name)
|
||
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
# 根据扩展名选择解析器
|
||
if file_name.endswith('.yaml') or file_name.endswith('.yml'):
|
||
data = yaml.safe_load(f) or {}
|
||
else:
|
||
data = json.load(f)
|
||
|
||
self._cache[file_name] = data
|
||
return data
|
||
except Exception as e:
|
||
logger.error(f"加载参考文献失败: {file_path}, {e}")
|
||
return {'examples': []}
|
||
|
||
def _save_file(self, file_name: str, data: Dict) -> bool:
|
||
"""保存参考文献文件 (支持 YAML 和 JSON)"""
|
||
file_path = os.path.join(self.base_path, file_name)
|
||
|
||
try:
|
||
with open(file_path, 'w', encoding='utf-8') as f:
|
||
if file_name.endswith('.yaml') or file_name.endswith('.yml'):
|
||
yaml.dump(data, f, allow_unicode=True, default_flow_style=False, sort_keys=False)
|
||
else:
|
||
json.dump(data, f, ensure_ascii=False, indent=4)
|
||
|
||
# 清除缓存
|
||
if file_name in self._cache:
|
||
del self._cache[file_name]
|
||
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"保存参考文献失败: {file_path}, {e}")
|
||
return False
|
||
|
||
def _add_example(self, file_name: str, content: str) -> bool:
|
||
"""添加示例"""
|
||
data = self._load_file(file_name)
|
||
examples = self._extract_examples(data)
|
||
|
||
# 检查重复
|
||
if content in examples:
|
||
logger.warning(f"参考文献已存在: {content[:50]}...")
|
||
return False
|
||
|
||
data.setdefault('examples', []).append(content)
|
||
|
||
return self._save_file(file_name, data)
|
||
|
||
def _remove_example(self, file_name: str, index: int) -> bool:
|
||
"""删除示例"""
|
||
data = self._load_file(file_name)
|
||
examples = data.get('examples', [])
|
||
|
||
if index < 0 or index >= len(examples):
|
||
logger.error(f"索引超出范围: {index}, 总数: {len(examples)}")
|
||
return False
|
||
|
||
removed = examples.pop(index)
|
||
removed_str = removed if isinstance(removed, str) else removed.get('content', '')
|
||
logger.info(f"删除参考文献: {removed_str[:50]}...")
|
||
|
||
return self._save_file(file_name, data)
|
||
|
||
def _update_example(self, file_name: str, index: int, content: str) -> bool:
|
||
"""更新示例"""
|
||
data = self._load_file(file_name)
|
||
examples = data.get('examples', [])
|
||
|
||
if index < 0 or index >= len(examples):
|
||
logger.error(f"索引超出范围: {index}, 总数: {len(examples)}")
|
||
return False
|
||
|
||
examples[index] = content
|
||
logger.info(f"更新参考文献[{index}]: {content[:50]}...")
|
||
|
||
return self._save_file(file_name, data)
|
||
|
||
|
||
# 全局单例
|
||
_manager_instance = None
|
||
|
||
def get_reference_manager() -> ReferenceManager:
|
||
"""获取全局 ReferenceManager 实例"""
|
||
global _manager_instance
|
||
if _manager_instance is None:
|
||
_manager_instance = ReferenceManager()
|
||
return _manager_instance
|