TravelContentCreator/domain/prompt/reference_manager.py

414 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
参考文献库管理器
功能:
1. 加载和缓存参考文献
2. 按人群/风格智能匹配
3. 随机抽样
4. 支持增删改查 (通过文件操作)
使用方式:
from domain.prompt.reference_manager import ReferenceManager
manager = ReferenceManager()
# 获取标题参考 (自动匹配)
titles = manager.get_titles(audience_id='qinzi', count=20)
# 获取正文范文 (自动匹配)
contents = manager.get_contents(style_id='gonglue', count=3)
# 添加新的标题参考
manager.add_title("新的爆款标题模板")
# 列出所有参考文献
manager.list_all()
"""
import json
import os
import random
import logging
from typing import List, Dict, Any, Optional
from pathlib import Path
import yaml
logger = logging.getLogger(__name__)
class ReferenceManager:
"""参考文献库管理器"""
def __init__(self, base_path: str = None):
"""
初始化管理器
Args:
base_path: 参考文献库根目录,默认为 prompts/reference
"""
if base_path is None:
base_path = os.path.join(
os.path.dirname(__file__),
'../../prompts/reference'
)
self.base_path = os.path.abspath(base_path)
self._config = None
self._cache = {} # 文件缓存
logger.info(f"ReferenceManager 初始化: {self.base_path}")
# ==================== 读取接口 ====================
def get_titles(self,
audience_id: str = None,
style_id: str = None,
count: int = None) -> List[str]:
"""
获取标题参考
Args:
audience_id: 人群 ID (用于智能匹配)
style_id: 风格 ID (用于智能匹配)
count: 抽取数量None 表示使用配置默认值
Returns:
标题列表
"""
config = self._get_config()
titles_config = config.get('titles', {}).get('default', {})
# TODO: 实现按 audience_id/style_id 匹配
# 目前使用默认配置
file_name = titles_config.get('file', '标题参考格式.json')
default_count = titles_config.get('sample_count', 20)
if count is None:
count = default_count
# 加载文件
data = self._load_file(file_name)
examples = self._extract_examples(data)
# 随机抽样
if len(examples) <= count:
return examples
return random.sample(examples, count)
def get_contents(self,
audience_id: str = None,
style_id: str = None,
count: int = None) -> List[str]:
"""
获取正文范文
Args:
audience_id: 人群 ID
style_id: 风格 ID
count: 抽取数量
Returns:
正文范文列表
"""
config = self._get_config()
contents_config = config.get('contents', {}).get('default', {})
file_name = contents_config.get('file', '正文范文参考.json')
default_count = contents_config.get('sample_count', 3)
if count is None:
count = default_count
data = self._load_file(file_name)
examples = self._extract_examples(data)
if len(examples) <= count:
return examples
return random.sample(examples, count)
def list_all(self) -> Dict[str, Any]:
"""
列出所有参考文献统计
Returns:
{
'titles': {'count': 100, 'file': '...'},
'contents': {'count': 10, 'file': '...'}
}
"""
config = self._get_config()
result = {}
# 标题统计
titles_file = config.get('titles', {}).get('default', {}).get('file')
if titles_file:
data = self._load_file(titles_file)
examples = self._extract_examples(data)
result['titles'] = {
'file': titles_file,
'count': len(examples),
'sample_count': config.get('titles', {}).get('default', {}).get('sample_count', 20)
}
# 正文统计
contents_file = config.get('contents', {}).get('default', {}).get('file')
if contents_file:
data = self._load_file(contents_file)
examples = self._extract_examples(data)
result['contents'] = {
'file': contents_file,
'count': len(examples),
'sample_count': config.get('contents', {}).get('default', {}).get('sample_count', 3)
}
return result
def get_all_titles(self) -> List[str]:
"""获取所有标题参考"""
config = self._get_config()
file_name = config.get('titles', {}).get('default', {}).get('file', 'titles.yaml')
data = self._load_file(file_name)
return self._extract_examples(data)
def get_all_contents(self) -> List[str]:
"""获取所有正文范文"""
config = self._get_config()
file_name = config.get('contents', {}).get('default', {}).get('file', 'contents.yaml')
data = self._load_file(file_name)
return self._extract_examples(data)
# ==================== 写入接口 ====================
def add_title(self, content: str) -> bool:
"""
添加新的标题参考
Args:
content: 标题内容
Returns:
是否成功
"""
config = self._get_config()
file_name = config.get('titles', {}).get('default', {}).get('file', '标题参考格式.json')
return self._add_example(file_name, content)
def add_content(self, content: str) -> bool:
"""
添加新的正文范文
Args:
content: 正文内容
Returns:
是否成功
"""
config = self._get_config()
file_name = config.get('contents', {}).get('default', {}).get('file', '正文范文参考.json')
return self._add_example(file_name, content)
def remove_title(self, index: int) -> bool:
"""
删除标题参考
Args:
index: 索引位置
Returns:
是否成功
"""
config = self._get_config()
file_name = config.get('titles', {}).get('default', {}).get('file', '标题参考格式.json')
return self._remove_example(file_name, index)
def remove_content(self, index: int) -> bool:
"""
删除正文范文
Args:
index: 索引位置
Returns:
是否成功
"""
config = self._get_config()
file_name = config.get('contents', {}).get('default', {}).get('file', 'contents.yaml')
return self._remove_example(file_name, index)
def update_title(self, index: int, content: str) -> bool:
"""更新标题参考"""
config = self._get_config()
file_name = config.get('titles', {}).get('default', {}).get('file', 'titles.yaml')
return self._update_example(file_name, index, content)
def update_content(self, index: int, content: str) -> bool:
"""更新正文范文"""
config = self._get_config()
file_name = config.get('contents', {}).get('default', {}).get('file', 'contents.yaml')
return self._update_example(file_name, index, content)
def clear_cache(self):
"""清除缓存"""
self._cache = {}
self._config = None
logger.info("参考文献缓存已清除")
# ==================== 内部方法 ====================
def _get_config(self) -> Dict:
"""获取配置"""
if self._config is not None:
return self._config
config_path = os.path.join(self.base_path, 'index.yaml')
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
self._config = yaml.safe_load(f) or {}
else:
# 默认配置
self._config = {
'titles': {
'default': {
'file': '标题参考格式.json',
'sample_count': 20
}
},
'contents': {
'default': {
'file': '正文范文参考.json',
'sample_count': 3
}
}
}
return self._config
def _extract_examples(self, data: Dict) -> List[str]:
"""
从数据中提取示例列表
支持两种格式:
1. YAML 格式: examples 是字符串列表
2. JSON 格式: examples 是 [{"content": "..."}, ...] 列表
"""
examples = data.get('examples', [])
if not examples:
return []
# 判断格式
if isinstance(examples[0], str):
# YAML 格式: 直接是字符串列表
return examples
elif isinstance(examples[0], dict):
# JSON 格式: 需要提取 content 字段
return [ex.get('content', '') for ex in examples]
else:
return []
def _load_file(self, file_name: str) -> Dict:
"""加载参考文献文件 (支持 YAML 和 JSON带缓存)"""
if file_name in self._cache:
return self._cache[file_name]
file_path = os.path.join(self.base_path, file_name)
try:
with open(file_path, 'r', encoding='utf-8') as f:
# 根据扩展名选择解析器
if file_name.endswith('.yaml') or file_name.endswith('.yml'):
data = yaml.safe_load(f) or {}
else:
data = json.load(f)
self._cache[file_name] = data
return data
except Exception as e:
logger.error(f"加载参考文献失败: {file_path}, {e}")
return {'examples': []}
def _save_file(self, file_name: str, data: Dict) -> bool:
"""保存参考文献文件 (支持 YAML 和 JSON)"""
file_path = os.path.join(self.base_path, file_name)
try:
with open(file_path, 'w', encoding='utf-8') as f:
if file_name.endswith('.yaml') or file_name.endswith('.yml'):
yaml.dump(data, f, allow_unicode=True, default_flow_style=False, sort_keys=False)
else:
json.dump(data, f, ensure_ascii=False, indent=4)
# 清除缓存
if file_name in self._cache:
del self._cache[file_name]
return True
except Exception as e:
logger.error(f"保存参考文献失败: {file_path}, {e}")
return False
def _add_example(self, file_name: str, content: str) -> bool:
"""添加示例"""
data = self._load_file(file_name)
examples = self._extract_examples(data)
# 检查重复
if content in examples:
logger.warning(f"参考文献已存在: {content[:50]}...")
return False
data.setdefault('examples', []).append(content)
return self._save_file(file_name, data)
def _remove_example(self, file_name: str, index: int) -> bool:
"""删除示例"""
data = self._load_file(file_name)
examples = data.get('examples', [])
if index < 0 or index >= len(examples):
logger.error(f"索引超出范围: {index}, 总数: {len(examples)}")
return False
removed = examples.pop(index)
removed_str = removed if isinstance(removed, str) else removed.get('content', '')
logger.info(f"删除参考文献: {removed_str[:50]}...")
return self._save_file(file_name, data)
def _update_example(self, file_name: str, index: int, content: str) -> bool:
"""更新示例"""
data = self._load_file(file_name)
examples = data.get('examples', [])
if index < 0 or index >= len(examples):
logger.error(f"索引超出范围: {index}, 总数: {len(examples)}")
return False
examples[index] = content
logger.info(f"更新参考文献[{index}]: {content[:50]}...")
return self._save_file(file_name, data)
# 全局单例
_manager_instance = None
def get_reference_manager() -> ReferenceManager:
"""获取全局 ReferenceManager 实例"""
global _manager_instance
if _manager_instance is None:
_manager_instance = ReferenceManager()
return _manager_instance