repositivator's picture
Update app.py
8ccff8c
import gradio as gr
import os
import torch
import torchvision
from model import create_model
from timeit import default_timer as timer
with open("class_names.txt", "r") as f:
class_names = [class_name.strip() for class_name in f.readlines()]
model, model_transforms = create_model(num_classes=len(class_names))
model.load_state_dict(torch.load("ConvNeXt_Tiny_101classes_20_10epochs.pth", map_location=torch.device('cpu'))) # load -> load_state_dict
model = model.to('cpu') # Place the model on the CPU
def predict(img):
time_start = timer()
weights = torchvision.models.ConvNeXt_Tiny_Weights.DEFAULT
transform_convnext_tiny = weights.transforms()
img_tensor = transform_convnext_tiny(img).unsqueeze(dim=0) # [Channels, Height, Width] -> [Batch_size, Channels, Height, Width]
model.eval()
with torch.inference_mode():
predicted_probs = model(img_tensor).softmax(dim=1)
# Class name & predicted probability for each class (required by Gradio)
pred_labels_probs = {}
for i in range(len(class_names)):
pred_labels_probs[class_names[i]] = float(predicted_probs[0][i])
return pred_labels_probs, round(timer() - time_start, 5)
app = gr.Interface(fn=predict, # mapping function for [ input -> output ]
inputs=gr.Image(type="pil"), # Input data
outputs=[gr.Label(num_top_classes=3, label="Predictions"), # Output data (fn function's return values)
gr.Number(label="Inference time (s)")],
examples=[["examples/" + example] for example in os.listdir("examples")],
title='ConvNeXt_Food101',
description='A ConvNext CV model to classify 101 foods',
article='Model trained on 150 images per class')
app.launch()