Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import sys | |
import csv | |
import numpy as np | |
import gradio as gr | |
import nibabel as nib | |
import matplotlib.pyplot as plt | |
from scipy import ndimage | |
from huggingface_hub import from_pretrained_keras | |
csv.field_size_limit(sys.maxsize) | |
def read_nifti_file(filepath): | |
"""Read and load volume""" | |
# Read file | |
scan = nib.load(filepath) | |
# Get raw data | |
scan = scan.get_fdata() | |
return scan | |
def normalize(volume): | |
"""Normalize the volume""" | |
min = -1000 | |
max = 400 | |
volume[volume < min] = min | |
volume[volume > max] = max | |
volume = (volume - min) / (max - min) | |
volume = volume.astype("float32") | |
return volume | |
def resize_volume(img): | |
"""Resize across z-axis""" | |
# Set the desired depth | |
desired_depth = 64 | |
desired_width = 128 | |
desired_height = 128 | |
# Get current depth | |
current_depth = img.shape[-1] | |
current_width = img.shape[0] | |
current_height = img.shape[1] | |
# Compute depth factor | |
depth = current_depth / desired_depth | |
width = current_width / desired_width | |
height = current_height / desired_height | |
depth_factor = 1 / depth | |
width_factor = 1 / width | |
height_factor = 1 / height | |
# Rotate | |
img = ndimage.rotate(img, 90, reshape=False) | |
# Resize across z-axis | |
img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1) | |
return img | |
def process_scan(path): | |
"""Read and resize volume""" | |
# Read scan | |
volume = read_nifti_file(path) | |
# Normalize | |
volume = normalize(volume) | |
# Resize width, height and depth | |
volume = resize_volume(volume) | |
return volume | |
def plot_slices(num_rows, num_columns, width, height, data): | |
"""Plot a montage of 20 CT slices""" | |
data = np.rot90(np.array(data)) | |
data = np.transpose(data) | |
data = np.reshape(data, (num_rows, num_columns, width, height)) | |
rows_data, columns_data = data.shape[0], data.shape[1] | |
heights = [slc[0].shape[0] for slc in data] | |
widths = [slc.shape[1] for slc in data[0]] | |
fig_width = 12.0 | |
fig_height = fig_width * sum(heights) / sum(widths) | |
f, axarr = plt.subplots( | |
rows_data, | |
columns_data, | |
figsize=(fig_width, fig_height), | |
gridspec_kw={"height_ratios": heights}, | |
) | |
for i in range(rows_data): | |
for j in range(columns_data): | |
axarr[i, j].imshow(data[i][j], cmap="gray") | |
axarr[i, j].axis("off") | |
return f | |
def infer(filename): | |
vol = process_scan(filename.name) | |
vol = np.expand_dims(vol, axis=0) | |
prediction = model.predict(vol)[0] | |
scores = [1 - prediction[0], prediction[0]] | |
class_names = ["normal", "abnormal"] | |
result = [] | |
for score, name in zip(scores, class_names): | |
result = result + [f"This model is {(100 * score):.2f} percent confident that CT scan is {name}"] | |
return result, plot_slices(2, 10, 128, 128, vol[0, :, :, :20]) | |
if __name__ == "__main__": | |
model = from_pretrained_keras('keras-io/3D_CNN_Pneumonia') | |
inputs = gr.inputs.File() | |
outputs = [gr.outputs.Textbox(), 'plot'] | |
title = "Predicting Viral Pneumonia in CT scans using 3D CNN" | |
description = "This space implements 3D convolutional neural network to predict the presence of viral pneumonia in computer tomography scans." | |
article = """<p style='text-align: center'> | |
<a href='https://keras.io/examples/vision/3D_image_classification/' target='_blank'>Keras Example given by Hasib Zunair</a> | |
<br> | |
Space by Faizan Shaikh | |
</p> | |
""" | |
iface = gr.Interface( | |
infer, | |
inputs, | |
outputs, | |
title=title, | |
article=article, | |
description=description, | |
examples=['example_1_normal.nii.gz'] | |
) | |
iface.launch(enable_queue=True) |