|
from pipline import Transformer_Regression, extract_regions_Last , compute_ratios |
|
import torch |
|
import torchvision.transforms as transforms |
|
from torch.nn import functional as F |
|
import cv2 |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
image_shape = 384 |
|
batch_size=1 |
|
dim_patch=4 |
|
num_classes=3 |
|
label_smoothing=0.1 |
|
scale=1 |
|
import time |
|
start = time.time() |
|
torch.manual_seed(0) |
|
|
|
|
|
|
|
tfms = transforms.Compose([ |
|
transforms.Resize((image_shape, image_shape)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(0.5,0.5) |
|
|
|
|
|
|
|
]) |
|
|
|
def Final_Compute_regression_results_Sample(Model, batch_sampler,num_head=2): |
|
Model.eval() |
|
score_cup = [] |
|
score_disc = [] |
|
yreg_pred = [] |
|
yreg_true = [] |
|
with torch.no_grad(): |
|
|
|
train_batch_tfms = batch_sampler['image'].to(device=device) |
|
|
|
ytrue_seg = batch_sampler['image_original'] |
|
scores = Model(train_batch_tfms.unsqueeze(0)) |
|
|
|
yseg_pred = F.interpolate(scores['seg'], size=(ytrue_seg.shape[0], ytrue_seg.shape[1]), mode='bilinear', |
|
align_corners=True) |
|
|
|
|
|
|
|
Regions_crop = extract_regions_Last(np.array(batch_sampler['image_original']), |
|
yseg_pred.argmax(1).long()[0].detach().cpu().numpy()) |
|
Regions_crop['image'] = Image.fromarray(np.uint8(Regions_crop['image'])).convert('RGB') |
|
|
|
|
|
ytrue_seg_crop = ytrue_seg[Regions_crop['cord'][0]:Regions_crop['cord'][1], |
|
Regions_crop['cord'][2]:Regions_crop['cord'][3]] |
|
ytrue_seg_crop = np.expand_dims(ytrue_seg_crop, axis=0) |
|
|
|
if num_head==2: |
|
scores = Model((tfms(Regions_crop['image']).unsqueeze(0)).to(device)) |
|
yseg_pred_crop = F.interpolate(scores['seg_aux_1'], size=(ytrue_seg_crop.shape[1], ytrue_seg_crop.shape[2]), |
|
mode='bilinear', align_corners=True) |
|
yseg_pred[:, :, Regions_crop['cord'][0]:Regions_crop['cord'][1], |
|
Regions_crop['cord'][2]:Regions_crop['cord'][3]] = yseg_pred_crop |
|
|
|
|
|
yseg_pred = torch.softmax(yseg_pred, dim=1) |
|
yseg_pred = yseg_pred.argmax(1).long() |
|
yseg_pred = ((yseg_pred).long()).detach().cpu().numpy() |
|
ratios = compute_ratios(yseg_pred[0]) |
|
yreg_pred.append(ratios.vcdr) |
|
|
|
|
|
p_img = batch_sampler['image'].to(device=device).unsqueeze(0) |
|
p_img = F.interpolate(p_img, size=(yseg_pred.shape[1], yseg_pred.shape[2]), |
|
mode='bilinear', align_corners=True) |
|
|
|
image_orig = (p_img[0] * 0.5 + 0.5).permute(1, 2, 0).detach().cpu().numpy() |
|
image_orig=np.uint8(image_orig*255) |
|
|
|
|
|
|
|
|
|
image_cont = image_orig.copy() |
|
|
|
|
|
ret, thresh = cv2.threshold(np.uint8(yseg_pred[0]), 1, 2, 0) |
|
|
|
conts, hir = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) |
|
cv2.drawContours(image_cont, conts, -1, (0, 255, 0), 2) |
|
|
|
ret, thresh = cv2.threshold(np.uint8(yseg_pred[0]), 0, 2, 0) |
|
|
|
conts, hir = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) |
|
cv2.drawContours(image_cont, conts, -1, (0, 0, 255), 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ratios.vcdr < 0.6: |
|
glaucoma = 'None' |
|
else: |
|
glaucoma = 'May be there is a risk of Glaucoma' |
|
|
|
|
|
|
|
|
|
return image_cont, ratios.vcdr, glaucoma, Regions_crop |
|
|
|
|
|
DeepLab=Transformer_Regression(image_dim=image_shape,dim_patch=dim_patch,num_classes=3,scale=scale,feat_dim=128) |
|
DeepLab.to(device=device) |
|
DeepLab.load_state_dict(torch.load("TrainAll_Maghrabi84_50iteration_SWIN.pth.tar", map_location=torch.device(device))) |
|
|
|
def infer(img): |
|
|
|
|
|
sample_batch = dict() |
|
sample_batch['image_original'] = img |
|
|
|
im_retina_pil = Image.fromarray(img) |
|
|
|
im_retina_pil = tfms(im_retina_pil) |
|
sample_batch['image'] = im_retina_pil |
|
|
|
|
|
result, ratio, diagnosis, cropped = Final_Compute_regression_results_Sample(DeepLab, sample_batch, num_head=2) |
|
|
|
|
|
cropped = result[cropped['cord'][0] :cropped['cord'][1] , |
|
cropped['cord'][2] :cropped['cord'][3] ] |
|
|
|
return ratio, diagnosis, result, cropped |
|
|
|
|
|
title = "Glaucoma Detection in Retinal Fundus Images" |
|
description = "The method detects disc and cup in the retinal image, then it computes the Vertical cup to disc ratio" |
|
|
|
outputs = [gr.Textbox(label="Vertical cup to disc ratio:"), gr.Textbox(label="predicted diagnosis (Rule of thumb ~0.6 or greater is suspicious)"), gr.Image(label='labeled image'), gr.Image(label='zoomed in')] |
|
with gr.Blocks(css='#title {text-align : center;} ') as demo: |
|
with gr.Row(): |
|
gr.Markdown( |
|
f''' |
|
# {title} |
|
{description} |
|
|
|
''', |
|
elem_id='title' |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Image(label="Upload Your Retinal Fundus Image") |
|
btn = gr.Button(value='Submit') |
|
examples = gr.Examples( |
|
['M00027.png','M00056.png','M00073.png','M00093.png', 'M00018.png', 'M00034.png'], |
|
inputs=[prompt], fn=infer, outputs=[outputs], cache_examples=False) |
|
with gr.Column(): |
|
with gr.Row(): |
|
text1 = gr.Textbox(label="Vertical Cup to Disc Ratio:") |
|
text2 = gr.Textbox(label="Predicted Diagnosis (Rule of thumb ~0.6 or greater is suspicious)") |
|
img = gr.Image(label='Detected disc and cup') |
|
zoom = gr.Image(label='Croppped') |
|
|
|
outputs = [text1,text2,img,zoom] |
|
|
|
btn.click(fn=infer, inputs=prompt, outputs=outputs) |
|
|
|
|
|
if __name__ == '__main__': |
|
demo.launch() |