Spaces:
Sleeping
Sleeping
import gradio as gr | |
from model import create_vit_instance | |
from pathlib import Path | |
import torch | |
import torchvision | |
from PIL import Image | |
from typing import List, Dict, Tuple | |
from timeit import default_timer as timer | |
# Reading all available classes | |
with open('class_names.txt', 'r') as f: | |
all_classes = [name.replace('\n', '') for name in f.readlines()] | |
demo_vit_model, demo_vit_transforms = create_vit_instance(num_classes=len(all_classes), | |
device='cpu') | |
weights_path = Path("ViT_Caltech101_five_epochs.pth") | |
demo_vit_model.load_state_dict(torch.load(f=weights_path, | |
map_location='cpu')) | |
## Creating predict method => It returns prediction probability dictionary as well as time taken to do the prediction | |
def predict(img_path: str, | |
model:torch.nn.Module=demo_vit_model, | |
transform: torchvision.transforms=demo_vit_transforms, | |
classes:List[str] = all_classes)->Tuple[Dict, int]: | |
pred_prob_dict = dict() | |
model = model.to('cpu') | |
# img_path = Image.open(img_path) | |
transformed_image = transform(img_path) | |
start = timer() | |
model.eval() | |
with torch.inference_mode(): | |
batch_img = transformed_image.unsqueeze(dim=0).to(device='cpu') | |
logit = model(batch_img) | |
pred_probs = torch.softmax(input=logit, | |
dim=1) | |
preds = torch.argmax(input=pred_probs, | |
dim=1).item() | |
end = timer() | |
total_time = round(end - start, 4) | |
pred_probs = pred_probs[0].tolist() | |
for idx in range(len(pred_probs)): | |
class_name = classes[idx] | |
pred_prob_dict[class_name] = pred_probs[idx] | |
sorted_order = sorted(pred_prob_dict.items(), key=lambda kv: kv[1], reverse=True) | |
return (pred_prob_dict, total_time) | |
title = "ObjectVision" | |
description = "ViT Feature Extractor trained for Image Classification based on Caltech101 dataset." | |
samples = [[path] for path in Path("examples").iterdir()] | |
demo = gr.Interface(fn=predict, | |
title=title, | |
description=description, | |
inputs=gr.Image(type="pil"), | |
examples=samples, | |
outputs=[ | |
gr.Label(num_top_classes=5, | |
label="Model thinks"), | |
gr.Number(label="Prediction time (in seconds)") | |
]) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |