katielink's picture
Rename slider label
ffd36f0
import os
import gradio as gr
import torch
from monai import bundle
# Set the bundle name and download path
BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0'
BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME)
description = """
## 🚀 To run
Upload a abdominal CT scan, or try one of the examples below!
If you want to see a different slice, update the slider.
More details on the model can be found [here!](https://huggingface.co/katielink/spleen_ct_segmentation_v0.1.0)
## ⚠️ Disclaimer
This is an example, not to be used for diagnostic purposes.
"""
# Set up some examples from the test set for better user experience
examples = [
['examples/spleen_1.nii.gz', 50],
['examples/spleen_11.nii.gz', 50],
]
# Load the pretrained model from Hugging Face Hub
model, _, _ = bundle.load(
name = BUNDLE_NAME,
source = 'huggingface_hub',
repo = 'katielink/spleen_ct_segmentation_v0.1.0',
load_ts_module=True,
)
# Use GPU if available
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Load transforms and inferer directly from the bundle
parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json')
preproc_transforms = parser.get_parsed_content(
'preprocessing',
lazy=True, eval_expr=True,instantiate=True
)
inferer = parser.get_parsed_content(
'inferer',
lazy=True, eval_expr=True, instantiate=True
)
# Define the prediction function
def predict(input_file, z_axis, model=model, device=device):
data = {'image': [input_file.name]}
data = preproc_transforms(data)
model.to(device)
model.eval()
with torch.no_grad():
inputs = data['image'].to(device)[None,...]
data['pred'] = inferer(inputs=inputs, network=model)
input_image = data['image'].numpy()
pred_image = torch.argmax(data['pred'], dim=1).cpu().detach().numpy()
return input_image[0, :, :, z_axis], pred_image[0, :, :, z_axis]*255
# Set up the demo interface
iface = gr.Interface(
fn=predict,
inputs=[
gr.File(label='input file'),
gr.Slider(0, 200, label='slice', value=50)
],
outputs=['image', 'image'],
title='Segment the Spleen using MONAI! 🩸',
description=description,
examples=examples,
)
# Launch the demo
iface.launch()