|
import torch |
|
import torchvision.models as models |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
import os |
|
import streamlit as st |
|
|
|
|
|
@st.cache() |
|
def load_model(): |
|
print("Loading model...") |
|
|
|
model = models.resnet50(pretrained=True) |
|
model.eval() |
|
|
|
|
|
transform = transforms.Compose( |
|
[ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
|
|
input_dir = "lfw" |
|
|
|
|
|
features_dict = {} |
|
|
|
|
|
model = torch.nn.Sequential(*list(model.children())[:-1]) |
|
|
|
|
|
for root, dirs, files in os.walk(input_dir): |
|
for file in files: |
|
|
|
if file.endswith(".jpg"): |
|
|
|
image = Image.open(os.path.join(root, file)) |
|
|
|
|
|
image = transform(image) |
|
|
|
|
|
image = image.unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
features = model(image).squeeze() |
|
|
|
|
|
features_dict[os.path.join(root, file)] = features.numpy() |
|
|
|
return features_dict |
|
|
|
|
|
@st.cache() |
|
def create_nearest_neighbors_object(features_dict): |
|
print("Creating nearest neighbors object...") |
|
import numpy as np |
|
from sklearn.neighbors import NearestNeighbors |
|
|
|
|
|
features_list = list(features_dict.values()) |
|
|
|
|
|
features_array = np.array(features_list) |
|
|
|
|
|
nn = NearestNeighbors(n_neighbors=11, metric="euclidean") |
|
|
|
|
|
nn.fit(features_array) |
|
|
|
return nn |
|
|
|
|
|
|
|
def get_nearest_neighbors(image_path): |
|
|
|
query_image = image_path |
|
print(query_image) |
|
|
|
|
|
for key, value in features_dict.items(): |
|
if key == query_image: |
|
query_features = value |
|
query_features = query_features.reshape(1, -1) |
|
distances, indices = nn.kneighbors(query_features) |
|
indices = indices[0] |
|
distances = distances[0] |
|
for i in range(1, 11): |
|
image = Image.open(list(features_dict.keys())[indices[i]]) |
|
st.image( |
|
image, |
|
caption="Distance: " + str(distances[i]), |
|
use_column_width=True, |
|
) |
|
|
|
|
|
|
|
allow_output_mutation = True |
|
features_dict = load_model() |
|
nn = create_nearest_neighbors_object(features_dict) |
|
|
|
|
|
st.title("Similarity Search") |
|
|
|
|
|
st.write("This app finds the 10 most similar images to a query image.") |
|
|
|
query_image = st.selectbox("Or select an image from the list", os.listdir("lfw")) |
|
|
|
print("lfw/" + query_image + "/" + query_image + "_0001.jpg") |
|
st.image( |
|
"lfw/" + query_image + "/" + query_image + "_0001.jpg", |
|
caption="Query Image", |
|
use_column_width=True, |
|
) |
|
|
|
if st.button("Find Similar Images"): |
|
|
|
nearest10 = get_nearest_neighbors( |
|
"lfw/" + query_image + "/" + query_image + "_0001.jpg" |
|
) |
|
|