luigi12345 commited on
Commit
7a986e7
1 Parent(s): 053a66b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -49
app.py CHANGED
@@ -1,99 +1,131 @@
1
- # Copyright (C) 2023, Xu Sun.
2
-
3
- # This program is licensed under the Apache License version 2.
4
- # See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.
5
-
6
  import torch
7
  import numpy as np
8
-
 
 
9
  import matplotlib.pyplot as plt
10
  import streamlit as st
11
-
12
  from PIL import Image
13
- from glaucoma import GlaucomaModel
14
-
15
- run_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
-
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def main():
19
- # Wide mode
20
  st.set_page_config(layout="wide")
21
 
22
- # Designing the interface
23
  st.title("Glaucoma Screening from Retinal Fundus Images")
24
- # For newline
25
- st.write('\n')
26
- # Author info
27
  st.write('Developed by X. Sun. Find more info about me: https://pamixsun.github.io')
28
- # For newline
29
- st.write('\n')
30
- # Instructions
31
- st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*")
32
- # Set the columns
33
  cols = st.beta_columns((1, 1, 1))
34
  cols[0].subheader("Input image")
35
  cols[1].subheader("Optic disc and optic cup")
36
- cols[2].subheader("Class activation map")
37
-
38
- # set the visualization figure
39
- fig, ax = plt.subplots()
40
 
41
- # Sidebar
42
- # File selection
43
  st.sidebar.title("Image selection")
44
- # Disabling warning
45
  st.set_option('deprecation.showfileUploaderEncoding', False)
46
- # Choose your own image
47
  uploaded_file = st.sidebar.file_uploader("Upload image", type=['png', 'jpeg', 'jpg'])
 
48
  if uploaded_file is not None:
49
- # read the upload image
50
  image = Image.open(uploaded_file).convert('RGB')
51
  image = np.array(image).astype(np.uint8)
52
- # page_idx = 0
53
  ax.imshow(image)
54
  ax.axis('off')
55
  cols[0].pyplot(fig)
56
 
57
- # For newline
58
- st.sidebar.write('\n')
59
-
60
- # actions
61
  if st.sidebar.button("Analyze image"):
62
-
63
  if uploaded_file is None:
64
  st.sidebar.write("Please upload an image")
65
-
66
  else:
67
  with st.spinner('Loading model...'):
68
- # load model
 
69
  model = GlaucomaModel(device=run_device)
70
 
71
  with st.spinner('Analyzing...'):
72
- # Forward the image to the model and get results
73
- disease_idx, disc_cup_image, cam, vcdr = model.process(image)
74
 
75
- # plot the optic disc and optic cup image
76
  ax.imshow(disc_cup_image)
77
  ax.axis('off')
78
  cols[1].pyplot(fig)
79
 
80
- # plot the stitched image
81
- ax.imshow(cam)
82
  ax.axis('off')
83
  cols[2].pyplot(fig)
84
 
85
- # Display JSON
86
- st.subheader(" Screening results:")
87
- st.write('\n')
88
-
89
  final_results_as_table = f"""
90
  |Parameters|Outcomes|
91
  |---|---|
92
  |Vertical cup-to-disc ratio|{vcdr:.04f}|
93
- |Category|{model.cls_id2label[disease_idx]}|
 
 
94
  """
95
  st.markdown(final_results_as_table)
96
 
97
-
98
  if __name__ == '__main__':
99
  main()
 
1
+ import cv2
 
 
 
 
2
  import torch
3
  import numpy as np
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation
7
  import matplotlib.pyplot as plt
8
  import streamlit as st
 
9
  from PIL import Image
 
 
 
 
10
 
11
+ # --- GlaucomaModel Class ---
12
+ class GlaucomaModel(object):
13
+ def __init__(self,
14
+ cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification",
15
+ seg_model_path='pamixsun/segformer_for_optic_disc_cup_segmentation',
16
+ device=torch.device('cpu')):
17
+ self.device = device
18
+ # Classification model for glaucoma
19
+ self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
20
+ self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
21
+ # Segmentation model for optic disc and cup
22
+ self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path)
23
+ self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval()
24
+
25
+ # Class activation map
26
+ self.cls_id2label = self.cls_model.config.id2label
27
+ self.seg_id2label = self.seg_model.config.id2label
28
+
29
+ def glaucoma_pred(self, image):
30
+ inputs = self.cls_extractor(images=image.copy(), return_tensors="pt")
31
+ with torch.no_grad():
32
+ inputs.to(self.device)
33
+ outputs = self.cls_model(**inputs).logits
34
+ # Softmax for probabilities
35
+ probs = F.softmax(outputs, dim=-1)
36
+ disease_idx = probs.cpu()[0, :].numpy().argmax()
37
+ confidence = probs.cpu()[0, disease_idx].item()
38
+ return disease_idx, confidence
39
+
40
+ def optic_disc_cup_pred(self, image):
41
+ inputs = self.seg_extractor(images=image.copy(), return_tensors="pt")
42
+ with torch.no_grad():
43
+ inputs.to(self.device)
44
+ outputs = self.seg_model(**inputs)
45
+ logits = outputs.logits.cpu()
46
+ upsampled_logits = nn.functional.interpolate(
47
+ logits, size=image.shape[:2], mode="bilinear", align_corners=False
48
+ )
49
+ # Softmax for segmentation confidence
50
+ seg_probs = F.softmax(upsampled_logits, dim=1)
51
+ pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
52
+ cup_confidence = seg_probs[0, 2, :, :].mean().item()
53
+ disc_confidence = seg_probs[0, 1, :, :].mean().item()
54
+ return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence
55
+
56
+ def process(self, image):
57
+ image_shape = image.shape[:2]
58
+ disease_idx, cls_confidence = self.glaucoma_pred(image)
59
+ disc_cup, cup_confidence, disc_confidence = self.optic_disc_cup_pred(image)
60
+ try:
61
+ vcdr = simple_vcdr(disc_cup) # Assuming simple_vcdr() is defined elsewhere
62
+ except:
63
+ vcdr = np.nan
64
+ _, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2)
65
+ return disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence
66
+
67
+ # --- Streamlit Interface ---
68
  def main():
69
+ # Wide mode in Streamlit
70
  st.set_page_config(layout="wide")
71
 
 
72
  st.title("Glaucoma Screening from Retinal Fundus Images")
 
 
 
73
  st.write('Developed by X. Sun. Find more info about me: https://pamixsun.github.io')
74
+
75
+ # Set columns for the interface
 
 
 
76
  cols = st.beta_columns((1, 1, 1))
77
  cols[0].subheader("Input image")
78
  cols[1].subheader("Optic disc and optic cup")
79
+ cols[2].subheader("Classification Map")
 
 
 
80
 
81
+ # File uploader
 
82
  st.sidebar.title("Image selection")
 
83
  st.set_option('deprecation.showfileUploaderEncoding', False)
 
84
  uploaded_file = st.sidebar.file_uploader("Upload image", type=['png', 'jpeg', 'jpg'])
85
+
86
  if uploaded_file is not None:
87
+ # Read and display uploaded image
88
  image = Image.open(uploaded_file).convert('RGB')
89
  image = np.array(image).astype(np.uint8)
90
+ fig, ax = plt.subplots()
91
  ax.imshow(image)
92
  ax.axis('off')
93
  cols[0].pyplot(fig)
94
 
 
 
 
 
95
  if st.sidebar.button("Analyze image"):
 
96
  if uploaded_file is None:
97
  st.sidebar.write("Please upload an image")
 
98
  else:
99
  with st.spinner('Loading model...'):
100
+ # Load the model on available device
101
+ run_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
102
  model = GlaucomaModel(device=run_device)
103
 
104
  with st.spinner('Analyzing...'):
105
+ # Get predictions from the model
106
+ disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence = model.process(image)
107
 
108
+ # Display optic disc and cup image
109
  ax.imshow(disc_cup_image)
110
  ax.axis('off')
111
  cols[1].pyplot(fig)
112
 
113
+ # Display classification map
114
+ ax.imshow(image)
115
  ax.axis('off')
116
  cols[2].pyplot(fig)
117
 
118
+ # Display results with confidence
119
+ st.subheader("Screening results:")
 
 
120
  final_results_as_table = f"""
121
  |Parameters|Outcomes|
122
  |---|---|
123
  |Vertical cup-to-disc ratio|{vcdr:.04f}|
124
+ |Category|{model.cls_id2label[disease_idx]} ({cls_confidence*100:.02f}% confidence)|
125
+ |Optic Cup Segmentation Confidence|{cup_confidence*100:.02f}%|
126
+ |Optic Disc Segmentation Confidence|{disc_confidence*100:.02f}%|
127
  """
128
  st.markdown(final_results_as_table)
129
 
 
130
  if __name__ == '__main__':
131
  main()