Spaces:
Running
Running
File size: 1,949 Bytes
dec1eb1 4e80033 dec1eb1 815ed91 8405ca4 dec1eb1 069c641 dec1eb1 bf3a51e 069c641 bf3a51e 069c641 77a2024 dec1eb1 bf3a51e be2c3da 77a2024 bf3a51e dec1eb1 95ddfb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import transforms
# load the model
model = torch.jit.load("food_classifier_resnet18.ptl")
# Transformations that will be applied
the_transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.CenterCrop((224,224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])
# Classes
class_names = ['Apple Pie','Bibimbap','Cannoli','Edamame','Falafel','French Toast','Ice Cream','Ramen','Sushi','Tiramisu']
# Returns transformed image
def transform_img(img):
return the_transform(img)
# Returns string with class and probability
def classify_img(img):
# Applying transformation to the image
model_img = transform_img(img)
model_img = model_img.view(1,3,224,224)
# Running image through the model
model.eval()
with torch.no_grad():
result = model(model_img)
# Converting values to softmax values
result = F.softmax(result,dim=1)
# Grabbing top 3 indices and probabilities for each index
top3_prob, top3_catid = torch.topk(result,3)
# Dictionary I will display
model_output = {}
for i in range(top3_prob.size(1)):
model_output[class_names[top3_catid[0][i].item()]] = top3_prob[0][i].item()
print(model_output)
return model_output
demo = gr.Interface(classify_img,
inputs = gr.inputs.Image(type="pil"),
outputs = gr.outputs.Label(type="confidences",num_top_classes=3),
title = "Food Classifier!",
description = "Insert food image you would like to classify! Returns confidence % for the top three categories <br> Categories: Apple Pie, Bibimbap, Cannoli, Edamame, Falafel, French Toast, Ice Cream, Ramen, Sushi, Tiramisu",
)
demo.launch(inline=False) |