Spaces:
Build error
Build error
import gradio as gr | |
import os | |
import torch | |
import torchvision | |
from model import create_model | |
from timeit import default_timer as timer | |
with open("class_names.txt", "r") as f: | |
class_names = [class_name.strip() for class_name in f.readlines()] | |
model, model_transforms = create_model(num_classes=len(class_names)) | |
model.load_state_dict(torch.load("ConvNeXt_Tiny_101classes_20_10epochs.pth", map_location=torch.device('cpu'))) # load -> load_state_dict | |
model = model.to('cpu') # Place the model on the CPU | |
def predict(img): | |
time_start = timer() | |
weights = torchvision.models.ConvNeXt_Tiny_Weights.DEFAULT | |
transform_convnext_tiny = weights.transforms() | |
img_tensor = transform_convnext_tiny(img).unsqueeze(dim=0) # [Channels, Height, Width] -> [Batch_size, Channels, Height, Width] | |
model.eval() | |
with torch.inference_mode(): | |
predicted_probs = model(img_tensor).softmax(dim=1) | |
# Class name & predicted probability for each class (required by Gradio) | |
pred_labels_probs = {} | |
for i in range(len(class_names)): | |
pred_labels_probs[class_names[i]] = float(predicted_probs[0][i]) | |
return pred_labels_probs, round(timer() - time_start, 5) | |
app = gr.Interface(fn=predict, # mapping function for [ input -> output ] | |
inputs=gr.Image(type="pil"), # Input data | |
outputs=[gr.Label(num_top_classes=3, label="Predictions"), # Output data (fn function's return values) | |
gr.Number(label="Inference time (s)")], | |
examples=[["examples/" + example] for example in os.listdir("examples")], | |
title='ConvNeXt_Food101', | |
description='A ConvNext CV model to classify 101 foods', | |
article='Model trained on 150 images per class') | |
app.launch() | |