TravelContentCreator/scripts/image_test/simple_dupe_checker.py

416 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
简化版图像相似度检测器
直接计算指定目录中图片之间的相似度,输出表格结果
"""
import os
import sys
import hashlib
import argparse
import time
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import pandas as pd
from skimage.metrics import structural_similarity as ssim
import imagehash
from datetime import datetime
class SimpleImageDupeChecker:
"""简化版图像相似度检测器"""
def __init__(self):
"""初始化检测器"""
# 算法列表
self.algorithms = {
'md5': self.compare_md5,
'phash': self.compare_phash,
'ahash': self.compare_ahash,
'dhash': self.compare_dhash,
'color_hist': self.compare_color_histogram,
'sift': self.compare_sift,
'ssim': self.compare_ssim
}
# 初始化SIFT检测器
self.sift = cv2.SIFT_create()
# FLANN匹配器参数
FLANN_INDEX_KDTREE = 1
index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
search_params = dict(checks=50)
self.flann = cv2.FlannBasedMatcher(index_params, search_params)
def get_md5(self, image_path):
"""计算图像的MD5哈希值"""
with open(image_path, 'rb') as f:
md5hash = hashlib.md5(f.read()).hexdigest()
return md5hash
def compare_md5(self, image1_path, image2_path):
"""比较两张图像的MD5哈希值相似度"""
hash1 = self.get_md5(image1_path)
hash2 = self.get_md5(image2_path)
# 如果哈希完全相同返回1.0否则返回0.0
return 1.0 if hash1 == hash2 else 0.0
def get_phash(self, image_path, hash_size=8):
"""计算图像的感知哈希(pHash)"""
img = Image.open(image_path).convert('L').resize((hash_size * 4, hash_size * 4), Image.LANCZOS)
return imagehash.phash(img, hash_size=hash_size)
def compare_phash(self, image1_path, image2_path):
"""比较两张图像的感知哈希(pHash)相似度"""
hash1 = self.get_phash(image1_path)
hash2 = self.get_phash(image2_path)
# 计算哈希相似度64位哈希的汉明距离
hash_diff = hash1 - hash2
max_diff = 64.0 # 8x8哈希的最大汉明距离
similarity = 1.0 - (hash_diff / max_diff)
return max(0.0, similarity) # 确保相似度不低于0
def get_ahash(self, image_path, hash_size=8):
"""计算图像的平均哈希(aHash)"""
img = Image.open(image_path).convert('L').resize((hash_size, hash_size), Image.LANCZOS)
return imagehash.average_hash(img, hash_size=hash_size)
def compare_ahash(self, image1_path, image2_path):
"""比较两张图像的平均哈希(aHash)相似度"""
hash1 = self.get_ahash(image1_path)
hash2 = self.get_ahash(image2_path)
# 计算哈希相似度
hash_diff = hash1 - hash2
max_diff = 64.0 # 8x8哈希的最大汉明距离
similarity = 1.0 - (hash_diff / max_diff)
return max(0.0, similarity)
def get_dhash(self, image_path, hash_size=8):
"""计算图像的差值哈希(dHash)"""
img = Image.open(image_path).convert('L').resize((hash_size + 1, hash_size), Image.LANCZOS)
return imagehash.dhash(img, hash_size=hash_size)
def compare_dhash(self, image1_path, image2_path):
"""比较两张图像的差值哈希(dHash)相似度"""
hash1 = self.get_dhash(image1_path)
hash2 = self.get_dhash(image2_path)
# 计算哈希相似度
hash_diff = hash1 - hash2
max_diff = 64.0 # 8x8哈希的最大汉明距离
similarity = 1.0 - (hash_diff / max_diff)
return max(0.0, similarity)
def get_color_histogram(self, image_path):
"""计算图像的颜色直方图"""
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
# 计算HSV颜色空间的直方图
hist = cv2.calcHist([img], [0, 1, 2], None, [8, 8, 8], [0, 180, 0, 256, 0, 256])
cv2.normalize(hist, hist, 0, 1.0, cv2.NORM_MINMAX)
return hist.flatten()
def compare_color_histogram(self, image1_path, image2_path):
"""比较两张图像的颜色直方图相似度"""
hist1 = self.get_color_histogram(image1_path)
hist2 = self.get_color_histogram(image2_path)
# 计算直方图相似度
similarity = cv2.compareHist(hist1, hist2, cv2.HISTCMP_CORREL)
return max(0.0, similarity) # 确保相似度不低于0
def get_sift_features(self, image_path):
"""获取图像的SIFT特征点和描述符"""
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
keypoints, descriptors = self.sift.detectAndCompute(img, None)
return keypoints, descriptors
def compare_sift(self, image1_path, image2_path):
"""比较两张图像的SIFT特征点相似度"""
kp1, des1 = self.get_sift_features(image1_path)
kp2, des2 = self.get_sift_features(image2_path)
# 如果没有足够的特征点,返回低相似度
if des1 is None or des2 is None or len(des1) < 2 or len(des2) < 2:
return 0.0
# 使用FLANN匹配器查找最佳匹配
matches = self.flann.knnMatch(des1, des2, k=2)
# 应用比率测试,筛选好的匹配
good_matches = []
for m, n in matches:
if m.distance < 0.7 * n.distance:
good_matches.append(m)
# 计算相似度
similarity = len(good_matches) / max(len(kp1), len(kp2)) if max(len(kp1), len(kp2)) > 0 else 0
return min(1.0, similarity) # 确保相似度不超过1
def compare_ssim(self, image1_path, image2_path):
"""比较两张图像的结构相似性(SSIM)"""
img1 = cv2.imread(image1_path, cv2.IMREAD_GRAYSCALE)
img2 = cv2.imread(image2_path, cv2.IMREAD_GRAYSCALE)
# 确保两张图像尺寸相同
h, w = min(img1.shape[0], img2.shape[0]), min(img1.shape[1], img2.shape[1])
img1 = cv2.resize(img1, (w, h))
img2 = cv2.resize(img2, (w, h))
# 计算SSIM
similarity = ssim(img1, img2)
return max(0.0, similarity) # 确保相似度不低于0
def check_images(self, input_dir, output_dir=None):
"""
检查目录中所有图像之间的相似度
Args:
input_dir: 输入图像目录
output_dir: 输出结果目录如果为None则使用input_dir
Returns:
包含相似度信息的DataFrame
"""
# 如果未指定输出目录,使用输入目录
if output_dir is None:
output_dir = input_dir
os.makedirs(output_dir, exist_ok=True)
# 获取所有图像文件
print("正在读取图像文件...")
image_files = []
for file in os.listdir(input_dir):
file_path = os.path.join(input_dir, file)
if os.path.isfile(file_path) and file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.gif')):
image_files.append(file_path)
# 检查图像数量
if len(image_files) == 0:
print(f"错误: 在目录 '{input_dir}' 中没有找到图像文件")
return None
print(f"找到 {len(image_files)} 个图像文件,开始计算相似度...")
# 结果列表
results = []
# 计算所有图像对的相似度
total_pairs = len(image_files) * (len(image_files) - 1) // 2
pair_count = 0
for i, img1 in enumerate(image_files):
img1_name = os.path.basename(img1)
for j, img2 in enumerate(image_files):
if j <= i: # 跳过重复比较
continue
img2_name = os.path.basename(img2)
pair_count += 1
print(f"比较图像对 {pair_count}/{total_pairs}: {img1_name} vs {img2_name}")
# 对每种算法计算相似度
for alg_name, compare_func in self.algorithms.items():
try:
start_time = time.time()
similarity = compare_func(img1, img2)
compute_time = time.time() - start_time
results.append({
'image1': img1_name,
'image2': img2_name,
'algorithm': alg_name,
'similarity': similarity,
'compute_time': compute_time
})
except Exception as e:
print(f" 算法 {alg_name} 比较出错: {e}")
results.append({
'image1': img1_name,
'image2': img2_name,
'algorithm': alg_name,
'similarity': None,
'compute_time': None,
'error': str(e)
})
# 创建DataFrame
df = pd.DataFrame(results)
# 保存结果
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
csv_path = os.path.join(output_dir, f"similarity_results_{timestamp}.csv")
df.to_csv(csv_path, index=False)
print(f"结果已保存至: {csv_path}")
# 创建透视表
pivot_df = df.pivot_table(
index=['image1', 'image2'],
columns='algorithm',
values='similarity',
aggfunc='first'
).reset_index()
# 保存透视表
pivot_path = os.path.join(output_dir, f"similarity_pivot_{timestamp}.csv")
pivot_df.to_csv(pivot_path, index=False)
print(f"透视表已保存至: {pivot_path}")
# 显示结果摘要
self.print_summary(df)
# 生成可视化
self.visualize_results(df, output_dir)
return df
def print_summary(self, df):
"""打印结果摘要"""
print("\n===== 图像相似度检测结果摘要 =====")
# 按算法分组计算平均相似度
alg_summary = df.groupby('algorithm')['similarity'].agg(['mean', 'min', 'max', 'std']).reset_index()
alg_summary = alg_summary.sort_values('mean', ascending=False)
print("\n各算法平均相似度:")
for _, row in alg_summary.iterrows():
print(f" {row['algorithm']:<12}: 平均值 = {row['mean']:.4f} (最小值: {row['min']:.4f}, 最大值: {row['max']:.4f}, 标准差: {row['std']:.4f})")
# 显示最相似的图像对
print("\n相似度最高的图像对:")
for alg in df['algorithm'].unique():
alg_df = df[df['algorithm'] == alg]
if not alg_df.empty:
max_idx = alg_df['similarity'].idxmax()
max_row = alg_df.loc[max_idx]
print(f" {alg:<12}: {max_row['image1']}{max_row['image2']} (相似度: {max_row['similarity']:.4f})")
# 显示最不相似的图像对
print("\n相似度最低的图像对:")
for alg in df['algorithm'].unique():
alg_df = df[df['algorithm'] == alg]
if not alg_df.empty:
min_idx = alg_df['similarity'].idxmin()
min_row = alg_df.loc[min_idx]
print(f" {alg:<12}: {min_row['image1']}{min_row['image2']} (相似度: {min_row['similarity']:.4f})")
def visualize_results(self, df, output_dir):
"""可视化结果"""
print("\n生成结果可视化...")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# 1. 绘制条形图 - 各算法的平均相似度
plt.figure(figsize=(10, 6))
avg_similarity = df.groupby('algorithm')['similarity'].mean().reset_index()
avg_similarity = avg_similarity.sort_values('similarity', ascending=False)
plt.bar(avg_similarity['algorithm'], avg_similarity['similarity'], color='skyblue')
plt.xlabel('算法')
plt.ylabel('平均相似度')
plt.title('各算法的平均相似度')
plt.xticks(rotation=45)
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
bar_chart_path = os.path.join(output_dir, f"algorithm_avg_similarity_{timestamp}.png")
plt.savefig(bar_chart_path)
# 2. 绘制热力图 - 各图像对的相似度矩阵
# 为每个算法生成一个热力图
image_names = sorted(list(set(df['image1'].unique()) | set(df['image2'].unique())))
for alg in df['algorithm'].unique():
plt.figure(figsize=(10, 8))
# 创建热力图数据
heatmap_data = np.zeros((len(image_names), len(image_names)))
# 填充热力图数据
for idx, row in df[df['algorithm'] == alg].iterrows():
i = image_names.index(row['image1'])
j = image_names.index(row['image2'])
heatmap_data[i, j] = row['similarity']
heatmap_data[j, i] = row['similarity'] # 对称矩阵
# 热力图对角线设为1.0自己与自己相似度为1
for i in range(len(image_names)):
heatmap_data[i, i] = 1.0
# 创建热力图
plt.imshow(heatmap_data, cmap='viridis')
plt.colorbar(label='相似度')
plt.title(f'{alg} 算法的图像相似度热力图')
# 设置坐标轴刻度
plt.xticks(np.arange(len(image_names)), image_names, rotation=90)
plt.yticks(np.arange(len(image_names)), image_names)
# 添加文本标签
for i in range(len(image_names)):
for j in range(len(image_names)):
text = plt.text(j, i, f"{heatmap_data[i, j]:.2f}",
ha="center", va="center",
color="w" if heatmap_data[i, j] < 0.7 else "black")
plt.tight_layout()
heatmap_path = os.path.join(output_dir, f"similarity_heatmap_{alg}_{timestamp}.png")
plt.savefig(heatmap_path)
# 3. 箱线图 - 各算法的相似度分布
plt.figure(figsize=(10, 6))
# 准备数据
data = []
labels = []
for alg in sorted(df['algorithm'].unique()):
alg_data = df[df['algorithm'] == alg]['similarity']
if not alg_data.empty:
data.append(alg_data)
labels.append(alg)
# 创建箱线图
plt.boxplot(data, labels=labels)
plt.ylabel('相似度')
plt.title('各算法的相似度分布')
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
boxplot_path = os.path.join(output_dir, f"algorithm_boxplot_{timestamp}.png")
plt.savefig(boxplot_path)
# 关闭所有图形
plt.close('all')
print(f"可视化图表已保存到:\n 条形图: {bar_chart_path}\n 箱线图: {boxplot_path}")
print(f" 各算法热力图已保存到输出目录")
def main():
"""主函数"""
parser = argparse.ArgumentParser(description='简化版图像相似度检测工具')
parser.add_argument('--input_dir', '-i', type=str, default='./scripts/image_test/test_images',
help='输入图像目录路径')
parser.add_argument('--output_dir', '-o', type=str, default=None,
help='输出结果目录路径,默认与输入目录相同')
args = parser.parse_args()
print("===== 简化版图像相似度检测工具 =====")
print(f"输入目录: {args.input_dir}")
print(f"输出目录: {args.output_dir or args.input_dir}")
checker = SimpleImageDupeChecker()
checker.check_images(args.input_dir, args.output_dir)
print("\n===== 检测完成 =====")
if __name__ == '__main__':
main()