akhaliq HF staff commited on
Commit
fb993be
1 Parent(s): 18dde43

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ %config InlineBackend.figure_format = 'retina'
2
+ from pathlib import Path
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchio as tio
6
+ import numpy as np
7
+ from tqdm.notebook import tqdm
8
+ import gradio as gr
9
+ from matplotlib import pyplot as plt
10
+
11
+ torch.set_grad_enabled(False);
12
+ # Download an example image
13
+ import urllib
14
+ url, filename = ("https://github.com/mateuszbuda/brain-segmentation-pytorch/raw/master/assets/TCGA_CS_4944.png", "TCGA_CS_4944.png")
15
+ try: urllib.URLopener().retrieve(url, filename)
16
+ except: urllib.request.urlretrieve(url, filename)
17
+ def inference(img):
18
+ path = img.name
19
+ slices = [tio.ScalarImage(path).data]
20
+ tensor = torch.cat(slices, dim=-1)
21
+ guessed_affine = np.diag([-1, -1, 9, 1])
22
+ subject = tio.Subject(mri=tio.ScalarImage(tensor=tensor, affine=guessed_affine))
23
+ subject_preprocessed = tio.ZNormalization()(subject)
24
+ subject_preprocessed.plot()
25
+ subject_preprocessed.mri
26
+ patch_overlap = 0
27
+ patch_size = 256, 256, 1
28
+ grid_sampler = tio.inference.GridSampler(
29
+ subject_preprocessed,
30
+ patch_size,
31
+ patch_overlap,
32
+ )
33
+ patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=8)
34
+ aggregator = tio.inference.GridAggregator(grid_sampler)
35
+ model = torch.hub.load(
36
+ 'mateuszbuda/brain-segmentation-pytorch',
37
+ 'unet',
38
+ in_channels=3,
39
+ out_channels=1,
40
+ init_features=32,
41
+ pretrained=True,
42
+ )
43
+ for patches_batch in tqdm(patch_loader):
44
+ input_tensor = patches_batch['mri'][tio.DATA][..., 0]
45
+ locations = patches_batch[tio.LOCATION]
46
+ probs = model(input_tensor)[..., np.newaxis]
47
+ aggregator.add_batch(probs, locations)
48
+ output_tensor = aggregator.get_output_tensor()
49
+ output_subject = tio.Subject(prediction=tio.ScalarImage(tensor=output_tensor, affine=guessed_affine))
50
+ images = subject_preprocessed.mri.tensor.detach().numpy().reshape((3, 256, 256))
51
+ mask = output_subject.prediction.tensor.detach().numpy().reshape((256, 256))
52
+ images = np.moveaxis(np.moveaxis(images, 0, 2), 0, 1)
53
+ mask = np.moveaxis(mask, 0, 1)
54
+
55
+ f, ax = plt.subplots(1, 2)
56
+ ax[0].set_axis_off()
57
+ ax[1].set_axis_off()
58
+ ax[0].imshow(images)
59
+ ax[1].imshow(mask, cmap='gray')
60
+ return f
61
+
62
+ title = "U-NET FOR BRAIN MRI"
63
+ description = "Gradio demo for u-net for brain mri, U-Net with batch normalization for biomedical image segmentation with pretrained weights for abnormality segmentation in brain MRI. To use it, simply add your image, or click one of the examples to load them. Read more at the links below."
64
+ article = "<p style='text-align: center'><a href='https://mateuszbuda.github.io/2017/12/01/brainseg.html'>Segmentation of brain tumor in magnetic resonance images</a> | <a href='https://github.com/mateuszbuda/brain-segmentation-pytorch'>Github Repo</a></p>"
65
+ examples = [
66
+ ['TCGA_CS_4944.png']
67
+ ]
68
+ gr.Interface(inference, gr.inputs.Image(label="input image", type='file'), gr.outputs.Image(type='plot'), description=description, article=article, title=title, examples=examples, analytics_enabled=False).launch()