import os import gradio as gr import torch from monai import bundle # Set the bundle name and download path BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0' BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME) description = """ ## 🚀 To run Upload a abdominal CT scan, or try one of the examples below! If you want to see a different slice, update the slider. More details on the model can be found [here!](https://huggingface.co/katielink/spleen_ct_segmentation_v0.1.0) ## ⚠️ Disclaimer This is an example, not to be used for diagnostic purposes. """ # Set up some examples from the test set for better user experience examples = [ ['examples/spleen_1.nii.gz', 50], ['examples/spleen_11.nii.gz', 50], ] # Load the pretrained model from Hugging Face Hub model, _, _ = bundle.load( name = BUNDLE_NAME, source = 'huggingface_hub', repo = 'katielink/spleen_ct_segmentation_v0.1.0', load_ts_module=True, ) # Use GPU if available device = "cuda:0" if torch.cuda.is_available() else "cpu" # Load transforms and inferer directly from the bundle parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json') preproc_transforms = parser.get_parsed_content( 'preprocessing', lazy=True, eval_expr=True,instantiate=True ) inferer = parser.get_parsed_content( 'inferer', lazy=True, eval_expr=True, instantiate=True ) # Define the prediction function def predict(input_file, z_axis, model=model, device=device): data = {'image': [input_file.name]} data = preproc_transforms(data) model.to(device) model.eval() with torch.no_grad(): inputs = data['image'].to(device)[None,...] data['pred'] = inferer(inputs=inputs, network=model) input_image = data['image'].numpy() pred_image = torch.argmax(data['pred'], dim=1).cpu().detach().numpy() return input_image[0, :, :, z_axis], pred_image[0, :, :, z_axis]*255 # Set up the demo interface iface = gr.Interface( fn=predict, inputs=[ gr.File(label='input file'), gr.Slider(0, 200, label='slice', value=50) ], outputs=['image', 'image'], title='Segment the Spleen using MONAI! 🩸', description=description, examples=examples, ) # Launch the demo iface.launch()