snaramirez872's picture
added files
dc282ed
raw
history blame
2.44 kB
import io
import tarfile
import numpy as np
import random as RAND
import torch
import torchvision.transforms as TRNSFM
import torchvision.models as MDLS
from PIL import Image as IMG
from scipy.spatial.distance import cosine
import streamlit as st
def similar(image): # Function for Streamlit App
pict = form(IMG.open(image).convert('RGB'))
pictFeats = mod(pict.unsqueeze(0)).detach().numpy().squeeze()
for na, feat in feats.items():
s = 1 - cosine(pictFeats, feat)
simScores.append((na, s))
simScores.sort(key=lambda x: x[1], reverse=True)
st.write("### Selected Image")
test = IMG.open(image)
test.show()
print('\n')
# 10 Most Similar Images from Dictionary
st.write("### 10 Most Similar Images")
for na in simScores[:10]:
for x in range(10):
st.write(f"### {x}")
new_na = na[:3] + "images/" + na[3:]
new_path = "http://vis-www.cs.umass.edu/" + new_na
simImages = IMG.open(new_path)
simImages.show()
mod = MDLS.resnet50(pretrained=True)
mod.eval()
mod = torch.nn.Sequential(*list(mod.children())[:-1])
inFile = tarfile.open('/datasets/lfw.tar', 'r')
feats = {}
simScores = [] # Similarity Scores for Later
form = TRNSFM.Compose([
TRNSFM.Resize(256),
TRNSFM.CenterCrop(224),
TRNSFM.ToTensor(),
TRNSFM.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
]) # Image Pre-processing
stuffs = inFile.getmembers()
for stuff in stuffs: # Going through the TAR file
f = inFile.extractfile(stuff)
if stuff.isdir():
continue
if stuff.name.lower().endswith(('.jpg', '.jpeg', '.png')):
n = stuff.name
pic = form(IMG.open(io.BytesIO(f.read())).convert('RGB')) # Pre-processes the image before feeding it into the model
feats[n] = mod(pic.unsqueeze(0)).detach().numpy().squeeze()
# Stuff for App
st.title("Similar Image Finder")
upload = st.file_uploader("Upload an Image...", type=['.jpg', '.jpeg', '.png'])
if upload is not None:
similar(upload)
st.write("## OR")
# Random Image Selector from 5 Pictures
randImages = [
'/datasets/random-images/img1.jpg',
'/datasets/random-images/img2.jpg',
'/datasets/random-images/img3.jpg',
'/datasets/random-images/img4.jpg',
'/datasets/random-images/img5.jpg'
]
if st.button("Surprise Me!"): # Button
imageOptOne = RAND.choice(randImages)
similar(imageOptOne)
inFile.close()