|
import gradio as gr |
|
from fastai.vision.all import * |
|
import skimage |
|
from skimage import filters, segmentation, measure |
|
|
|
def crop_to_shape(arr, shape, cval=0): |
|
"""Crops a numpy array into the specified shape. If the array was larger, return centered crop. If it was smaller, |
|
return a larger array with the original data in the center""" |
|
if arr.ndim != len(shape): |
|
raise Exception("Array and crop shape dimensions do not match") |
|
|
|
arr_shape = np.array(arr.shape) |
|
shape = np.array(shape) |
|
max_shape = np.stack([arr_shape, shape]).max(axis=0) |
|
output_arr = np.ones(max_shape, dtype=arr.dtype) * cval |
|
|
|
arr_min = ((max_shape - arr_shape) / 2).astype(np.int) |
|
arr_max = arr_min + arr_shape |
|
slicer_obj = tuple(slice(idx_min, idx_max, 1) for idx_min, idx_max in zip(arr_min, arr_max)) |
|
output_arr[slicer_obj] = arr |
|
|
|
crop_min = ((max_shape - shape) / 2).astype(np.int) |
|
crop_max = crop_min + shape |
|
slicer_obj = tuple(slice(idx_min, idx_max, 1) for idx_min, idx_max in zip(crop_min, crop_max)) |
|
return output_arr[slicer_obj].copy() |
|
|
|
|
|
def crop_retina(image): |
|
"""Return a square crop of the image centered on the retina. |
|
This function does the following assumtions: |
|
- image is an np.array with dimensions [height, weight, channels] or [height, weight] |
|
- the background of the retinography will have a stark contrast with the rest of the image |
|
""" |
|
|
|
if image.ndim > 3: |
|
raise Exception("image has too many dimensions. Max 3") |
|
elif image.ndim < 2: |
|
raise Exception("image has too few dimensions. Min 2") |
|
|
|
|
|
image = crop_to_shape( |
|
image, |
|
np.array(image.shape) + np.array([20, 20, 0])[:image.ndim], |
|
cval=0 |
|
) |
|
|
|
|
|
if image.ndim == 3: |
|
bw_image = np.mean(image, axis=-1) |
|
else: |
|
bw_image = image |
|
|
|
|
|
thresh = filters.threshold_triangle(bw_image) |
|
binary = bw_image > thresh |
|
|
|
|
|
label_image = measure.label(binary) |
|
eye_region = sorted(measure.regionprops(label_image), key=lambda p: -p.area)[0] |
|
|
|
|
|
y_start, x_start, y_end, x_end = eye_region.bbox |
|
y_diff = y_end - y_start |
|
x_diff = x_end - x_start |
|
if x_diff > y_diff: |
|
if (y_start + x_diff) <= binary.shape[0]: |
|
y_end_x_diff = (y_start + x_diff) |
|
cropped_image = image[y_start:y_end_x_diff, x_start:x_end] |
|
else: |
|
y_start_x_diff = (y_end - x_diff) if (y_end - x_diff) > 0 else 0 |
|
cropped_image = image[y_start_x_diff:y_end, x_start:x_end] |
|
else: |
|
if (x_start + y_diff) <= binary.shape[1]: |
|
x_end_y_diff = (x_start + y_diff) |
|
cropped_image = image[y_start:y_end, x_start:x_end_y_diff] |
|
else: |
|
x_start_y_diff = (x_end - y_diff) if (x_end - y_diff) > 0 else 0 |
|
cropped_image = image[y_start:y_end, x_start_y_diff:x_end] |
|
|
|
|
|
max_axis = max(cropped_image.shape) |
|
if cropped_image.ndim == 3: |
|
square_crop = (max_axis, max_axis, cropped_image.shape[-1]) |
|
else: |
|
square_crop = (max_axis, max_axis) |
|
square_image = crop_to_shape(cropped_image, square_crop) |
|
return square_image |
|
|
|
def equalize_histogram( |
|
img, |
|
mean_rgb_vals=np.array([120, 100, 80]), |
|
std_rgb_vals=np.array([75.40455101, 60.72748057, 46.14914927]) |
|
): |
|
mask = img < 10 |
|
img_masked = np.ma.array(img, mask=mask, fill_value=0) |
|
|
|
if img.ndim == 3: |
|
equalized_img = (img - img_masked.mean(axis=(0, 1))) / img.std(axis=(0, 1)) * std_rgb_vals + mean_rgb_vals |
|
elif img.ndim == 2: |
|
equalized_img = (img - img_masked.mean()) / img.std() * std_rgb_vals[1] + mean_rgb_vals[1] |
|
else: |
|
raise Exception("img ndim is neither 2 nor 3") |
|
equalized_img = (equalized_img * ~mask).clip(0, 255).astype(np.uint8) |
|
return equalized_img |
|
|
|
|
|
def crop_and_equalize_fd(img, crop_resize=None, equalize=True): |
|
retina_arr = crop_retina(np.array(img)) |
|
if equalize: |
|
retina_arr = equalize_histogram(retina_arr) |
|
retina_img = Image.fromarray(retina_arr) |
|
if crop_resize: |
|
retina_img = retina_img.resize((crop_resize, crop_resize), Image.LANCZOS) |
|
return retina_img |
|
|
|
|
|
learn = load_learner('convnext_base.pkl') |
|
labels = learn.dls.vocab |
|
|
|
|
|
def predict(img): |
|
img = PILImage.create(img) |
|
im= crop_and_equalize_fd(img,512) |
|
im.save('tmp.jpg') |
|
pred,_ = learn.tta(dl=learn.dls.test_dl(['tmp.jpg'])) |
|
rg_likelihood = float(pred[0][1]) |
|
rg_binary = bool(pred[0][1]>0.5) |
|
ungradability_score = 1-float(pred[0][np.argmax(pred[0])]) |
|
ungradability_binary = bool(float(pred[0][np.argmax(pred[0])])<0.9) |
|
return rg_likelihood,rg_binary,ungradability_score,ungradability_binary |
|
|
|
title = "Clasificaci贸n de im谩genes" |
|
description = "Demo for the submission to the airogs challenge." |
|
examples = ['TRAIN000000.jpg','TRAIN001000.jpg','TRAIN002000.jpg','TRAIN003000.jpg'] |
|
interpretation='default' |
|
enable_queue=True |
|
|
|
gr.Interface(fn=predict,inputs=gr.inputs.Image(shape=(512, 512)),outputs=[gr.outputs.Textbox(label="Multiple referable glaucoma likelihoods"),gr.outputs.Textbox(label="Multiple referable glaucoma binary"),gr.outputs.Textbox(label="Multiple ungradability score"),gr.outputs.Textbox(label="Multiple ungradability binary")],title=title,description=description,examples=examples,interpretation=interpretation,enable_queue=enable_queue).launch() |
|
|