#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 参考文献库管理器 功能: 1. 加载和缓存参考文献 2. 按人群/风格智能匹配 3. 随机抽样 4. 支持增删改查 (通过文件操作) 使用方式: from domain.prompt.reference_manager import ReferenceManager manager = ReferenceManager() # 获取标题参考 (自动匹配) titles = manager.get_titles(audience_id='qinzi', count=20) # 获取正文范文 (自动匹配) contents = manager.get_contents(style_id='gonglue', count=3) # 添加新的标题参考 manager.add_title("新的爆款标题模板") # 列出所有参考文献 manager.list_all() """ import json import os import random import logging from typing import List, Dict, Any, Optional from pathlib import Path import yaml logger = logging.getLogger(__name__) class ReferenceManager: """参考文献库管理器""" def __init__(self, base_path: str = None): """ 初始化管理器 Args: base_path: 参考文献库根目录,默认为 prompts/reference """ if base_path is None: base_path = os.path.join( os.path.dirname(__file__), '../../prompts/reference' ) self.base_path = os.path.abspath(base_path) self._config = None self._cache = {} # 文件缓存 logger.info(f"ReferenceManager 初始化: {self.base_path}") # ==================== 读取接口 ==================== def get_titles(self, audience_id: str = None, style_id: str = None, count: int = None) -> List[str]: """ 获取标题参考 Args: audience_id: 人群 ID (用于智能匹配) style_id: 风格 ID (用于智能匹配) count: 抽取数量,None 表示使用配置默认值 Returns: 标题列表 """ config = self._get_config() titles_config = config.get('titles', {}).get('default', {}) # TODO: 实现按 audience_id/style_id 匹配 # 目前使用默认配置 file_name = titles_config.get('file', '标题参考格式.json') default_count = titles_config.get('sample_count', 20) if count is None: count = default_count # 加载文件 data = self._load_file(file_name) examples = self._extract_examples(data) # 随机抽样 if len(examples) <= count: return examples return random.sample(examples, count) def get_contents(self, audience_id: str = None, style_id: str = None, count: int = None) -> List[str]: """ 获取正文范文 Args: audience_id: 人群 ID style_id: 风格 ID count: 抽取数量 Returns: 正文范文列表 """ config = self._get_config() contents_config = config.get('contents', {}).get('default', {}) file_name = contents_config.get('file', '正文范文参考.json') default_count = contents_config.get('sample_count', 3) if count is None: count = default_count data = self._load_file(file_name) examples = self._extract_examples(data) if len(examples) <= count: return examples return random.sample(examples, count) def list_all(self) -> Dict[str, Any]: """ 列出所有参考文献统计 Returns: { 'titles': {'count': 100, 'file': '...'}, 'contents': {'count': 10, 'file': '...'} } """ config = self._get_config() result = {} # 标题统计 titles_file = config.get('titles', {}).get('default', {}).get('file') if titles_file: data = self._load_file(titles_file) examples = self._extract_examples(data) result['titles'] = { 'file': titles_file, 'count': len(examples), 'sample_count': config.get('titles', {}).get('default', {}).get('sample_count', 20) } # 正文统计 contents_file = config.get('contents', {}).get('default', {}).get('file') if contents_file: data = self._load_file(contents_file) examples = self._extract_examples(data) result['contents'] = { 'file': contents_file, 'count': len(examples), 'sample_count': config.get('contents', {}).get('default', {}).get('sample_count', 3) } return result def get_all_titles(self) -> List[str]: """获取所有标题参考""" config = self._get_config() file_name = config.get('titles', {}).get('default', {}).get('file', 'titles.yaml') data = self._load_file(file_name) return self._extract_examples(data) def get_all_contents(self) -> List[str]: """获取所有正文范文""" config = self._get_config() file_name = config.get('contents', {}).get('default', {}).get('file', 'contents.yaml') data = self._load_file(file_name) return self._extract_examples(data) # ==================== 写入接口 ==================== def add_title(self, content: str) -> bool: """ 添加新的标题参考 Args: content: 标题内容 Returns: 是否成功 """ config = self._get_config() file_name = config.get('titles', {}).get('default', {}).get('file', '标题参考格式.json') return self._add_example(file_name, content) def add_content(self, content: str) -> bool: """ 添加新的正文范文 Args: content: 正文内容 Returns: 是否成功 """ config = self._get_config() file_name = config.get('contents', {}).get('default', {}).get('file', '正文范文参考.json') return self._add_example(file_name, content) def remove_title(self, index: int) -> bool: """ 删除标题参考 Args: index: 索引位置 Returns: 是否成功 """ config = self._get_config() file_name = config.get('titles', {}).get('default', {}).get('file', '标题参考格式.json') return self._remove_example(file_name, index) def remove_content(self, index: int) -> bool: """ 删除正文范文 Args: index: 索引位置 Returns: 是否成功 """ config = self._get_config() file_name = config.get('contents', {}).get('default', {}).get('file', 'contents.yaml') return self._remove_example(file_name, index) def update_title(self, index: int, content: str) -> bool: """更新标题参考""" config = self._get_config() file_name = config.get('titles', {}).get('default', {}).get('file', 'titles.yaml') return self._update_example(file_name, index, content) def update_content(self, index: int, content: str) -> bool: """更新正文范文""" config = self._get_config() file_name = config.get('contents', {}).get('default', {}).get('file', 'contents.yaml') return self._update_example(file_name, index, content) def clear_cache(self): """清除缓存""" self._cache = {} self._config = None logger.info("参考文献缓存已清除") # ==================== 内部方法 ==================== def _get_config(self) -> Dict: """获取配置""" if self._config is not None: return self._config config_path = os.path.join(self.base_path, 'index.yaml') if os.path.exists(config_path): with open(config_path, 'r', encoding='utf-8') as f: self._config = yaml.safe_load(f) or {} else: # 默认配置 self._config = { 'titles': { 'default': { 'file': '标题参考格式.json', 'sample_count': 20 } }, 'contents': { 'default': { 'file': '正文范文参考.json', 'sample_count': 3 } } } return self._config def _extract_examples(self, data: Dict) -> List[str]: """ 从数据中提取示例列表 支持两种格式: 1. YAML 格式: examples 是字符串列表 2. JSON 格式: examples 是 [{"content": "..."}, ...] 列表 """ examples = data.get('examples', []) if not examples: return [] # 判断格式 if isinstance(examples[0], str): # YAML 格式: 直接是字符串列表 return examples elif isinstance(examples[0], dict): # JSON 格式: 需要提取 content 字段 return [ex.get('content', '') for ex in examples] else: return [] def _load_file(self, file_name: str) -> Dict: """加载参考文献文件 (支持 YAML 和 JSON,带缓存)""" if file_name in self._cache: return self._cache[file_name] file_path = os.path.join(self.base_path, file_name) try: with open(file_path, 'r', encoding='utf-8') as f: # 根据扩展名选择解析器 if file_name.endswith('.yaml') or file_name.endswith('.yml'): data = yaml.safe_load(f) or {} else: data = json.load(f) self._cache[file_name] = data return data except Exception as e: logger.error(f"加载参考文献失败: {file_path}, {e}") return {'examples': []} def _save_file(self, file_name: str, data: Dict) -> bool: """保存参考文献文件 (支持 YAML 和 JSON)""" file_path = os.path.join(self.base_path, file_name) try: with open(file_path, 'w', encoding='utf-8') as f: if file_name.endswith('.yaml') or file_name.endswith('.yml'): yaml.dump(data, f, allow_unicode=True, default_flow_style=False, sort_keys=False) else: json.dump(data, f, ensure_ascii=False, indent=4) # 清除缓存 if file_name in self._cache: del self._cache[file_name] return True except Exception as e: logger.error(f"保存参考文献失败: {file_path}, {e}") return False def _add_example(self, file_name: str, content: str) -> bool: """添加示例""" data = self._load_file(file_name) examples = self._extract_examples(data) # 检查重复 if content in examples: logger.warning(f"参考文献已存在: {content[:50]}...") return False data.setdefault('examples', []).append(content) return self._save_file(file_name, data) def _remove_example(self, file_name: str, index: int) -> bool: """删除示例""" data = self._load_file(file_name) examples = data.get('examples', []) if index < 0 or index >= len(examples): logger.error(f"索引超出范围: {index}, 总数: {len(examples)}") return False removed = examples.pop(index) removed_str = removed if isinstance(removed, str) else removed.get('content', '') logger.info(f"删除参考文献: {removed_str[:50]}...") return self._save_file(file_name, data) def _update_example(self, file_name: str, index: int, content: str) -> bool: """更新示例""" data = self._load_file(file_name) examples = data.get('examples', []) if index < 0 or index >= len(examples): logger.error(f"索引超出范围: {index}, 总数: {len(examples)}") return False examples[index] = content logger.info(f"更新参考文献[{index}]: {content[:50]}...") return self._save_file(file_name, data) # 全局单例 _manager_instance = None def get_reference_manager() -> ReferenceManager: """获取全局 ReferenceManager 实例""" global _manager_instance if _manager_instance is None: _manager_instance = ReferenceManager() return _manager_instance