File size: 5,841 Bytes
77a9c88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e9ee25
77a9c88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae544c1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
from fastai.vision.all import *
import skimage
from skimage import filters, segmentation, measure

def crop_to_shape(arr, shape, cval=0):
    """Crops a numpy array into the specified shape. If the array was larger, return centered crop. If it was smaller,
    return a larger array with the original data in the center"""
    if arr.ndim != len(shape):
        raise Exception("Array and crop shape dimensions do not match")

    arr_shape = np.array(arr.shape)
    shape = np.array(shape)
    max_shape = np.stack([arr_shape, shape]).max(axis=0)
    output_arr = np.ones(max_shape, dtype=arr.dtype) * cval

    arr_min = ((max_shape - arr_shape) / 2).astype(np.int)
    arr_max = arr_min + arr_shape
    slicer_obj = tuple(slice(idx_min, idx_max, 1) for idx_min, idx_max in zip(arr_min, arr_max))
    output_arr[slicer_obj] = arr

    crop_min = ((max_shape - shape) / 2).astype(np.int)
    crop_max = crop_min + shape
    slicer_obj = tuple(slice(idx_min, idx_max, 1) for idx_min, idx_max in zip(crop_min, crop_max))
    return output_arr[slicer_obj].copy() # Return a copy of the view, so the rest of memory can be GC


def crop_retina(image):
    """Return a square crop of the image centered on the retina.
    This function does the following assumtions:
    - image is an np.array with dimensions [height, weight, channels] or [height, weight]
    - the background of the retinography will have a stark contrast with the rest of the image
    """
    # Check dimensionality of the array is valid
    if image.ndim > 3:
        raise Exception("image has too many dimensions. Max 3")
    elif image.ndim < 2:
        raise Exception("image has too few dimensions. Min 2")
    
    # Rescale image to ensure there will be a black border around (even if the original was already cropped)
    image = crop_to_shape(
        image,
        np.array(image.shape) + np.array([20, 20, 0])[:image.ndim],
        cval=0
    )
    
    # If image is an RGB array, convert to grayscale
    if image.ndim == 3:
        bw_image = np.mean(image, axis=-1)
    else:
        bw_image = image
    
    # Find and apply threshold, to create a binary mask
    thresh = filters.threshold_triangle(bw_image)
    binary = bw_image > thresh
        
    # Label image regions and select the largest one (the retina)
    label_image = measure.label(binary)
    eye_region = sorted(measure.regionprops(label_image), key=lambda p: -p.area)[0]
    
    # Crop around the retina
    y_start, x_start, y_end, x_end = eye_region.bbox
    y_diff = y_end - y_start
    x_diff = x_end - x_start
    if x_diff > y_diff:
        if (y_start + x_diff) <= binary.shape[0]:
            y_end_x_diff = (y_start + x_diff)
            cropped_image = image[y_start:y_end_x_diff, x_start:x_end]
        else:
            y_start_x_diff = (y_end - x_diff) if (y_end - x_diff) > 0 else 0
            cropped_image = image[y_start_x_diff:y_end, x_start:x_end]
    else:
        if (x_start + y_diff) <= binary.shape[1]:
            x_end_y_diff = (x_start + y_diff)
            cropped_image = image[y_start:y_end, x_start:x_end_y_diff]
        else:
            x_start_y_diff = (x_end - y_diff) if (x_end - y_diff) > 0 else 0
            cropped_image = image[y_start:y_end, x_start_y_diff:x_end]

    # Ensure aspect ratio will be square
    max_axis = max(cropped_image.shape)
    if cropped_image.ndim == 3:
        square_crop = (max_axis, max_axis, cropped_image.shape[-1])
    else:
        square_crop = (max_axis, max_axis)
    square_image = crop_to_shape(cropped_image, square_crop)
    return square_image

def equalize_histogram(
        img,
        mean_rgb_vals=np.array([120, 100, 80]),
        std_rgb_vals=np.array([75.40455101, 60.72748057, 46.14914927])
):
    mask = img < 10
    img_masked = np.ma.array(img, mask=mask, fill_value=0)

    if img.ndim == 3:
        equalized_img = (img - img_masked.mean(axis=(0, 1))) / img.std(axis=(0, 1)) * std_rgb_vals + mean_rgb_vals
    elif img.ndim == 2:
        equalized_img = (img - img_masked.mean()) / img.std() * std_rgb_vals[1] + mean_rgb_vals[1]
    else:
        raise Exception("img ndim is neither 2 nor 3")
    equalized_img = (equalized_img * ~mask).clip(0, 255).astype(np.uint8)
    return equalized_img


def crop_and_equalize_fd(img, crop_resize=None, equalize=True):
        retina_arr = crop_retina(np.array(img))
        if equalize:
            retina_arr = equalize_histogram(retina_arr)
        retina_img = Image.fromarray(retina_arr)
        if crop_resize:
            retina_img = retina_img.resize((crop_resize, crop_resize), Image.LANCZOS)
        return retina_img


learn = load_learner('convnext_base.pkl')
labels = learn.dls.vocab


def predict(img):
    img = PILImage.create(img)
    im= crop_and_equalize_fd(img,512)
    im.save('tmp.jpg')
    pred,_ = learn.tta(dl=learn.dls.test_dl(['tmp.jpg']))
    rg_likelihood = float(pred[0][1])
    rg_binary = bool(pred[0][1]>0.5)
    ungradability_score = 1-float(pred[0][np.argmax(pred[0])])
    ungradability_binary = bool(float(pred[0][np.argmax(pred[0])])<0.9)
    return rg_likelihood,rg_binary,ungradability_score,ungradability_binary

title = "Clasificación de imágenes"
description = "Demo for the submission to the airogs challenge."
examples = ['TRAIN000000.jpg','TRAIN001000.jpg','TRAIN002000.jpg','TRAIN003000.jpg']
interpretation='default'
enable_queue=True

gr.Interface(fn=predict,inputs=gr.inputs.Image(shape=(512, 512)),outputs=[gr.outputs.Textbox(label="Multiple referable glaucoma likelihoods"),gr.outputs.Textbox(label="Multiple referable glaucoma binary"),gr.outputs.Textbox(label="Multiple ungradability score"),gr.outputs.Textbox(label="Multiple ungradability binary")],title=title,description=description,examples=examples,interpretation=interpretation,enable_queue=enable_queue).launch()