Ammar2k commited on
Commit
d246b53
1 Parent(s): 4aef0b2

initial commit

Browse files
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+
5
+ from model import create_vit
6
+ from timeit import default_timer as timer
7
+ from typing import Tuple, Dict
8
+
9
+ class_names = ["NORMAL", "PNEUMONIA"]
10
+
11
+ vit_model, vit_transforms = create_vit(seed=42)
12
+
13
+ vit_model.load_state_dict(
14
+ torch.load(
15
+ f="finetuned_vit_b_16_pneumonia_feature_extractor.pth",
16
+ map_location=torch.device("cpu")
17
+ )
18
+ )
19
+
20
+ def predict(img):
21
+ start_timer = timer()
22
+
23
+ img = vit_transforms(img).unsqueeze(0)
24
+
25
+ vit_model.eval()
26
+ with torch.inference_mode():
27
+ pred_prob_int = torch.sigmoid(vit_model(img)).round().int().squeeze()
28
+
29
+ if pred_prob_int.item() == 1:
30
+ class_name = class_names[1]
31
+ else:
32
+ class_name = class_names[0]
33
+
34
+ pred_time = round(timer() - start_timer, 5)
35
+
36
+ return class_name, pred_time
37
+
38
+ title = "Detect Pneumonia from chest X-Ray"
39
+ description = "A ViT feature extractor Computer Vision model to detect Pneumonia from X-Ray Images."
40
+ article = "Access project repository at [GitHub](https://github.com/Ammar2k/pneumonia_detection)"
41
+
42
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
43
+
44
+ demo = gr.Interface(fn=predict,
45
+ inputs=gr.Image(type="pil"),
46
+ outputs=[gr.Label(num_top_classes=6, label="Predictions"),
47
+ gr.Number(label="Prediction time(s)")],
48
+ examples=example_list,
49
+ title=title,
50
+ description=description,
51
+ article=article
52
+ )
53
+
54
+ demo.launch()
examples/IM-0117-0001.jpeg ADDED
examples/IM-0154-0001.jpeg ADDED
examples/person16_bacteria_54.jpeg ADDED
examples/person3_bacteria_10.jpeg ADDED
finetuned_vit_b_16_pneumonia_feature_extractor.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1737cbcc3556394cf4f82fa7f28cb985d11b32212adeff79fad3d7259674b17a
3
+ size 343270545
model.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+
4
+ def create_vit(seed: int=42):
5
+ weights = torchvision.models.ViT_B_16_Weights.DEFAULT
6
+
7
+ transforms = weights.transforms()
8
+
9
+ model = torchvision.models.vit_b_16(weights=weights)
10
+
11
+ for param in model.parameters():
12
+ param.requires_grad = False
13
+
14
+ torch.manual_seed(seed)
15
+ model.heads = torch.nn.Sequential(torch.nn.LayerNorm(normalized_shape=768),
16
+ torch.nn.Linear(in_features=768, out_features=1))
17
+ return model, transforms