luigi12345
commited on
Commit
•
7a986e7
1
Parent(s):
053a66b
Update app.py
Browse files
app.py
CHANGED
@@ -1,99 +1,131 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
# This program is licensed under the Apache License version 2.
|
4 |
-
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.
|
5 |
-
|
6 |
import torch
|
7 |
import numpy as np
|
8 |
-
|
|
|
|
|
9 |
import matplotlib.pyplot as plt
|
10 |
import streamlit as st
|
11 |
-
|
12 |
from PIL import Image
|
13 |
-
from glaucoma import GlaucomaModel
|
14 |
-
|
15 |
-
run_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
16 |
-
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def main():
|
19 |
-
# Wide mode
|
20 |
st.set_page_config(layout="wide")
|
21 |
|
22 |
-
# Designing the interface
|
23 |
st.title("Glaucoma Screening from Retinal Fundus Images")
|
24 |
-
# For newline
|
25 |
-
st.write('\n')
|
26 |
-
# Author info
|
27 |
st.write('Developed by X. Sun. Find more info about me: https://pamixsun.github.io')
|
28 |
-
|
29 |
-
|
30 |
-
# Instructions
|
31 |
-
st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*")
|
32 |
-
# Set the columns
|
33 |
cols = st.beta_columns((1, 1, 1))
|
34 |
cols[0].subheader("Input image")
|
35 |
cols[1].subheader("Optic disc and optic cup")
|
36 |
-
cols[2].subheader("
|
37 |
-
|
38 |
-
# set the visualization figure
|
39 |
-
fig, ax = plt.subplots()
|
40 |
|
41 |
-
#
|
42 |
-
# File selection
|
43 |
st.sidebar.title("Image selection")
|
44 |
-
# Disabling warning
|
45 |
st.set_option('deprecation.showfileUploaderEncoding', False)
|
46 |
-
# Choose your own image
|
47 |
uploaded_file = st.sidebar.file_uploader("Upload image", type=['png', 'jpeg', 'jpg'])
|
|
|
48 |
if uploaded_file is not None:
|
49 |
-
#
|
50 |
image = Image.open(uploaded_file).convert('RGB')
|
51 |
image = np.array(image).astype(np.uint8)
|
52 |
-
|
53 |
ax.imshow(image)
|
54 |
ax.axis('off')
|
55 |
cols[0].pyplot(fig)
|
56 |
|
57 |
-
# For newline
|
58 |
-
st.sidebar.write('\n')
|
59 |
-
|
60 |
-
# actions
|
61 |
if st.sidebar.button("Analyze image"):
|
62 |
-
|
63 |
if uploaded_file is None:
|
64 |
st.sidebar.write("Please upload an image")
|
65 |
-
|
66 |
else:
|
67 |
with st.spinner('Loading model...'):
|
68 |
-
#
|
|
|
69 |
model = GlaucomaModel(device=run_device)
|
70 |
|
71 |
with st.spinner('Analyzing...'):
|
72 |
-
#
|
73 |
-
disease_idx, disc_cup_image,
|
74 |
|
75 |
-
#
|
76 |
ax.imshow(disc_cup_image)
|
77 |
ax.axis('off')
|
78 |
cols[1].pyplot(fig)
|
79 |
|
80 |
-
#
|
81 |
-
ax.imshow(
|
82 |
ax.axis('off')
|
83 |
cols[2].pyplot(fig)
|
84 |
|
85 |
-
# Display
|
86 |
-
st.subheader("
|
87 |
-
st.write('\n')
|
88 |
-
|
89 |
final_results_as_table = f"""
|
90 |
|Parameters|Outcomes|
|
91 |
|---|---|
|
92 |
|Vertical cup-to-disc ratio|{vcdr:.04f}|
|
93 |
-
|Category|{model.cls_id2label[disease_idx]}|
|
|
|
|
|
94 |
"""
|
95 |
st.markdown(final_results_as_table)
|
96 |
|
97 |
-
|
98 |
if __name__ == '__main__':
|
99 |
main()
|
|
|
1 |
+
import cv2
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation
|
7 |
import matplotlib.pyplot as plt
|
8 |
import streamlit as st
|
|
|
9 |
from PIL import Image
|
|
|
|
|
|
|
|
|
10 |
|
11 |
+
# --- GlaucomaModel Class ---
|
12 |
+
class GlaucomaModel(object):
|
13 |
+
def __init__(self,
|
14 |
+
cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification",
|
15 |
+
seg_model_path='pamixsun/segformer_for_optic_disc_cup_segmentation',
|
16 |
+
device=torch.device('cpu')):
|
17 |
+
self.device = device
|
18 |
+
# Classification model for glaucoma
|
19 |
+
self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
|
20 |
+
self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
|
21 |
+
# Segmentation model for optic disc and cup
|
22 |
+
self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path)
|
23 |
+
self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval()
|
24 |
+
|
25 |
+
# Class activation map
|
26 |
+
self.cls_id2label = self.cls_model.config.id2label
|
27 |
+
self.seg_id2label = self.seg_model.config.id2label
|
28 |
+
|
29 |
+
def glaucoma_pred(self, image):
|
30 |
+
inputs = self.cls_extractor(images=image.copy(), return_tensors="pt")
|
31 |
+
with torch.no_grad():
|
32 |
+
inputs.to(self.device)
|
33 |
+
outputs = self.cls_model(**inputs).logits
|
34 |
+
# Softmax for probabilities
|
35 |
+
probs = F.softmax(outputs, dim=-1)
|
36 |
+
disease_idx = probs.cpu()[0, :].numpy().argmax()
|
37 |
+
confidence = probs.cpu()[0, disease_idx].item()
|
38 |
+
return disease_idx, confidence
|
39 |
+
|
40 |
+
def optic_disc_cup_pred(self, image):
|
41 |
+
inputs = self.seg_extractor(images=image.copy(), return_tensors="pt")
|
42 |
+
with torch.no_grad():
|
43 |
+
inputs.to(self.device)
|
44 |
+
outputs = self.seg_model(**inputs)
|
45 |
+
logits = outputs.logits.cpu()
|
46 |
+
upsampled_logits = nn.functional.interpolate(
|
47 |
+
logits, size=image.shape[:2], mode="bilinear", align_corners=False
|
48 |
+
)
|
49 |
+
# Softmax for segmentation confidence
|
50 |
+
seg_probs = F.softmax(upsampled_logits, dim=1)
|
51 |
+
pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
|
52 |
+
cup_confidence = seg_probs[0, 2, :, :].mean().item()
|
53 |
+
disc_confidence = seg_probs[0, 1, :, :].mean().item()
|
54 |
+
return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence
|
55 |
+
|
56 |
+
def process(self, image):
|
57 |
+
image_shape = image.shape[:2]
|
58 |
+
disease_idx, cls_confidence = self.glaucoma_pred(image)
|
59 |
+
disc_cup, cup_confidence, disc_confidence = self.optic_disc_cup_pred(image)
|
60 |
+
try:
|
61 |
+
vcdr = simple_vcdr(disc_cup) # Assuming simple_vcdr() is defined elsewhere
|
62 |
+
except:
|
63 |
+
vcdr = np.nan
|
64 |
+
_, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2)
|
65 |
+
return disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence
|
66 |
+
|
67 |
+
# --- Streamlit Interface ---
|
68 |
def main():
|
69 |
+
# Wide mode in Streamlit
|
70 |
st.set_page_config(layout="wide")
|
71 |
|
|
|
72 |
st.title("Glaucoma Screening from Retinal Fundus Images")
|
|
|
|
|
|
|
73 |
st.write('Developed by X. Sun. Find more info about me: https://pamixsun.github.io')
|
74 |
+
|
75 |
+
# Set columns for the interface
|
|
|
|
|
|
|
76 |
cols = st.beta_columns((1, 1, 1))
|
77 |
cols[0].subheader("Input image")
|
78 |
cols[1].subheader("Optic disc and optic cup")
|
79 |
+
cols[2].subheader("Classification Map")
|
|
|
|
|
|
|
80 |
|
81 |
+
# File uploader
|
|
|
82 |
st.sidebar.title("Image selection")
|
|
|
83 |
st.set_option('deprecation.showfileUploaderEncoding', False)
|
|
|
84 |
uploaded_file = st.sidebar.file_uploader("Upload image", type=['png', 'jpeg', 'jpg'])
|
85 |
+
|
86 |
if uploaded_file is not None:
|
87 |
+
# Read and display uploaded image
|
88 |
image = Image.open(uploaded_file).convert('RGB')
|
89 |
image = np.array(image).astype(np.uint8)
|
90 |
+
fig, ax = plt.subplots()
|
91 |
ax.imshow(image)
|
92 |
ax.axis('off')
|
93 |
cols[0].pyplot(fig)
|
94 |
|
|
|
|
|
|
|
|
|
95 |
if st.sidebar.button("Analyze image"):
|
|
|
96 |
if uploaded_file is None:
|
97 |
st.sidebar.write("Please upload an image")
|
|
|
98 |
else:
|
99 |
with st.spinner('Loading model...'):
|
100 |
+
# Load the model on available device
|
101 |
+
run_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
102 |
model = GlaucomaModel(device=run_device)
|
103 |
|
104 |
with st.spinner('Analyzing...'):
|
105 |
+
# Get predictions from the model
|
106 |
+
disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence = model.process(image)
|
107 |
|
108 |
+
# Display optic disc and cup image
|
109 |
ax.imshow(disc_cup_image)
|
110 |
ax.axis('off')
|
111 |
cols[1].pyplot(fig)
|
112 |
|
113 |
+
# Display classification map
|
114 |
+
ax.imshow(image)
|
115 |
ax.axis('off')
|
116 |
cols[2].pyplot(fig)
|
117 |
|
118 |
+
# Display results with confidence
|
119 |
+
st.subheader("Screening results:")
|
|
|
|
|
120 |
final_results_as_table = f"""
|
121 |
|Parameters|Outcomes|
|
122 |
|---|---|
|
123 |
|Vertical cup-to-disc ratio|{vcdr:.04f}|
|
124 |
+
|Category|{model.cls_id2label[disease_idx]} ({cls_confidence*100:.02f}% confidence)|
|
125 |
+
|Optic Cup Segmentation Confidence|{cup_confidence*100:.02f}%|
|
126 |
+
|Optic Disc Segmentation Confidence|{disc_confidence*100:.02f}%|
|
127 |
"""
|
128 |
st.markdown(final_results_as_table)
|
129 |
|
|
|
130 |
if __name__ == '__main__':
|
131 |
main()
|