File size: 915 Bytes
8ddb2de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import streamlit as st
from PIL import Image
from torchvision.models import resnet50, ResNet50_Weights

st.title("Simple Classification")
file_up = st.file_uploader("upload an image", type="jpg")


def func(image):
    img = Image.open(image).convert('RGB')
    weights = ResNet50_Weights.DEFAULT
    model = resnet50(weights=weights)
    model.eval()

    preprocess = weights.transforms()
    batch = preprocess(img).unsqueeze(0)

    prediction = model(batch).squeeze(0).softmax(0)
    class_id = prediction.argmax().item()
    score = prediction[class_id].item()
    category_name = weights.meta["categories"][class_id]
    st.write("category: ", category_name, "\n")
    score *= 100
    score = ('%.2f' % score)
    st.write("score:", score, "%")


if file_up is not None:
    image = Image.open(file_up)
    st.image(image, caption='Uploaded Image', use_column_width=True)
    st.write("")
    func(file_up)