#!/usr/bin/env python # coding=utf-8 # Author: Jun Ma import os join = os.path.join import argparse import numpy as np import torch import torch.nn as nn import tifffile as tif import monai from tqdm import tqdm from utils.postprocess import mask_overlay from monai.transforms import Activations, AddChanneld, AsChannelFirstd, AsDiscrete, Compose, EnsureTyped, EnsureType from models.unicell_modules import MiT_B2_UNet_MultiHead, MiT_B3_UNet_MultiHead import matplotlib.pyplot as plt from skimage import io, exposure, segmentation, morphology from utils.postprocess import watershed_post from utils.multi_task_sliding_window_inference import multi_task_sliding_window_inference import gradio as gr def normalize_channel(img, lower=0.1, upper=99.9): non_zero_vals = img[np.nonzero(img)] percentiles = np.percentile(non_zero_vals, [lower, upper]) if percentiles[1] - percentiles[0] > 0.001: img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8') else: img_norm = img return img_norm def preprocess(img_data): if len(img_data.shape) == 2: img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1) elif len(img_data.shape) == 3 and img_data.shape[-1] > 3: img_data = img_data[:,:, :3] else: pass pre_img_data = np.zeros(img_data.shape, dtype=np.uint8) for i in range(3): img_channel_i = img_data[:,:,i] if len(img_channel_i[np.nonzero(img_channel_i)])>0: pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99) return pre_img_data def inference(pre_img_data): test_npy = pre_img_data/np.max(pre_img_data) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = MiT_B2_UNet_MultiHead(in_channels=3, out_channels=3, regress_class=1, img_size=256).to(device) checkpoint = torch.load('./model.pth', map_location=torch.device(device)) model.load_state_dict(checkpoint['model_state_dict']) model.eval() with torch.no_grad(): test_tensor = torch.from_numpy(np.expand_dims(test_npy, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device) val_pred, val_pred_dist = multi_task_sliding_window_inference(inputs=test_tensor, roi_size=(256, 256), sw_batch_size=8, predictor=model) # watershed postprocessing val_seg_inst = watershed_post(val_pred_dist.squeeze(1).cpu().numpy(), val_pred.squeeze(1).cpu().numpy()[:,1]) test_pred_mask = val_seg_inst.squeeze().astype(np.uint16) # overlay boundary = segmentation.find_boundaries(test_pred_mask, connectivity=1, mode='inner') boundary = morphology.binary_dilation(boundary, morphology.disk(1)) pre_img_data[boundary, 0] = 0 pre_img_data[boundary, 1] = 255 pre_img_data[boundary, 2] = 0 return test_pred_mask, pre_img_data def predict(img): print('##########', img.name) img_name = img.name if img_name.endswith('.tif') or img_name.endswith('.tiff'): img_data = tif.imread(img_name) else: img_data = io.imread(img_name) if len(img_data.shape)==2: pre_img_data = normalize_channel(img_data, lower=0.1, upper=99.9) pre_img_data = np.repeat(np.expand_dims(pre_img_data, -1), repeats=3, axis=-1) else: pre_img_data = np.zeros((img_data.shape[0], img_data.shape[1], 3), dtype=np.uint8) for i in range(3): img_channel_i = img_data[:,:,i] if len(img_channel_i[np.nonzero(img_channel_i)])>0: pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=0.1, upper=99.9) seg_labels, seg_overlay = inference(pre_img_data) tif.imwrite(join(os.getcwd(), 'segmentation.tiff'), seg_labels, compression='zlib') return seg_overlay, join(os.getcwd(), 'segmentation.tiff') unicell_api = gr.Interface( predict, inputs = gr.File(label="Input image (png, bmp, jpg, tif, tiff)"), outputs = [gr.Image(label="Segmentation overlay"), gr.File(label="Download segmentation")], title = "UniCell Online Demo", examples=['demo.png', 'demo.tif'] ) unicell_api.launch()