File size: 2,275 Bytes
14c80dd
 
 
 
 
39c57ff
14c80dd
 
 
e61c9df
 
612f056
 
 
e61c9df
 
 
612f056
36ccbd1
e61c9df
 
39c57ff
 
 
 
 
14c80dd
39c57ff
14c80dd
 
90b32b2
14c80dd
 
 
 
39c57ff
14c80dd
 
39c57ff
14c80dd
39c57ff
 
 
 
 
 
 
 
14c80dd
39c57ff
14c80dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39c57ff
14c80dd
 
 
39c57ff
ffd36f0
14c80dd
 
36ccbd1
e61c9df
39c57ff
14c80dd
 
39c57ff
14c80dd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import gradio as gr
import torch
from monai import bundle

# Set the bundle name and download path
BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0'
BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME)

description = """
## 🚀 To run
Upload a abdominal CT scan, or try one of the examples below! 
If you want to see a different slice, update the slider.

More details on the model can be found [here!](https://huggingface.co/katielink/spleen_ct_segmentation_v0.1.0)

## ⚠️ Disclaimer
This is an example, not to be used for diagnostic purposes.

"""

# Set up some examples from the test set for better user experience
examples = [
    ['examples/spleen_1.nii.gz', 50],
    ['examples/spleen_11.nii.gz', 50],
]

# Load the pretrained model from Hugging Face Hub
model, _, _ = bundle.load(
    name = BUNDLE_NAME,
    source = 'huggingface_hub',
    repo = 'katielink/spleen_ct_segmentation_v0.1.0',
    load_ts_module=True,
)

# Use GPU if available
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load transforms and inferer directly from the bundle
parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json')
preproc_transforms = parser.get_parsed_content(
    'preprocessing',
    lazy=True, eval_expr=True,instantiate=True
)
inferer = parser.get_parsed_content(
    'inferer',
    lazy=True, eval_expr=True, instantiate=True
)

# Define the prediction function
def predict(input_file, z_axis, model=model, device=device):
    data = {'image': [input_file.name]}
    data = preproc_transforms(data)
    
    model.to(device)
    model.eval()
    with torch.no_grad():
        inputs = data['image'].to(device)[None,...]
        data['pred'] = inferer(inputs=inputs, network=model)
    
    input_image = data['image'].numpy()
    pred_image = torch.argmax(data['pred'], dim=1).cpu().detach().numpy()
    
    return input_image[0, :, :, z_axis], pred_image[0, :, :, z_axis]*255

# Set up the demo interface
iface = gr.Interface(
    fn=predict,
    inputs=[
        gr.File(label='input file'),
        gr.Slider(0, 200, label='slice', value=50)
    ], 
    outputs=['image', 'image'],
    title='Segment the Spleen using MONAI! 🩸',
    description=description,
    examples=examples,
)

# Launch the demo
iface.launch()