File size: 4,188 Bytes
56afa1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2088d61
56afa1a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/env python
# coding=utf-8
# Author: Jun Ma

import os
join = os.path.join
import argparse
import numpy as np
import torch
import torch.nn as nn
import tifffile as tif
import monai
from tqdm import tqdm
from utils.postprocess import mask_overlay
from monai.transforms import Activations, AddChanneld, AsChannelFirstd, AsDiscrete, Compose, EnsureTyped, EnsureType
from models.unicell_modules import MiT_B2_UNet_MultiHead, MiT_B3_UNet_MultiHead
import matplotlib.pyplot as plt
from skimage import io, exposure, segmentation, morphology
from utils.postprocess import watershed_post
from utils.multi_task_sliding_window_inference import multi_task_sliding_window_inference
import gradio as gr

def normalize_channel(img, lower=0.1, upper=99.9):
    non_zero_vals = img[np.nonzero(img)]
    percentiles = np.percentile(non_zero_vals, [lower, upper])
    if percentiles[1] - percentiles[0] > 0.001:
        img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8')
    else:
        img_norm = img
    return img_norm

def preprocess(img_data):
    if len(img_data.shape) == 2:
        img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1)
    elif len(img_data.shape) == 3 and img_data.shape[-1] > 3:
        img_data = img_data[:,:, :3]
    else:
        pass
    pre_img_data = np.zeros(img_data.shape, dtype=np.uint8)
    for i in range(3):
        img_channel_i = img_data[:,:,i]
        if len(img_channel_i[np.nonzero(img_channel_i)])>0:
            pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99)
    return pre_img_data


def inference(pre_img_data):
    test_npy = pre_img_data/np.max(pre_img_data)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = MiT_B2_UNet_MultiHead(in_channels=3, out_channels=3, regress_class=1, img_size=256).to(device)
    checkpoint = torch.load('./model.pth', map_location=torch.device(device))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    with torch.no_grad():
        test_tensor = torch.from_numpy(np.expand_dims(test_npy, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
        
        val_pred, val_pred_dist = multi_task_sliding_window_inference(inputs=test_tensor, roi_size=(256, 256), sw_batch_size=8, predictor=model)

        # watershed postprocessing
        val_seg_inst = watershed_post(val_pred_dist.squeeze(1).cpu().numpy(), val_pred.squeeze(1).cpu().numpy()[:,1])        
        test_pred_mask = val_seg_inst.squeeze().astype(np.uint16)
        
        # overlay
        boundary = segmentation.find_boundaries(test_pred_mask, connectivity=1, mode='inner')
        boundary = morphology.binary_dilation(boundary, morphology.disk(1))
        pre_img_data[boundary, 0] = 0
        pre_img_data[boundary, 1] = 255
        pre_img_data[boundary, 2] = 0

    return test_pred_mask, pre_img_data


def predict(img):
    print('##########', img.name)
    img_name = img.name
    if img_name.endswith('.tif') or img_name.endswith('.tiff'):
        img_data = tif.imread(img_name)
    else:
        img_data = io.imread(img_name)
    if len(img_data.shape)==2:
        pre_img_data = normalize_channel(img_data, lower=0.1, upper=99.9)
        pre_img_data = np.repeat(np.expand_dims(pre_img_data, -1), repeats=3, axis=-1)
    else:
        pre_img_data = np.zeros((img_data.shape[0], img_data.shape[1], 3), dtype=np.uint8)
        for i in range(3):
            img_channel_i = img_data[:,:,i]
            if len(img_channel_i[np.nonzero(img_channel_i)])>0:
                pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=0.1, upper=99.9)
    
    seg_labels, seg_overlay = inference(pre_img_data)

    tif.imwrite(join(os.getcwd(), 'segmentation.tiff'), seg_labels, compression='zlib')

    return seg_overlay, join(os.getcwd(), 'segmentation.tiff')

unicell_api = gr.Interface(
    predict, 
    inputs = gr.File(label="Input image (png, bmp, jpg, tif, tiff)"),
    outputs = [gr.Image(label="Segmentation overlay"), gr.File(label="Download segmentation")],
    title = "UniCell Online Demo",
    examples=['demo.png', 'demo.tif']
)

unicell_api.launch()