Spaces:
Runtime error
Runtime error
cabustillo13
commited on
Commit
•
5c3647f
1
Parent(s):
e43c5ed
Upload 3 files
Browse files- tools/extract_features.py +63 -0
- tools/load_database.py +17 -0
- tools/search.py +40 -0
tools/extract_features.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torchvision import models, transforms
|
4 |
+
|
5 |
+
|
6 |
+
# https://towardsdatascience.com/image-feature-extraction-using-pytorch-e3b327c3607a
|
7 |
+
class FeatureExtractor(nn.Module):
|
8 |
+
def __init__(self, model):
|
9 |
+
super(FeatureExtractor, self).__init__()
|
10 |
+
# Extract VGG-16 Feature Layers
|
11 |
+
self.features = list(model.features)
|
12 |
+
self.features = nn.Sequential(*self.features)
|
13 |
+
# Extract VGG-16 Average Pooling Layer
|
14 |
+
self.pooling = model.avgpool
|
15 |
+
# Convert the image into one-dimensional vector
|
16 |
+
self.flatten = nn.Flatten()
|
17 |
+
# Extract the first part of fully-connected layer from VGG16
|
18 |
+
self.fc = model.classifier[0]
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
# It will take the input 'x' until it returns the feature vector called 'out'
|
22 |
+
out = self.features(x)
|
23 |
+
out = self.pooling(out)
|
24 |
+
out = self.flatten(out)
|
25 |
+
out = self.fc(out)
|
26 |
+
return out
|
27 |
+
|
28 |
+
# Initialize the model
|
29 |
+
"""
|
30 |
+
https://download.pytorch.org/models/vgg16-397923af.pth
|
31 |
+
"""
|
32 |
+
MODEL = models.vgg16(pretrained=True)
|
33 |
+
NEW_MODEL = FeatureExtractor(MODEL)
|
34 |
+
|
35 |
+
# Transform the image, so it becomes readable with the model
|
36 |
+
"""
|
37 |
+
Without center crop isn't working!
|
38 |
+
-> Así que se lo puse pero al final no me afecta en nada a la imagen mientras el resize sea el mismo.
|
39 |
+
https://stackoverflow.com/questions/69334048/understanding-transforms-resize-and-centercrop-with-same-size
|
40 |
+
https://discuss.pytorch.org/t/transforms-resize-vs-centercrop/86588/2
|
41 |
+
"""
|
42 |
+
TRANSFORMS = transforms.Compose([
|
43 |
+
transforms.ToPILImage(),
|
44 |
+
transforms.Resize(448),
|
45 |
+
transforms.CenterCrop(448),
|
46 |
+
transforms.ToTensor()
|
47 |
+
])
|
48 |
+
|
49 |
+
|
50 |
+
def extract_features_image(img):
|
51 |
+
# Transform the image
|
52 |
+
img = TRANSFORMS(img)
|
53 |
+
# Reshape the image. PyTorch model reads 4-dimensional tensor
|
54 |
+
# [batch_size, channels, width, height]
|
55 |
+
img = img.reshape(1, 3, 448, 448)
|
56 |
+
|
57 |
+
# We only extract features, so we don't need gradient
|
58 |
+
with torch.no_grad():
|
59 |
+
# Extract the feature from the image
|
60 |
+
feature = NEW_MODEL(img)
|
61 |
+
|
62 |
+
return feature.tolist()[0]
|
63 |
+
|
tools/load_database.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
|
3 |
+
# Default values
|
4 |
+
filename = "./database/dataframes/all_database.csv"
|
5 |
+
DATABASE = pd.read_csv(filename)
|
6 |
+
|
7 |
+
|
8 |
+
def select_database(marca, prenda):
|
9 |
+
filtered_df = DATABASE[DATABASE['tag'] == prenda]
|
10 |
+
|
11 |
+
if marca != "Ninguno":
|
12 |
+
filtered_df = filtered_df[filtered_df['marca'] == marca]
|
13 |
+
|
14 |
+
# Restar index y eliminar el ex index column
|
15 |
+
filtered_df.reset_index(inplace=True, drop=True)
|
16 |
+
|
17 |
+
return filtered_df
|
tools/search.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy.spatial import cKDTree
|
3 |
+
import ast
|
4 |
+
|
5 |
+
|
6 |
+
def search_similar_products(vgg_search, df, number_of_items=5):
|
7 |
+
"""Find similar products"""
|
8 |
+
|
9 |
+
try:
|
10 |
+
# Similar features
|
11 |
+
indexes = df.index
|
12 |
+
indexes_list = indexes.tolist()
|
13 |
+
vgg_vector_list = df["vgg_vector"].tolist()
|
14 |
+
vgg_vector_list_interested = [ast.literal_eval(vgg_vector_list[i]) for i in indexes_list]
|
15 |
+
|
16 |
+
# Searching
|
17 |
+
matrix = np.array(vgg_vector_list_interested)
|
18 |
+
|
19 |
+
# [0]: distances
|
20 |
+
# [1]: indexes
|
21 |
+
results_indexes = cKDTree(matrix).query(vgg_search, k=number_of_items)[1]
|
22 |
+
|
23 |
+
# Make output
|
24 |
+
images_urls_list = []
|
25 |
+
for index in results_indexes:
|
26 |
+
image_url = df.iloc[index]['image_urls']
|
27 |
+
images_urls_list.append(image_url)
|
28 |
+
|
29 |
+
result = {
|
30 |
+
"success": True,
|
31 |
+
"result": images_urls_list,
|
32 |
+
}
|
33 |
+
|
34 |
+
except:
|
35 |
+
result = {
|
36 |
+
"success": False,
|
37 |
+
"result": [],
|
38 |
+
}
|
39 |
+
|
40 |
+
return result
|