TravelContentCreator/scripts/select_images.py

225 lines
9.4 KiB
Python
Raw Normal View History

import argparse
import logging
from pathlib import Path
import random
import sys
import json # Added for reading metadata
try:
from PIL import Image
except ImportError:
print("错误: 未找到所需的库 PIL (Pillow) 和 imagehash。")
print("请先安装它们: pip install Pillow imagehash")
sys.exit(1)
# --- 配置 ---
# 日志设置
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
# 支持的图片文件扩展名 (小写)
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.webp'}
# 每组选取的图片数量
IMAGES_PER_GROUP = 3
# --- 核心功能 ---
def find_image_files(directory: Path) -> list[Path]:
# Changed to non-recursive find to align with poster_notes_creator logic
"""Finds image files directly within the specified directory (non-recursive)."""
if not directory.is_dir():
logger.error(f"错误:目录不存在或不是一个有效的目录 -> {directory}")
return []
logger.info(f"开始在目录中扫描图片 (非递归): {directory}")
image_files = []
try:
# Use iterdir for non-recursive listing
for item in directory.iterdir():
if item.is_file() and item.suffix.lower() in IMAGE_EXTENSIONS:
image_files.append(item)
except Exception as e:
logger.error(f"扫描目录时出错: {e}")
return []
logger.info(f"{directory} 中找到 {len(image_files)} 个图片文件。")
return image_files
def calculate_image_hash(image_path: Path) -> str | None:
"""计算图片的感知哈希 (phash) 作为数字指纹。"""
try:
# 使用 Pillow 打开图片
with Image.open(image_path) as img:
# phash 通常在灰度图上效果更好,且能处理不同模式(如 RGBA, P
# 如果图片无法转换或打开,会抛出异常
hash_val = imagehash.phash(img.convert('L'))
# 返回哈希值的字符串形式
return str(hash_val)
except Exception as e:
logger.warning(f"无法识别的图片格式,跳过哈希计算: {image_path}")
return None
except Exception as e:
# 捕获其他可能的错误,如文件损坏、权限问题等
logger.warning(f"无法为图片计算哈希值 {image_path}: {e}")
return None
def select_image_groups(
image_paths: list[Path],
num_groups: int,
metadata_path: Path # Added metadata path argument
) -> list[list[dict]]:
"""从图片列表中随机选择指定数量的图片组,排除元数据中指定的图片。"""
# --- Read Metadata and Filter Used Images ---
used_image_filenames = set()
if metadata_path and metadata_path.is_file():
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
# Assuming the key for used images is 'collage_images'
# Adjust the key if your metadata structure is different
used_list = metadata.get('collage_images', [])
if isinstance(used_list, list):
used_image_filenames = set(used_list)
logger.info(f"从元数据 {metadata_path.name} 中加载了 {len(used_image_filenames)} 个已使用的图片文件名进行排除。")
else:
logger.warning(f"元数据文件 {metadata_path.name} 中的 'collage_images' 键不是一个列表,无法排除图片。")
except json.JSONDecodeError:
logger.error(f"无法解析元数据文件 (JSON格式错误): {metadata_path}. 将不排除任何图片。")
except Exception as e:
logger.error(f"读取元数据文件时出错 {metadata_path}: {e}. 将不排除任何图片。")
elif metadata_path:
logger.warning(f"指定的元数据文件不存在或不是文件: {metadata_path}. 将不排除任何图片。")
else:
logger.info("未提供元数据文件,将不排除任何图片。")
# Filter the initial list based on used filenames
available_image_paths = [p for p in image_paths if p.name not in used_image_filenames]
logger.info(f"总共找到 {len(image_paths)} 张图片, 排除 {len(used_image_filenames)} 张已用图片后, 剩余可用图片: {len(available_image_paths)} 张。")
# --- End Metadata Reading and Filtering ---
total_available_images = len(available_image_paths)
if total_available_images < IMAGES_PER_GROUP:
logger.error(f"可用图片总数 ({total_available_images}) 少于每组所需的 {IMAGES_PER_GROUP} 张。无法进行选择。")
return []
selected_groups_data = []
logger.info(f"开始从 {total_available_images} 张可用图片中选择 {num_groups} 组,每组 {IMAGES_PER_GROUP} 张...")
# Use a copy of the available paths list for sampling if groups might overlap significantly
# For now, sample directly from the available list
current_available_paths = list(available_image_paths) # Make a copy if we need to ensure groups don't deplete the pool for later groups in one run
for i in range(num_groups):
logger.info(f"--- 正在选择第 {i + 1} / {num_groups} 组 ---")
# Check if enough images remain for the current group
if len(current_available_paths) < IMAGES_PER_GROUP:
logger.warning(f"剩余可用图片 ({len(current_available_paths)}) 不足以组成第 {i+1} 组 (需要 {IMAGES_PER_GROUP} 张)。停止选择。")
break # Stop if not enough images left for a full group
try:
# Sample from the *currently available* paths
selected_paths_for_group = random.sample(current_available_paths, IMAGES_PER_GROUP)
# Optional: Remove selected images from the pool if groups must be distinct
# (Commented out for now, allowing images to appear in multiple groups)
# for path in selected_paths_for_group:
# current_available_paths.remove(path)
group_data = []
for img_path in selected_paths_for_group:
img_hash = calculate_image_hash(img_path)
group_data.append({
"path": str(img_path.resolve()),
"hash": img_hash
})
logger.debug(f" 已选定: {img_path} (哈希: {img_hash if img_hash else '无法计算'})")
selected_groups_data.append(group_data)
except ValueError:
# This might happen if the pool size becomes less than IMAGES_PER_GROUP between check and sample (unlikely here)
logger.error(f"选择第 {i+1} 组时发生 ValueError (剩余图片不足?)。已选 {len(current_available_paths)} 张可用。")
break
except Exception as e:
logger.error(f"选择第 {i+1} 组时发生意外错误: {e}")
continue # Skip this group
logger.info(f"图片组选择完成。共选出 {len(selected_groups_data)} 组。")
return selected_groups_data
def print_results(selected_groups: list[list[dict]]):
"""格式化并打印选择结果。"""
if not selected_groups:
print("\n未能选出任何图片组。")
return
print("\n========== 图片选择结果 ==========")
for i, group in enumerate(selected_groups):
print(f"\n--- 组 {i + 1} ---")
if not group:
print(" (此组选择失败)")
continue
for item in group:
print(f" 图片路径: {item['path']}")
print(f" 数字指纹 (phash): {item['hash'] if item['hash'] else 'N/A'}")
print("\n===================================")
# --- 主程序入口 ---
if __name__ == "__main__":
# 设置命令行参数解析
parser = argparse.ArgumentParser(
description=f"从指定目录随机选择数组图片(每组 {IMAGES_PER_GROUP} 张),并计算数字指纹。"
)
parser.add_argument(
"--image_dir",
type=str,
required=True,
help="包含图片的源目录路径。"
)
parser.add_argument(
"--num_groups",
type=int,
required=True,
help="要选择的图片组数量。"
)
parser.add_argument(
"--metadata_path",
type=str,
required=True, # Make it required as per user request
help="包含已使用图片列表的 JSON 元数据文件路径。假设键为 'collage_images'"
)
# 解析参数
args = parser.parse_args()
# 输入验证
source_directory = Path(args.image_dir)
number_of_groups = args.num_groups
metadata_file_path = Path(args.metadata_path)
if number_of_groups <= 0:
logger.error("选择的组数必须是一个正整数。")
if not metadata_file_path.is_file():
logger.error(f"指定的元数据文件不存在或不是一个文件: {metadata_file_path}")
sys.exit(1)
# 1. 查找图片 (Non-recursive)
all_images = find_image_files(source_directory)
if not all_images:
logger.info("在指定目录中未找到任何符合条件的图片文件。") # Changed from error to info
sys.exit(0) # Exit gracefully if no images found
# 2. 选择图片组 (Now includes filtering)
final_selected_groups = select_image_groups(all_images, number_of_groups, metadata_file_path)
# 3. 打印结果
print_results(final_selected_groups)
logger.info("脚本执行完毕。")