TravelContentCreator/scripts/reimage/regenerate_missing.py

505 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
import json
import argparse
import logging
import shutil
from typing import List, Dict, Any, Tuple, Optional, Set
# python scripts/regenerate_missing.py --run_dir /root/autodl-tmp/TravelContentCreator/result/2025-05-12_21-36-33 --config poster_gen_config.json
# 添加项目根目录到系统路径
sys.path.append("/root/autodl-tmp/TravelContentCreator")
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 导入所需的模块
from utils.output_handler import FileSystemOutputHandler
from utils.poster_notes_creator import select_additional_images
from utils.tweet_generator import generate_posters_for_topic, generate_content_for_topic
from core.topic_parser import TopicParser
from core.ai_agent import AI_Agent
from utils.prompt_manager import PromptManager
def load_config(config_path="poster_gen_config.json"):
"""加载配置文件"""
if not os.path.exists(config_path):
logging.error(f"错误:配置文件 '{config_path}' 不存在")
return None
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
return config
except Exception as e:
logging.error(f"加载配置文件时出错: {e}")
return None
def check_content_status(run_dir: str, topic_indices: List[int], variants_count: int) -> Dict:
"""
检查所有主题和变体的内容状态
Args:
run_dir: 运行目录路径
topic_indices: 所有主题索引列表
variants_count: 每个主题应有的变体数量
Returns:
Dict: 包含所有主题变体内容状态的字典
"""
status = {}
for topic_index in topic_indices:
status[topic_index] = {}
for variant_index in range(1, variants_count + 1):
variant_key = f"{topic_index}_{variant_index}"
topic_dir = os.path.join(run_dir, variant_key)
# 如果变体目录不存在,则标记为完全缺失
if not os.path.exists(topic_dir):
status[topic_index][variant_index] = {
"exists": False,
"has_content": False,
"has_poster": False,
"has_additional": False
}
continue
# 变体目录存在,检查内容
content_file = os.path.join(topic_dir, "tweet_content.json")
poster_dir = os.path.join(topic_dir, "poster")
# 检查海报图片和元数据
has_poster_image = False
has_poster_metadata = False
has_additional_images = False
if os.path.exists(poster_dir):
poster_files = [f for f in os.listdir(poster_dir)
if f.endswith(".jpg") and not f.startswith("additional_")]
has_poster_image = len(poster_files) > 0
metadata_files = [f for f in os.listdir(poster_dir)
if f.endswith("_metadata.json")]
has_poster_metadata = len(metadata_files) > 0
additional_files = [f for f in os.listdir(poster_dir)
if f.startswith("additional_") and f.endswith(".jpg")]
has_additional_images = len(additional_files) > 0
# 记录状态
status[topic_index][variant_index] = {
"exists": True,
"has_content": os.path.exists(content_file),
"has_poster": has_poster_image and has_poster_metadata,
"has_additional": has_additional_images
}
return status
def regenerate_missing_content(run_dir: str, config: Dict):
"""重新生成缺失的海报和配图,包括完全缺失的变体"""
run_id = os.path.basename(run_dir)
logging.info(f"处理运行ID: {run_id}")
# 创建输出处理器
output_handler = FileSystemOutputHandler(config.get("output_dir", "result"))
# 加载主题文件
topics_file = os.path.join(run_dir, f"tweet_topic_{run_id}.json")
if not os.path.exists(topics_file):
alternative_topics_file = os.path.join(run_dir, "tweet_topic.json")
if not os.path.exists(alternative_topics_file):
logging.error(f"主题文件不存在: {topics_file}{alternative_topics_file}")
return
topics_file = alternative_topics_file
logging.info(f"使用主题文件: {topics_file}")
topics_list = TopicParser.load_topics_from_json(topics_file)
if not topics_list:
logging.error("无法加载主题列表")
return
# 获取所有主题索引和应有的变体数量
topic_indices = [topic.get('index') for topic in topics_list if topic.get('index')]
variants_count = config.get("variants", 1)
logging.info(f"找到 {len(topic_indices)} 个主题,每个主题应有 {variants_count} 个变体")
# 检查所有主题和变体的内容状态
content_status = check_content_status(run_dir, topic_indices, variants_count)
# 统计需要处理的内容
missing_variants = []
incomplete_variants = []
for topic_index in content_status:
for variant_index in content_status[topic_index]:
status = content_status[topic_index][variant_index]
variant_key = f"{topic_index}_{variant_index}"
if not status["exists"]:
missing_variants.append((topic_index, variant_index))
logging.info(f"变体 {variant_key} 完全不存在,需要创建")
elif not (status["has_poster"] and status["has_additional"]):
incomplete_variants.append((topic_index, variant_index))
logging.info(f"变体 {variant_key} 不完整:海报={status['has_poster']} 配图={status['has_additional']}")
logging.info(f"需要创建 {len(missing_variants)} 个完全缺失的变体,补充 {len(incomplete_variants)} 个不完整的变体")
# 如果需要创建内容则初始化AI代理
if missing_variants:
# 创建PromptManager实例
try:
prompt_manager = PromptManager(
topic_system_prompt_path=config.get("topic_system_prompt"),
topic_user_prompt_path=config.get("topic_user_prompt"),
content_system_prompt_path=config.get("content_system_prompt"),
prompts_config=config.get("prompts_config"),
resource_dir_config=config.get("resource_dir", []),
topic_gen_num=config.get("num", 1),
topic_gen_date=config.get("date", ""),
content_judger_system_prompt_path=config.get("content_judger_system_prompt")
)
logging.info("PromptManager实例创建成功")
except Exception as e:
logging.error(f"创建PromptManager实例失败: {e}")
return
# 创建AI代理
try:
request_timeout = config.get("request_timeout", 180)
max_retries = config.get("max_retries", 3)
stream_chunk_timeout = config.get("stream_chunk_timeout", 30)
ai_agent = AI_Agent(
config.get("api_url"),
config.get("model"),
config.get("api_key"),
timeout=request_timeout,
max_retries=max_retries,
stream_chunk_timeout=stream_chunk_timeout,
)
logging.info("AI代理创建成功")
except Exception as e:
logging.error(f"创建AI代理失败: {e}")
return
else:
ai_agent = None
prompt_manager = None
# 读取配置参数
poster_assets_dir = config.get("poster_assets_base_dir")
img_base_dir = config.get("image_base_dir")
res_dir_config = config.get("resource_dir", [])
poster_size = tuple(config.get("poster_target_size", [900, 1200]))
txt_possibility = config.get("text_possibility", 0.3)
img_frame_possibility = config.get("img_frame_possibility", 0.7)
text_bg_possibility = config.get("text_bg_possibility", 0)
collage_subdir = config.get("output_collage_subdir", "collage_img")
poster_subdir = config.get("output_poster_subdir", "poster")
poster_filename = config.get("output_poster_filename", "poster.jpg")
poster_content_system_prompt_path = config.get("poster_content_system_prompt")
collage_style = config.get("collage_style")
title_possibility = config.get("title_possibility", 0.3)
request_timeout = config.get("request_timeout", 180)
additional_images_count = config.get("additional_images_count", 3)
# 获取图像选择配置参数
image_selection_config = config.get("image_selection", {})
variation_strength = image_selection_config.get("variation_strength", "medium")
extra_effects = image_selection_config.get("extra_effects", True)
# 检查关键配置
if not poster_assets_dir or not img_base_dir or not poster_content_system_prompt_path:
logging.error("缺少关键配置参数")
return
# 读取海报内容系统提示词
with open(poster_content_system_prompt_path, "r", encoding="utf-8") as f:
poster_content_system_prompt = f.read()
# 构建主题索引到主题项的映射
topic_map = {topic.get('index'): topic for topic in topics_list if topic.get('index')}
# 1. 处理完全缺失的变体
if missing_variants and ai_agent and prompt_manager:
for topic_index, variant_index in missing_variants:
variant_key = f"{topic_index}_{variant_index}"
logging.info(f"创建完全缺失的变体: {variant_key}")
topic_item = topic_map.get(topic_index)
if not topic_item:
logging.error(f"找不到主题 {topic_index} 的数据,跳过")
continue
# 1.1 复制已有变体的内容文件(如果存在)
content_copied = False
for existing_variant in range(1, variants_count + 1):
if existing_variant == variant_index:
continue
existing_dir = os.path.join(run_dir, f"{topic_index}_{existing_variant}")
existing_content = os.path.join(existing_dir, "tweet_content.json")
if os.path.exists(existing_content):
# 创建变体目录
new_variant_dir = os.path.join(run_dir, variant_key)
os.makedirs(new_variant_dir, exist_ok=True)
# 复制内容文件
new_content_file = os.path.join(new_variant_dir, "tweet_content.json")
shutil.copy2(existing_content, new_content_file)
logging.info(f"已从变体 {topic_index}_{existing_variant} 复制内容文件到 {variant_key}")
content_copied = True
break
# 1.2 如果没有可复制的内容,则生成新内容
if not content_copied:
logging.info(f"为变体 {variant_key} 生成新内容...")
# 只为当前变体生成内容
content_success = generate_content_for_topic(
ai_agent,
prompt_manager,
topic_item,
run_id,
topic_index,
output_handler,
variants=1,
variant_start_index=variant_index,
temperature=config.get("content_temperature", 0.3),
top_p=config.get("content_top_p", 0.4),
presence_penalty=config.get("content_presence_penalty", 1.5),
enable_content_judge=config.get("enable_content_judge", False)
)
if not content_success:
logging.error(f"为变体 {variant_key} 生成内容失败,跳过后续处理")
continue
logging.info(f"已为变体 {variant_key} 生成内容")
# 1.3 生成海报
logging.info(f"为变体 {variant_key} 生成海报...")
try:
posters_attempted = generate_posters_for_topic(
topic_item=topic_item,
output_dir=config["output_dir"],
run_id=run_id,
topic_index=topic_index,
output_handler=output_handler,
model_name=config["model"],
base_url=config["api_url"],
api_key=config["api_key"],
variants=1, # 只生成一个变体
variant_start_index=variant_index, # 从指定变体索引开始
title_possibility=title_possibility,
poster_assets_base_dir=poster_assets_dir,
image_base_dir=img_base_dir,
resource_dir_config=res_dir_config,
poster_target_size=poster_size,
text_possibility=txt_possibility,
img_frame_possibility=img_frame_possibility,
text_bg_possibility=text_bg_possibility,
output_collage_subdir=collage_subdir,
output_poster_subdir=poster_subdir,
output_poster_filename=poster_filename,
system_prompt=poster_content_system_prompt,
collage_style=collage_style,
timeout=request_timeout
)
if posters_attempted:
logging.info(f"变体 {variant_key} 海报生成完成")
else:
logging.warning(f"变体 {variant_key} 海报生成失败")
continue
except Exception as e:
logging.exception(f"生成海报时出错: {e}")
continue
# 1.4 生成额外配图
topic_dir = os.path.join(run_dir, variant_key)
poster_dir = os.path.join(topic_dir, poster_subdir)
if not os.path.exists(poster_dir):
logging.warning(f"海报目录不存在: {poster_dir},无法生成配图")
continue
try:
metadata_files = [f for f in os.listdir(poster_dir)
if f.endswith("_metadata.json") and os.path.isfile(os.path.join(poster_dir, f))]
if metadata_files:
poster_metadata_path = os.path.join(poster_dir, metadata_files[0])
logging.info(f"为变体 {variant_key} 生成额外配图,使用元数据: {poster_metadata_path}")
try:
object_name = topic_item.get("object", "").split(".")[0].replace("景点信息-", "").strip()
source_image_dir = os.path.join(img_base_dir, object_name)
if os.path.exists(source_image_dir) and os.path.isdir(source_image_dir):
additional_paths = select_additional_images(
run_id=run_id,
topic_index=topic_index,
variant_index=variant_index,
poster_metadata_path=poster_metadata_path,
source_image_dir=source_image_dir,
num_additional_images=additional_images_count,
output_handler=output_handler,
variation_strength=variation_strength,
extra_effects=extra_effects
)
if additional_paths:
logging.info(f"已为变体 {variant_key} 生成 {len(additional_paths)} 张额外配图")
else:
logging.warning(f"未能为变体 {variant_key} 生成任何额外配图")
else:
logging.warning(f"源图像目录不存在: {source_image_dir}")
except Exception as e:
logging.exception(f"生成额外配图时出错: {e}")
else:
logging.warning(f"未找到海报元数据文件,无法生成额外配图")
except Exception as e:
logging.warning(f"访问海报目录时出错: {e}")
# 2. 处理已存在但不完整的变体
for topic_index, variant_index in incomplete_variants:
variant_key = f"{topic_index}_{variant_index}"
topic_dir = os.path.join(run_dir, variant_key)
topic_item = topic_map.get(topic_index)
if not topic_item:
logging.error(f"找不到主题 {topic_index} 的数据,跳过")
continue
status = content_status[topic_index][variant_index]
# 2.1 如果缺少海报,生成海报
if not status["has_poster"]:
logging.info(f"为变体 {variant_key} 补充生成海报...")
try:
posters_attempted = generate_posters_for_topic(
topic_item=topic_item,
output_dir=config["output_dir"],
run_id=run_id,
topic_index=topic_index,
output_handler=output_handler,
model_name=config["model"],
base_url=config["api_url"],
api_key=config["api_key"],
variants=1,
variant_start_index=variant_index,
title_possibility=title_possibility,
poster_assets_base_dir=poster_assets_dir,
image_base_dir=img_base_dir,
resource_dir_config=res_dir_config,
poster_target_size=poster_size,
text_possibility=txt_possibility,
img_frame_possibility=img_frame_possibility,
text_bg_possibility=text_bg_possibility,
output_collage_subdir=collage_subdir,
output_poster_subdir=poster_subdir,
output_poster_filename=poster_filename,
system_prompt=poster_content_system_prompt,
collage_style=collage_style,
timeout=request_timeout
)
if posters_attempted:
logging.info(f"变体 {variant_key} 海报生成完成")
# 更新状态
status["has_poster"] = True
else:
logging.warning(f"变体 {variant_key} 海报生成失败")
except Exception as e:
logging.exception(f"生成海报时出错: {e}")
# 2.2 如果有海报但缺少额外配图,生成配图
if status["has_poster"] and not status["has_additional"]:
poster_dir = os.path.join(topic_dir, poster_subdir)
if not os.path.exists(poster_dir):
logging.warning(f"海报目录不存在: {poster_dir},无法生成配图")
continue
try:
metadata_files = [f for f in os.listdir(poster_dir)
if f.endswith("_metadata.json") and os.path.isfile(os.path.join(poster_dir, f))]
if metadata_files:
poster_metadata_path = os.path.join(poster_dir, metadata_files[0])
logging.info(f"为变体 {variant_key} 补充生成额外配图,使用元数据: {poster_metadata_path}")
try:
object_name = topic_item.get("object", "").split(".")[0].replace("景点信息-", "").strip()
source_image_dir = os.path.join(img_base_dir, object_name)
if os.path.exists(source_image_dir) and os.path.isdir(source_image_dir):
additional_paths = select_additional_images(
run_id=run_id,
topic_index=topic_index,
variant_index=variant_index,
poster_metadata_path=poster_metadata_path,
source_image_dir=source_image_dir,
num_additional_images=additional_images_count,
output_handler=output_handler,
variation_strength=variation_strength,
extra_effects=extra_effects
)
if additional_paths:
logging.info(f"已为变体 {variant_key} 生成 {len(additional_paths)} 张额外配图")
else:
logging.warning(f"未能为变体 {variant_key} 生成任何额外配图")
else:
logging.warning(f"源图像目录不存在: {source_image_dir}")
except Exception as e:
logging.exception(f"生成额外配图时出错: {e}")
else:
logging.warning(f"未找到海报元数据文件,无法生成额外配图")
except Exception as e:
logging.warning(f"访问海报目录时出错: {e}")
# 关闭AI代理
if ai_agent:
ai_agent.close()
logging.info("AI代理已关闭")
def main():
parser = argparse.ArgumentParser(description="补充生成丢失的海报和配图")
parser.add_argument("--run_dir", required=True, help="之前运行结果的目录")
parser.add_argument("--config", default="poster_gen_config.json", help="配置文件路径")
parser.add_argument("--debug", action="store_true", help="启用调试日志")
args = parser.parse_args()
# 设置日志级别
log_level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(
level=log_level,
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# 加载配置
config = load_config(args.config)
if not config:
sys.exit(1)
# 检查目录是否存在
if not os.path.exists(args.run_dir):
logging.error(f"指定的运行目录不存在: {args.run_dir}")
sys.exit(1)
# 开始处理
logging.info(f"开始为 {args.run_dir} 补充生成内容")
regenerate_missing_content(args.run_dir, config)
logging.info("处理完成")
if __name__ == "__main__":
main()