TravelContentCreator/utils/resource_loader.py

214 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import random
class ResourceLoader:
"""资源加载器,用于加载提示词和参考资料"""
@staticmethod
def load_file_content(file_path):
"""加载文件内容"""
try:
if os.path.exists(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return content
else:
print(f"文件不存在: {file_path}")
return ""
except Exception as e:
print(f"加载文件内容失败: {e}")
return ""
@staticmethod
def load_all_refer_files(refer_dir, refer_content_length=50):
"""加载Refer目录下的所有文件内容"""
refer_content = ""
try:
if os.path.exists(refer_dir):
files = os.listdir(refer_dir)
for file in files:
file_path = os.path.join(refer_dir, file)
if os.path.isfile(file_path):
content = ResourceLoader.load_file_content(file_path)
# 用\n分割content取前length条
content_lines = content.split("\n")
content_lines = random.sample(content_lines, refer_content_length)
content = "\n".join(content_lines)
refer_content += f"## {file}\n{content}\n\n"
return refer_content
except Exception as e:
print(f"加载Refer目录文件失败: {e}")
return ""
@staticmethod
def find_file_by_name(directory, file_name, exact_match=True):
"""查找文件,支持精确匹配和模糊匹配"""
try:
# 确保传入的文件名包含后缀
if not file_name.endswith(".txt"):
file_name = f"{file_name}.txt"
# 精确匹配
exact_path = os.path.join(directory, file_name)
if os.path.exists(exact_path):
return exact_path
# 如果不需要精确匹配,尝试模糊匹配
if not exact_match and os.path.exists(directory):
file_name_base = file_name.replace(".txt", "")
for file in os.listdir(directory):
if file_name_base in file:
return os.path.join(directory, file)
return None
except Exception as e:
print(f"查找文件失败: {e}")
return None
@staticmethod
def build_user_prompt(item, prompts_dir, resource_dir):
"""根据选题信息构建用户提示词"""
batch_prompt = f"选题日期:{item['date']}\n"
# 基础目录设置
style_dir = os.path.join(prompts_dir, "Style")
demand_dir = os.path.join(prompts_dir, "Demand")
refer_dir = os.path.join(prompts_dir, "Refer")
# 从resource_dir中获取Object和Product资源
object_resource = next((res for res in resource_dir if res["type"] == "Object"), None)
product_resource = next((res for res in resource_dir if res["type"] == "Product"), None)
# 添加所有标签内容到批处理提示中
if 'index' in item and item['index']:
batch_prompt += f"选题序号:{item['index']}\n"
if 'logic' in item and item['logic']:
batch_prompt += f"选定逻辑:{item['logic']}\n"
# 处理Object信息
if 'object' in item and item['object'] and object_resource:
batch_prompt += f"选定对象:{item['object']}\n"
# 查找包含对象名称的文件路径
object_file = None
for file_path in object_resource["file_path"]:
if item['object'] in file_path:
object_file = file_path
break
if object_file:
object_content = ResourceLoader.load_file_content(object_file)
batch_prompt += f"对象信息:\n{object_content}\n"
# 处理Product信息
if 'product' in item and item['product'] and product_resource:
batch_prompt += f"选定产品:{item['product']}\n"
# 尝试从产品名称中提取前缀
product_parts = item['product'].split("-")
product_prefix = product_parts[0] if product_parts and product_parts[0] else item['product']
# 查找包含产品前缀的文件路径
product_file = None
for file_path in product_resource["file_path"]:
if product_prefix in file_path:
product_file = file_path
break
if product_file:
product_content = ResourceLoader.load_file_content(product_file)
batch_prompt += f"产品信息:\n{product_content}\n"
# 处理Product Logic
if 'product_logic' in item and item['product_logic']:
batch_prompt += f"选定产品的逻辑:{item['product_logic']}\n"
# 处理Style信息
if 'style' in item and item['style']:
batch_prompt += f"选题风格:{item['style']}\n"
# 加载风格提示词
style_file = ResourceLoader.find_file_by_name(style_dir, item['style'], False)
if style_file:
style_content = ResourceLoader.load_file_content(style_file)
batch_prompt += f"风格提示词:\n{style_content}\n"
# 处理Style Logic
if 'style_logic' in item and item['style_logic']:
batch_prompt += f"选题风格的逻辑:{item['style_logic']}\n"
# 处理Target Audience
if 'target_audience' in item and item['target_audience']:
batch_prompt += f"目标受众:{item['target_audience']}\n"
# 尝试加载目标受众文件
audience_file = ResourceLoader.find_file_by_name(demand_dir, item['target_audience'], False)
if audience_file:
audience_content = ResourceLoader.load_file_content(audience_file)
batch_prompt += f"目标受众信息:\n{audience_content}\n"
# 处理Target Audience Logic
if 'target_audience_logic' in item and item['target_audience_logic']:
batch_prompt += f"目标受众逻辑:{item['target_audience_logic']}\n"
# 加载所有Refer文件
refer_content = ResourceLoader.load_all_refer_files(refer_dir)
if refer_content:
batch_prompt += f"\n参考资料:\n{refer_content}\n"
return batch_prompt
@staticmethod
def load_system_prompt(prompt_file):
"""加载系统提示词文件"""
try:
with open(prompt_file, 'r', encoding='utf-8') as f:
system_prompt = f.read()
return system_prompt
except Exception as e:
print(f"加载系统提示词文件失败: {e}")
return ""
@staticmethod
def create_summary_file(output_dir, run_id, topics_count):
"""创建汇总文件并返回路径"""
summary_file = os.path.join(output_dir, f"summary.md")
os.makedirs(output_dir, exist_ok=True)
with open(summary_file, 'w', encoding='utf-8') as f:
f.write(f"# 小红书选题文案生成 - {run_id}\n\n")
f.write(f"共生成 {topics_count} 篇选题文案\n\n")
return summary_file
@staticmethod
def update_summary(summary_file, article_index, prompt, result):
"""更新汇总文件"""
try:
with open(summary_file, 'a', encoding='utf-8') as f:
f.write(f"## 文章 {article_index}\n\n")
f.write(f"### 选题信息\n")
f.write(f"```\n{prompt}\n```\n\n")
f.write(f"### 生成内容\n")
f.write(f"```\n{result}\n```\n\n")
f.write("--------------------------------\n\n")
except Exception as e:
print(f"更新汇总文件时出错: {e}")
@staticmethod
def save_article(result, prompt, output_dir, run_id, article_index, variant_index):
"""保存生成的文章到文件"""
try:
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
# 创建文件名
filename = f"article_{article_index}_{variant_index}.txt"
filepath = os.path.join(output_dir, filename)
# 保存文件
with open(filepath, 'w', encoding='utf-8') as f:
f.write(f"prompt: {prompt}\n\n")
f.write(f"result: {result}\n")
return filepath
except Exception as e:
print(f"保存文章时出错: {e}")
return None