luigi12345's picture
Update app.py
a4f839e verified
raw
history blame
7.95 kB
import cv2
import torch
import numpy as np
import torch.nn.functional as F
from torch import nn
from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
import streamlit as st
from PIL import Image
import io
# --- GlaucomaModel Class ---
class GlaucomaModel(object):
def __init__(self,
cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification",
seg_model_path='pamixsun/segformer_for_optic_disc_cup_segmentation',
device=torch.device('cpu')):
self.device = device
# Classification model for glaucoma
self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
# Segmentation model for optic disc and cup
self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path)
self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval()
# Class activation map
self.cls_id2label = self.cls_model.config.id2label
self.seg_id2label = self.seg_model.config.id2label
def glaucoma_pred(self, image):
inputs = self.cls_extractor(images=image.copy(), return_tensors="pt")
with torch.no_grad():
inputs.to(self.device)
outputs = self.cls_model(**inputs).logits
# Softmax for probabilities
probs = F.softmax(outputs, dim=-1)
disease_idx = probs.cpu()[0, :].numpy().argmax()
confidence = probs.cpu()[0, disease_idx].item() * 100 # Scale to percentage
return disease_idx, confidence
def optic_disc_cup_pred(self, image):
inputs = self.seg_extractor(images=image.copy(), return_tensors="pt")
with torch.no_grad():
inputs.to(self.device)
outputs = self.seg_model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = nn.functional.interpolate(
logits, size=image.shape[:2], mode="bilinear", align_corners=False
)
# Softmax for segmentation confidence
seg_probs = F.softmax(upsampled_logits, dim=1)
pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
cup_confidence = seg_probs[0, 2, :, :].mean().item() * 100 # Scale to percentage
disc_confidence = seg_probs[0, 1, :, :].mean().item() * 100 # Scale to percentage
return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence
def process(self, image):
image_shape = image.shape[:2]
disease_idx, cls_confidence = self.glaucoma_pred(image)
disc_cup, cup_confidence, disc_confidence = self.optic_disc_cup_pred(image)
try:
vcdr = simple_vcdr(disc_cup) # Calculate vertical cup-to-disc ratio
except:
vcdr = np.nan
# Mask for optic disc and cup
mask = (disc_cup > 0).astype(np.uint8)
# Get bounding box of the optic cup + disc and add dynamic padding
x, y, w, h = cv2.boundingRect(mask)
padding = max(50, int(0.2 * max(w, h))) # Dynamic padding (20% of width or height)
x = max(x - padding, 0)
y = max(y - padding, 0)
w = min(w + 2 * padding, image.shape[1] - x)
h = min(h + 2 * padding, image.shape[0] - y)
# Ensure that the bounding box is large enough to avoid cropping errors
cropped_image = image[y:y+h, x:x+w] if w >= 50 and h >= 50 else image.copy()
# Generate disc and cup visualization
_, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2)
return disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image
# --- Utility Functions ---
def simple_vcdr(mask):
"""
Simple function to calculate the vertical cup-to-disc ratio (VCDR).
Assumes:
- mask contains class 1 for optic disc and class 2 for optic cup.
"""
disc_area = np.sum(mask == 1)
cup_area = np.sum(mask == 2)
if disc_area == 0: # Avoid division by zero
return np.nan
vcdr = cup_area / disc_area
return vcdr
def add_mask(image, mask, classes, colors, alpha=0.5):
"""
Adds a transparent mask to the original image.
Args:
- image: the original RGB image
- mask: the predicted segmentation mask
- classes: a list of class indices to apply masks for (e.g., [1, 2])
- colors: a list of colors for each class (e.g., [[0, 255, 0], [255, 0, 0]] for green and red)
- alpha: transparency level (default = 0.5)
"""
overlay = image.copy()
for class_id, color in zip(classes, colors):
overlay[mask == class_id] = color
output = cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0)
return output, overlay
# --- Streamlit Interface ---
def main():
st.set_page_config(layout="wide")
st.title("Glaucoma Screening from Retinal Fundus Images")
st.write('Developed by X. Sun. Find more info about me: https://pamixsun.github.io')
# Set columns for the interface
cols = st.beta_columns((1, 1, 1, 1))
cols[0].subheader("Input image")
cols[1].subheader("Optic disc and optic cup")
cols[2].subheader("Class activation map")
cols[3].subheader("Cropped Image")
# File uploader
st.sidebar.title("Image selection")
st.set_option('deprecation.showfileUploaderEncoding', False)
uploaded_file = st.sidebar.file_uploader("Upload image", type=['png', 'jpeg', 'jpg'])
if uploaded_file is not None:
# Read and display uploaded image
image = Image.open(uploaded_file).convert('RGB')
image = np.array(image).astype(np.uint8)
fig, ax = plt.subplots()
ax.imshow(image)
ax.axis('off')
cols[0].pyplot(fig)
if st.sidebar.button("Analyze image"):
if uploaded_file is None:
st.sidebar.write("Please upload an image")
else:
with st.spinner('Loading model...'):
run_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = GlaucomaModel(device=run_device)
with st.spinner('Analyzing...'):
# Get predictions from the model
disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image = model.process(image)
# Display optic disc and cup image
ax.imshow(disc_cup_image)
ax.axis('off')
cols[1].pyplot(fig)
# Display classification map
ax.imshow(image)
ax.axis('off')
cols[2].pyplot(fig)
# Display the cropped image
ax.imshow(cropped_image)
ax.axis('off')
cols[3].pyplot(fig)
# Make cropped image downloadable
buf = io.BytesIO()
Image.fromarray(cropped_image).save(buf, format="PNG")
st.sidebar.download_button(
label="Download Cropped Image",
data=buf.getvalue(),
file_name="cropped_image.png",
mime="image/png"
)
# Display results with confidence
st.subheader("Screening results:")
final_results_as_table = f"""
|Parameters|Outcomes|
|---|---|
|Vertical cup-to-disc ratio|{vcdr:.04f}|
|Category|{model.cls_id2label[disease_idx]} ({cls_confidence:.02f}% confidence)|
|Optic Cup Segmentation Confidence|{cup_confidence:.02f}%|
|Optic Disc Segmentation Confidence|{disc_confidence:.02f}%|
"""
st.markdown(final_results_as_table)
if __name__ == '__main__':
main()