497 lines
22 KiB
Python
497 lines
22 KiB
Python
import os
|
||
import json
|
||
import logging
|
||
from pathlib import Path
|
||
import traceback
|
||
import sys
|
||
import random
|
||
import re
|
||
from datetime import datetime
|
||
from openai import OpenAI
|
||
import time # Added for timing the single API call
|
||
|
||
# --- Path Setup ---
|
||
# Ensure the project root is in the path
|
||
project_root = Path(__file__).resolve().parent.parent
|
||
if str(project_root) not in sys.path:
|
||
sys.path.insert(0, str(project_root))
|
||
print(f"Added project root {project_root} to sys.path")
|
||
|
||
# --- Logging Configuration ---
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
|
||
datefmt='%Y-%m-%d %H:%M:%S'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# --- Constants ---
|
||
DEFAULT_REF_TITLE_SAMPLE_SIZE = 20 # Number of reference titles to include in prompt
|
||
DEFAULT_RUN_DIR = "/root/autodl-tmp/TravelContentCreator/result/长鹿旅游休博园/2025-04-27_14-03-44" # Default run directory
|
||
DEFAULT_CONFIG_PATH = "poster_gen_config.json" # Default config path
|
||
DEFAULT_REF_DIR = "/root/autodl-tmp/TravelContentCreator/genPrompts/Refer" # Default reference directory
|
||
|
||
# --- Helper Functions ---
|
||
|
||
def load_main_config(config_path="poster_gen_config.json"):
|
||
"""Loads the main configuration file."""
|
||
config_file = Path(config_path)
|
||
if not config_file.is_file():
|
||
logger.error(f"Main configuration file not found: {config_path}")
|
||
return None
|
||
try:
|
||
with open(config_file, 'r', encoding='utf-8') as f:
|
||
config = json.load(f)
|
||
logger.info(f"Main configuration loaded successfully: {config_path}")
|
||
# Validate necessary keys for this script
|
||
required_keys = ["api_url", "model", "api_key"] # QWQ model should be set here
|
||
if not all(key in config for key in required_keys):
|
||
missing = [key for key in required_keys if key not in config]
|
||
logger.error(f"Main configuration file is missing required keys: {missing}")
|
||
return None
|
||
if config.get("model","").lower() != "qwq":
|
||
logger.warning(f"Configured model is '{config.get('model')}', but QWQ was requested. Ensure the config is correct.")
|
||
# You might want to force model='qwq' or exit if strict adherence is needed
|
||
# config['model'] = 'qwq' # Force it?
|
||
return config
|
||
except json.JSONDecodeError:
|
||
logger.error(f"Failed to parse main configuration file (JSON format error): {config_path}")
|
||
return None
|
||
except Exception as e:
|
||
logger.exception(f"Error loading main configuration from {config_path}: {e}")
|
||
return None
|
||
|
||
def load_topic_data(run_dir: Path, run_id: str):
|
||
"""Loads the topic list from the tweet_topic_{run_id}.json file."""
|
||
topic_file_exact = run_dir / f"tweet_topic_{run_id}.json"
|
||
topic_file_generic = run_dir / "tweet_topic.json"
|
||
|
||
topic_file_to_load = None
|
||
if topic_file_exact.is_file():
|
||
topic_file_to_load = topic_file_exact
|
||
elif topic_file_generic.is_file():
|
||
topic_file_to_load = topic_file_generic
|
||
logger.warning(f"Specific topic file {topic_file_exact.name} not found, using generic {topic_file_generic.name}")
|
||
else:
|
||
logger.error(f"Topic file not found in {run_dir} (tried {topic_file_exact.name} and {topic_file_generic.name})")
|
||
return None
|
||
|
||
try:
|
||
with open(topic_file_to_load, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
if isinstance(data, dict) and "topics_list" in data:
|
||
topics_list = data["topics_list"]
|
||
elif isinstance(data, list):
|
||
topics_list = data
|
||
else:
|
||
logger.error(f"Unexpected format in topic file: {topic_file_to_load}")
|
||
return None
|
||
logger.info(f"Successfully loaded {len(topics_list)} topics from {topic_file_to_load}")
|
||
return topics_list
|
||
except json.JSONDecodeError:
|
||
logger.error(f"Failed to parse topic file (JSON format error): {topic_file_to_load}")
|
||
return None
|
||
except Exception as e:
|
||
logger.exception(f"Error loading topics from {topic_file_to_load}: {e}")
|
||
return None
|
||
|
||
def get_topic_index(topic_item: dict, list_index: int) -> int:
|
||
"""Safely gets the topic index, falling back to list index + 1."""
|
||
topic_index = topic_item.get("index")
|
||
try:
|
||
if isinstance(topic_index, str) and topic_index.isdigit():
|
||
topic_index = int(topic_index)
|
||
elif not isinstance(topic_index, int) or topic_index < 1:
|
||
logger.warning(f"Topic item {list_index} has invalid or missing 'index'. Using list index {list_index + 1}.")
|
||
topic_index = list_index + 1
|
||
except Exception:
|
||
logger.warning(f"Error processing index for topic item {list_index}. Using list index {list_index + 1}.")
|
||
topic_index = list_index + 1
|
||
return topic_index
|
||
|
||
def find_and_load_variant_contents(run_dir: Path, topic_index: int) -> list[str]:
|
||
"""Finds all variant directories for a topic and loads their article *content*."""
|
||
variant_contents = []
|
||
variant_dirs = sorted(run_dir.glob(f"{topic_index}_*"), key=lambda p: int(p.name.split('_')[-1]))
|
||
|
||
if not variant_dirs:
|
||
logger.warning(f"No variant directories found for topic {topic_index} in {run_dir}")
|
||
return variant_contents
|
||
|
||
logger.debug(f"Found {len(variant_dirs)} potential variant directories for topic {topic_index}: {[d.name for d in variant_dirs]}")
|
||
|
||
for variant_dir in variant_dirs:
|
||
if variant_dir.is_dir() and '_' in variant_dir.name:
|
||
parts = variant_dir.name.split('_')
|
||
if len(parts) == 2 and parts[0] == str(topic_index) and parts[1].isdigit():
|
||
content_data = load_variant_article_data(variant_dir)
|
||
if content_data and isinstance(content_data.get("content"), str):
|
||
variant_contents.append(content_data["content"]) # Only add the content string
|
||
else:
|
||
logger.warning(f"Could not load valid 'content' string for variant {variant_dir.name}, it will be excluded.")
|
||
|
||
logger.info(f"Loaded content strings for {len(variant_contents)} variants for topic {topic_index}.")
|
||
return variant_contents
|
||
|
||
|
||
def load_variant_article_data(variant_dir: Path):
|
||
"""Loads the article data (dict) from article.json."""
|
||
content_file = variant_dir / "article.json"
|
||
if not content_file.is_file():
|
||
logger.warning(f"Article content file not found: {content_file}")
|
||
return None
|
||
try:
|
||
with open(content_file, 'r', encoding='utf-8') as f:
|
||
content_data = json.load(f)
|
||
# Basic check for dict format
|
||
if isinstance(content_data, dict):
|
||
# logger.debug(f"Successfully loaded content data from {content_file}")
|
||
return content_data
|
||
else:
|
||
logger.warning(f"Invalid format (not a dict) in article content file: {content_file}.")
|
||
return None
|
||
except json.JSONDecodeError:
|
||
logger.error(f"Failed to parse article content file (JSON format error): {content_file}")
|
||
return None
|
||
except Exception as e:
|
||
logger.exception(f"Error loading content data from {content_file}: {e}")
|
||
return None
|
||
|
||
def load_reference_titles(ref_dir: Path) -> list[str]:
|
||
"""Loads all lines from text files in the reference directory."""
|
||
titles = []
|
||
if not ref_dir or not ref_dir.is_dir():
|
||
logger.error(f"Reference title directory not found or invalid: {ref_dir}")
|
||
return titles
|
||
logger.info(f"Loading reference titles from: {ref_dir}")
|
||
try:
|
||
for file_path in ref_dir.glob("*.txt"): # Assuming titles are in .txt files
|
||
if file_path.is_file():
|
||
logger.debug(f"Reading reference file: {file_path.name}")
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
for line in f:
|
||
cleaned_line = line.strip()
|
||
if cleaned_line: # Avoid empty lines
|
||
titles.append(cleaned_line)
|
||
except Exception as e:
|
||
logger.exception(f"Error reading reference files from {ref_dir}: {e}")
|
||
|
||
logger.info(f"Loaded {len(titles)} reference titles.")
|
||
return titles
|
||
|
||
# --- Prompt Generation (Single Topic, All Contents) ---
|
||
# Logic reverted: Use all j contents for topic i, expect j titles in order.
|
||
def create_single_topic_prompts(
|
||
topic_description: str,
|
||
# selected_contents: list[str],
|
||
content_list: list[str], # Use all j contents for the topic
|
||
num_titles_needed: int, # The original variant count (j)
|
||
reference_titles: list[str],
|
||
sample_size: int
|
||
) -> tuple[str, str]:
|
||
"""为单个主题创建系统和用户提示词,使用其全部j篇内容,生成j个一一对应的标题。"""
|
||
|
||
# Select a sample of reference titles
|
||
if len(reference_titles) > sample_size:
|
||
sampled_refs = random.sample(reference_titles, sample_size)
|
||
else:
|
||
sampled_refs = reference_titles
|
||
|
||
# System Prompt - Ask for j titles corresponding to j contents
|
||
system_prompt = f"""你是一位专业的社交媒体文案撰稿人,尤其擅长小红书、推特等平台的旅行内容。
|
||
你的任务是为一个关于"{topic_description}"的主题,根据我提供的 {num_titles_needed} 篇具体内容,生成恰好 {num_titles_needed} 个与之对应的推文标题。
|
||
生成的每个标题必须严格按照用户提示中内容项的顺序一一对应。
|
||
请确保标题简洁(理想情况下少于15个字)、吸引人、与具体内容相关,并能抓住描述的旅行体验精髓。
|
||
输出格式必须严格为仅包含一个 JSON 列表,其中包含 {num_titles_needed} 个字符串,每个字符串是一个生成的标题。列表中的标题顺序必须与提供的内容项顺序完全一致。
|
||
例如,对于3个内容项,输出格式应为:["第1项内容的标题", "第2项内容的标题", "第3项内容的标题"]
|
||
不要在 JSON 列表前后包含任何解释、道歉或其他无关文本。"""
|
||
|
||
# User Prompt - Show all j contents for the topic
|
||
content_block = ""
|
||
# content_count = len(selected_contents)
|
||
# for i, content in enumerate(selected_contents):
|
||
for i, content in enumerate(content_list):
|
||
# Combine into a single multi-line f-string for clarity and safety
|
||
# content_block += f"""--- 示例内容 {i+1} ---
|
||
content_block += f"""--- 内容项 {i+1} ---
|
||
内容摘要:
|
||
{content[:800]}...\n\n"""
|
||
|
||
ref_block = "\n".join([f"- {ref}" for ref in sampled_refs])
|
||
|
||
# user_prompt = f"""请为主题"{topic_description}"生成恰好 {num_titles_needed} 个独特的推文标题。
|
||
# 这里有 {content_count} 篇示例内容供你参考:
|
||
user_prompt = f"""请为主题"{topic_description}"的以下 {num_titles_needed} 篇内容,生成恰好 {num_titles_needed} 个一一对应的推文标题。
|
||
请确保输出列表中的标题与这些内容项按顺序一一对应:
|
||
|
||
{content_block}
|
||
这里还有一些参考标题可供启发:
|
||
{ref_block}
|
||
|
||
请记住,只输出包含 {num_titles_needed} 个生成标题且顺序正确的 JSON 列表。"""
|
||
|
||
return system_prompt, user_prompt
|
||
|
||
# --- AI Response Parsing (Should be suitable for flat list) ---
|
||
|
||
def parse_ai_title_response(response: str, num_expected: int) -> list[str] | None:
|
||
"""Parses the AI response expecting a JSON list of strings."""
|
||
try:
|
||
# Find the JSON list part (handle potential markdown fences)
|
||
json_match = re.search(r'\[.*\]', response, re.DOTALL)
|
||
if not json_match:
|
||
logger.error(f"Could not find JSON list structure in AI response: {response[:200]}...")
|
||
return None
|
||
|
||
json_str = json_match.group(0)
|
||
titles = json.loads(json_str)
|
||
|
||
if isinstance(titles, list) and all(isinstance(t, str) for t in titles):
|
||
if len(titles) != num_expected:
|
||
logger.warning(f"AI generated {len(titles)} titles, but {num_expected} were expected. Using generated titles anyway.")
|
||
# Optionally pad or truncate here if strict count is needed and alignment is not critical
|
||
# while len(titles) < num_expected: titles.append("[MISSING TITLE]")
|
||
# titles = titles[:num_expected]
|
||
return titles
|
||
else:
|
||
logger.error(f"Parsed JSON is not a list of strings: {titles}")
|
||
return None
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"Failed to decode AI response as JSON: {e}. Response snippet: {response[:200]}...")
|
||
return None
|
||
except Exception as e:
|
||
logger.exception(f"Unexpected error parsing AI response: {e}")
|
||
return None
|
||
|
||
# --- Title Generation (Single Topic) ---
|
||
# Adjusted from batch generation logic
|
||
def generate_titles_for_topic(
|
||
openai_client: OpenAI,
|
||
model_name: str,
|
||
config: dict,
|
||
system_prompt: str,
|
||
user_prompt: str,
|
||
num_titles_needed: int, # Expecting j titles
|
||
) -> list[str] | None:
|
||
"""Generates titles for a single topic using OpenAI API."""
|
||
if num_titles_needed == 0:
|
||
logger.warning("num_titles_needed is 0 for this topic. Skipping generation.")
|
||
return [] # Return empty list
|
||
|
||
logger.info(f"--- Regenerating {num_titles_needed} Titles for Topic (based on <=2 examples) --- ")
|
||
logger.debug(f"System Prompt:\n{system_prompt}")
|
||
logger.debug(f"User Prompt Snippet:\n{user_prompt[:500]}...") # Log only snippet
|
||
|
||
# Call OpenAI API
|
||
try:
|
||
temp = config.get("title_temperature", 0.7)
|
||
top_p = config.get("title_top_p", 0.8)
|
||
pres_penalty = config.get("title_presence_penalty", 1.0)
|
||
|
||
messages = [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt}
|
||
]
|
||
|
||
logger.info(f"Calling OpenAI model '{model_name}' to generate {num_titles_needed} titles...")
|
||
start_time = time.time()
|
||
response_object = openai_client.chat.completions.create(
|
||
model=model_name,
|
||
messages=messages,
|
||
temperature=temp,
|
||
top_p=top_p,
|
||
presence_penalty=pres_penalty,
|
||
)
|
||
end_time = time.time()
|
||
time_cost = end_time - start_time
|
||
logger.info(f"OpenAI API call completed in {time_cost:.2f}s.")
|
||
|
||
response_content = response_object.choices[0].message.content
|
||
logger.debug(f"Raw OpenAI Response content: {response_content[:500]}...")
|
||
if hasattr(response_object, 'usage') and response_object.usage:
|
||
logger.info(f"API Usage: Prompt={response_object.usage.prompt_tokens}, Completion={response_object.usage.completion_tokens}, Total={response_object.usage.total_tokens}")
|
||
else:
|
||
logger.info(f"API Usage: Information not available in response object.")
|
||
|
||
except Exception as e:
|
||
logger.exception(f"OpenAI API call failed: {e}")
|
||
return None
|
||
|
||
# Parse Response
|
||
if response_content:
|
||
generated_titles = parse_ai_title_response(response_content, num_titles_needed)
|
||
if generated_titles:
|
||
logger.info(f"Successfully parsed {len(generated_titles)} titles from response.")
|
||
# Warning for mismatch is already in parse_ai_title_response
|
||
return generated_titles
|
||
else:
|
||
logger.error("Failed to parse titles from AI response.")
|
||
return None # Indicate failure
|
||
else:
|
||
logger.error("AI returned an empty response.")
|
||
return None
|
||
|
||
# --- Main Logic ---
|
||
def main(run_dir_path: Path, config_path: str, ref_dir_path: Path, debug_mode: bool):
|
||
"""Main processing function for a single run directory."""
|
||
|
||
if debug_mode:
|
||
logger.setLevel(logging.DEBUG)
|
||
logging.getLogger().setLevel(logging.DEBUG)
|
||
logger.info("DEBUG 日志已启用")
|
||
|
||
run_id = run_dir_path.name
|
||
logger.info(f"\n===== Processing Run Directory: {run_dir_path} (Run ID: {run_id}) =====")
|
||
|
||
# Load main config
|
||
main_config = load_main_config(config_path)
|
||
if main_config is None:
|
||
logger.critical("Failed to load main configuration. Aborting.")
|
||
return
|
||
|
||
# Load reference titles
|
||
reference_titles = load_reference_titles(ref_dir_path)
|
||
if not reference_titles:
|
||
logger.warning("No reference titles loaded. Title generation quality may be affected.")
|
||
|
||
# Initialize OpenAI client
|
||
try:
|
||
openai_client = OpenAI(
|
||
base_url=main_config["api_url"],
|
||
api_key=main_config["api_key"],
|
||
# Add timeout if needed: timeout=main_config.get("request_timeout", 180)
|
||
)
|
||
logger.info(f"OpenAI client initialized for base_url: {main_config['api_url']}")
|
||
model_to_use = main_config["model"] # Get model name from config
|
||
logger.info(f"Using model: {model_to_use}")
|
||
except KeyError as e:
|
||
logger.critical(f"Failed to initialize OpenAI client: Missing key {e} in config")
|
||
return
|
||
except Exception as e:
|
||
logger.critical(f"Failed to initialize OpenAI client: {e}")
|
||
traceback.print_exc()
|
||
return
|
||
|
||
# Load topics for this specific run
|
||
topics_list = load_topic_data(run_dir_path, run_id)
|
||
if topics_list is None:
|
||
logger.critical(f"Failed to load topics for run {run_id}. Aborting.")
|
||
return
|
||
|
||
# --- Process Topics Individually ---
|
||
regenerated_results = {
|
||
"run_id": run_id,
|
||
"regenerated_at": datetime.now().isoformat(),
|
||
"reference_directory": str(ref_dir_path),
|
||
"config_model": model_to_use,
|
||
# "generation_mode": "single_topic_max2_content", # Indicate the mode
|
||
"generation_mode": "single_topic_all_content_j_titles", # Updated mode description
|
||
"titles_by_topic": {}
|
||
}
|
||
success_topics = 0
|
||
failed_topics = 0
|
||
titles_generated_count = 0
|
||
|
||
for i, topic_item in enumerate(topics_list):
|
||
topic_index = get_topic_index(topic_item, i)
|
||
topic_description = topic_item.get("topic", f"未知主题 {topic_index}")
|
||
logger.info(f"\n--- Processing Topic {topic_index}: '{topic_description}' ---")
|
||
|
||
# Find and load content for all variants of this topic
|
||
variant_contents = find_and_load_variant_contents(run_dir_path, topic_index)
|
||
original_variant_count = len(variant_contents) # This is 'j'
|
||
|
||
if original_variant_count == 0:
|
||
logger.warning(f"Skipping topic {topic_index} due to missing variant content.")
|
||
regenerated_results["titles_by_topic"][str(topic_index)] = [] # Store empty list
|
||
failed_topics += 1
|
||
continue
|
||
|
||
# Removed logic for selecting max 2 contents
|
||
# We will use all variant_contents (j items)
|
||
logger.info(f"Topic {topic_index} has {original_variant_count} variants. Using all for title generation.")
|
||
|
||
# Regenerate titles (expecting original_variant_count titles, 1 per content)
|
||
try:
|
||
# 1. Create Prompts for this topic using all contents
|
||
system_prompt, user_prompt = create_single_topic_prompts(
|
||
topic_description,
|
||
# selected_contents,
|
||
variant_contents, # Pass all j contents
|
||
original_variant_count, # Ask for j titles
|
||
reference_titles,
|
||
DEFAULT_REF_TITLE_SAMPLE_SIZE
|
||
)
|
||
|
||
# 2. Call AI Function for this topic
|
||
new_titles = generate_titles_for_topic(
|
||
openai_client,
|
||
model_to_use,
|
||
main_config,
|
||
system_prompt,
|
||
user_prompt,
|
||
original_variant_count # Expect j titles
|
||
)
|
||
|
||
if new_titles is not None:
|
||
# Store results using the determined topic_index as string key
|
||
regenerated_results["titles_by_topic"][str(topic_index)] = new_titles
|
||
titles_generated_count += len(new_titles)
|
||
# Consider success if API call returned something, even if count mismatches (warning is in parse func)
|
||
success_topics += 1
|
||
logger.info(f"Successfully generated {len(new_titles)} titles for topic {topic_index}.")
|
||
else:
|
||
# API call or parsing failed
|
||
logger.error(f"Failed to generate titles for topic {topic_index}.")
|
||
regenerated_results["titles_by_topic"][str(topic_index)] = [] # Store empty list on failure
|
||
failed_topics += 1
|
||
|
||
except Exception as e:
|
||
logger.exception(f"Unhandled error regenerating titles for topic {topic_index}:")
|
||
regenerated_results["titles_by_topic"][str(topic_index)] = [] # Store empty list on error
|
||
failed_topics += 1
|
||
|
||
# --- End Processing Topics ---
|
||
|
||
# Save the results
|
||
output_file = run_dir_path / f"regenerated_titles_{run_id}.json"
|
||
try:
|
||
with open(output_file, 'w', encoding='utf-8') as f_out:
|
||
json.dump(regenerated_results, f_out, ensure_ascii=False, indent=4)
|
||
logger.info(f"Regenerated titles saved to: {output_file}")
|
||
except Exception as e:
|
||
logger.exception(f"Failed to save regenerated titles to {output_file}: {e}")
|
||
|
||
logger.info("=" * 30)
|
||
logger.info(f"Title Regeneration Summary for Run ID: {run_id}")
|
||
logger.info(f"Successfully Generated Titles for Topics: {success_topics}")
|
||
logger.info(f"Failed/Skipped Topics: {failed_topics}")
|
||
logger.info(f"Total Titles Generated (across all topics): {titles_generated_count}")
|
||
logger.info("=" * 30)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# --- Configuration (Set parameters directly here) ---
|
||
# Use Path objects for directories
|
||
run_directory_path_str = DEFAULT_RUN_DIR
|
||
config_path_str = DEFAULT_CONFIG_PATH
|
||
reference_directory_path_str = DEFAULT_REF_DIR
|
||
debug_mode_enabled = False # Set to True to enable debug logging
|
||
|
||
run_directory = Path(run_directory_path_str)
|
||
reference_directory = Path(reference_directory_path_str)
|
||
# --- End Configuration ---
|
||
|
||
|
||
# --- Validate Paths ---
|
||
if not run_directory.is_dir():
|
||
print(f"错误: 指定的运行目录不存在或不是一个目录: {run_directory_path_str}")
|
||
sys.exit(1)
|
||
|
||
if not reference_directory.is_dir():
|
||
print(f"错误: 指定的参考标题目录不存在或不是一个目录: {reference_directory_path_str}")
|
||
sys.exit(1)
|
||
# --- End Validate Paths ---
|
||
|
||
# Call the main processing function
|
||
main(run_directory, config_path_str, reference_directory, debug_mode_enabled) |