UniCell / app.py
junma's picture
Update app.py
2088d61
#!/usr/bin/env python
# coding=utf-8
# Author: Jun Ma
import os
join = os.path.join
import argparse
import numpy as np
import torch
import torch.nn as nn
import tifffile as tif
import monai
from tqdm import tqdm
from utils.postprocess import mask_overlay
from monai.transforms import Activations, AddChanneld, AsChannelFirstd, AsDiscrete, Compose, EnsureTyped, EnsureType
from models.unicell_modules import MiT_B2_UNet_MultiHead, MiT_B3_UNet_MultiHead
import matplotlib.pyplot as plt
from skimage import io, exposure, segmentation, morphology
from utils.postprocess import watershed_post
from utils.multi_task_sliding_window_inference import multi_task_sliding_window_inference
import gradio as gr
def normalize_channel(img, lower=0.1, upper=99.9):
non_zero_vals = img[np.nonzero(img)]
percentiles = np.percentile(non_zero_vals, [lower, upper])
if percentiles[1] - percentiles[0] > 0.001:
img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8')
else:
img_norm = img
return img_norm
def preprocess(img_data):
if len(img_data.shape) == 2:
img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1)
elif len(img_data.shape) == 3 and img_data.shape[-1] > 3:
img_data = img_data[:,:, :3]
else:
pass
pre_img_data = np.zeros(img_data.shape, dtype=np.uint8)
for i in range(3):
img_channel_i = img_data[:,:,i]
if len(img_channel_i[np.nonzero(img_channel_i)])>0:
pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99)
return pre_img_data
def inference(pre_img_data):
test_npy = pre_img_data/np.max(pre_img_data)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MiT_B2_UNet_MultiHead(in_channels=3, out_channels=3, regress_class=1, img_size=256).to(device)
checkpoint = torch.load('./model.pth', map_location=torch.device(device))
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
with torch.no_grad():
test_tensor = torch.from_numpy(np.expand_dims(test_npy, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
val_pred, val_pred_dist = multi_task_sliding_window_inference(inputs=test_tensor, roi_size=(256, 256), sw_batch_size=8, predictor=model)
# watershed postprocessing
val_seg_inst = watershed_post(val_pred_dist.squeeze(1).cpu().numpy(), val_pred.squeeze(1).cpu().numpy()[:,1])
test_pred_mask = val_seg_inst.squeeze().astype(np.uint16)
# overlay
boundary = segmentation.find_boundaries(test_pred_mask, connectivity=1, mode='inner')
boundary = morphology.binary_dilation(boundary, morphology.disk(1))
pre_img_data[boundary, 0] = 0
pre_img_data[boundary, 1] = 255
pre_img_data[boundary, 2] = 0
return test_pred_mask, pre_img_data
def predict(img):
print('##########', img.name)
img_name = img.name
if img_name.endswith('.tif') or img_name.endswith('.tiff'):
img_data = tif.imread(img_name)
else:
img_data = io.imread(img_name)
if len(img_data.shape)==2:
pre_img_data = normalize_channel(img_data, lower=0.1, upper=99.9)
pre_img_data = np.repeat(np.expand_dims(pre_img_data, -1), repeats=3, axis=-1)
else:
pre_img_data = np.zeros((img_data.shape[0], img_data.shape[1], 3), dtype=np.uint8)
for i in range(3):
img_channel_i = img_data[:,:,i]
if len(img_channel_i[np.nonzero(img_channel_i)])>0:
pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=0.1, upper=99.9)
seg_labels, seg_overlay = inference(pre_img_data)
tif.imwrite(join(os.getcwd(), 'segmentation.tiff'), seg_labels, compression='zlib')
return seg_overlay, join(os.getcwd(), 'segmentation.tiff')
unicell_api = gr.Interface(
predict,
inputs = gr.File(label="Input image (png, bmp, jpg, tif, tiff)"),
outputs = [gr.Image(label="Segmentation overlay"), gr.File(label="Download segmentation")],
title = "UniCell Online Demo",
examples=['demo.png', 'demo.tif']
)
unicell_api.launch()