cxas-demo / app.py
cmseibold's picture
Update app.py
d0c2ef0
import numpy as np
import gradio as gr
from cxas import CXAS
import torch
import torch.nn.functional as F
from cxas.visualize import visualize_mask
from cxas.label_mapper import label_mapper
model = CXAS(
model_name = 'UNet_ResNet50_default',
gpus = 'cpu'
)
def predict(image, class_names):
with torch.no_grad():
if type(class_names) == str:
class_names = [class_names]
normalized_image = {'data':F.interpolate(model.fileloader.normalize(torch.tensor(np.transpose(image, [2,0,1])/255)).unsqueeze(0), 512).float()}
out = model(normalized_image)['segmentation_preds'][0].cpu().numpy()
return np.array(visualize_mask(class_names = class_names,
mask = out,
image = np.transpose(image, [2,0,1]),
img_size = 512,
cat = True,
axis = 1,
))
dropdown_option_class_names = list(label_mapper.keys())
with gr.Blocks() as demo:
with gr.Tab("Anatomy Segmentation"):
with gr.Row():
gr.Markdown('This is a demo for Chest X-Ray Anatomy Segmentation with 159 classes. \n As it is running on CPU, it is not that fast. \n The demo only supports .jpg and .png files.')
with gr.Row():
gr.Markdown('Note: for lateral views, only patients facing to the right side of the image are supported.')
with gr.Row():
gr.Markdown('Not intended for clinical use!')
with gr.Row():
gr.Markdown('To use it locally and to segment other data types like dicom, check out \'pip install cxas\'.')
with gr.Row():
class_options = gr.Dropdown(dropdown_option_class_names, value = ['thoracic spine'], label="class_names", multiselect=True)
with gr.Row():
image_input = gr.Image()
with gr.Row():
image_output = gr.Image()
image_button = gr.Button("Segment Chest X-Ray Anatomy")
image_button.click(predict, inputs=[image_input,class_options], outputs=image_output)
demo.launch()