|
import cv2 |
|
import torch |
|
import numpy as np |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation |
|
import streamlit as st |
|
from PIL import Image |
|
import io |
|
import zipfile |
|
import pandas as pd |
|
from datetime import datetime |
|
import os |
|
import tempfile |
|
import base64 |
|
|
|
|
|
class GlaucomaModel(object): |
|
def __init__(self, |
|
cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification", |
|
seg_model_path='pamixsun/segformer_for_optic_disc_cup_segmentation', |
|
device=torch.device('cpu')): |
|
self.device = device |
|
|
|
self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path) |
|
self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval() |
|
|
|
self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path) |
|
self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval() |
|
|
|
self.cls_id2label = self.cls_model.config.id2label |
|
|
|
def glaucoma_pred(self, image): |
|
inputs = self.cls_extractor(images=image.copy(), return_tensors="pt") |
|
with torch.no_grad(): |
|
inputs.to(self.device) |
|
outputs = self.cls_model(**inputs).logits |
|
probs = F.softmax(outputs, dim=-1) |
|
disease_idx = probs.cpu()[0, :].numpy().argmax() |
|
confidence = probs.cpu()[0, disease_idx].item() * 100 |
|
return disease_idx, confidence |
|
|
|
def optic_disc_cup_pred(self, image): |
|
inputs = self.seg_extractor(images=image.copy(), return_tensors="pt") |
|
with torch.no_grad(): |
|
inputs.to(self.device) |
|
outputs = self.seg_model(**inputs) |
|
logits = outputs.logits.cpu() |
|
upsampled_logits = nn.functional.interpolate( |
|
logits, size=image.shape[:2], mode="bilinear", align_corners=False |
|
) |
|
seg_probs = F.softmax(upsampled_logits, dim=1) |
|
pred_disc_cup = upsampled_logits.argmax(dim=1)[0] |
|
|
|
|
|
|
|
cup_mask = pred_disc_cup == 2 |
|
disc_mask = pred_disc_cup == 1 |
|
|
|
|
|
cup_confidence = seg_probs[0, 2, cup_mask].mean().item() * 100 if cup_mask.any() else 0 |
|
disc_confidence = seg_probs[0, 1, disc_mask].mean().item() * 100 if disc_mask.any() else 0 |
|
|
|
return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence |
|
|
|
def process(self, image): |
|
disease_idx, cls_confidence = self.glaucoma_pred(image) |
|
disc_cup, cup_confidence, disc_confidence = self.optic_disc_cup_pred(image) |
|
|
|
try: |
|
vcdr = simple_vcdr(disc_cup) |
|
except: |
|
vcdr = np.nan |
|
|
|
mask = (disc_cup > 0).astype(np.uint8) |
|
x, y, w, h = cv2.boundingRect(mask) |
|
padding = max(50, int(0.2 * max(w, h))) |
|
x = max(x - padding, 0) |
|
y = max(y - padding, 0) |
|
w = min(w + 2 * padding, image.shape[1] - x) |
|
h = min(h + 2 * padding, image.shape[0] - y) |
|
|
|
cropped_image = image[y:y+h, x:x+w] if w >= 50 and h >= 50 else image.copy() |
|
_, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2) |
|
|
|
return disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image |
|
|
|
|
|
def simple_vcdr(mask): |
|
disc_area = np.sum(mask == 1) |
|
cup_area = np.sum(mask == 2) |
|
if disc_area == 0: |
|
return np.nan |
|
vcdr = cup_area / disc_area |
|
return vcdr |
|
|
|
def add_mask(image, mask, classes, colors, alpha=0.5): |
|
overlay = image.copy() |
|
for class_id, color in zip(classes, colors): |
|
overlay[mask == class_id] = color |
|
output = cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0) |
|
return output, overlay |
|
|
|
def get_confidence_level(confidence): |
|
if confidence >= 90: |
|
return "Very High" |
|
elif confidence >= 75: |
|
return "High" |
|
elif confidence >= 60: |
|
return "Moderate" |
|
elif confidence >= 45: |
|
return "Low" |
|
else: |
|
return "Very Low" |
|
|
|
def process_batch(model, images_data, progress_bar=None): |
|
results = [] |
|
for idx, (file_name, image) in enumerate(images_data): |
|
try: |
|
disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image) |
|
results.append({ |
|
'file_name': file_name, |
|
'diagnosis': model.cls_id2label[disease_idx], |
|
'confidence': cls_conf, |
|
'vcdr': vcdr, |
|
'cup_conf': cup_conf, |
|
'disc_conf': disc_conf, |
|
'processed_image': disc_cup_image, |
|
'cropped_image': cropped_image |
|
}) |
|
if progress_bar: |
|
progress_bar.progress((idx + 1) / len(images_data)) |
|
except Exception as e: |
|
st.error(f"Error processing {file_name}: {str(e)}") |
|
return results |
|
|
|
def save_results(results, original_images): |
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
|
df = pd.DataFrame([{ |
|
'File': r['file_name'], |
|
'Diagnosis': r['diagnosis'], |
|
'Confidence (%)': f"{r['confidence']:.1f}", |
|
'VCDR': f"{r['vcdr']:.3f}", |
|
'Cup Confidence (%)': f"{r['cup_conf']:.1f}", |
|
'Disc Confidence (%)': f"{r['disc_conf']:.1f}" |
|
} for r in results]) |
|
|
|
report_path = os.path.join(temp_dir, 'report.csv') |
|
df.to_csv(report_path, index=False) |
|
|
|
|
|
for result, orig_img in zip(results, original_images): |
|
img_name = result['file_name'] |
|
base_name = os.path.splitext(img_name)[0] |
|
|
|
|
|
orig_path = os.path.join(temp_dir, f"{base_name}_original.jpg") |
|
Image.fromarray(orig_img).save(orig_path) |
|
|
|
|
|
seg_path = os.path.join(temp_dir, f"{base_name}_segmentation.jpg") |
|
Image.fromarray(result['processed_image']).save(seg_path) |
|
|
|
|
|
roi_path = os.path.join(temp_dir, f"{base_name}_roi.jpg") |
|
Image.fromarray(result['cropped_image']).save(roi_path) |
|
|
|
|
|
zip_path = os.path.join(temp_dir, 'results.zip') |
|
with zipfile.ZipFile(zip_path, 'w') as zipf: |
|
for root, _, files in os.walk(temp_dir): |
|
for file in files: |
|
if file != 'results.zip': |
|
file_path = os.path.join(root, file) |
|
arcname = os.path.basename(file_path) |
|
zipf.write(file_path, arcname) |
|
|
|
with open(zip_path, 'rb') as f: |
|
return f.read() |
|
|
|
|
|
def main(): |
|
st.set_page_config(layout="wide", page_title="Glaucoma Screening Tool") |
|
|
|
print("App started") |
|
|
|
st.markdown(""" |
|
<h1 style='text-align: center;'>Glaucoma Screening from Retinal Fundus Images</h1> |
|
<p style='text-align: center; color: gray;'>Upload retinal images for automated glaucoma detection and optic disc/cup segmentation</p> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.sidebar.markdown("### 📤 Upload Images") |
|
uploaded_files = st.sidebar.file_uploader( |
|
"Upload Retinal Images", |
|
type=['png', 'jpeg', 'jpg'], |
|
accept_multiple_files=True, |
|
help="Support multiple images in PNG, JPEG formats" |
|
) |
|
|
|
print(f"Files uploaded: {uploaded_files}") |
|
|
|
st.sidebar.markdown("### Settings") |
|
max_batch = st.sidebar.number_input("Max Batch Size", |
|
min_value=1, |
|
max_value=100, |
|
value=20) |
|
|
|
if uploaded_files: |
|
print("Processing uploaded files") |
|
|
|
if len(uploaded_files) > max_batch: |
|
st.warning(f"Please upload maximum {max_batch} images at once.") |
|
return |
|
|
|
st.markdown(f"Total images: {len(uploaded_files)}") |
|
st.markdown(f"Using: {'GPU' if torch.cuda.is_available() else 'CPU'}") |
|
|
|
try: |
|
|
|
print("Initializing model") |
|
model = GlaucomaModel(device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")) |
|
|
|
|
|
images_data = [] |
|
original_images = [] |
|
print("Starting image processing") |
|
|
|
for file in uploaded_files: |
|
try: |
|
print(f"Processing file: {file.name}") |
|
image = Image.open(file).convert('RGB') |
|
image_np = np.array(image) |
|
images_data.append((file.name, image_np)) |
|
original_images.append(image_np) |
|
except Exception as e: |
|
print(f"Error processing file {file.name}: {str(e)}") |
|
st.error(f"Error loading {file.name}: {str(e)}") |
|
continue |
|
|
|
if not images_data: |
|
st.error("No valid images to process!") |
|
return |
|
|
|
progress = st.progress(0) |
|
st.write(f"Processing {len(images_data)} images...") |
|
|
|
|
|
print("Starting batch processing") |
|
results = process_batch(model, images_data, progress) |
|
print(f"Batch processing complete. Results: {len(results)}") |
|
|
|
if results: |
|
print("Showing results") |
|
|
|
for result in results: |
|
st.markdown(f"### Results for {result['file_name']}") |
|
st.markdown(f"**Diagnosis:** {result['diagnosis']}") |
|
st.markdown(f"**Confidence:** {result['confidence']:.1f}%") |
|
st.markdown(f"**VCDR:** {result['vcdr']:.3f}") |
|
|
|
|
|
st.image(result['processed_image'], caption="Segmentation") |
|
st.image(result['cropped_image'], caption="ROI") |
|
st.markdown("---") |
|
|
|
|
|
print("Generating ZIP file") |
|
zip_data = save_results(results, original_images) |
|
|
|
st.markdown("### Download Results") |
|
|
|
filename = f"glaucoma_screening_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip" |
|
st.markdown( |
|
f'<a href="data:application/zip;base64,{base64.b64encode(zip_data).decode()}" download="{filename}">Download All Results (ZIP)</a>', |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
st.markdown("### Summary") |
|
glaucoma_count = sum(1 for r in results if r['diagnosis'] == 'Glaucoma') |
|
normal_count = len(results) - glaucoma_count |
|
st.markdown(f"**Total Processed:** {len(results)}") |
|
st.markdown(f"**Glaucoma Detected:** {glaucoma_count}") |
|
st.markdown(f"**Normal:** {normal_count}") |
|
st.markdown(f"**Average Confidence:** {sum(r['confidence'] for r in results) / len(results):.1f}%") |
|
|
|
except Exception as e: |
|
print(f"Error in main processing: {str(e)}") |
|
st.error(f"An error occurred: {str(e)}") |
|
|
|
if __name__ == "__main__": |
|
print("Starting main") |
|
main() |
|
|