BigTimeCoderSean's picture
Update app.py
a972ea6
raw
history blame contribute delete
No virus
2.1 kB
import gradio as gr
import torchvision
from torch import nn
import torch
from torchvision import models
from timeit import default_timer as timer
from typing import Tuple, Dict
#class names
with open('class_names.txt', "r") as f:
class_names = [car.strip() for car in f.readlines()]
#model and transforms preparation
effnetb0_weights = models.EfficientNet_B0_Weights.DEFAULT
effnetb0 = torchvision.models.efficientnet_b0(weights = effnetb0_weights)
effnetb0_transforms = effnetb0_weights.transforms()
#freeze params
for param in effnetb0.parameters():
param.requires_grad = False
#change classifier
effnetb0.classifier = nn.Sequential(
nn.Dropout(p=.2),
nn.Linear(in_features = 1280,
out_features = 196)
)
#load saved weights
effnetb0.load_state_dict(torch.load('pretrained_effnetb0_stanford_cars_20_percent.pth'),
map_location=torch.device("cpu"))
#predict function
def predict(img) -> Tuple[Dict, float]:
start_time = timer()
#put model into eval mode
effnetb0.eval()
with torch.inference_mode():
pred_logits = effnetb0(img.unsqueeze(0))
pred_probs = torch.softmax(pred_logits, dim = 1)
# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
end_time = timer()
time = round(end_time - start_time, 5)
return pred_labels_and_probs, time
#gradio app
title = 'effnetb0'
description = 'Pretrained effnetb0 model on stanford cars dataset'
#create example list
example_list = [["examples/" + example] for example in os.listdir("examples")]
# Create Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[
gr.Label(num_top_classes=5, label="Predictions"),
gr.Number(label="Prediction time (s)"),
],
examples=example_list,
title=title,
description=description
)
# Launch the app!
demo.launch()