bhimrazy commited on
Commit
65847f8
1 Parent(s): 624f6e5

Adds gradio app

Browse files
Files changed (1) hide show
  1. app.py +50 -0
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from src.model import DRModel
4
+ from torchvision import transforms as T
5
+
6
+ CHECKPOINT_PATH = "checkpoints/epoch=19-step=8800.ckpt"
7
+ model = DRModel.load_from_checkpoint(CHECKPOINT_PATH)
8
+
9
+ labels = {
10
+ 0: "No DR",
11
+ 1: "Mild",
12
+ 2: "Moderate",
13
+ 3: "Severe",
14
+ 4: "Proliferative DR",
15
+ }
16
+
17
+ transform = T.Compose(
18
+ [
19
+ T.Resize((192, 192)),
20
+ T.ToTensor(),
21
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
22
+ ]
23
+ )
24
+
25
+ # Define the prediction function
26
+ def predict(input_img):
27
+ input_img = transform(input_img).unsqueeze(0)
28
+ with torch.no_grad():
29
+ prediction = torch.nn.functional.softmax(model(input_img)[0], dim=0)
30
+ confidences = {labels[i]: float(prediction[i]) for i in labels}
31
+ return confidences
32
+
33
+
34
+ # Set up the Gradio app interface
35
+ dr_app = gr.Interface(
36
+ fn=predict,
37
+ inputs=gr.Image(type="pil"),
38
+ outputs=gr.Label(),
39
+ title="Diabetic Retinopathy Detection",
40
+ examples=[
41
+ "data/sample/10_left.jpeg",
42
+ "data/sample/10_right.jpeg",
43
+ "data/sample/15_left.jpeg",
44
+ "data/sample/16_right.jpeg",
45
+ ],
46
+ )
47
+
48
+ # Run the Gradio app
49
+ if __name__ == "__main__":
50
+ dr_app.launch()