diff --git a/poster_gen_config.json b/poster_gen_config.json index ea96449..551edbd 100644 --- a/poster_gen_config.json +++ b/poster_gen_config.json @@ -33,8 +33,8 @@ { "type": "Refer", "file_path": [ - "./genPrompts/Refer/标题参考格式.txt", - "./genPrompts/Refer/正文开头引入段落参考.txt" + "./genPrompts/Refer/标题参考格式.json", + "./genPrompts/Refer/正文开头引入段落参考.json" ] } ], diff --git a/scripts/test_refer.py b/scripts/test_refer.py new file mode 100644 index 0000000..8c8d0d3 --- /dev/null +++ b/scripts/test_refer.py @@ -0,0 +1,6 @@ +import json + +with open("genPrompts/Refer/标题参考格式.json", "r", encoding="utf-8") as f: + data = json.load(f) + +print(data) diff --git a/utils/__pycache__/prompt_manager.cpython-312.pyc b/utils/__pycache__/prompt_manager.cpython-312.pyc index e75407e..09861d8 100644 Binary files a/utils/__pycache__/prompt_manager.cpython-312.pyc and b/utils/__pycache__/prompt_manager.cpython-312.pyc differ diff --git a/utils/__pycache__/resource_loader.cpython-312.pyc b/utils/__pycache__/resource_loader.cpython-312.pyc index 5ecade9..8158426 100644 Binary files a/utils/__pycache__/resource_loader.cpython-312.pyc and b/utils/__pycache__/resource_loader.cpython-312.pyc differ diff --git a/utils/prompt_manager.py b/utils/prompt_manager.py index affb5c2..e3075eb 100644 --- a/utils/prompt_manager.py +++ b/utils/prompt_manager.py @@ -71,6 +71,7 @@ class PromptManager: if os.path.exists(dateline_path): self._dateline_cache = ResourceLoader.load_file_content(dateline_path) logging.info(f"预加载日期线文件: {dateline_path}") + # 加载prompts_config配置的文件 if not self.prompts_config: @@ -92,9 +93,12 @@ class PromptManager: elif prompt_type == "demand": for path in file_paths: + # print(path) if os.path.exists(path): filename = os.path.basename(path) + # print(filename) content = ResourceLoader.load_file_content(path) + # print(content) if content: self._demand_cache[filename] = content name_without_ext = os.path.splitext(filename)[0] @@ -102,11 +106,14 @@ class PromptManager: elif prompt_type == "refer": for path in file_paths: + # print(path) if os.path.exists(path): filename = os.path.basename(path) - content = ResourceLoader.load_file_content(path) + # print(filename) + content = ResourceLoader.load_all_refer_files(path) if content: self._refer_cache[filename] = content + # print(content) def find_directory_fuzzy_match(self, name, directory=None, files=None): """ diff --git a/utils/resource_loader.py b/utils/resource_loader.py index fb7175b..4cb4ef3 100644 --- a/utils/resource_loader.py +++ b/utils/resource_loader.py @@ -22,16 +22,14 @@ class ResourceLoader: return None @staticmethod - def load_all_refer_files(refer_dir, refer_content_length=50): - """加载Refer目录下的所有文件内容""" + def load_all_refer_files(file_path, refer_content_rate=0.5): + """加载Refer目录下的指定文件内容""" refer_content = "" - if not refer_dir or not os.path.isdir(refer_dir): - print(f"Warning: Refer directory '{refer_dir}' not found or invalid.") + if not file_path or not os.path.isfile(file_path): + print(f"Warning: Refer directory '{file_path}' not found or invalid.") return "" try: - files = os.listdir(refer_dir) - for file in files: - file_path = os.path.join(refer_dir, file) + if True: # print(file_path) if os.path.isfile(file_path) and file_path.endswith(".txt"): # Use the updated load_file_content content = ResourceLoader.load_file_content(file_path) @@ -39,18 +37,38 @@ class ResourceLoader: # 用\n分割content,取前length条 content_lines = content.split("\n") # Ensure refer_content_length doesn't exceed available lines - sample_size = min(refer_content_length, len(content_lines)) + sample_size = int(len(content_lines) * refer_content_rate) content_lines = random.sample(content_lines, sample_size) content = "\n".join(content_lines) - refer_content += f"## {file}\n{content}\n\n" + refer_content += f"## {file_path}\n{content}\n\n" elif os.path.isfile(file_path) and file_path.endswith(".json"): - # 读取json文件 - with open(file_path, 'r', encoding='utf-8') as f: - content = json.load(f) - - ## 随机进行多次抽样 - - refer_content += f"## {file}\n{content}\n\n" + try: + # 读取json文件 + with open(file_path, 'r', encoding='utf-8') as f: + file_content = json.load(f) + + # 检查必要的键是否存在 + if "title" not in file_content or "description" not in file_content or "examples" not in file_content: + print(f"Warning: JSON文件 '{file_path}' 缺少必要的键(title/description/examples)") + + title_content = file_content["title"] + description_content = file_content["description"] + examples = file_content["examples"] + + # 对examples进行采样 + if examples and isinstance(examples, list): + sample_size = max(1, int(len(examples) * refer_content_rate)) + sampled_examples = random.sample(examples, sample_size) + + # 格式化内容 + examples_formatted = json.dumps(sampled_examples, ensure_ascii=False, indent=2) + content = f"{title_content}\n{description_content}\n{examples_formatted}\n" + + refer_content += f"## {file_path}\n{content}\n\n" + else: + print(f"Warning: JSON文件 '{file_path}' 的examples不是有效列表") + except Exception as json_err: + print(f"处理JSON文件 '{file_path}' 失败: {json_err}") return refer_content except Exception as e: print(f"加载Refer目录文件失败: {e}")