import gradio as gr import os import torch from model import create_vit from timeit import default_timer as timer from typing import Tuple, Dict class_names = ["NORMAL", "PNEUMONIA"] vit_model, vit_transforms = create_vit(seed=42) vit_model.load_state_dict( torch.load( f="finetuned_vit_b_16_pneumonia_feature_extractor.pth", map_location=torch.device("cpu") ) ) def predict(img): start_timer = timer() img = vit_transforms(img).unsqueeze(0) vit_model.eval() with torch.inference_mode(): pred_prob_int = torch.sigmoid(vit_model(img)).round().int().squeeze() if pred_prob_int.item() == 1: class_name = class_names[1] else: class_name = class_names[0] pred_time = round(timer() - start_timer, 5) return class_name, pred_time title = "Detect Pneumonia from chest X-Ray" description = "A ViT feature extractor Computer Vision model to detect Pneumonia from X-Ray Images." article = "Access project repository at [GitHub](https://github.com/Ammar2k/pneumonia_detection)" example_list = [["examples/" + example] for example in os.listdir("examples")] demo = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=[gr.Label(num_top_classes=6, label="Predictions"), gr.Number(label="Prediction time(s)")], examples=example_list, title=title, description=description, article=article ) demo.launch()