BirdClassifier / app.py
taufiqdp's picture
Update app.py
f3dc754
raw history blame
No virus
2.14 kB
import torch
import torchvision
import gradio as gr
import pathlib
import random
from torch import nn
from typing import Tuple, Dict
from PIL import Image
from timeit import default_timer as timer
from typing import Tuple, Dict
device = 'cuda' if torch.cuda.is_available() else 'cpu'
with open('class-names.txt', 'r') as f:
class_names = f.read().split('\n')[:-1]
def load_model() -> Tuple[torch.nn.Module, torchvision.transforms.Compose]:
weights = torchvision.models.ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1
shufflenet_transforms = weights.transforms()
shufflenet = torchvision.models.shufflenet_v2_x1_5(weights=weights)
shufflenet.fc = nn.Linear(in_features=1024, out_features=len(class_names), bias=True)
state_dict = torch.load('ShuffleNetV2.pt', map_location=device)
shufflenet.load_state_dict(state_dict)
return shufflenet, shufflenet_transforms
model, transforms = load_model()
def predict(img) -> Tuple[Dict, float]:
start = timer()
model.to(device)
model.eval()
with torch.inference_mode():
transformed_img = transforms(img).to(device)
logits = model(transformed_img.unsqueeze(dim=0))
pred_prob = torch.softmax(logits, dim=1)
pred_dict = {class_names[i]:pred_prob.squeeze(0)[i].item() for i in range(len(class_names))}
pred_time = round(timer() - start, 5)
return pred_dict, pred_time
example_paths = list(pathlib.Path('examples').glob("*.jpg"))
example_list = [[str(filepath)] for filepath in random.sample(example_paths, k=5)]
title = 'Birds Species Classifier 🐦'
description = 'A [ShuffleNetV2](https://pytorch.org/vision/main/models/shufflenetv2.html) feature extractor computer vision model to classify images of [525 species birds](https://www.kaggle.com/datasets/gpiosenka/100-bird-species/).'
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type='pil'),
outputs=[gr.Label(num_top_classes=3, label='Predictions'),
gr.Number(label="Prediction time (s)")],
description=description,
title=title,
allow_flagging='never',
examples=example_list
)
demo.launch(debug=False)