TravelContentCreator/scripts/select_images.py

225 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import logging
from pathlib import Path
import random
import sys
import json # Added for reading metadata
try:
from PIL import Image
except ImportError:
print("错误: 未找到所需的库 PIL (Pillow) 和 imagehash。")
print("请先安装它们: pip install Pillow imagehash")
sys.exit(1)
# --- 配置 ---
# 日志设置
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
# 支持的图片文件扩展名 (小写)
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.webp'}
# 每组选取的图片数量
IMAGES_PER_GROUP = 3
# --- 核心功能 ---
def find_image_files(directory: Path) -> list[Path]:
# Changed to non-recursive find to align with poster_notes_creator logic
"""Finds image files directly within the specified directory (non-recursive)."""
if not directory.is_dir():
logger.error(f"错误:目录不存在或不是一个有效的目录 -> {directory}")
return []
logger.info(f"开始在目录中扫描图片 (非递归): {directory}")
image_files = []
try:
# Use iterdir for non-recursive listing
for item in directory.iterdir():
if item.is_file() and item.suffix.lower() in IMAGE_EXTENSIONS:
image_files.append(item)
except Exception as e:
logger.error(f"扫描目录时出错: {e}")
return []
logger.info(f"{directory} 中找到 {len(image_files)} 个图片文件。")
return image_files
def calculate_image_hash(image_path: Path) -> str | None:
"""计算图片的感知哈希 (phash) 作为数字指纹。"""
try:
# 使用 Pillow 打开图片
with Image.open(image_path) as img:
# phash 通常在灰度图上效果更好,且能处理不同模式(如 RGBA, P
# 如果图片无法转换或打开,会抛出异常
hash_val = imagehash.phash(img.convert('L'))
# 返回哈希值的字符串形式
return str(hash_val)
except Exception as e:
logger.warning(f"无法识别的图片格式,跳过哈希计算: {image_path}")
return None
except Exception as e:
# 捕获其他可能的错误,如文件损坏、权限问题等
logger.warning(f"无法为图片计算哈希值 {image_path}: {e}")
return None
def select_image_groups(
image_paths: list[Path],
num_groups: int,
metadata_path: Path # Added metadata path argument
) -> list[list[dict]]:
"""从图片列表中随机选择指定数量的图片组,排除元数据中指定的图片。"""
# --- Read Metadata and Filter Used Images ---
used_image_filenames = set()
if metadata_path and metadata_path.is_file():
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
# Assuming the key for used images is 'collage_images'
# Adjust the key if your metadata structure is different
used_list = metadata.get('collage_images', [])
if isinstance(used_list, list):
used_image_filenames = set(used_list)
logger.info(f"从元数据 {metadata_path.name} 中加载了 {len(used_image_filenames)} 个已使用的图片文件名进行排除。")
else:
logger.warning(f"元数据文件 {metadata_path.name} 中的 'collage_images' 键不是一个列表,无法排除图片。")
except json.JSONDecodeError:
logger.error(f"无法解析元数据文件 (JSON格式错误): {metadata_path}. 将不排除任何图片。")
except Exception as e:
logger.error(f"读取元数据文件时出错 {metadata_path}: {e}. 将不排除任何图片。")
elif metadata_path:
logger.warning(f"指定的元数据文件不存在或不是文件: {metadata_path}. 将不排除任何图片。")
else:
logger.info("未提供元数据文件,将不排除任何图片。")
# Filter the initial list based on used filenames
available_image_paths = [p for p in image_paths if p.name not in used_image_filenames]
logger.info(f"总共找到 {len(image_paths)} 张图片, 排除 {len(used_image_filenames)} 张已用图片后, 剩余可用图片: {len(available_image_paths)} 张。")
# --- End Metadata Reading and Filtering ---
total_available_images = len(available_image_paths)
if total_available_images < IMAGES_PER_GROUP:
logger.error(f"可用图片总数 ({total_available_images}) 少于每组所需的 {IMAGES_PER_GROUP} 张。无法进行选择。")
return []
selected_groups_data = []
logger.info(f"开始从 {total_available_images} 张可用图片中选择 {num_groups} 组,每组 {IMAGES_PER_GROUP} 张...")
# Use a copy of the available paths list for sampling if groups might overlap significantly
# For now, sample directly from the available list
current_available_paths = list(available_image_paths) # Make a copy if we need to ensure groups don't deplete the pool for later groups in one run
for i in range(num_groups):
logger.info(f"--- 正在选择第 {i + 1} / {num_groups} 组 ---")
# Check if enough images remain for the current group
if len(current_available_paths) < IMAGES_PER_GROUP:
logger.warning(f"剩余可用图片 ({len(current_available_paths)}) 不足以组成第 {i+1} 组 (需要 {IMAGES_PER_GROUP} 张)。停止选择。")
break # Stop if not enough images left for a full group
try:
# Sample from the *currently available* paths
selected_paths_for_group = random.sample(current_available_paths, IMAGES_PER_GROUP)
# Optional: Remove selected images from the pool if groups must be distinct
# (Commented out for now, allowing images to appear in multiple groups)
# for path in selected_paths_for_group:
# current_available_paths.remove(path)
group_data = []
for img_path in selected_paths_for_group:
img_hash = calculate_image_hash(img_path)
group_data.append({
"path": str(img_path.resolve()),
"hash": img_hash
})
logger.debug(f" 已选定: {img_path} (哈希: {img_hash if img_hash else '无法计算'})")
selected_groups_data.append(group_data)
except ValueError:
# This might happen if the pool size becomes less than IMAGES_PER_GROUP between check and sample (unlikely here)
logger.error(f"选择第 {i+1} 组时发生 ValueError (剩余图片不足?)。已选 {len(current_available_paths)} 张可用。")
break
except Exception as e:
logger.error(f"选择第 {i+1} 组时发生意外错误: {e}")
continue # Skip this group
logger.info(f"图片组选择完成。共选出 {len(selected_groups_data)} 组。")
return selected_groups_data
def print_results(selected_groups: list[list[dict]]):
"""格式化并打印选择结果。"""
if not selected_groups:
print("\n未能选出任何图片组。")
return
print("\n========== 图片选择结果 ==========")
for i, group in enumerate(selected_groups):
print(f"\n--- 组 {i + 1} ---")
if not group:
print(" (此组选择失败)")
continue
for item in group:
print(f" 图片路径: {item['path']}")
print(f" 数字指纹 (phash): {item['hash'] if item['hash'] else 'N/A'}")
print("\n===================================")
# --- 主程序入口 ---
if __name__ == "__main__":
# 设置命令行参数解析
parser = argparse.ArgumentParser(
description=f"从指定目录随机选择数组图片(每组 {IMAGES_PER_GROUP} 张),并计算数字指纹。"
)
parser.add_argument(
"--image_dir",
type=str,
required=True,
help="包含图片的源目录路径。"
)
parser.add_argument(
"--num_groups",
type=int,
required=True,
help="要选择的图片组数量。"
)
parser.add_argument(
"--metadata_path",
type=str,
required=True, # Make it required as per user request
help="包含已使用图片列表的 JSON 元数据文件路径。假设键为 'collage_images'"
)
# 解析参数
args = parser.parse_args()
# 输入验证
source_directory = Path(args.image_dir)
number_of_groups = args.num_groups
metadata_file_path = Path(args.metadata_path)
if number_of_groups <= 0:
logger.error("选择的组数必须是一个正整数。")
if not metadata_file_path.is_file():
logger.error(f"指定的元数据文件不存在或不是一个文件: {metadata_file_path}")
sys.exit(1)
# 1. 查找图片 (Non-recursive)
all_images = find_image_files(source_directory)
if not all_images:
logger.info("在指定目录中未找到任何符合条件的图片文件。") # Changed from error to info
sys.exit(0) # Exit gracefully if no images found
# 2. 选择图片组 (Now includes filtering)
final_selected_groups = select_image_groups(all_images, number_of_groups, metadata_file_path)
# 3. 打印结果
print_results(final_selected_groups)
logger.info("脚本执行完毕。")