jalFaizy's picture
Update app.py
3b5fa06
import sys
import csv
import numpy as np
import gradio as gr
import nibabel as nib
import matplotlib.pyplot as plt
from scipy import ndimage
from huggingface_hub import from_pretrained_keras
csv.field_size_limit(sys.maxsize)
def read_nifti_file(filepath):
"""Read and load volume"""
# Read file
scan = nib.load(filepath)
# Get raw data
scan = scan.get_fdata()
return scan
def normalize(volume):
"""Normalize the volume"""
min = -1000
max = 400
volume[volume < min] = min
volume[volume > max] = max
volume = (volume - min) / (max - min)
volume = volume.astype("float32")
return volume
def resize_volume(img):
"""Resize across z-axis"""
# Set the desired depth
desired_depth = 64
desired_width = 128
desired_height = 128
# Get current depth
current_depth = img.shape[-1]
current_width = img.shape[0]
current_height = img.shape[1]
# Compute depth factor
depth = current_depth / desired_depth
width = current_width / desired_width
height = current_height / desired_height
depth_factor = 1 / depth
width_factor = 1 / width
height_factor = 1 / height
# Rotate
img = ndimage.rotate(img, 90, reshape=False)
# Resize across z-axis
img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)
return img
def process_scan(path):
"""Read and resize volume"""
# Read scan
volume = read_nifti_file(path)
# Normalize
volume = normalize(volume)
# Resize width, height and depth
volume = resize_volume(volume)
return volume
def plot_slices(num_rows, num_columns, width, height, data):
"""Plot a montage of 20 CT slices"""
data = np.rot90(np.array(data))
data = np.transpose(data)
data = np.reshape(data, (num_rows, num_columns, width, height))
rows_data, columns_data = data.shape[0], data.shape[1]
heights = [slc[0].shape[0] for slc in data]
widths = [slc.shape[1] for slc in data[0]]
fig_width = 12.0
fig_height = fig_width * sum(heights) / sum(widths)
f, axarr = plt.subplots(
rows_data,
columns_data,
figsize=(fig_width, fig_height),
gridspec_kw={"height_ratios": heights},
)
for i in range(rows_data):
for j in range(columns_data):
axarr[i, j].imshow(data[i][j], cmap="gray")
axarr[i, j].axis("off")
return f
def infer(filename):
vol = process_scan(filename.name)
vol = np.expand_dims(vol, axis=0)
prediction = model.predict(vol)[0]
scores = [1 - prediction[0], prediction[0]]
class_names = ["normal", "abnormal"]
result = []
for score, name in zip(scores, class_names):
result = result + [f"This model is {(100 * score):.2f} percent confident that CT scan is {name}"]
return result, plot_slices(2, 10, 128, 128, vol[0, :, :, :20])
if __name__ == "__main__":
model = from_pretrained_keras('keras-io/3D_CNN_Pneumonia')
inputs = gr.inputs.File()
outputs = [gr.outputs.Textbox(), 'plot']
title = "Predicting Viral Pneumonia in CT scans using 3D CNN"
description = "This space implements 3D convolutional neural network to predict the presence of viral pneumonia in computer tomography scans."
article = """<p style='text-align: center'>
<a href='https://keras.io/examples/vision/3D_image_classification/' target='_blank'>Keras Example given by Hasib Zunair</a>
<br>
Space by Faizan Shaikh
</p>
"""
iface = gr.Interface(
infer,
inputs,
outputs,
title=title,
article=article,
description=description,
examples=['example_1_normal.nii.gz']
)
iface.launch(enable_queue=True)