|
|
|
import gradio as gr |
|
from create_effnet import create_effnetb2_instance |
|
from typing import Dict, Tuple, List |
|
from PIL import Image |
|
from torch import nn |
|
import torchvision |
|
from pathlib import Path |
|
from timeit import default_timer as timer |
|
import torch |
|
|
|
def predict(img, |
|
classes:List[str]=['pizza', 'steak', 'sushi'] |
|
)->Tuple[Dict, int]: |
|
|
|
model, transform = create_effnetb2_instance(num_classes=len(classes), |
|
device='cpu') |
|
|
|
weight_path = Path("EfficientNet_B2_10_Epochs.pth") |
|
|
|
|
|
model.load_state_dict(state_dict=torch.load(f=weight_path, |
|
map_location="cpu")) |
|
|
|
|
|
start = timer() |
|
|
|
|
|
transformed_image = transform(img) |
|
|
|
|
|
model.eval() |
|
with torch.inference_mode(): |
|
logit = model(transformed_image.unsqueeze(dim=0)) |
|
pred_probs = torch.softmax(input=logit, |
|
dim=1) |
|
preds = torch.argmax(input=pred_probs, |
|
dim=1).item() |
|
|
|
|
|
end = timer() |
|
|
|
pred_label_prob_dict = dict() |
|
for idx, probs in enumerate(pred_probs.squeeze(dim=0).tolist()): |
|
class_name = classes[idx] |
|
pred_label_prob_dict[class_name] = probs |
|
|
|
|
|
pred_time = round(end-start, 4) |
|
|
|
return (pred_label_prob_dict, pred_time) |
|
|
|
example_img_path = Path("examples") |
|
title = "Pizza π, Steak π₯©, Sushi π£ Classifier" |
|
description = "Image Classifier which works on Feature Extractor EffNet B2 model and trained on Pizza, Steak, Sushi dataset" |
|
example_images = [[img] for img in example_img_path.iterdir()] |
|
|
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=[gr.Label(num_top_classes=3, |
|
label="Model thinks"), |
|
gr.Number(label="Time Taken (in seconds)")], |
|
title=title, |
|
description=description, |
|
examples=example_images) |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch(share=True) |
|
|
|
|
|
|