refer模块正确抽样

This commit is contained in:
jinye_huang 2025-05-08 15:21:02 +08:00
parent 5c5d03d2a3
commit bbaa6eee53

View File

@ -305,232 +305,112 @@ class PromptManager:
return None return None
def _get_all_refer_contents(self, random_sample=True): def _get_all_refer_contents(self, random_sample=True):
"""获取所有Refer文件内容可选择随机抽样 """获取所有Refer文件内容可选择随机抽样文件内容
Args: Args:
random_sample: 是否进行随机抽样默认为True random_sample: 是否对文件内容进行随机抽样默认为True
Returns: Returns:
str: 组合后的refer内容 str: 组合后的refer内容
""" """
# 如果缓存中有内容,先使用缓存 import json
# 初始化结果字符串
refer_content_all = ""
# 准备处理所有文件
all_refer_files = {}
# 1. 从缓存中获取文件
if self._refer_cache: if self._refer_cache:
refer_content_all = "" all_refer_files.update(self._refer_cache)
# 获取所有refer文件名列表 # 2. 从本地目录获取其他文件(如果有的话)
refer_files = list(self._refer_cache.keys())
if random_sample and len(refer_files) > 1:
# 随机决定要使用的refer文件数量至少使用一个
sample_size = max(1, int(len(refer_files) * self._sample_rate))
# 随机选择要使用的refer文件
sampled_files = random.sample(refer_files, sample_size)
logging.info(f"随机抽样了{sample_size}/{len(refer_files)}个refer文件: {sampled_files}")
for filename in sampled_files:
content = self._refer_cache[filename]
# 判断是否为JSON格式
is_json = False
json_data = None
if filename.lower().endswith('.json'):
try:
json_data = json.loads(content)
is_json = True
logging.info(f"检测到JSON格式的refer文件: {filename}")
except json.JSONDecodeError:
logging.warning(f"文件{filename}扩展名为.json但内容不是有效的JSON格式")
if is_json and json_data:
# 处理JSON格式的refer文件
title = json_data.get("title", "未命名参考资料")
description = json_data.get("description", "")
examples = json_data.get("examples", [])
refer_content_all += f"--- Refer File: {filename} ---\n"
refer_content_all += f"标题: {title}\n"
refer_content_all += f"描述: {description}\n\n"
if examples and random_sample:
# 对examples数组进行抽样
sample_examples_count = max(5, int(len(examples) * self._sample_rate))
sampled_examples = random.sample(examples, sample_examples_count)
logging.info(f"从文件{filename}中随机抽样了{sample_examples_count}/{len(examples)}个示例")
refer_content_all += "示例:\n"
for idx, example in enumerate(sampled_examples, 1):
content_text = example.get("content", "")
refer_content_all += f"{idx}. {content_text}\n"
elif examples:
# 不抽样使用所有examples
refer_content_all += "示例:\n"
for idx, example in enumerate(examples, 1):
content_text = example.get("content", "")
refer_content_all += f"{idx}. {content_text}\n"
refer_content_all += "\n"
else:
# 处理普通文本格式的refer文件
# 对文件内容进行行级抽样
if random_sample and content:
lines = content.split('\n')
if len(lines) > 5: # 只对较长的内容进行抽样
sample_lines_count = max(5, int(len(lines) * self._sample_rate))
sampled_lines = random.sample(lines, sample_lines_count)
# 保持原有顺序
sampled_lines.sort(key=lambda line: lines.index(line))
content = '\n'.join(sampled_lines)
logging.info(f"从文件{filename}中随机抽样了{sample_lines_count}/{len(lines)}行内容")
refer_content_all += f"--- Refer File: {filename} ---\n{content}\n\n"
else:
# 不进行抽样,使用所有文件
for filename, content in self._refer_cache.items():
# 判断是否为JSON格式
is_json = False
json_data = None
if filename.lower().endswith('.json'):
try:
json_data = json.loads(content)
is_json = True
except json.JSONDecodeError:
logging.warning(f"文件{filename}扩展名为.json但内容不是有效的JSON格式")
if is_json and json_data:
# 处理JSON格式的refer文件
title = json_data.get("title", "未命名参考资料")
description = json_data.get("description", "")
examples = json_data.get("examples", [])
refer_content_all += f"--- Refer File: {filename} ---\n"
refer_content_all += f"标题: {title}\n"
refer_content_all += f"描述: {description}\n\n"
refer_content_all += "示例:\n"
for idx, example in enumerate(examples, 1):
content_text = example.get("content", "")
refer_content_all += f"{idx}. {content_text}\n"
refer_content_all += "\n"
else:
# 处理普通文本格式的refer文件
refer_content_all += f"--- Refer File: {filename} ---\n{content}\n\n"
return refer_content_all
# 如果缓存为空尝试从prompts_dir加载向后兼容
refer_content_all = ""
if self.prompts_dir: if self.prompts_dir:
refer_dir = os.path.join(self.prompts_dir, "Refer") refer_dir = os.path.join(self.prompts_dir, "Refer")
if os.path.isdir(refer_dir): if os.path.isdir(refer_dir):
refer_files = [f for f in os.listdir(refer_dir) if os.path.isfile(os.path.join(refer_dir, f))] dir_files = [f for f in os.listdir(refer_dir) if os.path.isfile(os.path.join(refer_dir, f))]
for refer_file in dir_files:
if random_sample and len(refer_files) > 1: if refer_file not in all_refer_files:
# 随机决定要使用的refer文件数量至少使用一个
sample_size = max(1, int(len(refer_files) * self._sample_rate))
# 随机选择要使用的refer文件
sampled_files = random.sample(refer_files, sample_size)
logging.info(f"从目录随机抽样了{sample_size}/{len(refer_files)}个refer文件: {sampled_files}")
for refer_file in sampled_files:
refer_path = os.path.join(refer_dir, refer_file) refer_path = os.path.join(refer_dir, refer_file)
content = ResourceLoader.load_file_content(refer_path) content = ResourceLoader.load_file_content(refer_path)
if content: if content:
# 判断是否为JSON格式 all_refer_files[refer_file] = content
# 保存到缓存
self._refer_cache[refer_file] = content
if not all_refer_files:
logging.warning("没有找到任何Refer文件")
return refer_content_all
logging.info(f"找到{len(all_refer_files)}个Refer文件")
# 3. 处理所有文件
for filename, content in all_refer_files.items():
# 检查是否为JSON格式尝试解析
is_json = False is_json = False
json_data = None json_data = None
if refer_file.lower().endswith('.json'):
if filename.lower().endswith('.json'):
try: try:
json_data = json.loads(content) json_data = json.loads(content)
is_json = True is_json = True
logging.info(f"从目录检测到JSON格式的refer文件: {refer_file}") logging.info(f"成功解析JSON格式的refer文件: {filename}")
except json.JSONDecodeError: except json.JSONDecodeError as e:
logging.warning(f"目录中的文件{refer_file}扩展名为.json但内容不是有效的JSON格式") logging.warning(f"文件{filename}扩展名为.json但内容不是有效的JSON格式: {str(e)}")
logging.info(f"将以文本格式处理文件: {filename}")
# 添加文件头部信息
refer_content_all += f"--- Refer File: {filename} ---\n"
if is_json and json_data: if is_json and json_data:
# 处理JSON格式的refer文件 # 处理JSON格式文件
title = json_data.get("title", "未命名参考资料") title = json_data.get("title", "未命名参考资料")
description = json_data.get("description", "") description = json_data.get("description", "")
examples = json_data.get("examples", []) examples = json_data.get("examples", [])
refer_content_all += f"--- Refer File: {refer_file} ---\n"
refer_content_all += f"标题: {title}\n" refer_content_all += f"标题: {title}\n"
refer_content_all += f"描述: {description}\n\n" refer_content_all += f"描述: {description}\n\n"
if examples and random_sample: if examples:
# 对examples数组进行抽样 # 处理examples数组
sample_examples_count = max(5, int(len(examples) * self._sample_rate)) if random_sample and len(examples) > 10:
sampled_examples = random.sample(examples, sample_examples_count) # 对examples进行随机抽样
logging.info(f"从文件{refer_file}中随机抽样了{sample_examples_count}/{len(examples)}个示例") sample_size = max(10, int(len(examples) * self._sample_rate))
sampled_examples = random.sample(examples, sample_size)
logging.info(f"从文件{filename}的JSON中随机抽样了{sample_size}/{len(examples)}个示例")
refer_content_all += "示例:\n" refer_content_all += "示例:\n"
for idx, example in enumerate(sampled_examples, 1): for idx, example in enumerate(sampled_examples, 1):
content_text = example.get("content", "") content_text = example.get("content", "")
refer_content_all += f"{idx}. {content_text}\n" refer_content_all += f"{idx}. {content_text}\n"
elif examples: else:
# 不抽样使用所有examples # 不进行抽样或examples数量较少使用全部
refer_content_all += "示例:\n" refer_content_all += "示例:\n"
for idx, example in enumerate(examples, 1): for idx, example in enumerate(examples, 1):
content_text = example.get("content", "") content_text = example.get("content", "")
refer_content_all += f"{idx}. {content_text}\n" refer_content_all += f"{idx}. {content_text}\n"
refer_content_all += "\n"
else: else:
# 对文件内容进行行级抽样 # 处理普通文本文件
if random_sample: if random_sample:
lines = content.split('\n') lines = content.split('\n')
if len(lines) > 5: # 只对较长的内容进行抽样 if len(lines) > 10: # 只对较长的内容进行抽样
sample_lines_count = max(5, int(len(lines) * self._sample_rate)) sample_size = max(10, int(len(lines) * self._sample_rate))
sampled_lines = random.sample(lines, sample_lines_count) sampled_lines = random.sample(lines, sample_size)
# 保持原有顺序 # 保持原有顺序
sampled_lines.sort(key=lambda line: lines.index(line)) sampled_lines.sort(key=lambda line: lines.index(line))
content = '\n'.join(sampled_lines) sampled_content = '\n'.join(sampled_lines)
logging.info(f"从目录文件{refer_file}中随机抽样了{sample_lines_count}/{len(lines)}行内容") logging.info(f"从文件{filename}中随机抽样了{sample_size}/{len(lines)}行内容")
refer_content_all += f"{sampled_content}\n"
refer_content_all += f"--- Refer File: {refer_file} ---\n{content}\n\n"
# 保存到缓存
self._refer_cache[refer_file] = content
else: else:
# 不进行抽样,使用所有文件 # 内容较短,不进行抽样
for refer_file in refer_files: refer_content_all += f"{content}\n"
refer_path = os.path.join(refer_dir, refer_file) else:
content = ResourceLoader.load_file_content(refer_path) # 不进行抽样
refer_content_all += f"{content}\n"
if content:
# 判断是否为JSON格式
is_json = False
json_data = None
if refer_file.lower().endswith('.json'):
try:
json_data = json.loads(content)
is_json = True
except json.JSONDecodeError:
logging.warning(f"目录中的文件{refer_file}扩展名为.json但内容不是有效的JSON格式")
if is_json and json_data:
# 处理JSON格式的refer文件
title = json_data.get("title", "未命名参考资料")
description = json_data.get("description", "")
examples = json_data.get("examples", [])
refer_content_all += f"--- Refer File: {refer_file} ---\n"
refer_content_all += f"标题: {title}\n"
refer_content_all += f"描述: {description}\n\n"
refer_content_all += "示例:\n"
for idx, example in enumerate(examples, 1):
content_text = example.get("content", "")
refer_content_all += f"{idx}. {content_text}\n"
# 添加文件之间的分隔
refer_content_all += "\n" refer_content_all += "\n"
else:
# 处理普通文本格式的refer文件
refer_content_all += f"--- Refer File: {refer_file} ---\n{content}\n\n"
# 保存到缓存
self._refer_cache[refer_file] = content
return refer_content_all return refer_content_all