TravelContentCreator/utils/tweet_generator.py

383 lines
15 KiB
Python
Raw Normal View History

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import time
import random
import argparse
import json
from datetime import datetime
import sys
2025-04-22 14:16:29 +08:00
import traceback
sys.path.append('/root/autodl-tmp')
# 从本地模块导入
from TravelContentCreator.core.ai_agent import AI_Agent
from TravelContentCreator.core.topic_parser import TopicParser
2025-04-22 14:19:21 +08:00
# ResourceLoader is now used implicitly via PromptManager
# from TravelContentCreator.utils.resource_loader import ResourceLoader
from TravelContentCreator.utils.prompt_manager import PromptManager # Import PromptManager
class tweetTopic:
def __init__(self, index, date, logic, object, product, product_logic, style, style_logic, target_audience, target_audience_logic):
self.index = index
self.date = date
self.logic = logic
self.object = object
self.product = product
self.product_logic = product_logic
self.style = style
self.style_logic = style_logic
self.target_audience = target_audience
self.target_audience_logic = target_audience_logic
class tweetTopicRecord:
def __init__(self, topics_list, system_prompt, user_prompt, output_dir, run_id):
self.topics_list = topics_list
self.system_prompt = system_prompt
self.user_prompt = user_prompt
self.output_dir = output_dir
self.run_id = run_id
def save_topics(self, path):
try:
with open(path, "w", encoding="utf-8") as f:
json.dump(self.topics_list, f, ensure_ascii=False, indent=4)
except Exception as e:
print(f"保存选题失败: {e}")
return False
return True
def save_prompt(self, path):
try:
with open(path, "w", encoding="utf-8") as f:
f.write(self.system_prompt + "\n")
f.write(self.user_prompt + "\n")
f.write(self.output_dir + "\n")
f.write(self.run_id + "\n")
except Exception as e:
print(f"保存提示词失败: {e}")
return False
return True
class tweetContent:
def __init__(self, result, prompt, output_dir, run_id, article_index, variant_index):
self.result = result
self.prompt = prompt
self.output_dir = output_dir
self.run_id = run_id
self.article_index = article_index
self.variant_index = variant_index
self.title, self.content = self.split_content(result)
self.json_file = self.gen_result_json()
def split_content(self, result):
## remove <\think>
result = result.split("</think>")[1]
## get tile
title = result.split("title>")[1].split("</title>")[0]
## get content
content = result.split("content>")[1].split("</content>")[0]
return title, content
def gen_result_json(self):
json_file = {
"title": self.title,
"content": self.content
}
return json_file
def save_content(self, json_path):
with open(json_path, "w", encoding="utf-8") as f:
json.dump(self.json_file, f, ensure_ascii=False, indent=4)
return json_path
def save_prompt(self, path):
with open(path, "w", encoding="utf-8") as f:
f.write(self.prompt + "\n")
return path
def get_content(self):
return self.content
def get_title(self):
return self.title
def get_json_file(self):
return self.json_file
def generate_topics(ai_agent, system_prompt, user_prompt, output_dir, temperature=0.2, top_p=0.5, presence_penalty=1.5):
"""生成选题列表"""
print("开始生成选题...")
# 记录开始时间
time_start = time.time()
# 生成选题
result, system_prompt, user_prompt, file_folder, file_name, tokens, time_cost = ai_agent.work(
system_prompt, user_prompt, "", temperature, top_p, presence_penalty
)
# 计算总耗时
time_end = time.time()
print(f"选题生成完成,耗时:{time_end - time_start}")
# 生成唯一ID
run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# 解析选题
result_list = TopicParser.parse_topics(result)
# success, json_path = TopicParser.save_topics(result_list, output_dir, run_id, result)
tweet_topic_record = tweetTopicRecord(result_list, system_prompt, user_prompt, output_dir, run_id)
return run_id, tweet_topic_record
2025-04-22 14:19:21 +08:00
def generate_single_content(ai_agent, system_prompt, user_prompt, item, output_dir, run_id,
2025-04-17 18:39:49 +08:00
article_index, variant_index, temperature=0.3, top_p=0.4, presence_penalty=1.5):
2025-04-22 14:19:21 +08:00
"""生成单篇文章内容. Requires prompts to be passed in."""
2025-04-17 16:14:41 +08:00
try:
2025-04-22 14:19:21 +08:00
# Prompts are now passed directly as arguments
# No longer build user_prompt here
# user_prompt = ResourceLoader.build_user_prompt(item, prompts_dir, resource_dir)
if not system_prompt or not user_prompt:
print("Error: System or User prompt is empty. Cannot generate content.")
return None, None
print(f"Using pre-constructed prompts. User prompt length: {len(user_prompt)}")
2025-04-17 16:14:41 +08:00
# 添加随机停顿,避免请求过于频繁
time.sleep(random.random() * 0.5 + 0.1)
# 生成文章
result, _, _, _, _, tokens, time_cost = ai_agent.work(
2025-04-17 18:39:49 +08:00
system_prompt, user_prompt, "", temperature, top_p, presence_penalty
2025-04-17 16:14:41 +08:00
)
print(f"生成完成tokens: {tokens}, 耗时: {time_cost}s")
# 保存到单独文件
tweet_content = tweetContent(result, user_prompt, output_dir, run_id, article_index, variant_index)
result_dir = os.path.join(output_dir, f"{article_index}_{variant_index}")
os.makedirs(result_dir, exist_ok=True)
tweet_content.save_content(os.path.join(result_dir, "article.json"))
tweet_content.save_prompt(os.path.join(result_dir, "tweet_prompt.txt"))
return tweet_content, result
except Exception as e:
print(f"生成单篇文章时出错: {e}")
return None, None
def generate_content(ai_agent, system_prompt, topics, output_dir, run_id, prompts_dir, resource_dir,
variants=2, temperature=0.3, start_index=0, end_index=None):
"""根据选题生成内容"""
if not topics:
print("没有选题,无法生成内容")
return
# 确定处理范围
if end_index is None or end_index > len(topics):
end_index = len(topics)
topics_to_process = topics[start_index:end_index]
print(f"准备处理{len(topics_to_process)}个选题...")
# 创建汇总文件
# summary_file = ResourceLoader.create_summary_file(output_dir, run_id, len(topics_to_process))
# 处理每个选题
processed_results = []
for i, item in enumerate(topics_to_process):
print(f"处理第 {i+1}/{len(topics_to_process)} 篇文章")
2025-04-17 16:14:41 +08:00
# 为每个选题生成多个变体
for j in range(variants):
print(f"正在生成变体 {j+1}/{variants}")
# 调用单篇文章生成函数
tweet_content, result = generate_single_content(
2025-04-22 14:19:21 +08:00
ai_agent, system_prompt, item, output_dir, run_id, i+1, j+1, temperature
2025-04-17 16:14:41 +08:00
)
if tweet_content:
processed_results.append(tweet_content)
# # 更新汇总文件 (仅保存第一个变体到汇总文件)
# if j == 0:
# ResourceLoader.update_summary(summary_file, i+1, user_prompt, result)
print(f"完成{len(processed_results)}篇文章生成")
return processed_results
def prepare_topic_generation(
2025-04-22 14:19:21 +08:00
config # Pass the whole config dictionary now
# select_date, select_num,
# system_prompt_path, user_prompt_path,
# base_url="vllm", model_name="qwenQWQ", api_key="EMPTY",
# gen_prompts_path="/root/autodl-tmp/TravelContentCreator/genPrompts",
# resource_dir="/root/autodl-tmp/TravelContentCreator/resource",
# output_dir="/root/autodl-tmp/TravelContentCreator/result"
):
2025-04-22 14:19:21 +08:00
"""准备选题生成的环境和参数. Returns agent and prompts."""
2025-04-22 14:19:21 +08:00
# Initialize PromptManager
prompt_manager = PromptManager(config)
2025-04-22 14:19:21 +08:00
# Get prompts using PromptManager
system_prompt, user_prompt = prompt_manager.get_topic_prompts()
2025-04-22 14:19:21 +08:00
if not system_prompt or not user_prompt:
print("Error: Failed to get topic generation prompts.")
return None, None, None, None
2025-04-22 14:19:21 +08:00
# 创建AI Agent (still create agent here for the topic generation phase)
try:
print("Initializing AI Agent for topic generation...")
ai_agent = AI_Agent(config["api_url"], config["model"], config["api_key"])
except Exception as e:
print(f"Error initializing AI Agent for topic generation: {e}")
traceback.print_exc()
return None, None, None, None
# Removed prompt loading/building logic, now handled by PromptManager
2025-04-22 14:19:21 +08:00
# Return agent and the generated prompts
return ai_agent, system_prompt, user_prompt, config["output_dir"]
2025-04-22 14:16:29 +08:00
def run_topic_generation_pipeline(config):
"""Runs the complete topic generation pipeline based on the configuration."""
print("Step 1: Generating Topics...")
# Prepare necessary inputs and the AI agent for topic generation
2025-04-22 14:19:21 +08:00
ai_agent, system_prompt, user_prompt, base_output_dir = None, None, None, None
2025-04-22 14:16:29 +08:00
try:
2025-04-22 14:19:21 +08:00
# Pass the config directly to prepare_topic_generation
ai_agent, system_prompt, user_prompt, base_output_dir = prepare_topic_generation(config)
if not ai_agent or not system_prompt or not user_prompt:
raise ValueError("Failed to prepare topic generation (agent or prompts missing).")
2025-04-22 14:16:29 +08:00
except Exception as e:
print(f"Error during topic generation preparation: {e}")
traceback.print_exc()
return None, None
# Generate topics using the prepared agent and prompts
try:
run_id, tweet_topic_record = generate_topics(
ai_agent, system_prompt, user_prompt, config["output_dir"],
config.get("topic_temperature", 0.2),
config.get("topic_top_p", 0.5),
config.get("topic_max_tokens", 1.5) # Consider if max_tokens name is accurate here (was presence_penalty?)
)
except Exception as e:
print(f"Error during topic generation API call: {e}")
traceback.print_exc()
if ai_agent: ai_agent.close() # Ensure agent is closed on error
return None, None
# Ensure the AI agent is closed after generation
if ai_agent:
ai_agent.close()
# Process results
if not run_id or not tweet_topic_record:
print("Topic generation failed (no run_id or topics returned).")
return None, None
output_dir = os.path.join(config["output_dir"], run_id)
try:
os.makedirs(output_dir, exist_ok=True)
# Save topics and prompt details
save_topics_success = tweet_topic_record.save_topics(os.path.join(output_dir, "tweet_topic.json"))
save_prompt_success = tweet_topic_record.save_prompt(os.path.join(output_dir, "tweet_prompt.txt"))
if not save_topics_success or not save_prompt_success:
print("Warning: Failed to save topic generation results or prompts.")
# Continue but warn user
except Exception as e:
print(f"Error saving topic generation results: {e}")
traceback.print_exc()
# Return the generated data even if saving fails, but maybe warn more strongly?
# return run_id, tweet_topic_record # Decide if partial success is okay
return None, None # Or consider failure if saving is critical
print(f"Topics generated successfully. Run ID: {run_id}")
return run_id, tweet_topic_record
def main():
"""主函数入口"""
config_file = {
"date": "4月17日",
"num": 5,
"model": "qwenQWQ",
"api_url": "vllm",
"api_key": "EMPTY",
"topic_system_prompt": "/root/autodl-tmp/TravelContentCreator/SelectPrompt/systemPrompt.txt",
"topic_user_prompt": "/root/autodl-tmp/TravelContentCreator/SelectPrompt/userPrompt.txt",
"content_system_prompt": "/root/autodl-tmp/TravelContentCreator/genPrompts/systemPrompt.txt",
"resource_dir": [{
"type": "Object",
"num": 4,
"file_path": ["/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-尚书第.txt",
"/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-明清园.txt",
"/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-泰宁古城.txt",
"/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-甘露寺.txt"
]},
{
"type": "Product",
"num": 0,
"file_path": []
}
],
"prompts_dir": "/root/autodl-tmp/TravelContentCreator/genPrompts",
"output_dir": "/root/autodl-tmp/TravelContentCreator/result",
"variants": 2,
"topic_temperature": 0.2,
"content_temperature": 0.3
}
if True:
# 1. 首先生成选题
ai_agent, system_prompt, user_prompt, output_dir = prepare_topic_generation(
2025-04-22 14:19:21 +08:00
config_file
)
run_id, tweet_topic_record = generate_topics(
ai_agent, system_prompt, user_prompt, config_file["output_dir"],
config_file["topic_temperature"], 0.5, 1.5
)
output_dir = os.path.join(config_file["output_dir"], run_id)
os.makedirs(output_dir, exist_ok=True)
tweet_topic_record.save_topics(os.path.join(output_dir, "tweet_topic.json"))
tweet_topic_record.save_prompt(os.path.join(output_dir, "tweet_prompt.txt"))
# raise Exception("选题生成失败,退出程序")
if not run_id or not tweet_topic_record:
print("选题生成失败,退出程序")
return
# 2. 然后生成内容
print("\n开始根据选题生成内容...")
# 加载内容生成的系统提示词
content_system_prompt = ResourceLoader.load_system_prompt(config_file["content_system_prompt"])
if not content_system_prompt:
print("内容生成系统提示词为空,使用选题生成的系统提示词")
content_system_prompt = system_prompt
# 直接使用同一个AI Agent实例
result = generate_content(
ai_agent, content_system_prompt, tweet_topic_record.topics_list, output_dir, run_id, config_file["prompts_dir"], config_file["resource_dir"],
config_file["variants"], config_file["content_temperature"]
)
if __name__ == "__main__":
main()