TravelContentCreator/utils/tweet_generator.py

383 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import time
import random
import argparse
import json
from datetime import datetime
import sys
import traceback
sys.path.append('/root/autodl-tmp')
# 从本地模块导入
from TravelContentCreator.core.ai_agent import AI_Agent
from TravelContentCreator.core.topic_parser import TopicParser
# ResourceLoader is now used implicitly via PromptManager
# from TravelContentCreator.utils.resource_loader import ResourceLoader
from TravelContentCreator.utils.prompt_manager import PromptManager # Import PromptManager
class tweetTopic:
def __init__(self, index, date, logic, object, product, product_logic, style, style_logic, target_audience, target_audience_logic):
self.index = index
self.date = date
self.logic = logic
self.object = object
self.product = product
self.product_logic = product_logic
self.style = style
self.style_logic = style_logic
self.target_audience = target_audience
self.target_audience_logic = target_audience_logic
class tweetTopicRecord:
def __init__(self, topics_list, system_prompt, user_prompt, output_dir, run_id):
self.topics_list = topics_list
self.system_prompt = system_prompt
self.user_prompt = user_prompt
self.output_dir = output_dir
self.run_id = run_id
def save_topics(self, path):
try:
with open(path, "w", encoding="utf-8") as f:
json.dump(self.topics_list, f, ensure_ascii=False, indent=4)
except Exception as e:
print(f"保存选题失败: {e}")
return False
return True
def save_prompt(self, path):
try:
with open(path, "w", encoding="utf-8") as f:
f.write(self.system_prompt + "\n")
f.write(self.user_prompt + "\n")
f.write(self.output_dir + "\n")
f.write(self.run_id + "\n")
except Exception as e:
print(f"保存提示词失败: {e}")
return False
return True
class tweetContent:
def __init__(self, result, prompt, output_dir, run_id, article_index, variant_index):
self.result = result
self.prompt = prompt
self.output_dir = output_dir
self.run_id = run_id
self.article_index = article_index
self.variant_index = variant_index
self.title, self.content = self.split_content(result)
self.json_file = self.gen_result_json()
def split_content(self, result):
## remove <\think>
result = result.split("</think>")[1]
## get tile
title = result.split("title>")[1].split("</title>")[0]
## get content
content = result.split("content>")[1].split("</content>")[0]
return title, content
def gen_result_json(self):
json_file = {
"title": self.title,
"content": self.content
}
return json_file
def save_content(self, json_path):
with open(json_path, "w", encoding="utf-8") as f:
json.dump(self.json_file, f, ensure_ascii=False, indent=4)
return json_path
def save_prompt(self, path):
with open(path, "w", encoding="utf-8") as f:
f.write(self.prompt + "\n")
return path
def get_content(self):
return self.content
def get_title(self):
return self.title
def get_json_file(self):
return self.json_file
def generate_topics(ai_agent, system_prompt, user_prompt, output_dir, temperature=0.2, top_p=0.5, presence_penalty=1.5):
"""生成选题列表"""
print("开始生成选题...")
# 记录开始时间
time_start = time.time()
# 生成选题
result, system_prompt, user_prompt, file_folder, file_name, tokens, time_cost = ai_agent.work(
system_prompt, user_prompt, "", temperature, top_p, presence_penalty
)
# 计算总耗时
time_end = time.time()
print(f"选题生成完成,耗时:{time_end - time_start}")
# 生成唯一ID
run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# 解析选题
result_list = TopicParser.parse_topics(result)
# success, json_path = TopicParser.save_topics(result_list, output_dir, run_id, result)
tweet_topic_record = tweetTopicRecord(result_list, system_prompt, user_prompt, output_dir, run_id)
return run_id, tweet_topic_record
def generate_single_content(ai_agent, system_prompt, user_prompt, item, output_dir, run_id,
article_index, variant_index, temperature=0.3, top_p=0.4, presence_penalty=1.5):
"""生成单篇文章内容. Requires prompts to be passed in."""
try:
# Prompts are now passed directly as arguments
# No longer build user_prompt here
# user_prompt = ResourceLoader.build_user_prompt(item, prompts_dir, resource_dir)
if not system_prompt or not user_prompt:
print("Error: System or User prompt is empty. Cannot generate content.")
return None, None
print(f"Using pre-constructed prompts. User prompt length: {len(user_prompt)}")
# 添加随机停顿,避免请求过于频繁
time.sleep(random.random() * 0.5 + 0.1)
# 生成文章
result, _, _, _, _, tokens, time_cost = ai_agent.work(
system_prompt, user_prompt, "", temperature, top_p, presence_penalty
)
print(f"生成完成tokens: {tokens}, 耗时: {time_cost}s")
# 保存到单独文件
tweet_content = tweetContent(result, user_prompt, output_dir, run_id, article_index, variant_index)
result_dir = os.path.join(output_dir, f"{article_index}_{variant_index}")
os.makedirs(result_dir, exist_ok=True)
tweet_content.save_content(os.path.join(result_dir, "article.json"))
tweet_content.save_prompt(os.path.join(result_dir, "tweet_prompt.txt"))
return tweet_content, result
except Exception as e:
print(f"生成单篇文章时出错: {e}")
return None, None
def generate_content(ai_agent, system_prompt, topics, output_dir, run_id, prompts_dir, resource_dir,
variants=2, temperature=0.3, start_index=0, end_index=None):
"""根据选题生成内容"""
if not topics:
print("没有选题,无法生成内容")
return
# 确定处理范围
if end_index is None or end_index > len(topics):
end_index = len(topics)
topics_to_process = topics[start_index:end_index]
print(f"准备处理{len(topics_to_process)}个选题...")
# 创建汇总文件
# summary_file = ResourceLoader.create_summary_file(output_dir, run_id, len(topics_to_process))
# 处理每个选题
processed_results = []
for i, item in enumerate(topics_to_process):
print(f"处理第 {i+1}/{len(topics_to_process)} 篇文章")
# 为每个选题生成多个变体
for j in range(variants):
print(f"正在生成变体 {j+1}/{variants}")
# 调用单篇文章生成函数
tweet_content, result = generate_single_content(
ai_agent, system_prompt, item, output_dir, run_id, i+1, j+1, temperature
)
if tweet_content:
processed_results.append(tweet_content)
# # 更新汇总文件 (仅保存第一个变体到汇总文件)
# if j == 0:
# ResourceLoader.update_summary(summary_file, i+1, user_prompt, result)
print(f"完成{len(processed_results)}篇文章生成")
return processed_results
def prepare_topic_generation(
config # Pass the whole config dictionary now
# select_date, select_num,
# system_prompt_path, user_prompt_path,
# base_url="vllm", model_name="qwenQWQ", api_key="EMPTY",
# gen_prompts_path="/root/autodl-tmp/TravelContentCreator/genPrompts",
# resource_dir="/root/autodl-tmp/TravelContentCreator/resource",
# output_dir="/root/autodl-tmp/TravelContentCreator/result"
):
"""准备选题生成的环境和参数. Returns agent and prompts."""
# Initialize PromptManager
prompt_manager = PromptManager(config)
# Get prompts using PromptManager
system_prompt, user_prompt = prompt_manager.get_topic_prompts()
if not system_prompt or not user_prompt:
print("Error: Failed to get topic generation prompts.")
return None, None, None, None
# 创建AI Agent (still create agent here for the topic generation phase)
try:
print("Initializing AI Agent for topic generation...")
ai_agent = AI_Agent(config["api_url"], config["model"], config["api_key"])
except Exception as e:
print(f"Error initializing AI Agent for topic generation: {e}")
traceback.print_exc()
return None, None, None, None
# Removed prompt loading/building logic, now handled by PromptManager
# Return agent and the generated prompts
return ai_agent, system_prompt, user_prompt, config["output_dir"]
def run_topic_generation_pipeline(config):
"""Runs the complete topic generation pipeline based on the configuration."""
print("Step 1: Generating Topics...")
# Prepare necessary inputs and the AI agent for topic generation
ai_agent, system_prompt, user_prompt, base_output_dir = None, None, None, None
try:
# Pass the config directly to prepare_topic_generation
ai_agent, system_prompt, user_prompt, base_output_dir = prepare_topic_generation(config)
if not ai_agent or not system_prompt or not user_prompt:
raise ValueError("Failed to prepare topic generation (agent or prompts missing).")
except Exception as e:
print(f"Error during topic generation preparation: {e}")
traceback.print_exc()
return None, None
# Generate topics using the prepared agent and prompts
try:
run_id, tweet_topic_record = generate_topics(
ai_agent, system_prompt, user_prompt, config["output_dir"],
config.get("topic_temperature", 0.2),
config.get("topic_top_p", 0.5),
config.get("topic_max_tokens", 1.5) # Consider if max_tokens name is accurate here (was presence_penalty?)
)
except Exception as e:
print(f"Error during topic generation API call: {e}")
traceback.print_exc()
if ai_agent: ai_agent.close() # Ensure agent is closed on error
return None, None
# Ensure the AI agent is closed after generation
if ai_agent:
ai_agent.close()
# Process results
if not run_id or not tweet_topic_record:
print("Topic generation failed (no run_id or topics returned).")
return None, None
output_dir = os.path.join(config["output_dir"], run_id)
try:
os.makedirs(output_dir, exist_ok=True)
# Save topics and prompt details
save_topics_success = tweet_topic_record.save_topics(os.path.join(output_dir, "tweet_topic.json"))
save_prompt_success = tweet_topic_record.save_prompt(os.path.join(output_dir, "tweet_prompt.txt"))
if not save_topics_success or not save_prompt_success:
print("Warning: Failed to save topic generation results or prompts.")
# Continue but warn user
except Exception as e:
print(f"Error saving topic generation results: {e}")
traceback.print_exc()
# Return the generated data even if saving fails, but maybe warn more strongly?
# return run_id, tweet_topic_record # Decide if partial success is okay
return None, None # Or consider failure if saving is critical
print(f"Topics generated successfully. Run ID: {run_id}")
return run_id, tweet_topic_record
def main():
"""主函数入口"""
config_file = {
"date": "4月17日",
"num": 5,
"model": "qwenQWQ",
"api_url": "vllm",
"api_key": "EMPTY",
"topic_system_prompt": "/root/autodl-tmp/TravelContentCreator/SelectPrompt/systemPrompt.txt",
"topic_user_prompt": "/root/autodl-tmp/TravelContentCreator/SelectPrompt/userPrompt.txt",
"content_system_prompt": "/root/autodl-tmp/TravelContentCreator/genPrompts/systemPrompt.txt",
"resource_dir": [{
"type": "Object",
"num": 4,
"file_path": ["/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-尚书第.txt",
"/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-明清园.txt",
"/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-泰宁古城.txt",
"/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-甘露寺.txt"
]},
{
"type": "Product",
"num": 0,
"file_path": []
}
],
"prompts_dir": "/root/autodl-tmp/TravelContentCreator/genPrompts",
"output_dir": "/root/autodl-tmp/TravelContentCreator/result",
"variants": 2,
"topic_temperature": 0.2,
"content_temperature": 0.3
}
if True:
# 1. 首先生成选题
ai_agent, system_prompt, user_prompt, output_dir = prepare_topic_generation(
config_file
)
run_id, tweet_topic_record = generate_topics(
ai_agent, system_prompt, user_prompt, config_file["output_dir"],
config_file["topic_temperature"], 0.5, 1.5
)
output_dir = os.path.join(config_file["output_dir"], run_id)
os.makedirs(output_dir, exist_ok=True)
tweet_topic_record.save_topics(os.path.join(output_dir, "tweet_topic.json"))
tweet_topic_record.save_prompt(os.path.join(output_dir, "tweet_prompt.txt"))
# raise Exception("选题生成失败,退出程序")
if not run_id or not tweet_topic_record:
print("选题生成失败,退出程序")
return
# 2. 然后生成内容
print("\n开始根据选题生成内容...")
# 加载内容生成的系统提示词
content_system_prompt = ResourceLoader.load_system_prompt(config_file["content_system_prompt"])
if not content_system_prompt:
print("内容生成系统提示词为空,使用选题生成的系统提示词")
content_system_prompt = system_prompt
# 直接使用同一个AI Agent实例
result = generate_content(
ai_agent, content_system_prompt, tweet_topic_record.topics_list, output_dir, run_id, config_file["prompts_dir"], config_file["resource_dir"],
config_file["variants"], config_file["content_temperature"]
)
if __name__ == "__main__":
main()