darragh's picture
Comments
aa8c4cb
import sys
import os
import glob
import shutil
import torch
import argparse
import mediapy
import cv2
import numpy as np
import gradio as gr
from skimage import color, img_as_ubyte
from monai import transforms, data
os.system("git clone https://github.com/darraghdog/Project-MONAI-research-contributions pmrc")
sys.path.append("pmrc/SwinUNETR/BTCV")
from swinunetr import SwinUnetrModelForInference, SwinUnetrConfig
ffmpeg_path = shutil.which('ffmpeg')
mediapy.set_ffmpeg(ffmpeg_path)
# Load model
model = SwinUnetrModelForInference.from_pretrained('darragh/swinunetr-btcv-tiny')
model.eval()
# Pull files from github
input_files = glob.glob('pmrc/SwinUNETR/BTCV/dataset/imagesSampleTs/*.nii.gz')
input_files = dict((f.split('/')[-1], f) for f in input_files)
# Load and process dicom with monai transforms
test_transform = transforms.Compose(
[
transforms.LoadImaged(keys=["image"]),
transforms.AddChanneld(keys=["image"]),
transforms.Spacingd(keys="image",
pixdim=(1.5, 1.5, 2.0),
mode="bilinear"),
transforms.ScaleIntensityRanged(keys=["image"],
a_min=-175.0,
a_max=250.0,
b_min=0.0,
b_max=1.0,
clip=True),
# transforms.Resized(keys=["image"], spatial_size = (256,256,-1)),
transforms.ToTensord(keys=["image"]),
])
# Create Data Loader
def create_dl(test_files):
ds = test_transform(test_files)
loader = data.DataLoader(ds,
batch_size=1,
shuffle=False)
return loader
# Inference and video generation
def generate_dicom_video(selected_file, n_frames):
# Data processor
test_file = input_files[selected_file]
test_files = [{'image': test_file}]
dl = create_dl(test_files)
batch = next(iter(dl))
# Select dicom slices
tst_inputs = batch["image"]
tst_inputs = tst_inputs[:,:,:,:,-n_frames:]
# Inference
with torch.no_grad():
outputs = model(tst_inputs,
(96,96,96),
8,
overlap=0.5,
mode="gaussian")
tst_outputs = torch.softmax(outputs.logits, 1)
tst_outputs = torch.argmax(tst_outputs, axis=1)
# Write frames to video
for inp, outp in zip(tst_inputs, tst_outputs):
frames = []
for idx in range(inp.shape[-1]):
# Segmentation
seg = outp[:,:,idx].numpy().astype(np.uint8)
# Input dicom frame
img = (inp[0,:,:,idx]*255).numpy().astype(np.uint8)
img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
frame = color.label2rgb(seg,img, bg_label = 0)
frame = img_as_ubyte(frame)
frame = np.concatenate((img, frame), 1)
frames.append(frame)
mediapy.write_video("dicom.mp4", frames, fps=4)
return 'dicom.mp4'
theme = 'dark-peach'
with gr.Blocks(theme=theme) as demo:
gr.Markdown('''<center><h1>SwinUnetr BTCV</h1></center>
This is a Gradio Blocks app of the winning transformer in the Beyond the Cranial Vault (BTCV) Segmentation Challenge, <a href="https://github.com/darraghdog/Project-MONAI-research-contributions/tree/main/SwinUNETR/BTCV">SwinUnetr</a> (tiny version).
''')
selected_dicom_key = gr.inputs.Dropdown(
choices=sorted(input_files),
type="value",
label="Select a dicom file")
n_frames = gr.Slider(1, 100, value=32, label="Choose the number of dicom slices to process", step = 1)
button_gen_video = gr.Button("Generate Video")
output_interpolation = gr.Video(label="Generated Video")
button_gen_video.click(fn=generate_dicom_video,
inputs=[selected_dicom_key, n_frames],
outputs=output_interpolation)
demo.launch(debug=True, enable_queue=True)