jalFaizy's picture
Update app.py
3b5fa06
raw
history blame contribute delete
No virus
3.76 kB
import sys
import csv
import numpy as np
import gradio as gr
import nibabel as nib
import matplotlib.pyplot as plt
from scipy import ndimage
from huggingface_hub import from_pretrained_keras
csv.field_size_limit(sys.maxsize)
def read_nifti_file(filepath):
"""Read and load volume"""
# Read file
scan = nib.load(filepath)
# Get raw data
scan = scan.get_fdata()
return scan
def normalize(volume):
"""Normalize the volume"""
min = -1000
max = 400
volume[volume < min] = min
volume[volume > max] = max
volume = (volume - min) / (max - min)
volume = volume.astype("float32")
return volume
def resize_volume(img):
"""Resize across z-axis"""
# Set the desired depth
desired_depth = 64
desired_width = 128
desired_height = 128
# Get current depth
current_depth = img.shape[-1]
current_width = img.shape[0]
current_height = img.shape[1]
# Compute depth factor
depth = current_depth / desired_depth
width = current_width / desired_width
height = current_height / desired_height
depth_factor = 1 / depth
width_factor = 1 / width
height_factor = 1 / height
# Rotate
img = ndimage.rotate(img, 90, reshape=False)
# Resize across z-axis
img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)
return img
def process_scan(path):
"""Read and resize volume"""
# Read scan
volume = read_nifti_file(path)
# Normalize
volume = normalize(volume)
# Resize width, height and depth
volume = resize_volume(volume)
return volume
def plot_slices(num_rows, num_columns, width, height, data):
"""Plot a montage of 20 CT slices"""
data = np.rot90(np.array(data))
data = np.transpose(data)
data = np.reshape(data, (num_rows, num_columns, width, height))
rows_data, columns_data = data.shape[0], data.shape[1]
heights = [slc[0].shape[0] for slc in data]
widths = [slc.shape[1] for slc in data[0]]
fig_width = 12.0
fig_height = fig_width * sum(heights) / sum(widths)
f, axarr = plt.subplots(
rows_data,
columns_data,
figsize=(fig_width, fig_height),
gridspec_kw={"height_ratios": heights},
)
for i in range(rows_data):
for j in range(columns_data):
axarr[i, j].imshow(data[i][j], cmap="gray")
axarr[i, j].axis("off")
return f
def infer(filename):
vol = process_scan(filename.name)
vol = np.expand_dims(vol, axis=0)
prediction = model.predict(vol)[0]
scores = [1 - prediction[0], prediction[0]]
class_names = ["normal", "abnormal"]
result = []
for score, name in zip(scores, class_names):
result = result + [f"This model is {(100 * score):.2f} percent confident that CT scan is {name}"]
return result, plot_slices(2, 10, 128, 128, vol[0, :, :, :20])
if __name__ == "__main__":
model = from_pretrained_keras('keras-io/3D_CNN_Pneumonia')
inputs = gr.inputs.File()
outputs = [gr.outputs.Textbox(), 'plot']
title = "Predicting Viral Pneumonia in CT scans using 3D CNN"
description = "This space implements 3D convolutional neural network to predict the presence of viral pneumonia in computer tomography scans."
article = """<p style='text-align: center'>
<a href='https://keras.io/examples/vision/3D_image_classification/' target='_blank'>Keras Example given by Hasib Zunair</a>
<br>
Space by Faizan Shaikh
</p>
"""
iface = gr.Interface(
infer,
inputs,
outputs,
title=title,
article=article,
description=description,
examples=['example_1_normal.nii.gz']
)
iface.launch(enable_queue=True)