TravelContentCreator/utils/tweet_generator.py

644 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import time
import random
import argparse
import json
from datetime import datetime
import sys
import traceback
# sys.path.append('/root/autodl-tmp') # No longer needed if running as a module or if path is set correctly
# 从本地模块导入
# from TravelContentCreator.core.ai_agent import AI_Agent # Remove project name prefix
# from TravelContentCreator.core.topic_parser import TopicParser # Remove project name prefix
# ResourceLoader is now used implicitly via PromptManager
# from TravelContentCreator.utils.resource_loader import ResourceLoader
# from TravelContentCreator.utils.prompt_manager import PromptManager # Remove project name prefix
# from ..core import contentGen as core_contentGen # Change to absolute import
# from ..core import posterGen as core_posterGen # Change to absolute import
# from ..core import simple_collage as core_simple_collage # Change to absolute import
from core.ai_agent import AI_Agent
from core.topic_parser import TopicParser
from utils.prompt_manager import PromptManager # Keep this as it's importing from the same level package 'utils'
from core import contentGen as core_contentGen
from core import posterGen as core_posterGen
from core import simple_collage as core_simple_collage
class tweetTopic:
def __init__(self, index, date, logic, object, product, product_logic, style, style_logic, target_audience, target_audience_logic):
self.index = index
self.date = date
self.logic = logic
self.object = object
self.product = product
self.product_logic = product_logic
self.style = style
self.style_logic = style_logic
self.target_audience = target_audience
self.target_audience_logic = target_audience_logic
class tweetTopicRecord:
def __init__(self, topics_list, system_prompt, user_prompt, output_dir, run_id):
self.topics_list = topics_list
self.system_prompt = system_prompt
self.user_prompt = user_prompt
self.output_dir = output_dir
self.run_id = run_id
def save_topics(self, path):
try:
with open(path, "w", encoding="utf-8") as f:
json.dump(self.topics_list, f, ensure_ascii=False, indent=4)
print(f"Topics list successfully saved to {path}") # Add success message
except Exception as e:
print(f"保存选题失败到 {path}: {e}")
print("--- Traceback for save_topics error ---")
traceback.print_exc() # Print detailed traceback
print("--- End Traceback ---")
return False
return True
def save_prompt(self, path):
try:
with open(path, "w", encoding="utf-8") as f:
f.write(self.system_prompt + "\n")
f.write(self.user_prompt + "\n")
f.write(self.output_dir + "\n")
f.write(self.run_id + "\n")
except Exception as e:
print(f"保存提示词失败: {e}")
return False
return True
class tweetContent:
def __init__(self, result, prompt, output_dir, run_id, article_index, variant_index):
self.result = result
self.prompt = prompt
self.output_dir = output_dir
self.run_id = run_id
self.article_index = article_index
self.variant_index = variant_index
self.title, self.content = self.split_content(result)
self.json_file = self.gen_result_json()
def split_content(self, result):
## remove <\think>
result = result.split("</think>")[1]
## get tile
title = result.split("title>")[1].split("</title>")[0]
## get content
content = result.split("content>")[1].split("</content>")[0]
return title, content
def gen_result_json(self):
json_file = {
"title": self.title,
"content": self.content
}
return json_file
def save_content(self, json_path):
with open(json_path, "w", encoding="utf-8") as f:
json.dump(self.json_file, f, ensure_ascii=False, indent=4)
return json_path
def save_prompt(self, path):
with open(path, "w", encoding="utf-8") as f:
f.write(self.prompt + "\n")
return path
def get_content(self):
return self.content
def get_title(self):
return self.title
def get_json_file(self):
return self.json_file
def generate_topics(ai_agent, system_prompt, user_prompt, output_dir, run_id, temperature=0.2, top_p=0.5, presence_penalty=1.5):
"""生成选题列表 (run_id is now passed in)"""
print("开始生成选题...")
# 记录开始时间
time_start = time.time()
# 生成选题
result, system_prompt, user_prompt, file_folder, file_name, tokens, time_cost = ai_agent.work(
system_prompt, user_prompt, "", temperature, top_p, presence_penalty
)
# 计算总耗时
time_end = time.time()
print(f"选题生成完成,耗时:{time_end - time_start}")
# 解析选题
result_list = TopicParser.parse_topics(result)
# success, json_path = TopicParser.save_topics(result_list, output_dir, run_id, result)
tweet_topic_record = tweetTopicRecord(result_list, system_prompt, user_prompt, output_dir, run_id)
# Return only the record, run_id is known by caller
return tweet_topic_record
def generate_single_content(ai_agent, system_prompt, user_prompt, item, output_dir, run_id,
article_index, variant_index, temperature=0.3, top_p=0.4, presence_penalty=1.5):
"""生成单篇文章内容. Requires prompts to be passed in."""
try:
# Prompts are now passed directly as arguments
# No longer build user_prompt here
# user_prompt = ResourceLoader.build_user_prompt(item, prompts_dir, resource_dir)
if not system_prompt or not user_prompt:
print("Error: System or User prompt is empty. Cannot generate content.")
return None, None
print(f"Using pre-constructed prompts. User prompt length: {len(user_prompt)}")
# 添加随机停顿,避免请求过于频繁
time.sleep(random.random() * 0.5 + 0.1)
# 生成文章
result, _, _, _, _, tokens, time_cost = ai_agent.work(
system_prompt, user_prompt, "", temperature, top_p, presence_penalty
)
print(f"生成完成tokens: {tokens}, 耗时: {time_cost}s")
# --- Correct directory structure ---
# Define the full path to the run-specific directory
run_specific_output_dir = os.path.join(output_dir, run_id) # e.g., result/2025-04-22...
# Define the directory for this specific article variant *under* the run directory
variant_result_dir = os.path.join(run_specific_output_dir, f"{article_index}_{variant_index}") # e.g., result/2025-04-22.../1_1
os.makedirs(variant_result_dir, exist_ok=True) # Ensure the variant directory exists
# Create the tweetContent object (output_dir param might be irrelevant for saving)
tweet_content = tweetContent(result, user_prompt, output_dir, run_id, article_index, variant_index)
# Save content and prompt to the correct variant directory
content_save_path = os.path.join(variant_result_dir, "article.json")
prompt_save_path = os.path.join(variant_result_dir, "tweet_prompt.txt")
tweet_content.save_content(content_save_path)
tweet_content.save_prompt(prompt_save_path)
print(f" Saved article content to: {content_save_path}") # Add log for confirmation
# --- End directory structure correction ---
return tweet_content, result
except Exception as e:
print(f"生成单篇文章时出错: {e}")
return None, None
def generate_content(ai_agent, system_prompt, topics, output_dir, run_id, prompts_dir, resource_dir,
variants=2, temperature=0.3, start_index=0, end_index=None):
"""根据选题生成内容"""
if not topics:
print("没有选题,无法生成内容")
return
# 确定处理范围
if end_index is None or end_index > len(topics):
end_index = len(topics)
topics_to_process = topics[start_index:end_index]
print(f"准备处理{len(topics_to_process)}个选题...")
# 创建汇总文件
# summary_file = ResourceLoader.create_summary_file(output_dir, run_id, len(topics_to_process))
# 处理每个选题
processed_results = []
for i, item in enumerate(topics_to_process):
print(f"处理第 {i+1}/{len(topics_to_process)} 篇文章")
# 为每个选题生成多个变体
for j in range(variants):
print(f"正在生成变体 {j+1}/{variants}")
# 调用单篇文章生成函数
tweet_content, result = generate_single_content(
ai_agent, system_prompt, item, output_dir, run_id, i+1, j+1, temperature
)
if tweet_content:
processed_results.append(tweet_content)
# # 更新汇总文件 (仅保存第一个变体到汇总文件)
# if j == 0:
# ResourceLoader.update_summary(summary_file, i+1, user_prompt, result)
print(f"完成{len(processed_results)}篇文章生成")
return processed_results
def prepare_topic_generation(
config # Pass the whole config dictionary now
# select_date, select_num,
# system_prompt_path, user_prompt_path,
# base_url="vllm", model_name="qwenQWQ", api_key="EMPTY",
# gen_prompts_path="/root/autodl-tmp/TravelContentCreator/genPrompts",
# resource_dir="/root/autodl-tmp/TravelContentCreator/resource",
# output_dir="/root/autodl-tmp/TravelContentCreator/result"
):
"""准备选题生成的环境和参数. Returns agent and prompts."""
# Initialize PromptManager
prompt_manager = PromptManager(config)
# Get prompts using PromptManager
system_prompt, user_prompt = prompt_manager.get_topic_prompts()
if not system_prompt or not user_prompt:
print("Error: Failed to get topic generation prompts.")
return None, None, None, None
# 创建AI Agent (still create agent here for the topic generation phase)
try:
print("Initializing AI Agent for topic generation...")
# --- Read timeout/retry from config ---
request_timeout = config.get("request_timeout", 30) # Default 30 seconds
max_retries = config.get("max_retries", 3) # Default 3 retries
# --- Pass values to AI_Agent ---
ai_agent = AI_Agent(
config["api_url"],
config["model"],
config["api_key"],
timeout=request_timeout,
max_retries=max_retries
)
except Exception as e:
print(f"Error initializing AI Agent for topic generation: {e}")
traceback.print_exc()
return None, None, None, None
# Removed prompt loading/building logic, now handled by PromptManager
# Return agent and the generated prompts
return ai_agent, system_prompt, user_prompt, config["output_dir"]
def run_topic_generation_pipeline(config, run_id=None):
"""Runs the complete topic generation pipeline based on the configuration."""
print("Step 1: Generating Topics...")
# --- Handle run_id ---
if run_id is None:
print("No run_id provided, generating one based on timestamp.")
run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
else:
print(f"Using provided run_id: {run_id}")
# --- End run_id handling ---
# Prepare necessary inputs and the AI agent for topic generation
ai_agent, system_prompt, user_prompt, base_output_dir = None, None, None, None
try:
# Pass the config directly to prepare_topic_generation
ai_agent, system_prompt, user_prompt, base_output_dir = prepare_topic_generation(config)
if not ai_agent or not system_prompt or not user_prompt:
raise ValueError("Failed to prepare topic generation (agent or prompts missing).")
except Exception as e:
print(f"Error during topic generation preparation: {e}")
traceback.print_exc()
return None, None
# Generate topics using the prepared agent and prompts
try:
# Pass the determined run_id to generate_topics
tweet_topic_record = generate_topics(
ai_agent, system_prompt, user_prompt, config["output_dir"],
run_id, # Pass the run_id
config.get("topic_temperature", 0.2),
config.get("topic_top_p", 0.5),
config.get("topic_presence_penalty", 1.5)
)
except Exception as e:
print(f"Error during topic generation API call: {e}")
traceback.print_exc()
if ai_agent: ai_agent.close() # Ensure agent is closed on error
return None, None
# Ensure the AI agent is closed after generation
if ai_agent:
ai_agent.close()
# Process results
if not tweet_topic_record:
print("Topic generation failed (generate_topics returned None).")
return None, None # Return None for run_id as well if record is None
# Use the determined run_id for output directory
output_dir = os.path.join(config["output_dir"], run_id)
try:
os.makedirs(output_dir, exist_ok=True)
# --- Debug: Print the data before attempting to save ---
print("--- Debug: Data to be saved in tweet_topic.json ---")
print(tweet_topic_record.topics_list)
print("--- End Debug ---")
# --- End Debug ---
# Save topics and prompt details
save_topics_success = tweet_topic_record.save_topics(os.path.join(output_dir, "tweet_topic.json"))
save_prompt_success = tweet_topic_record.save_prompt(os.path.join(output_dir, "tweet_prompt.txt"))
if not save_topics_success or not save_prompt_success:
print("Warning: Failed to save topic generation results or prompts.")
# Continue but warn user
except Exception as e:
print(f"Error saving topic generation results: {e}")
traceback.print_exc()
# Return the generated data even if saving fails, but maybe warn more strongly?
# return run_id, tweet_topic_record # Decide if partial success is okay
return None, None # Or consider failure if saving is critical
print(f"Topics generated successfully. Run ID: {run_id}")
# Return the determined run_id and the record
return run_id, tweet_topic_record
# --- Decoupled Functional Units (Moved from main.py) ---
def generate_content_for_topic(ai_agent, prompt_manager, config, topic_item, output_dir, run_id, topic_index):
"""Generates all content variants for a single topic item.
Args:
ai_agent: An initialized AI_Agent instance.
prompt_manager: An initialized PromptManager instance.
config: The global configuration dictionary.
topic_item: The dictionary representing a single topic.
output_dir: The base output directory for the entire run (e.g., ./result).
run_id: The ID for the current run.
topic_index: The 1-based index of the current topic.
Returns:
A list of tweet content data (dictionaries) generated for the topic,
or None if generation failed.
"""
print(f" Generating content for Topic {topic_index}...")
tweet_content_list = []
variants = config.get("variants", 1)
for j in range(variants):
variant_index = j + 1
print(f" Generating Variant {variant_index}/{variants}...")
# Get prompts for this specific topic item
# Assuming prompt_manager is correctly initialized and passed
content_system_prompt, content_user_prompt = prompt_manager.get_content_prompts(topic_item)
if not content_system_prompt or not content_user_prompt:
print(f" Skipping Variant {variant_index} due to missing content prompts.")
continue # Skip this variant
time.sleep(random.random() * 0.5)
try:
# Call the core generation function (generate_single_content is in this file)
tweet_content, gen_result = generate_single_content(
ai_agent, content_system_prompt, content_user_prompt, topic_item,
output_dir, run_id, topic_index, variant_index,
config.get("content_temperature", 0.3),
config.get("content_top_p", 0.4),
config.get("content_presence_penalty", 1.5)
)
if tweet_content:
try:
tweet_content_data = tweet_content.get_json_file()
if tweet_content_data:
tweet_content_list.append(tweet_content_data)
else:
print(f" Warning: tweet_content.get_json_file() for Topic {topic_index}, Variant {variant_index} returned empty data.")
except Exception as parse_err:
print(f" Error processing tweet content after generation for Topic {topic_index}, Variant {variant_index}: {parse_err}")
else:
print(f" Failed to generate content for Topic {topic_index}, Variant {variant_index}. Skipping.")
except Exception as e:
print(f" Error during content generation for Topic {topic_index}, Variant {variant_index}: {e}")
# traceback.print_exc()
if not tweet_content_list:
print(f" No valid content generated for Topic {topic_index}.")
return None
else:
print(f" Successfully generated {len(tweet_content_list)} content variants for Topic {topic_index}.")
return tweet_content_list
def generate_posters_for_topic(config, topic_item, tweet_content_list, output_dir, run_id, topic_index):
"""Generates all posters for a single topic item based on its generated content.
Args:
config: The global configuration dictionary.
topic_item: The dictionary representing a single topic.
tweet_content_list: List of content data generated by generate_content_for_topic.
output_dir: The base output directory for the entire run (e.g., ./result).
run_id: The ID for the current run.
topic_index: The 1-based index of the current topic.
Returns:
True if poster generation was attempted (regardless of individual variant success),
False if setup failed before attempting variants.
"""
print(f" Generating posters for Topic {topic_index}...")
# Initialize necessary generators here, assuming they are stateless or cheap to create
# Alternatively, pass initialized instances if they hold state or are expensive
try:
content_gen_instance = core_contentGen.ContentGenerator()
poster_gen_instance = core_posterGen.PosterGenerator()
except Exception as e:
print(f" Error initializing generators for poster creation: {e}")
return False
# --- Setup: Paths and Object Name ---
image_base_dir = config.get("image_base_dir")
if not image_base_dir:
print(" Error: image_base_dir missing in config for poster generation.")
return False
modify_image_subdir = config.get("modify_image_subdir", "modify")
camera_image_subdir = config.get("camera_image_subdir", "相机")
object_name = topic_item.get("object", "")
if not object_name:
print(" Warning: Topic object name is missing. Cannot generate posters.")
return False
# Clean object name
try:
object_name_cleaned = object_name.split(".")[0].replace("景点信息-", "").strip()
if not object_name_cleaned:
print(f" Warning: Object name '{object_name}' resulted in empty string after cleaning. Skipping posters.")
return False
object_name = object_name_cleaned
except Exception as e:
print(f" Warning: Could not fully clean object name '{object_name}': {e}. Skipping posters.")
return False
# Construct and check image paths
input_img_dir_path = os.path.join(image_base_dir, modify_image_subdir, object_name)
camera_img_dir_path = os.path.join(image_base_dir, camera_image_subdir, object_name)
description_file_path = os.path.join(camera_img_dir_path, "description.txt")
if not os.path.exists(input_img_dir_path) or not os.path.isdir(input_img_dir_path):
print(f" Image directory not found or not a directory: '{input_img_dir_path}'. Skipping posters for this topic.")
return False
info_directory = []
if os.path.exists(description_file_path):
info_directory = [description_file_path]
print(f" Using description file: {description_file_path}")
else:
print(f" Description file not found: '{description_file_path}'. Using generated content for poster text.")
# --- Generate Text Configurations for All Variants ---
try:
poster_text_configs_raw = content_gen_instance.run(info_directory, config["variants"], tweet_content_list)
if not poster_text_configs_raw:
print(" Warning: ContentGenerator returned empty configuration data. Skipping posters.")
return False
poster_config_summary = core_posterGen.PosterConfig(poster_text_configs_raw)
except Exception as e:
print(f" Error running ContentGenerator or parsing poster configs: {e}")
traceback.print_exc()
return False # Cannot proceed if text config fails
# --- Poster Generation Loop for each variant ---
poster_num = config.get("variants", 1)
target_size = tuple(config.get("poster_target_size", [900, 1200]))
any_poster_attempted = False
text_possibility = config.get("text_possibility", 0.3) # Get from config
for j_index in range(poster_num):
variant_index = j_index + 1
print(f" Generating Poster {variant_index}/{poster_num}...")
any_poster_attempted = True
try:
poster_config = poster_config_summary.get_config_by_index(j_index)
if not poster_config:
print(f" Warning: Could not get poster config for index {j_index}. Skipping.")
continue
# Define output directories for this specific variant
run_output_dir = os.path.join(output_dir, run_id) # Base dir for the run
variant_output_dir = os.path.join(run_output_dir, f"{topic_index}_{variant_index}")
collage_output_dir = os.path.join(variant_output_dir, "collage_img")
poster_output_dir = os.path.join(variant_output_dir, "poster")
os.makedirs(collage_output_dir, exist_ok=True)
os.makedirs(poster_output_dir, exist_ok=True)
# --- Image Collage ---
img_list = core_simple_collage.process_directory(
input_img_dir_path,
target_size=target_size,
output_count=1,
output_dir=collage_output_dir
)
if not img_list or len(img_list) == 0 or not img_list[0].get('path'):
print(f" Failed to generate collage image for Variant {variant_index}. Skipping poster.")
continue
collage_img_path = img_list[0]['path']
print(f" Using collage image: {collage_img_path}")
# --- Create Poster ---
text_data = {
"title": poster_config.get('main_title', 'Default Title'),
"subtitle": "",
"additional_texts": []
}
texts = poster_config.get('texts', [])
if texts:
# Ensure TEXT_POSBILITY is accessible, maybe pass via config?
# text_possibility = config.get("text_possibility", 0.3)
text_data["additional_texts"].append({"text": texts[0], "position": "bottom", "size_factor": 0.5})
if len(texts) > 1 and random.random() < text_possibility: # Use variable from config
text_data["additional_texts"].append({"text": texts[1], "position": "bottom", "size_factor": 0.5})
final_poster_path = os.path.join(poster_output_dir, "poster.jpg")
result_path = poster_gen_instance.create_poster(collage_img_path, text_data, final_poster_path)
if result_path:
print(f" Successfully generated poster: {result_path}")
else:
print(f" Poster generation function did not return a valid path.")
except Exception as e:
print(f" Error during poster generation for Variant {variant_index}: {e}")
traceback.print_exc()
continue # Continue to next variant
return any_poster_attempted
def main():
"""主函数入口"""
config_file = {
"date": "4月17日",
"num": 5,
"model": "qwenQWQ",
"api_url": "vllm",
"api_key": "EMPTY",
"topic_system_prompt": "/root/autodl-tmp/TravelContentCreator/SelectPrompt/systemPrompt.txt",
"topic_user_prompt": "/root/autodl-tmp/TravelContentCreator/SelectPrompt/userPrompt.txt",
"content_system_prompt": "/root/autodl-tmp/TravelContentCreator/genPrompts/systemPrompt.txt",
"resource_dir": [{
"type": "Object",
"num": 4,
"file_path": ["/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-尚书第.txt",
"/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-明清园.txt",
"/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-泰宁古城.txt",
"/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-甘露寺.txt"
]},
{
"type": "Product",
"num": 0,
"file_path": []
}
],
"prompts_dir": "/root/autodl-tmp/TravelContentCreator/genPrompts",
"output_dir": "/root/autodl-tmp/TravelContentCreator/result",
"variants": 2,
"topic_temperature": 0.2,
"content_temperature": 0.3
}
if True:
# 1. 首先生成选题
ai_agent, system_prompt, user_prompt, output_dir = prepare_topic_generation(
config_file
)
run_id, tweet_topic_record = generate_topics(
ai_agent, system_prompt, user_prompt, config_file["output_dir"],
config_file["topic_temperature"], 0.5, 1.5
)
output_dir = os.path.join(config_file["output_dir"], run_id)
os.makedirs(output_dir, exist_ok=True)
tweet_topic_record.save_topics(os.path.join(output_dir, "tweet_topic.json"))
tweet_topic_record.save_prompt(os.path.join(output_dir, "tweet_prompt.txt"))
# raise Exception("选题生成失败,退出程序")
if not run_id or not tweet_topic_record:
print("选题生成失败,退出程序")
return
# 2. 然后生成内容
print("\n开始根据选题生成内容...")
# 加载内容生成的系统提示词
content_system_prompt = ResourceLoader.load_system_prompt(config_file["content_system_prompt"])
if not content_system_prompt:
print("内容生成系统提示词为空,使用选题生成的系统提示词")
content_system_prompt = system_prompt
# 直接使用同一个AI Agent实例
result = generate_content(
ai_agent, content_system_prompt, tweet_topic_record.topics_list, output_dir, run_id, config_file["prompts_dir"], config_file["resource_dir"],
config_file["variants"], config_file["content_temperature"]
)
if __name__ == "__main__":
main()