Rodiyah commited on
Commit
e473c8a
·
verified ·
1 Parent(s): 6523e27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -9
app.py CHANGED
@@ -11,20 +11,19 @@ from pytorch_grad_cam.utils.image import show_cam_on_image
11
  import os
12
  import datetime
13
 
 
14
  # Setup
 
15
  device = torch.device("cpu")
16
  save_dir = "/home/user/app/saved_predictions"
17
- if not os.path.exists(save_dir):
18
- os.makedirs(save_dir)
19
- print("📁 Folder created:", save_dir)
20
  os.makedirs(save_dir, exist_ok=True)
21
 
22
-
23
  # Placeholder image for invalid uploads
24
  invalid_img = Image.new("RGB", (224, 224), color=(200, 200, 200))
25
 
26
-
27
  # Load model
 
28
  model = models.resnet50(weights=None)
29
  model.fc = torch.nn.Linear(model.fc.in_features, 2)
30
  model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
@@ -43,6 +42,9 @@ transform = transforms.Compose([
43
  [0.229, 0.224, 0.225])
44
  ])
45
 
 
 
 
46
  def looks_like_fundus(image):
47
  """
48
  Basic heuristic to check if an image is likely a retinal fundus scan.
@@ -56,6 +58,9 @@ def looks_like_fundus(image):
56
  # Fundus images usually occupy ~40–75% of the area
57
  return 0.40 < white_ratio < 0.75
58
 
 
 
 
59
  def predict_retinopathy(image):
60
  # Validate image first
61
  if not looks_like_fundus(image):
@@ -80,7 +85,10 @@ def predict_retinopathy(image):
80
  # Grad-CAM
81
  rgb_img_np = np.array(img).astype(np.float32) / 255.0
82
  rgb_img_np = np.ascontiguousarray(rgb_img_np)
83
- grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
 
 
 
84
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
85
  cam_pil = Image.fromarray(cam_image)
86
 
@@ -88,12 +96,14 @@ def predict_retinopathy(image):
88
  filename = f"{timestamp}_{label}_{confidence:.2f}.png"
89
  cam_pil.save(os.path.join(save_dir, filename))
90
 
91
- return (
92
  cam_pil,
93
  f"{label} (Confidence: {confidence:.2f})"
94
  )
95
 
 
96
  # Gradio app
 
97
  gr.Interface(
98
  fn=predict_retinopathy,
99
  inputs=gr.Image(type="pil", label="Upload Retinal Image"),
@@ -107,8 +117,7 @@ gr.Interface(
107
  "Upload an image to classify DR and visualise the Grad-CAM heatmap showing important regions."
108
  ),
109
  article=(
110
- "⚕️ **OpthaDetect** is an AI-powered ophthalmic diagnostic tool. "
111
  "It highlights retinal risk regions using Grad-CAM for better clinical interpretability."
112
  )
113
  ).launch()
114
-
 
11
  import os
12
  import datetime
13
 
14
+ # -----------------------
15
  # Setup
16
+ # -----------------------
17
  device = torch.device("cpu")
18
  save_dir = "/home/user/app/saved_predictions"
 
 
 
19
  os.makedirs(save_dir, exist_ok=True)
20
 
 
21
  # Placeholder image for invalid uploads
22
  invalid_img = Image.new("RGB", (224, 224), color=(200, 200, 200))
23
 
24
+ # -----------------------
25
  # Load model
26
+ # -----------------------
27
  model = models.resnet50(weights=None)
28
  model.fc = torch.nn.Linear(model.fc.in_features, 2)
29
  model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
 
42
  [0.229, 0.224, 0.225])
43
  ])
44
 
45
+ # -----------------------
46
+ # Helper: basic fundus check
47
+ # -----------------------
48
  def looks_like_fundus(image):
49
  """
50
  Basic heuristic to check if an image is likely a retinal fundus scan.
 
58
  # Fundus images usually occupy ~40–75% of the area
59
  return 0.40 < white_ratio < 0.75
60
 
61
+ # -----------------------
62
+ # Predict and save
63
+ # -----------------------
64
  def predict_retinopathy(image):
65
  # Validate image first
66
  if not looks_like_fundus(image):
 
85
  # Grad-CAM
86
  rgb_img_np = np.array(img).astype(np.float32) / 255.0
87
  rgb_img_np = np.ascontiguousarray(rgb_img_np)
88
+ grayscale_cam = cam(
89
+ input_tensor=img_tensor,
90
+ targets=[ClassifierOutputTarget(pred)]
91
+ )[0]
92
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
93
  cam_pil = Image.fromarray(cam_image)
94
 
 
96
  filename = f"{timestamp}_{label}_{confidence:.2f}.png"
97
  cam_pil.save(os.path.join(save_dir, filename))
98
 
99
+ return (
100
  cam_pil,
101
  f"{label} (Confidence: {confidence:.2f})"
102
  )
103
 
104
+ # -----------------------
105
  # Gradio app
106
+ # -----------------------
107
  gr.Interface(
108
  fn=predict_retinopathy,
109
  inputs=gr.Image(type="pil", label="Upload Retinal Image"),
 
117
  "Upload an image to classify DR and visualise the Grad-CAM heatmap showing important regions."
118
  ),
119
  article=(
120
+ "⚕️ **OpthaDetect** is an AI-powered ophthalmic decision-support tool. "
121
  "It highlights retinal risk regions using Grad-CAM for better clinical interpretability."
122
  )
123
  ).launch()