File size: 7,071 Bytes
fb2c012
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import cv2
import torch
import numpy as np
import torch.nn.functional as F
from torch import nn
from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
import streamlit as st
from PIL import Image
import io
import zipfile
import os

# --- GlaucomaModel Class ---
class GlaucomaModel(object):
    def __init__(self, 
                 cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification", 
                 seg_model_path='pamixsun/segformer_for_optic_disc_cup_segmentation',
                 device=torch.device('cpu')):
        self.device = device
        # Classification model for glaucoma
        self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
        self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
        # Segmentation model for optic disc and cup
        self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path)
        self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval()
        # Mapping for class labels
        self.cls_id2label = self.cls_model.config.id2label

    def glaucoma_pred(self, image):
        inputs = self.cls_extractor(images=image.copy(), return_tensors="pt")
        with torch.no_grad():
            inputs.to(self.device)
            outputs = self.cls_model(**inputs).logits
            probs = F.softmax(outputs, dim=-1)
            disease_idx = probs.cpu()[0, :].numpy().argmax()
            confidence = probs.cpu()[0, disease_idx].item() * 100
        return disease_idx, confidence

    def optic_disc_cup_pred(self, image):
        inputs = self.seg_extractor(images=image.copy(), return_tensors="pt")
        with torch.no_grad():
            inputs.to(self.device)
            outputs = self.seg_model(**inputs)
        logits = outputs.logits.cpu()
        upsampled_logits = nn.functional.interpolate(
            logits, size=image.shape[:2], mode="bilinear", align_corners=False
        )
        seg_probs = F.softmax(upsampled_logits, dim=1)
        pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
        cup_confidence = seg_probs[0, 2, :, :].mean().item() * 100
        disc_confidence = seg_probs[0, 1, :, :].mean().item() * 100
        return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence

    def process(self, image):
        disease_idx, cls_confidence = self.glaucoma_pred(image)
        disc_cup, cup_confidence, disc_confidence = self.optic_disc_cup_pred(image)

        try:
            vcdr = simple_vcdr(disc_cup)
        except:
            vcdr = np.nan

        mask = (disc_cup > 0).astype(np.uint8)
        x, y, w, h = cv2.boundingRect(mask)
        padding = max(50, int(0.2 * max(w, h)))
        x = max(x - padding, 0)
        y = max(y - padding, 0)
        w = min(w + 2 * padding, image.shape[1] - x)
        h = min(h + 2 * padding, image.shape[0] - y)

        cropped_image = image[y:y+h, x:x+w] if w >= 50 and h >= 50 else image.copy()
        _, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2)

        return disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image

# --- Utility Functions ---
def simple_vcdr(mask):
    disc_area = np.sum(mask == 1)
    cup_area = np.sum(mask == 2)
    if disc_area == 0:
        return np.nan
    vcdr = cup_area / disc_area
    return vcdr

def add_mask(image, mask, classes, colors, alpha=0.5):
    overlay = image.copy()
    for class_id, color in zip(classes, colors):
        overlay[mask == class_id] = color
    output = cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0)
    return output, overlay

# --- Streamlit Interface ---
def main():
    st.set_page_config(layout="wide")
    st.title("Batch Glaucoma Screening from Retinal Fundus Images")
    
    # Explanation for the confidence threshold
    st.sidebar.write("**Confidence Threshold** (optional): Set a threshold to filter images based on the model's confidence in glaucoma classification.")
    confidence_threshold = st.sidebar.slider("Confidence Threshold (%)", 0, 100, 70)
    uploaded_files = st.sidebar.file_uploader("Upload Images", type=['png', 'jpeg', 'jpg'], accept_multiple_files=True)
    
    confident_images = []
    download_confident_images = []

    if uploaded_files:
        for uploaded_file in uploaded_files:
            image = Image.open(uploaded_file).convert('RGB')
            image_np = np.array(image).astype(np.uint8)
            
            with st.spinner(f'Processing {uploaded_file.name}...'):
                model = GlaucomaModel(device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
                disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image_np)
                
                # Confidence-based grouping
                is_confident = cls_conf >= confidence_threshold
                if is_confident:
                    confident_images.append(uploaded_file.name)
                    download_confident_images.append((cropped_image, uploaded_file.name))
                
                # Display Results
                with st.container():
                    st.subheader(f"Results for {uploaded_file.name}")
                    cols = st.columns(4)
                    cols[0].image(image_np, caption="Input Image", use_column_width=True)
                    cols[1].image(disc_cup_image, caption="Disc/Cup Segmentation", use_column_width=True)
                    cols[2].image(image_np, caption="Class Activation Map", use_column_width=True)
                    cols[3].image(cropped_image, caption="Cropped Image", use_column_width=True)

                    st.write(f"**Vertical cup-to-disc ratio:** {vcdr:.04f}")
                    st.write(f"**Category:** {model.cls_id2label[disease_idx]} ({cls_conf:.02f}% confidence)")
                    st.write(f"**Optic Cup Segmentation Confidence:** {cup_conf:.02f}%")
                    st.write(f"**Optic Disc Segmentation Confidence:** {disc_conf:.02f}%")
                    st.write(f"**Confidence Group:** {'Confident' if is_confident else 'Not Confident'}")
        
        # Download Button for Confident Images
        if download_confident_images:
            with zipfile.ZipFile("confident_cropped_images.zip", "w") as zf:
                for cropped_image, name in download_confident_images:
                    img_buffer = io.BytesIO()
                    Image.fromarray(cropped_image).save(img_buffer, format="PNG")
                    zf.writestr(f"{name}_cropped.png", img_buffer.getvalue())
                    
            # Provide a markdown link to the ZIP file
            st.sidebar.markdown(
                f"[Download Confident Cropped Images](./confident_cropped_images.zip)", 
                unsafe_allow_html=True
            )
    else:
        st.sidebar.info("Upload images to begin analysis.")

if __name__ == '__main__':
    main()