| import streamlit as st | |
| from PIL import Image | |
| from torchvision.models import resnet50, ResNet50_Weights | |
| st.title("Simple Classification") | |
| file_up = st.file_uploader("upload an image", type="jpg") | |
| def func(image): | |
| img = Image.open(image).convert('RGB') | |
| weights = ResNet50_Weights.DEFAULT | |
| model = resnet50(weights=weights) | |
| model.eval() | |
| preprocess = weights.transforms() | |
| batch = preprocess(img).unsqueeze(0) | |
| prediction = model(batch).squeeze(0).softmax(0) | |
| class_id = prediction.argmax().item() | |
| score = prediction[class_id].item() | |
| category_name = weights.meta["categories"][class_id] | |
| st.write("category: ", category_name, "\n") | |
| score *= 100 | |
| score = ('%.2f' % score) | |
| st.write("score:", score, "%") | |
| if file_up is not None: | |
| image = Image.open(file_up) | |
| st.image(image, caption='Uploaded Image', use_column_width=True) | |
| st.write("") | |
| func(file_up) | |