Fashion-Search / utils.py
Ajay-user's picture
Upload 2 files
9bdc63e
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.neighbors import NearestNeighbors
import pickle
import requests
from PIL import Image
import streamlit as st
@st.cache_resource
def feature_extractor()->tf.keras.Sequential:
model = tf.keras.applications.MobileNetV2(weights='imagenet', include_top=False, input_shape=[80,60,3])
feature_extactor = tf.keras.Sequential([
model,
tf.keras.layers.Flatten()
])
return feature_extactor
@st.cache_data
def load_resource(resource_path):
with open(f'./Embeddings/{resource_path}', 'rb') as fp:
res = pickle.load(fp)
return res
class FashionSearch:
def __init__(self) -> None:
self.embeddings = load_resource(resource_path='image_embeddings.pkl')
self.name = np.array(load_resource(resource_path='image_ids.pkl'))
self.image_link = load_resource(resource_path='name_link_map.pkl')
self.feature_extractor = feature_extractor()
def KNN(self, metric:str='minkowski')->NearestNeighbors:
knn = NearestNeighbors(n_neighbors=10, algorithm='brute', metric='minkowski')
knn.fit(self.embeddings)
return knn
def image_feature_extraction(self, img:Image.Image):
sample_img_arr = np.array(img.resize((60,80)))
sample_img_arr = tf.keras.applications.mobilenet_v2.preprocess_input(sample_img_arr)
sample_features = self.feature_extractor(sample_img_arr[None, :])
return sample_features
def find_k_neighbors(self, sample_img:Image.Image, metric:str='minkowski')->list[int]:
knn = self.KNN(metric=metric)
features = self.image_feature_extraction(img=sample_img)
distance, indices = knn.kneighbors(X=features, n_neighbors=16)
return list(map(str, self.name[indices.flatten().tolist()]))