149 lines
6.8 KiB
Python
149 lines
6.8 KiB
Python
import os
|
||
import random
|
||
import json
|
||
class ResourceLoader:
|
||
"""资源加载器,用于加载提示词和参考资料"""
|
||
|
||
@staticmethod
|
||
def load_file_content(file_path):
|
||
"""加载文件内容"""
|
||
try:
|
||
if os.path.exists(file_path):
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
return content
|
||
else:
|
||
print(f"文件不存在: {file_path}")
|
||
# Return None for non-existent file to distinguish from empty file
|
||
return None
|
||
except Exception as e:
|
||
print(f"加载文件 '{file_path}' 内容失败: {e}")
|
||
# Return None on error as well
|
||
return None
|
||
|
||
@staticmethod
|
||
def load_all_refer_files(file_path, refer_content_rate=0.5):
|
||
"""加载Refer目录下的指定文件内容"""
|
||
refer_content = ""
|
||
if not file_path or not os.path.isfile(file_path):
|
||
print(f"Warning: Refer directory '{file_path}' not found or invalid.")
|
||
return ""
|
||
try:
|
||
if True: # print(file_path)
|
||
if os.path.isfile(file_path) and file_path.endswith(".txt"):
|
||
# Use the updated load_file_content
|
||
content = ResourceLoader.load_file_content(file_path)
|
||
if content: # Check if content was loaded successfully
|
||
# 用\n分割content,取前length条
|
||
content_lines = content.split("\n")
|
||
# Ensure refer_content_length doesn't exceed available lines
|
||
sample_size = int(len(content_lines) * refer_content_rate)
|
||
content_lines = random.sample(content_lines, sample_size)
|
||
content = "\n".join(content_lines)
|
||
refer_content += f"## {file_path}\n{content}\n\n"
|
||
elif os.path.isfile(file_path) and file_path.endswith(".json"):
|
||
try:
|
||
# 读取json文件
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
file_content = json.load(f)
|
||
|
||
# 检查必要的键是否存在
|
||
if "title" not in file_content or "description" not in file_content or "examples" not in file_content:
|
||
print(f"Warning: JSON文件 '{file_path}' 缺少必要的键(title/description/examples)")
|
||
|
||
title_content = file_content["title"]
|
||
description_content = file_content["description"]
|
||
examples = file_content["examples"]
|
||
|
||
# 对examples进行采样
|
||
if examples and isinstance(examples, list):
|
||
sample_size = max(1, int(len(examples) * refer_content_rate))
|
||
sampled_examples = random.sample(examples, sample_size)
|
||
|
||
# 格式化内容
|
||
examples_formatted = json.dumps(sampled_examples, ensure_ascii=False, indent=2)
|
||
content = f"{title_content}\n{description_content}\n{examples_formatted}\n"
|
||
|
||
refer_content += f"## {file_path}\n{content}\n\n"
|
||
else:
|
||
print(f"Warning: JSON文件 '{file_path}' 的examples不是有效列表")
|
||
except Exception as json_err:
|
||
print(f"处理JSON文件 '{file_path}' 失败: {json_err}")
|
||
return refer_content
|
||
except Exception as e:
|
||
print(f"加载Refer目录文件失败: {e}")
|
||
return ""
|
||
|
||
@staticmethod
|
||
def find_file_by_name(directory, file_name, exact_match=True):
|
||
"""查找文件,支持精确匹配和模糊匹配"""
|
||
if not directory or not file_name:
|
||
return None
|
||
try:
|
||
# 确保传入的文件名包含后缀
|
||
if not file_name.endswith(".txt"):
|
||
file_name = f"{file_name}.txt"
|
||
|
||
# 精确匹配
|
||
exact_path = os.path.join(directory, file_name)
|
||
if os.path.exists(exact_path) and os.path.isfile(exact_path):
|
||
return exact_path
|
||
|
||
# 如果不需要精确匹配,尝试模糊匹配
|
||
if not exact_match and os.path.isdir(directory):
|
||
file_name_base = file_name.replace(".txt", "")
|
||
for file in os.listdir(directory):
|
||
if os.path.isfile(os.path.join(directory, file)) and file_name_base in file:
|
||
return os.path.join(directory, file)
|
||
|
||
return None
|
||
except Exception as e:
|
||
print(f"查找文件 '{file_name}' 在 '{directory}' 失败: {e}")
|
||
return None
|
||
|
||
@staticmethod
|
||
def create_summary_file(output_dir, run_id, topics_count):
|
||
"""创建汇总文件并返回路径"""
|
||
summary_file = os.path.join(output_dir, f"summary.md")
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
with open(summary_file, 'w', encoding='utf-8') as f:
|
||
f.write(f"# 小红书选题文案生成 - {run_id}\n\n")
|
||
f.write(f"共生成 {topics_count} 篇选题文案\n\n")
|
||
|
||
return summary_file
|
||
|
||
@staticmethod
|
||
def update_summary(summary_file, article_index, prompt, result):
|
||
"""更新汇总文件"""
|
||
try:
|
||
with open(summary_file, 'a', encoding='utf-8') as f:
|
||
f.write(f"## 文章 {article_index}\n\n")
|
||
f.write(f"### 选题信息\n")
|
||
f.write(f"```\n{prompt}\n```\n\n")
|
||
f.write(f"### 生成内容\n")
|
||
f.write(f"```\n{result}\n```\n\n")
|
||
f.write("--------------------------------\n\n")
|
||
except Exception as e:
|
||
print(f"更新汇总文件时出错: {e}")
|
||
|
||
@staticmethod
|
||
def save_article(result, prompt, output_dir, run_id, article_index, variant_index):
|
||
"""保存生成的文章到文件"""
|
||
try:
|
||
# 确保输出目录存在
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 创建文件名
|
||
filename = f"article_{article_index}_{variant_index}.txt"
|
||
filepath = os.path.join(output_dir, filename)
|
||
|
||
# 保存文件
|
||
with open(filepath, 'w', encoding='utf-8') as f:
|
||
f.write(f"prompt: {prompt}\n\n")
|
||
f.write(f"result: {result}\n")
|
||
|
||
return filepath
|
||
except Exception as e:
|
||
print(f"保存文章时出错: {e}")
|
||
return None |