457 lines
20 KiB
Python
457 lines
20 KiB
Python
import os
|
|
import json
|
|
import logging
|
|
import argparse
|
|
from pathlib import Path
|
|
import traceback
|
|
import sys
|
|
import re # Needed for cleaning object names and finding description files
|
|
|
|
# --- Path Setup ---
|
|
# Get the directory two levels up from the script (assumes scripts/regenerate...py structure)
|
|
project_root = Path(__file__).resolve().parent.parent
|
|
if str(project_root) not in sys.path:
|
|
sys.path.insert(0, str(project_root))
|
|
print(f"Added project root {project_root} to sys.path")
|
|
|
|
try:
|
|
# Import necessary modules from your project structure
|
|
from utils.content_generator import ContentGenerator
|
|
# ContentGenerator likely handles AI_Agent internally for its 'run' method
|
|
# from core.ai_agent import AI_Agent
|
|
except ImportError as e:
|
|
print(f"Error importing project modules: {e}")
|
|
print(f"Sys.path: {sys.path}")
|
|
print("Ensure the script is run from a location where 'utils' and 'core' are importable, or adjust sys.path.")
|
|
sys.exit(1)
|
|
|
|
# --- Logging Configuration ---
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- Helper Functions ---
|
|
|
|
def load_main_config(config_path="poster_gen_config.json"):
|
|
"""Loads the main configuration file."""
|
|
config_file = Path(config_path)
|
|
if not config_file.is_file():
|
|
logger.error(f"Main configuration file not found: {config_path}")
|
|
return None
|
|
try:
|
|
with open(config_file, 'r', encoding='utf-8') as f:
|
|
config = json.load(f)
|
|
logger.info(f"Main configuration loaded successfully: {config_path}")
|
|
# Validate necessary keys for this script
|
|
required_keys = [
|
|
"api_url", "model", "api_key", "output_dir",
|
|
"poster_content_system_prompt", "variants", "resource_dir"
|
|
# Add other keys needed by ContentGenerator.run or description finding
|
|
]
|
|
if not all(key in config for key in required_keys):
|
|
missing = [key for key in required_keys if key not in config]
|
|
logger.error(f"Main configuration file is missing required keys: {missing}")
|
|
return None
|
|
return config
|
|
except json.JSONDecodeError:
|
|
logger.error(f"Failed to parse main configuration file (JSON format error): {config_path}")
|
|
return None
|
|
except Exception as e:
|
|
logger.exception(f"Error loading main configuration from {config_path}: {e}")
|
|
return None
|
|
|
|
def load_topic_data(run_dir: Path, run_id: str):
|
|
"""Loads the topic list from the tweet_topic_{run_id}.json file."""
|
|
topic_file_exact = run_dir / f"tweet_topic_{run_id}.json"
|
|
topic_file_generic = run_dir / "tweet_topic.json"
|
|
|
|
topic_file_to_load = None
|
|
if topic_file_exact.is_file():
|
|
topic_file_to_load = topic_file_exact
|
|
elif topic_file_generic.is_file():
|
|
topic_file_to_load = topic_file_generic
|
|
logger.warning(f"Specific topic file {topic_file_exact.name} not found, using generic {topic_file_generic.name}")
|
|
else:
|
|
logger.error(f"Topic file not found in {run_dir} (tried {topic_file_exact.name} and {topic_file_generic.name})")
|
|
return None
|
|
|
|
try:
|
|
with open(topic_file_to_load, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
if isinstance(data, dict) and "topics_list" in data:
|
|
topics_list = data["topics_list"]
|
|
elif isinstance(data, list):
|
|
topics_list = data
|
|
else:
|
|
logger.error(f"Unexpected format in topic file: {topic_file_to_load}")
|
|
return None
|
|
# Add run_dir to each topic item for easier path access later
|
|
for topic in topics_list:
|
|
if isinstance(topic, dict):
|
|
topic['run_dir'] = str(run_dir) # Store the run directory path
|
|
logger.info(f"Successfully loaded {len(topics_list)} topics from {topic_file_to_load}")
|
|
return topics_list
|
|
except json.JSONDecodeError:
|
|
logger.error(f"Failed to parse topic file (JSON format error): {topic_file_to_load}")
|
|
return None
|
|
except Exception as e:
|
|
logger.exception(f"Error loading topics from {topic_file_to_load}: {e}")
|
|
return None
|
|
|
|
def get_topic_index(topic_item: dict, list_index: int) -> int:
|
|
"""Safely gets the topic index, falling back to list index + 1."""
|
|
topic_index = topic_item.get("index")
|
|
try:
|
|
if isinstance(topic_index, str) and topic_index.isdigit():
|
|
topic_index = int(topic_index)
|
|
elif not isinstance(topic_index, int) or topic_index < 1:
|
|
logger.warning(f"Topic item {list_index} has invalid or missing 'index'. Using list index {list_index + 1}.")
|
|
topic_index = list_index + 1
|
|
except Exception:
|
|
logger.warning(f"Error processing index for topic item {list_index}. Using list index {list_index + 1}.")
|
|
topic_index = list_index + 1
|
|
return topic_index
|
|
|
|
|
|
def find_and_load_variant_contents(run_dir: Path, topic_index: int) -> list:
|
|
"""Finds all variant directories for a topic and loads their article.json content."""
|
|
variant_contents = []
|
|
variant_dirs = sorted(run_dir.glob(f"{topic_index}_*"), key=lambda p: int(p.name.split('_')[-1])) # Sort by variant number
|
|
|
|
if not variant_dirs:
|
|
logger.warning(f"No variant directories found for topic {topic_index} in {run_dir}")
|
|
return variant_contents
|
|
|
|
logger.debug(f"Found {len(variant_dirs)} potential variant directories for topic {topic_index}: {[d.name for d in variant_dirs]}")
|
|
|
|
for variant_dir in variant_dirs:
|
|
# Double check it's a directory with the correct format
|
|
if variant_dir.is_dir() and '_' in variant_dir.name:
|
|
parts = variant_dir.name.split('_')
|
|
if len(parts) == 2 and parts[0] == str(topic_index) and parts[1].isdigit():
|
|
content = load_variant_content(variant_dir)
|
|
if content: # Only add if content was loaded successfully
|
|
variant_contents.append(content)
|
|
else:
|
|
logger.warning(f"Could not load valid content for variant {variant_dir.name}, it will be excluded.")
|
|
|
|
logger.info(f"Loaded content for {len(variant_contents)} variants for topic {topic_index}.")
|
|
return variant_contents
|
|
|
|
|
|
def load_variant_content(variant_dir: Path):
|
|
"""Loads the article content (title, content) from article.json."""
|
|
content_file = variant_dir / "article.json"
|
|
if not content_file.is_file():
|
|
logger.warning(f"Article content file not found: {content_file}")
|
|
return None
|
|
try:
|
|
with open(content_file, 'r', encoding='utf-8') as f:
|
|
content_data = json.load(f)
|
|
if isinstance(content_data, dict) and "title" in content_data and "content" in content_data:
|
|
logger.debug(f"Successfully loaded content from {content_file}")
|
|
return content_data
|
|
else:
|
|
logger.warning(f"Invalid format in article content file: {content_file}. Missing 'title' or 'content'.")
|
|
return None
|
|
except json.JSONDecodeError:
|
|
logger.error(f"Failed to parse article content file (JSON format error): {content_file}")
|
|
return None
|
|
except Exception as e:
|
|
logger.exception(f"Error loading content from {content_file}: {e}")
|
|
return None
|
|
|
|
def clean_object_name(object_name_raw: str) -> str:
|
|
"""Cleans the object name, similar to logic in tweet_generator."""
|
|
if not object_name_raw:
|
|
return ""
|
|
try:
|
|
cleaned = object_name_raw.split(".")[0].replace("景点信息-", "").strip()
|
|
return cleaned
|
|
except Exception as e:
|
|
logger.warning(f"Could not fully clean object name '{object_name_raw}': {e}. Using raw.")
|
|
return object_name_raw.strip()
|
|
|
|
def find_description_file(object_name: str, resource_dir_config: list, desc_dir: Path | None = None) -> list:
|
|
"""
|
|
Tries to find the description file path based on object name.
|
|
Prioritizes searching in desc_dir if provided.
|
|
"""
|
|
info_directory = [] # Expects a list of paths
|
|
found_description = False
|
|
|
|
# --- Step 1: Prioritize search in the specified desc_dir ---
|
|
if desc_dir and desc_dir.is_dir():
|
|
logger.info(f"Prioritizing description file search in specified directory: {desc_dir}")
|
|
# Simple search: find first file containing the object name (case-insensitive)
|
|
for file_path in desc_dir.iterdir():
|
|
# Use stem for matching (filename without extension)
|
|
# Make comparison case-insensitive for robustness
|
|
if file_path.is_file() and object_name.lower() in file_path.stem.lower():
|
|
info_directory = [str(file_path)]
|
|
logger.info(f"Found description file in specified directory: {file_path}")
|
|
found_description = True
|
|
break # Use the first match found in the specified dir
|
|
if not found_description:
|
|
logger.warning(f"No file containing '{object_name}' found in specified directory: {desc_dir}. Falling back to resource_dir config.")
|
|
elif desc_dir:
|
|
logger.warning(f"Specified description directory '{desc_dir}' not found or not a directory. Falling back to resource_dir config.")
|
|
|
|
# --- Step 2: Fallback to searching resource_dir_config from main config ---
|
|
if not found_description:
|
|
if not object_name or not resource_dir_config:
|
|
logger.warning("Cannot find description file: Missing object name or resource_dir config.")
|
|
return info_directory # Return empty list
|
|
|
|
logger.debug(f"Searching for description file for object '{object_name}' in resource_dir config...")
|
|
# keywords = [k for k in re.split(r'[+\s_\-]+', object_name) if len(k) > 1] # Keywords might be needed for fuzzy matching later
|
|
|
|
# Search logic based on resource_dir_config (as before)
|
|
for dir_info in resource_dir_config:
|
|
if dir_info.get("type") == "Description":
|
|
for file_path_str in dir_info.get("file_path", []):
|
|
file_path = Path(file_path_str)
|
|
# Use stem for matching, case-insensitive
|
|
if object_name.lower() in file_path.stem.lower():
|
|
if file_path.is_file():
|
|
info_directory = [str(file_path)]
|
|
logger.info(f"Found potential description file in resource_dir config: {file_path}")
|
|
found_description = True
|
|
break # Take the first match based on this logic
|
|
else:
|
|
logger.warning(f"Configured description file not found at path: {file_path}")
|
|
if found_description:
|
|
break
|
|
|
|
if not found_description:
|
|
logger.warning(f"No description file found for '{object_name}' in resource_dir config either.")
|
|
|
|
return info_directory
|
|
|
|
|
|
def regenerate_topic_poster_config(
|
|
content_gen: ContentGenerator,
|
|
topic_item: dict,
|
|
topic_index: int, # Now using the determined index
|
|
run_dir: Path,
|
|
main_config: dict,
|
|
desc_dir: Path | None # Added desc_dir parameter
|
|
) -> bool:
|
|
"""Regenerates the poster text config for a single topic using all its variant contents."""
|
|
logger.info(f"--- Regenerating Poster Config for Topic {topic_index} (Object: {topic_item.get('object', 'N/A')}) ---")
|
|
|
|
# 1. Load content from ALL variants for this topic
|
|
variant_contents_list = find_and_load_variant_contents(run_dir, topic_index)
|
|
if not variant_contents_list:
|
|
logger.error(f"No valid variant contents found for topic {topic_index}. Skipping.")
|
|
return False
|
|
|
|
# 2. Find description file (info_directory) - Pass desc_dir here
|
|
object_name_raw = topic_item.get("object", "")
|
|
object_name_cleaned = clean_object_name(object_name_raw)
|
|
resource_config = main_config.get("resource_dir", [])
|
|
info_directory = find_description_file(object_name_cleaned, resource_config, desc_dir) # Pass desc_dir
|
|
|
|
# 3. Get parameters for ContentGenerator.run
|
|
poster_num = len(variant_contents_list)
|
|
system_prompt_path = main_config.get("poster_content_system_prompt")
|
|
api_url = main_config.get("api_url")
|
|
model_name = main_config.get("model")
|
|
api_key = main_config.get("api_key")
|
|
timeout = main_config.get("request_timeout", 120)
|
|
|
|
if not system_prompt_path or not Path(system_prompt_path).is_file():
|
|
logger.error(f"Poster content system prompt file not found or not specified in config: {system_prompt_path}")
|
|
return False
|
|
|
|
try:
|
|
with open(system_prompt_path, "r", encoding="utf-8") as f:
|
|
system_prompt = f.read()
|
|
except Exception as e:
|
|
logger.exception(f"Failed to read system prompt file {system_prompt_path}: {e}")
|
|
return False
|
|
|
|
# Set ContentGenerator parameters
|
|
content_gen.set_temperature(main_config.get("content_temperature", 0.7))
|
|
content_gen.set_top_p(main_config.get("content_top_p", 0.8))
|
|
content_gen.set_presence_penalty(main_config.get("content_presence_penalty", 1.2))
|
|
|
|
# 4. Call ContentGenerator.run with the list of variant contents
|
|
logger.info(f"Calling ContentGenerator.run for topic {topic_index} with {poster_num} variant contents...")
|
|
logger.debug(f" - Info Dir Used: {info_directory}") # Log the final info_directory used
|
|
logger.debug(f" - Poster Num: {poster_num}")
|
|
|
|
try:
|
|
regenerated_configs = content_gen.run(
|
|
info_directory=info_directory, # Use the potentially updated info_directory
|
|
poster_num=poster_num,
|
|
content_data=variant_contents_list,
|
|
system_prompt=system_prompt,
|
|
api_url=api_url,
|
|
model_name=model_name,
|
|
api_key=api_key,
|
|
timeout=timeout
|
|
)
|
|
except Exception as e:
|
|
logger.exception(f"ContentGenerator.run failed for topic {topic_index}: {e}")
|
|
traceback.print_exc()
|
|
regenerated_configs = None
|
|
|
|
# 5. Save the regenerated config (list)
|
|
if regenerated_configs is not None:
|
|
output_config_path = run_dir / f"topic_{topic_index}_poster_configs.json"
|
|
try:
|
|
if not isinstance(regenerated_configs, list):
|
|
logger.warning(f"ContentGenerator.run for topic {topic_index} did not return a list. Attempting to wrap. Result type: {type(regenerated_configs)}")
|
|
regenerated_configs = [regenerated_configs] if regenerated_configs else []
|
|
|
|
with open(output_config_path, 'w', encoding='utf-8') as f_out:
|
|
json.dump(regenerated_configs, f_out, ensure_ascii=False, indent=4)
|
|
logger.info(f"Successfully regenerated and saved poster configs to: {output_config_path} ({len(regenerated_configs)} configs)")
|
|
return True
|
|
except Exception as e:
|
|
logger.exception(f"Failed to save regenerated poster configs to {output_config_path}: {e}")
|
|
return False
|
|
else:
|
|
logger.error(f"ContentGenerator.run did not return valid config data for topic {topic_index}.")
|
|
return False
|
|
|
|
|
|
# --- Main Logic ---
|
|
def main(run_dirs_to_process, config_path, debug_mode, desc_dir_path: Path | None):
|
|
"""Main processing function."""
|
|
|
|
if debug_mode:
|
|
logger.setLevel(logging.DEBUG)
|
|
logging.getLogger().setLevel(logging.DEBUG)
|
|
logger.info("DEBUG 日志已启用")
|
|
|
|
main_config = load_main_config(config_path)
|
|
if main_config is None:
|
|
logger.critical("Failed to load main configuration. Aborting.")
|
|
return
|
|
|
|
try:
|
|
content_generator = ContentGenerator(output_dir=main_config.get("output_dir"))
|
|
logger.info("ContentGenerator initialized.")
|
|
except Exception as e:
|
|
logger.critical(f"Failed to initialize ContentGenerator: {e}")
|
|
traceback.print_exc()
|
|
return
|
|
|
|
total_success_count = 0
|
|
total_failure_count = 0
|
|
all_topics_data = []
|
|
|
|
# Phase 1: Load all topics
|
|
for run_dir_str in run_dirs_to_process:
|
|
run_directory = Path(run_dir_str)
|
|
run_id = run_directory.name
|
|
logger.info(f"\n===== Loading Topics from: {run_directory} (Run ID: {run_id}) =====")
|
|
if not run_directory.is_dir():
|
|
logger.error(f"Directory not found: {run_directory}. Skipping.")
|
|
continue
|
|
topics_list = load_topic_data(run_directory, run_id)
|
|
if topics_list:
|
|
all_topics_data.extend(topics_list)
|
|
else:
|
|
logger.error(f"Failed to load topics for {run_id}.")
|
|
|
|
if not all_topics_data:
|
|
logger.critical("No topics loaded from any specified run directory. Aborting.")
|
|
return
|
|
|
|
logger.info(f"\n===== Loaded a total of {len(all_topics_data)} topics. Starting regeneration. =====")
|
|
|
|
# Phase 2: Process all loaded topics
|
|
for i, topic_item in enumerate(all_topics_data):
|
|
run_dir_str = topic_item.get('run_dir')
|
|
if not run_dir_str:
|
|
logger.error(f"Topic item {i} is missing 'run_dir' information. Cannot process.")
|
|
total_failure_count += 1
|
|
continue
|
|
run_directory = Path(run_dir_str)
|
|
topic_index = get_topic_index(topic_item, i)
|
|
|
|
try:
|
|
# Pass desc_dir_path to the processing function
|
|
if regenerate_topic_poster_config(
|
|
content_generator,
|
|
topic_item,
|
|
topic_index,
|
|
run_directory,
|
|
main_config,
|
|
desc_dir_path # Pass the specified description directory path
|
|
):
|
|
total_success_count += 1
|
|
else:
|
|
total_failure_count += 1
|
|
except Exception as e:
|
|
logger.exception(f"Unhandled error processing topic index {topic_index} from run {run_directory.name}:")
|
|
total_failure_count += 1
|
|
|
|
logger.info("=" * 30)
|
|
logger.info(f"All topics processed.")
|
|
logger.info(f"Total Configs Regenerated Successfully: {total_success_count}")
|
|
logger.info(f"Total Failed/Skipped Topics: {total_failure_count}")
|
|
logger.info("=" * 30)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="重新生成指定运行ID目录下所有主题的海报文本配置")
|
|
parser.add_argument(
|
|
"--config",
|
|
type=str,
|
|
default="poster_gen_config.json",
|
|
help="主配置文件路径 (poster_gen_config.json)"
|
|
)
|
|
parser.add_argument(
|
|
"--desc_dir",
|
|
type=str,
|
|
default=None,
|
|
help="可选:指定一个目录优先查找描述文件 (如果提供)"
|
|
)
|
|
parser.add_argument(
|
|
"--debug",
|
|
action='store_true',
|
|
help="启用 DEBUG 级别日志"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# --- Process desc_dir argument ---
|
|
description_directory = Path("/root/autodl-tmp/TravelContentCreator/resource/Object")
|
|
if args.desc_dir:
|
|
description_directory = Path(args.desc_dir)
|
|
if not description_directory.is_dir():
|
|
logger.warning(f"指定的描述目录 (--desc_dir) 不是一个有效的目录: {args.desc_dir}. 将忽略此参数.")
|
|
description_directory = None # Reset if invalid
|
|
else:
|
|
logger.info(f"将优先在以下目录查找描述文件: {description_directory}")
|
|
# --- End Process desc_dir ---
|
|
|
|
|
|
# ==================================================
|
|
# Define the list of run directories to process here
|
|
# ==================================================
|
|
run_directories = [
|
|
"/root/autodl-tmp/TravelContentCreator/result/安吉/2025-04-27_12-55-40",
|
|
"/root/autodl-tmp/TravelContentCreator/result/乌镇/2025-04-27_10-50-34",
|
|
"/root/autodl-tmp/TravelContentCreator/result/齐云山/2025-04-27_11-51-56",
|
|
"/root/autodl-tmp/TravelContentCreator/result/长鹿/2025-04-27_14-03-44",
|
|
"/root/autodl-tmp/TravelContentCreator/result/笔架山/2025-04-27_02-02-34",
|
|
"/root/autodl-tmp/TravelContentCreator/result/笔架山/2025-04-27_02-23-17",
|
|
"/root/autodl-tmp/TravelContentCreator/result/笔架山/2025-04-27_07-57-20",
|
|
"/root/autodl-tmp/TravelContentCreator/result/笔架山/2025-04-27_09-29-20"
|
|
]
|
|
# ==================================================
|
|
|
|
if not run_directories:
|
|
print("ERROR: Please edit the 'run_directories' list in the script to specify at least one run directory to process.")
|
|
sys.exit(1)
|
|
|
|
# Call the main processing function, passing the processed description_directory path
|
|
main(run_directories, args.config, args.debug, description_directory) |