TravelContentCreator/batch_image_selector.py

699 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import random
import logging
import json
import argparse
from typing import List, Dict, Any, Tuple, Optional
from PIL import Image, ImageEnhance, ImageFilter, ImageOps # 添加 PIL 导入
import colorsys # 添加 colorsys 导入
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def select_image_batch(
available_images: List[str],
num_images_to_select: int,
batch_seed: int
) -> List[str]:
"""
使用给定种子从可用图像列表中随机选择指定数量的图像。
Args:
available_images: 可用图像文件名的列表。
num_images_to_select: 此批次要选择的图像数量。
batch_seed: 此批次选择使用的随机种子。
Returns:
List[str]: 该批次选定的图像文件名列表。
"""
if not available_images:
logger.warning("没有可供选择的图像。")
return []
# 如果可用图像少于请求数量,发出警告但选择所有可用图像
if len(available_images) < num_images_to_select:
logger.warning(
f"可用图像数量 ({len(available_images)}) 少于请求的数量 ({num_images_to_select})"
f"将选择所有可用图像。"
)
# 即使选择所有图像,也使用 sample 进行随机排序
random.seed(batch_seed)
selected = random.sample(available_images, len(available_images))
random.seed() # 重置种子以避免影响后续操作
return selected
else:
logger.info(f"使用种子 {batch_seed}{len(available_images)} 张可用图像中选择 {num_images_to_select} 张。")
random.seed(batch_seed)
selected = random.sample(available_images, num_images_to_select)
random.seed() # 重置种子
return selected
# --- 从 poster_notes_creator.py 复制的图像处理函数 --- START ---
def _adjust_brightness(image: Image.Image, factor: float) -> Image.Image:
"""调整图像亮度"""
if factor == 1.0:
return image
try:
enhancer = ImageEnhance.Brightness(image)
return enhancer.enhance(factor)
except ImportError:
logger.warning("ImageEnhance.Brightness not available, using manual adjustment.")
data = list(image.getdata())
new_data = []
for pixel in data:
r, g, b = pixel[:3]
r = min(255, max(0, int(r * factor)))
g = min(255, max(0, int(g * factor)))
b = min(255, max(0, int(b * factor)))
if len(pixel) > 3: # 如果有alpha通道
new_data.append((r, g, b, pixel[3]))
else:
new_data.append((r, g, b))
result = Image.new(image.mode, image.size)
result.putdata(new_data)
return result
def _adjust_contrast(image: Image.Image, factor: float) -> Image.Image:
"""调整图像对比度"""
if factor == 1.0:
return image
try:
enhancer = ImageEnhance.Contrast(image)
return enhancer.enhance(factor)
except ImportError:
logger.warning("ImageEnhance.Contrast not available, using manual adjustment.")
data = list(image.getdata())
new_data = []
# 计算平均亮度
avg_r, avg_g, avg_b = 0, 0, 0
count = 0
for pixel in data:
r, g, b = pixel[:3]
avg_r += r
avg_g += g
avg_b += b
count += 1
if count > 0:
avg_r //= count
avg_g //= count
avg_b //= count
else:
return image # 无法计算平均值
# 调整对比度
for pixel in data:
r, g, b = pixel[:3]
r = min(255, max(0, int(avg_r + (r - avg_r) * factor)))
g = min(255, max(0, int(avg_g + (g - avg_g) * factor)))
b = min(255, max(0, int(avg_b + (b - avg_b) * factor)))
if len(pixel) > 3: # 如果有alpha通道
new_data.append((r, g, b, pixel[3]))
else:
new_data.append((r, g, b))
result = Image.new(image.mode, image.size)
result.putdata(new_data)
return result
def _adjust_saturation(image: Image.Image, factor: float) -> Image.Image:
"""调整图像饱和度"""
if factor == 1.0:
return image
try:
enhancer = ImageEnhance.Color(image)
return enhancer.enhance(factor)
except ImportError:
logger.warning("ImageEnhance.Color not available, using manual adjustment.")
data = list(image.getdata())
new_data = []
for pixel in data:
r, g, b = pixel[:3]
# 计算灰度值
gray = int(0.299 * r + 0.587 * g + 0.114 * b) # 更精确的灰度计算
# 调整饱和度
r = min(255, max(0, int(gray + (r - gray) * factor)))
g = min(255, max(0, int(gray + (g - gray) * factor)))
b = min(255, max(0, int(gray + (b - gray) * factor)))
if len(pixel) > 3: # 如果有alpha通道
new_data.append((r, g, b, pixel[3]))
else:
new_data.append((r, g, b))
result = Image.new(image.mode, image.size)
result.putdata(new_data)
return result
def _adjust_hue(image: Image.Image, shift: float) -> Image.Image:
"""调整图像色相"""
if shift == 0.0:
return image
try:
# 获取像素数据
image = image.convert('RGB') # 确保是RGB模式
data = list(image.getdata())
new_data = []
for pixel in data:
r, g, b = pixel
# 转换为HSV
h, s, v = colorsys.rgb_to_hsv(r/255.0, g/255.0, b/255.0)
# 调整色相 (H 是 0-1 的值)
h = (h + shift) % 1.0
# 转回RGB
r_new, g_new, b_new = colorsys.hsv_to_rgb(h, s, v)
r_new = int(r_new * 255)
g_new = int(g_new * 255)
b_new = int(b_new * 255)
new_data.append((r_new, g_new, b_new))
result = Image.new(image.mode, image.size)
result.putdata(new_data)
return result
except Exception as e:
logger.warning(f"Error adjusting hue: {e}. Returning original image.")
return image
def _add_noise(image: Image.Image, intensity: float = 0.02) -> Image.Image:
"""添加微弱噪点intensity控制噪点强度(0-1)"""
if intensity <= 0:
return image
try:
image = image.convert('RGB') # 确保是RGB模式
data = list(image.getdata())
new_data = []
noise_limit = int(intensity * 255)
for pixel in data:
r, g, b = pixel
# 添加随机噪点
noise_r = random.randint(-noise_limit, noise_limit)
noise_g = random.randint(-noise_limit, noise_limit)
noise_b = random.randint(-noise_limit, noise_limit)
r = max(0, min(255, r + noise_r))
g = max(0, min(255, g + noise_g))
b = max(0, min(255, b + noise_b))
new_data.append((r, g, b))
result = Image.new(image.mode, image.size)
result.putdata(new_data)
return result
except Exception as e:
logger.warning(f"Error adding noise: {e}. Returning original image.")
return image
def _add_border_and_crop(image: Image.Image, border_size: int) -> Image.Image:
"""添加边框然后裁剪回原尺寸,用于改变边缘像素"""
if border_size <= 0:
return image
try:
width, height = image.size
image = image.convert('RGB') # 确保是RGB模式
# 创建略大的画布
border_color = (
random.randint(0, 10),
random.randint(0, 10),
random.randint(0, 10)
)
bordered = Image.new(image.mode, (width + border_size*2, height + border_size*2), border_color)
bordered.paste(image, (border_size, border_size))
# 随机裁剪回原尺寸
offset_x = random.randint(0, border_size*2)
offset_y = random.randint(0, border_size*2)
result = bordered.crop((offset_x, offset_y, offset_x + width, offset_y + height))
return result
except Exception as e:
logger.warning(f"Error adding border and cropping: {e}. Returning original image.")
return image
def _slight_sharpen(image: Image.Image) -> Image.Image:
"""轻微锐化图像"""
try:
enhancer = ImageEnhance.Sharpness(image)
return enhancer.enhance(1.2) # 轻微锐化1.0是原始锐度
except ImportError:
logger.warning("ImageEnhance.Sharpness not available. Skipping sharpen.")
return image
except Exception as e:
logger.warning(f"Error applying sharpen: {e}. Returning original image.")
return image
def _slight_blur(image: Image.Image) -> Image.Image:
"""轻微模糊图像"""
try:
return image.filter(ImageFilter.GaussianBlur(radius=0.5))
except ImportError:
logger.warning("ImageFilter.GaussianBlur not available. Skipping blur.")
return image
except Exception as e:
logger.warning(f"Error applying blur: {e}. Returning original image.")
return image
def process_image_to_aspect_ratio(
image: Image.Image,
target_ratio: Tuple[int, int],
add_variation: bool = True,
seed: int = None,
variation_strength: str = "medium",
extra_effects: bool = True
) -> Optional[Image.Image]:
"""
处理图像到指定的宽高比,并添加微小变化
Args:
image: 原始图像
target_ratio: 目标宽高比,如(3, 4)
add_variation: 是否添加微小变化
seed: 随机种子
variation_strength: 微调强度 ("low", "medium", "high")
extra_effects: 是否添加额外效果
Returns:
Image.Image: 处理后的图像, 或在失败时返回 None
"""
try:
original_mode = image.mode
# 如果指定了种子,设置随机种子
if seed is not None:
random.seed(seed)
# 根据微调强度设置参数范围
if variation_strength == "low":
brightness_range = (-0.03, 0.03)
contrast_range = (-0.03, 0.03)
saturation_range = (-0.03, 0.03)
hue_range = (-0.01, 0.01)
max_crop_px = 3
max_rotation = 0.5
noise_intensity = 0.01
border_size_range = (0, 2)
elif variation_strength == "high":
brightness_range = (-0.08, 0.08)
contrast_range = (-0.08, 0.08)
saturation_range = (-0.08, 0.08)
hue_range = (-0.02, 0.02)
max_crop_px = 8
max_rotation = 2.0
noise_intensity = 0.03
border_size_range = (0, 4)
else: # medium (默认)
brightness_range = (-0.05, 0.05)
contrast_range = (-0.05, 0.05)
saturation_range = (-0.05, 0.05)
hue_range = (-0.015, 0.015)
max_crop_px = 5
max_rotation = 1.0
noise_intensity = 0.02
border_size_range = (0, 3)
width, height = image.size
current_ratio = width / height
target_ratio_value = target_ratio[0] / target_ratio[1]
# 调整尺寸策略:使目标尺寸大约为 900x1200 或 1200x900 (3:4)
# 您可以根据需要修改这里的基准尺寸
base_width = 900
base_height = 1200
if target_ratio_value > 1: # 宽幅图,交换基准尺寸
base_width, base_height = base_height, base_width
if current_ratio > target_ratio_value: # 图片较宽
new_height = base_height
new_width = int(new_height * current_ratio)
else: # 图片较高
new_width = base_width
new_height = int(new_width / current_ratio)
# 避免尺寸为0
new_width = max(1, new_width)
new_height = max(1, new_height)
resized_image = image.resize((new_width, new_height), Image.LANCZOS)
# 计算裁剪区域
resized_width, resized_height = resized_image.size
if resized_width / resized_height > target_ratio_value:
crop_width = int(resized_height * target_ratio_value)
crop_height = resized_height
# 微小偏移
max_offset_x = max(1, min(20, (resized_width - crop_width) // 5)) if add_variation else 0
offset_x = random.randint(-max_offset_x, max_offset_x) if add_variation else 0
crop_x1 = max(0, min((resized_width - crop_width) // 2 + offset_x, resized_width - crop_width))
crop_y1 = 0
else:
crop_height = int(resized_width / target_ratio_value)
crop_width = resized_width
# 微小偏移
max_offset_y = max(1, min(20, (resized_height - crop_height) // 5)) if add_variation else 0
offset_y = random.randint(-max_offset_y, max_offset_y) if add_variation else 0
crop_y1 = max(0, min((resized_height - crop_height) // 2 + offset_y, resized_height - crop_height))
crop_x1 = 0
crop_x2 = crop_x1 + crop_width
crop_y2 = crop_y1 + crop_height
result = resized_image.crop((crop_x1, crop_y1, crop_x2, crop_y2))
if add_variation:
# 确保处理时为 RGB 或 RGBA
if result.mode not in ('RGB', 'RGBA'):
processed_image = result.convert('RGB')
else:
processed_image = result.copy()
# 1. 亮度
brightness_factor = 1.0 + random.uniform(*brightness_range)
processed_image = _adjust_brightness(processed_image, brightness_factor)
# 2. 对比度
contrast_factor = 1.0 + random.uniform(*contrast_range)
processed_image = _adjust_contrast(processed_image, contrast_factor)
# 3. 饱和度
saturation_factor = 1.0 + random.uniform(*saturation_range)
processed_image = _adjust_saturation(processed_image, saturation_factor)
# 4. 微小裁剪
crop_px = random.randint(0, max_crop_px)
if crop_px > 0:
w, h = processed_image.size
if w > 2*crop_px and h > 2*crop_px:
processed_image = processed_image.crop((crop_px, crop_px, w-crop_px, h-crop_px))
processed_image = processed_image.resize((w, h), Image.LANCZOS)
# 5. 微小旋转
rotation_angle = random.uniform(-max_rotation, max_rotation)
try:
# 使用 expand=False 保持尺寸,填充背景可能为黑色,需要处理
processed_image = processed_image.rotate(rotation_angle, resample=Image.BICUBIC, expand=False)
except ValueError:
logger.warning(f"Skipping rotation due to potential transparency issues.")
# 6. 额外效果
if extra_effects:
# 6.1 噪点
processed_image = _add_noise(processed_image, intensity=noise_intensity)
# 6.2 色相
processed_image = _adjust_hue(processed_image, shift=random.uniform(*hue_range))
# 6.3 边缘微调
border_size = random.randint(*border_size_range)
if border_size > 0:
processed_image = _add_border_and_crop(processed_image, border_size)
# 6.4 锐化/模糊
if random.random() > 0.5:
processed_image = _slight_sharpen(processed_image)
else:
processed_image = _slight_blur(processed_image)
# 如果原始是 RGBA尝试恢复 alpha 通道
# 注意旋转等操作可能使alpha处理复杂这里简单处理
if original_mode == 'RGBA' and processed_image.mode == 'RGB':
try:
alpha = result.getchannel('A')
processed_image.putalpha(alpha)
except (ValueError, IndexError):
logger.warning("Could not restore original alpha channel after processing.")
elif original_mode != processed_image.mode and original_mode != 'P': # 避免转换回索引调色板
try:
processed_image = processed_image.convert(original_mode)
except ValueError:
logger.warning(f"Could not convert processed image back to original mode {original_mode}.")
final_image = processed_image
else:
final_image = result
# 重置随机种子
if seed is not None:
random.seed()
return final_image
except Exception as e:
logger.error(f"Error processing image: {e}")
# 重置随机种子以防万一
if seed is not None:
random.seed()
return None
# --- 从 poster_notes_creator.py 复制的图像处理函数 --- END ---
def main():
parser = argparse.ArgumentParser(description="根据指定的目录和批次数,生成指定数量的选图批次,并可选择处理图像。")
parser.add_argument(
"--source-dir",
required=True,
help="包含源图像的目录路径。"
)
parser.add_argument(
"--num-batches",
type=int,
required=True,
help="要生成的图像批次数。"
)
parser.add_argument(
"--images-per-batch",
type=int,
required=True,
help="每批要选择的图像数量。"
)
parser.add_argument(
"--base-seed",
type=int,
default=42,
help="用于生成批次选择的随机种子基数 (默认为 42)。"
)
parser.add_argument(
"--output-file",
default="selected_batches.json",
help="用于保存选定批次信息的JSON文件路径 (默认为 selected_batches.json)。"
)
# 添加图像处理相关参数
parser.add_argument(
"--process-images",
action='store_true', # 设为开关参数
help="如果指定,则对选定的图像进行处理。"
)
parser.add_argument(
"--target-ratio",
default="3:4",
help="处理图像的目标宽高比 (格式 W:H, 例如 3:4, 16:9)。默认为 3:4。"
)
parser.add_argument(
"--variation-strength",
choices=["low", "medium", "high"],
default="medium",
help="图像处理的微调强度。默认为 medium。"
)
parser.add_argument(
"--extra-effects",
action=argparse.BooleanOptionalAction, # 提供 --extra-effects / --no-extra-effects
default=True,
help="是否在图像处理中添加额外效果 (噪点、色相等)。默认启用。"
)
parser.add_argument(
"--output-processed-dir",
default="processed_batches",
help="用于保存处理后图像的目录 (默认为 processed_batches)。仅当 --process-images 被指定时使用。"
)
args = parser.parse_args()
# 解析 target_ratio
try:
ratio_parts = args.target_ratio.split(':')
if len(ratio_parts) != 2:
raise ValueError("Ratio must be in W:H format")
target_ratio_tuple = (int(ratio_parts[0]), int(ratio_parts[1]))
if target_ratio_tuple[0] <= 0 or target_ratio_tuple[1] <= 0:
raise ValueError("Ratio dimensions must be positive")
except ValueError as e:
logger.error(f"无效的目标宽高比 '{args.target_ratio}': {e}")
return
# 检查源目录是否存在
if not os.path.isdir(args.source_dir):
logger.error(f"源目录不存在: {args.source_dir}")
return
# 列出源目录中的所有图像文件
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp') # 支持更多格式
try:
all_files = os.listdir(args.source_dir)
# 使用排序确保每次运行时文件列表顺序一致(在随机选择前)
available_images = sorted([
f for f in all_files
if os.path.isfile(os.path.join(args.source_dir, f)) and
f.lower().endswith(image_extensions)
])
except OSError as e:
logger.error(f"无法读取目录 {args.source_dir}: {e}")
return
if not available_images:
logger.error(f"在源目录 {args.source_dir} 中没有找到支持的图像文件。")
return
logger.info(f"{args.source_dir} 中找到 {len(available_images)} 张可用图像。")
# 检查是否有足够的图像(如果需要无重复选择 across batches
total_images_needed_if_no_repeat = args.num_batches * args.images_per_batch
if total_images_needed_if_no_repeat > len(available_images):
logger.warning(
f"请求的总图像数 ({total_images_needed_if_no_repeat}) 大于可用图像数 ({len(available_images)})。"
f"不同批次之间可能会选择到相同的图像。"
)
# 注意: 当前实现是每个批次独立地从完整的可用图像列表中抽样。
all_batches: Dict[str, List[str]] = {}
# 生成并选择每个批次
for i in range(args.num_batches):
batch_index = i + 1
# 为每个批次生成确定性种子
batch_seed = args.base_seed + i
logger.info(f"--- 开始处理批次 {batch_index}/{args.num_batches} (种子: {batch_seed}) ---")
selected_for_batch = select_image_batch(
available_images=available_images,
num_images_to_select=args.images_per_batch,
batch_seed=batch_seed
)
if selected_for_batch:
batch_name = f"batch_{batch_index}"
all_batches[batch_name] = selected_for_batch
# 打印部分选择结果以供预览
preview = ', '.join(selected_for_batch[:5])
if len(selected_for_batch) > 5:
preview += '...'
logger.info(f"批次 {batch_index} 选择完成,选择了 {len(selected_for_batch)} 张图像: [{preview}]")
# --- 新增:处理选定的图像 --- START ---
if args.process_images and selected_for_batch:
logger.info(f"开始处理批次 {batch_index}{len(selected_for_batch)} 张图像...")
processed_output_dir = args.output_processed_dir
# 确保处理输出目录存在
try:
if not os.path.exists(processed_output_dir):
os.makedirs(processed_output_dir)
logger.info(f"创建处理后图像的输出目录: {processed_output_dir}")
except OSError as e:
logger.error(f"无法创建处理输出目录 {processed_output_dir}: {e}. 跳过处理。")
continue # 跳过这个批次的处理
processed_count = 0
for img_idx, original_filename in enumerate(selected_for_batch):
input_image_path = os.path.join(args.source_dir, original_filename)
# 构建输出文件名
base, ext = os.path.splitext(original_filename)
# 保持原始扩展名,或者统一为 jpg这里先保持
output_filename = f"batch_{batch_index}_processed_{base}{ext}"
output_image_path = os.path.join(processed_output_dir, output_filename)
# 为每个图像使用不同的微调种子
variation_seed = batch_seed + img_idx
try:
with Image.open(input_image_path) as img:
logger.debug(f"处理图像: {original_filename} -> {output_filename} (种子: {variation_seed})")
processed_img = process_image_to_aspect_ratio(
image=img,
target_ratio=target_ratio_tuple,
add_variation=True, # 始终添加变化,由强度控制
seed=variation_seed,
variation_strength=args.variation_strength,
extra_effects=args.extra_effects
)
if processed_img:
# 保存处理后的图像
# 尝试保存为原始格式,如果失败则回退到 JPEG
try:
save_format = Image.registered_extensions().get(ext.lower())
if not save_format or save_format == 'JPEG': # 对 JPEG 设置质量
processed_img.save(output_image_path, quality=95, format='JPEG')
elif save_format == 'PNG':
processed_img.save(output_image_path, optimize=True, format='PNG')
else:
processed_img.save(output_image_path, format=save_format)
except (KeyError, ValueError, IOError, OSError) as save_err:
logger.warning(f"无法按原始格式 {ext} 保存 {output_filename} ({save_err}). 尝试保存为 JPEG.")
output_filename = f"batch_{batch_index}_processed_{base}.jpg"
output_image_path = os.path.join(processed_output_dir, output_filename)
try:
# 确保是RGB模式才能保存为JPEG
if processed_img.mode == 'RGBA':
rgb_img = Image.new("RGB", processed_img.size, (255, 255, 255))
rgb_img.paste(processed_img, mask=processed_img.split()[3]) # 应用alpha蒙版
rgb_img.save(output_image_path, quality=95, format='JPEG')
elif processed_img.mode != 'RGB':
processed_img.convert('RGB').save(output_image_path, quality=95, format='JPEG')
else:
processed_img.save(output_image_path, quality=95, format='JPEG')
except Exception as jpeg_save_err:
logger.error(f"无法将 {output_filename} 保存为 JPEG: {jpeg_save_err}")
continue # 跳过这个图像
processed_count += 1
if (img_idx + 1) % 10 == 0 or (img_idx + 1) == len(selected_for_batch):
logger.info(f" 批次 {batch_index}: 已处理 {img_idx + 1}/{len(selected_for_batch)} 张图像...")
else:
logger.warning(f"处理图像失败: {original_filename}")
except FileNotFoundError:
logger.error(f"找不到输入图像文件: {input_image_path}")
except Image.UnidentifiedImageError:
logger.error(f"无法识别或打开图像文件: {input_image_path}")
except Exception as proc_err:
logger.error(f"处理图像时发生未知错误 '{original_filename}': {proc_err}")
logger.info(f"批次 {batch_index} 处理完成,成功处理 {processed_count} 张图像。保存在: {processed_output_dir}")
# --- 新增:处理选定的图像 --- END ---
else:
logger.warning(f"批次 {batch_index} 未选择任何图像。")
# 保存结果到JSON文件
try:
output_dir = os.path.dirname(args.output_file)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
logger.info(f"创建输出目录: {output_dir}")
with open(args.output_file, 'w', encoding='utf-8') as f:
json.dump(all_batches, f, indent=4, ensure_ascii=False)
logger.info(f"所有批次的选择结果已保存到: {args.output_file}")
except IOError as e:
logger.error(f"无法写入输出文件 {args.output_file}: {e}")
# 如果保存失败,可以选择在控制台打印结果
print("\n--- 选择结果 (保存失败,打印至控制台) ---") # 修复了字符串换行问题
print(json.dumps(all_batches, indent=4, ensure_ascii=False))
print("---------------------------------------")
except OSError as e:
logger.error(f"创建输出目录时出错: {e}")
if __name__ == "__main__":
main()