File size: 3,687 Bytes
130cbf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b95046
130cbf2
3b95046
130cbf2
 
147df28
130cbf2
 
 
 
a2799b0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import keras
from keras.models import Model
from tensorflow.keras.applications import densenet
from keras import backend as K
import tensorflow as tf
from keras.layers import Dense, GlobalAveragePooling2D
from keras.models import Model
import numpy as np

chexnet_weights_path = "brucechou1983_CheXNet_Keras_0.3.0_weights.h5"
IMG_SIZE = 320

base = densenet.DenseNet121(weights=None,
                            include_top=False,
                            input_shape=(IMG_SIZE,IMG_SIZE,3)
                           )
## workaround - add dummy layer then load weights then pop dummy layer, in order to match expected shape for pretrained weights
predictions = tf.keras.layers.Dense(14, activation='sigmoid', name='predictions')(base.output)
## ,by_name=True - could save on workaround, but don't know if names will necessarily match + how to validate? - https://github.com/keras-team/keras/issues/5397
base = tf.keras.Model(inputs=base.input, outputs=predictions) 
base.load_weights(chexnet_weights_path)
print("CheXNet loaded")
base.trainable=False # freeze most layers
base.training=False

base.layers.pop()

base.layers.pop()

### https://stackoverflow.com/questions/41668813/how-to-add-and-remove-new-layers-in-keras-after-loading-weights
new_model = GlobalAveragePooling2D()(base.layers[-4].output) 
new_model = Dense(4, activation='softmax')(new_model)
chexnet_model = keras.Model(base.input, new_model)

### Loading Weights from the saved Weights

from keras.models import load_model
chexnet_model.load_weights('./chexnet_with_data_aug_and standardized.05.hdf5')

### User-Defined Functions for Pre-Processing

labels = ['Negative for Pneumonia',
 'Typical Appearance',
 'Indeterminate Appearance',
 'Atypical Appearance']

def get_mean_std(inp):
    sample_data = []
    sample_data.append(np.array(inp))
    mean = np.mean(sample_data)
    std = np.std(sample_data)
    return mean, std

def standardize_image(inp, preprocess=True):
    """Load and preprocess image."""
    mean, std = get_mean_std(inp)
    if preprocess:
        inp = np.array(inp, dtype=np.float32)
        inp -= mean
        inp /= std
        inp = np.expand_dims(inp, axis=0)
    return inp

def preprocess_input(inp):
    inp = standardize_image(inp, preprocess=True)
    return inp
    
def get_prediction(inp):
    prediction = chexnet_model.predict(inp)
    confidences = {labels[i]: float(prediction[i]) for i in range(4)}
    return confidences

def classify_image(inp):
    inp = standardize_image(inp, preprocess=True)
    prediction = chexnet_model.predict(inp)
    confidences = {labels[i]: float(prediction[0][i]) for i in range(4)}
    return confidences

### Application Code

import gradio as gr

title = "Decision Support System to Diagnose COVID-19 using Chest X-Ray"
description = """
A prototype app for classifying the four COVID-19 Pneumonia Classes, created as a Demo using Gradio and HuggingFace Spaces
. It is currently able to classify a given image as one of the following four classes: Typical Appearance, Atypical Appearance, 
Indeterminate Appearance and Negative for Pneumonia. The System has an additional interpretability component to assist the user why a particular decision has been made
"""

examples = ['0a3afeef9c01.png', '65761e66de9f.png', '0bf205469ffa.png', '0c06f6f96a5a.png', '0c1ba97ad7c8.png', '0c2d323a04bf.png', '0c6b440ba98e.png', '0c7b15362352.png', '0c7e3c0eda27.png']
interpretation='default'
enable_queue=True

gr.Interface(fn=classify_image,inputs=gr.inputs.Image(shape=(320, 320)),outputs=gr.outputs.Label(num_top_classes=4),
title=title,description=description, examples=examples,interpretation=interpretation,enable_queue=enable_queue).launch()