Spaces:
Running
Running
File size: 5,029 Bytes
2ef318a e9dc61f 2ef318a e9dc61f 2ef318a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import cv2
import numpy as np
import streamlit as st
import tensorflow as tf
from utils import _get_retina_bb, _pad_to_square
@st.cache_resource
def load_model(model_file):
model = tf.keras.models.load_model(model_file, compile=False)
print(f'Model {model_file} Loaded!')
return model
@st.cache_resource
def load_gatekeeper():
validator_model = tf.keras.models.load_model('checkpoints/ResNetV2-EyeQ-QA.tf')
print('Gatekeeper Model Loaded!')
return validator_model
def parse_function(image):
image = tf.image.resize(image, [512, 512])
image = tf.image.convert_image_dtype(image, tf.float32)
return image
def main():
st.title('Retina Segmentation')
st.sidebar.title('Segmentation Model')
options = st.sidebar.selectbox('Select Option:', ('Vessels', 'Lesions (BETA)'))
gatekeeper = st.sidebar.radio("Gatekeeper:", ('Enabled', 'Disabled'))
gatekeeper_model = load_gatekeeper()
if options == 'Vessels':
st.set_option('deprecation.showfileUploaderEncoding', False)
uploaded_file = st.file_uploader('Choose an image...', type=('png', 'jpg', 'jpeg'))
model = load_model('checkpoints/DeeplabV3Plus_DRIVE.tf')
if uploaded_file:
col1, col2 = st.columns(2)
# Load Image
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
image = cv2.imdecode(file_bytes, 1)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Check image
valid = np.argmax(gatekeeper_model(parse_function(image[None, ...])))
if valid == 2 and gatekeeper == 'Enabled':
st.image(image)
st.info('Image is of poor quality')
return
# Localise and center retina image
x, y, w, h, _ = _get_retina_bb(image)
image = image[y:y + h, x:x + w, :]
image = _pad_to_square(image, border=0)
image = cv2.resize(image, (1024, 1024))
with col1:
st.subheader("Uploaded Image")
st.image(image)
# Apply CLAHE pre-processing
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(16, 16))
image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
image[:, :, 0] = clahe.apply(image[:, :, 0])
image = cv2.cvtColor(image, cv2.COLOR_LAB2RGB)
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
image = tf.image.convert_image_dtype(image, tf.float32)
# Run model on input
y_pred = model(image[None, ..., None])[0].numpy()
with col2:
st.subheader("Predicted Vessel")
st.image(y_pred)
elif options == 'Lesions (BETA)':
st.write('```--- WARNING: This model is highly experimental ---```')
st.set_option('deprecation.showfileUploaderEncoding', False)
uploaded_file = st.file_uploader('Choose an image...', type=('png', 'jpg', 'jpeg'))
model = load_model('checkpoints/DeeplabV3Plus_FGADR.tf')
if uploaded_file:
col1, col2, col3, = st.columns(3)
# Load Image
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
image = cv2.imdecode(file_bytes, 1)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Check image
valid = np.argmax(gatekeeper_model(parse_function(image[None, ...])))
if valid == 2 and gatekeeper == 'Enabled':
st.image(image)
st.info('Image is of poor quality')
return
# Localise and center retina image
x, y, w, h, _ = _get_retina_bb(image)
image = image[y:y + h, x:x + w, :]
image = _pad_to_square(image, border=0)
image = cv2.resize(image, (1024, 1024))
with col1:
st.subheader("Uploaded Image")
st.image(image)
# Apply CLAHE pre-processing
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(16, 16))
image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
image[:, :, 0] = clahe.apply(image[:, :, 0])
image = cv2.cvtColor(image, cv2.COLOR_LAB2RGB)
image = tf.image.convert_image_dtype(image, tf.float32)
# Run model on input
y_pred = model(image[None, ..., None])[0].numpy()
with col2:
st.subheader(f'MA')
st.image(y_pred[..., 1])
with col3:
st.subheader(f'HE')
st.image(y_pred[..., 2])
with col1:
st.subheader(f'EX')
st.image(y_pred[..., 3])
with col2:
st.subheader(f'SE')
st.image(y_pred[..., 4])
with col3:
st.subheader(f'OD')
st.image(y_pred[..., 5])
if __name__ == '__main__':
tf.config.set_visible_devices([], 'GPU')
main()
|