Jainesh212 commited on
Commit
5f65b55
1 Parent(s): 6c2d4f9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ from PIL import Image
4
+ import os
5
+ from sklearn.neighbors import NearestNeighbors
6
+ import numpy as np
7
+
8
+
9
+ # Load pre-trained ResNet-50 model
10
+ model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', weights=None)
11
+ model.eval()
12
+
13
+ # Define image transformation
14
+ transform = transforms.Compose([
15
+ transforms.Resize((256, 256)),
16
+ transforms.CenterCrop((224, 224)),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
19
+ ])
20
+
21
+ # Directory containing images
22
+ images_dir = "picture/"
23
+
24
+ # List all image files in directory
25
+ image_files = [f for f in os.listdir(images_dir) if f.endswith('.jpg') or f.endswith('.png')]
26
+
27
+ if not image_files:
28
+ print("No images found in directory")
29
+ else:
30
+ # Dictionary to store feature vectors
31
+ feature_dict = {}
32
+
33
+ # Loop through images in the directory
34
+ for filename in image_files:
35
+ # Load image
36
+ image_path = os.path.join(images_dir, filename)
37
+ with Image.open(image_path) as img:
38
+ img = transform(img).unsqueeze(0)
39
+
40
+ # Extract features from penultimate layer
41
+ with torch.no_grad():
42
+ features = model(img)
43
+ features = torch.squeeze(features).detach().numpy()
44
+
45
+ feature_dict[filename] = features
46
+
47
+ # Convert dictionary of feature vectors to array
48
+ feature_array = np.array(list(feature_dict.values()))
49
+
50
+ if len(feature_array) == 0:
51
+ print("No feature vectors extracted")
52
+ else:
53
+ # Fit nearest neighbor model on feature vectors
54
+ nbrs = NearestNeighbors(n_neighbors=10, algorithm='auto').fit(feature_array)
55
+
56
+ # Loop through images again to query nearest neighbors
57
+ for query_filename in image_files:
58
+ query_feature = feature_dict[query_filename]
59
+ distances, indices = nbrs.kneighbors(query_feature.reshape(1, -1))
60
+
61
+ print("Query image:", query_filename)
62
+ print("Most similar images:")
63
+ for i, idx in enumerate(indices[0]):
64
+ if i == 0:
65
+ continue # Skip first index, as it will always be the query image itself
66
+ print(image_files[idx])
67
+ print("-----")