Spaces:
Build error
Build error
| import torch | |
| import matplotlib.pyplot as plt | |
| from monai.networks.nets import SegResNet | |
| from monai.inferers import sliding_window_inference | |
| from monai.transforms import ( | |
| Activations, | |
| AsDiscrete, | |
| Compose, | |
| ) | |
| model = SegResNet( | |
| blocks_down=[1, 2, 2, 4], | |
| blocks_up=[1, 1, 1], | |
| init_filters=16, | |
| in_channels=4, | |
| out_channels=3, | |
| dropout_prob=0.2, | |
| ) | |
| model.load_state_dict( | |
| torch.load("weights/model.pt", map_location=torch.device('cpu')) | |
| ) | |
| # define inference method | |
| VAL_AMP = True | |
| def inference(input): | |
| def _compute(input): | |
| return sliding_window_inference( | |
| inputs=input, | |
| roi_size=(240, 240, 160), | |
| sw_batch_size=1, | |
| predictor=model, | |
| overlap=0.5, | |
| ) | |
| if VAL_AMP: | |
| with torch.cuda.amp.autocast(): | |
| return _compute(input) | |
| else: | |
| return _compute(input) | |
| post_trans = Compose( | |
| [Activations(sigmoid=True), AsDiscrete(threshold=0.5)] | |
| ) | |
| import gradio as gr | |
| def load_sample1(): | |
| return load_sample(1) | |
| def load_sample2(): | |
| return load_sample(2) | |
| def load_sample3(): | |
| return load_sample(3) | |
| def load_sample4(): | |
| return load_sample(4) | |
| def load_sample5(): | |
| return load_sample(5) | |
| def load_sample6(): | |
| return load_sample(6) | |
| def load_sample7(): | |
| return load_sample(7) | |
| def load_sample8(): | |
| return load_sample(8) | |
| import torchvision | |
| def load_sample(index): | |
| #sample_index = index | |
| sample = torch.load(f"samples/val{index-1}.pt") | |
| imgs = [] | |
| for i in range(4): | |
| imgs.append(sample["image"][i, :, :, 70]) | |
| pil_images = [] | |
| for i in range(4): | |
| pil_images.append(torchvision.transforms.functional.to_pil_image(imgs[i])) | |
| imgs_label = [] | |
| for i in range(3): | |
| imgs_label.append(sample["label"][i, :, :, 70]) | |
| pil_images_label = [] | |
| for i in range(3): | |
| pil_images_label.append(torchvision.transforms.functional.to_pil_image(imgs_label[i])) | |
| return [index, pil_images[0], pil_images[1], pil_images[2], pil_images[3], | |
| pil_images_label[0], pil_images_label[1], pil_images_label[2]] | |
| def predict(sample_index): | |
| sample = torch.load(f"samples/val{sample_index-1}.pt") | |
| model.eval() | |
| with torch.no_grad(): | |
| # select one image to evaluate and visualize the model output | |
| val_input = sample["image"].unsqueeze(0) | |
| roi_size = (128, 128, 64) | |
| sw_batch_size = 4 | |
| val_output = inference(val_input) | |
| val_output = post_trans(val_output[0]) | |
| imgs_output = [] | |
| for i in range(3): | |
| imgs_output.append(val_output[i, :, :, 70]) | |
| pil_images_output = [] | |
| for i in range(3): | |
| pil_images_output.append(torchvision.transforms.functional.to_pil_image(imgs_output[i])) | |
| return [pil_images_output[0], pil_images_output[1], pil_images_output[2]] | |
| with gr.Blocks(title="Brain tumor 3D segmentation with MONAIMNIST - ClassCat", | |
| css=".gradio-container {background:azure;}" | |
| ) as demo: | |
| sample_index = gr.State([]) | |
| gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">Brain tumor 3D segmentation with MONAI</div>""") | |
| gr.HTML("""<h4 style="color:navy;">1. Select an example, which includes input images and label images, by clicking "Example x" button.</h4>""") | |
| with gr.Row(): | |
| input_image0 = gr.Image(label="image channel 0", type="pil", shape=(240, 240)) | |
| input_image1 = gr.Image(label="image channel 1", type="pil", shape=(240, 240)) | |
| input_image2 = gr.Image(label="image channel 2", type="pil", shape=(240, 240)) | |
| input_image3 = gr.Image(label="image channel 3", type="pil", shape=(240, 240)) | |
| with gr.Row(): | |
| label_image0 = gr.Image(label="label channel 0", type="pil") | |
| label_image1 = gr.Image(label="label channel 1", type="pil") | |
| label_image2 = gr.Image(label="label channel 2", type="pil") | |
| with gr.Row(): | |
| example1_btn = gr.Button("Example 1") | |
| example2_btn = gr.Button("Example 2") | |
| example3_btn = gr.Button("Example 3") | |
| example4_btn = gr.Button("Example 4") | |
| example5_btn = gr.Button("Example 5") | |
| example6_btn = gr.Button("Example 6") | |
| example7_btn = gr.Button("Example 7") | |
| example8_btn = gr.Button("Example 8") | |
| example1_btn.click(fn=load_sample1, inputs=None, | |
| outputs=[sample_index, input_image0, input_image1, input_image2, input_image3, | |
| label_image0, label_image1, label_image2]) | |
| example2_btn.click(fn=load_sample2, inputs=None, | |
| outputs=[sample_index, input_image0, input_image1, input_image2, input_image3, | |
| label_image0, label_image1, label_image2]) | |
| example3_btn.click(fn=load_sample3, inputs=None, | |
| outputs=[sample_index, input_image0, input_image1, input_image2, input_image3, | |
| label_image0, label_image1, label_image2]) | |
| example4_btn.click(fn=load_sample4, inputs=None, | |
| outputs=[sample_index, input_image0, input_image1, input_image2, input_image3, | |
| label_image0, label_image1, label_image2]) | |
| example5_btn.click(fn=load_sample5, inputs=None, | |
| outputs=[sample_index, input_image0, input_image1, input_image2, input_image3, | |
| label_image0, label_image1, label_image2]) | |
| example6_btn.click(fn=load_sample6, inputs=None, | |
| outputs=[sample_index, input_image0, input_image1, input_image2, input_image3, | |
| label_image0, label_image1, label_image2]) | |
| example7_btn.click(fn=load_sample7, inputs=None, | |
| outputs=[sample_index, input_image0, input_image1, input_image2, input_image3, | |
| label_image0, label_image1, label_image2]) | |
| example8_btn.click(fn=load_sample8, inputs=None, | |
| outputs=[sample_index, input_image0, input_image1, input_image2, input_image3, | |
| label_image0, label_image1, label_image2]) | |
| gr.HTML("""<br/>""") | |
| gr.HTML("""<h4 style="color:navy;">2. Then, click "Infer" button to predict segmentation images. It will take about 30 seconds (on cpu)</h4>""") | |
| with gr.Row(): | |
| output_image0 = gr.Image(label="output channel 0", type="pil") | |
| output_image1 = gr.Image(label="output channel 1", type="pil") | |
| output_image2 = gr.Image(label="output channel 2", type="pil") | |
| send_btn = gr.Button("Infer") | |
| send_btn.click(fn=predict, inputs=[sample_index], outputs=[output_image0, output_image1, output_image2]) | |
| gr.HTML("""<br/>""") | |
| gr.HTML("""<h4 style="color:navy;">Reference</h4>""") | |
| gr.HTML("""<ul>""") | |
| gr.HTML("""<li><a href="https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/brats_segmentation_3d.ipynb" target="_blank">Brain tumor 3D segmentation with MONAI</a></li>""") | |
| gr.HTML("""</ul>""") | |
| #demo.queue() | |
| demo.launch(debug=True) | |
| ### EOF ### | |