Image / app.py
timo1227's picture
Update app.py
5cd92f1
raw
history blame
4.06 kB
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import os
import streamlit as st
@st.cache()
def load_model():
print("Loading model...")
# Load the pre-trained ResNet-50 model and set it to eval mode
model = models.resnet50(pretrained=True)
model.eval()
# Define the transform to be applied to each input image
transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Define the directory containing the input images
input_dir = "lfw"
# Create a dictionary to store the feature vectors
features_dict = {}
# Modify the ResNet model to return the output of the penultimate layer
model = torch.nn.Sequential(*list(model.children())[:-1])
# Loop over each subdirectory and image file in the input directory
for root, dirs, files in os.walk(input_dir):
for file in files:
# Check if the file is a JPEG image
if file.endswith(".jpg"):
# Load the image
image = Image.open(os.path.join(root, file))
# Apply the transform to the image
image = transform(image)
# Reshape the image to add a batch dimension
image = image.unsqueeze(0)
# Extract the features from the model's penultimate layer
with torch.no_grad():
features = model(image).squeeze()
# Add the feature vector to the dictionary
features_dict[os.path.join(root, file)] = features.numpy()
return features_dict
@st.cache()
def create_nearest_neighbors_object(features_dict):
print("Creating nearest neighbors object...")
import numpy as np
from sklearn.neighbors import NearestNeighbors
# Create a list of the feature vectors
features_list = list(features_dict.values())
# Create a NumPy array of the feature vectors
features_array = np.array(features_list)
# Create a nearest neighbors object
nn = NearestNeighbors(n_neighbors=11, metric="euclidean")
# Fit the nearest neighbors object to the feature vectors
nn.fit(features_array)
return nn
# Create a get nearest neighbors function
def get_nearest_neighbors(image_path):
# Define the query image
query_image = image_path
print(query_image)
# Loop through the dictionary to find the 10 nearest neighbors to the query image
for key, value in features_dict.items():
if key == query_image:
query_features = value
query_features = query_features.reshape(1, -1)
distances, indices = nn.kneighbors(query_features)
indices = indices[0]
distances = distances[0]
for i in range(1, 11):
image = Image.open(list(features_dict.keys())[indices[i]])
st.image(
image,
caption="Distance: " + str(distances[i]),
use_column_width=True,
)
# SteamLit App
allow_output_mutation = True
features_dict = load_model()
nn = create_nearest_neighbors_object(features_dict)
# Title
st.title("Similarity Search")
# Subtitle
st.write("This app finds the 10 most similar images to a query image.")
query_image = st.selectbox("Or select an image from the list", os.listdir("lfw"))
# Display the query image from dir ./lfw/query_image/query_image_0001.jpg
print("lfw/" + query_image + "/" + query_image + "_0001.jpg")
st.image(
"lfw/" + query_image + "/" + query_image + "_0001.jpg",
caption="Query Image",
use_column_width=True,
)
# Find the 10 most similar images
if st.button("Find Similar Images"):
# Call the get nearest neighbors function
nearest10 = get_nearest_neighbors(
"lfw/" + query_image + "/" + query_image + "_0001.jpg"
)