import numpy as np import gradio as gr from cxas import CXAS import torch import torch.nn.functional as F from cxas.visualize import visualize_mask from cxas.label_mapper import label_mapper model = CXAS( model_name = 'UNet_ResNet50_default', gpus = 'cpu' ) def predict(image, class_names): with torch.no_grad(): if type(class_names) == str: class_names = [class_names] normalized_image = {'data':F.interpolate(model.fileloader.normalize(torch.tensor(np.transpose(image, [2,0,1])/255)).unsqueeze(0), 512).float()} out = model(normalized_image)['segmentation_preds'][0].cpu().numpy() return np.array(visualize_mask(class_names = class_names, mask = out, image = np.transpose(image, [2,0,1]), img_size = 512, cat = True, axis = 1, )) dropdown_option_class_names = list(label_mapper.keys()) with gr.Blocks() as demo: with gr.Tab("Anatomy Segmentation"): with gr.Row(): gr.Markdown('This is a demo for Chest X-Ray Anatomy Segmentation with 159 classes. \n As it is running on CPU, it is not that fast. \n The demo only supports .jpg and .png files.') with gr.Row(): gr.Markdown('Note: for lateral views, only patients facing to the right side of the image are supported.') with gr.Row(): gr.Markdown('Not intended for clinical use!') with gr.Row(): gr.Markdown('To use it locally and to segment other data types like dicom, check out \'pip install cxas\'.') with gr.Row(): class_options = gr.Dropdown(dropdown_option_class_names, value = ['thoracic spine'], label="class_names", multiselect=True) with gr.Row(): image_input = gr.Image() with gr.Row(): image_output = gr.Image() image_button = gr.Button("Segment Chest X-Ray Anatomy") image_button.click(predict, inputs=[image_input,class_options], outputs=image_output) demo.launch()