|
import streamlit as st |
|
from PIL import Image |
|
from torchvision.models import resnet50, ResNet50_Weights |
|
|
|
st.title("Simple Classification") |
|
file_up = st.file_uploader("upload an image", type="jpg") |
|
|
|
|
|
def func(image): |
|
img = Image.open(image).convert('RGB') |
|
weights = ResNet50_Weights.DEFAULT |
|
model = resnet50(weights=weights) |
|
model.eval() |
|
|
|
preprocess = weights.transforms() |
|
batch = preprocess(img).unsqueeze(0) |
|
|
|
prediction = model(batch).squeeze(0).softmax(0) |
|
class_id = prediction.argmax().item() |
|
score = prediction[class_id].item() |
|
category_name = weights.meta["categories"][class_id] |
|
st.write("category: ", category_name, "\n") |
|
score *= 100 |
|
score = ('%.2f' % score) |
|
st.write("score:", score, "%") |
|
|
|
|
|
if file_up is not None: |
|
image = Image.open(file_up) |
|
st.image(image, caption='Uploaded Image', use_column_width=True) |
|
st.write("") |
|
func(file_up) |
|
|