cabustillo13 commited on
Commit
5c3647f
1 Parent(s): e43c5ed

Upload 3 files

Browse files
tools/extract_features.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torchvision import models, transforms
4
+
5
+
6
+ # https://towardsdatascience.com/image-feature-extraction-using-pytorch-e3b327c3607a
7
+ class FeatureExtractor(nn.Module):
8
+ def __init__(self, model):
9
+ super(FeatureExtractor, self).__init__()
10
+ # Extract VGG-16 Feature Layers
11
+ self.features = list(model.features)
12
+ self.features = nn.Sequential(*self.features)
13
+ # Extract VGG-16 Average Pooling Layer
14
+ self.pooling = model.avgpool
15
+ # Convert the image into one-dimensional vector
16
+ self.flatten = nn.Flatten()
17
+ # Extract the first part of fully-connected layer from VGG16
18
+ self.fc = model.classifier[0]
19
+
20
+ def forward(self, x):
21
+ # It will take the input 'x' until it returns the feature vector called 'out'
22
+ out = self.features(x)
23
+ out = self.pooling(out)
24
+ out = self.flatten(out)
25
+ out = self.fc(out)
26
+ return out
27
+
28
+ # Initialize the model
29
+ """
30
+ https://download.pytorch.org/models/vgg16-397923af.pth
31
+ """
32
+ MODEL = models.vgg16(pretrained=True)
33
+ NEW_MODEL = FeatureExtractor(MODEL)
34
+
35
+ # Transform the image, so it becomes readable with the model
36
+ """
37
+ Without center crop isn't working!
38
+ -> Así que se lo puse pero al final no me afecta en nada a la imagen mientras el resize sea el mismo.
39
+ https://stackoverflow.com/questions/69334048/understanding-transforms-resize-and-centercrop-with-same-size
40
+ https://discuss.pytorch.org/t/transforms-resize-vs-centercrop/86588/2
41
+ """
42
+ TRANSFORMS = transforms.Compose([
43
+ transforms.ToPILImage(),
44
+ transforms.Resize(448),
45
+ transforms.CenterCrop(448),
46
+ transforms.ToTensor()
47
+ ])
48
+
49
+
50
+ def extract_features_image(img):
51
+ # Transform the image
52
+ img = TRANSFORMS(img)
53
+ # Reshape the image. PyTorch model reads 4-dimensional tensor
54
+ # [batch_size, channels, width, height]
55
+ img = img.reshape(1, 3, 448, 448)
56
+
57
+ # We only extract features, so we don't need gradient
58
+ with torch.no_grad():
59
+ # Extract the feature from the image
60
+ feature = NEW_MODEL(img)
61
+
62
+ return feature.tolist()[0]
63
+
tools/load_database.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ # Default values
4
+ filename = "./database/dataframes/all_database.csv"
5
+ DATABASE = pd.read_csv(filename)
6
+
7
+
8
+ def select_database(marca, prenda):
9
+ filtered_df = DATABASE[DATABASE['tag'] == prenda]
10
+
11
+ if marca != "Ninguno":
12
+ filtered_df = filtered_df[filtered_df['marca'] == marca]
13
+
14
+ # Restar index y eliminar el ex index column
15
+ filtered_df.reset_index(inplace=True, drop=True)
16
+
17
+ return filtered_df
tools/search.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.spatial import cKDTree
3
+ import ast
4
+
5
+
6
+ def search_similar_products(vgg_search, df, number_of_items=5):
7
+ """Find similar products"""
8
+
9
+ try:
10
+ # Similar features
11
+ indexes = df.index
12
+ indexes_list = indexes.tolist()
13
+ vgg_vector_list = df["vgg_vector"].tolist()
14
+ vgg_vector_list_interested = [ast.literal_eval(vgg_vector_list[i]) for i in indexes_list]
15
+
16
+ # Searching
17
+ matrix = np.array(vgg_vector_list_interested)
18
+
19
+ # [0]: distances
20
+ # [1]: indexes
21
+ results_indexes = cKDTree(matrix).query(vgg_search, k=number_of_items)[1]
22
+
23
+ # Make output
24
+ images_urls_list = []
25
+ for index in results_indexes:
26
+ image_url = df.iloc[index]['image_urls']
27
+ images_urls_list.append(image_url)
28
+
29
+ result = {
30
+ "success": True,
31
+ "result": images_urls_list,
32
+ }
33
+
34
+ except:
35
+ result = {
36
+ "success": False,
37
+ "result": [],
38
+ }
39
+
40
+ return result