# Copyright (C) 2023, Xu Sun. # This program is licensed under the Apache License version 2. # See LICENSE or go to for full license details. import torch import numpy as np import matplotlib.pyplot as plt import streamlit as st from PIL import Image import cv2 # <-- Add this import statement from glaucoma import GlaucomaModel run_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Helper function to crop with padding # Helper function to crop the green area (label == 1) with padding def crop_green_area_with_padding(image, mask, padding=20): # Find the green area (label == 1) green_area = (mask == 1).astype(np.uint8) # Calculate the center of mass (centroid) of the green area moments = cv2.moments(green_area) if moments["m00"] != 0: center_x = int(moments["m10"] / moments["m00"]) center_y = int(moments["m01"] / moments["m00"]) else: center_x, center_y = image.shape[1] // 2, image.shape[0] // 2 # Fallback to image center if no green area # Define the radius as a fraction of the image size or based on the green area radius = int(min(image.shape[1], image.shape[0]) * 0.7 / 4) # Apply padding and cropping h, w = image.shape[:2] x1 = max(0, center_x - radius - padding) y1 = max(0, center_y - radius - padding) x2 = min(w, center_x + radius + padding) y2 = min(h, center_y + radius + padding) # Crop the image cropped = image[y1:y2, x1:x2] # Calculate the desired size for padding desired_size = (2 * radius + 2 * padding, 2 * radius + 2 * padding) # Pad the cropped image if it's smaller than the desired size padded_cropped = cv2.copyMakeBorder( cropped, top=max(0, (desired_size[1] - cropped.shape[0]) // 2), bottom=max(0, (desired_size[1] - cropped.shape[0]) // 2), left=max(0, (desired_size[0] - cropped.shape[1]) // 2), right=max(0, (desired_size[0] - cropped.shape[1]) // 2), borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0], # Black padding ) return padded_cropped def main(): # Wide mode st.set_page_config(layout="wide") # Designing the interface st.title("Glaucoma Screening from Retinal Fundus Images") # For newline st.write('\n') # Author info st.write('Developed by X. Sun. Find more info about me: https://pamixsun.github.io') # For newline st.write('\n') # Instructions st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*") # set the visualization figure fig, ax = plt.subplots() # Sidebar st.sidebar.title("Image selection") st.set_option('deprecation.showfileUploaderEncoding', False) uploaded_file = st.sidebar.file_uploader("Upload image", type=['png', 'jpeg', 'jpg']) if uploaded_file is not None: # Read and display the uploaded image image = Image.open(uploaded_file).convert('RGB') image_np = np.array(image).astype(np.uint8) ax.imshow(image_np) ax.axis('off') st.pyplot(fig) st.sidebar.write('\n') # Analyze button if st.sidebar.button("Analyze image"): if uploaded_file is None: st.sidebar.write("Please upload an image") else: with st.spinner('Loading model...'): model = GlaucomaModel(device=run_device) with st.spinner('Analyzing...'): # Process the image with the model disease_idx, disc_cup_image, cam, vcdr = model.process(image_np) # Display results st.subheader("Optic Disc and Optic Cup Segmentation") ax.imshow(disc_cup_image) ax.axis('off') st.pyplot(fig) st.subheader("Class Activation Map (CAM)") ax.imshow(cam) ax.axis('off') st.pyplot(fig) # Display results as a table st.subheader("Screening Results") final_results_as_table = f""" |Parameters|Outcomes| |---|---| |Vertical cup-to-disc ratio|{vcdr:.04f}| |Category|{model.cls_id2label[disease_idx]}| """ st.markdown(final_results_as_table) # Cropping with padding st.subheader("Cropped Image with Padding") gray_image = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY) # Ensure OpenCV is used blurred_image = cv2.GaussianBlur(gray_image, (65, 65), 0) max_intensity_pixel = np.unravel_index(np.argmax(blurred_image), blurred_image.shape) radius = int(min(image.width, image.height) * 0.7 / 4) cropped_image = crop_with_padding(image_np, max_intensity_pixel, radius, padding=30) # Display the cropped image ax.imshow(cropped_image) ax.axis('off') st.pyplot(fig) if __name__ == '__main__': main(