import gradio as gr from model import create_vit_instance from pathlib import Path import torch import torchvision from PIL import Image from typing import List, Dict, Tuple from timeit import default_timer as timer # Reading all available classes with open('class_names.txt', 'r') as f: all_classes = [name.replace('\n', '') for name in f.readlines()] demo_vit_model, demo_vit_transforms = create_vit_instance(num_classes=len(all_classes), device='cpu') weights_path = Path("ViT_Caltech101_five_epochs.pth") demo_vit_model.load_state_dict(torch.load(f=weights_path, map_location='cpu')) ## Creating predict method => It returns prediction probability dictionary as well as time taken to do the prediction def predict(img_path: str, model:torch.nn.Module=demo_vit_model, transform: torchvision.transforms=demo_vit_transforms, classes:List[str] = all_classes)->Tuple[Dict, int]: pred_prob_dict = dict() model = model.to('cpu') # img_path = Image.open(img_path) transformed_image = transform(img_path) start = timer() model.eval() with torch.inference_mode(): batch_img = transformed_image.unsqueeze(dim=0).to(device='cpu') logit = model(batch_img) pred_probs = torch.softmax(input=logit, dim=1) preds = torch.argmax(input=pred_probs, dim=1).item() end = timer() total_time = round(end - start, 4) pred_probs = pred_probs[0].tolist() for idx in range(len(pred_probs)): class_name = classes[idx] pred_prob_dict[class_name] = pred_probs[idx] sorted_order = sorted(pred_prob_dict.items(), key=lambda kv: kv[1], reverse=True) return (pred_prob_dict, total_time) title = "ObjectVision" description = "ViT Feature Extractor trained for Image Classification based on Caltech101 dataset." samples = [[path] for path in Path("examples").iterdir()] demo = gr.Interface(fn=predict, title=title, description=description, inputs=gr.Image(type="pil"), examples=samples, outputs=[ gr.Label(num_top_classes=5, label="Model thinks"), gr.Number(label="Prediction time (in seconds)") ]) if __name__ == "__main__": demo.launch(debug=True)