Spaces:
Runtime error
Runtime error
Jainesh212
commited on
Commit
•
5f65b55
1
Parent(s):
6c2d4f9
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
from sklearn.neighbors import NearestNeighbors
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
# Load pre-trained ResNet-50 model
|
10 |
+
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', weights=None)
|
11 |
+
model.eval()
|
12 |
+
|
13 |
+
# Define image transformation
|
14 |
+
transform = transforms.Compose([
|
15 |
+
transforms.Resize((256, 256)),
|
16 |
+
transforms.CenterCrop((224, 224)),
|
17 |
+
transforms.ToTensor(),
|
18 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
19 |
+
])
|
20 |
+
|
21 |
+
# Directory containing images
|
22 |
+
images_dir = "picture/"
|
23 |
+
|
24 |
+
# List all image files in directory
|
25 |
+
image_files = [f for f in os.listdir(images_dir) if f.endswith('.jpg') or f.endswith('.png')]
|
26 |
+
|
27 |
+
if not image_files:
|
28 |
+
print("No images found in directory")
|
29 |
+
else:
|
30 |
+
# Dictionary to store feature vectors
|
31 |
+
feature_dict = {}
|
32 |
+
|
33 |
+
# Loop through images in the directory
|
34 |
+
for filename in image_files:
|
35 |
+
# Load image
|
36 |
+
image_path = os.path.join(images_dir, filename)
|
37 |
+
with Image.open(image_path) as img:
|
38 |
+
img = transform(img).unsqueeze(0)
|
39 |
+
|
40 |
+
# Extract features from penultimate layer
|
41 |
+
with torch.no_grad():
|
42 |
+
features = model(img)
|
43 |
+
features = torch.squeeze(features).detach().numpy()
|
44 |
+
|
45 |
+
feature_dict[filename] = features
|
46 |
+
|
47 |
+
# Convert dictionary of feature vectors to array
|
48 |
+
feature_array = np.array(list(feature_dict.values()))
|
49 |
+
|
50 |
+
if len(feature_array) == 0:
|
51 |
+
print("No feature vectors extracted")
|
52 |
+
else:
|
53 |
+
# Fit nearest neighbor model on feature vectors
|
54 |
+
nbrs = NearestNeighbors(n_neighbors=10, algorithm='auto').fit(feature_array)
|
55 |
+
|
56 |
+
# Loop through images again to query nearest neighbors
|
57 |
+
for query_filename in image_files:
|
58 |
+
query_feature = feature_dict[query_filename]
|
59 |
+
distances, indices = nbrs.kneighbors(query_feature.reshape(1, -1))
|
60 |
+
|
61 |
+
print("Query image:", query_filename)
|
62 |
+
print("Most similar images:")
|
63 |
+
for i, idx in enumerate(indices[0]):
|
64 |
+
if i == 0:
|
65 |
+
continue # Skip first index, as it will always be the query image itself
|
66 |
+
print(image_files[idx])
|
67 |
+
print("-----")
|