simple_model / Classification.py
QIAN137's picture
Upload Classification.py
8ddb2de
raw
history blame contribute delete
915 Bytes
import streamlit as st
from PIL import Image
from torchvision.models import resnet50, ResNet50_Weights
st.title("Simple Classification")
file_up = st.file_uploader("upload an image", type="jpg")
def func(image):
img = Image.open(image).convert('RGB')
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()
preprocess = weights.transforms()
batch = preprocess(img).unsqueeze(0)
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
st.write("category: ", category_name, "\n")
score *= 100
score = ('%.2f' % score)
st.write("score:", score, "%")
if file_up is not None:
image = Image.open(file_up)
st.image(image, caption='Uploaded Image', use_column_width=True)
st.write("")
func(file_up)