|
import numpy as np |
|
import gradio as gr |
|
from cxas import CXAS |
|
import torch |
|
import torch.nn.functional as F |
|
from cxas.visualize import visualize_mask |
|
from cxas.label_mapper import label_mapper |
|
|
|
model = CXAS( |
|
model_name = 'UNet_ResNet50_default', |
|
gpus = 'cpu' |
|
) |
|
|
|
def predict(image, class_names): |
|
with torch.no_grad(): |
|
if type(class_names) == str: |
|
class_names = [class_names] |
|
|
|
normalized_image = {'data':F.interpolate(model.fileloader.normalize(torch.tensor(np.transpose(image, [2,0,1])/255)).unsqueeze(0), 512).float()} |
|
out = model(normalized_image)['segmentation_preds'][0].cpu().numpy() |
|
return np.array(visualize_mask(class_names = class_names, |
|
mask = out, |
|
image = np.transpose(image, [2,0,1]), |
|
img_size = 512, |
|
cat = True, |
|
axis = 1, |
|
)) |
|
|
|
dropdown_option_class_names = list(label_mapper.keys()) |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Tab("Anatomy Segmentation"): |
|
with gr.Row(): |
|
gr.Markdown('This is a demo for Chest X-Ray Anatomy Segmentation with 159 classes. \n As it is running on CPU, it is not that fast. \n The demo only supports .jpg and .png files.') |
|
with gr.Row(): |
|
gr.Markdown('Note: for lateral views, only patients facing to the right side of the image are supported.') |
|
with gr.Row(): |
|
gr.Markdown('Not intended for clinical use!') |
|
with gr.Row(): |
|
gr.Markdown('To use it locally and to segment other data types like dicom, check out \'pip install cxas\'.') |
|
|
|
with gr.Row(): |
|
class_options = gr.Dropdown(dropdown_option_class_names, value = ['thoracic spine'], label="class_names", multiselect=True) |
|
with gr.Row(): |
|
image_input = gr.Image() |
|
with gr.Row(): |
|
image_output = gr.Image() |
|
image_button = gr.Button("Segment Chest X-Ray Anatomy") |
|
|
|
image_button.click(predict, inputs=[image_input,class_options], outputs=image_output) |
|
|
|
demo.launch() |
|
|