Spaces:
Runtime error
Runtime error
import streamlit as st | |
from io import StringIO | |
import requests | |
import torch | |
from torchvision.models.inception import inception_v3 | |
import matplotlib.pyplot as plt | |
from skimage.transform import resize | |
def load_stuff(): | |
model = inception_v3(pretrained=True, # load existing weights | |
transform_input=True, # preprocess input image the same way as in training | |
) | |
model.aux_logits = False # don't predict intermediate logits (yellow layers at the bottom) | |
model.train(False) | |
LABELS_URL = 'https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json' | |
labels = {i: c for i, c in enumerate(requests.get(LABELS_URL).json())} | |
return model, labels | |
model, labels = load_stuff() | |
def transform_input(img): | |
return torch.as_tensor(img.reshape([1, 299, 299, 3]).transpose([0, 3, 1, 2]), dtype=torch.float32) | |
def predict(img): | |
img = transform_input(img) | |
probs = torch.nn.functional.softmax(model(img), dim=-1) | |
probs = probs.data.numpy() | |
top_ix = probs.ravel().argsort()[-1:-10:-1] | |
s = 'top-10 classes are: \n\n [prob : class label]\n\n' | |
for l in top_ix: | |
s = s + '%.4f :\t%s' % (probs.ravel()[l], labels[l].split(',')[0]) + '\n\n' | |
return s | |
st.markdown("### Hello dude!") | |
uploaded_file = st.file_uploader("Choose a file") | |
if uploaded_file is not None: | |
# To read file as bytes: | |
bytes_data = uploaded_file.getvalue() | |
with open('tmp', 'wb')as f: | |
f.write(bytes_data) | |
img = resize(plt.imread('tmp'), (299, 299))[..., :3] | |
top_classes = predict(img) | |
st.markdown(top_classes) | |
st.image('tmp') | |