demo-hse-22 / app.py
justheuristic's picture
Update app.py
efb2a0e
raw
history blame contribute delete
No virus
1.68 kB
import streamlit as st
from io import StringIO
import requests
import torch
from torchvision.models.inception import inception_v3
import matplotlib.pyplot as plt
from skimage.transform import resize
@st.cache
def load_stuff():
model = inception_v3(pretrained=True, # load existing weights
transform_input=True, # preprocess input image the same way as in training
)
model.aux_logits = False # don't predict intermediate logits (yellow layers at the bottom)
model.train(False)
LABELS_URL = 'https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json'
labels = {i: c for i, c in enumerate(requests.get(LABELS_URL).json())}
return model, labels
model, labels = load_stuff()
def transform_input(img):
return torch.as_tensor(img.reshape([1, 299, 299, 3]).transpose([0, 3, 1, 2]), dtype=torch.float32)
def predict(img):
img = transform_input(img)
probs = torch.nn.functional.softmax(model(img), dim=-1)
probs = probs.data.numpy()
top_ix = probs.ravel().argsort()[-1:-10:-1]
s = 'top-10 classes are: \n\n [prob : class label]\n\n'
for l in top_ix:
s = s + '%.4f :\t%s' % (probs.ravel()[l], labels[l].split(',')[0]) + '\n\n'
return s
st.markdown("### Hello dude!")
uploaded_file = st.file_uploader("Choose a file")
if uploaded_file is not None:
# To read file as bytes:
bytes_data = uploaded_file.getvalue()
with open('tmp', 'wb')as f:
f.write(bytes_data)
img = resize(plt.imread('tmp'), (299, 299))[..., :3]
top_classes = predict(img)
st.markdown(top_classes)
st.image('tmp')