TravelContentCreator/scripts/regenerate_title.py

497 lines
22 KiB
Python
Raw Normal View History

import os
import json
import logging
from pathlib import Path
import traceback
import sys
import random
import re
from datetime import datetime
from openai import OpenAI
import time # Added for timing the single API call
# --- Path Setup ---
# Ensure the project root is in the path
project_root = Path(__file__).resolve().parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
print(f"Added project root {project_root} to sys.path")
# --- Logging Configuration ---
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
# --- Constants ---
DEFAULT_REF_TITLE_SAMPLE_SIZE = 20 # Number of reference titles to include in prompt
DEFAULT_RUN_DIR = "/root/autodl-tmp/TravelContentCreator/result/长鹿旅游休博园/2025-04-27_14-03-44" # Default run directory
DEFAULT_CONFIG_PATH = "poster_gen_config.json" # Default config path
DEFAULT_REF_DIR = "/root/autodl-tmp/TravelContentCreator/genPrompts/Refer" # Default reference directory
# --- Helper Functions ---
def load_main_config(config_path="poster_gen_config.json"):
"""Loads the main configuration file."""
config_file = Path(config_path)
if not config_file.is_file():
logger.error(f"Main configuration file not found: {config_path}")
return None
try:
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
logger.info(f"Main configuration loaded successfully: {config_path}")
# Validate necessary keys for this script
required_keys = ["api_url", "model", "api_key"] # QWQ model should be set here
if not all(key in config for key in required_keys):
missing = [key for key in required_keys if key not in config]
logger.error(f"Main configuration file is missing required keys: {missing}")
return None
if config.get("model","").lower() != "qwq":
logger.warning(f"Configured model is '{config.get('model')}', but QWQ was requested. Ensure the config is correct.")
# You might want to force model='qwq' or exit if strict adherence is needed
# config['model'] = 'qwq' # Force it?
return config
except json.JSONDecodeError:
logger.error(f"Failed to parse main configuration file (JSON format error): {config_path}")
return None
except Exception as e:
logger.exception(f"Error loading main configuration from {config_path}: {e}")
return None
def load_topic_data(run_dir: Path, run_id: str):
"""Loads the topic list from the tweet_topic_{run_id}.json file."""
topic_file_exact = run_dir / f"tweet_topic_{run_id}.json"
topic_file_generic = run_dir / "tweet_topic.json"
topic_file_to_load = None
if topic_file_exact.is_file():
topic_file_to_load = topic_file_exact
elif topic_file_generic.is_file():
topic_file_to_load = topic_file_generic
logger.warning(f"Specific topic file {topic_file_exact.name} not found, using generic {topic_file_generic.name}")
else:
logger.error(f"Topic file not found in {run_dir} (tried {topic_file_exact.name} and {topic_file_generic.name})")
return None
try:
with open(topic_file_to_load, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, dict) and "topics_list" in data:
topics_list = data["topics_list"]
elif isinstance(data, list):
topics_list = data
else:
logger.error(f"Unexpected format in topic file: {topic_file_to_load}")
return None
logger.info(f"Successfully loaded {len(topics_list)} topics from {topic_file_to_load}")
return topics_list
except json.JSONDecodeError:
logger.error(f"Failed to parse topic file (JSON format error): {topic_file_to_load}")
return None
except Exception as e:
logger.exception(f"Error loading topics from {topic_file_to_load}: {e}")
return None
def get_topic_index(topic_item: dict, list_index: int) -> int:
"""Safely gets the topic index, falling back to list index + 1."""
topic_index = topic_item.get("index")
try:
if isinstance(topic_index, str) and topic_index.isdigit():
topic_index = int(topic_index)
elif not isinstance(topic_index, int) or topic_index < 1:
logger.warning(f"Topic item {list_index} has invalid or missing 'index'. Using list index {list_index + 1}.")
topic_index = list_index + 1
except Exception:
logger.warning(f"Error processing index for topic item {list_index}. Using list index {list_index + 1}.")
topic_index = list_index + 1
return topic_index
def find_and_load_variant_contents(run_dir: Path, topic_index: int) -> list[str]:
"""Finds all variant directories for a topic and loads their article *content*."""
variant_contents = []
variant_dirs = sorted(run_dir.glob(f"{topic_index}_*"), key=lambda p: int(p.name.split('_')[-1]))
if not variant_dirs:
logger.warning(f"No variant directories found for topic {topic_index} in {run_dir}")
return variant_contents
logger.debug(f"Found {len(variant_dirs)} potential variant directories for topic {topic_index}: {[d.name for d in variant_dirs]}")
for variant_dir in variant_dirs:
if variant_dir.is_dir() and '_' in variant_dir.name:
parts = variant_dir.name.split('_')
if len(parts) == 2 and parts[0] == str(topic_index) and parts[1].isdigit():
content_data = load_variant_article_data(variant_dir)
if content_data and isinstance(content_data.get("content"), str):
variant_contents.append(content_data["content"]) # Only add the content string
else:
logger.warning(f"Could not load valid 'content' string for variant {variant_dir.name}, it will be excluded.")
logger.info(f"Loaded content strings for {len(variant_contents)} variants for topic {topic_index}.")
return variant_contents
def load_variant_article_data(variant_dir: Path):
"""Loads the article data (dict) from article.json."""
content_file = variant_dir / "article.json"
if not content_file.is_file():
logger.warning(f"Article content file not found: {content_file}")
return None
try:
with open(content_file, 'r', encoding='utf-8') as f:
content_data = json.load(f)
# Basic check for dict format
if isinstance(content_data, dict):
# logger.debug(f"Successfully loaded content data from {content_file}")
return content_data
else:
logger.warning(f"Invalid format (not a dict) in article content file: {content_file}.")
return None
except json.JSONDecodeError:
logger.error(f"Failed to parse article content file (JSON format error): {content_file}")
return None
except Exception as e:
logger.exception(f"Error loading content data from {content_file}: {e}")
return None
def load_reference_titles(ref_dir: Path) -> list[str]:
"""Loads all lines from text files in the reference directory."""
titles = []
if not ref_dir or not ref_dir.is_dir():
logger.error(f"Reference title directory not found or invalid: {ref_dir}")
return titles
logger.info(f"Loading reference titles from: {ref_dir}")
try:
for file_path in ref_dir.glob("*.txt"): # Assuming titles are in .txt files
if file_path.is_file():
logger.debug(f"Reading reference file: {file_path.name}")
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
cleaned_line = line.strip()
if cleaned_line: # Avoid empty lines
titles.append(cleaned_line)
except Exception as e:
logger.exception(f"Error reading reference files from {ref_dir}: {e}")
logger.info(f"Loaded {len(titles)} reference titles.")
return titles
# --- Prompt Generation (Single Topic, All Contents) ---
# Logic reverted: Use all j contents for topic i, expect j titles in order.
def create_single_topic_prompts(
topic_description: str,
# selected_contents: list[str],
content_list: list[str], # Use all j contents for the topic
num_titles_needed: int, # The original variant count (j)
reference_titles: list[str],
sample_size: int
) -> tuple[str, str]:
"""为单个主题创建系统和用户提示词使用其全部j篇内容生成j个一一对应的标题。"""
# Select a sample of reference titles
if len(reference_titles) > sample_size:
sampled_refs = random.sample(reference_titles, sample_size)
else:
sampled_refs = reference_titles
# System Prompt - Ask for j titles corresponding to j contents
system_prompt = f"""你是一位专业的社交媒体文案撰稿人,尤其擅长小红书、推特等平台的旅行内容。
你的任务是为一个关于"{topic_description}"的主题根据我提供的 {num_titles_needed} 篇具体内容生成恰好 {num_titles_needed} 个与之对应的推文标题
生成的每个标题必须严格按照用户提示中内容项的顺序一一对应
请确保标题简洁理想情况下少于15个字吸引人与具体内容相关并能抓住描述的旅行体验精髓
输出格式必须严格为仅包含一个 JSON 列表其中包含 {num_titles_needed} 个字符串每个字符串是一个生成的标题列表中的标题顺序必须与提供的内容项顺序完全一致
例如对于3个内容项输出格式应为["第1项内容的标题", "第2项内容的标题", "第3项内容的标题"]
不要在 JSON 列表前后包含任何解释道歉或其他无关文本"""
# User Prompt - Show all j contents for the topic
content_block = ""
# content_count = len(selected_contents)
# for i, content in enumerate(selected_contents):
for i, content in enumerate(content_list):
# Combine into a single multi-line f-string for clarity and safety
# content_block += f"""--- 示例内容 {i+1} ---
content_block += f"""--- 内容项 {i+1} ---
内容摘要:
{content[:800]}...\n\n"""
ref_block = "\n".join([f"- {ref}" for ref in sampled_refs])
# user_prompt = f"""请为主题"{topic_description}"生成恰好 {num_titles_needed} 个独特的推文标题。
# 这里有 {content_count} 篇示例内容供你参考:
user_prompt = f"""请为主题"{topic_description}"的以下 {num_titles_needed} 篇内容,生成恰好 {num_titles_needed} 个一一对应的推文标题。
请确保输出列表中的标题与这些内容项按顺序一一对应
{content_block}
这里还有一些参考标题可供启发
{ref_block}
请记住只输出包含 {num_titles_needed} 个生成标题且顺序正确的 JSON 列表"""
return system_prompt, user_prompt
# --- AI Response Parsing (Should be suitable for flat list) ---
def parse_ai_title_response(response: str, num_expected: int) -> list[str] | None:
"""Parses the AI response expecting a JSON list of strings."""
try:
# Find the JSON list part (handle potential markdown fences)
json_match = re.search(r'\[.*\]', response, re.DOTALL)
if not json_match:
logger.error(f"Could not find JSON list structure in AI response: {response[:200]}...")
return None
json_str = json_match.group(0)
titles = json.loads(json_str)
if isinstance(titles, list) and all(isinstance(t, str) for t in titles):
if len(titles) != num_expected:
logger.warning(f"AI generated {len(titles)} titles, but {num_expected} were expected. Using generated titles anyway.")
# Optionally pad or truncate here if strict count is needed and alignment is not critical
# while len(titles) < num_expected: titles.append("[MISSING TITLE]")
# titles = titles[:num_expected]
return titles
else:
logger.error(f"Parsed JSON is not a list of strings: {titles}")
return None
except json.JSONDecodeError as e:
logger.error(f"Failed to decode AI response as JSON: {e}. Response snippet: {response[:200]}...")
return None
except Exception as e:
logger.exception(f"Unexpected error parsing AI response: {e}")
return None
# --- Title Generation (Single Topic) ---
# Adjusted from batch generation logic
def generate_titles_for_topic(
openai_client: OpenAI,
model_name: str,
config: dict,
system_prompt: str,
user_prompt: str,
num_titles_needed: int, # Expecting j titles
) -> list[str] | None:
"""Generates titles for a single topic using OpenAI API."""
if num_titles_needed == 0:
logger.warning("num_titles_needed is 0 for this topic. Skipping generation.")
return [] # Return empty list
logger.info(f"--- Regenerating {num_titles_needed} Titles for Topic (based on <=2 examples) --- ")
logger.debug(f"System Prompt:\n{system_prompt}")
logger.debug(f"User Prompt Snippet:\n{user_prompt[:500]}...") # Log only snippet
# Call OpenAI API
try:
temp = config.get("title_temperature", 0.7)
top_p = config.get("title_top_p", 0.8)
pres_penalty = config.get("title_presence_penalty", 1.0)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
logger.info(f"Calling OpenAI model '{model_name}' to generate {num_titles_needed} titles...")
start_time = time.time()
response_object = openai_client.chat.completions.create(
model=model_name,
messages=messages,
temperature=temp,
top_p=top_p,
presence_penalty=pres_penalty,
)
end_time = time.time()
time_cost = end_time - start_time
logger.info(f"OpenAI API call completed in {time_cost:.2f}s.")
response_content = response_object.choices[0].message.content
logger.debug(f"Raw OpenAI Response content: {response_content[:500]}...")
if hasattr(response_object, 'usage') and response_object.usage:
logger.info(f"API Usage: Prompt={response_object.usage.prompt_tokens}, Completion={response_object.usage.completion_tokens}, Total={response_object.usage.total_tokens}")
else:
logger.info(f"API Usage: Information not available in response object.")
except Exception as e:
logger.exception(f"OpenAI API call failed: {e}")
return None
# Parse Response
if response_content:
generated_titles = parse_ai_title_response(response_content, num_titles_needed)
if generated_titles:
logger.info(f"Successfully parsed {len(generated_titles)} titles from response.")
# Warning for mismatch is already in parse_ai_title_response
return generated_titles
else:
logger.error("Failed to parse titles from AI response.")
return None # Indicate failure
else:
logger.error("AI returned an empty response.")
return None
# --- Main Logic ---
def main(run_dir_path: Path, config_path: str, ref_dir_path: Path, debug_mode: bool):
"""Main processing function for a single run directory."""
if debug_mode:
logger.setLevel(logging.DEBUG)
logging.getLogger().setLevel(logging.DEBUG)
logger.info("DEBUG 日志已启用")
run_id = run_dir_path.name
logger.info(f"\n===== Processing Run Directory: {run_dir_path} (Run ID: {run_id}) =====")
# Load main config
main_config = load_main_config(config_path)
if main_config is None:
logger.critical("Failed to load main configuration. Aborting.")
return
# Load reference titles
reference_titles = load_reference_titles(ref_dir_path)
if not reference_titles:
logger.warning("No reference titles loaded. Title generation quality may be affected.")
# Initialize OpenAI client
try:
openai_client = OpenAI(
base_url=main_config["api_url"],
api_key=main_config["api_key"],
# Add timeout if needed: timeout=main_config.get("request_timeout", 180)
)
logger.info(f"OpenAI client initialized for base_url: {main_config['api_url']}")
model_to_use = main_config["model"] # Get model name from config
logger.info(f"Using model: {model_to_use}")
except KeyError as e:
logger.critical(f"Failed to initialize OpenAI client: Missing key {e} in config")
return
except Exception as e:
logger.critical(f"Failed to initialize OpenAI client: {e}")
traceback.print_exc()
return
# Load topics for this specific run
topics_list = load_topic_data(run_dir_path, run_id)
if topics_list is None:
logger.critical(f"Failed to load topics for run {run_id}. Aborting.")
return
# --- Process Topics Individually ---
regenerated_results = {
"run_id": run_id,
"regenerated_at": datetime.now().isoformat(),
"reference_directory": str(ref_dir_path),
"config_model": model_to_use,
# "generation_mode": "single_topic_max2_content", # Indicate the mode
"generation_mode": "single_topic_all_content_j_titles", # Updated mode description
"titles_by_topic": {}
}
success_topics = 0
failed_topics = 0
titles_generated_count = 0
for i, topic_item in enumerate(topics_list):
topic_index = get_topic_index(topic_item, i)
topic_description = topic_item.get("topic", f"未知主题 {topic_index}")
logger.info(f"\n--- Processing Topic {topic_index}: '{topic_description}' ---")
# Find and load content for all variants of this topic
variant_contents = find_and_load_variant_contents(run_dir_path, topic_index)
original_variant_count = len(variant_contents) # This is 'j'
if original_variant_count == 0:
logger.warning(f"Skipping topic {topic_index} due to missing variant content.")
regenerated_results["titles_by_topic"][str(topic_index)] = [] # Store empty list
failed_topics += 1
continue
# Removed logic for selecting max 2 contents
# We will use all variant_contents (j items)
logger.info(f"Topic {topic_index} has {original_variant_count} variants. Using all for title generation.")
# Regenerate titles (expecting original_variant_count titles, 1 per content)
try:
# 1. Create Prompts for this topic using all contents
system_prompt, user_prompt = create_single_topic_prompts(
topic_description,
# selected_contents,
variant_contents, # Pass all j contents
original_variant_count, # Ask for j titles
reference_titles,
DEFAULT_REF_TITLE_SAMPLE_SIZE
)
# 2. Call AI Function for this topic
new_titles = generate_titles_for_topic(
openai_client,
model_to_use,
main_config,
system_prompt,
user_prompt,
original_variant_count # Expect j titles
)
if new_titles is not None:
# Store results using the determined topic_index as string key
regenerated_results["titles_by_topic"][str(topic_index)] = new_titles
titles_generated_count += len(new_titles)
# Consider success if API call returned something, even if count mismatches (warning is in parse func)
success_topics += 1
logger.info(f"Successfully generated {len(new_titles)} titles for topic {topic_index}.")
else:
# API call or parsing failed
logger.error(f"Failed to generate titles for topic {topic_index}.")
regenerated_results["titles_by_topic"][str(topic_index)] = [] # Store empty list on failure
failed_topics += 1
except Exception as e:
logger.exception(f"Unhandled error regenerating titles for topic {topic_index}:")
regenerated_results["titles_by_topic"][str(topic_index)] = [] # Store empty list on error
failed_topics += 1
# --- End Processing Topics ---
# Save the results
output_file = run_dir_path / f"regenerated_titles_{run_id}.json"
try:
with open(output_file, 'w', encoding='utf-8') as f_out:
json.dump(regenerated_results, f_out, ensure_ascii=False, indent=4)
logger.info(f"Regenerated titles saved to: {output_file}")
except Exception as e:
logger.exception(f"Failed to save regenerated titles to {output_file}: {e}")
logger.info("=" * 30)
logger.info(f"Title Regeneration Summary for Run ID: {run_id}")
logger.info(f"Successfully Generated Titles for Topics: {success_topics}")
logger.info(f"Failed/Skipped Topics: {failed_topics}")
logger.info(f"Total Titles Generated (across all topics): {titles_generated_count}")
logger.info("=" * 30)
if __name__ == "__main__":
# --- Configuration (Set parameters directly here) ---
# Use Path objects for directories
run_directory_path_str = DEFAULT_RUN_DIR
config_path_str = DEFAULT_CONFIG_PATH
reference_directory_path_str = DEFAULT_REF_DIR
debug_mode_enabled = False # Set to True to enable debug logging
run_directory = Path(run_directory_path_str)
reference_directory = Path(reference_directory_path_str)
# --- End Configuration ---
# --- Validate Paths ---
if not run_directory.is_dir():
print(f"错误: 指定的运行目录不存在或不是一个目录: {run_directory_path_str}")
sys.exit(1)
if not reference_directory.is_dir():
print(f"错误: 指定的参考标题目录不存在或不是一个目录: {reference_directory_path_str}")
sys.exit(1)
# --- End Validate Paths ---
# Call the main processing function
main(run_directory, config_path_str, reference_directory, debug_mode_enabled)