Spaces:
Runtime error
Runtime error
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
import torch | |
from torch import nn | |
from einops.layers.torch import Rearrange | |
from torchvision import transforms | |
from models.unet_model import Unet | |
from models.datasetDM_model import DatasetDM | |
from skimage import measure, segmentation | |
import cv2 | |
from tqdm import tqdm | |
from einops import repeat | |
img_size = 128 | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
## %% | |
def load_img(img_file): | |
# assert type of input | |
if isinstance(img_file, np.ndarray): | |
img = torch.Tensor(img_file).float() | |
# make sure img is between 0 and 1 | |
if img.max() > 1: | |
img /= 255 | |
# resize | |
img = transforms.Resize(img_size)(img) | |
elif isinstance(img_file, str): | |
img = Image.open(img_file).convert('L').resize((img_size, img_size)) | |
img = transforms.ToTensor()(img).float() | |
elif isinstance(img_file, Image.Image): | |
img = img_file.convert('L').resize((img_size, img_size)) | |
img = transforms.ToTensor()(img).float() | |
else: | |
raise TypeError("Input must be a numpy array, PIL image, or filepath") | |
if len(img.shape) == 2: | |
img = img[None, None] | |
elif len(img.shape) == 3: | |
img = img[None] | |
else: | |
raise ValueError("Input must be a 2D or 3D array") | |
return img | |
def predict_baseline(img, checkpoint_path): | |
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) | |
config = checkpoint["config"] | |
baseline = Unet(**vars(config)) | |
baseline.load_state_dict(checkpoint["model_state_dict"]) | |
baseline.eval() | |
return (torch.sigmoid(baseline(img)) > .5).float().squeeze().numpy() | |
def predict_LEDM(img, checkpoint_path): | |
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) | |
config = checkpoint["config"] | |
config.verbose = False | |
LEDM = DatasetDM(config) | |
LEDM.load_state_dict(checkpoint["model_state_dict"]) | |
LEDM.eval() | |
return (torch.sigmoid(LEDM(img)) > .5).float().squeeze().numpy() | |
def predict_TEDM(img, checkpoint_path): | |
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) | |
config = checkpoint["config"] | |
config.verbose = False | |
TEDM = DatasetDM(config) | |
TEDM.classifier = nn.Sequential( | |
Rearrange('b (step act) h w -> (b step) act h w', step=len(TEDM.steps)), | |
nn.Conv2d(960, 128, 1), | |
nn.ReLU(), | |
nn.BatchNorm2d(128), | |
nn.Conv2d(128, 32, 1), | |
nn.ReLU(), | |
nn.BatchNorm2d(32), | |
nn.Conv2d(32, 1, config.out_channels) | |
) | |
TEDM.load_state_dict(checkpoint["model_state_dict"]) | |
TEDM.eval() | |
return (torch.sigmoid(TEDM(img)).mean(0) > .5).float().squeeze().numpy() | |
predictors = {'Baseline': predict_baseline, | |
'Global CL': predict_baseline, | |
'Global & Local CL': predict_baseline, | |
'LEDM': predict_LEDM, | |
'LEDMe': predict_LEDM, | |
'TEDM': predict_TEDM} | |
model_folders = { | |
'Baseline': 'baseline', | |
'Global CL': 'global_finetune', | |
'Global & Local CL': 'glob_loc_finetune', | |
'LEDM': 'LEDM', | |
'LEDMe': 'LEDMe', | |
'TEDM': 'TEDM' | |
} | |
def postprocess(pred, img): | |
all_labels = measure.label(pred, background=0) | |
_, cn = np.unique(all_labels, return_counts=True) | |
# find the two largest connected components that are not the background | |
if len(cn) >= 3: | |
lungs = np.argsort(cn[1:])[-2:] + 1 | |
all_labels[(all_labels!=lungs[0]) & (all_labels!=lungs[1])] = 0 | |
all_labels[(all_labels==lungs[0]) | (all_labels==lungs[1])] = 1 | |
# put all_labels into a cv2 object | |
if len(cn) > 1: | |
img = segmentation.mark_boundaries(img, all_labels, color=(1,0,0), mode='outer', background_label=0) | |
else: | |
img = repeat(img, 'h w -> h w c', c=3) | |
return img | |
def predict(img_file, models:list, training_sizes:list, seg_img=False, progress=gr.Progress()): | |
max_progress = len(models) * len(training_sizes) | |
n_progress = 0 | |
progress((n_progress, max_progress), desc="Starting") | |
img = load_img(img_file) | |
print(img.shape) | |
preds = [] | |
# sorting models so that they show as baseline - LEDM - LEDMe - TEDM | |
models = sorted(models, key=lambda x: 0 if x == 'Baseline' else 1 if x == 'Global CL' else 2 if x == 'Global & Local CL' else 3 if x == 'LEDM' else 4 if x == 'LEDMe' else 5) | |
for model in models: | |
print(model) | |
model_preds = [] | |
for training_size in sorted(training_sizes): | |
#if n_progress < max_progress: | |
progress((n_progress, max_progress) , desc=f"Predicting {model} {training_size}") | |
n_progress += 1 | |
print(training_size) | |
out = predictors[model](img, f"logs/{model_folders[model]}/{training_size}/best_model.pt") | |
writing_colour = (.5,.5,.5) | |
if seg_img: | |
out = postprocess(out, img.squeeze().numpy()) | |
writing_colour = (1,1,1) | |
out = cv2.putText(np.array(out),f"{model} {training_size}",(5,125), font, .5, writing_colour,1, cv2.LINE_AA) | |
#ImageDraw.Draw(out).text((0,128), f"{model} {training_size}", fill=(255,0,0)) | |
model_preds.append(np.asarray(out)) | |
preds.append(np.concatenate(model_preds, axis=1)) | |
prediction = np.concatenate(preds, axis=0) | |
if (prediction.shape[1] <=128*2): | |
pad = (330 - prediction.shape[1])//2 | |
if len(prediction.shape) == 2: | |
prediction = np.pad(prediction, ((0,0), (pad, pad)), 'constant', constant_values=1) | |
else: | |
prediction = np.pad(prediction, ((0,0), (pad, pad), (0,0)), 'constant', constant_values=1) | |
return prediction | |
## %% | |
input = gr.Image( label="Chest X-ray", shape=(img_size, img_size), type="pil") | |
output = gr.Image(label="Segmentation", shape=(img_size, img_size)) | |
## %% | |
demo = gr.Interface( | |
fn=predict, | |
inputs=[input, | |
gr.CheckboxGroup(["Baseline", "Global CL", "Global & Local CL", "LEDM", "LEDMe", "TEDM"], label="Model", value=["Baseline", "LEDM", "LEDMe", "TEDM"]), | |
gr.CheckboxGroup([1,3,6,12,197], label="Training size", value=[1,3,6,12,197]), | |
gr.Checkbox(label="Show masked image (otherwise show binary segmentation)", value=True),], | |
outputs=output, | |
examples = [ | |
['img_examples/NIH_0006.png'], | |
['img_examples/NIH_0076.png'], | |
["img_examples/00016568_041.png"], | |
['img_examples/NIH_0024.png'], | |
['img_examples/00015548_000.png'], | |
['img_examples/NIH_0019.png'], | |
['img_examples/NIH_0094.png'], | |
['img_examples/NIH_0051.png'], | |
['img_examples/NIH_0012.png'], | |
['img_examples/NIH_0014.png'], | |
['img_examples/NIH_0055.png'], | |
['img_examples/NIH_0035.png'], | |
], | |
title="Chest X-ray Segmentation with TEDM.", | |
description="""<img src="file/img_examples/TEDM_n_LEDM.drawio.png" | |
alt="Markdown Monster icon" | |
style="margin-right: 10px;" />"""+ | |
"\nMedical image segmentation is a challenging task, made more difficult by many datasets' limited size and annotations. Denoising diffusion probabilistic models (DDPM) have recently shown promise in modelling " + | |
"the distribution of natural images and were successfully applied to various medical imaging tasks. This work focuses on semi-supervised image segmentation using diffusion models, particularly addressing domain " + | |
"generalisation. Firstly, we demonstrate that smaller diffusion steps generate latent representations that are more robust for downstream tasks than larger steps. Secondly, we use this insight to propose an improved " + | |
"esembling scheme that leverages information-dense small steps and the regularising effect of larger steps to generate predictions. Our model shows significantly better performance in domain-shifted settings while " + | |
"retaining competitive performance in-domain. Overall, this work highlights the potential of DDPMs for semi-supervised medical image segmentation and provides insights into optimising their performance under domain shift."+ | |
"\n\n\n When choosing 'Show masked image', we post-process the segmentation by choosing up to two largest connected components and drawing their outline. "+ | |
"\nNote that each model takes 10-35 seconds to run on CPU. Choosing all models and all training sizes will take some time. "+ | |
"We noticed that gradio sometimes fails on the first try. If it doesn't work, try again.", | |
cache_examples=False, | |
) | |
#demo.queue().launch(share=True) | |