import streamlit as st from PIL import Image from torchvision.models import resnet50, ResNet50_Weights st.title("Simple Classification") file_up = st.file_uploader("upload an image", type="jpg") def func(image): img = Image.open(image).convert('RGB') weights = ResNet50_Weights.DEFAULT model = resnet50(weights=weights) model.eval() preprocess = weights.transforms() batch = preprocess(img).unsqueeze(0) prediction = model(batch).squeeze(0).softmax(0) class_id = prediction.argmax().item() score = prediction[class_id].item() category_name = weights.meta["categories"][class_id] st.write("category: ", category_name, "\n") score *= 100 score = ('%.2f' % score) st.write("score:", score, "%") if file_up is not None: image = Image.open(file_up) st.image(image, caption='Uploaded Image', use_column_width=True) st.write("") func(file_up)