TravelContentCreator/utils/tweet_generator.py

708 lines
33 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import time
import random
import argparse
import json
from datetime import datetime
import sys
import traceback
import logging # Add logging
# sys.path.append('/root/autodl-tmp') # No longer needed if running as a module or if path is set correctly
# 从本地模块导入
# from TravelContentCreator.core.ai_agent import AI_Agent # Remove project name prefix
# from TravelContentCreator.core.topic_parser import TopicParser # Remove project name prefix
# ResourceLoader is now used implicitly via PromptManager
# from TravelContentCreator.utils.resource_loader import ResourceLoader
# from TravelContentCreator.utils.prompt_manager import PromptManager # Remove project name prefix
# from ..core import contentGen as core_contentGen # Change to absolute import
# from ..core import posterGen as core_posterGen # Change to absolute import
# from ..core import simple_collage as core_simple_collage # Change to absolute import
from core.ai_agent import AI_Agent
from core.topic_parser import TopicParser
from utils.prompt_manager import PromptManager # Keep this as it's importing from the same level package 'utils'
from core import contentGen as core_contentGen
from core import posterGen as core_posterGen
from core import simple_collage as core_simple_collage
class tweetTopic:
def __init__(self, index, date, logic, object, product, product_logic, style, style_logic, target_audience, target_audience_logic):
self.index = index
self.date = date
self.logic = logic
self.object = object
self.product = product
self.product_logic = product_logic
self.style = style
self.style_logic = style_logic
self.target_audience = target_audience
self.target_audience_logic = target_audience_logic
class tweetTopicRecord:
def __init__(self, topics_list, system_prompt, user_prompt, output_dir, run_id):
self.topics_list = topics_list
self.system_prompt = system_prompt
self.user_prompt = user_prompt
self.output_dir = output_dir
self.run_id = run_id
def save_topics(self, path):
try:
with open(path, "w", encoding="utf-8") as f:
json.dump(self.topics_list, f, ensure_ascii=False, indent=4)
logging.info(f"Topics list successfully saved to {path}") # Change to logging
except Exception as e:
# Keep print for traceback, but add logging
logging.exception(f"保存选题失败到 {path}: {e}") # Log exception
# print(f"保存选题失败到 {path}: {e}")
# print("--- Traceback for save_topics error ---")
# traceback.print_exc()
# print("--- End Traceback ---")
return False
return True
def save_prompt(self, path):
try:
with open(path, "w", encoding="utf-8") as f:
f.write(self.system_prompt + "\n")
f.write(self.user_prompt + "\n")
# f.write(self.output_dir + "\n") # Output dir not needed in prompt file?
# f.write(self.run_id + "\n") # run_id not needed in prompt file?
logging.info(f"Prompts saved to {path}")
except Exception as e:
logging.exception(f"保存提示词失败: {e}")
# print(f"保存提示词失败: {e}")
return False
return True
class tweetContent:
def __init__(self, result, prompt, output_dir, run_id, article_index, variant_index):
self.result = result
self.prompt = prompt
self.output_dir = output_dir
self.run_id = run_id
self.article_index = article_index
self.variant_index = variant_index
try:
self.title, self.content = self.split_content(result)
self.json_file = self.gen_result_json()
except Exception as e:
logging.error(f"Failed to parse AI result for {article_index}_{variant_index}: {e}")
logging.debug(f"Raw result: {result[:500]}...") # Log partial raw result
self.title = "[Parsing Error]"
self.content = "[Failed to parse AI content]"
self.json_file = {"title": self.title, "content": self.content, "error": True, "raw_result": result}
def split_content(self, result):
# Assuming split logic might still fail, keep it simple or improve with regex/json
# We should ideally switch content generation to JSON output as well.
# For now, keep existing logic but handle errors in __init__.
# Optional: Add basic check before splitting
if not result or "</think>" not in result or "title>" not in result or "content>" not in result:
logging.warning(f"AI result format unexpected: {result[:200]}...")
# Raise error to be caught in __init__
raise ValueError("AI result missing expected tags or is empty")
# --- Existing Logic (prone to errors) ---
processed_result = result
if "</think>" in result:
processed_result = result.split("</think>", 1)[1] # Take part after </think>
title = processed_result.split("title>", 1)[1].split("</title>", 1)[0]
content = processed_result.split("content>", 1)[1].split("</content>", 1)[0]
# --- End Existing Logic ---
return title.strip(), content.strip()
def gen_result_json(self):
json_file = {
"title": self.title,
"content": self.content
}
return json_file
def save_content(self, json_path):
try:
with open(json_path, "w", encoding="utf-8") as f:
# If parsing failed, save the error structure
json.dump(self.json_file, f, ensure_ascii=False, indent=4)
logging.info(f"Content JSON saved to: {json_path}")
except Exception as e:
logging.exception(f"Failed to save content JSON to {json_path}: {e}")
return None # Indicate failure
return json_path
def save_prompt(self, path):
try:
with open(path, "w", encoding="utf-8") as f:
f.write(self.prompt + "\n")
logging.info(f"Content prompt saved to: {path}")
except Exception as e:
logging.exception(f"Failed to save content prompt to {path}: {e}")
return None # Indicate failure
return path
def get_content(self):
return self.content
def get_title(self):
return self.title
def get_json_file(self):
return self.json_file
def generate_topics(ai_agent, system_prompt, user_prompt, output_dir, run_id, temperature=0.2, top_p=0.5, presence_penalty=1.5):
"""生成选题列表 (run_id is now passed in)"""
logging.info("Starting topic generation...")
time_start = time.time()
# Call AI agent work method (updated return values)
result, tokens, time_cost = ai_agent.work(
system_prompt, user_prompt, "", temperature, top_p, presence_penalty
)
logging.info(f"Topic generation API call completed in {time_cost:.2f}s. Estimated tokens: {tokens}")
# Parse topics
result_list = TopicParser.parse_topics(result)
if not result_list:
logging.warning("Topic parsing resulted in an empty list.")
# Optionally save raw result here if parsing fails?
error_log_path = os.path.join(output_dir, run_id, f"topic_parsing_error_{run_id}.txt")
try:
os.makedirs(os.path.dirname(error_log_path), exist_ok=True)
with open(error_log_path, "w", encoding="utf-8") as f_err:
f_err.write("--- Topic Parsing Failed ---\n")
f_err.write(result)
logging.info(f"Saved raw AI output due to parsing failure to: {error_log_path}")
except Exception as log_err:
logging.error(f"Failed to save raw AI output on parsing failure: {log_err}")
# Create record object (even if list is empty)
tweet_topic_record = tweetTopicRecord(result_list, system_prompt, user_prompt, output_dir, run_id)
return tweet_topic_record # Return only the record
def generate_single_content(ai_agent, system_prompt, user_prompt, item, output_dir, run_id,
article_index, variant_index, temperature=0.3, top_p=0.4, presence_penalty=1.5):
"""生成单篇文章内容. Requires prompts to be passed in."""
logging.info(f"Generating content for topic {article_index}, variant {variant_index}")
try:
if not system_prompt or not user_prompt:
logging.error("System or User prompt is empty. Cannot generate content.")
return None, None
logging.debug(f"Using pre-constructed prompts. User prompt length: {len(user_prompt)}")
time.sleep(random.random() * 0.5)
# Generate content (updated return values)
result, tokens, time_cost = ai_agent.work(
system_prompt, user_prompt, "", temperature, top_p, presence_penalty
)
logging.info(f"Content generation for {article_index}_{variant_index} completed in {time_cost:.2f}s. Estimated tokens: {tokens}")
# --- Correct directory structure ---
run_specific_output_dir = os.path.join(output_dir, run_id)
variant_result_dir = os.path.join(run_specific_output_dir, f"{article_index}_{variant_index}")
os.makedirs(variant_result_dir, exist_ok=True)
# Create tweetContent object (handles potential parsing errors inside its __init__)
tweet_content = tweetContent(result, user_prompt, output_dir, run_id, article_index, variant_index)
# Save content and prompt
content_save_path = os.path.join(variant_result_dir, "article.json")
prompt_save_path = os.path.join(variant_result_dir, "tweet_prompt.txt")
tweet_content.save_content(content_save_path)
tweet_content.save_prompt(prompt_save_path)
# logging.info(f" Saved article content to: {content_save_path}") # Already logged in save_content
return tweet_content, result # Return object and raw result
except Exception as e:
logging.exception(f"Error generating single content for {article_index}_{variant_index}:")
return None, None
def generate_content(ai_agent, system_prompt, topics, output_dir, run_id, prompts_dir, resource_dir,
variants=2, temperature=0.3, start_index=0, end_index=None):
"""根据选题生成内容"""
if not topics:
print("没有选题,无法生成内容")
return
# 确定处理范围
if end_index is None or end_index > len(topics):
end_index = len(topics)
topics_to_process = topics[start_index:end_index]
print(f"准备处理{len(topics_to_process)}个选题...")
# 创建汇总文件
# summary_file = ResourceLoader.create_summary_file(output_dir, run_id, len(topics_to_process))
# 处理每个选题
processed_results = []
for i, item in enumerate(topics_to_process):
print(f"处理第 {i+1}/{len(topics_to_process)} 篇文章")
# 为每个选题生成多个变体
for j in range(variants):
print(f"正在生成变体 {j+1}/{variants}")
# 调用单篇文章生成函数
tweet_content, result = generate_single_content(
ai_agent, system_prompt, item, output_dir, run_id, i+1, j+1, temperature
)
if tweet_content:
processed_results.append(tweet_content)
# # 更新汇总文件 (仅保存第一个变体到汇总文件)
# if j == 0:
# ResourceLoader.update_summary(summary_file, i+1, user_prompt, result)
print(f"完成{len(processed_results)}篇文章生成")
return processed_results
def prepare_topic_generation(
config # Pass the whole config dictionary now
# select_date, select_num,
# system_prompt_path, user_prompt_path,
# base_url="vllm", model_name="qwenQWQ", api_key="EMPTY",
# gen_prompts_path="/root/autodl-tmp/TravelContentCreator/genPrompts",
# resource_dir="/root/autodl-tmp/TravelContentCreator/resource",
# output_dir="/root/autodl-tmp/TravelContentCreator/result"
):
"""准备选题生成的环境和参数. Returns agent and prompts."""
# Initialize PromptManager
prompt_manager = PromptManager(config)
# Get prompts using PromptManager
system_prompt, user_prompt = prompt_manager.get_topic_prompts()
if not system_prompt or not user_prompt:
print("Error: Failed to get topic generation prompts.")
return None, None, None, None
# 创建AI Agent (still create agent here for the topic generation phase)
try:
logging.info("Initializing AI Agent for topic generation...")
# --- Read timeout/retry from config ---
request_timeout = config.get("request_timeout", 30) # Default 30 seconds
max_retries = config.get("max_retries", 3) # Default 3 retries
# --- Pass values to AI_Agent ---
ai_agent = AI_Agent(
config["api_url"],
config["model"],
config["api_key"],
timeout=request_timeout,
max_retries=max_retries
)
except Exception as e:
logging.exception("Error initializing AI Agent for topic generation:")
traceback.print_exc()
return None, None, None, None
# Removed prompt loading/building logic, now handled by PromptManager
# Return agent and the generated prompts
return ai_agent, system_prompt, user_prompt, config["output_dir"]
def run_topic_generation_pipeline(config, run_id=None):
"""Runs the complete topic generation pipeline based on the configuration."""
logging.info("Starting Step 1: Topic Generation Pipeline...")
# --- Handle run_id ---
if run_id is None:
logging.info("No run_id provided, generating one based on timestamp.")
run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
else:
logging.info(f"Using provided run_id: {run_id}")
# --- End run_id handling ---
# Prepare necessary inputs and the AI agent for topic generation
ai_agent, system_prompt, user_prompt, base_output_dir = None, None, None, None
try:
# Pass the config directly to prepare_topic_generation
ai_agent, system_prompt, user_prompt, base_output_dir = prepare_topic_generation(config)
if not ai_agent or not system_prompt or not user_prompt:
raise ValueError("Failed to prepare topic generation (agent or prompts missing).")
except Exception as e:
logging.exception("Error during topic generation preparation:")
traceback.print_exc()
return None, None
# Generate topics using the prepared agent and prompts
try:
# Pass the determined run_id to generate_topics
tweet_topic_record = generate_topics(
ai_agent, system_prompt, user_prompt, config["output_dir"],
run_id, # Pass the run_id
config.get("topic_temperature", 0.2),
config.get("topic_top_p", 0.5),
config.get("topic_presence_penalty", 1.5)
)
except Exception as e:
logging.exception("Error during topic generation API call:")
traceback.print_exc()
if ai_agent: ai_agent.close() # Ensure agent is closed on error
return None, None
# Ensure the AI agent is closed after generation
if ai_agent:
logging.info("Closing topic generation AI Agent...")
ai_agent.close()
# Process results
if not tweet_topic_record:
logging.error("Topic generation failed (generate_topics returned None).")
return None, None # Return None for run_id as well if record is None
# Use the determined run_id for output directory
output_dir = os.path.join(config["output_dir"], run_id)
try:
os.makedirs(output_dir, exist_ok=True)
# --- Debug: Print the data before attempting to save ---
logging.info("--- Debug: Data to be saved in tweet_topic.json ---")
logging.info(tweet_topic_record.topics_list)
logging.info("--- End Debug ---")
# --- End Debug ---
# Save topics and prompt details
save_topics_success = tweet_topic_record.save_topics(os.path.join(output_dir, "tweet_topic.json"))
save_prompt_success = tweet_topic_record.save_prompt(os.path.join(output_dir, "tweet_prompt.txt"))
if not save_topics_success or not save_prompt_success:
logging.warning("Warning: Failed to save topic generation results or prompts.")
# Continue but warn user
except Exception as e:
logging.exception("Error saving topic generation results:")
traceback.print_exc()
# Return the generated data even if saving fails, but maybe warn more strongly?
# return run_id, tweet_topic_record # Decide if partial success is okay
return None, None # Or consider failure if saving is critical
logging.info(f"Topics generated successfully. Run ID: {run_id}")
# Return the determined run_id and the record
return run_id, tweet_topic_record
# --- Decoupled Functional Units (Moved from main.py) ---
def generate_content_for_topic(ai_agent, prompt_manager, config, topic_item, output_dir, run_id, topic_index):
"""Generates all content variants for a single topic item.
Args:
ai_agent: An initialized AI_Agent instance.
prompt_manager: An initialized PromptManager instance.
config: The global configuration dictionary.
topic_item: The dictionary representing a single topic.
output_dir: The base output directory for the entire run (e.g., ./result).
run_id: The ID for the current run.
topic_index: The 1-based index of the current topic.
Returns:
A list of tweet content data (dictionaries) generated for the topic,
or None if generation failed.
"""
logging.info(f"Generating content for Topic {topic_index} (Object: {topic_item.get('object', 'N/A')})...")
tweet_content_list = []
variants = config.get("variants", 1)
for j in range(variants):
variant_index = j + 1
logging.info(f"Generating Variant {variant_index}/{variants}...")
# Get prompts for this specific topic item
# Assuming prompt_manager is correctly initialized and passed
content_system_prompt, content_user_prompt = prompt_manager.get_content_prompts(topic_item)
if not content_system_prompt or not content_user_prompt:
logging.warning(f"Skipping Variant {variant_index} due to missing content prompts.")
continue # Skip this variant
time.sleep(random.random() * 0.5)
try:
# Call the core generation function (generate_single_content is in this file)
tweet_content, gen_result = generate_single_content(
ai_agent, content_system_prompt, content_user_prompt, topic_item,
output_dir, run_id, topic_index, variant_index,
config.get("content_temperature", 0.3),
config.get("content_top_p", 0.4),
config.get("content_presence_penalty", 1.5)
)
if tweet_content:
try:
tweet_content_data = tweet_content.get_json_file()
if tweet_content_data:
tweet_content_list.append(tweet_content_data)
else:
logging.warning(f"Warning: tweet_content.get_json_file() for Topic {topic_index}, Variant {variant_index} returned empty data.")
except Exception as parse_err:
logging.error(f"Error processing tweet content after generation for Topic {topic_index}, Variant {variant_index}: {parse_err}")
else:
logging.warning(f"Failed to generate content for Topic {topic_index}, Variant {variant_index}. Skipping.")
except Exception as e:
logging.exception(f"Error during content generation for Topic {topic_index}, Variant {variant_index}:")
# traceback.print_exc()
if not tweet_content_list:
logging.warning(f"No valid content generated for Topic {topic_index}.")
return None
else:
logging.info(f"Successfully generated {len(tweet_content_list)} content variants for Topic {topic_index}.")
return tweet_content_list
def generate_posters_for_topic(config, topic_item, tweet_content_list, output_dir, run_id, topic_index):
"""Generates all posters for a single topic item based on its generated content.
Args:
config: The global configuration dictionary.
topic_item: The dictionary representing a single topic.
tweet_content_list: List of content data generated by generate_content_for_topic.
output_dir: The base output directory for the entire run (e.g., ./result).
run_id: The ID for the current run.
topic_index: The 1-based index of the current topic.
Returns:
True if poster generation was attempted (regardless of individual variant success),
False if setup failed before attempting variants.
"""
logging.info(f"Generating posters for Topic {topic_index} (Object: {topic_item.get('object', 'N/A')})...")
# Initialize necessary generators here, assuming they are stateless or cheap to create
# Alternatively, pass initialized instances if they hold state or are expensive
try:
content_gen_instance = core_contentGen.ContentGenerator()
# poster_gen_instance = core_posterGen.PosterGenerator()
# --- Read poster assets base dir from config ---
poster_assets_base_dir = config.get("poster_assets_base_dir")
if not poster_assets_base_dir:
logging.error("Error: 'poster_assets_base_dir' not found in configuration. Cannot generate posters.")
return False # Cannot proceed without assets base dir
# --- Initialize PosterGenerator with the base dir ---
poster_gen_instance = core_posterGen.PosterGenerator(base_dir=poster_assets_base_dir)
except Exception as e:
logging.exception("Error initializing generators for poster creation:")
return False
# --- Setup: Paths and Object Name ---
image_base_dir = config.get("image_base_dir")
if not image_base_dir:
logging.error("Error: image_base_dir missing in config for poster generation.")
return False
modify_image_subdir = config.get("modify_image_subdir", "modify")
camera_image_subdir = config.get("camera_image_subdir", "相机")
object_name = topic_item.get("object", "")
if not object_name:
logging.warning("Warning: Topic object name is missing. Cannot generate posters.")
return False
# Clean object name
try:
object_name_cleaned = object_name.split(".")[0].replace("景点信息-", "").strip()
if not object_name_cleaned:
logging.warning(f"Warning: Object name '{object_name}' resulted in empty string after cleaning. Skipping posters.")
return False
object_name = object_name_cleaned # Use the cleaned name for searching
except Exception as e:
logging.warning(f"Warning: Could not fully clean object name '{object_name}': {e}. Skipping posters.")
return False
# Construct and check INPUT image paths (still needed for collage)
input_img_dir_path = os.path.join(image_base_dir, modify_image_subdir, object_name)
if not os.path.exists(input_img_dir_path) or not os.path.isdir(input_img_dir_path):
logging.warning(f"Warning: Modify Image directory not found or not a directory: '{input_img_dir_path}'. Skipping posters for this topic.")
return False
# --- NEW: Locate Description File using resource_dir type "Description" ---
info_directory = []
description_file_path = None
resource_dir_config = config.get("resource_dir", [])
found_description = False
for dir_info in resource_dir_config:
if dir_info.get("type") == "Description":
for file_path in dir_info.get("file_path", []):
# Match description file based on object name containment
if object_name in os.path.basename(file_path):
description_file_path = file_path
if os.path.exists(description_file_path):
info_directory = [description_file_path] # Pass the found path
logging.info(f"Found and using description file from config: {description_file_path}")
found_description = True
else:
logging.warning(f"Warning: Description file specified in config not found: {description_file_path}")
break # Found the matching entry in this list
if found_description: # Stop searching resource_dir if found
break
if not found_description:
logging.info(f"Warning: No matching description file found for object '{object_name}' in config resource_dir (type='Description').")
# --- End NEW Description File Logic ---
# --- Generate Text Configurations for All Variants ---
try:
poster_text_configs_raw = content_gen_instance.run(info_directory, config["variants"], tweet_content_list)
if not poster_text_configs_raw:
logging.warning("Warning: ContentGenerator returned empty configuration data. Skipping posters.")
return False
poster_config_summary = core_posterGen.PosterConfig(poster_text_configs_raw)
except Exception as e:
logging.exception("Error running ContentGenerator or parsing poster configs:")
traceback.print_exc()
return False # Cannot proceed if text config fails
# --- Poster Generation Loop for each variant ---
poster_num = config.get("variants", 1)
target_size = tuple(config.get("poster_target_size", [900, 1200]))
any_poster_attempted = False
text_possibility = config.get("text_possibility", 0.3) # Get from config
for j_index in range(poster_num):
variant_index = j_index + 1
logging.info(f"Generating Poster {variant_index}/{poster_num}...")
any_poster_attempted = True
try:
poster_config = poster_config_summary.get_config_by_index(j_index)
if not poster_config:
logging.warning(f"Warning: Could not get poster config for index {j_index}. Skipping.")
continue
# Define output directories for this specific variant
run_output_dir = os.path.join(output_dir, run_id) # Base dir for the run
variant_output_dir = os.path.join(run_output_dir, f"{topic_index}_{variant_index}")
output_collage_subdir = config.get("output_collage_subdir", "collage_img")
output_poster_subdir = config.get("output_poster_subdir", "poster")
collage_output_dir = os.path.join(variant_output_dir, output_collage_subdir)
poster_output_dir = os.path.join(variant_output_dir, output_poster_subdir)
os.makedirs(collage_output_dir, exist_ok=True)
os.makedirs(poster_output_dir, exist_ok=True)
# --- Image Collage ---
logging.info(f"Generating collage from: {input_img_dir_path}")
img_list = core_simple_collage.process_directory(
input_img_dir_path,
target_size=target_size,
output_count=1,
output_dir=collage_output_dir
)
if not img_list or len(img_list) == 0 or not img_list[0].get('path'):
logging.warning(f"Warning: Failed to generate collage image for Variant {variant_index}. Skipping poster.")
continue
collage_img_path = img_list[0]['path']
logging.info(f"Using collage image: {collage_img_path}")
# --- Create Poster ---
text_data = {
"title": poster_config.get('main_title', 'Default Title'),
"subtitle": "",
"additional_texts": []
}
texts = poster_config.get('texts', [])
if texts:
# Ensure TEXT_POSBILITY is accessible, maybe pass via config?
# text_possibility = config.get("text_possibility", 0.3)
text_data["additional_texts"].append({"text": texts[0], "position": "bottom", "size_factor": 0.5})
if len(texts) > 1 and random.random() < text_possibility: # Use variable from config
text_data["additional_texts"].append({"text": texts[1], "position": "bottom", "size_factor": 0.5})
# final_poster_path = os.path.join(poster_output_dir, "poster.jpg") # Filename "poster.jpg" is hardcoded
output_poster_filename = config.get("output_poster_filename", "poster.jpg")
final_poster_path = os.path.join(poster_output_dir, output_poster_filename)
result_path = poster_gen_instance.create_poster(collage_img_path, text_data, final_poster_path) # Uses hardcoded output filename
if result_path:
logging.info(f"Successfully generated poster: {result_path}")
else:
logging.warning(f"Warning: Poster generation function did not return a valid path for {final_poster_path}.")
except Exception as e:
logging.exception(f"Error during poster generation for Variant {variant_index}:")
traceback.print_exc()
continue # Continue to next variant
return any_poster_attempted
def main():
"""主函数入口"""
config_file = {
"date": "4月17日",
"num": 5,
"model": "qwenQWQ",
"api_url": "vllm",
"api_key": "EMPTY",
"topic_system_prompt": "/root/autodl-tmp/TravelContentCreator/SelectPrompt/systemPrompt.txt",
"topic_user_prompt": "/root/autodl-tmp/TravelContentCreator/SelectPrompt/userPrompt.txt",
"content_system_prompt": "/root/autodl-tmp/TravelContentCreator/genPrompts/systemPrompt.txt",
"resource_dir": [{
"type": "Object",
"num": 4,
"file_path": ["/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-尚书第.txt",
"/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-明清园.txt",
"/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-泰宁古城.txt",
"/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-甘露寺.txt"
]},
{
"type": "Product",
"num": 0,
"file_path": []
}
],
"prompts_dir": "/root/autodl-tmp/TravelContentCreator/genPrompts",
"output_dir": "/root/autodl-tmp/TravelContentCreator/result",
"variants": 2,
"topic_temperature": 0.2,
"content_temperature": 0.3
}
if True:
# 1. 首先生成选题
ai_agent, system_prompt, user_prompt, output_dir = prepare_topic_generation(
config_file
)
run_id, tweet_topic_record = generate_topics(
ai_agent, system_prompt, user_prompt, config_file["output_dir"],
config_file["topic_temperature"], 0.5, 1.5
)
output_dir = os.path.join(config_file["output_dir"], run_id)
os.makedirs(output_dir, exist_ok=True)
tweet_topic_record.save_topics(os.path.join(output_dir, "tweet_topic.json"))
tweet_topic_record.save_prompt(os.path.join(output_dir, "tweet_prompt.txt"))
# raise Exception("选题生成失败,退出程序")
if not run_id or not tweet_topic_record:
print("选题生成失败,退出程序")
return
# 2. 然后生成内容
print("\n开始根据选题生成内容...")
# 加载内容生成的系统提示词
content_system_prompt = ResourceLoader.load_system_prompt(config_file["content_system_prompt"])
if not content_system_prompt:
print("内容生成系统提示词为空,使用选题生成的系统提示词")
content_system_prompt = system_prompt
# 直接使用同一个AI Agent实例
result = generate_content(
ai_agent, content_system_prompt, tweet_topic_record.topics_list, output_dir, run_id, config_file["prompts_dir"], config_file["resource_dir"],
config_file["variants"], config_file["content_temperature"]
)
if __name__ == "__main__":
main()