Streamlit / app.py
Jainesh212's picture
Create app.py
5f65b55
import torch
import torchvision.transforms as transforms
from PIL import Image
import os
from sklearn.neighbors import NearestNeighbors
import numpy as np
# Load pre-trained ResNet-50 model
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', weights=None)
model.eval()
# Define image transformation
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Directory containing images
images_dir = "picture/"
# List all image files in directory
image_files = [f for f in os.listdir(images_dir) if f.endswith('.jpg') or f.endswith('.png')]
if not image_files:
print("No images found in directory")
else:
# Dictionary to store feature vectors
feature_dict = {}
# Loop through images in the directory
for filename in image_files:
# Load image
image_path = os.path.join(images_dir, filename)
with Image.open(image_path) as img:
img = transform(img).unsqueeze(0)
# Extract features from penultimate layer
with torch.no_grad():
features = model(img)
features = torch.squeeze(features).detach().numpy()
feature_dict[filename] = features
# Convert dictionary of feature vectors to array
feature_array = np.array(list(feature_dict.values()))
if len(feature_array) == 0:
print("No feature vectors extracted")
else:
# Fit nearest neighbor model on feature vectors
nbrs = NearestNeighbors(n_neighbors=10, algorithm='auto').fit(feature_array)
# Loop through images again to query nearest neighbors
for query_filename in image_files:
query_feature = feature_dict[query_filename]
distances, indices = nbrs.kneighbors(query_feature.reshape(1, -1))
print("Query image:", query_filename)
print("Most similar images:")
for i, idx in enumerate(indices[0]):
if i == 0:
continue # Skip first index, as it will always be the query image itself
print(image_files[idx])
print("-----")