416 lines
16 KiB
Python
416 lines
16 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
简化版图像相似度检测器
|
||
直接计算指定目录中图片之间的相似度,输出表格结果
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import hashlib
|
||
import argparse
|
||
import time
|
||
import numpy as np
|
||
import cv2
|
||
from PIL import Image
|
||
import matplotlib.pyplot as plt
|
||
import pandas as pd
|
||
from skimage.metrics import structural_similarity as ssim
|
||
import imagehash
|
||
from datetime import datetime
|
||
|
||
class SimpleImageDupeChecker:
|
||
"""简化版图像相似度检测器"""
|
||
|
||
def __init__(self):
|
||
"""初始化检测器"""
|
||
# 算法列表
|
||
self.algorithms = {
|
||
'md5': self.compare_md5,
|
||
'phash': self.compare_phash,
|
||
'ahash': self.compare_ahash,
|
||
'dhash': self.compare_dhash,
|
||
'color_hist': self.compare_color_histogram,
|
||
'sift': self.compare_sift,
|
||
'ssim': self.compare_ssim
|
||
}
|
||
|
||
# 初始化SIFT检测器
|
||
self.sift = cv2.SIFT_create()
|
||
# FLANN匹配器参数
|
||
FLANN_INDEX_KDTREE = 1
|
||
index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
|
||
search_params = dict(checks=50)
|
||
self.flann = cv2.FlannBasedMatcher(index_params, search_params)
|
||
|
||
def get_md5(self, image_path):
|
||
"""计算图像的MD5哈希值"""
|
||
with open(image_path, 'rb') as f:
|
||
md5hash = hashlib.md5(f.read()).hexdigest()
|
||
return md5hash
|
||
|
||
def compare_md5(self, image1_path, image2_path):
|
||
"""比较两张图像的MD5哈希值相似度"""
|
||
hash1 = self.get_md5(image1_path)
|
||
hash2 = self.get_md5(image2_path)
|
||
|
||
# 如果哈希完全相同返回1.0,否则返回0.0
|
||
return 1.0 if hash1 == hash2 else 0.0
|
||
|
||
def get_phash(self, image_path, hash_size=8):
|
||
"""计算图像的感知哈希(pHash)"""
|
||
img = Image.open(image_path).convert('L').resize((hash_size * 4, hash_size * 4), Image.LANCZOS)
|
||
return imagehash.phash(img, hash_size=hash_size)
|
||
|
||
def compare_phash(self, image1_path, image2_path):
|
||
"""比较两张图像的感知哈希(pHash)相似度"""
|
||
hash1 = self.get_phash(image1_path)
|
||
hash2 = self.get_phash(image2_path)
|
||
|
||
# 计算哈希相似度(64位哈希的汉明距离)
|
||
hash_diff = hash1 - hash2
|
||
max_diff = 64.0 # 8x8哈希的最大汉明距离
|
||
similarity = 1.0 - (hash_diff / max_diff)
|
||
return max(0.0, similarity) # 确保相似度不低于0
|
||
|
||
def get_ahash(self, image_path, hash_size=8):
|
||
"""计算图像的平均哈希(aHash)"""
|
||
img = Image.open(image_path).convert('L').resize((hash_size, hash_size), Image.LANCZOS)
|
||
return imagehash.average_hash(img, hash_size=hash_size)
|
||
|
||
def compare_ahash(self, image1_path, image2_path):
|
||
"""比较两张图像的平均哈希(aHash)相似度"""
|
||
hash1 = self.get_ahash(image1_path)
|
||
hash2 = self.get_ahash(image2_path)
|
||
|
||
# 计算哈希相似度
|
||
hash_diff = hash1 - hash2
|
||
max_diff = 64.0 # 8x8哈希的最大汉明距离
|
||
similarity = 1.0 - (hash_diff / max_diff)
|
||
return max(0.0, similarity)
|
||
|
||
def get_dhash(self, image_path, hash_size=8):
|
||
"""计算图像的差值哈希(dHash)"""
|
||
img = Image.open(image_path).convert('L').resize((hash_size + 1, hash_size), Image.LANCZOS)
|
||
return imagehash.dhash(img, hash_size=hash_size)
|
||
|
||
def compare_dhash(self, image1_path, image2_path):
|
||
"""比较两张图像的差值哈希(dHash)相似度"""
|
||
hash1 = self.get_dhash(image1_path)
|
||
hash2 = self.get_dhash(image2_path)
|
||
|
||
# 计算哈希相似度
|
||
hash_diff = hash1 - hash2
|
||
max_diff = 64.0 # 8x8哈希的最大汉明距离
|
||
similarity = 1.0 - (hash_diff / max_diff)
|
||
return max(0.0, similarity)
|
||
|
||
def get_color_histogram(self, image_path):
|
||
"""计算图像的颜色直方图"""
|
||
img = cv2.imread(image_path)
|
||
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
||
|
||
# 计算HSV颜色空间的直方图
|
||
hist = cv2.calcHist([img], [0, 1, 2], None, [8, 8, 8], [0, 180, 0, 256, 0, 256])
|
||
cv2.normalize(hist, hist, 0, 1.0, cv2.NORM_MINMAX)
|
||
return hist.flatten()
|
||
|
||
def compare_color_histogram(self, image1_path, image2_path):
|
||
"""比较两张图像的颜色直方图相似度"""
|
||
hist1 = self.get_color_histogram(image1_path)
|
||
hist2 = self.get_color_histogram(image2_path)
|
||
|
||
# 计算直方图相似度
|
||
similarity = cv2.compareHist(hist1, hist2, cv2.HISTCMP_CORREL)
|
||
return max(0.0, similarity) # 确保相似度不低于0
|
||
|
||
def get_sift_features(self, image_path):
|
||
"""获取图像的SIFT特征点和描述符"""
|
||
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
|
||
keypoints, descriptors = self.sift.detectAndCompute(img, None)
|
||
return keypoints, descriptors
|
||
|
||
def compare_sift(self, image1_path, image2_path):
|
||
"""比较两张图像的SIFT特征点相似度"""
|
||
kp1, des1 = self.get_sift_features(image1_path)
|
||
kp2, des2 = self.get_sift_features(image2_path)
|
||
|
||
# 如果没有足够的特征点,返回低相似度
|
||
if des1 is None or des2 is None or len(des1) < 2 or len(des2) < 2:
|
||
return 0.0
|
||
|
||
# 使用FLANN匹配器查找最佳匹配
|
||
matches = self.flann.knnMatch(des1, des2, k=2)
|
||
|
||
# 应用比率测试,筛选好的匹配
|
||
good_matches = []
|
||
for m, n in matches:
|
||
if m.distance < 0.7 * n.distance:
|
||
good_matches.append(m)
|
||
|
||
# 计算相似度
|
||
similarity = len(good_matches) / max(len(kp1), len(kp2)) if max(len(kp1), len(kp2)) > 0 else 0
|
||
return min(1.0, similarity) # 确保相似度不超过1
|
||
|
||
def compare_ssim(self, image1_path, image2_path):
|
||
"""比较两张图像的结构相似性(SSIM)"""
|
||
img1 = cv2.imread(image1_path, cv2.IMREAD_GRAYSCALE)
|
||
img2 = cv2.imread(image2_path, cv2.IMREAD_GRAYSCALE)
|
||
|
||
# 确保两张图像尺寸相同
|
||
h, w = min(img1.shape[0], img2.shape[0]), min(img1.shape[1], img2.shape[1])
|
||
img1 = cv2.resize(img1, (w, h))
|
||
img2 = cv2.resize(img2, (w, h))
|
||
|
||
# 计算SSIM
|
||
similarity = ssim(img1, img2)
|
||
return max(0.0, similarity) # 确保相似度不低于0
|
||
|
||
def check_images(self, input_dir, output_dir=None):
|
||
"""
|
||
检查目录中所有图像之间的相似度
|
||
|
||
Args:
|
||
input_dir: 输入图像目录
|
||
output_dir: 输出结果目录,如果为None则使用input_dir
|
||
|
||
Returns:
|
||
包含相似度信息的DataFrame
|
||
"""
|
||
# 如果未指定输出目录,使用输入目录
|
||
if output_dir is None:
|
||
output_dir = input_dir
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 获取所有图像文件
|
||
print("正在读取图像文件...")
|
||
image_files = []
|
||
for file in os.listdir(input_dir):
|
||
file_path = os.path.join(input_dir, file)
|
||
if os.path.isfile(file_path) and file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.gif')):
|
||
image_files.append(file_path)
|
||
|
||
# 检查图像数量
|
||
if len(image_files) == 0:
|
||
print(f"错误: 在目录 '{input_dir}' 中没有找到图像文件")
|
||
return None
|
||
|
||
print(f"找到 {len(image_files)} 个图像文件,开始计算相似度...")
|
||
|
||
# 结果列表
|
||
results = []
|
||
|
||
# 计算所有图像对的相似度
|
||
total_pairs = len(image_files) * (len(image_files) - 1) // 2
|
||
pair_count = 0
|
||
|
||
for i, img1 in enumerate(image_files):
|
||
img1_name = os.path.basename(img1)
|
||
|
||
for j, img2 in enumerate(image_files):
|
||
if j <= i: # 跳过重复比较
|
||
continue
|
||
|
||
img2_name = os.path.basename(img2)
|
||
pair_count += 1
|
||
print(f"比较图像对 {pair_count}/{total_pairs}: {img1_name} vs {img2_name}")
|
||
|
||
# 对每种算法计算相似度
|
||
for alg_name, compare_func in self.algorithms.items():
|
||
try:
|
||
start_time = time.time()
|
||
similarity = compare_func(img1, img2)
|
||
compute_time = time.time() - start_time
|
||
|
||
results.append({
|
||
'image1': img1_name,
|
||
'image2': img2_name,
|
||
'algorithm': alg_name,
|
||
'similarity': similarity,
|
||
'compute_time': compute_time
|
||
})
|
||
|
||
except Exception as e:
|
||
print(f" 算法 {alg_name} 比较出错: {e}")
|
||
results.append({
|
||
'image1': img1_name,
|
||
'image2': img2_name,
|
||
'algorithm': alg_name,
|
||
'similarity': None,
|
||
'compute_time': None,
|
||
'error': str(e)
|
||
})
|
||
|
||
# 创建DataFrame
|
||
df = pd.DataFrame(results)
|
||
|
||
# 保存结果
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
csv_path = os.path.join(output_dir, f"similarity_results_{timestamp}.csv")
|
||
df.to_csv(csv_path, index=False)
|
||
print(f"结果已保存至: {csv_path}")
|
||
|
||
# 创建透视表
|
||
pivot_df = df.pivot_table(
|
||
index=['image1', 'image2'],
|
||
columns='algorithm',
|
||
values='similarity',
|
||
aggfunc='first'
|
||
).reset_index()
|
||
|
||
# 保存透视表
|
||
pivot_path = os.path.join(output_dir, f"similarity_pivot_{timestamp}.csv")
|
||
pivot_df.to_csv(pivot_path, index=False)
|
||
print(f"透视表已保存至: {pivot_path}")
|
||
|
||
# 显示结果摘要
|
||
self.print_summary(df)
|
||
|
||
# 生成可视化
|
||
self.visualize_results(df, output_dir)
|
||
|
||
return df
|
||
|
||
def print_summary(self, df):
|
||
"""打印结果摘要"""
|
||
print("\n===== 图像相似度检测结果摘要 =====")
|
||
|
||
# 按算法分组计算平均相似度
|
||
alg_summary = df.groupby('algorithm')['similarity'].agg(['mean', 'min', 'max', 'std']).reset_index()
|
||
alg_summary = alg_summary.sort_values('mean', ascending=False)
|
||
|
||
print("\n各算法平均相似度:")
|
||
for _, row in alg_summary.iterrows():
|
||
print(f" {row['algorithm']:<12}: 平均值 = {row['mean']:.4f} (最小值: {row['min']:.4f}, 最大值: {row['max']:.4f}, 标准差: {row['std']:.4f})")
|
||
|
||
# 显示最相似的图像对
|
||
print("\n相似度最高的图像对:")
|
||
for alg in df['algorithm'].unique():
|
||
alg_df = df[df['algorithm'] == alg]
|
||
if not alg_df.empty:
|
||
max_idx = alg_df['similarity'].idxmax()
|
||
max_row = alg_df.loc[max_idx]
|
||
print(f" {alg:<12}: {max_row['image1']} 和 {max_row['image2']} (相似度: {max_row['similarity']:.4f})")
|
||
|
||
# 显示最不相似的图像对
|
||
print("\n相似度最低的图像对:")
|
||
for alg in df['algorithm'].unique():
|
||
alg_df = df[df['algorithm'] == alg]
|
||
if not alg_df.empty:
|
||
min_idx = alg_df['similarity'].idxmin()
|
||
min_row = alg_df.loc[min_idx]
|
||
print(f" {alg:<12}: {min_row['image1']} 和 {min_row['image2']} (相似度: {min_row['similarity']:.4f})")
|
||
|
||
def visualize_results(self, df, output_dir):
|
||
"""可视化结果"""
|
||
print("\n生成结果可视化...")
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
|
||
# 1. 绘制条形图 - 各算法的平均相似度
|
||
plt.figure(figsize=(10, 6))
|
||
avg_similarity = df.groupby('algorithm')['similarity'].mean().reset_index()
|
||
avg_similarity = avg_similarity.sort_values('similarity', ascending=False)
|
||
|
||
plt.bar(avg_similarity['algorithm'], avg_similarity['similarity'], color='skyblue')
|
||
plt.xlabel('算法')
|
||
plt.ylabel('平均相似度')
|
||
plt.title('各算法的平均相似度')
|
||
plt.xticks(rotation=45)
|
||
plt.grid(True, linestyle='--', alpha=0.7)
|
||
plt.tight_layout()
|
||
|
||
bar_chart_path = os.path.join(output_dir, f"algorithm_avg_similarity_{timestamp}.png")
|
||
plt.savefig(bar_chart_path)
|
||
|
||
# 2. 绘制热力图 - 各图像对的相似度矩阵
|
||
# 为每个算法生成一个热力图
|
||
image_names = sorted(list(set(df['image1'].unique()) | set(df['image2'].unique())))
|
||
|
||
for alg in df['algorithm'].unique():
|
||
plt.figure(figsize=(10, 8))
|
||
|
||
# 创建热力图数据
|
||
heatmap_data = np.zeros((len(image_names), len(image_names)))
|
||
|
||
# 填充热力图数据
|
||
for idx, row in df[df['algorithm'] == alg].iterrows():
|
||
i = image_names.index(row['image1'])
|
||
j = image_names.index(row['image2'])
|
||
heatmap_data[i, j] = row['similarity']
|
||
heatmap_data[j, i] = row['similarity'] # 对称矩阵
|
||
|
||
# 热力图对角线设为1.0(自己与自己相似度为1)
|
||
for i in range(len(image_names)):
|
||
heatmap_data[i, i] = 1.0
|
||
|
||
# 创建热力图
|
||
plt.imshow(heatmap_data, cmap='viridis')
|
||
plt.colorbar(label='相似度')
|
||
plt.title(f'{alg} 算法的图像相似度热力图')
|
||
|
||
# 设置坐标轴刻度
|
||
plt.xticks(np.arange(len(image_names)), image_names, rotation=90)
|
||
plt.yticks(np.arange(len(image_names)), image_names)
|
||
|
||
# 添加文本标签
|
||
for i in range(len(image_names)):
|
||
for j in range(len(image_names)):
|
||
text = plt.text(j, i, f"{heatmap_data[i, j]:.2f}",
|
||
ha="center", va="center",
|
||
color="w" if heatmap_data[i, j] < 0.7 else "black")
|
||
|
||
plt.tight_layout()
|
||
heatmap_path = os.path.join(output_dir, f"similarity_heatmap_{alg}_{timestamp}.png")
|
||
plt.savefig(heatmap_path)
|
||
|
||
# 3. 箱线图 - 各算法的相似度分布
|
||
plt.figure(figsize=(10, 6))
|
||
|
||
# 准备数据
|
||
data = []
|
||
labels = []
|
||
|
||
for alg in sorted(df['algorithm'].unique()):
|
||
alg_data = df[df['algorithm'] == alg]['similarity']
|
||
if not alg_data.empty:
|
||
data.append(alg_data)
|
||
labels.append(alg)
|
||
|
||
# 创建箱线图
|
||
plt.boxplot(data, labels=labels)
|
||
plt.ylabel('相似度')
|
||
plt.title('各算法的相似度分布')
|
||
plt.grid(True, linestyle='--', alpha=0.7)
|
||
plt.tight_layout()
|
||
|
||
boxplot_path = os.path.join(output_dir, f"algorithm_boxplot_{timestamp}.png")
|
||
plt.savefig(boxplot_path)
|
||
|
||
# 关闭所有图形
|
||
plt.close('all')
|
||
|
||
print(f"可视化图表已保存到:\n 条形图: {bar_chart_path}\n 箱线图: {boxplot_path}")
|
||
print(f" 各算法热力图已保存到输出目录")
|
||
|
||
def main():
|
||
"""主函数"""
|
||
parser = argparse.ArgumentParser(description='简化版图像相似度检测工具')
|
||
parser.add_argument('--input_dir', '-i', type=str, default='./scripts/image_test/test_images',
|
||
help='输入图像目录路径')
|
||
parser.add_argument('--output_dir', '-o', type=str, default=None,
|
||
help='输出结果目录路径,默认与输入目录相同')
|
||
|
||
args = parser.parse_args()
|
||
|
||
print("===== 简化版图像相似度检测工具 =====")
|
||
print(f"输入目录: {args.input_dir}")
|
||
print(f"输出目录: {args.output_dir or args.input_dir}")
|
||
|
||
checker = SimpleImageDupeChecker()
|
||
checker.check_images(args.input_dir, args.output_dir)
|
||
|
||
print("\n===== 检测完成 =====")
|
||
|
||
if __name__ == '__main__':
|
||
main() |