import os import json import logging from pathlib import Path import traceback import sys import random import re from datetime import datetime from openai import OpenAI import time # Added for timing the single API call # --- Path Setup --- # Ensure the project root is in the path project_root = Path(__file__).resolve().parent.parent if str(project_root) not in sys.path: sys.path.insert(0, str(project_root)) print(f"Added project root {project_root} to sys.path") # --- Logging Configuration --- logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger(__name__) # --- Constants --- DEFAULT_REF_TITLE_SAMPLE_SIZE = 20 # Number of reference titles to include in prompt DEFAULT_RUN_DIR = "/root/autodl-tmp/TravelContentCreator/result/长鹿旅游休博园/2025-04-27_14-03-44" # Default run directory DEFAULT_CONFIG_PATH = "poster_gen_config.json" # Default config path DEFAULT_REF_DIR = "/root/autodl-tmp/TravelContentCreator/genPrompts/Refer" # Default reference directory # --- Helper Functions --- def load_main_config(config_path="poster_gen_config.json"): """Loads the main configuration file.""" config_file = Path(config_path) if not config_file.is_file(): logger.error(f"Main configuration file not found: {config_path}") return None try: with open(config_file, 'r', encoding='utf-8') as f: config = json.load(f) logger.info(f"Main configuration loaded successfully: {config_path}") # Validate necessary keys for this script required_keys = ["api_url", "model", "api_key"] # QWQ model should be set here if not all(key in config for key in required_keys): missing = [key for key in required_keys if key not in config] logger.error(f"Main configuration file is missing required keys: {missing}") return None if config.get("model","").lower() != "qwq": logger.warning(f"Configured model is '{config.get('model')}', but QWQ was requested. Ensure the config is correct.") # You might want to force model='qwq' or exit if strict adherence is needed # config['model'] = 'qwq' # Force it? return config except json.JSONDecodeError: logger.error(f"Failed to parse main configuration file (JSON format error): {config_path}") return None except Exception as e: logger.exception(f"Error loading main configuration from {config_path}: {e}") return None def load_topic_data(run_dir: Path, run_id: str): """Loads the topic list from the tweet_topic_{run_id}.json file.""" topic_file_exact = run_dir / f"tweet_topic_{run_id}.json" topic_file_generic = run_dir / "tweet_topic.json" topic_file_to_load = None if topic_file_exact.is_file(): topic_file_to_load = topic_file_exact elif topic_file_generic.is_file(): topic_file_to_load = topic_file_generic logger.warning(f"Specific topic file {topic_file_exact.name} not found, using generic {topic_file_generic.name}") else: logger.error(f"Topic file not found in {run_dir} (tried {topic_file_exact.name} and {topic_file_generic.name})") return None try: with open(topic_file_to_load, 'r', encoding='utf-8') as f: data = json.load(f) if isinstance(data, dict) and "topics_list" in data: topics_list = data["topics_list"] elif isinstance(data, list): topics_list = data else: logger.error(f"Unexpected format in topic file: {topic_file_to_load}") return None logger.info(f"Successfully loaded {len(topics_list)} topics from {topic_file_to_load}") return topics_list except json.JSONDecodeError: logger.error(f"Failed to parse topic file (JSON format error): {topic_file_to_load}") return None except Exception as e: logger.exception(f"Error loading topics from {topic_file_to_load}: {e}") return None def get_topic_index(topic_item: dict, list_index: int) -> int: """Safely gets the topic index, falling back to list index + 1.""" topic_index = topic_item.get("index") try: if isinstance(topic_index, str) and topic_index.isdigit(): topic_index = int(topic_index) elif not isinstance(topic_index, int) or topic_index < 1: logger.warning(f"Topic item {list_index} has invalid or missing 'index'. Using list index {list_index + 1}.") topic_index = list_index + 1 except Exception: logger.warning(f"Error processing index for topic item {list_index}. Using list index {list_index + 1}.") topic_index = list_index + 1 return topic_index def find_and_load_variant_contents(run_dir: Path, topic_index: int) -> list[str]: """Finds all variant directories for a topic and loads their article *content*.""" variant_contents = [] variant_dirs = sorted(run_dir.glob(f"{topic_index}_*"), key=lambda p: int(p.name.split('_')[-1])) if not variant_dirs: logger.warning(f"No variant directories found for topic {topic_index} in {run_dir}") return variant_contents logger.debug(f"Found {len(variant_dirs)} potential variant directories for topic {topic_index}: {[d.name for d in variant_dirs]}") for variant_dir in variant_dirs: if variant_dir.is_dir() and '_' in variant_dir.name: parts = variant_dir.name.split('_') if len(parts) == 2 and parts[0] == str(topic_index) and parts[1].isdigit(): content_data = load_variant_article_data(variant_dir) if content_data and isinstance(content_data.get("content"), str): variant_contents.append(content_data["content"]) # Only add the content string else: logger.warning(f"Could not load valid 'content' string for variant {variant_dir.name}, it will be excluded.") logger.info(f"Loaded content strings for {len(variant_contents)} variants for topic {topic_index}.") return variant_contents def load_variant_article_data(variant_dir: Path): """Loads the article data (dict) from article.json.""" content_file = variant_dir / "article.json" if not content_file.is_file(): logger.warning(f"Article content file not found: {content_file}") return None try: with open(content_file, 'r', encoding='utf-8') as f: content_data = json.load(f) # Basic check for dict format if isinstance(content_data, dict): # logger.debug(f"Successfully loaded content data from {content_file}") return content_data else: logger.warning(f"Invalid format (not a dict) in article content file: {content_file}.") return None except json.JSONDecodeError: logger.error(f"Failed to parse article content file (JSON format error): {content_file}") return None except Exception as e: logger.exception(f"Error loading content data from {content_file}: {e}") return None def load_reference_titles(ref_dir: Path) -> list[str]: """Loads all lines from text files in the reference directory.""" titles = [] if not ref_dir or not ref_dir.is_dir(): logger.error(f"Reference title directory not found or invalid: {ref_dir}") return titles logger.info(f"Loading reference titles from: {ref_dir}") try: for file_path in ref_dir.glob("*.txt"): # Assuming titles are in .txt files if file_path.is_file(): logger.debug(f"Reading reference file: {file_path.name}") with open(file_path, 'r', encoding='utf-8') as f: for line in f: cleaned_line = line.strip() if cleaned_line: # Avoid empty lines titles.append(cleaned_line) except Exception as e: logger.exception(f"Error reading reference files from {ref_dir}: {e}") logger.info(f"Loaded {len(titles)} reference titles.") return titles # --- Prompt Generation (Single Topic, All Contents) --- # Logic reverted: Use all j contents for topic i, expect j titles in order. def create_single_topic_prompts( topic_description: str, # selected_contents: list[str], content_list: list[str], # Use all j contents for the topic num_titles_needed: int, # The original variant count (j) reference_titles: list[str], sample_size: int ) -> tuple[str, str]: """为单个主题创建系统和用户提示词,使用其全部j篇内容,生成j个一一对应的标题。""" # Select a sample of reference titles if len(reference_titles) > sample_size: sampled_refs = random.sample(reference_titles, sample_size) else: sampled_refs = reference_titles # System Prompt - Ask for j titles corresponding to j contents system_prompt = f"""你是一位专业的社交媒体文案撰稿人,尤其擅长小红书、推特等平台的旅行内容。 你的任务是为一个关于"{topic_description}"的主题,根据我提供的 {num_titles_needed} 篇具体内容,生成恰好 {num_titles_needed} 个与之对应的推文标题。 生成的每个标题必须严格按照用户提示中内容项的顺序一一对应。 请确保标题简洁(理想情况下少于15个字)、吸引人、与具体内容相关,并能抓住描述的旅行体验精髓。 输出格式必须严格为仅包含一个 JSON 列表,其中包含 {num_titles_needed} 个字符串,每个字符串是一个生成的标题。列表中的标题顺序必须与提供的内容项顺序完全一致。 例如,对于3个内容项,输出格式应为:["第1项内容的标题", "第2项内容的标题", "第3项内容的标题"] 不要在 JSON 列表前后包含任何解释、道歉或其他无关文本。""" # User Prompt - Show all j contents for the topic content_block = "" # content_count = len(selected_contents) # for i, content in enumerate(selected_contents): for i, content in enumerate(content_list): # Combine into a single multi-line f-string for clarity and safety # content_block += f"""--- 示例内容 {i+1} --- content_block += f"""--- 内容项 {i+1} --- 内容摘要: {content[:800]}...\n\n""" ref_block = "\n".join([f"- {ref}" for ref in sampled_refs]) # user_prompt = f"""请为主题"{topic_description}"生成恰好 {num_titles_needed} 个独特的推文标题。 # 这里有 {content_count} 篇示例内容供你参考: user_prompt = f"""请为主题"{topic_description}"的以下 {num_titles_needed} 篇内容,生成恰好 {num_titles_needed} 个一一对应的推文标题。 请确保输出列表中的标题与这些内容项按顺序一一对应: {content_block} 这里还有一些参考标题可供启发: {ref_block} 请记住,只输出包含 {num_titles_needed} 个生成标题且顺序正确的 JSON 列表。""" return system_prompt, user_prompt # --- AI Response Parsing (Should be suitable for flat list) --- def parse_ai_title_response(response: str, num_expected: int) -> list[str] | None: """Parses the AI response expecting a JSON list of strings.""" try: # Find the JSON list part (handle potential markdown fences) json_match = re.search(r'\[.*\]', response, re.DOTALL) if not json_match: logger.error(f"Could not find JSON list structure in AI response: {response[:200]}...") return None json_str = json_match.group(0) titles = json.loads(json_str) if isinstance(titles, list) and all(isinstance(t, str) for t in titles): if len(titles) != num_expected: logger.warning(f"AI generated {len(titles)} titles, but {num_expected} were expected. Using generated titles anyway.") # Optionally pad or truncate here if strict count is needed and alignment is not critical # while len(titles) < num_expected: titles.append("[MISSING TITLE]") # titles = titles[:num_expected] return titles else: logger.error(f"Parsed JSON is not a list of strings: {titles}") return None except json.JSONDecodeError as e: logger.error(f"Failed to decode AI response as JSON: {e}. Response snippet: {response[:200]}...") return None except Exception as e: logger.exception(f"Unexpected error parsing AI response: {e}") return None # --- Title Generation (Single Topic) --- # Adjusted from batch generation logic def generate_titles_for_topic( openai_client: OpenAI, model_name: str, config: dict, system_prompt: str, user_prompt: str, num_titles_needed: int, # Expecting j titles ) -> list[str] | None: """Generates titles for a single topic using OpenAI API.""" if num_titles_needed == 0: logger.warning("num_titles_needed is 0 for this topic. Skipping generation.") return [] # Return empty list logger.info(f"--- Regenerating {num_titles_needed} Titles for Topic (based on <=2 examples) --- ") logger.debug(f"System Prompt:\n{system_prompt}") logger.debug(f"User Prompt Snippet:\n{user_prompt[:500]}...") # Log only snippet # Call OpenAI API try: temp = config.get("title_temperature", 0.7) top_p = config.get("title_top_p", 0.8) pres_penalty = config.get("title_presence_penalty", 1.0) messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ] logger.info(f"Calling OpenAI model '{model_name}' to generate {num_titles_needed} titles...") start_time = time.time() response_object = openai_client.chat.completions.create( model=model_name, messages=messages, temperature=temp, top_p=top_p, presence_penalty=pres_penalty, ) end_time = time.time() time_cost = end_time - start_time logger.info(f"OpenAI API call completed in {time_cost:.2f}s.") response_content = response_object.choices[0].message.content logger.debug(f"Raw OpenAI Response content: {response_content[:500]}...") if hasattr(response_object, 'usage') and response_object.usage: logger.info(f"API Usage: Prompt={response_object.usage.prompt_tokens}, Completion={response_object.usage.completion_tokens}, Total={response_object.usage.total_tokens}") else: logger.info(f"API Usage: Information not available in response object.") except Exception as e: logger.exception(f"OpenAI API call failed: {e}") return None # Parse Response if response_content: generated_titles = parse_ai_title_response(response_content, num_titles_needed) if generated_titles: logger.info(f"Successfully parsed {len(generated_titles)} titles from response.") # Warning for mismatch is already in parse_ai_title_response return generated_titles else: logger.error("Failed to parse titles from AI response.") return None # Indicate failure else: logger.error("AI returned an empty response.") return None # --- Main Logic --- def main(run_dir_path: Path, config_path: str, ref_dir_path: Path, debug_mode: bool): """Main processing function for a single run directory.""" if debug_mode: logger.setLevel(logging.DEBUG) logging.getLogger().setLevel(logging.DEBUG) logger.info("DEBUG 日志已启用") run_id = run_dir_path.name logger.info(f"\n===== Processing Run Directory: {run_dir_path} (Run ID: {run_id}) =====") # Load main config main_config = load_main_config(config_path) if main_config is None: logger.critical("Failed to load main configuration. Aborting.") return # Load reference titles reference_titles = load_reference_titles(ref_dir_path) if not reference_titles: logger.warning("No reference titles loaded. Title generation quality may be affected.") # Initialize OpenAI client try: openai_client = OpenAI( base_url=main_config["api_url"], api_key=main_config["api_key"], # Add timeout if needed: timeout=main_config.get("request_timeout", 180) ) logger.info(f"OpenAI client initialized for base_url: {main_config['api_url']}") model_to_use = main_config["model"] # Get model name from config logger.info(f"Using model: {model_to_use}") except KeyError as e: logger.critical(f"Failed to initialize OpenAI client: Missing key {e} in config") return except Exception as e: logger.critical(f"Failed to initialize OpenAI client: {e}") traceback.print_exc() return # Load topics for this specific run topics_list = load_topic_data(run_dir_path, run_id) if topics_list is None: logger.critical(f"Failed to load topics for run {run_id}. Aborting.") return # --- Process Topics Individually --- regenerated_results = { "run_id": run_id, "regenerated_at": datetime.now().isoformat(), "reference_directory": str(ref_dir_path), "config_model": model_to_use, # "generation_mode": "single_topic_max2_content", # Indicate the mode "generation_mode": "single_topic_all_content_j_titles", # Updated mode description "titles_by_topic": {} } success_topics = 0 failed_topics = 0 titles_generated_count = 0 for i, topic_item in enumerate(topics_list): topic_index = get_topic_index(topic_item, i) topic_description = topic_item.get("topic", f"未知主题 {topic_index}") logger.info(f"\n--- Processing Topic {topic_index}: '{topic_description}' ---") # Find and load content for all variants of this topic variant_contents = find_and_load_variant_contents(run_dir_path, topic_index) original_variant_count = len(variant_contents) # This is 'j' if original_variant_count == 0: logger.warning(f"Skipping topic {topic_index} due to missing variant content.") regenerated_results["titles_by_topic"][str(topic_index)] = [] # Store empty list failed_topics += 1 continue # Removed logic for selecting max 2 contents # We will use all variant_contents (j items) logger.info(f"Topic {topic_index} has {original_variant_count} variants. Using all for title generation.") # Regenerate titles (expecting original_variant_count titles, 1 per content) try: # 1. Create Prompts for this topic using all contents system_prompt, user_prompt = create_single_topic_prompts( topic_description, # selected_contents, variant_contents, # Pass all j contents original_variant_count, # Ask for j titles reference_titles, DEFAULT_REF_TITLE_SAMPLE_SIZE ) # 2. Call AI Function for this topic new_titles = generate_titles_for_topic( openai_client, model_to_use, main_config, system_prompt, user_prompt, original_variant_count # Expect j titles ) if new_titles is not None: # Store results using the determined topic_index as string key regenerated_results["titles_by_topic"][str(topic_index)] = new_titles titles_generated_count += len(new_titles) # Consider success if API call returned something, even if count mismatches (warning is in parse func) success_topics += 1 logger.info(f"Successfully generated {len(new_titles)} titles for topic {topic_index}.") else: # API call or parsing failed logger.error(f"Failed to generate titles for topic {topic_index}.") regenerated_results["titles_by_topic"][str(topic_index)] = [] # Store empty list on failure failed_topics += 1 except Exception as e: logger.exception(f"Unhandled error regenerating titles for topic {topic_index}:") regenerated_results["titles_by_topic"][str(topic_index)] = [] # Store empty list on error failed_topics += 1 # --- End Processing Topics --- # Save the results output_file = run_dir_path / f"regenerated_titles_{run_id}.json" try: with open(output_file, 'w', encoding='utf-8') as f_out: json.dump(regenerated_results, f_out, ensure_ascii=False, indent=4) logger.info(f"Regenerated titles saved to: {output_file}") except Exception as e: logger.exception(f"Failed to save regenerated titles to {output_file}: {e}") logger.info("=" * 30) logger.info(f"Title Regeneration Summary for Run ID: {run_id}") logger.info(f"Successfully Generated Titles for Topics: {success_topics}") logger.info(f"Failed/Skipped Topics: {failed_topics}") logger.info(f"Total Titles Generated (across all topics): {titles_generated_count}") logger.info("=" * 30) if __name__ == "__main__": # --- Configuration (Set parameters directly here) --- # Use Path objects for directories run_directory_path_str = DEFAULT_RUN_DIR config_path_str = DEFAULT_CONFIG_PATH reference_directory_path_str = DEFAULT_REF_DIR debug_mode_enabled = False # Set to True to enable debug logging run_directory = Path(run_directory_path_str) reference_directory = Path(reference_directory_path_str) # --- End Configuration --- # --- Validate Paths --- if not run_directory.is_dir(): print(f"错误: 指定的运行目录不存在或不是一个目录: {run_directory_path_str}") sys.exit(1) if not reference_directory.is_dir(): print(f"错误: 指定的参考标题目录不存在或不是一个目录: {reference_directory_path_str}") sys.exit(1) # --- End Validate Paths --- # Call the main processing function main(run_directory, config_path_str, reference_directory, debug_mode_enabled)