Spaces:
Runtime error
Runtime error
| import sys | |
| import os | |
| import glob | |
| import shutil | |
| import torch | |
| import argparse | |
| import mediapy | |
| import cv2 | |
| import numpy as np | |
| import gradio as gr | |
| from skimage import color, img_as_ubyte | |
| from monai import transforms, data | |
| os.system("git clone https://github.com/darraghdog/Project-MONAI-research-contributions pmrc") | |
| sys.path.append("pmrc/SwinUNETR/BTCV") | |
| from swinunetr import SwinUnetrModelForInference, SwinUnetrConfig | |
| ffmpeg_path = shutil.which('ffmpeg') | |
| mediapy.set_ffmpeg(ffmpeg_path) | |
| # Load model | |
| model = SwinUnetrModelForInference.from_pretrained('darragh/swinunetr-btcv-tiny') | |
| model.eval() | |
| # Pull files from github | |
| input_files = glob.glob('pmrc/SwinUNETR/BTCV/dataset/imagesSampleTs/*.nii.gz') | |
| input_files = dict((f.split('/')[-1], f) for f in input_files) | |
| # Load and process dicom with monai transforms | |
| test_transform = transforms.Compose( | |
| [ | |
| transforms.LoadImaged(keys=["image"]), | |
| transforms.AddChanneld(keys=["image"]), | |
| transforms.Spacingd(keys="image", | |
| pixdim=(1.5, 1.5, 2.0), | |
| mode="bilinear"), | |
| transforms.ScaleIntensityRanged(keys=["image"], | |
| a_min=-175.0, | |
| a_max=250.0, | |
| b_min=0.0, | |
| b_max=1.0, | |
| clip=True), | |
| # transforms.Resized(keys=["image"], spatial_size = (256,256,-1)), | |
| transforms.ToTensord(keys=["image"]), | |
| ]) | |
| # Create Data Loader | |
| def create_dl(test_files): | |
| ds = test_transform(test_files) | |
| loader = data.DataLoader(ds, | |
| batch_size=1, | |
| shuffle=False) | |
| return loader | |
| # Inference and video generation | |
| def generate_dicom_video(selected_file, n_frames): | |
| # Data processor | |
| test_file = input_files[selected_file] | |
| test_files = [{'image': test_file}] | |
| dl = create_dl(test_files) | |
| batch = next(iter(dl)) | |
| # Select dicom slices | |
| tst_inputs = batch["image"] | |
| tst_inputs = tst_inputs[:,:,:,:,-n_frames:] | |
| # Inference | |
| with torch.no_grad(): | |
| outputs = model(tst_inputs, | |
| (96,96,96), | |
| 8, | |
| overlap=0.5, | |
| mode="gaussian") | |
| tst_outputs = torch.softmax(outputs.logits, 1) | |
| tst_outputs = torch.argmax(tst_outputs, axis=1) | |
| # Write frames to video | |
| for inp, outp in zip(tst_inputs, tst_outputs): | |
| frames = [] | |
| for idx in range(inp.shape[-1]): | |
| # Segmentation | |
| seg = outp[:,:,idx].numpy().astype(np.uint8) | |
| # Input dicom frame | |
| img = (inp[0,:,:,idx]*255).numpy().astype(np.uint8) | |
| img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB) | |
| frame = color.label2rgb(seg,img, bg_label = 0) | |
| frame = img_as_ubyte(frame) | |
| frame = np.concatenate((img, frame), 1) | |
| frames.append(frame) | |
| mediapy.write_video("dicom.mp4", frames, fps=4) | |
| return 'dicom.mp4' | |
| theme = 'dark-peach' | |
| with gr.Blocks(theme=theme) as demo: | |
| gr.Markdown('''<center><h1>SwinUnetr BTCV</h1></center> | |
| This is a Gradio Blocks app of the winning transformer in the Beyond the Cranial Vault (BTCV) Segmentation Challenge, <a href="https://github.com/darraghdog/Project-MONAI-research-contributions/tree/main/SwinUNETR/BTCV">SwinUnetr</a> (tiny version). | |
| ''') | |
| selected_dicom_key = gr.inputs.Dropdown( | |
| choices=sorted(input_files), | |
| type="value", | |
| label="Select a dicom file") | |
| n_frames = gr.Slider(1, 100, value=32, label="Choose the number of dicom slices to process", step = 1) | |
| button_gen_video = gr.Button("Generate Video") | |
| output_interpolation = gr.Video(label="Generated Video") | |
| button_gen_video.click(fn=generate_dicom_video, | |
| inputs=[selected_dicom_key, n_frames], | |
| outputs=output_interpolation) | |
| demo.launch(debug=True, enable_queue=True) | |