itzRahul's picture
Update app.py
e35dcdb
import gradio as gr
from model import create_vit_instance
from pathlib import Path
import torch
import torchvision
from PIL import Image
from typing import List, Dict, Tuple
from timeit import default_timer as timer
# Reading all available classes
with open('class_names.txt', 'r') as f:
all_classes = [name.replace('\n', '') for name in f.readlines()]
demo_vit_model, demo_vit_transforms = create_vit_instance(num_classes=len(all_classes),
device='cpu')
weights_path = Path("ViT_Caltech101_five_epochs.pth")
demo_vit_model.load_state_dict(torch.load(f=weights_path,
map_location='cpu'))
## Creating predict method => It returns prediction probability dictionary as well as time taken to do the prediction
def predict(img_path: str,
model:torch.nn.Module=demo_vit_model,
transform: torchvision.transforms=demo_vit_transforms,
classes:List[str] = all_classes)->Tuple[Dict, int]:
pred_prob_dict = dict()
model = model.to('cpu')
# img_path = Image.open(img_path)
transformed_image = transform(img_path)
start = timer()
model.eval()
with torch.inference_mode():
batch_img = transformed_image.unsqueeze(dim=0).to(device='cpu')
logit = model(batch_img)
pred_probs = torch.softmax(input=logit,
dim=1)
preds = torch.argmax(input=pred_probs,
dim=1).item()
end = timer()
total_time = round(end - start, 4)
pred_probs = pred_probs[0].tolist()
for idx in range(len(pred_probs)):
class_name = classes[idx]
pred_prob_dict[class_name] = pred_probs[idx]
sorted_order = sorted(pred_prob_dict.items(), key=lambda kv: kv[1], reverse=True)
return (pred_prob_dict, total_time)
title = "ObjectVision"
description = "ViT Feature Extractor trained for Image Classification based on Caltech101 dataset."
samples = [[path] for path in Path("examples").iterdir()]
demo = gr.Interface(fn=predict,
title=title,
description=description,
inputs=gr.Image(type="pil"),
examples=samples,
outputs=[
gr.Label(num_top_classes=5,
label="Model thinks"),
gr.Number(label="Prediction time (in seconds)")
])
if __name__ == "__main__":
demo.launch(debug=True)