BirdClassifier / app.py
taufiqdp's picture
Update app.py
94e002e
import torch
import torchvision
import gradio as gr
import pathlib
import random
from torch import nn
from typing import Tuple, Dict
from PIL import Image
from timeit import default_timer as timer
from typing import Tuple, Dict
device = 'cuda' if torch.cuda.is_available() else 'cpu'
with open('class-names.txt', 'r') as f:
class_names = f.read().split('\n')[:-1]
def load_model() -> Tuple[torch.nn.Module, torchvision.transforms.Compose]:
weights = torchvision.models.ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1
shufflenet_transforms = weights.transforms()
shufflenet = torchvision.models.shufflenet_v2_x1_5(weights=weights)
shufflenet.fc = nn.Linear(in_features=1024, out_features=len(class_names), bias=True)
state_dict = torch.load('ShuffleNetV2.pt', map_location=device)
shufflenet.load_state_dict(state_dict)
return shufflenet, shufflenet_transforms
model, transforms = load_model()
def predict(img) -> Tuple[Dict, float]:
start = timer()
model.to(device)
model.eval()
with torch.inference_mode():
transformed_img = transforms(img).to(device)
logits = model(transformed_img.unsqueeze(dim=0))
pred_prob = torch.softmax(logits, dim=1)
pred_dict = {class_names[i]:pred_prob.squeeze(0)[i].item() for i in range(len(class_names))}
pred_time = round(timer() - start, 5)
return pred_dict, pred_time
example_paths = list(pathlib.Path('examples').glob("*/*.jpg"))
example_list = [[str(filepath)] for filepath in random.sample(example_paths, k=6)]
title = 'Bird Species Classifier 🐦'
description = 'A [ShuffleNetV2](https://pytorch.org/vision/main/models/shufflenetv2.html) feature extractor computer vision model to classify images of [525 bird species](https://www.kaggle.com/datasets/gpiosenka/100-bird-species/).'
article = 'Made with ❤️🤗 by [me](https://www.linkedin.com/in/taufiq-dwi-purnomo/).'
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type='pil'),
outputs=[gr.Label(num_top_classes=3, label='Predictions'),
gr.Number(label="Prediction time (s)")],
description=description,
title=title,
allow_flagging='never',
examples=example_list,
article=article
)
demo.launch(debug=False)