Spaces:
Runtime error
Runtime error
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import os | |
from sklearn.neighbors import NearestNeighbors | |
import numpy as np | |
# Load pre-trained ResNet-50 model | |
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', weights=None) | |
model.eval() | |
# Define image transformation | |
transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.CenterCrop((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# Directory containing images | |
images_dir = "picture/" | |
# List all image files in directory | |
image_files = [f for f in os.listdir(images_dir) if f.endswith('.jpg') or f.endswith('.png')] | |
if not image_files: | |
print("No images found in directory") | |
else: | |
# Dictionary to store feature vectors | |
feature_dict = {} | |
# Loop through images in the directory | |
for filename in image_files: | |
# Load image | |
image_path = os.path.join(images_dir, filename) | |
with Image.open(image_path) as img: | |
img = transform(img).unsqueeze(0) | |
# Extract features from penultimate layer | |
with torch.no_grad(): | |
features = model(img) | |
features = torch.squeeze(features).detach().numpy() | |
feature_dict[filename] = features | |
# Convert dictionary of feature vectors to array | |
feature_array = np.array(list(feature_dict.values())) | |
if len(feature_array) == 0: | |
print("No feature vectors extracted") | |
else: | |
# Fit nearest neighbor model on feature vectors | |
nbrs = NearestNeighbors(n_neighbors=10, algorithm='auto').fit(feature_array) | |
# Loop through images again to query nearest neighbors | |
for query_filename in image_files: | |
query_feature = feature_dict[query_filename] | |
distances, indices = nbrs.kneighbors(query_feature.reshape(1, -1)) | |
print("Query image:", query_filename) | |
print("Most similar images:") | |
for i, idx in enumerate(indices[0]): | |
if i == 0: | |
continue # Skip first index, as it will always be the query image itself | |
print(image_files[idx]) | |
print("-----") | |