From 5c5d03d2a32e6113058bfb5ef3766f5da96444a0 Mon Sep 17 00:00:00 2001 From: jinye_huang Date: Thu, 8 May 2025 14:45:18 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86=E6=8F=90=E7=A4=BA?= =?UTF-8?q?=E8=AF=8D=E9=9A=8F=E6=9C=BA=E8=B0=83=E5=BA=A6=E7=9A=84=E6=96=B9?= =?UTF-8?q?=E6=A1=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- utils/prompt_manager.py | 236 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 222 insertions(+), 14 deletions(-) diff --git a/utils/prompt_manager.py b/utils/prompt_manager.py index 3c68b74..be881c1 100644 --- a/utils/prompt_manager.py +++ b/utils/prompt_manager.py @@ -8,7 +8,9 @@ import os import traceback import logging # Add logging import re # 添加正则表达式支持 +import random # 添加随机模块支持 from .resource_loader import ResourceLoader # Use relative import within the same package +import json class PromptManager: """Handles the loading and construction of prompts.""" @@ -40,6 +42,7 @@ class PromptManager: self._user_prompt_cache = {} # 新增:用户提示词缓存 self._dateline_cache = None # 新增:日期线缓存 + self._sample_rate = 0.4 # 提高随机抽样率 # 初始化时预加载配置的文件 self._preload_prompt_files() @@ -110,7 +113,7 @@ class PromptManager: if os.path.exists(path): filename = os.path.basename(path) # print(filename) - content = ResourceLoader.load_all_refer_files(path, 0.25) + content = ResourceLoader.load_all_refer_files(path, 1) if content: self._refer_cache[filename] = content # print(content) @@ -301,13 +304,118 @@ class PromptManager: logging.warning(f"未能找到Demand文件: '{demand_name}',尝试过以下位置: 缓存, {self.prompts_dir}/Demand/") return None - def _get_all_refer_contents(self): - """获取所有Refer文件内容""" + def _get_all_refer_contents(self, random_sample=True): + """获取所有Refer文件内容,可选择随机抽样 + + Args: + random_sample: 是否进行随机抽样,默认为True + + Returns: + str: 组合后的refer内容 + """ # 如果缓存中有内容,先使用缓存 if self._refer_cache: refer_content_all = "" - for filename, content in self._refer_cache.items(): - refer_content_all += f"--- Refer File: {filename} ---\n{content}\n\n" + + # 获取所有refer文件名列表 + refer_files = list(self._refer_cache.keys()) + + if random_sample and len(refer_files) > 1: + # 随机决定要使用的refer文件数量,至少使用一个 + sample_size = max(1, int(len(refer_files) * self._sample_rate)) + # 随机选择要使用的refer文件 + sampled_files = random.sample(refer_files, sample_size) + logging.info(f"随机抽样了{sample_size}/{len(refer_files)}个refer文件: {sampled_files}") + + for filename in sampled_files: + content = self._refer_cache[filename] + + # 判断是否为JSON格式 + is_json = False + json_data = None + if filename.lower().endswith('.json'): + try: + json_data = json.loads(content) + is_json = True + logging.info(f"检测到JSON格式的refer文件: {filename}") + except json.JSONDecodeError: + logging.warning(f"文件{filename}扩展名为.json但内容不是有效的JSON格式") + + if is_json and json_data: + # 处理JSON格式的refer文件 + title = json_data.get("title", "未命名参考资料") + description = json_data.get("description", "") + examples = json_data.get("examples", []) + + refer_content_all += f"--- Refer File: {filename} ---\n" + refer_content_all += f"标题: {title}\n" + refer_content_all += f"描述: {description}\n\n" + + if examples and random_sample: + # 对examples数组进行抽样 + sample_examples_count = max(5, int(len(examples) * self._sample_rate)) + sampled_examples = random.sample(examples, sample_examples_count) + logging.info(f"从文件{filename}中随机抽样了{sample_examples_count}/{len(examples)}个示例") + + refer_content_all += "示例:\n" + for idx, example in enumerate(sampled_examples, 1): + content_text = example.get("content", "") + refer_content_all += f"{idx}. {content_text}\n" + elif examples: + # 不抽样,使用所有examples + refer_content_all += "示例:\n" + for idx, example in enumerate(examples, 1): + content_text = example.get("content", "") + refer_content_all += f"{idx}. {content_text}\n" + + refer_content_all += "\n" + else: + # 处理普通文本格式的refer文件 + # 对文件内容进行行级抽样 + if random_sample and content: + lines = content.split('\n') + if len(lines) > 5: # 只对较长的内容进行抽样 + sample_lines_count = max(5, int(len(lines) * self._sample_rate)) + sampled_lines = random.sample(lines, sample_lines_count) + # 保持原有顺序 + sampled_lines.sort(key=lambda line: lines.index(line)) + content = '\n'.join(sampled_lines) + logging.info(f"从文件{filename}中随机抽样了{sample_lines_count}/{len(lines)}行内容") + + refer_content_all += f"--- Refer File: {filename} ---\n{content}\n\n" + else: + # 不进行抽样,使用所有文件 + for filename, content in self._refer_cache.items(): + # 判断是否为JSON格式 + is_json = False + json_data = None + if filename.lower().endswith('.json'): + try: + json_data = json.loads(content) + is_json = True + except json.JSONDecodeError: + logging.warning(f"文件{filename}扩展名为.json但内容不是有效的JSON格式") + + if is_json and json_data: + # 处理JSON格式的refer文件 + title = json_data.get("title", "未命名参考资料") + description = json_data.get("description", "") + examples = json_data.get("examples", []) + + refer_content_all += f"--- Refer File: {filename} ---\n" + refer_content_all += f"标题: {title}\n" + refer_content_all += f"描述: {description}\n\n" + refer_content_all += "示例:\n" + + for idx, example in enumerate(examples, 1): + content_text = example.get("content", "") + refer_content_all += f"{idx}. {content_text}\n" + + refer_content_all += "\n" + else: + # 处理普通文本格式的refer文件 + refer_content_all += f"--- Refer File: {filename} ---\n{content}\n\n" + return refer_content_all # 如果缓存为空,尝试从prompts_dir加载(向后兼容) @@ -316,13 +424,113 @@ class PromptManager: refer_dir = os.path.join(self.prompts_dir, "Refer") if os.path.isdir(refer_dir): refer_files = [f for f in os.listdir(refer_dir) if os.path.isfile(os.path.join(refer_dir, f))] - for refer_file in refer_files: - refer_path = os.path.join(refer_dir, refer_file) - content = ResourceLoader.load_file_content(refer_path) - if content: - refer_content_all += f"--- Refer File: {refer_file} ---\n{content}\n\n" - # 保存到缓存 - self._refer_cache[refer_file] = content + + if random_sample and len(refer_files) > 1: + # 随机决定要使用的refer文件数量,至少使用一个 + sample_size = max(1, int(len(refer_files) * self._sample_rate)) + # 随机选择要使用的refer文件 + sampled_files = random.sample(refer_files, sample_size) + logging.info(f"从目录随机抽样了{sample_size}/{len(refer_files)}个refer文件: {sampled_files}") + + for refer_file in sampled_files: + refer_path = os.path.join(refer_dir, refer_file) + content = ResourceLoader.load_file_content(refer_path) + + if content: + # 判断是否为JSON格式 + is_json = False + json_data = None + if refer_file.lower().endswith('.json'): + try: + json_data = json.loads(content) + is_json = True + logging.info(f"从目录检测到JSON格式的refer文件: {refer_file}") + except json.JSONDecodeError: + logging.warning(f"目录中的文件{refer_file}扩展名为.json但内容不是有效的JSON格式") + + if is_json and json_data: + # 处理JSON格式的refer文件 + title = json_data.get("title", "未命名参考资料") + description = json_data.get("description", "") + examples = json_data.get("examples", []) + + refer_content_all += f"--- Refer File: {refer_file} ---\n" + refer_content_all += f"标题: {title}\n" + refer_content_all += f"描述: {description}\n\n" + + if examples and random_sample: + # 对examples数组进行抽样 + sample_examples_count = max(5, int(len(examples) * self._sample_rate)) + sampled_examples = random.sample(examples, sample_examples_count) + logging.info(f"从文件{refer_file}中随机抽样了{sample_examples_count}/{len(examples)}个示例") + + refer_content_all += "示例:\n" + for idx, example in enumerate(sampled_examples, 1): + content_text = example.get("content", "") + refer_content_all += f"{idx}. {content_text}\n" + elif examples: + # 不抽样,使用所有examples + refer_content_all += "示例:\n" + for idx, example in enumerate(examples, 1): + content_text = example.get("content", "") + refer_content_all += f"{idx}. {content_text}\n" + + refer_content_all += "\n" + else: + # 对文件内容进行行级抽样 + if random_sample: + lines = content.split('\n') + if len(lines) > 5: # 只对较长的内容进行抽样 + sample_lines_count = max(5, int(len(lines) * self._sample_rate)) + sampled_lines = random.sample(lines, sample_lines_count) + # 保持原有顺序 + sampled_lines.sort(key=lambda line: lines.index(line)) + content = '\n'.join(sampled_lines) + logging.info(f"从目录文件{refer_file}中随机抽样了{sample_lines_count}/{len(lines)}行内容") + + refer_content_all += f"--- Refer File: {refer_file} ---\n{content}\n\n" + + # 保存到缓存 + self._refer_cache[refer_file] = content + else: + # 不进行抽样,使用所有文件 + for refer_file in refer_files: + refer_path = os.path.join(refer_dir, refer_file) + content = ResourceLoader.load_file_content(refer_path) + + if content: + # 判断是否为JSON格式 + is_json = False + json_data = None + if refer_file.lower().endswith('.json'): + try: + json_data = json.loads(content) + is_json = True + except json.JSONDecodeError: + logging.warning(f"目录中的文件{refer_file}扩展名为.json但内容不是有效的JSON格式") + + if is_json and json_data: + # 处理JSON格式的refer文件 + title = json_data.get("title", "未命名参考资料") + description = json_data.get("description", "") + examples = json_data.get("examples", []) + + refer_content_all += f"--- Refer File: {refer_file} ---\n" + refer_content_all += f"标题: {title}\n" + refer_content_all += f"描述: {description}\n\n" + refer_content_all += "示例:\n" + + for idx, example in enumerate(examples, 1): + content_text = example.get("content", "") + refer_content_all += f"{idx}. {content_text}\n" + + refer_content_all += "\n" + else: + # 处理普通文本格式的refer文件 + refer_content_all += f"--- Refer File: {refer_file} ---\n{content}\n\n" + + # 保存到缓存 + self._refer_cache[refer_file] = content return refer_content_all @@ -469,8 +677,8 @@ class PromptManager: else: logging.warning(f"Demand content for '{demand_name}' not found.") - # Add refer contents - refers_content = self._get_all_refer_contents() + # Add refer contents - 现在使用随机抽样 + refers_content = self._get_all_refer_contents(random_sample=True) if refers_content: refers = f"Reference:\n{refers_content}\n\n"