TravelContentCreator/scripts/regenerate_img.py

411 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import logging
import argparse
from pathlib import Path
import traceback
from PIL import Image
import sys
# 假设你的项目结构允许这样导入
# 如果不行,你可能需要调整 sys.path
# 获取脚本所在的目录的上上级目录(通常是项目根目录)
project_root = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(project_root))
print(f"Added {project_root} to sys.path")
try:
from core.poster_gen import PosterGenerator
except ImportError as e:
print(f"Error importing PosterGenerator: {e}")
print("Initial sys.path import failed. Please ensure the script is run from the project root or adjust paths.")
# 尝试另一种可能的方式,如果脚本在 scripts/ 目录下
# script_dir = Path(__file__).resolve().parent
# project_root_alt = script_dir.parent
# if str(project_root_alt) not in sys.path:
# sys.path.insert(0, str(project_root_alt))
# print(f"Added alternative root {project_root_alt} to sys.path")
# try:
# from core.poster_gen import PosterGenerator
# except ImportError as e2:
# print(f"Second attempt failed: {e2}")
# print("Ensure 'core' module is accessible from the script's location.")
# sys.exit(1)
sys.exit(1)
# --- 日志配置 ---
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
def load_main_config(config_path="poster_gen_config.json"):
"""加载主配置文件"""
config_file = Path(config_path)
if not config_file.is_file():
logger.error(f"主配置文件未找到: {config_path}")
return None
try:
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
logger.info(f"主配置文件加载成功: {config_path}")
# 验证必要的键
required_keys = ["poster_assets_base_dir", "output_dir"]
if not all(key in config for key in required_keys):
missing = [key for key in required_keys if key not in config]
logger.error(f"主配置文件缺少必要的键: {missing}")
return None
return config
except json.JSONDecodeError:
logger.error(f"无法解析主配置文件 (JSON格式错误): {config_path}")
return None
except Exception as e:
logger.exception(f"加载主配置文件时发生错误 {config_path}: {e}")
return None
def find_topic_variants(run_dir: Path):
"""查找指定运行目录下所有的 topic_variant 子目录"""
variants = []
if not run_dir.is_dir():
logger.error(f"运行目录不存在或不是目录: {run_dir}")
return variants
for item in run_dir.iterdir():
# 匹配形如 "1_1", "2_3" 等目录
if item.is_dir() and '_' in item.name:
parts = item.name.split('_')
if len(parts) == 2 and parts[0].isdigit() and parts[1].isdigit():
topic_index = int(parts[0])
variant_index = int(parts[1])
variants.append({
"topic_index": topic_index,
"variant_index": variant_index,
"path": item
})
logger.info(f"{run_dir} 中找到 {len(variants)} 个主题变体目录")
return sorted(variants, key=lambda x: (x["topic_index"], x["variant_index"]))
def get_variant_poster_config(run_dir: Path, topic_index: int, variant_index: int):
"""从 topic_X_poster_configs.json 加载特定变体的海报配置"""
config_file = run_dir / f"topic_{topic_index}_poster_configs.json"
if not config_file.is_file():
logger.warning(f"主题 {topic_index} 的海报配置文件未找到: {config_file}")
return None
try:
with open(config_file, 'r', encoding='utf-8') as f:
all_configs = json.load(f)
if not isinstance(all_configs, list):
logger.warning(f"配置文件 {config_file} 格式不是列表")
return None
# 查找匹配 variant_index 的配置项 (假设配置中的 index 从 1 开始)
for config_item in all_configs:
if isinstance(config_item, dict) and config_item.get("index") == variant_index:
# 验证必要字段
if "main_title" in config_item and "texts" in config_item:
# 确保 texts 是列表,至少包含两个元素
texts = config_item.get("texts", ["", ""])
if not isinstance(texts, list):
texts = [str(texts), ""] # 强制转为列表
while len(texts) < 2:
texts.append("")
config_item["texts"] = texts[:2] # 最多取两个
logger.debug(f"找到变体 {topic_index}_{variant_index} 的配置: {config_item}")
return config_item
else:
logger.warning(f"变体 {topic_index}_{variant_index} 的配置缺少 main_title 或 texts 字段")
return None # 返回空配置或其他默认值这里返回None
logger.warning(f"{config_file} 中未找到索引为 {variant_index} 的配置")
return None
except json.JSONDecodeError:
logger.error(f"解析配置文件 JSON 失败: {config_file}")
return None
except Exception as e:
logger.exception(f"加载或查找变体配置时出错 {config_file}: {e}")
return None
def get_hotel_img_dir(run_dir_path: str) -> Path:
"""从运行目录路径中提取主题信息,并返回对应的酒店图片目录
Args:
run_dir_path: 运行目录路径,例如 "/root/autodl-tmp/TravelContentCreator/result/安吉银润锦江城堡酒店/2025-04-27_12-55-40"
Returns:
Path: 对应的酒店图片目录路径,例如 "/root/autodl-tmp/TravelContentCreator/hotel_img/安吉银润锦江城堡酒店"
"""
try:
# 从运行目录路径中提取主题名称
# 例如从 "/root/autodl-tmp/TravelContentCreator/result/安吉银润锦江城堡酒店/2025-04-27_12-55-40"
# 提取出 "安吉银润锦江城堡酒店"
run_path = Path(run_dir_path)
hotel_name = run_path.parent.name
# 构建酒店图片目录路径
hotel_img_dir = Path("/root/autodl-tmp/TravelContentCreator/hotel_img") / hotel_name
# 检查目录是否存在
if not hotel_img_dir.exists():
logger.warning(f"酒店图片目录不存在: {hotel_img_dir}")
return None
logger.info(f"找到酒店图片目录: {hotel_img_dir}")
return hotel_img_dir
except Exception as e:
logger.exception(f"获取酒店图片目录时出错: {e}")
return None
def regenerate_single_poster(generator: PosterGenerator, variant_info: dict, run_dir: Path):
"""重新生成单个海报和拼贴图"""
topic_index = variant_info["topic_index"]
variant_index = variant_info["variant_index"]
variant_path = variant_info["path"]
logger.info(f"--- 开始处理: 主题 {topic_index}, 变体 {variant_index} (目录: {variant_path.name}) ---")
# 1. 确保必要的目录存在
collage_dir = variant_path / "collage_img"
poster_dir = variant_path / "poster"
collage_dir.mkdir(parents=True, exist_ok=True)
poster_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"确保目录存在: collage_dir={collage_dir}, poster_dir={poster_dir}")
# 2. 定位拼贴图
collage_path = None
if collage_dir.is_dir():
image_files = list(collage_dir.glob("*.png")) + list(collage_dir.glob("*.jpg")) + list(collage_dir.glob("*.jpeg"))
if image_files:
collage_path = image_files[0] # 取第一个找到的
logger.info(f"找到拼贴图: {collage_path}")
else:
logger.warning(f"{collage_dir} 中未找到拼贴图 (png/jpg/jpeg)")
# 即使没有找到原始拼贴图,也继续处理,因为我们会重新生成
collage_path = collage_dir / "collage.jpg" # 设置一个默认路径
logger.info(f"将使用默认拼贴图路径: {collage_path}")
# 3. 获取该变体的海报文字配置
poster_config = get_variant_poster_config(run_dir, topic_index, variant_index)
if poster_config is None:
logger.warning(f"无法获取变体 {topic_index}_{variant_index} 的海报配置,跳过生成。")
return False
# 4. 准备 text_data
text_data = {
"title": poster_config.get("main_title", ""), # 保证默认是空字符串
"subtitle": "", # 通常副标题为空
"additional_texts": []
}
texts = poster_config.get("texts", [])
if texts:
for text in texts:
if text and isinstance(text, str): # 确保文本不为空且是字符串
text_data["additional_texts"].append({"text": text, "position": "middle", "size_factor": 0.8})
elif isinstance(text, dict) and text.get("text"): # 如果已经是字典格式
text_data["additional_texts"].append(text)
logger.debug(f"准备好的 text_data: {text_data}")
# 5. 重新生成拼贴图
try:
from core.simple_collage import ImageCollageCreator
collage_creator = ImageCollageCreator()
# 获取酒店图片目录
hotel_img_dir = get_hotel_img_dir(str(run_dir))
if not hotel_img_dir:
logger.warning(f"无法获取酒店图片目录,尝试使用原始拼贴图继续")
if collage_path and collage_path.exists():
collage_img = Image.open(collage_path).convert('RGBA')
else:
logger.error("无法获取酒店图片目录且原始拼贴图不存在,跳过生成")
return False
else:
# 重新生成拼贴图
new_collage, used_images = collage_creator.create_collage_with_style(
str(hotel_img_dir),
style=None, # 随机选择样式
target_size=(900, 1200)
)
if new_collage:
# 保存新生成的拼贴图
regenerated_collage_path = collage_dir / "regenerated_collage.jpg"
# 在保存为 JPEG 之前转换为 RGB 格式
if new_collage.mode == 'RGBA':
new_collage = new_collage.convert('RGB')
new_collage.save(regenerated_collage_path)
logger.info(f"成功重新生成并保存拼贴图到: {regenerated_collage_path}")
collage_img = new_collage
else:
logger.warning("拼贴图重新生成失败,尝试使用原始拼贴图继续")
if collage_path and collage_path.exists():
collage_img = Image.open(collage_path).convert('RGBA')
else:
logger.error("拼贴图重新生成失败且原始拼贴图不存在,跳过生成")
return False
except Exception as e:
logger.exception(f"重新生成拼贴图时发生错误: {e}")
# 尝试使用原始拼贴图继续
if collage_path and collage_path.exists():
collage_img = Image.open(collage_path).convert('RGBA')
else:
logger.error("处理拼贴图时发生错误且原始拼贴图不存在,跳过生成")
return False
# 6. 调用生成器创建海报
logger.info("调用 PosterGenerator.create_poster...")
try:
regenerated_poster_img = generator.create_poster(collage_img, text_data)
except Exception as e:
logger.exception(f"调用 create_poster 时发生严重错误: {e}")
traceback.print_exc() # 打印详细错误堆栈
regenerated_poster_img = None # 确保设置为 None
# 7. 保存重新生成的海报
if regenerated_poster_img:
# 确保输出目录存在
poster_output_dir = variant_path / "poster"
poster_output_dir.mkdir(parents=True, exist_ok=True)
# 假设输出文件名固定为 poster.jpg
poster_output_path = poster_output_dir / "poster.jpg"
try:
# 确保保存为 JPG 格式时是 RGB
if poster_output_path.suffix.lower() == ".jpg" or poster_output_path.suffix.lower() == ".jpeg":
if regenerated_poster_img.mode == 'RGBA':
logger.debug("将图像从 RGBA 转换为 RGB 以保存为 JPG")
regenerated_poster_img = regenerated_poster_img.convert('RGB')
regenerated_poster_img.save(poster_output_path)
logger.info(f"成功重新生成并保存海报到: {poster_output_path}")
return True
except Exception as e:
logger.exception(f"保存重新生成的海报失败 {poster_output_path}: {e}")
return False
else:
logger.error(f"PosterGenerator未能生成海报图像 (返回 None)")
return False
# --- 主逻辑 ---
def main(run_dirs_to_process, config_path, debug_mode):
"""主处理函数,接收运行目录列表"""
if debug_mode:
logger.setLevel(logging.DEBUG)
logging.getLogger().setLevel(logging.DEBUG) # 应用到根logger
logger.info("DEBUG 日志已启用")
# 加载主配置
main_config = load_main_config(config_path)
if main_config is None:
logger.critical("无法加载主配置,脚本终止。")
return # 或者 sys.exit(1)
# 初始化 PosterGenerator
poster_assets_dir = main_config.get("poster_assets_base_dir")
if not poster_assets_dir:
logger.critical("配置文件中未找到 'poster_assets_base_dir'")
return
poster_generator = None
try:
# PosterGenerator 可能需要 output_dir 参数,尽管这里可能不会直接用它保存
output_dir_from_config = main_config.get("output_dir")
poster_generator = PosterGenerator(base_dir=poster_assets_dir, output_dir=output_dir_from_config)
# 根据需要设置参数,例如从主配置文件读取
poster_generator.set_img_frame_possibility(main_config.get("img_frame_possibility", 0.7))
poster_generator.set_text_bg_possibility(main_config.get("text_bg_possibility", 0))
logger.info(f"PosterGenerator 初始化成功,资源目录: {poster_assets_dir}")
except Exception as e:
logger.critical(f"初始化 PosterGenerator 失败: {e}")
traceback.print_exc()
return
total_success_count = 0
total_failure_count = 0
# 遍历要处理的运行目录列表
for run_dir_str in run_dirs_to_process:
run_directory = Path(run_dir_str)
logger.info(f"\n===== 开始处理运行目录: {run_directory} =====")
if not run_directory.is_dir():
logger.error(f"指定的运行目录不存在或不是目录: {run_directory}")
continue # 跳过这个目录,继续下一个
# 查找当前运行目录下的所有主题变体
topic_variants = find_topic_variants(run_directory)
if not topic_variants:
logger.warning(f"{run_directory} 中未找到任何主题变体目录,跳过。")
continue
# 循环处理每个变体
run_success_count = 0
run_failure_count = 0
for variant in topic_variants:
try:
if regenerate_single_poster(poster_generator, variant, run_directory):
run_success_count += 1
else:
run_failure_count += 1
except Exception as e:
logger.exception(f"处理变体 {variant['path']} 时发生未捕获的错误:")
run_failure_count += 1
# logger.info(f"--- 完成处理变体: {variant['path'].name} ---") # 可以在 regenerate_single_poster 内部打印完成信息
logger.info(f"===== 完成处理运行目录: {run_directory} =====")
logger.info(f"本轮成功: {run_success_count}, 失败/跳过: {run_failure_count}")
total_success_count += run_success_count
total_failure_count += run_failure_count
logger.info("=" * 30)
logger.info(f"所有指定目录的海报重新生成完成。")
logger.info(f"总成功: {total_success_count}")
logger.info(f"总失败/跳过: {total_failure_count}")
logger.info("=" * 30)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="重新生成指定运行ID目录下所有主题变体的海报")
# 不再需要 run_dir 参数
# parser.add_argument(
# "run_dir",
# type=str,
# help="包含主题变体子目录的运行结果目录路径"
# )
parser.add_argument(
"--config",
type=str,
default="poster_gen_config.json",
help="主配置文件路径 (poster_gen_config.json)"
)
parser.add_argument(
"--debug",
action='store_true',
help="启用 DEBUG 级别日志"
)
args = parser.parse_args()
# ==================================================
# 在这里定义你要处理的运行目录列表
# ==================================================
run_directories = [
# "/root/autodl-tmp/TravelContentCreator/result/齐云山度假酒店/2025-04-27_11-51-56",
"/root/autodl-tmp/TravelContentCreator/result/长鹿旅游休博园/2025-04-27_14-03-44",
#"/root/autodl-tmp/TravelContentCreator/result/笔架山居森林度假酒店/2025-04-27_02-02-34",
# "/root/autodl-tmp/TravelContentCreator/result/笔架山居森林度假酒店/2025-04-27_02-23-17",
# "/root/autodl-tmp/TravelContentCreator/result/笔架山居森林度假酒店/2025-04-27_07-57-20",
# "/root/autodl-tmp/TravelContentCreator/result/笔架山居森林度假酒店/2025-04-27_09-29-20"
# "/root/autodl-tmp/TravelContentCreator/result/ANOTHER_RUN_ID",
# 添加更多你需要处理的运行目录路径...
]
# ==================================================
if not run_directories:
print("错误:请在脚本中编辑 `run_directories` 列表,指定至少一个要处理的运行目录。")
sys.exit(1)
# 调用主处理函数
main(run_directories, args.config, args.debug)