414 lines
13 KiB
Python
414 lines
13 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
参考文献库管理器
|
|||
|
|
|
|||
|
|
功能:
|
|||
|
|
1. 加载和缓存参考文献
|
|||
|
|
2. 按人群/风格智能匹配
|
|||
|
|
3. 随机抽样
|
|||
|
|
4. 支持增删改查 (通过文件操作)
|
|||
|
|
|
|||
|
|
使用方式:
|
|||
|
|
from domain.prompt.reference_manager import ReferenceManager
|
|||
|
|
|
|||
|
|
manager = ReferenceManager()
|
|||
|
|
|
|||
|
|
# 获取标题参考 (自动匹配)
|
|||
|
|
titles = manager.get_titles(audience_id='qinzi', count=20)
|
|||
|
|
|
|||
|
|
# 获取正文范文 (自动匹配)
|
|||
|
|
contents = manager.get_contents(style_id='gonglue', count=3)
|
|||
|
|
|
|||
|
|
# 添加新的标题参考
|
|||
|
|
manager.add_title("新的爆款标题模板")
|
|||
|
|
|
|||
|
|
# 列出所有参考文献
|
|||
|
|
manager.list_all()
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import os
|
|||
|
|
import random
|
|||
|
|
import logging
|
|||
|
|
from typing import List, Dict, Any, Optional
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
import yaml
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ReferenceManager:
|
|||
|
|
"""参考文献库管理器"""
|
|||
|
|
|
|||
|
|
def __init__(self, base_path: str = None):
|
|||
|
|
"""
|
|||
|
|
初始化管理器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
base_path: 参考文献库根目录,默认为 prompts/reference
|
|||
|
|
"""
|
|||
|
|
if base_path is None:
|
|||
|
|
base_path = os.path.join(
|
|||
|
|
os.path.dirname(__file__),
|
|||
|
|
'../../prompts/reference'
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
self.base_path = os.path.abspath(base_path)
|
|||
|
|
self._config = None
|
|||
|
|
self._cache = {} # 文件缓存
|
|||
|
|
|
|||
|
|
logger.info(f"ReferenceManager 初始化: {self.base_path}")
|
|||
|
|
|
|||
|
|
# ==================== 读取接口 ====================
|
|||
|
|
|
|||
|
|
def get_titles(self,
|
|||
|
|
audience_id: str = None,
|
|||
|
|
style_id: str = None,
|
|||
|
|
count: int = None) -> List[str]:
|
|||
|
|
"""
|
|||
|
|
获取标题参考
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
audience_id: 人群 ID (用于智能匹配)
|
|||
|
|
style_id: 风格 ID (用于智能匹配)
|
|||
|
|
count: 抽取数量,None 表示使用配置默认值
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
标题列表
|
|||
|
|
"""
|
|||
|
|
config = self._get_config()
|
|||
|
|
titles_config = config.get('titles', {}).get('default', {})
|
|||
|
|
|
|||
|
|
# TODO: 实现按 audience_id/style_id 匹配
|
|||
|
|
# 目前使用默认配置
|
|||
|
|
|
|||
|
|
file_name = titles_config.get('file', '标题参考格式.json')
|
|||
|
|
default_count = titles_config.get('sample_count', 20)
|
|||
|
|
|
|||
|
|
if count is None:
|
|||
|
|
count = default_count
|
|||
|
|
|
|||
|
|
# 加载文件
|
|||
|
|
data = self._load_file(file_name)
|
|||
|
|
examples = self._extract_examples(data)
|
|||
|
|
|
|||
|
|
# 随机抽样
|
|||
|
|
if len(examples) <= count:
|
|||
|
|
return examples
|
|||
|
|
return random.sample(examples, count)
|
|||
|
|
|
|||
|
|
def get_contents(self,
|
|||
|
|
audience_id: str = None,
|
|||
|
|
style_id: str = None,
|
|||
|
|
count: int = None) -> List[str]:
|
|||
|
|
"""
|
|||
|
|
获取正文范文
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
audience_id: 人群 ID
|
|||
|
|
style_id: 风格 ID
|
|||
|
|
count: 抽取数量
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
正文范文列表
|
|||
|
|
"""
|
|||
|
|
config = self._get_config()
|
|||
|
|
contents_config = config.get('contents', {}).get('default', {})
|
|||
|
|
|
|||
|
|
file_name = contents_config.get('file', '正文范文参考.json')
|
|||
|
|
default_count = contents_config.get('sample_count', 3)
|
|||
|
|
|
|||
|
|
if count is None:
|
|||
|
|
count = default_count
|
|||
|
|
|
|||
|
|
data = self._load_file(file_name)
|
|||
|
|
examples = self._extract_examples(data)
|
|||
|
|
|
|||
|
|
if len(examples) <= count:
|
|||
|
|
return examples
|
|||
|
|
return random.sample(examples, count)
|
|||
|
|
|
|||
|
|
def list_all(self) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
列出所有参考文献统计
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
{
|
|||
|
|
'titles': {'count': 100, 'file': '...'},
|
|||
|
|
'contents': {'count': 10, 'file': '...'}
|
|||
|
|
}
|
|||
|
|
"""
|
|||
|
|
config = self._get_config()
|
|||
|
|
|
|||
|
|
result = {}
|
|||
|
|
|
|||
|
|
# 标题统计
|
|||
|
|
titles_file = config.get('titles', {}).get('default', {}).get('file')
|
|||
|
|
if titles_file:
|
|||
|
|
data = self._load_file(titles_file)
|
|||
|
|
examples = self._extract_examples(data)
|
|||
|
|
result['titles'] = {
|
|||
|
|
'file': titles_file,
|
|||
|
|
'count': len(examples),
|
|||
|
|
'sample_count': config.get('titles', {}).get('default', {}).get('sample_count', 20)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 正文统计
|
|||
|
|
contents_file = config.get('contents', {}).get('default', {}).get('file')
|
|||
|
|
if contents_file:
|
|||
|
|
data = self._load_file(contents_file)
|
|||
|
|
examples = self._extract_examples(data)
|
|||
|
|
result['contents'] = {
|
|||
|
|
'file': contents_file,
|
|||
|
|
'count': len(examples),
|
|||
|
|
'sample_count': config.get('contents', {}).get('default', {}).get('sample_count', 3)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
def get_all_titles(self) -> List[str]:
|
|||
|
|
"""获取所有标题参考"""
|
|||
|
|
config = self._get_config()
|
|||
|
|
file_name = config.get('titles', {}).get('default', {}).get('file', 'titles.yaml')
|
|||
|
|
data = self._load_file(file_name)
|
|||
|
|
return self._extract_examples(data)
|
|||
|
|
|
|||
|
|
def get_all_contents(self) -> List[str]:
|
|||
|
|
"""获取所有正文范文"""
|
|||
|
|
config = self._get_config()
|
|||
|
|
file_name = config.get('contents', {}).get('default', {}).get('file', 'contents.yaml')
|
|||
|
|
data = self._load_file(file_name)
|
|||
|
|
return self._extract_examples(data)
|
|||
|
|
|
|||
|
|
# ==================== 写入接口 ====================
|
|||
|
|
|
|||
|
|
def add_title(self, content: str) -> bool:
|
|||
|
|
"""
|
|||
|
|
添加新的标题参考
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
content: 标题内容
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
是否成功
|
|||
|
|
"""
|
|||
|
|
config = self._get_config()
|
|||
|
|
file_name = config.get('titles', {}).get('default', {}).get('file', '标题参考格式.json')
|
|||
|
|
|
|||
|
|
return self._add_example(file_name, content)
|
|||
|
|
|
|||
|
|
def add_content(self, content: str) -> bool:
|
|||
|
|
"""
|
|||
|
|
添加新的正文范文
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
content: 正文内容
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
是否成功
|
|||
|
|
"""
|
|||
|
|
config = self._get_config()
|
|||
|
|
file_name = config.get('contents', {}).get('default', {}).get('file', '正文范文参考.json')
|
|||
|
|
|
|||
|
|
return self._add_example(file_name, content)
|
|||
|
|
|
|||
|
|
def remove_title(self, index: int) -> bool:
|
|||
|
|
"""
|
|||
|
|
删除标题参考
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
index: 索引位置
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
是否成功
|
|||
|
|
"""
|
|||
|
|
config = self._get_config()
|
|||
|
|
file_name = config.get('titles', {}).get('default', {}).get('file', '标题参考格式.json')
|
|||
|
|
|
|||
|
|
return self._remove_example(file_name, index)
|
|||
|
|
|
|||
|
|
def remove_content(self, index: int) -> bool:
|
|||
|
|
"""
|
|||
|
|
删除正文范文
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
index: 索引位置
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
是否成功
|
|||
|
|
"""
|
|||
|
|
config = self._get_config()
|
|||
|
|
file_name = config.get('contents', {}).get('default', {}).get('file', 'contents.yaml')
|
|||
|
|
|
|||
|
|
return self._remove_example(file_name, index)
|
|||
|
|
|
|||
|
|
def update_title(self, index: int, content: str) -> bool:
|
|||
|
|
"""更新标题参考"""
|
|||
|
|
config = self._get_config()
|
|||
|
|
file_name = config.get('titles', {}).get('default', {}).get('file', 'titles.yaml')
|
|||
|
|
return self._update_example(file_name, index, content)
|
|||
|
|
|
|||
|
|
def update_content(self, index: int, content: str) -> bool:
|
|||
|
|
"""更新正文范文"""
|
|||
|
|
config = self._get_config()
|
|||
|
|
file_name = config.get('contents', {}).get('default', {}).get('file', 'contents.yaml')
|
|||
|
|
return self._update_example(file_name, index, content)
|
|||
|
|
|
|||
|
|
def clear_cache(self):
|
|||
|
|
"""清除缓存"""
|
|||
|
|
self._cache = {}
|
|||
|
|
self._config = None
|
|||
|
|
logger.info("参考文献缓存已清除")
|
|||
|
|
|
|||
|
|
# ==================== 内部方法 ====================
|
|||
|
|
|
|||
|
|
def _get_config(self) -> Dict:
|
|||
|
|
"""获取配置"""
|
|||
|
|
if self._config is not None:
|
|||
|
|
return self._config
|
|||
|
|
|
|||
|
|
config_path = os.path.join(self.base_path, 'index.yaml')
|
|||
|
|
|
|||
|
|
if os.path.exists(config_path):
|
|||
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|||
|
|
self._config = yaml.safe_load(f) or {}
|
|||
|
|
else:
|
|||
|
|
# 默认配置
|
|||
|
|
self._config = {
|
|||
|
|
'titles': {
|
|||
|
|
'default': {
|
|||
|
|
'file': '标题参考格式.json',
|
|||
|
|
'sample_count': 20
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
'contents': {
|
|||
|
|
'default': {
|
|||
|
|
'file': '正文范文参考.json',
|
|||
|
|
'sample_count': 3
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return self._config
|
|||
|
|
|
|||
|
|
def _extract_examples(self, data: Dict) -> List[str]:
|
|||
|
|
"""
|
|||
|
|
从数据中提取示例列表
|
|||
|
|
|
|||
|
|
支持两种格式:
|
|||
|
|
1. YAML 格式: examples 是字符串列表
|
|||
|
|
2. JSON 格式: examples 是 [{"content": "..."}, ...] 列表
|
|||
|
|
"""
|
|||
|
|
examples = data.get('examples', [])
|
|||
|
|
|
|||
|
|
if not examples:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
# 判断格式
|
|||
|
|
if isinstance(examples[0], str):
|
|||
|
|
# YAML 格式: 直接是字符串列表
|
|||
|
|
return examples
|
|||
|
|
elif isinstance(examples[0], dict):
|
|||
|
|
# JSON 格式: 需要提取 content 字段
|
|||
|
|
return [ex.get('content', '') for ex in examples]
|
|||
|
|
else:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
def _load_file(self, file_name: str) -> Dict:
|
|||
|
|
"""加载参考文献文件 (支持 YAML 和 JSON,带缓存)"""
|
|||
|
|
if file_name in self._cache:
|
|||
|
|
return self._cache[file_name]
|
|||
|
|
|
|||
|
|
file_path = os.path.join(self.base_path, file_name)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|||
|
|
# 根据扩展名选择解析器
|
|||
|
|
if file_name.endswith('.yaml') or file_name.endswith('.yml'):
|
|||
|
|
data = yaml.safe_load(f) or {}
|
|||
|
|
else:
|
|||
|
|
data = json.load(f)
|
|||
|
|
|
|||
|
|
self._cache[file_name] = data
|
|||
|
|
return data
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"加载参考文献失败: {file_path}, {e}")
|
|||
|
|
return {'examples': []}
|
|||
|
|
|
|||
|
|
def _save_file(self, file_name: str, data: Dict) -> bool:
|
|||
|
|
"""保存参考文献文件 (支持 YAML 和 JSON)"""
|
|||
|
|
file_path = os.path.join(self.base_path, file_name)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
with open(file_path, 'w', encoding='utf-8') as f:
|
|||
|
|
if file_name.endswith('.yaml') or file_name.endswith('.yml'):
|
|||
|
|
yaml.dump(data, f, allow_unicode=True, default_flow_style=False, sort_keys=False)
|
|||
|
|
else:
|
|||
|
|
json.dump(data, f, ensure_ascii=False, indent=4)
|
|||
|
|
|
|||
|
|
# 清除缓存
|
|||
|
|
if file_name in self._cache:
|
|||
|
|
del self._cache[file_name]
|
|||
|
|
|
|||
|
|
return True
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"保存参考文献失败: {file_path}, {e}")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def _add_example(self, file_name: str, content: str) -> bool:
|
|||
|
|
"""添加示例"""
|
|||
|
|
data = self._load_file(file_name)
|
|||
|
|
examples = self._extract_examples(data)
|
|||
|
|
|
|||
|
|
# 检查重复
|
|||
|
|
if content in examples:
|
|||
|
|
logger.warning(f"参考文献已存在: {content[:50]}...")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
data.setdefault('examples', []).append(content)
|
|||
|
|
|
|||
|
|
return self._save_file(file_name, data)
|
|||
|
|
|
|||
|
|
def _remove_example(self, file_name: str, index: int) -> bool:
|
|||
|
|
"""删除示例"""
|
|||
|
|
data = self._load_file(file_name)
|
|||
|
|
examples = data.get('examples', [])
|
|||
|
|
|
|||
|
|
if index < 0 or index >= len(examples):
|
|||
|
|
logger.error(f"索引超出范围: {index}, 总数: {len(examples)}")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
removed = examples.pop(index)
|
|||
|
|
removed_str = removed if isinstance(removed, str) else removed.get('content', '')
|
|||
|
|
logger.info(f"删除参考文献: {removed_str[:50]}...")
|
|||
|
|
|
|||
|
|
return self._save_file(file_name, data)
|
|||
|
|
|
|||
|
|
def _update_example(self, file_name: str, index: int, content: str) -> bool:
|
|||
|
|
"""更新示例"""
|
|||
|
|
data = self._load_file(file_name)
|
|||
|
|
examples = data.get('examples', [])
|
|||
|
|
|
|||
|
|
if index < 0 or index >= len(examples):
|
|||
|
|
logger.error(f"索引超出范围: {index}, 总数: {len(examples)}")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
examples[index] = content
|
|||
|
|
logger.info(f"更新参考文献[{index}]: {content[:50]}...")
|
|||
|
|
|
|||
|
|
return self._save_file(file_name, data)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局单例
|
|||
|
|
_manager_instance = None
|
|||
|
|
|
|||
|
|
def get_reference_manager() -> ReferenceManager:
|
|||
|
|
"""获取全局 ReferenceManager 实例"""
|
|||
|
|
global _manager_instance
|
|||
|
|
if _manager_instance is None:
|
|||
|
|
_manager_instance = ReferenceManager()
|
|||
|
|
return _manager_instance
|