ccaglieri's picture
Update app.py
737fde0
raw
history blame
3.07 kB
import torch
import cv2
import torch.nn as nn
import numpy as np
from torchvision import models, transforms
import time
import os
import copy
import pickle
from PIL import Image
import datetime
import gdown
import zipfile
import urllib.request
import gradio as gr
IMG_SIZE = 512
CLASSES = [ "No DR", "Mild", "Moderate", "Severe", "Proliferative DR" ]
checkpoint = "./demo_checkpoint_convnext.pth"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load(checkpoint).to(device)
global_transforms = transforms.Compose([
transforms.ToPILImage(),
transforms.Lambda(lambda image: image.convert('RGB')),
transforms.Resize(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize([0.2786802, 0.2786802, 0.2786802], [0.16637428, 0.16637428, 0.16637428])
])
def crop_image_from_gray(img,tol=7):
mask = img>tol
img1=img[np.ix_(mask.any(1),mask.any(0))]
img2=img[np.ix_(mask.any(1),mask.any(0))]
img3=img[np.ix_(mask.any(1),mask.any(0))]
img = np.stack([img1,img2,img3],axis=-1)
return img
def circle_crop(img):
height, width = img.shape
x = int(width/2)
y = int(height/2)
r = np.amin((x,y))
circle_img = np.zeros((height, width), np.uint8)
cv2.circle(circle_img, (x,y), int(r), 1, thickness=-1)
img = cv2.bitwise_and(img, img, mask=circle_img)
img = crop_image_from_gray(img)
return img
def preprocess(img):
# Extract Green Channel
img = img[:,:,1]
#CLAHE
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
img = clahe.apply(img)
# Circle crop
img = circle_crop(img)
# Resize
img = cv2.resize(img, (IMG_SIZE,IMG_SIZE))
return img
def do_inference(img):
img = preprocess(img)
img_t = global_transforms(img)
batch_t = torch.unsqueeze(img_t, 0)
model.eval()
# We don't need gradients for test, so wrap in
# no_grad to save memory
with torch.no_grad():
batch_t = batch_t.to(device)
# forward propagation
output = model( batch_t)
# get prediction
probs = torch.nn.functional.softmax(output, dim=1)
output = torch.argsort(probs, dim=1, descending=True).cpu().numpy()[0].astype(int)
probs = probs.cpu().numpy()[0]
probs = probs[output]
labels = np.array(CLASSES)[output]
return {labels[i]: round(float(probs[i]),2) for i in range(len(labels))}
im = gr.inputs.Image(shape=(512, 512), image_mode='RGB',
invert_colors=False, source="upload",
type="pil")
title = "ConvNeXt for Diabetic Retinopathy Detection"
description = ""
examples = [['./noDr.png'],['./severe.png']]
#article="<p style='text-align: center'><a href='https://github.com/mawady/colab-recipes-cv' target='_blank'>Colab Recipes for Computer Vision - Dr. Mohamed Elawady</a></p>"
iface = gr.Interface(
do_inference,
im,
gr.outputs.Label(num_top_classes=5),
live=False,
interpretation=None,
title=title,
description=description,
examples=examples
)
#iface.test_launch()
iface.launch()