import numpy as np import gradio as gr from PIL import Image import torch from torch import nn from einops.layers.torch import Rearrange from torchvision import transforms from models.unet_model import Unet from models.datasetDM_model import DatasetDM from skimage import measure, segmentation import cv2 from tqdm import tqdm from einops import repeat img_size = 128 font = cv2.FONT_HERSHEY_SIMPLEX ## %% def load_img(img_file): # assert type of input if isinstance(img_file, np.ndarray): img = torch.Tensor(img_file).float() # make sure img is between 0 and 1 if img.max() > 1: img /= 255 # resize img = transforms.Resize(img_size)(img) elif isinstance(img_file, str): img = Image.open(img_file).convert('L').resize((img_size, img_size)) img = transforms.ToTensor()(img).float() elif isinstance(img_file, Image.Image): img = img_file.convert('L').resize((img_size, img_size)) img = transforms.ToTensor()(img).float() else: raise TypeError("Input must be a numpy array, PIL image, or filepath") if len(img.shape) == 2: img = img[None, None] elif len(img.shape) == 3: img = img[None] else: raise ValueError("Input must be a 2D or 3D array") return img def predict_baseline(img, checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) config = checkpoint["config"] baseline = Unet(**vars(config)) baseline.load_state_dict(checkpoint["model_state_dict"]) baseline.eval() return (torch.sigmoid(baseline(img)) > .5).float().squeeze().numpy() def predict_LEDM(img, checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) config = checkpoint["config"] config.verbose = False LEDM = DatasetDM(config) LEDM.load_state_dict(checkpoint["model_state_dict"]) LEDM.eval() return (torch.sigmoid(LEDM(img)) > .5).float().squeeze().numpy() def predict_TEDM(img, checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) config = checkpoint["config"] config.verbose = False TEDM = DatasetDM(config) TEDM.classifier = nn.Sequential( Rearrange('b (step act) h w -> (b step) act h w', step=len(TEDM.steps)), nn.Conv2d(960, 128, 1), nn.ReLU(), nn.BatchNorm2d(128), nn.Conv2d(128, 32, 1), nn.ReLU(), nn.BatchNorm2d(32), nn.Conv2d(32, 1, config.out_channels) ) TEDM.load_state_dict(checkpoint["model_state_dict"]) TEDM.eval() return (torch.sigmoid(TEDM(img)).mean(0) > .5).float().squeeze().numpy() predictors = {'Baseline': predict_baseline, 'Global CL': predict_baseline, 'Global & Local CL': predict_baseline, 'LEDM': predict_LEDM, 'LEDMe': predict_LEDM, 'TEDM': predict_TEDM} model_folders = { 'Baseline': 'baseline', 'Global CL': 'global_finetune', 'Global & Local CL': 'glob_loc_finetune', 'LEDM': 'LEDM', 'LEDMe': 'LEDMe', 'TEDM': 'TEDM' } def postprocess(pred, img): all_labels = measure.label(pred, background=0) _, cn = np.unique(all_labels, return_counts=True) # find the two largest connected components that are not the background if len(cn) >= 3: lungs = np.argsort(cn[1:])[-2:] + 1 all_labels[(all_labels!=lungs[0]) & (all_labels!=lungs[1])] = 0 all_labels[(all_labels==lungs[0]) | (all_labels==lungs[1])] = 1 # put all_labels into a cv2 object if len(cn) > 1: img = segmentation.mark_boundaries(img, all_labels, color=(1,0,0), mode='outer', background_label=0) else: img = repeat(img, 'h w -> h w c', c=3) return img def predict(img_file, models:list, training_sizes:list, seg_img=False, progress=gr.Progress()): max_progress = len(models) * len(training_sizes) n_progress = 0 progress((n_progress, max_progress), desc="Starting") img = load_img(img_file) print(img.shape) preds = [] # sorting models so that they show as baseline - LEDM - LEDMe - TEDM models = sorted(models, key=lambda x: 0 if x == 'Baseline' else 1 if x == 'Global CL' else 2 if x == 'Global & Local CL' else 3 if x == 'LEDM' else 4 if x == 'LEDMe' else 5) for model in models: print(model) model_preds = [] for training_size in sorted(training_sizes): #if n_progress < max_progress: progress((n_progress, max_progress) , desc=f"Predicting {model} {training_size}") n_progress += 1 print(training_size) out = predictors[model](img, f"logs/{model_folders[model]}/{training_size}/best_model.pt") writing_colour = (.5,.5,.5) if seg_img: out = postprocess(out, img.squeeze().numpy()) writing_colour = (1,1,1) out = cv2.putText(np.array(out),f"{model} {training_size}",(5,125), font, .5, writing_colour,1, cv2.LINE_AA) #ImageDraw.Draw(out).text((0,128), f"{model} {training_size}", fill=(255,0,0)) model_preds.append(np.asarray(out)) preds.append(np.concatenate(model_preds, axis=1)) prediction = np.concatenate(preds, axis=0) if (prediction.shape[1] <=128*2): pad = (330 - prediction.shape[1])//2 if len(prediction.shape) == 2: prediction = np.pad(prediction, ((0,0), (pad, pad)), 'constant', constant_values=1) else: prediction = np.pad(prediction, ((0,0), (pad, pad), (0,0)), 'constant', constant_values=1) return prediction ## %% input = gr.Image( label="Chest X-ray", shape=(img_size, img_size), type="pil") output = gr.Image(label="Segmentation", shape=(img_size, img_size)) ## %% demo = gr.Interface( fn=predict, inputs=[input, gr.CheckboxGroup(["Baseline", "Global CL", "Global & Local CL", "LEDM", "LEDMe", "TEDM"], label="Model", value=["Baseline", "LEDM", "LEDMe", "TEDM"]), gr.CheckboxGroup([1,3,6,12,197], label="Training size", value=[1,3,6,12,197]), gr.Checkbox(label="Show masked image (otherwise show binary segmentation)", value=True),], outputs=output, examples = [ ['img_examples/NIH_0006.png'], ['img_examples/NIH_0076.png'], ["img_examples/00016568_041.png"], ['img_examples/NIH_0024.png'], ['img_examples/00015548_000.png'], ['img_examples/NIH_0019.png'], ['img_examples/NIH_0094.png'], ['img_examples/NIH_0051.png'], ['img_examples/NIH_0012.png'], ['img_examples/NIH_0014.png'], ['img_examples/NIH_0055.png'], ['img_examples/NIH_0035.png'], ], title="Chest X-ray Segmentation with TEDM.", description="""Markdown Monster icon"""+ "\nMedical image segmentation is a challenging task, made more difficult by many datasets' limited size and annotations. Denoising diffusion probabilistic models (DDPM) have recently shown promise in modelling " + "the distribution of natural images and were successfully applied to various medical imaging tasks. This work focuses on semi-supervised image segmentation using diffusion models, particularly addressing domain " + "generalisation. Firstly, we demonstrate that smaller diffusion steps generate latent representations that are more robust for downstream tasks than larger steps. Secondly, we use this insight to propose an improved " + "esembling scheme that leverages information-dense small steps and the regularising effect of larger steps to generate predictions. Our model shows significantly better performance in domain-shifted settings while " + "retaining competitive performance in-domain. Overall, this work highlights the potential of DDPMs for semi-supervised medical image segmentation and provides insights into optimising their performance under domain shift."+ "\n\n\n When choosing 'Show masked image', we post-process the segmentation by choosing up to two largest connected components and drawing their outline. "+ "\nNote that each model takes 10-35 seconds to run on CPU. Choosing all models and all training sizes will take some time. "+ "We noticed that gradio sometimes fails on the first try. If it doesn't work, try again.", cache_examples=False, ) #demo.queue().launch(share=True) demo.queue().launch()