File size: 3,757 Bytes
95f4b63
 
f4343d2
eaeb709
96884f9
3886cf7
 
96884f9
eaeb709
 
95f4b63
 
677dcdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410c1fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
677dcdb
 
 
 
 
 
 
 
 
 
 
 
410c1fe
eaeb709
0d43e79
3b5fa06
0d43e79
 
 
 
b52ffae
be87173
 
0d43e79
 
 
 
 
 
 
 
 
 
 
 
be87173
0d43e79
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import sys
import csv
import numpy as np
import gradio as gr
import nibabel as nib
import matplotlib.pyplot as plt

from scipy import ndimage
from huggingface_hub import from_pretrained_keras

csv.field_size_limit(sys.maxsize)

def read_nifti_file(filepath):
    """Read and load volume"""
    # Read file
    scan = nib.load(filepath)
    # Get raw data
    scan = scan.get_fdata()
    return scan

def normalize(volume):
    """Normalize the volume"""
    min = -1000
    max = 400
    volume[volume < min] = min
    volume[volume > max] = max
    volume = (volume - min) / (max - min)
    volume = volume.astype("float32")
    return volume

def resize_volume(img):
    """Resize across z-axis"""
    # Set the desired depth
    desired_depth = 64
    desired_width = 128
    desired_height = 128
    # Get current depth
    current_depth = img.shape[-1]
    current_width = img.shape[0]
    current_height = img.shape[1]
    # Compute depth factor
    depth = current_depth / desired_depth
    width = current_width / desired_width
    height = current_height / desired_height
    depth_factor = 1 / depth
    width_factor = 1 / width
    height_factor = 1 / height
    # Rotate
    img = ndimage.rotate(img, 90, reshape=False)
    # Resize across z-axis
    img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)
    return img

def process_scan(path):
    """Read and resize volume"""
    # Read scan
    volume = read_nifti_file(path)
    # Normalize
    volume = normalize(volume)
    # Resize width, height and depth
    volume = resize_volume(volume)
    return volume

def plot_slices(num_rows, num_columns, width, height, data):
    """Plot a montage of 20 CT slices"""
    data = np.rot90(np.array(data))
    data = np.transpose(data)
    data = np.reshape(data, (num_rows, num_columns, width, height))
    rows_data, columns_data = data.shape[0], data.shape[1]
    heights = [slc[0].shape[0] for slc in data]
    widths = [slc.shape[1] for slc in data[0]]
    fig_width = 12.0
    fig_height = fig_width * sum(heights) / sum(widths)
    f, axarr = plt.subplots(
        rows_data,
        columns_data,
        figsize=(fig_width, fig_height),
        gridspec_kw={"height_ratios": heights},
    )
    for i in range(rows_data):
        for j in range(columns_data):
            axarr[i, j].imshow(data[i][j], cmap="gray")
            axarr[i, j].axis("off")

    return f

def infer(filename):
    vol = process_scan(filename.name)
    vol = np.expand_dims(vol, axis=0)
    prediction = model.predict(vol)[0]

    scores = [1 - prediction[0], prediction[0]]

    class_names = ["normal", "abnormal"]
    result = []
    for score, name in zip(scores, class_names):
        result = result + [f"This model is {(100 * score):.2f} percent confident that CT scan is {name}"]

    return result, plot_slices(2, 10, 128, 128, vol[0, :, :, :20])

if __name__ == "__main__":
    model = from_pretrained_keras('keras-io/3D_CNN_Pneumonia')
    
    inputs = gr.inputs.File()
    outputs = [gr.outputs.Textbox(), 'plot']
    
    title = "Predicting Viral Pneumonia in CT scans using 3D CNN"
    description = "This space implements 3D convolutional neural network to predict the presence of viral pneumonia in computer tomography scans."
    article = """<p style='text-align: center'>
        <a href='https://keras.io/examples/vision/3D_image_classification/' target='_blank'>Keras Example given by Hasib Zunair</a>
        <br>
        Space by Faizan Shaikh
    </p>
    """
    
    iface = gr.Interface(
        infer, 
        inputs,
        outputs,
        title=title,
        article=article,
        description=description,
        examples=['example_1_normal.nii.gz']
    )
    
    iface.launch(enable_queue=True)