214 lines
8.9 KiB
Python
214 lines
8.9 KiB
Python
import os
|
||
import random
|
||
|
||
class ResourceLoader:
|
||
"""资源加载器,用于加载提示词和参考资料"""
|
||
|
||
@staticmethod
|
||
def load_file_content(file_path):
|
||
"""加载文件内容"""
|
||
try:
|
||
if os.path.exists(file_path):
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
return content
|
||
else:
|
||
print(f"文件不存在: {file_path}")
|
||
return ""
|
||
except Exception as e:
|
||
print(f"加载文件内容失败: {e}")
|
||
return ""
|
||
|
||
@staticmethod
|
||
def load_all_refer_files(refer_dir, refer_content_length=50):
|
||
"""加载Refer目录下的所有文件内容"""
|
||
refer_content = ""
|
||
try:
|
||
if os.path.exists(refer_dir):
|
||
files = os.listdir(refer_dir)
|
||
for file in files:
|
||
file_path = os.path.join(refer_dir, file)
|
||
if os.path.isfile(file_path):
|
||
content = ResourceLoader.load_file_content(file_path)
|
||
# 用\n分割content,取前length条
|
||
content_lines = content.split("\n")
|
||
content_lines = random.sample(content_lines, refer_content_length)
|
||
content = "\n".join(content_lines)
|
||
refer_content += f"## {file}\n{content}\n\n"
|
||
return refer_content
|
||
except Exception as e:
|
||
print(f"加载Refer目录文件失败: {e}")
|
||
return ""
|
||
|
||
@staticmethod
|
||
def find_file_by_name(directory, file_name, exact_match=True):
|
||
"""查找文件,支持精确匹配和模糊匹配"""
|
||
try:
|
||
# 确保传入的文件名包含后缀
|
||
if not file_name.endswith(".txt"):
|
||
file_name = f"{file_name}.txt"
|
||
|
||
# 精确匹配
|
||
exact_path = os.path.join(directory, file_name)
|
||
if os.path.exists(exact_path):
|
||
return exact_path
|
||
|
||
# 如果不需要精确匹配,尝试模糊匹配
|
||
if not exact_match and os.path.exists(directory):
|
||
file_name_base = file_name.replace(".txt", "")
|
||
for file in os.listdir(directory):
|
||
if file_name_base in file:
|
||
return os.path.join(directory, file)
|
||
|
||
return None
|
||
except Exception as e:
|
||
print(f"查找文件失败: {e}")
|
||
return None
|
||
|
||
@staticmethod
|
||
def build_user_prompt(item, prompts_dir, resource_dir):
|
||
"""根据选题信息构建用户提示词"""
|
||
batch_prompt = f"选题日期:{item['date']}\n"
|
||
|
||
# 基础目录设置
|
||
style_dir = os.path.join(prompts_dir, "Style")
|
||
demand_dir = os.path.join(prompts_dir, "Demand")
|
||
refer_dir = os.path.join(prompts_dir, "Refer")
|
||
|
||
# 从resource_dir中获取Object和Product资源
|
||
object_resource = next((res for res in resource_dir if res["type"] == "Object"), None)
|
||
product_resource = next((res for res in resource_dir if res["type"] == "Product"), None)
|
||
|
||
# 添加所有标签内容到批处理提示中
|
||
if 'index' in item and item['index']:
|
||
batch_prompt += f"选题序号:{item['index']}\n"
|
||
if 'logic' in item and item['logic']:
|
||
batch_prompt += f"选定逻辑:{item['logic']}\n"
|
||
|
||
# 处理Object信息
|
||
if 'object' in item and item['object'] and object_resource:
|
||
batch_prompt += f"选定对象:{item['object']}\n"
|
||
# 查找包含对象名称的文件路径
|
||
object_file = None
|
||
for file_path in object_resource["file_path"]:
|
||
if item['object'] in file_path:
|
||
object_file = file_path
|
||
break
|
||
|
||
if object_file:
|
||
object_content = ResourceLoader.load_file_content(object_file)
|
||
batch_prompt += f"对象信息:\n{object_content}\n"
|
||
|
||
# 处理Product信息
|
||
if 'product' in item and item['product'] and product_resource:
|
||
batch_prompt += f"选定产品:{item['product']}\n"
|
||
# 尝试从产品名称中提取前缀
|
||
product_parts = item['product'].split("-")
|
||
product_prefix = product_parts[0] if product_parts and product_parts[0] else item['product']
|
||
|
||
# 查找包含产品前缀的文件路径
|
||
product_file = None
|
||
for file_path in product_resource["file_path"]:
|
||
if product_prefix in file_path:
|
||
product_file = file_path
|
||
break
|
||
|
||
if product_file:
|
||
product_content = ResourceLoader.load_file_content(product_file)
|
||
batch_prompt += f"产品信息:\n{product_content}\n"
|
||
|
||
# 处理Product Logic
|
||
if 'product_logic' in item and item['product_logic']:
|
||
batch_prompt += f"选定产品的逻辑:{item['product_logic']}\n"
|
||
|
||
# 处理Style信息
|
||
if 'style' in item and item['style']:
|
||
batch_prompt += f"选题风格:{item['style']}\n"
|
||
|
||
# 加载风格提示词
|
||
style_file = ResourceLoader.find_file_by_name(style_dir, item['style'], False)
|
||
if style_file:
|
||
style_content = ResourceLoader.load_file_content(style_file)
|
||
batch_prompt += f"风格提示词:\n{style_content}\n"
|
||
|
||
# 处理Style Logic
|
||
if 'style_logic' in item and item['style_logic']:
|
||
batch_prompt += f"选题风格的逻辑:{item['style_logic']}\n"
|
||
|
||
# 处理Target Audience
|
||
if 'target_audience' in item and item['target_audience']:
|
||
batch_prompt += f"目标受众:{item['target_audience']}\n"
|
||
# 尝试加载目标受众文件
|
||
audience_file = ResourceLoader.find_file_by_name(demand_dir, item['target_audience'], False)
|
||
if audience_file:
|
||
audience_content = ResourceLoader.load_file_content(audience_file)
|
||
batch_prompt += f"目标受众信息:\n{audience_content}\n"
|
||
|
||
# 处理Target Audience Logic
|
||
if 'target_audience_logic' in item and item['target_audience_logic']:
|
||
batch_prompt += f"目标受众逻辑:{item['target_audience_logic']}\n"
|
||
|
||
# 加载所有Refer文件
|
||
refer_content = ResourceLoader.load_all_refer_files(refer_dir)
|
||
if refer_content:
|
||
batch_prompt += f"\n参考资料:\n{refer_content}\n"
|
||
|
||
return batch_prompt
|
||
|
||
@staticmethod
|
||
def load_system_prompt(prompt_file):
|
||
"""加载系统提示词文件"""
|
||
try:
|
||
with open(prompt_file, 'r', encoding='utf-8') as f:
|
||
system_prompt = f.read()
|
||
return system_prompt
|
||
except Exception as e:
|
||
print(f"加载系统提示词文件失败: {e}")
|
||
return ""
|
||
|
||
@staticmethod
|
||
def create_summary_file(output_dir, run_id, topics_count):
|
||
"""创建汇总文件并返回路径"""
|
||
summary_file = os.path.join(output_dir, f"summary.md")
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
with open(summary_file, 'w', encoding='utf-8') as f:
|
||
f.write(f"# 小红书选题文案生成 - {run_id}\n\n")
|
||
f.write(f"共生成 {topics_count} 篇选题文案\n\n")
|
||
|
||
return summary_file
|
||
|
||
@staticmethod
|
||
def update_summary(summary_file, article_index, prompt, result):
|
||
"""更新汇总文件"""
|
||
try:
|
||
with open(summary_file, 'a', encoding='utf-8') as f:
|
||
f.write(f"## 文章 {article_index}\n\n")
|
||
f.write(f"### 选题信息\n")
|
||
f.write(f"```\n{prompt}\n```\n\n")
|
||
f.write(f"### 生成内容\n")
|
||
f.write(f"```\n{result}\n```\n\n")
|
||
f.write("--------------------------------\n\n")
|
||
except Exception as e:
|
||
print(f"更新汇总文件时出错: {e}")
|
||
|
||
@staticmethod
|
||
def save_article(result, prompt, output_dir, run_id, article_index, variant_index):
|
||
"""保存生成的文章到文件"""
|
||
try:
|
||
# 确保输出目录存在
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 创建文件名
|
||
filename = f"article_{article_index}_{variant_index}.txt"
|
||
filepath = os.path.join(output_dir, filename)
|
||
|
||
# 保存文件
|
||
with open(filepath, 'w', encoding='utf-8') as f:
|
||
f.write(f"prompt: {prompt}\n\n")
|
||
f.write(f"result: {result}\n")
|
||
|
||
return filepath
|
||
except Exception as e:
|
||
print(f"保存文章时出错: {e}")
|
||
return None |