diff --git a/utils/__pycache__/prompts.cpython-312.pyc b/utils/__pycache__/prompts.cpython-312.pyc index 03ca2c1..8a9ffe9 100644 Binary files a/utils/__pycache__/prompts.cpython-312.pyc and b/utils/__pycache__/prompts.cpython-312.pyc differ diff --git a/utils/prompts.py b/utils/prompts.py index 8dab0a1..69dc7ba 100644 --- a/utils/prompts.py +++ b/utils/prompts.py @@ -140,8 +140,30 @@ class BasePromptBuilder(PromptTemplate): data = json.load(f) if "examples" in data and isinstance(data["examples"], list): - formatted_examples = [f"- {item.get('content', '')}" for item in data["examples"]] - return f"参考标题列表:\n" + "\n".join(formatted_examples) + examples = data["examples"] + return f"参考标题列表:\n" + "\n".join([f"- {item.get('content', '')}" for item in examples]) + else: + return json.dumps(data, ensure_ascii=False, indent=2) + except Exception as e: + logger.error(f"解析或格式化JSON文件 '{path}' 失败: {e}") + return f"加载文件 '{path.name}' 失败。" + else: + return path.read_text('utf-8') + + def _load_and_format_content_with_sampling(self, path: Path, sampling_rate: float) -> str: + """根据文件类型加载和格式化内容,并应用采样率""" + if path.suffix == '.json': + try: + with path.open('r', encoding='utf-8') as f: + data = json.load(f) + + if "examples" in data and isinstance(data["examples"], list): + examples = data["examples"] + # 应用采样率 + sample_size = max(1, int(len(examples) * sampling_rate)) + sampled_examples = random.sample(examples, sample_size) + logger.info(f"文件 '{path.name}' 中的examples采样: {sample_size}/{len(examples)} (采样率: {sampling_rate:.2f})") + return f"参考标题列表:\n" + "\n".join([f"- {item.get('content', '')}" for item in sampled_examples]) else: return json.dumps(data, ensure_ascii=False, indent=2) except Exception as e: @@ -187,25 +209,37 @@ class BasePromptBuilder(PromptTemplate): full_path = self._get_full_path(path_str) - files_to_read = [] + # 简化逻辑:对于单个文件,直接应用采样率决定是否加载 if full_path.is_file(): - if random.random() < sampling_rate: - files_to_read.append(full_path) - logger.info(f"文件 '{path_str}' 采样成功 (采样率: {sampling_rate})") + # 对于JSON文件,对内容进行采样 + if full_path.suffix == '.json': + file_content = self._load_and_format_content_with_sampling(full_path, sampling_rate) + content_parts.append(f"--- {full_path.name} ---\n{file_content}") + logger.info(f"加载JSON文件 '{path_str}' 并应用内部采样") + # 对于其他文件,根据采样率决定是否完全加载 + elif random.random() < sampling_rate: + file_content = self._load_and_format_content(full_path) + content_parts.append(f"--- {full_path.name} ---\n{file_content}") + logger.info(f"文件 '{path_str}' 采样成功 (采样率: {sampling_rate:.2f})") else: - logger.info(f"文件 '{path_str}' 采样失败 (采样率: {sampling_rate})") + logger.info(f"文件 '{path_str}' 采样失败 (采样率: {sampling_rate:.2f})") + # 对于目录,直接选择指定比例的文件 elif full_path.is_dir(): all_files = sorted(p for p in full_path.iterdir() if p.is_file()) - if sampling_rate < 1.0: - num_to_sample = max(1, int(len(all_files) * sampling_rate)) - files_to_read = random.sample(all_files, num_to_sample) - logger.info(f"对目录 '{path_str}' 进行采样 (采样率: {sampling_rate}),选取 {len(files_to_read)}/{len(all_files)} 个文件。") + if all_files: + if sampling_rate < 1.0: + sample_size = max(1, int(len(all_files) * sampling_rate)) + files_to_read = random.sample(all_files, sample_size) + logger.info(f"目录 '{path_str}' 采样: {sample_size}/{len(all_files)} 个文件 (采样率: {sampling_rate:.2f})") + else: + files_to_read = all_files + logger.info(f"目录 '{path_str}' 全部加载: {len(all_files)} 个文件") + + for f_path in files_to_read: + file_content = self._load_and_format_content(f_path) + content_parts.append(f"--- {f_path.name} ---\n{file_content}") else: - files_to_read = all_files - - for f_path in files_to_read: - file_content = self._load_and_format_content(f_path) - content_parts.append(f"--- {f_path.name} ---\n{file_content}") + logger.warning(f"目录 '{path_str}' 中没有文件") except Exception as e: logger.error(f"加载Refer资源 '{ref_item}' 失败: {e}", exc_info=True)