File size: 915 Bytes
8ddb2de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import streamlit as st
from PIL import Image
from torchvision.models import resnet50, ResNet50_Weights
st.title("Simple Classification")
file_up = st.file_uploader("upload an image", type="jpg")
def func(image):
img = Image.open(image).convert('RGB')
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()
preprocess = weights.transforms()
batch = preprocess(img).unsqueeze(0)
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
st.write("category: ", category_name, "\n")
score *= 100
score = ('%.2f' % score)
st.write("score:", score, "%")
if file_up is not None:
image = Image.open(file_up)
st.image(image, caption='Uploaded Image', use_column_width=True)
st.write("")
func(file_up)
|