Lewislou commited on
Commit
0ca2a11
·
1 Parent(s): ad5cbcb

Upload 24 files

Browse files
README.md CHANGED
@@ -1,3 +1,72 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Solution of Team Sribd-med for NeurIPS-CellSeg Challenge
2
+ This repository provides the solution of team Sribd-med for [NeurIPS-CellSeg](https://neurips22-cellseg.grand-challenge.org/) Challenge. The details of our method are described in our paper [Multi-stream Cell Segmentation with Low-level Cues for Multi-modality Images]. Some parts of the codes are from the baseline codes of the [NeurIPS-CellSeg-Baseline](https://github.com/JunMa11/NeurIPS-CellSeg) repository,
3
+
4
+ You can reproduce our method as follows step by step:
5
+
6
+ ## Environments and Requirements:
7
+ Install requirements by
8
+
9
+ ```shell
10
+ python -m pip install -r requirements.txt
11
+ ```
12
+
13
+ ## Dataset
14
+ The competition training and tuning data can be downloaded from https://neurips22-cellseg.grand-challenge.org/dataset/
15
+ Besides, you can download three publiced data from the following link:
16
+ Cellpose: https://www.cellpose.org/dataset 
17
+ Omnipose: http://www.cellpose.org/dataset_omnipose
18
+ Sartorius: https://www.kaggle.com/competitions/sartorius-cell-instance-segmentation/overview 
19
+
20
+ ## Automatic cell classification
21
+ You can classify the cells into four classes in this step.
22
+ Put all the images (competition + Cellpose + Omnipose + Sartorius) in one folder (data/allimages).
23
+ Run classification code:
24
+
25
+ ```shell
26
+ python classification/unsup_classification.py
27
+ ```
28
+ The results can be stored in data/classification_results/
29
+
30
+ ## CNN-base classification model training
31
+ Using the classified images in data/classification_results/. A resnet18 is trained:
32
+ ```shell
33
+ python classification/train_classification.py
34
+ ```
35
+ ## Segmentation Training
36
+ Pre-training convnext-stardist using all the images (data/allimages).
37
+ ```shell
38
+ python train_convnext_stardist.py
39
+ ```
40
+ For class 0,2,3 finetune on the classified data (Take class1 as a example):
41
+ ```shell
42
+ python finetune_convnext_stardist.py model_dir=(The pretrained convnext-stardist model) data_dir='data/classification_results/class1'
43
+ ```
44
+ For class 1 train the convnext-hover from scratch using classified class 3 data.
45
+ ```shell
46
+ python train_convnext_hover.py data_dir='data/classification_results/class3'
47
+ ```
48
+
49
+ Finally, four segmentation models will be trained.
50
+
51
+ ## Trained models
52
+ The models can be downloaded from this link:
53
+ https://drive.google.com/drive/folders/1MkEOpgmdkg5Yqw6Ng5PoOhtmo9xPPwIj?usp=sharing
54
+
55
+ ## Inference
56
+ The inference process includes classification and segmentation.
57
+ ```shell
58
+ python predict.py -i input_path -o output_path --model_path './models'
59
+ ```
60
+
61
+ ## Evaluation
62
+ Calculate the F-score for evaluation:
63
+ ```shell
64
+ python compute_metric.py --gt_path path_to_labels --seg_path output_path
65
+ ```
66
+
67
+ ## Results
68
+ The tuning set F1 score of our method is 0.8795. The rank running time of our method on all the 101 cases in the tuning set is zero in our local
69
+ workstation.
70
+ ## Acknowledgement
71
+ We thank for the contributors of public datasets.
72
+
cellseg_time_eval.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The code was adapted from the MICCAI FLARE Challenge
3
+ https://flare22.grand-challenge.org/
4
+
5
+ The testing images will be evaluated one by one.
6
+ To compensate for the Docker container startup time, we give a time tolerance for the running time.
7
+ https://neurips22-cellseg.grand-challenge.org/metrics/
8
+ """
9
+
10
+ import os
11
+ join = os.path.join
12
+ import sys
13
+ import shutil
14
+ import time
15
+ import torch
16
+ import argparse
17
+ from collections import OrderedDict
18
+ from skimage import io
19
+ import tifffile as tif
20
+ import numpy as np
21
+ import pandas as pd
22
+
23
+ parser = argparse.ArgumentParser('Segmentation efficiency eavluation for docker containers', add_help=False)
24
+ parser.add_argument('-i', '--test_img_path', default='./val-imgs-30/', type=str, help='testing data path')
25
+ parser.add_argument('-o','--save_path', default='./val_team_seg', type=str, help='segmentation output path')
26
+ parser.add_argument('-d','--docker_folder_path', default='./team_docker', type=str, help='team docker path')
27
+ args = parser.parse_args()
28
+
29
+ test_img_path = args.test_img_path
30
+ save_path = args.save_path
31
+ docker_path = args.docker_folder_path
32
+
33
+ input_temp = './inputs/'
34
+ output_temp = './outputs'
35
+ os.makedirs(save_path, exist_ok=True)
36
+
37
+ dockers = sorted(os.listdir(docker_path))
38
+ test_cases = sorted(os.listdir(test_img_path))
39
+
40
+ for docker in dockers:
41
+ try:
42
+ # create temp folers for inference one-by-one
43
+ if os.path.exists(input_temp):
44
+ shutil.rmtree(input_temp)
45
+ if os.path.exists(output_temp):
46
+ shutil.rmtree(output_temp)
47
+ os.makedirs(input_temp)
48
+ os.makedirs(output_temp)
49
+ # load docker and create a new folder to save segmentation results
50
+ teamname = docker.split('.')[0].lower()
51
+ print('teamname docker: ', docker)
52
+ # os.system('docker image load < {}'.format(join(docker_path, docker)))
53
+ team_outpath = join(save_path, teamname)
54
+ if os.path.exists(team_outpath):
55
+ shutil.rmtree(team_outpath)
56
+ os.mkdir(team_outpath)
57
+ metric = OrderedDict()
58
+ metric['Img Name'] = []
59
+ metric['Real Running Time'] = []
60
+ metric['Rank Running Time'] = []
61
+ # To obtain the running time for each case, we inference the testing case one-by-one
62
+ for case in test_cases:
63
+ shutil.copy(join(test_img_path, case), input_temp)
64
+ if case.endswith('.tif') or case.endswith('.tiff'):
65
+ img = tif.imread(join(input_temp, case))
66
+ else:
67
+ img = io.imread(join(input_temp, case))
68
+ pix_num = img.shape[0] * img.shape[1]
69
+ cmd = 'docker container run --gpus="device=0" -m 28g --name {} --rm -v $PWD/inputs/:/workspace/inputs/ -v $PWD/outputs/:/workspace/outputs/ {}:latest /bin/bash -c "sh predict.sh" '.format(teamname, teamname)
70
+ print(teamname, ' docker command:', cmd, '\n', 'testing image name:', case)
71
+ start_time = time.time()
72
+ os.system(cmd)
73
+ real_running_time = time.time() - start_time
74
+ print(f"{case} finished! Inference time: {real_running_time}")
75
+ # save metrics
76
+ metric['Img Name'].append(case)
77
+ metric['Real Running Time'].append(real_running_time)
78
+ if pix_num <= 1000000:
79
+ rank_running_time = np.max([0, real_running_time-10])
80
+ else:
81
+ rank_running_time = np.max([0, real_running_time-10*(pix_num/1000000)])
82
+ metric['Rank Running Time'].append(rank_running_time)
83
+ os.remove(join(input_temp, case))
84
+ seg_name = case.split('.')[0] + '_label.tiff'
85
+ try:
86
+ os.rename(join(output_temp, seg_name), join(team_outpath, seg_name))
87
+ except:
88
+ print(f"{join(output_temp, seg_name)}, {join(team_outpath, seg_name)}")
89
+ print("Wrong segmentation name!!! It should be image_name.split(\'.\')[0] + \'_label.tiff\' ")
90
+ metric_df = pd.DataFrame(metric)
91
+ metric_df.to_csv(join(team_outpath, teamname + '_running_time.csv'), index=False)
92
+ torch.cuda.empty_cache()
93
+ # os.system("docker rmi {}:latest".format(teamname))
94
+ shutil.rmtree(input_temp)
95
+ shutil.rmtree(output_temp)
96
+ except Exception as e:
97
+ print(e)
classification/train_classification.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, glob, time, random, shutil, copy
2
+ from tqdm import tqdm
3
+ import numpy as np
4
+ import pandas as pd
5
+ import torch
6
+ import torchvision
7
+ from torchvision import datasets, models, transforms
8
+ import torch.utils.data as data
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+ from torch.optim import lr_scheduler
12
+ import torch.nn.functional as F
13
+ from torchsummary import summary
14
+ from matplotlib import pyplot as plt
15
+ from torchvision.models import resnet18, ResNet18_Weights # do not import
16
+ from PIL import Image, ImageFile
17
+ from skimage import io
18
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
19
+
20
+ # Set the train and validation directory paths
21
+ train_directory = 'dataset/train'
22
+ valid_directory = 'dataset/val'
23
+
24
+ # Batch size
25
+ bs = 64
26
+ # Number of epochs
27
+ num_epochs = 20
28
+ # Number of classes
29
+ num_classes = 4
30
+ # Number of workers
31
+ num_cpu = 8
32
+
33
+ # Applying transforms to the data
34
+ image_transforms = {
35
+ 'train': transforms.Compose([
36
+ transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
37
+ transforms.RandomRotation(degrees=15),
38
+ transforms.RandomHorizontalFlip(),
39
+ transforms.CenterCrop(size=224),
40
+ transforms.ToTensor(),
41
+ transforms.Normalize([0.485, 0.456, 0.406],
42
+ [0.229, 0.224, 0.225])
43
+ ]),
44
+ 'valid': transforms.Compose([
45
+ transforms.Resize(size=256),
46
+ transforms.CenterCrop(size=224),
47
+ transforms.ToTensor(),
48
+ transforms.Normalize([0.485, 0.456, 0.406],
49
+ [0.229, 0.224, 0.225])
50
+ ])
51
+ }
52
+
53
+ # Load data from folders
54
+ dataset = {
55
+ 'train': datasets.ImageFolder(root=train_directory, transform=image_transforms['train']),
56
+ 'valid': datasets.ImageFolder(root=valid_directory, transform=image_transforms['valid'])
57
+ }
58
+
59
+ # Size of train and validation data
60
+ dataset_sizes = {
61
+ 'train':len(dataset['train']),
62
+ 'valid':len(dataset['valid'])
63
+ }
64
+
65
+ # Create iterators for data loading
66
+ dataloaders = {
67
+ 'train':data.DataLoader(dataset['train'], batch_size=bs, shuffle=True,
68
+ num_workers=num_cpu, pin_memory=True, drop_last=False),
69
+ 'valid':data.DataLoader(dataset['valid'], batch_size=bs, shuffle=False,
70
+ num_workers=num_cpu, pin_memory=True, drop_last=False)
71
+ }
72
+
73
+ # Class names or target labels
74
+ class_names = dataset['train'].classes
75
+ print("Classes:", class_names)
76
+
77
+ # Print the train and validation data sizes
78
+ print("Training-set size:",dataset_sizes['train'],
79
+ "\nValidation-set size:", dataset_sizes['valid'])
80
+
81
+ modelname = 'resnet18'
82
+
83
+ # Set default device as gpu, if available
84
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
85
+
86
+ weights = ResNet18_Weights.DEFAULT
87
+ model = resnet18(weights=None)
88
+ num_ftrs = model.fc.in_features
89
+ model.fc = nn.Linear(num_ftrs, num_classes)
90
+
91
+
92
+ # Transfer the model to GPU
93
+ model = model.to(device)
94
+
95
+ # Print model summary
96
+ print('Model Summary:-\n')
97
+ for num, (name, param) in enumerate(model.named_parameters()):
98
+ print(num, name, param.requires_grad )
99
+ summary(model, input_size=(3, 224, 224))
100
+
101
+ # Loss function
102
+ criterion = nn.CrossEntropyLoss()
103
+
104
+ # Optimizer
105
+ optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
106
+
107
+ # Learning rate decay
108
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
109
+
110
+ since = time.time()
111
+
112
+ best_model_wts = copy.deepcopy(model.state_dict())
113
+ best_acc = 0.0
114
+
115
+ for epoch in range(1, num_epochs+1):
116
+ print('Epoch {}/{}'.format(epoch, num_epochs))
117
+ print('-' * 10)
118
+
119
+ # Each epoch has a training and validation phase
120
+ for phase in ['train', 'valid']:
121
+ if phase == 'train':
122
+ model.train() # Set model to training mode
123
+ else:
124
+ model.eval() # Set model to evaluate mode
125
+
126
+ running_loss = 0.0
127
+ running_corrects = 0
128
+
129
+ # Iterate over data.
130
+ n = 0
131
+ stream = tqdm(dataloaders[phase])
132
+ for i, (inputs, labels) in enumerate(stream, start=1):
133
+ inputs = inputs.to(device)
134
+ labels = labels.to(device)
135
+
136
+ # zero the parameter gradients
137
+ optimizer.zero_grad()
138
+
139
+ # forward
140
+ # track history if only in train
141
+ with torch.set_grad_enabled(phase == 'train'):
142
+ outputs = model(inputs)
143
+ _, preds = torch.max(outputs, 1)
144
+ loss = criterion(outputs, labels)
145
+
146
+ # backward + optimize only if in training phase
147
+ if phase == 'train':
148
+ loss.backward()
149
+ optimizer.step()
150
+
151
+ # statistics
152
+ n += inputs.shape[0]
153
+ running_loss += loss.item() * inputs.size(0)
154
+ running_corrects += torch.sum(preds == labels.data)
155
+
156
+ stream.set_description(f'Batch {i}/{len(dataloaders[phase])} | Loss: {running_loss/n:.4f}, Acc: {running_corrects/n:.4f}')
157
+
158
+ if phase == 'train':
159
+ scheduler.step()
160
+
161
+ epoch_loss = running_loss / dataset_sizes[phase]
162
+ epoch_acc = running_corrects.double() / dataset_sizes[phase]
163
+
164
+ print('Epoch {}-{} Loss: {:.4f} Acc: {:.4f}'.format(
165
+ epoch, phase, epoch_loss, epoch_acc))
166
+
167
+ # deep copy the model
168
+ if phase == 'valid' and epoch_acc >= best_acc:
169
+ best_acc = epoch_acc
170
+ best_model_wts = copy.deepcopy(model.state_dict())
171
+ print('Update best model!')
172
+
173
+ time_elapsed = time.time() - since
174
+ print('Training complete in {:.0f}m {:.0f}s'.format(
175
+ time_elapsed // 60, time_elapsed % 60))
176
+ print('Best val Acc: {:4f}'.format(best_acc))
177
+
178
+ # load best model weights
179
+ model.load_state_dict(best_model_wts)
180
+ torch.save(model, 'logs/resnet18_4class.pth')
181
+ torch.save(model.state_dict(), 'logs/resnet18_4class.tar')
classification/unsup_classification.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[1]:
5
+
6
+
7
+ import os
8
+ import numpy as np
9
+ import shutil
10
+ import torch
11
+ import torch.nn
12
+ import torchvision.models as models
13
+ from torch.autograd import Variable
14
+ import torch.cuda
15
+ import torchvision.transforms as transforms
16
+ from PIL import Image
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from sklearn.datasets import make_blobs
20
+ from sklearn.cluster import KMeans
21
+ from sklearn.metrics import silhouette_score
22
+ from sklearn.preprocessing import StandardScaler
23
+ from sklearn.metrics import pairwise_distances_argmin_min
24
+ from scipy.spatial.distance import pdist, squareform
25
+ from skimage import io, segmentation, morphology, exposure
26
+ from skimage.color import rgb2hsv
27
+ img_to_tensor = transforms.ToTensor()
28
+ import random
29
+ import tifffile as tif
30
+ path = '/data1/partitionA/CUHKSZ/histopath_2022/grand_competition/Train_Labeled/images/'
31
+ files = os.listdir(path)
32
+ binary_path = '0/'
33
+ gray_path = '1/'
34
+ colored_path = 'colored/'
35
+ os.makedirs(binary_path, exist_ok=True)
36
+ os.makedirs(colored_path, exist_ok=True)
37
+ os.makedirs(gray_path, exist_ok=True)
38
+ for img_name in files:
39
+ img_path = path + str(img_name)
40
+ if img_name.endswith('.tif') or img_name.endswith('.tiff'):
41
+ img_data = tif.imread(img_path)
42
+ else:
43
+ img_data = io.imread(img_path)
44
+ if len(img_data.shape) == 2 or (len(img_data.shape) == 3 and img_data.shape[-1] == 1):
45
+ shutil.copyfile(path + img_name, binary_path + img_name)
46
+ elif len(img_data.shape) == 3 and img_data.shape[-1] > 3:
47
+ shutil.copyfile(path + img_name, colored_path + img_name)
48
+ else:
49
+ hsv_img = rgb2hsv(img_data)
50
+ s = hsv_img[:,:,1]
51
+ v = hsv_img[:,:,2]
52
+ print(img_name,s.mean(),v.mean())
53
+ if s.mean() > 0.1 or (v.mean()<0.1 or v.mean() > 0.6):
54
+ shutil.copyfile(path + img_name, colored_path + img_name)
55
+ else:
56
+ shutil.copyfile(path + img_name, gray_path + img_name)
57
+
58
+
59
+
60
+ # In[3]:
61
+
62
+
63
+ ####Phrase 2 clustering by cell size
64
+ from skimage import measure
65
+ colored_path = 'colored/'
66
+ label_path = 'allimages/tif/'
67
+ big_path = '2/'
68
+ small_path = '3/'
69
+ files = os.listdir(colored_path)
70
+ os.makedirs(big_path, exist_ok=True)
71
+ os.makedirs(small_path, exist_ok=True)
72
+ for img_name in files:
73
+ label = tif.imread(label_path + img_name.split('.')[0]+'.tif')
74
+ props = measure.regionprops(label)
75
+ num_pix = []
76
+ for idx in range(len(props)):
77
+ num_pix.append(props[idx].area)
78
+ max_area = max(num_pix)
79
+ print(max_area)
80
+ if max_area > 30000:
81
+ shutil.copyfile(path + img_name, big_path + img_name)
82
+ else:
83
+ shutil.copyfile(path + img_name, small_path + img_name)
84
+
85
+
86
+
87
+
88
+
89
+
90
+
91
+
compute_metric.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Created on Thu Mar 31 18:10:52 2022
3
+ adapted form https://github.com/stardist/stardist/blob/master/stardist/matching.py
4
+ Thanks the authors of Stardist for sharing the great code
5
+
6
+ """
7
+
8
+ import argparse
9
+ import numpy as np
10
+ from numba import jit
11
+ from scipy.optimize import linear_sum_assignment
12
+ from collections import OrderedDict
13
+ import pandas as pd
14
+ from skimage import segmentation
15
+ import tifffile as tif
16
+ import os
17
+ join = os.path.join
18
+ from tqdm import tqdm
19
+
20
+ def _intersection_over_union(masks_true, masks_pred):
21
+ """ intersection over union of all mask pairs
22
+
23
+ Parameters
24
+ ------------
25
+
26
+ masks_true: ND-array, int
27
+ ground truth masks, where 0=NO masks; 1,2... are mask labels
28
+ masks_pred: ND-array, int
29
+ predicted masks, where 0=NO masks; 1,2... are mask labels
30
+ """
31
+ overlap = _label_overlap(masks_true, masks_pred)
32
+ n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
33
+ n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
34
+ iou = overlap / (n_pixels_pred + n_pixels_true - overlap)
35
+ iou[np.isnan(iou)] = 0.0
36
+ return iou
37
+
38
+ @jit(nopython=True)
39
+ def _label_overlap(x, y):
40
+ """ fast function to get pixel overlaps between masks in x and y
41
+
42
+ Parameters
43
+ ------------
44
+
45
+ x: ND-array, int
46
+ where 0=NO masks; 1,2... are mask labels
47
+ y: ND-array, int
48
+ where 0=NO masks; 1,2... are mask labels
49
+
50
+ Returns
51
+ ------------
52
+
53
+ overlap: ND-array, int
54
+ matrix of pixel overlaps of size [x.max()+1, y.max()+1]
55
+
56
+ """
57
+ x = x.ravel()
58
+ y = y.ravel()
59
+
60
+ # preallocate a 'contact map' matrix
61
+ overlap = np.zeros((1+x.max(),1+y.max()), dtype=np.uint)
62
+
63
+ # loop over the labels in x and add to the corresponding
64
+ # overlap entry. If label A in x and label B in y share P
65
+ # pixels, then the resulting overlap is P
66
+ # len(x)=len(y), the number of pixels in the whole image
67
+ for i in range(len(x)):
68
+ overlap[x[i],y[i]] += 1
69
+ return overlap
70
+
71
+ def _true_positive(iou, th):
72
+ """ true positive at threshold th
73
+
74
+ Parameters
75
+ ------------
76
+
77
+ iou: float, ND-array
78
+ array of IOU pairs
79
+ th: float
80
+ threshold on IOU for positive label
81
+
82
+ Returns
83
+ ------------
84
+
85
+ tp: float
86
+ number of true positives at threshold
87
+ """
88
+ n_min = min(iou.shape[0], iou.shape[1])
89
+ costs = -(iou >= th).astype(float) - iou / (2*n_min)
90
+ true_ind, pred_ind = linear_sum_assignment(costs)
91
+ match_ok = iou[true_ind, pred_ind] >= th
92
+ tp = match_ok.sum()
93
+ return tp
94
+
95
+ def eval_tp_fp_fn(masks_true, masks_pred, threshold=0.5):
96
+ num_inst_gt = np.max(masks_true)
97
+ num_inst_seg = np.max(masks_pred)
98
+ if num_inst_seg>0:
99
+ iou = _intersection_over_union(masks_true, masks_pred)[1:, 1:]
100
+ # for k,th in enumerate(threshold):
101
+ tp = _true_positive(iou, threshold)
102
+ fp = num_inst_seg - tp
103
+ fn = num_inst_gt - tp
104
+ else:
105
+ print('No segmentation results!')
106
+ tp = 0
107
+ fp = 0
108
+ fn = 0
109
+
110
+ return tp, fp, fn
111
+
112
+ def remove_boundary_cells(mask):
113
+ W, H = mask.shape
114
+ bd = np.ones((W, H))
115
+ bd[2:W-2, 2:H-2] = 0
116
+ bd_cells = np.unique(mask*bd)
117
+ for i in bd_cells[1:]:
118
+ mask[mask==i] = 0
119
+ new_label,_,_ = segmentation.relabel_sequential(mask)
120
+ return new_label
121
+
122
+ def main():
123
+ parser = argparse.ArgumentParser('Compute F1 score for cell segmentation results', add_help=False)
124
+ # Dataset parameters
125
+ parser.add_argument('--gt_path', type=str, help='path to ground truth; file names end with _label.tiff', required=True)
126
+ parser.add_argument('--seg_path', type=str, help='path to segmentation results; file names are the same as ground truth', required=True)
127
+ parser.add_argument('--save_path', default='./', help='path where to save metrics')
128
+ args = parser.parse_args()
129
+
130
+ gt_path = args.gt_path
131
+ seg_path = args.seg_path
132
+ names = sorted(os.listdir(seg_path))
133
+ seg_metric = OrderedDict()
134
+ seg_metric['Names'] = []
135
+ seg_metric['F1_Score'] = []
136
+ for name in tqdm(names):
137
+ assert name.endswith('_label.tiff'), 'The suffix of label name should be _label.tiff'
138
+
139
+ # Load the images for this case
140
+ gt = tif.imread(join(gt_path, name))
141
+ seg = tif.imread(join(seg_path, name))
142
+
143
+ # Score the cases
144
+ # do not consider cells on the boundaries during evaluation
145
+ if np.prod(gt.shape)<25000000:
146
+ gt = remove_boundary_cells(gt.astype(np.int32))
147
+ seg = remove_boundary_cells(seg.astype(np.int32))
148
+ tp, fp, fn = eval_tp_fp_fn(gt, seg, threshold=0.5)
149
+ else: # for large images (>5000x5000), the F1 score is computed by a patch-based way
150
+ H, W = gt.shape
151
+ roi_size = 2000
152
+
153
+ if H % roi_size != 0:
154
+ n_H = H // roi_size + 1
155
+ new_H = roi_size * n_H
156
+ else:
157
+ n_H = H // roi_size
158
+ new_H = H
159
+
160
+ if W % roi_size != 0:
161
+ n_W = W // roi_size + 1
162
+ new_W = roi_size * n_W
163
+ else:
164
+ n_W = W // roi_size
165
+ new_W = W
166
+
167
+ gt_pad = np.zeros((new_H, new_W), dtype=gt.dtype)
168
+ seg_pad = np.zeros((new_H, new_W), dtype=gt.dtype)
169
+ gt_pad[:H, :W] = gt
170
+ seg_pad[:H, :W] = seg
171
+
172
+ tp = 0
173
+ fp = 0
174
+ fn = 0
175
+ for i in range(n_H):
176
+ for j in range(n_W):
177
+ gt_roi = remove_boundary_cells(gt_pad[roi_size*i:roi_size*(i+1), roi_size*j:roi_size*(j+1)])
178
+ seg_roi = remove_boundary_cells(seg_pad[roi_size*i:roi_size*(i+1), roi_size*j:roi_size*(j+1)])
179
+ tp_i, fp_i, fn_i = eval_tp_fp_fn(gt_roi, seg_roi, threshold=0.5)
180
+ tp += tp_i
181
+ fp += fp_i
182
+ fn += fn_i
183
+
184
+ if tp == 0:
185
+ precision = 0
186
+ recall = 0
187
+ f1 = 0
188
+ else:
189
+ precision = tp / (tp + fp)
190
+ recall = tp / (tp + fn)
191
+ f1 = 2*(precision * recall)/ (precision + recall)
192
+ seg_metric['Names'].append(name)
193
+ seg_metric['F1_Score'].append(np.round(f1, 4))
194
+
195
+
196
+ seg_metric_df = pd.DataFrame(seg_metric)
197
+ seg_metric_df.to_csv(join(args.save_path, 'seg_metric.csv'), index=False)
198
+ print('mean F1 Score:', np.mean(seg_metric['F1_Score']))
199
+
200
+ if __name__ == '__main__':
201
+ main()
finetune_convnext_stardist.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Adapted form MONAI Tutorial: https://github.com/Project-MONAI/tutorials/tree/main/2d_segmentation/torch
5
+ """
6
+
7
+ import argparse
8
+ import os
9
+
10
+ join = os.path.join
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ from torch.utils.data import DataLoader
16
+ from torch.utils.tensorboard import SummaryWriter
17
+ from stardist import star_dist,edt_prob
18
+ from stardist import dist_to_coord, non_maximum_suppression, polygons_to_label
19
+ from stardist import random_label_cmap,ray_angles
20
+ import monai
21
+ from collections import OrderedDict
22
+ from compute_metric import eval_tp_fp_fn,remove_boundary_cells
23
+ from monai.data import decollate_batch, PILReader
24
+ from monai.inferers import sliding_window_inference
25
+ from monai.metrics import DiceMetric
26
+ from monai.transforms import (
27
+ Activations,
28
+ AsChannelFirstd,
29
+ AddChanneld,
30
+ AsDiscrete,
31
+ Compose,
32
+ LoadImaged,
33
+ SpatialPadd,
34
+ RandSpatialCropd,
35
+ RandRotate90d,
36
+ ScaleIntensityd,
37
+ RandAxisFlipd,
38
+ RandZoomd,
39
+ RandGaussianNoised,
40
+ RandAdjustContrastd,
41
+ RandGaussianSmoothd,
42
+ RandHistogramShiftd,
43
+ EnsureTyped,
44
+ EnsureType,
45
+ )
46
+ from monai.visualize import plot_2d_or_3d_image
47
+ import matplotlib.pyplot as plt
48
+ from datetime import datetime
49
+ import shutil
50
+ import tqdm
51
+ from models.unetr2d import UNETR2D
52
+ from models.swin_unetr import SwinUNETR
53
+ from models.flexible_unet import FlexibleUNet
54
+ from models.flexible_unet_convext import FlexibleUNetConvext
55
+ print("Successfully imported all requirements!")
56
+ torch.backends.cudnn.enabled =False
57
+ #os.environ["OMP_NUM_THREADS"] = "1"
58
+ #os.environ["MKL_NUM_THREADS"] = "1"
59
+ def main():
60
+ parser = argparse.ArgumentParser("Baseline for Microscopy image segmentation")
61
+ # Dataset parameters
62
+ parser.add_argument(
63
+ "--data_path",
64
+ default="",
65
+ type=str,
66
+ help="training data path; subfolders: images, labels",
67
+ )
68
+ parser.add_argument(
69
+ "--work_dir", default="/mntnfs/med_data5/louwei/nips_comp/stardist_finetune1/", help="path where to save models and logs"
70
+ )
71
+ parser.add_argument(
72
+ "--model_dir", default="/", help="path where to load pretrained model"
73
+ )
74
+ parser.add_argument("--seed", default=2022, type=int)
75
+ # parser.add_argument("--resume", default=False, help="resume from checkpoint")
76
+ parser.add_argument("--num_workers", default=4, type=int)
77
+ #parser.add_argument("--local_rank", type=int)
78
+ # Model parameters
79
+ parser.add_argument(
80
+ "--model_name", default="efficientunet", help="select mode: unet, unetr, swinunetr"
81
+ )
82
+ parser.add_argument("--num_class", default=3, type=int, help="segmentation classes")
83
+ parser.add_argument(
84
+ "--input_size", default=512, type=int, help="segmentation classes"
85
+ )
86
+ # Training parameters
87
+ parser.add_argument("--batch_size", default=16, type=int, help="Batch size per GPU")
88
+ parser.add_argument("--max_epochs", default=2000, type=int)
89
+ parser.add_argument("--val_interval", default=10, type=int)
90
+ parser.add_argument("--epoch_tolerance", default=100, type=int)
91
+ parser.add_argument("--initial_lr", type=float, default=1e-4, help="learning rate")
92
+
93
+ args = parser.parse_args()
94
+ #torch.cuda.set_device(args.local_rank)
95
+ #torch.distributed.init_process_group(backend='nccl')
96
+ monai.config.print_config()
97
+ n_rays = 32
98
+ pre_trained = True
99
+ #%% set training/validation split
100
+ np.random.seed(args.seed)
101
+ pre_trained_path = args.model_dir
102
+ model_path = join(args.work_dir, args.model_name + "_3class")
103
+ os.makedirs(model_path, exist_ok=True)
104
+ run_id = datetime.now().strftime("%Y%m%d-%H%M")
105
+ # This must be change every runing time ! ! ! ! ! ! ! ! ! ! !
106
+ model_file = "models/flexible_unet_convext.py"
107
+ shutil.copyfile(
108
+ __file__, join(model_path, os.path.basename(__file__))
109
+ )
110
+ shutil.copyfile(
111
+ model_file, join(model_path, os.path.basename(model_file))
112
+ )
113
+ img_path = join(args.data_path, "train/images")
114
+ gt_path = join(args.data_path, "train/tif")
115
+ val_img_path = join(args.data_path, "valid/images")
116
+ val_gt_path = join(args.data_path, "valid/tif")
117
+ img_names = sorted(os.listdir(img_path))
118
+ gt_names = [img_name.split(".")[0] + ".tif" for img_name in img_names]
119
+ img_num = len(img_names)
120
+ val_frac = 0.1
121
+ val_img_names = sorted(os.listdir(val_img_path))
122
+ val_gt_names = [img_name.split(".")[0] + ".tif" for img_name in val_img_names]
123
+ #indices = np.arange(img_num)
124
+ #np.random.shuffle(indices)
125
+ #val_split = int(img_num * val_frac)
126
+ #train_indices = indices[val_split:]
127
+ #val_indices = indices[:val_split]
128
+
129
+ train_files = [
130
+ {"img": join(img_path, img_names[i]), "label": join(gt_path, gt_names[i])}
131
+ for i in range(len(img_names))
132
+ ]
133
+ val_files = [
134
+ {"img": join(val_img_path, val_img_names[i]), "label": join(val_gt_path, val_gt_names[i])}
135
+ for i in range(len(val_img_names))
136
+ ]
137
+ print(
138
+ f"training image num: {len(train_files)}, validation image num: {len(val_files)}"
139
+ )
140
+ #%% define transforms for image and segmentation
141
+ train_transforms = Compose(
142
+ [
143
+ LoadImaged(
144
+ keys=["img", "label"], reader=PILReader, dtype=np.float32
145
+ ), # image three channels (H, W, 3); label: (H, W)
146
+ AddChanneld(keys=["label"], allow_missing_keys=True), # label: (1, H, W)
147
+ AsChannelFirstd(
148
+ keys=["img"], channel_dim=-1, allow_missing_keys=True
149
+ ), # image: (3, H, W)
150
+ #ScaleIntensityd(
151
+ #keys=["img"], allow_missing_keys=True
152
+ #), # Do not scale label
153
+ SpatialPadd(keys=["img", "label"], spatial_size=args.input_size),
154
+ RandSpatialCropd(
155
+ keys=["img", "label"], roi_size=args.input_size, random_size=False
156
+ ),
157
+ RandAxisFlipd(keys=["img", "label"], prob=0.5),
158
+ RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=[0, 1]),
159
+ # # intensity transform
160
+ RandGaussianNoised(keys=["img"], prob=0.25, mean=0, std=0.1),
161
+ RandAdjustContrastd(keys=["img"], prob=0.25, gamma=(1, 2)),
162
+ RandGaussianSmoothd(keys=["img"], prob=0.25, sigma_x=(1, 2)),
163
+ RandHistogramShiftd(keys=["img"], prob=0.25, num_control_points=3),
164
+ RandZoomd(
165
+ keys=["img", "label"],
166
+ prob=0.15,
167
+ min_zoom=0.5,
168
+ max_zoom=2,
169
+ mode=["area", "nearest"],
170
+ ),
171
+ EnsureTyped(keys=["img", "label"]),
172
+ ]
173
+ )
174
+
175
+ val_transforms = Compose(
176
+ [
177
+ LoadImaged(keys=["img", "label"], reader=PILReader, dtype=np.float32),
178
+ AddChanneld(keys=["label"], allow_missing_keys=True),
179
+ AsChannelFirstd(keys=["img"], channel_dim=-1, allow_missing_keys=True),
180
+ #ScaleIntensityd(keys=["img"], allow_missing_keys=True),
181
+ # AsDiscreted(keys=['label'], to_onehot=3),
182
+ EnsureTyped(keys=["img", "label"]),
183
+ ]
184
+ )
185
+
186
+ #% define dataset, data loader
187
+ check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
188
+ check_loader = DataLoader(check_ds, batch_size=1, num_workers=4)
189
+ check_data = monai.utils.misc.first(check_loader)
190
+ print(
191
+ "sanity check:",
192
+ check_data["img"].shape,
193
+ torch.max(check_data["img"]),
194
+ check_data["label"].shape,
195
+ torch.max(check_data["label"]),
196
+ )
197
+
198
+ #%% create a training data loader
199
+ train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
200
+ # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
201
+ train_loader = DataLoader(
202
+ train_ds,
203
+ batch_size=args.batch_size,
204
+ shuffle=True,
205
+ num_workers=args.num_workers,
206
+ pin_memory =True,
207
+ )
208
+ # create a validation data loader
209
+ val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
210
+ val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=1)
211
+
212
+ dice_metric = DiceMetric(
213
+ include_background=False, reduction="mean", get_not_nans=False
214
+ )
215
+
216
+ post_pred = Compose(
217
+ [EnsureType(), Activations(softmax=True), AsDiscrete(threshold=0.5)]
218
+ )
219
+ post_gt = Compose([EnsureType(), AsDiscrete(to_onehot=None)])
220
+ # create UNet, DiceLoss and Adam optimizer
221
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
222
+
223
+ if args.model_name.lower() == "efficientunet":
224
+ model = FlexibleUNetConvext(
225
+ in_channels=3,
226
+ out_channels=n_rays+1,
227
+ backbone='convnext_small',
228
+ pretrained=False,
229
+ ).to(device)
230
+
231
+ #loss_masked_dice = monai.losses.DiceCELoss(softmax=True)
232
+ loss_dice = monai.losses.DiceLoss(squared_pred=True,jaccard=True)
233
+ loss_bce = nn.BCELoss()
234
+ loss_dist_mae = nn.L1Loss()
235
+ activatation = nn.ReLU()
236
+ sigmoid = nn.Sigmoid()
237
+ #loss_dist_mae = monai.losses.DiceCELoss(softmax=True)
238
+ initial_lr = args.initial_lr
239
+ encoder = list(map(id, model.encoder.parameters()))
240
+ base_params = filter(lambda p: id(p) not in encoder, model.parameters())
241
+ params = [
242
+ {"params": base_params, "lr":initial_lr},
243
+ {"params": model.encoder.parameters(), "lr": initial_lr * 0.1},
244
+ ]
245
+ optimizer = torch.optim.AdamW(params, initial_lr)
246
+ if pre_trained == True:
247
+
248
+ checkpoint = torch.load(pre_trained_path, map_location=torch.device(device))
249
+ model.load_state_dict(checkpoint['model_state_dict'])
250
+ print('Load pretrained weights...')
251
+ max_epochs = args.max_epochs
252
+ epoch_tolerance = args.epoch_tolerance
253
+ val_interval = args.val_interval
254
+ best_metric = -1
255
+ best_metric_epoch = -1
256
+ epoch_loss_values = list()
257
+ metric_values = list()
258
+ writer = SummaryWriter(model_path)
259
+ max_f1 = 0
260
+ for epoch in range(0, max_epochs):
261
+ model.train()
262
+ epoch_loss = 0
263
+ epoch_loss_prob = 0
264
+ epoch_loss_dist_2 = 0
265
+ epoch_loss_dist_1 = 0
266
+ for step, batch_data in enumerate(train_loader, 1):
267
+ print(step)
268
+ inputs, labels = batch_data["img"],batch_data["label"]
269
+
270
+ processes_labels = []
271
+
272
+ for i in range(labels.shape[0]):
273
+ label = labels[i][0]
274
+ distances = star_dist(label,n_rays,mode='opencl')
275
+ distances = np.transpose(distances,(2,0,1))
276
+ #print(distances.shape)
277
+ obj_probabilities = edt_prob(label.astype(int))
278
+ obj_probabilities = np.expand_dims(obj_probabilities,0)
279
+ #print(obj_probabilities.shape)
280
+ final_label = np.concatenate((distances,obj_probabilities),axis=0)
281
+ #print(final_label.shape)
282
+ processes_labels.append(final_label)
283
+
284
+ labels = np.stack(processes_labels)
285
+
286
+ #print(inputs.shape,labels.shape)
287
+ inputs, labels = torch.tensor(inputs).to(device), torch.tensor(labels).to(device)
288
+ #print(inputs.shape,labels.shape)
289
+ optimizer.zero_grad()
290
+ output_dist,output_prob = model(inputs)
291
+ #print(outputs.shape)
292
+ dist_output = output_dist
293
+ prob_output = output_prob
294
+ dist_label = labels[:,:n_rays,:,:]
295
+ prob_label = torch.unsqueeze(labels[:,-1,:,:], 1)
296
+ #print(dist_output.shape,prob_output.shape,dist_label.shape)
297
+ #labels_onehot = monai.networks.one_hot(
298
+ #labels, args.num_class
299
+ #) # (b,cls,256,256)
300
+ #print(prob_label.max(),prob_label.min())
301
+ loss_dist_1 = loss_dice(dist_output*prob_label,dist_label*prob_label)
302
+ #print(loss_dist_1)
303
+ loss_prob = loss_bce(prob_output,prob_label)
304
+ #print(prob_label.shape,dist_output.shape)
305
+ loss_dist_2 = loss_dist_mae(dist_output*prob_label,dist_label*prob_label)
306
+ #print(loss_dist_2)
307
+ loss = loss_prob + loss_dist_2*0.3 + loss_dist_1
308
+ loss.backward()
309
+ optimizer.step()
310
+ epoch_loss += loss.item()
311
+ epoch_loss_prob += loss_prob.item()
312
+ epoch_loss_dist_2 += loss_dist_2.item()
313
+ epoch_loss_dist_1 += loss_dist_1.item()
314
+ epoch_len = len(train_ds) // train_loader.batch_size
315
+
316
+ epoch_loss /= step
317
+ epoch_loss_prob /= step
318
+ epoch_loss_dist_2 /= step
319
+ epoch_loss_dist_1 /= step
320
+ epoch_loss_values.append(epoch_loss)
321
+ print(f"epoch {epoch} average loss: {epoch_loss:.4f}")
322
+ writer.add_scalar("train_loss", epoch_loss, epoch)
323
+ print('dist dice: '+str(epoch_loss_dist_1)+' dist mae: '+str(epoch_loss_dist_2)+' prob bce: '+str(epoch_loss_prob))
324
+ checkpoint = {
325
+ "epoch": epoch,
326
+ "model_state_dict": model.state_dict(),
327
+ "optimizer_state_dict": optimizer.state_dict(),
328
+ "loss": epoch_loss_values,
329
+ }
330
+ if epoch < 40:
331
+ continue
332
+ if epoch > 1 and epoch % val_interval == 0:
333
+ torch.save(checkpoint, join(model_path, str(epoch) + ".pth"))
334
+ model.eval()
335
+ with torch.no_grad():
336
+ val_images = None
337
+ val_labels = None
338
+ val_outputs = None
339
+ seg_metric = OrderedDict()
340
+ seg_metric['F1_Score'] = []
341
+ for val_data in tqdm.tqdm(val_loader):
342
+ val_images, val_labels = val_data["img"].to(device), val_data[
343
+ "label"
344
+ ].to(device)
345
+ roi_size = (512, 512)
346
+ sw_batch_size = 4
347
+ output_dist,output_prob = sliding_window_inference(
348
+ val_images, roi_size, sw_batch_size, model
349
+ )
350
+ val_labels = val_labels[0][0].cpu().numpy()
351
+ prob = output_prob[0][0].cpu().numpy()
352
+ dist = output_dist[0].cpu().numpy()
353
+ #print(val_labels.shape,prob.shape,dist.shape)
354
+ dist = np.transpose(dist,(1,2,0))
355
+ dist = np.maximum(1e-3, dist)
356
+ points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4)
357
+
358
+ coord = dist_to_coord(disti,points)
359
+
360
+ star_label = polygons_to_label(disti, points, prob=probi,shape=prob.shape)
361
+ gt = remove_boundary_cells(val_labels.astype(np.int32))
362
+ seg = remove_boundary_cells(star_label.astype(np.int32))
363
+ tp, fp, fn = eval_tp_fp_fn(gt, seg, threshold=0.5)
364
+ if tp == 0:
365
+ precision = 0
366
+ recall = 0
367
+ f1 = 0
368
+ else:
369
+ precision = tp / (tp + fp)
370
+ recall = tp / (tp + fn)
371
+ f1 = 2*(precision * recall)/ (precision + recall)
372
+ f1 = np.round(f1, 4)
373
+ seg_metric['F1_Score'].append(np.round(f1, 4))
374
+ avg_f1 = np.mean(seg_metric['F1_Score'])
375
+ writer.add_scalar("val_f1score", avg_f1, epoch)
376
+ if avg_f1 > max_f1:
377
+ max_f1 = avg_f1
378
+ print(str(epoch) + 'f1 score: ' + str(max_f1))
379
+ torch.save(checkpoint, join(model_path, "best_model.pth"))
380
+ np.savez_compressed(
381
+ join(model_path, "train_log.npz"),
382
+ val_dice=metric_values,
383
+ epoch_loss=epoch_loss_values,
384
+ )
385
+
386
+
387
+ if __name__ == "__main__":
388
+ main()
models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Sun Mar 20 14:23:55 2022
5
+
6
+ @author: jma
7
+ """
8
+
9
+ from .unetr2d import UNETR2D
10
+ from .swin_unetr import SwinUNETR
models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (282 Bytes). View file
 
models/__pycache__/convnext.cpython-38.pyc ADDED
Binary file (9.12 kB). View file
 
models/__pycache__/flexible_unet.cpython-38.pyc ADDED
Binary file (10.3 kB). View file
 
models/__pycache__/flexible_unet_convext.cpython-38.pyc ADDED
Binary file (10.4 kB). View file
 
models/__pycache__/swin_unetr.cpython-38.pyc ADDED
Binary file (30 kB). View file
 
models/__pycache__/unetr2d.cpython-38.pyc ADDED
Binary file (3.74 kB). View file
 
models/convnext.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ from functools import partial
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from timm.models.layers import trunc_normal_, DropPath
13
+ from timm.models.registry import register_model
14
+ from monai.networks.layers.factories import Act, Conv, Pad, Pool
15
+ from monai.networks.layers.utils import get_norm_layer
16
+ from monai.utils.module import look_up_option
17
+ from typing import List, NamedTuple, Optional, Tuple, Type, Union
18
+ class Block(nn.Module):
19
+ r""" ConvNeXt Block. There are two equivalent implementations:
20
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
21
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
22
+ We use (2) as we find it slightly faster in PyTorch
23
+
24
+ Args:
25
+ dim (int): Number of input channels.
26
+ drop_path (float): Stochastic depth rate. Default: 0.0
27
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
28
+ """
29
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
30
+ super().__init__()
31
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
32
+ self.norm = LayerNorm(dim, eps=1e-6)
33
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
34
+ self.act = nn.GELU()
35
+ self.pwconv2 = nn.Linear(4 * dim, dim)
36
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
37
+ requires_grad=True) if layer_scale_init_value > 0 else None
38
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
39
+
40
+ def forward(self, x):
41
+ input = x
42
+ x = self.dwconv(x)
43
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
44
+ x = self.norm(x)
45
+ x = self.pwconv1(x)
46
+ x = self.act(x)
47
+ x = self.pwconv2(x)
48
+ if self.gamma is not None:
49
+ x = self.gamma * x
50
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
51
+
52
+ x = input + self.drop_path(x)
53
+ return x
54
+
55
+ class ConvNeXt(nn.Module):
56
+ r""" ConvNeXt
57
+ A PyTorch impl of : `A ConvNet for the 2020s` -
58
+ https://arxiv.org/pdf/2201.03545.pdf
59
+
60
+ Args:
61
+ in_chans (int): Number of input image channels. Default: 3
62
+ num_classes (int): Number of classes for classification head. Default: 1000
63
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
64
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
65
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
66
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
67
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
68
+ """
69
+ def __init__(self, in_chans=3, num_classes=21841,
70
+ depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
71
+ layer_scale_init_value=1e-6, head_init_scale=1., out_indices=[0, 1, 2, 3],
72
+ ):
73
+ super().__init__()
74
+ # conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv["conv", 2]
75
+ # self._conv_stem = conv_type(self.in_channels, self.in_channels, kernel_size=3, stride=stride, bias=False)
76
+ # self._conv_stem_padding = _make_same_padder(self._conv_stem, current_image_size)
77
+
78
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
79
+ stem = nn.Sequential(
80
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
81
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
82
+ )
83
+ self.downsample_layers.append(stem)
84
+ for i in range(3):
85
+ downsample_layer = nn.Sequential(
86
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
87
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
88
+ )
89
+ self.downsample_layers.append(downsample_layer)
90
+
91
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
92
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
93
+ cur = 0
94
+ for i in range(4):
95
+ stage = nn.Sequential(
96
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
97
+ layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
98
+ )
99
+ self.stages.append(stage)
100
+ cur += depths[i]
101
+
102
+
103
+ self.out_indices = out_indices
104
+
105
+ norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
106
+ for i_layer in range(4):
107
+ layer = norm_layer(dims[i_layer])
108
+ layer_name = f'norm{i_layer}'
109
+ self.add_module(layer_name, layer)
110
+ self.apply(self._init_weights)
111
+
112
+
113
+ def _init_weights(self, m):
114
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
115
+ trunc_normal_(m.weight, std=.02)
116
+ nn.init.constant_(m.bias, 0)
117
+
118
+ def forward_features(self, x):
119
+ outs = []
120
+
121
+ for i in range(4):
122
+ x = self.downsample_layers[i](x)
123
+ x = self.stages[i](x)
124
+ if i in self.out_indices:
125
+ norm_layer = getattr(self, f'norm{i}')
126
+ x_out = norm_layer(x)
127
+
128
+ outs.append(x_out)
129
+
130
+ return tuple(outs)
131
+
132
+ def forward(self, x):
133
+ x = self.forward_features(x)
134
+
135
+ return x
136
+
137
+ class LayerNorm(nn.Module):
138
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
139
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
140
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
141
+ with shape (batch_size, channels, height, width).
142
+ """
143
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
144
+ super().__init__()
145
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
146
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
147
+ self.eps = eps
148
+ self.data_format = data_format
149
+ if self.data_format not in ["channels_last", "channels_first"]:
150
+ raise NotImplementedError
151
+ self.normalized_shape = (normalized_shape, )
152
+
153
+ def forward(self, x):
154
+ if self.data_format == "channels_last":
155
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
156
+ elif self.data_format == "channels_first":
157
+ u = x.mean(1, keepdim=True)
158
+ s = (x - u).pow(2).mean(1, keepdim=True)
159
+ x = (x - u) / torch.sqrt(s + self.eps)
160
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
161
+ return x
162
+
163
+
164
+ model_urls = {
165
+ "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
166
+ "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
167
+ "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
168
+ "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
169
+ "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
170
+ "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
171
+ "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
172
+ "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
173
+ "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
174
+ }
175
+
176
+ @register_model
177
+ def convnext_tiny(pretrained=False,in_22k=False, **kwargs):
178
+ model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
179
+ if pretrained:
180
+ url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
181
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
182
+ model.load_state_dict(checkpoint["model"])
183
+ return model
184
+
185
+ @register_model
186
+ def convnext_small(pretrained=False,in_22k=False, **kwargs):
187
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
188
+ if pretrained:
189
+ url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
190
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
191
+ model.load_state_dict(checkpoint["model"], strict=False)
192
+ return model
193
+
194
+ @register_model
195
+ def convnext_base(pretrained=False, in_22k=False, **kwargs):
196
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
197
+ if pretrained:
198
+ url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
199
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
200
+ model.load_state_dict(checkpoint["model"], strict=False)
201
+ return model
202
+
203
+ @register_model
204
+ def convnext_large(pretrained=False, in_22k=False, **kwargs):
205
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
206
+ if pretrained:
207
+ url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
208
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
209
+ model.load_state_dict(checkpoint["model"])
210
+ return model
211
+
212
+ @register_model
213
+ def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
214
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
215
+ if pretrained:
216
+ assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
217
+ url = model_urls['convnext_xlarge_22k']
218
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
219
+ model.load_state_dict(checkpoint["model"])
220
+ return model
models/flexible_unet.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from typing import List, Optional, Sequence, Tuple, Union
13
+
14
+ import torch
15
+ from torch import nn
16
+
17
+ from monai.networks.blocks import UpSample
18
+ from monai.networks.layers.factories import Conv
19
+ from monai.networks.layers.utils import get_act_layer
20
+ from monai.networks.nets import EfficientNetBNFeatures
21
+ from monai.networks.nets.basic_unet import UpCat
22
+ from monai.utils import InterpolateMode
23
+
24
+ __all__ = ["FlexibleUNet"]
25
+
26
+ encoder_feature_channel = {
27
+ "efficientnet-b0": (16, 24, 40, 112, 320),
28
+ "efficientnet-b1": (16, 24, 40, 112, 320),
29
+ "efficientnet-b2": (16, 24, 48, 120, 352),
30
+ "efficientnet-b3": (24, 32, 48, 136, 384),
31
+ "efficientnet-b4": (24, 32, 56, 160, 448),
32
+ "efficientnet-b5": (24, 40, 64, 176, 512),
33
+ "efficientnet-b6": (32, 40, 72, 200, 576),
34
+ "efficientnet-b7": (32, 48, 80, 224, 640),
35
+ "efficientnet-b8": (32, 56, 88, 248, 704),
36
+ "efficientnet-l2": (72, 104, 176, 480, 1376),
37
+ }
38
+
39
+
40
+ def _get_encoder_channels_by_backbone(backbone: str, in_channels: int = 3) -> tuple:
41
+ """
42
+ Get the encoder output channels by given backbone name.
43
+
44
+ Args:
45
+ backbone: name of backbone to generate features, can be from [efficientnet-b0, ..., efficientnet-b7].
46
+ in_channels: channel of input tensor, default to 3.
47
+
48
+ Returns:
49
+ A tuple of output feature map channels' length .
50
+ """
51
+ encoder_channel_tuple = encoder_feature_channel[backbone]
52
+ encoder_channel_list = [in_channels] + list(encoder_channel_tuple)
53
+ encoder_channel = tuple(encoder_channel_list)
54
+ return encoder_channel
55
+
56
+
57
+ class UNetDecoder(nn.Module):
58
+ """
59
+ UNet Decoder.
60
+ This class refers to `segmentation_models.pytorch
61
+ <https://github.com/qubvel/segmentation_models.pytorch>`_.
62
+
63
+ Args:
64
+ spatial_dims: number of spatial dimensions.
65
+ encoder_channels: number of output channels for all feature maps in encoder.
66
+ `len(encoder_channels)` should be no less than 2.
67
+ decoder_channels: number of output channels for all feature maps in decoder.
68
+ `len(decoder_channels)` should equal to `len(encoder_channels) - 1`.
69
+ act: activation type and arguments.
70
+ norm: feature normalization type and arguments.
71
+ dropout: dropout ratio.
72
+ bias: whether to have a bias term in convolution blocks in this decoder.
73
+ upsample: upsampling mode, available options are
74
+ ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
75
+ pre_conv: a conv block applied before upsampling.
76
+ Only used in the "nontrainable" or "pixelshuffle" mode.
77
+ interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
78
+ Only used in the "nontrainable" mode.
79
+ align_corners: set the align_corners parameter for upsample. Defaults to True.
80
+ Only used in the "nontrainable" mode.
81
+ is_pad: whether to pad upsampling features to fit the encoder spatial dims.
82
+
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ spatial_dims: int,
88
+ encoder_channels: Sequence[int],
89
+ decoder_channels: Sequence[int],
90
+ act: Union[str, tuple],
91
+ norm: Union[str, tuple],
92
+ dropout: Union[float, tuple],
93
+ bias: bool,
94
+ upsample: str,
95
+ pre_conv: Optional[str],
96
+ interp_mode: str,
97
+ align_corners: Optional[bool],
98
+ is_pad: bool,
99
+ ):
100
+
101
+ super().__init__()
102
+ if len(encoder_channels) < 2:
103
+ raise ValueError("the length of `encoder_channels` should be no less than 2.")
104
+ if len(decoder_channels) != len(encoder_channels) - 1:
105
+ raise ValueError("`len(decoder_channels)` should equal to `len(encoder_channels) - 1`.")
106
+
107
+ in_channels = [encoder_channels[-1]] + list(decoder_channels[:-1])
108
+ skip_channels = list(encoder_channels[1:-1][::-1]) + [0]
109
+ halves = [True] * (len(skip_channels) - 1)
110
+ halves.append(False)
111
+ blocks = []
112
+ for in_chn, skip_chn, out_chn, halve in zip(in_channels, skip_channels, decoder_channels, halves):
113
+ blocks.append(
114
+ UpCat(
115
+ spatial_dims=spatial_dims,
116
+ in_chns=in_chn,
117
+ cat_chns=skip_chn,
118
+ out_chns=out_chn,
119
+ act=act,
120
+ norm=norm,
121
+ dropout=dropout,
122
+ bias=bias,
123
+ upsample=upsample,
124
+ pre_conv=pre_conv,
125
+ interp_mode=interp_mode,
126
+ align_corners=align_corners,
127
+ halves=halve,
128
+ is_pad=is_pad,
129
+ )
130
+ )
131
+ self.blocks = nn.ModuleList(blocks)
132
+
133
+ def forward(self, features: List[torch.Tensor], skip_connect: int = 4):
134
+ skips = features[:-1][::-1]
135
+ features = features[1:][::-1]
136
+
137
+ x = features[0]
138
+ for i, block in enumerate(self.blocks):
139
+ if i < skip_connect:
140
+ skip = skips[i]
141
+ else:
142
+ skip = None
143
+ x = block(x, skip)
144
+
145
+ return x
146
+
147
+
148
+ class SegmentationHead(nn.Sequential):
149
+ """
150
+ Segmentation head.
151
+ This class refers to `segmentation_models.pytorch
152
+ <https://github.com/qubvel/segmentation_models.pytorch>`_.
153
+
154
+ Args:
155
+ spatial_dims: number of spatial dimensions.
156
+ in_channels: number of input channels for the block.
157
+ out_channels: number of output channels for the block.
158
+ kernel_size: kernel size for the conv layer.
159
+ act: activation type and arguments.
160
+ scale_factor: multiplier for spatial size. Has to match input size if it is a tuple.
161
+
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ spatial_dims: int,
167
+ in_channels: int,
168
+ out_channels: int,
169
+ kernel_size: int = 3,
170
+ act: Optional[Union[Tuple, str]] = None,
171
+ scale_factor: float = 1.0,
172
+ ):
173
+
174
+ conv_layer = Conv[Conv.CONV, spatial_dims](
175
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=kernel_size // 2
176
+ )
177
+ up_layer: nn.Module = nn.Identity()
178
+ if scale_factor > 1.0:
179
+ up_layer = UpSample(
180
+ spatial_dims=spatial_dims,
181
+ scale_factor=scale_factor,
182
+ mode="nontrainable",
183
+ pre_conv=None,
184
+ interp_mode=InterpolateMode.LINEAR,
185
+ )
186
+ if act is not None:
187
+ act_layer = get_act_layer(act)
188
+ else:
189
+ act_layer = nn.Identity()
190
+ super().__init__(conv_layer, up_layer, act_layer)
191
+
192
+
193
+ class FlexibleUNet(nn.Module):
194
+ """
195
+ A flexible implementation of UNet-like encoder-decoder architecture.
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ in_channels: int,
201
+ out_channels: int,
202
+ backbone: str,
203
+ pretrained: bool = False,
204
+ decoder_channels: Tuple = (256, 128, 64, 32, 16),
205
+ spatial_dims: int = 2,
206
+ norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.1}),
207
+ act: Union[str, tuple] = ("relu", {"inplace": True}),
208
+ dropout: Union[float, tuple] = 0.0,
209
+ decoder_bias: bool = False,
210
+ upsample: str = "nontrainable",
211
+ interp_mode: str = "nearest",
212
+ is_pad: bool = True,
213
+ ) -> None:
214
+ """
215
+ A flexible implement of UNet, in which the backbone/encoder can be replaced with
216
+ any efficient network. Currently the input must have a 2 or 3 spatial dimension
217
+ and the spatial size of each dimension must be a multiple of 32 if is pad parameter
218
+ is False
219
+
220
+ Args:
221
+ in_channels: number of input channels.
222
+ out_channels: number of output channels.
223
+ backbone: name of backbones to initialize, only support efficientnet right now,
224
+ can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2].
225
+ pretrained: whether to initialize pretrained ImageNet weights, only available
226
+ for spatial_dims=2 and batch norm is used, default to False.
227
+ decoder_channels: number of output channels for all feature maps in decoder.
228
+ `len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default
229
+ to (256, 128, 64, 32, 16).
230
+ spatial_dims: number of spatial dimensions, default to 2.
231
+ norm: normalization type and arguments, default to ("batch", {"eps": 1e-3,
232
+ "momentum": 0.1}).
233
+ act: activation type and arguments, default to ("relu", {"inplace": True}).
234
+ dropout: dropout ratio, default to 0.0.
235
+ decoder_bias: whether to have a bias term in decoder's convolution blocks.
236
+ upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``,
237
+ ``"nontrainable"``.
238
+ interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
239
+ Only used in the "nontrainable" mode.
240
+ is_pad: whether to pad upsampling features to fit features from encoder. Default to True.
241
+ If this parameter is set to "True", the spatial dim of network input can be arbitary
242
+ size, which is not supported by TensorRT. Otherwise, it must be a multiple of 32.
243
+ """
244
+ super().__init__()
245
+
246
+ if backbone not in encoder_feature_channel:
247
+ raise ValueError(f"invalid model_name {backbone} found, must be one of {encoder_feature_channel.keys()}.")
248
+
249
+ if spatial_dims not in (2, 3):
250
+ raise ValueError("spatial_dims can only be 2 or 3.")
251
+
252
+ adv_prop = "ap" in backbone
253
+
254
+ self.backbone = backbone
255
+ self.spatial_dims = spatial_dims
256
+ model_name = backbone
257
+ encoder_channels = _get_encoder_channels_by_backbone(backbone, in_channels)
258
+ self.encoder = EfficientNetBNFeatures(
259
+ model_name=model_name,
260
+ pretrained=pretrained,
261
+ in_channels=in_channels,
262
+ spatial_dims=spatial_dims,
263
+ norm=norm,
264
+ adv_prop=adv_prop,
265
+ )
266
+ self.decoder = UNetDecoder(
267
+ spatial_dims=spatial_dims,
268
+ encoder_channels=encoder_channels,
269
+ decoder_channels=decoder_channels,
270
+ act=act,
271
+ norm=norm,
272
+ dropout=dropout,
273
+ bias=decoder_bias,
274
+ upsample=upsample,
275
+ interp_mode=interp_mode,
276
+ pre_conv=None,
277
+ align_corners=None,
278
+ is_pad=is_pad,
279
+ )
280
+ self.dist_head = SegmentationHead(
281
+ spatial_dims=spatial_dims,
282
+ in_channels=decoder_channels[-1],
283
+ out_channels=32,
284
+ kernel_size=1,
285
+ act='relu',
286
+ )
287
+ self.prob_head = SegmentationHead(
288
+ spatial_dims=spatial_dims,
289
+ in_channels=decoder_channels[-1],
290
+ out_channels=1,
291
+ kernel_size=1,
292
+ act='sigmoid',
293
+ )
294
+
295
+ def forward(self, inputs: torch.Tensor):
296
+ """
297
+ Do a typical encoder-decoder-header inference.
298
+
299
+ Args:
300
+ inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``,
301
+ N is defined by `dimensions`.
302
+
303
+ Returns:
304
+ A torch Tensor of "raw" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``.
305
+
306
+ """
307
+ x = inputs
308
+ enc_out = self.encoder(x)
309
+ decoder_out = self.decoder(enc_out)
310
+ dist = self.dist_head(decoder_out)
311
+ prob = self.prob_head(decoder_out)
312
+ return dist,prob
models/flexible_unet_convext.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from typing import List, Optional, Sequence, Tuple, Union
13
+
14
+ import torch
15
+ from torch import nn
16
+ from . import convnext
17
+ from monai.networks.blocks import UpSample
18
+ from monai.networks.layers.factories import Conv
19
+ from monai.networks.layers.utils import get_act_layer
20
+ from monai.networks.nets import EfficientNetBNFeatures
21
+ from monai.networks.nets.basic_unet import UpCat
22
+ from monai.utils import InterpolateMode
23
+
24
+ __all__ = ["FlexibleUNet"]
25
+
26
+ encoder_feature_channel = {
27
+ "efficientnet-b0": (16, 24, 40, 112, 320),
28
+ "efficientnet-b1": (16, 24, 40, 112, 320),
29
+ "efficientnet-b2": (16, 24, 48, 120, 352),
30
+ "efficientnet-b3": (24, 32, 48, 136, 384),
31
+ "efficientnet-b4": (24, 32, 56, 160, 448),
32
+ "efficientnet-b5": (24, 40, 64, 176, 512),
33
+ "efficientnet-b6": (32, 40, 72, 200, 576),
34
+ "efficientnet-b7": (32, 48, 80, 224, 640),
35
+ "efficientnet-b8": (32, 56, 88, 248, 704),
36
+ "efficientnet-l2": (72, 104, 176, 480, 1376),
37
+ "convnext_small": (96, 192, 384, 768),
38
+ "convnext_base": (128, 256, 512, 1024),
39
+ "van_b2": (64, 128, 320, 512),
40
+ "van_b1": (64, 128, 320, 512),
41
+ }
42
+
43
+
44
+ def _get_encoder_channels_by_backbone(backbone: str, in_channels: int = 3) -> tuple:
45
+ """
46
+ Get the encoder output channels by given backbone name.
47
+
48
+ Args:
49
+ backbone: name of backbone to generate features, can be from [efficientnet-b0, ..., efficientnet-b7].
50
+ in_channels: channel of input tensor, default to 3.
51
+
52
+ Returns:
53
+ A tuple of output feature map channels' length .
54
+ """
55
+ encoder_channel_tuple = encoder_feature_channel[backbone]
56
+ encoder_channel_list = [in_channels] + list(encoder_channel_tuple)
57
+ encoder_channel = tuple(encoder_channel_list)
58
+ return encoder_channel
59
+
60
+
61
+ class UNetDecoder(nn.Module):
62
+ """
63
+ UNet Decoder.
64
+ This class refers to `segmentation_models.pytorch
65
+ <https://github.com/qubvel/segmentation_models.pytorch>`_.
66
+
67
+ Args:
68
+ spatial_dims: number of spatial dimensions.
69
+ encoder_channels: number of output channels for all feature maps in encoder.
70
+ `len(encoder_channels)` should be no less than 2.
71
+ decoder_channels: number of output channels for all feature maps in decoder.
72
+ `len(decoder_channels)` should equal to `len(encoder_channels) - 1`.
73
+ act: activation type and arguments.
74
+ norm: feature normalization type and arguments.
75
+ dropout: dropout ratio.
76
+ bias: whether to have a bias term in convolution blocks in this decoder.
77
+ upsample: upsampling mode, available options are
78
+ ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
79
+ pre_conv: a conv block applied before upsampling.
80
+ Only used in the "nontrainable" or "pixelshuffle" mode.
81
+ interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
82
+ Only used in the "nontrainable" mode.
83
+ align_corners: set the align_corners parameter for upsample. Defaults to True.
84
+ Only used in the "nontrainable" mode.
85
+ is_pad: whether to pad upsampling features to fit the encoder spatial dims.
86
+
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ spatial_dims: int,
92
+ encoder_channels: Sequence[int],
93
+ decoder_channels: Sequence[int],
94
+ act: Union[str, tuple],
95
+ norm: Union[str, tuple],
96
+ dropout: Union[float, tuple],
97
+ bias: bool,
98
+ upsample: str,
99
+ pre_conv: Optional[str],
100
+ interp_mode: str,
101
+ align_corners: Optional[bool],
102
+ is_pad: bool,
103
+ ):
104
+
105
+ super().__init__()
106
+ if len(encoder_channels) < 2:
107
+ raise ValueError("the length of `encoder_channels` should be no less than 2.")
108
+ if len(decoder_channels) != len(encoder_channels) - 1:
109
+ raise ValueError("`len(decoder_channels)` should equal to `len(encoder_channels) - 1`.")
110
+
111
+ in_channels = [encoder_channels[-1]] + list(decoder_channels[:-1])
112
+ skip_channels = list(encoder_channels[1:-1][::-1]) + [0]
113
+ halves = [True] * (len(skip_channels) - 1)
114
+ halves.append(False)
115
+ blocks = []
116
+ for in_chn, skip_chn, out_chn, halve in zip(in_channels, skip_channels, decoder_channels, halves):
117
+ blocks.append(
118
+ UpCat(
119
+ spatial_dims=spatial_dims,
120
+ in_chns=in_chn,
121
+ cat_chns=skip_chn,
122
+ out_chns=out_chn,
123
+ act=act,
124
+ norm=norm,
125
+ dropout=dropout,
126
+ bias=bias,
127
+ upsample=upsample,
128
+ pre_conv=pre_conv,
129
+ interp_mode=interp_mode,
130
+ align_corners=align_corners,
131
+ halves=halve,
132
+ is_pad=is_pad,
133
+ )
134
+ )
135
+ self.blocks = nn.ModuleList(blocks)
136
+
137
+ def forward(self, features: List[torch.Tensor], skip_connect: int = 3):
138
+ skips = features[:-1][::-1]
139
+ features = features[1:][::-1]
140
+
141
+ x = features[0]
142
+ for i, block in enumerate(self.blocks):
143
+ if i < skip_connect:
144
+ skip = skips[i]
145
+ else:
146
+ skip = None
147
+ x = block(x, skip)
148
+
149
+ return x
150
+
151
+
152
+ class SegmentationHead(nn.Sequential):
153
+ """
154
+ Segmentation head.
155
+ This class refers to `segmentation_models.pytorch
156
+ <https://github.com/qubvel/segmentation_models.pytorch>`_.
157
+
158
+ Args:
159
+ spatial_dims: number of spatial dimensions.
160
+ in_channels: number of input channels for the block.
161
+ out_channels: number of output channels for the block.
162
+ kernel_size: kernel size for the conv layer.
163
+ act: activation type and arguments.
164
+ scale_factor: multiplier for spatial size. Has to match input size if it is a tuple.
165
+
166
+ """
167
+
168
+ def __init__(
169
+ self,
170
+ spatial_dims: int,
171
+ in_channels: int,
172
+ out_channels: int,
173
+ kernel_size: int = 3,
174
+ act: Optional[Union[Tuple, str]] = None,
175
+ scale_factor: float = 1.0,
176
+ ):
177
+
178
+ conv_layer = Conv[Conv.CONV, spatial_dims](
179
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=kernel_size // 2
180
+ )
181
+ up_layer: nn.Module = nn.Identity()
182
+ # if scale_factor > 1.0:
183
+ # up_layer = UpSample(
184
+ # in_channels=out_channels,
185
+ # spatial_dims=spatial_dims,
186
+ # scale_factor=scale_factor,
187
+ # mode="deconv",
188
+ # pre_conv=None,
189
+ # interp_mode=InterpolateMode.LINEAR,
190
+ # )
191
+ if scale_factor > 1.0:
192
+ up_layer = UpSample(
193
+ spatial_dims=spatial_dims,
194
+ scale_factor=scale_factor,
195
+ mode="nontrainable",
196
+ pre_conv=None,
197
+ interp_mode=InterpolateMode.LINEAR,
198
+ )
199
+ if act is not None:
200
+ act_layer = get_act_layer(act)
201
+ else:
202
+ act_layer = nn.Identity()
203
+ super().__init__(conv_layer, up_layer, act_layer)
204
+
205
+
206
+ class FlexibleUNetConvext(nn.Module):
207
+ """
208
+ A flexible implementation of UNet-like encoder-decoder architecture.
209
+ """
210
+
211
+ def __init__(
212
+ self,
213
+ in_channels: int,
214
+ out_channels: int,
215
+ backbone: str,
216
+ pretrained: bool = False,
217
+ decoder_channels: Tuple = (1024, 512, 256, 128),
218
+ spatial_dims: int = 2,
219
+ norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.1}),
220
+ act: Union[str, tuple] = ("relu", {"inplace": True}),
221
+ dropout: Union[float, tuple] = 0.0,
222
+ decoder_bias: bool = False,
223
+ upsample: str = "nontrainable",
224
+ interp_mode: str = "nearest",
225
+ is_pad: bool = True,
226
+ ) -> None:
227
+ """
228
+ A flexible implement of UNet, in which the backbone/encoder can be replaced with
229
+ any efficient network. Currently the input must have a 2 or 3 spatial dimension
230
+ and the spatial size of each dimension must be a multiple of 32 if is pad parameter
231
+ is False
232
+
233
+ Args:
234
+ in_channels: number of input channels.
235
+ out_channels: number of output channels.
236
+ backbone: name of backbones to initialize, only support efficientnet right now,
237
+ can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2].
238
+ pretrained: whether to initialize pretrained ImageNet weights, only available
239
+ for spatial_dims=2 and batch norm is used, default to False.
240
+ decoder_channels: number of output channels for all feature maps in decoder.
241
+ `len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default
242
+ to (256, 128, 64, 32, 16).
243
+ spatial_dims: number of spatial dimensions, default to 2.
244
+ norm: normalization type and arguments, default to ("batch", {"eps": 1e-3,
245
+ "momentum": 0.1}).
246
+ act: activation type and arguments, default to ("relu", {"inplace": True}).
247
+ dropout: dropout ratio, default to 0.0.
248
+ decoder_bias: whether to have a bias term in decoder's convolution blocks.
249
+ upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``,
250
+ ``"nontrainable"``.
251
+ interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
252
+ Only used in the "nontrainable" mode.
253
+ is_pad: whether to pad upsampling features to fit features from encoder. Default to True.
254
+ If this parameter is set to "True", the spatial dim of network input can be arbitary
255
+ size, which is not supported by TensorRT. Otherwise, it must be a multiple of 32.
256
+ """
257
+ super().__init__()
258
+
259
+ if backbone not in encoder_feature_channel:
260
+ raise ValueError(f"invalid model_name {backbone} found, must be one of {encoder_feature_channel.keys()}.")
261
+
262
+ if spatial_dims not in (2, 3):
263
+ raise ValueError("spatial_dims can only be 2 or 3.")
264
+
265
+ adv_prop = "ap" in backbone
266
+
267
+ self.backbone = backbone
268
+ self.spatial_dims = spatial_dims
269
+ model_name = backbone
270
+ encoder_channels = _get_encoder_channels_by_backbone(backbone, in_channels)
271
+
272
+ self.encoder = convnext.convnext_small(pretrained=True,in_22k=True)
273
+ # self.encoder = VAN(embed_dims=[64, 128, 320, 512],
274
+ # depths=[3, 3, 12, 3],
275
+ # init_cfg=dict(type='Pretrained', checkpoint='pretrained/van_b2.pth'),
276
+ # norm_cfg=dict(type='BN', requires_grad=True)
277
+ # )
278
+ # self.encoder = VAN(embed_dims=[64, 128, 320, 512],
279
+ # depths=[2, 2, 4, 2],
280
+ # init_cfg=dict(type='Pretrained', checkpoint='pretrained/van_b1.pth'),
281
+ # norm_cfg=dict(type='BN', requires_grad=True)
282
+ # )
283
+ # self.encoder.init_weights()
284
+ self.decoder = UNetDecoder(
285
+ spatial_dims=spatial_dims,
286
+ encoder_channels=encoder_channels,
287
+ decoder_channels=decoder_channels,
288
+ act=act,
289
+ norm=norm,
290
+ dropout=dropout,
291
+ bias=decoder_bias,
292
+ upsample=upsample,
293
+ interp_mode=interp_mode,
294
+ pre_conv=None,
295
+ align_corners=None,
296
+ is_pad=is_pad,
297
+ )
298
+ self.dist_head = SegmentationHead(
299
+ spatial_dims=spatial_dims,
300
+ in_channels=decoder_channels[-1],
301
+ out_channels=64,
302
+ kernel_size=1,
303
+ act='relu',
304
+ scale_factor = 2,
305
+ )
306
+ self.prob_head = SegmentationHead(
307
+ spatial_dims=spatial_dims,
308
+ in_channels=decoder_channels[-1],
309
+ out_channels=1,
310
+ kernel_size=1,
311
+ act='sigmoid',
312
+ scale_factor = 2,
313
+ )
314
+
315
+ def forward(self, inputs: torch.Tensor):
316
+ """
317
+ Do a typical encoder-decoder-header inference.
318
+
319
+ Args:
320
+ inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``,
321
+ N is defined by `dimensions`.
322
+
323
+ Returns:
324
+ A torch Tensor of "raw" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``.
325
+
326
+ """
327
+ x = inputs
328
+ enc_out = self.encoder(x)
329
+ decoder_out = self.decoder(enc_out)
330
+
331
+ dist = self.dist_head(decoder_out)
332
+ prob = self.prob_head(decoder_out)
333
+
334
+ return dist,prob
335
+ class FlexibleUNet_hv(nn.Module):
336
+ """
337
+ A flexible implementation of UNet-like encoder-decoder architecture.
338
+ """
339
+
340
+ def __init__(
341
+ self,
342
+ in_channels: int,
343
+ out_channels: int,
344
+ backbone: str,
345
+ pretrained: bool = False,
346
+ decoder_channels: Tuple = (1024, 512, 256, 128),
347
+ spatial_dims: int = 2,
348
+ norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.1}),
349
+ act: Union[str, tuple] = ("relu", {"inplace": True}),
350
+ dropout: Union[float, tuple] = 0.0,
351
+ decoder_bias: bool = False,
352
+ upsample: str = "nontrainable",
353
+ interp_mode: str = "nearest",
354
+ is_pad: bool = True,
355
+ n_rays: int = 32,
356
+ prob_out_channels: int = 1,
357
+ ) -> None:
358
+ """
359
+ A flexible implement of UNet, in which the backbone/encoder can be replaced with
360
+ any efficient network. Currently the input must have a 2 or 3 spatial dimension
361
+ and the spatial size of each dimension must be a multiple of 32 if is pad parameter
362
+ is False
363
+
364
+ Args:
365
+ in_channels: number of input channels.
366
+ out_channels: number of output channels.
367
+ backbone: name of backbones to initialize, only support efficientnet right now,
368
+ can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2].
369
+ pretrained: whether to initialize pretrained ImageNet weights, only available
370
+ for spatial_dims=2 and batch norm is used, default to False.
371
+ decoder_channels: number of output channels for all feature maps in decoder.
372
+ `len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default
373
+ to (256, 128, 64, 32, 16).
374
+ spatial_dims: number of spatial dimensions, default to 2.
375
+ norm: normalization type and arguments, default to ("batch", {"eps": 1e-3,
376
+ "momentum": 0.1}).
377
+ act: activation type and arguments, default to ("relu", {"inplace": True}).
378
+ dropout: dropout ratio, default to 0.0.
379
+ decoder_bias: whether to have a bias term in decoder's convolution blocks.
380
+ upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``,
381
+ ``"nontrainable"``.
382
+ interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
383
+ Only used in the "nontrainable" mode.
384
+ is_pad: whether to pad upsampling features to fit features from encoder. Default to True.
385
+ If this parameter is set to "True", the spatial dim of network input can be arbitary
386
+ size, which is not supported by TensorRT. Otherwise, it must be a multiple of 32.
387
+ """
388
+ super().__init__()
389
+
390
+ if backbone not in encoder_feature_channel:
391
+ raise ValueError(f"invalid model_name {backbone} found, must be one of {encoder_feature_channel.keys()}.")
392
+
393
+ if spatial_dims not in (2, 3):
394
+ raise ValueError("spatial_dims can only be 2 or 3.")
395
+
396
+ adv_prop = "ap" in backbone
397
+
398
+ self.backbone = backbone
399
+ self.spatial_dims = spatial_dims
400
+ model_name = backbone
401
+ encoder_channels = _get_encoder_channels_by_backbone(backbone, in_channels)
402
+ self.encoder = convnext.convnext_small(pretrained=True,in_22k=True)
403
+ self.decoder = UNetDecoder(
404
+ spatial_dims=spatial_dims,
405
+ encoder_channels=encoder_channels,
406
+ decoder_channels=decoder_channels,
407
+ act=act,
408
+ norm=norm,
409
+ dropout=dropout,
410
+ bias=decoder_bias,
411
+ upsample=upsample,
412
+ interp_mode=interp_mode,
413
+ pre_conv=None,
414
+ align_corners=None,
415
+ is_pad=is_pad,
416
+ )
417
+ self.dist_head = SegmentationHead(
418
+ spatial_dims=spatial_dims,
419
+ in_channels=decoder_channels[-1],
420
+ out_channels=n_rays,
421
+ kernel_size=1,
422
+ act=None,
423
+ scale_factor = 2,
424
+ )
425
+ self.prob_head = SegmentationHead(
426
+ spatial_dims=spatial_dims,
427
+ in_channels=decoder_channels[-1],
428
+ out_channels=prob_out_channels,
429
+ kernel_size=1,
430
+ act='sigmoid',
431
+ scale_factor = 2,
432
+ )
433
+
434
+ def forward(self, inputs: torch.Tensor):
435
+ """
436
+ Do a typical encoder-decoder-header inference.
437
+
438
+ Args:
439
+ inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``,
440
+ N is defined by `dimensions`.
441
+
442
+ Returns:
443
+ A torch Tensor of "raw" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``.
444
+
445
+ """
446
+ x = inputs
447
+ enc_out = self.encoder(x)
448
+ decoder_out = self.decoder(enc_out)
449
+ dist = self.dist_head(decoder_out)
450
+ prob = self.prob_head(decoder_out)
451
+ return dist,prob
overlay.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ ###overlay
5
+ import cv2
6
+ import math
7
+ import random
8
+ import colorsys
9
+ import numpy as np
10
+ import itertools
11
+ import matplotlib.pyplot as plt
12
+ from matplotlib import cm
13
+ import os
14
+ import scipy.io as io
15
+ def get_bounding_box(img):
16
+ """Get bounding box coordinate information."""
17
+ rows = np.any(img, axis=1)
18
+ cols = np.any(img, axis=0)
19
+ rmin, rmax = np.where(rows)[0][[0, -1]]
20
+ cmin, cmax = np.where(cols)[0][[0, -1]]
21
+ # due to python indexing, need to add 1 to max
22
+ # else accessing will be 1px in the box, not out
23
+ rmax += 1
24
+ cmax += 1
25
+ return [rmin, rmax, cmin, cmax]
26
+ ####
27
+ def colorize(ch, vmin, vmax):
28
+ """Will clamp value value outside the provided range to vmax and vmin."""
29
+ cmap = plt.get_cmap("jet")
30
+ ch = np.squeeze(ch.astype("float32"))
31
+ vmin = vmin if vmin is not None else ch.min()
32
+ vmax = vmax if vmax is not None else ch.max()
33
+ ch[ch > vmax] = vmax # clamp value
34
+ ch[ch < vmin] = vmin
35
+ ch = (ch - vmin) / (vmax - vmin + 1.0e-16)
36
+ # take RGB from RGBA heat map
37
+ ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8")
38
+ return ch_cmap
39
+
40
+
41
+ ####
42
+ def random_colors(N, bright=True):
43
+ """Generate random colors.
44
+
45
+ To get visually distinct colors, generate them in HSV space then
46
+ convert to RGB.
47
+ """
48
+ brightness = 1.0 if bright else 0.7
49
+ hsv = [(i / N, 1, brightness) for i in range(N)]
50
+ colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
51
+ random.shuffle(colors)
52
+ return colors
53
+
54
+
55
+ ####
56
+ def visualize_instances_map(
57
+ input_image, inst_map, type_map=None, type_colour=None, line_thickness=2
58
+ ):
59
+ """Overlays segmentation results on image as contours.
60
+
61
+ Args:
62
+ input_image: input image
63
+ inst_map: instance mask with unique value for every object
64
+ type_map: type mask with unique value for every class
65
+ type_colour: a dict of {type : colour} , `type` is from 0-N
66
+ and `colour` is a tuple of (R, G, B)
67
+ line_thickness: line thickness of contours
68
+
69
+ Returns:
70
+ overlay: output image with segmentation overlay as contours
71
+ """
72
+ overlay = np.copy((input_image).astype(np.uint8))
73
+
74
+ inst_list = list(np.unique(inst_map)) # get list of instances
75
+ inst_list.remove(0) # remove background
76
+
77
+ inst_rng_colors = random_colors(len(inst_list))
78
+ inst_rng_colors = np.array(inst_rng_colors) * 255
79
+ inst_rng_colors = inst_rng_colors.astype(np.uint8)
80
+
81
+ for inst_idx, inst_id in enumerate(inst_list):
82
+ inst_map_mask = np.array(inst_map == inst_id, np.uint8) # get single object
83
+ y1, y2, x1, x2 = get_bounding_box(inst_map_mask)
84
+ y1 = y1 - 2 if y1 - 2 >= 0 else y1
85
+ x1 = x1 - 2 if x1 - 2 >= 0 else x1
86
+ x2 = x2 + 2 if x2 + 2 <= inst_map.shape[1] - 1 else x2
87
+ y2 = y2 + 2 if y2 + 2 <= inst_map.shape[0] - 1 else y2
88
+ inst_map_crop = inst_map_mask[y1:y2, x1:x2]
89
+ contours_crop = cv2.findContours(
90
+ inst_map_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
91
+ )
92
+ # only has 1 instance per map, no need to check #contour detected by opencv
93
+ #print(contours_crop)
94
+ contours_crop = np.squeeze(
95
+ contours_crop[0][0].astype("int32")
96
+ ) # * opencv protocol format may break
97
+
98
+ if len(contours_crop.shape) == 1:
99
+ contours_crop = contours_crop.reshape(1,-1)
100
+ #print(contours_crop.shape)
101
+ contours_crop += np.asarray([[x1, y1]]) # index correction
102
+ if type_map is not None:
103
+ type_map_crop = type_map[y1:y2, x1:x2]
104
+ type_id = np.unique(type_map_crop).max() # non-zero
105
+ inst_colour = type_colour[type_id]
106
+ else:
107
+ inst_colour = (inst_rng_colors[inst_idx]).tolist()
108
+ cv2.drawContours(overlay, [contours_crop], -1, inst_colour, line_thickness)
109
+ return overlay
110
+
111
+
112
+ # In[ ]:
113
+
114
+
115
+
116
+
predict.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ join = os.path.join
4
+ import argparse
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from collections import OrderedDict
9
+ from torchvision import datasets, models, transforms
10
+ from classifiers import resnet10, resnet18
11
+
12
+ from utils_modify import sliding_window_inference,sliding_window_inference_large,__proc_np_hv
13
+ from PIL import Image
14
+ import torch.nn.functional as F
15
+ from skimage import io, segmentation, morphology, measure, exposure
16
+ import tifffile as tif
17
+ from models.flexible_unet_convnext import FlexibleUNet_star,FlexibleUNet_hv
18
+ #from overlay import visualize_instances_map
19
+
20
+ def normalize_channel(img, lower=1, upper=99):
21
+ non_zero_vals = img[np.nonzero(img)]
22
+ percentiles = np.percentile(non_zero_vals, [lower, upper])
23
+ if percentiles[1] - percentiles[0] > 0.001:
24
+ img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8')
25
+ else:
26
+ img_norm = img
27
+ return img_norm.astype(np.uint8)
28
+ #torch.cuda.synchronize()
29
+ parser = argparse.ArgumentParser('Baseline for Microscopy image segmentation', add_help=False)
30
+ # Dataset parameters
31
+ parser.add_argument('-i', '--input_path', default='./inputs', type=str, help='training data path; subfolders: images, labels')
32
+ parser.add_argument("-o", '--output_path', default='./outputs', type=str, help='output path')
33
+ parser.add_argument('--model_path', default='./models', help='path where to save models and segmentation results')
34
+ parser.add_argument('--show_overlay', required=False, default=False, action="store_true", help='save segmentation overlay')
35
+
36
+ # Model parameters
37
+ parser.add_argument('--model_name', default='efficientunet', help='select mode: unet, unetr, swinunetr')
38
+ parser.add_argument('--input_size', default=512, type=int, help='segmentation classes')
39
+ args = parser.parse_args()
40
+ input_path = args.input_path
41
+ output_path = args.output_path
42
+ model_path = args.model_path
43
+ os.makedirs(output_path, exist_ok=True)
44
+ #overlay_path = 'overlays/'
45
+ #print(input_path)
46
+
47
+ img_names = sorted(os.listdir(join(input_path)))
48
+ #print(img_names)
49
+
50
+
51
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
+
53
+
54
+ preprocess=transforms.Compose([
55
+ transforms.Resize(size=256),
56
+ transforms.CenterCrop(size=224),
57
+ transforms.ToTensor(),
58
+ transforms.Normalize([0.485, 0.456, 0.406],
59
+ [0.229, 0.224, 0.225])
60
+ ])
61
+ roi_size = (512, 512)
62
+ overlap = 0.5
63
+ np_thres, ksize, overall_thres, obj_size_thres = 0.6, 15, 0.4, 100
64
+ n_rays = 32
65
+ sw_batch_size = 4
66
+ num_classes= 4
67
+ block_size = 2048
68
+ min_overlap = 128
69
+ context = 128
70
+ with torch.no_grad():
71
+ for img_name in img_names:
72
+ #print(img_name)
73
+ if img_name.endswith('.tif') or img_name.endswith('.tiff'):
74
+ img_data = tif.imread(join(input_path, img_name))
75
+ else:
76
+ img_data = io.imread(join(input_path, img_name))
77
+ # normalize image data
78
+ if len(img_data.shape) == 2:
79
+ img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1)
80
+ elif len(img_data.shape) == 3 and img_data.shape[-1] > 3:
81
+ img_data = img_data[:,:, :3]
82
+ else:
83
+ pass
84
+ pre_img_data = np.zeros(img_data.shape, dtype=np.uint8)
85
+ for i in range(3):
86
+ img_channel_i = img_data[:,:,i]
87
+ if len(img_channel_i[np.nonzero(img_channel_i)])>0:
88
+ pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99)
89
+ inputs=preprocess(Image.fromarray(pre_img_data)).unsqueeze(0).to(device)
90
+ cls_MODEL = model_path + '/cls/resnet18_4class_all_modified.tar'
91
+ model = resnet18().to(device)
92
+ model.load_state_dict(torch.load(cls_MODEL))
93
+ model.eval()
94
+ outputs = model(inputs)
95
+ _, preds = torch.max(outputs, 1)
96
+ label=preds[0].cpu().numpy()
97
+ #print(label)
98
+ test_npy01 = pre_img_data
99
+ if label in [0,1,2] or img_data.shape[0] > 4000:
100
+ if label == 0:
101
+ model = FlexibleUNet_star(in_channels=3,out_channels=n_rays+1,backbone='convnext_small',pretrained=False,n_rays=n_rays,prob_out_channels=1,).to(device)
102
+ checkpoint = torch.load(model_path+'/0/best_model.pth', map_location=torch.device(device))
103
+ model.load_state_dict(checkpoint['model_state_dict'])
104
+ model.eval()
105
+
106
+ output_label = sliding_window_inference_large(test_npy01,block_size,min_overlap,context, roi_size,sw_batch_size,predictor=model,device=device)
107
+ tif.imwrite(join(output_path, img_name.split('.')[0]+'_label.tiff'), output_label)
108
+
109
+ elif label == 1:
110
+ model = FlexibleUNet_star(in_channels=3,out_channels=n_rays+1,backbone='convnext_small',pretrained=False,n_rays=n_rays,prob_out_channels=1,).to(device)
111
+ checkpoint = torch.load(model_path+'/1/best_model.pth', map_location=torch.device(device))
112
+ model.load_state_dict(checkpoint['model_state_dict'])
113
+ model.eval()
114
+
115
+ output_label = sliding_window_inference_large(test_npy01,block_size,min_overlap,context, roi_size,sw_batch_size,predictor=model,device=device)
116
+ tif.imwrite(join(output_path, img_name.split('.')[0]+'_label.tiff'), output_label)
117
+ elif label == 2:
118
+ model = FlexibleUNet_star(in_channels=3,out_channels=n_rays+1,backbone='convnext_small',pretrained=False,n_rays=n_rays,prob_out_channels=1,).to(device)
119
+ checkpoint = torch.load(model_path+'/2/best_model.pth', map_location=torch.device(device))
120
+ model.load_state_dict(checkpoint['model_state_dict'])
121
+ model.eval()
122
+
123
+ output_label = sliding_window_inference_large(test_npy01,block_size,min_overlap,context, roi_size,sw_batch_size,predictor=model,device=device)
124
+ tif.imwrite(join(output_path, img_name.split('.')[0]+'_label.tiff'), output_label)
125
+
126
+
127
+ else:
128
+ model = FlexibleUNet_hv(in_channels=3,out_channels=2+2,backbone='convnext_small',pretrained=False,n_rays=2,prob_out_channels=2,).to(device)
129
+ checkpoint = torch.load(model_path+'/3/best_model_converted.pth', map_location=torch.device(device))
130
+ #model.load_state_dict(checkpoint['model_state_dict'])
131
+ #od = OrderedDict()
132
+ #for k, v in checkpoint['model_state_dict'].items():
133
+ #od[k.replace('module.', '')] = v
134
+ model.load_state_dict(checkpoint)
135
+ model.to(device)
136
+ model.eval()
137
+ test_tensor = torch.from_numpy(np.expand_dims(test_npy01, 0)).permute(0, 3, 1, 2).type(torch.FloatTensor).to(device)
138
+ if isinstance(roi_size, tuple):
139
+ roi = roi_size
140
+
141
+ output_hv, output_np = sliding_window_inference(test_tensor, roi, sw_batch_size, model, overlap=overlap)
142
+ pred_dict = {'np': output_np, 'hv': output_hv}
143
+ pred_dict = OrderedDict(
144
+ [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] # NHWC
145
+ )
146
+ pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1:]
147
+ pred_output = torch.cat(list(pred_dict.values()), -1).cpu().numpy() # NHW3
148
+ pred_map = np.squeeze(pred_output) # HW3
149
+ pred_inst = __proc_np_hv(pred_map, np_thres, ksize, overall_thres, obj_size_thres)
150
+ raw_pred_shape = pred_inst.shape[:2]
151
+ output_label = pred_inst
152
+
153
+ tif.imwrite(join(output_path, img_name.split('.')[0]+'_label.tiff'), output_label)
154
+
155
+
156
+
157
+
predict_unet_convnext.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ join = os.path.join
4
+ import argparse
5
+ import numpy as np
6
+ import torch
7
+ import monai
8
+ import torch.nn as nn
9
+
10
+ from utils import sliding_window_inference
11
+ #from baseline.models.unetr2d import UNETR2D
12
+ import time
13
+ from stardist import dist_to_coord, non_maximum_suppression, polygons_to_label
14
+ from stardist import random_label_cmap,ray_angles
15
+ from stardist import star_dist,edt_prob
16
+ from skimage import io, segmentation, morphology, measure, exposure
17
+ import tifffile as tif
18
+ import cv2
19
+ from overlay import visualize_instances_map
20
+ from models.flexible_unet import FlexibleUNet
21
+ from models.flexible_unet_convext import FlexibleUNetConvext
22
+ def normalize_channel(img, lower=1, upper=99):
23
+ non_zero_vals = img[np.nonzero(img)]
24
+ percentiles = np.percentile(non_zero_vals, [lower, upper])
25
+ if percentiles[1] - percentiles[0] > 0.001:
26
+ img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8')
27
+ else:
28
+ img_norm = img
29
+ return img_norm.astype(np.uint8)
30
+
31
+ def main():
32
+ parser = argparse.ArgumentParser('Baseline for Microscopy image segmentation', add_help=False)
33
+ # Dataset parameters
34
+ #parser.add_argument('-i', '--input_path', default='./inputs', type=str, help='training data path; subfolders: images, labels')
35
+ #parser.add_argument("-o", '--output_path', default='./outputs', type=str, help='output path')
36
+ parser.add_argument('--model_path', default='./work_dir/swinunetr_3class', help='path where to save models and segmentation results')
37
+ parser.add_argument('--show_overlay', required=False, default=False, action="store_true", help='save segmentation overlay')
38
+
39
+ # Model parameters
40
+ parser.add_argument('--model_name', default='efficientunet', help='select mode: unet, unetr, swinunetr')
41
+ parser.add_argument('--num_class', default=3, type=int, help='segmentation classes')
42
+ parser.add_argument('--input_size', default=512, type=int, help='segmentation classes')
43
+ args = parser.parse_args()
44
+
45
+ input_path = '/home/data/TuningSet/'
46
+ output_path = '/home/data/output/'
47
+ overlay_path = '/home/data/overlay/'
48
+
49
+
50
+ img_names = sorted(os.listdir(join(input_path)))
51
+ n_rays = 32
52
+
53
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+
55
+
56
+
57
+ if args.model_name.lower() == "efficientunet":
58
+ model = FlexibleUNetConvext(
59
+ in_channels=3,
60
+ out_channels=n_rays+1,
61
+ backbone='convnext_small',
62
+ pretrained=True,
63
+ ).to(device)
64
+
65
+
66
+
67
+ sigmoid = nn.Sigmoid()
68
+ checkpoint = torch.load('/home/louwei/stardist_convnext/efficientunet_3class/best_model.pth', map_location=torch.device(device))
69
+ model.load_state_dict(checkpoint['model_state_dict'])
70
+ #%%
71
+ roi_size = (args.input_size, args.input_size)
72
+ sw_batch_size = 4
73
+ model.eval()
74
+ with torch.no_grad():
75
+ for img_name in img_names:
76
+ print(img_name)
77
+ if img_name.endswith('.tif') or img_name.endswith('.tiff'):
78
+ img_data = tif.imread(join(input_path, img_name))
79
+ else:
80
+ img_data = io.imread(join(input_path, img_name))
81
+ # normalize image data
82
+ if len(img_data.shape) == 2:
83
+ img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1)
84
+ elif len(img_data.shape) == 3 and img_data.shape[-1] > 3:
85
+ img_data = img_data[:,:, :3]
86
+ else:
87
+ pass
88
+ pre_img_data = np.zeros(img_data.shape, dtype=np.uint8)
89
+ for i in range(3):
90
+ img_channel_i = img_data[:,:,i]
91
+ if len(img_channel_i[np.nonzero(img_channel_i)])>0:
92
+ pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99)
93
+
94
+ t0 = time.time()
95
+ #test_npy01 = pre_img_data/np.max(pre_img_data)
96
+ test_npy01 = pre_img_data
97
+ test_tensor = torch.from_numpy(np.expand_dims(test_npy01, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
98
+ output_dist,output_prob = sliding_window_inference(test_tensor, roi_size, sw_batch_size, model)
99
+ #test_pred_out = torch.nn.functional.softmax(test_pred_out, dim=1) # (B, C, H, W)
100
+ prob = output_prob[0][0].cpu().numpy()
101
+ dist = output_dist[0].cpu().numpy()
102
+
103
+
104
+ dist = np.transpose(dist,(1,2,0))
105
+ dist = np.maximum(1e-3, dist)
106
+ points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4)
107
+
108
+ coord = dist_to_coord(disti,points)
109
+
110
+ star_label = polygons_to_label(disti, points, prob=probi,shape=prob.shape)
111
+ tif.imwrite(join(output_path, img_name.split('.')[0]+'_label.tiff'), star_label)
112
+ overlay = visualize_instances_map(pre_img_data,star_label)
113
+ cv2.imwrite(join(overlay_path, img_name.split('.')[0]+'.png'), cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
114
+
115
+
116
+
117
+ if __name__ == "__main__":
118
+ main()
119
+
120
+
121
+
122
+
123
+
requirements.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gputools==0.2.13
2
+ h5py==3.7.0
3
+ huggingface-hub==0.10.1
4
+ imagecodecs
5
+ imageio==2.22.2
6
+ importlib-metadata==5.0.0
7
+ kiwisolver==1.4.4
8
+ llvmlite==0.39.1
9
+ Mako==1.2.3
10
+ Markdown==3.4.1
11
+ MarkupSafe==2.1.1
12
+ matplotlib==3.6.1
13
+ mkl-fft==1.3.1
14
+ mkl-service==2.4.0
15
+ monai==1.0.0
16
+ networkx==2.8.7
17
+ numba==0.56.3
18
+ numexpr
19
+ numpy
20
+ oauthlib==3.2.2
21
+ opencv-python==4.6.0.66
22
+ packaging
23
+ pandas==1.4.4
24
+ Pillow==9.2.0
25
+ scikit-image==0.19.3
26
+ scipy==1.9.2
27
+ stardist==0.8.3
28
+ tensorboard==2.10.1
29
+ tensorboard-data-server==0.6.1
30
+ tensorboard-plugin-wit==1.8.1
31
+ tifffile==2022.10.10
32
+ timm==0.6.11
33
+ torch==1.12.1
34
+ torchaudio==0.12.1
35
+ torchvision==0.13.1
36
+ tqdm==4.64.1
37
+
train_convnext_hover..py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Adapted form MONAI Tutorial: https://github.com/Project-MONAI/tutorials/tree/main/2d_segmentation/torch
5
+ """
6
+
7
+ import argparse
8
+ import os, sys
9
+
10
+ join = os.path.join
11
+ #sys.path.append('/data2/yuxinyi/stardist_pytorch')
12
+
13
+ from tqdm import tqdm
14
+ import numpy as np
15
+ import pandas as pd
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torch.nn import DataParallel
20
+ from torch.utils.data import Dataset, DataLoader
21
+ from torch.utils.tensorboard import SummaryWriter
22
+ from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
23
+ from stardist import star_dist, edt_prob
24
+ from stardist import dist_to_coord, non_maximum_suppression, polygons_to_label
25
+ from stardist import random_label_cmap, ray_angles
26
+ import monai
27
+ from collections import OrderedDict
28
+ from compute_metric import eval_tp_fp_fn, remove_boundary_cells
29
+ from monai.data import decollate_batch, PILReader
30
+ from monai.inferers import sliding_window_inference
31
+ from monai.metrics import DiceMetric
32
+ from monai.transforms import (
33
+ Activations,
34
+ AsChannelFirstd,
35
+ AddChanneld,
36
+ AsDiscrete,
37
+ CenterSpatialCropd,
38
+ Compose,
39
+ Lambdad,
40
+ LoadImaged,
41
+ # LoadImaged_modified,
42
+ SpatialPadd,
43
+ RandSpatialCropd,
44
+ RandRotate90d,
45
+ ScaleIntensityd,
46
+ RandAxisFlipd,
47
+ RandZoomd,
48
+ RandGaussianNoised,
49
+ RandAdjustContrastd,
50
+ RandGaussianSmoothd,
51
+ RandHistogramShiftd,
52
+ EnsureTyped,
53
+ EnsureType,
54
+ apply_transform,
55
+ )
56
+ from monai.visualize import plot_2d_or_3d_image
57
+ import matplotlib.pyplot as plt
58
+ from datetime import datetime
59
+ import shutil
60
+ from skimage import io
61
+ from skimage.color import gray2rgb
62
+
63
+ from models.unetr2d import UNETR2D
64
+ from models.swin_unetr import SwinUNETR
65
+ from models.flexible_unet_convext import FlexibleUNet_hv
66
+
67
+ from utils import cropping_center, gen_targets, xentropy_loss, dice_loss, mse_loss, msge_loss
68
+
69
+ import warnings
70
+ warnings.filterwarnings("ignore")
71
+
72
+ print("Successfully imported all requirements!")
73
+ torch.backends.cudnn.enabled = False
74
+
75
+ def rm_n_mkdir(dir_path):
76
+ """Remove and make directory."""
77
+ if os.path.isdir(dir_path):
78
+ shutil.rmtree(dir_path)
79
+ os.makedirs(dir_path)
80
+
81
+ class HoverDataset(Dataset):
82
+ def __init__(self, data, transform, mask_shape):
83
+ self.data = data
84
+ self.transform = transform
85
+ self.mask_shape = mask_shape
86
+
87
+ def __len__(self) -> int:
88
+ return len(self.data)
89
+
90
+ def _transform(self, index):
91
+ data_i = self.data[index]
92
+ return apply_transform(self.transform, data_i) if self.transform is not None else data_i
93
+
94
+ def __getitem__(self, index):
95
+ ret = self._transform(index)
96
+ # print(target_dict['img'].dtype, target_dict['label'].dtype)
97
+ # gen targets
98
+ inst_map = np.squeeze(ret['label'].numpy()).astype('int32') # 1HW -> HW
99
+ target_dict = gen_targets(inst_map, inst_map.shape[:2]) # original code: self.mask_shape -> current code: aug_size
100
+ np_map, hv_map = target_dict['np_map'], target_dict['hv_map']
101
+ np_map = cropping_center(np_map, self.mask_shape) # HW
102
+ hv_map = cropping_center(hv_map, self.mask_shape) # HW2
103
+ target_dict['np_map'] = torch.tensor(np_map)
104
+ target_dict['hv_map'] = torch.tensor(hv_map)
105
+ # centercrop img
106
+ img = cropping_center(ret['img'].permute(1,2,0), self.mask_shape).permute(2,0,1) # CHW -> HWC -> CHW
107
+ ret['img'] = img
108
+ ret.update(target_dict)
109
+ return ret
110
+
111
+ def valid_step(model, batch_data):
112
+
113
+ model.eval() # infer mode
114
+
115
+ ####
116
+ imgs = batch_data["img"]
117
+ true_np = batch_data["np_map"]
118
+ true_hv = batch_data["hv_map"]
119
+
120
+ imgs_gpu = imgs.to("cuda").type(torch.float32) # NCHW
121
+
122
+ # HWC
123
+ true_np = torch.squeeze(true_np).type(torch.int64)
124
+ true_hv = torch.squeeze(true_hv).type(torch.float32)
125
+
126
+ true_dict = {
127
+ "np": true_np,
128
+ "hv": true_hv,
129
+ }
130
+
131
+ # --------------------------------------------------------------
132
+ with torch.no_grad(): # dont compute gradient
133
+ preds = model(imgs_gpu)
134
+ pred_dict = {'np': preds[1], 'hv': preds[0]}
135
+ pred_dict = OrderedDict(
136
+ [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()]
137
+ )
138
+ pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1]
139
+
140
+ # * Its up to user to define the protocol to process the raw output per step!
141
+ result_dict = { # protocol for contents exchange within `raw`
142
+ "raw": {
143
+ "imgs": imgs.numpy(),
144
+ "true_np": true_dict["np"].numpy(),
145
+ "true_hv": true_dict["hv"].numpy(),
146
+ "prob_np": pred_dict["np"].cpu().numpy(),
147
+ "pred_hv": pred_dict["hv"].cpu().numpy(),
148
+ }
149
+ }
150
+
151
+ return result_dict
152
+
153
+ def proc_valid_step_output(raw_data, nr_types=None):
154
+
155
+ track_dict = {}
156
+
157
+ def _dice_info(true, pred, label):
158
+ true = np.array(true == label, np.int32)
159
+ pred = np.array(pred == label, np.int32)
160
+ inter = (pred * true).sum()
161
+ total = (pred + true).sum()
162
+ return inter, total
163
+
164
+ over_inter = 0
165
+ over_total = 0
166
+ over_correct = 0
167
+ prob_np = raw_data["prob_np"]
168
+ true_np = raw_data["true_np"]
169
+ for idx in range(len(raw_data["true_np"])):
170
+ patch_prob_np = prob_np[idx]
171
+ patch_true_np = true_np[idx]
172
+ patch_pred_np = np.array(patch_prob_np > 0.5, dtype=np.int32)
173
+ inter, total = _dice_info(patch_true_np, patch_pred_np, 1)
174
+ correct = (patch_pred_np == patch_true_np).sum()
175
+ over_inter += inter
176
+ over_total += total
177
+ over_correct += correct
178
+ nr_pixels = len(true_np) * np.size(true_np[0])
179
+ acc_np = over_correct / nr_pixels
180
+ dice_np = 2 * over_inter / (over_total + 1.0e-8)
181
+ track_dict['np_acc'] = acc_np
182
+ track_dict['np_dice'] = dice_np
183
+
184
+ # * HV regression statistic
185
+ pred_hv = raw_data["pred_hv"]
186
+ true_hv = raw_data["true_hv"]
187
+
188
+ over_squared_error = 0
189
+ for idx in range(len(raw_data["true_np"])):
190
+ patch_pred_hv = pred_hv[idx]
191
+ patch_true_hv = true_hv[idx]
192
+ squared_error = patch_pred_hv - patch_true_hv
193
+ squared_error = squared_error * squared_error
194
+ over_squared_error += squared_error.sum()
195
+ mse = over_squared_error / nr_pixels
196
+ track_dict['hv_mse'] = mse
197
+
198
+ return track_dict
199
+
200
+ def main():
201
+
202
+ # class Args:
203
+ # def __init__(self, data_path, seed, num_workers, model_name, input_size, mask_size, batch_size, max_epochs,
204
+ # val_interval, save_interval, initial_lr, gpu_id, n_rays):
205
+ # self.data_path = data_path
206
+ # self.seed = seed
207
+ # self.num_workers = num_workers
208
+ # self.model_name = model_name
209
+ # self.input_size = input_size
210
+ # self.mask_size = mask_size
211
+ # self.batch_size = batch_size
212
+ # self.max_epochs = max_epochs
213
+ # self.val_interval = val_interval
214
+ # self.save_interval = save_interval
215
+ # self.initial_lr = initial_lr
216
+ # self.gpu_id = gpu_id
217
+ # self.n_rays = n_rays
218
+
219
+ # args = Args('/data2/yuxinyi/stardist_pytorch/dataset/class3_seed2', 2022, 4, 'efficientunet', 512, 256, 16, 600,
220
+ # 1, 10, 1e-4, '4', 32)
221
+ modelname = 'star-hover'
222
+ strategy = 'aug256_out256'
223
+ parser = argparse.ArgumentParser("Baseline for Microscopy image segmentation")
224
+ # Dataset parameters
225
+ parser.add_argument(
226
+ "--data_path",
227
+ default=f"/mntnfs/med_data5/louwei/consep/",
228
+ type=str,
229
+ help="training data path; subfolders: images, labels",
230
+ )
231
+ parser.add_argument("--seed", default=10, type=int)
232
+ # parser.add_argument("--resume", default=False, help="resume from checkpoint")
233
+ parser.add_argument("--num_workers", default=4, type=int)
234
+
235
+ # Model parameters
236
+ parser.add_argument(
237
+ "--model_name", default="efficientunet", help="select mode: unet, unetr, swinunetr"
238
+ )
239
+ parser.add_argument("--input_size", default=512, type=int, help="after rand crop")
240
+ parser.add_argument("--mask_size", default=256, type=int, help="after gen target")
241
+ # Training parameters
242
+ parser.add_argument("--batch_size", default=12, type=int, help="Batch size per GPU")
243
+ parser.add_argument("--max_epochs", default=800, type=int)
244
+ parser.add_argument("--val_interval", default=1, type=int)
245
+ parser.add_argument("--save_interval", default=10, type=int)
246
+ parser.add_argument("--initial_lr", type=float, default=1e-4, help="learning rate")
247
+ parser.add_argument('--gpu_id', type=str, default='0', help='gpu id')
248
+
249
+ args = parser.parse_args()
250
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
251
+
252
+ work_dir = f'/mntnfs/med_data5/louwei/hover_stardist/class_{modelname}_{strategy}'
253
+
254
+ # monai.config.print_config()
255
+ pre_trained = False
256
+ # %% set training/validation split
257
+ np.random.seed(args.seed)
258
+ model_path = join(work_dir)
259
+ rm_n_mkdir(model_path)
260
+ run_id = datetime.now().strftime("%Y%m%d-%H%M")
261
+ shutil.copyfile(
262
+ __file__, join(model_path, run_id + "_" + os.path.basename(__file__))
263
+ )
264
+ img_path = join(args.data_path, "Train/Images_3channels")
265
+ gt_path = join(args.data_path, "Train/tif")
266
+ val_img_path = join(args.data_path, "Test/Images_3channels")
267
+ val_gt_path = join(args.data_path, "Test/tif")
268
+ img_names = sorted(os.listdir(img_path))
269
+ gt_names = [img_name.replace('.png', '.tif') for img_name in img_names]
270
+ img_num = len(img_names)
271
+ val_frac = 0.1
272
+ val_img_names = sorted(os.listdir(val_img_path))
273
+ val_gt_names = [img_name.replace('.png', '.tif') for img_name in val_img_names]
274
+
275
+ train_files = [
276
+ {"img": join(img_path, img_names[i]), "label": join(gt_path, gt_names[i]), 'name': img_names[i]}
277
+ for i in range(len(img_names))
278
+ ]
279
+ val_files = [
280
+ {"img": join(val_img_path, val_img_names[i]), "label": join(val_gt_path, val_gt_names[i]),
281
+ 'name': val_img_names[i]}
282
+ for i in range(len(val_img_names))
283
+ ]
284
+ print(
285
+ f"training image num: {len(train_files)}, validation image num: {len(val_files)}"
286
+ )
287
+
288
+ def load_img(img):
289
+ ret = io.imread(img)
290
+ if len(ret.shape) == 2:
291
+ ret = gray2rgb(ret)
292
+ return ret.astype('float32')
293
+
294
+ def load_ann(ann):
295
+ ret = np.squeeze(io.imread(ann)).astype('float32')
296
+ return ret
297
+
298
+ # %% define transforms for image and segmentation
299
+ train_transforms = Compose(
300
+ [
301
+ Lambdad(('img',), load_img),
302
+ Lambdad(('label',), load_ann),
303
+ # LoadImaged(
304
+ # keys=["img", "label"], reader=PILReader, dtype=np.float32
305
+ # ), # image three channels (H, W, 3); label: (H, W)
306
+ AddChanneld(keys=["label"], allow_missing_keys=True), # label: (1, H, W)
307
+ AsChannelFirstd(
308
+ keys=["img"], channel_dim=-1, allow_missing_keys=True
309
+ ), # image: (3, H, W)
310
+ # ScaleIntensityd(
311
+ # keys=["img"], allow_missing_keys=True
312
+ # ), # Do not scale label
313
+ # SpatialPadd(keys=["img", "label"], spatial_size=args.input_size),
314
+ # RandSpatialCropd(
315
+ # keys=["img", "label"], roi_size=args.input_size, random_size=False
316
+ # ),
317
+ RandAxisFlipd(keys=["img", "label"], prob=0.5),
318
+ RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=[0, 1]),
319
+ # # intensity transform
320
+ RandGaussianNoised(keys=["img"], prob=0.25, mean=0, std=0.1),
321
+ RandAdjustContrastd(keys=["img"], prob=0.25, gamma=(1, 2)),
322
+ RandGaussianSmoothd(keys=["img"], prob=0.25, sigma_x=(1, 2)),
323
+ RandHistogramShiftd(keys=["img"], prob=0.25, num_control_points=3),
324
+ RandZoomd(
325
+ keys=["img", "label"],
326
+ prob=0.15,
327
+ min_zoom=0.5,
328
+ max_zoom=2.0,
329
+ mode=["area", "nearest"],
330
+ ),
331
+ EnsureTyped(keys=["img", "label"]),
332
+ ]
333
+ )
334
+
335
+ val_transforms = Compose(
336
+ [
337
+ Lambdad(('img',), load_img),
338
+ Lambdad(('label',), load_ann),
339
+ # LoadImaged(keys=["img", "label"], reader=PILReader, dtype=np.float32),
340
+ AddChanneld(keys=["label"], allow_missing_keys=True),
341
+ AsChannelFirstd(keys=["img"], channel_dim=-1, allow_missing_keys=True),
342
+ # ScaleIntensityd(keys=["img"], allow_missing_keys=True),
343
+ # AsDiscreted(keys=['label'], to_onehot=3),
344
+ # CenterSpatialCropd(
345
+ # keys=["img", "label"], roi_size=args.input_size
346
+ # ),
347
+ EnsureTyped(keys=["img", "label"]),
348
+ ]
349
+ )
350
+
351
+ # % define dataset, data loader
352
+ # check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
353
+ check_ds = HoverDataset(data=train_files, transform=train_transforms, mask_shape=(args.mask_size, args.mask_size))
354
+ print(len(check_ds))
355
+ tmp = check_ds[0]
356
+ print(tmp['img'].shape, tmp['label'].shape, tmp['hv_map'].shape, tmp['np_map'].shape)
357
+ check_loader = DataLoader(check_ds, batch_size=1, num_workers=4)
358
+ check_data = monai.utils.misc.first(check_loader)
359
+ print(
360
+ "sanity check:",
361
+ check_data["img"].shape,
362
+ torch.max(check_data["img"]),
363
+ check_data["label"].shape,
364
+ torch.max(check_data["label"]),
365
+ check_data["hv_map"].shape,
366
+ torch.max(check_data["hv_map"]),
367
+ check_data["np_map"].shape,
368
+ torch.max(check_data["np_map"]),
369
+ )
370
+
371
+ # %% create a training data loader
372
+ # train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
373
+ train_ds = HoverDataset(data=train_files, transform=train_transforms, mask_shape=(args.mask_size, args.mask_size))
374
+ print(len(train_ds))
375
+ # example = train_ds[0]
376
+ # plt.imshow(np.array(example['img']).transpose(1,2,0).astype('uint8'))
377
+ # plt.imshow(np.squeeze(example['np_map'].numpy()).astype('uint8'), 'gray')
378
+ # plt.imshow(example['hv_map'].numpy()[...,0])
379
+ # plt.imshow(example['hv_map'].numpy()[..., 1])
380
+ # plt.show()
381
+ # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
382
+ train_loader = DataLoader(
383
+ train_ds,
384
+ batch_size=args.batch_size,
385
+ shuffle=True,
386
+ num_workers=args.num_workers,
387
+ pin_memory=torch.cuda.is_available(),
388
+ )
389
+ # create a validation data loader
390
+ # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
391
+ val_ds = HoverDataset(data=val_files, transform=val_transforms, mask_shape=(args.mask_size, args.mask_size))
392
+ val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=4)
393
+
394
+ model = FlexibleUNet_hv(
395
+ in_channels=3,
396
+ out_channels=2+2,
397
+ backbone='convnext_small',
398
+ pretrained=True,
399
+ n_rays=2,
400
+ prob_out_channels=2,
401
+ )
402
+
403
+ activatation = nn.ReLU()
404
+ sigmoid = nn.Sigmoid()
405
+ initial_lr = args.initial_lr
406
+ optimizer = torch.optim.AdamW(model.parameters(), initial_lr)
407
+ scheduler = StepLR(optimizer, 100, 0.1)
408
+ #if pre_trained == True:
409
+ #print('Load pretrained weights...')
410
+ #checkpoint = torch.load('/data2/yuxinyi/stardist_pytorch/pretrained/overall/330.pth')
411
+ #model.load_state_dict(checkpoint['model_state_dict'])
412
+ # model = DataParallel(model)
413
+ model = model.to('cuda')
414
+ # start a typical PyTorch training
415
+ max_epochs = args.max_epochs
416
+ val_interval = args.val_interval
417
+ save_interval = args.save_interval
418
+ epoch_loss_values = []
419
+ writer = SummaryWriter(model_path)
420
+
421
+ #*# record loss and f1
422
+ loss_file = f'{work_dir}/train_loss.txt'
423
+ f1_file = f'{work_dir}/train_loss.txt'
424
+ if os.path.exists(loss_file):
425
+ os.remove(loss_file)
426
+ if os.path.exists(f1_file):
427
+ os.remove(f1_file)
428
+ #*#
429
+
430
+ for epoch in range(1, args.max_epochs):
431
+ model.train()
432
+ epoch_loss = 0
433
+ running_np_1, running_np_2, running_hv_1, running_hv_2 = 0.0, 0.0, 0.0, 0.0
434
+ stream = tqdm(train_loader)
435
+ for step, batch_data in enumerate(stream, start=1):
436
+
437
+ #*# hv map
438
+ inputs, true_np, true_hv = batch_data["img"], batch_data["np_map"], batch_data['hv_map']
439
+ true_np = true_np.to("cuda").type(torch.int64) # NHW
440
+ true_hv = true_hv.to("cuda").type(torch.float32) # NHWC
441
+ true_np_onehot = (F.one_hot(true_np, num_classes=2)).type(torch.float32) # NHWC
442
+ inputs = torch.tensor(inputs).to('cuda')
443
+ # print(inputs.shape, true_np.shape, true_hv.shape)
444
+
445
+ optimizer.zero_grad()
446
+ pred_hv, pred_np = model(inputs) # NCHW
447
+ pred_hv = pred_hv.permute(0, 2, 3, 1).contiguous() # NHWC
448
+ pred_np = pred_np.permute(0, 2, 3, 1).contiguous() # NHWC
449
+ pred_np = F.softmax(pred_np, dim=-1)
450
+
451
+ # losses
452
+ loss_np_1 = xentropy_loss(true_np_onehot, pred_np) # bce
453
+ loss_np_2 = dice_loss(true_np_onehot, pred_np) # dice
454
+ loss_hv_1 = mse_loss(true_hv, pred_hv) # mse
455
+ loss_hv_2 = msge_loss(true_hv, pred_hv, true_np_onehot[...,1]) # msge
456
+ loss = loss_np_1 + loss_np_2 + loss_hv_1 + loss_hv_2
457
+ loss.backward()
458
+ optimizer.step()
459
+ epoch_loss += loss.item()
460
+ epoch_len = len(train_ds) // train_loader.batch_size
461
+
462
+ running_np_1 += loss_np_1.item()
463
+ running_np_2 += loss_np_2.item()
464
+ running_hv_1 += loss_hv_1.item()
465
+ running_hv_2 += loss_hv_2.item()
466
+ #*#
467
+
468
+ stream.set_description(
469
+ f'Epoch {epoch} | np bce: {running_np_1 / step:.4f}, np dice: {running_np_2 / step:.4f}, hv mse: {running_hv_1 / step:.4f}, hv msge: {running_hv_2 / step:.4f}')
470
+
471
+ epoch_loss /= step
472
+ epoch_loss_values.append(epoch_loss)
473
+ writer.add_scalar("train_loss", epoch_loss, epoch)
474
+ writer.add_scalar("np_bce", running_np_1 / step, epoch)
475
+ writer.add_scalar("np_dice", running_np_2 / step, epoch)
476
+ writer.add_scalar("hv_mse", running_hv_1 / step, epoch)
477
+ writer.add_scalar("hv_msge", running_hv_2 / step, epoch)
478
+ print(f"epoch {epoch} average loss: {epoch_loss:.4f}, lr: {optimizer.param_groups[0]['lr']}")
479
+
480
+ #*# record
481
+ with open(loss_file, 'a') as f:
482
+ f.write(f'Epoch{epoch}\tloss:{epoch_loss:.4f}\tnp_bce:{running_np_1/step:.4f}\tnp_dice:{running_np_2/step:.4f}\thv_mse:{running_hv_1/step:.4f}\thv_msge:{running_hv_2/step:.4f}\n')
483
+ #*#
484
+
485
+ checkpoint = {
486
+ "epoch": epoch,
487
+ "model_state_dict": model.state_dict(),
488
+ "optimizer_state_dict": optimizer.state_dict(),
489
+ "loss": epoch_loss_values,
490
+ }
491
+ if epoch % save_interval == 0:
492
+ torch.save(checkpoint, join(model_path, str(epoch) + ".pth"))
493
+
494
+ running_np_acc, running_np_dice, running_hv_mse = 0.0, 0.0, 0.0
495
+ stream_val = tqdm(val_loader)
496
+ for step, batch_data in enumerate(stream_val, start=1):
497
+ raw_data = valid_step(model, batch_data)['raw']
498
+ track_dict = proc_valid_step_output(raw_data)
499
+ running_np_acc += track_dict['np_acc']
500
+ running_np_dice += track_dict['np_dice']
501
+ running_hv_mse += track_dict['hv_mse']
502
+ stream.set_description(f'Epoch {epoch} | np acc: {running_np_acc / step:.4f}, np dice: {running_np_dice / step:.4f}, hv mse: {running_hv_mse / step:.4f}')
503
+ writer.add_scalar("np_acc", running_np_acc / step, epoch)
504
+ writer.add_scalar("np_dice", running_np_dice / step, epoch)
505
+ writer.add_scalar("hv_mse", running_hv_mse / step, epoch)
506
+ print(f'Epoch {epoch} | np acc: {running_np_acc / step:.4f}, np dice: {running_np_dice / step:.4f}, hv mse: {running_hv_mse / step:.4f}')
507
+
508
+ #*# record
509
+ with open(loss_file, 'a') as f:
510
+ f.write(f'Validation | Epoch{epoch}\tloss:{epoch_loss:.4f}\tnp_acc:{running_np_acc/step:.4f}\tnp_dice:{running_np_dice/step:.4f}\thv_mse:{running_hv_mse/step:.4f}\n')
511
+ #*#
512
+
513
+ scheduler.step()
514
+
515
+ if __name__ == "__main__":
516
+ main()
train_convnext_stardist.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Adapted form MONAI Tutorial: https://github.com/Project-MONAI/tutorials/tree/main/2d_segmentation/torch
5
+ """
6
+
7
+ import argparse
8
+ import os
9
+
10
+ join = os.path.join
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ from torch.utils.data import DataLoader
16
+ from torch.utils.tensorboard import SummaryWriter
17
+ from stardist import star_dist,edt_prob
18
+ from stardist import dist_to_coord, non_maximum_suppression, polygons_to_label
19
+ from stardist import random_label_cmap,ray_angles
20
+ import monai
21
+ from collections import OrderedDict
22
+ from compute_metric import eval_tp_fp_fn,remove_boundary_cells
23
+ from monai.data import decollate_batch, PILReader
24
+ from monai.inferers import sliding_window_inference
25
+ from monai.metrics import DiceMetric
26
+ from monai.transforms import (
27
+ Activations,
28
+ AsChannelFirstd,
29
+ AddChanneld,
30
+ AsDiscrete,
31
+ Compose,
32
+ LoadImaged,
33
+ SpatialPadd,
34
+ RandSpatialCropd,
35
+ RandRotate90d,
36
+ ScaleIntensityd,
37
+ RandAxisFlipd,
38
+ RandZoomd,
39
+ RandGaussianNoised,
40
+ RandAdjustContrastd,
41
+ RandGaussianSmoothd,
42
+ RandHistogramShiftd,
43
+ EnsureTyped,
44
+ EnsureType,
45
+ )
46
+ from monai.visualize import plot_2d_or_3d_image
47
+ import matplotlib.pyplot as plt
48
+ from datetime import datetime
49
+ import shutil
50
+ import tqdm
51
+ from models.unetr2d import UNETR2D
52
+ from models.swin_unetr import SwinUNETR
53
+ from models.flexible_unet import FlexibleUNet
54
+ from models.flexible_unet_convext import FlexibleUNetConvext
55
+ print("Successfully imported all requirements!")
56
+ torch.backends.cudnn.enabled =False
57
+
58
+ def main():
59
+ parser = argparse.ArgumentParser("Baseline for Microscopy image segmentation")
60
+ # Dataset parameters
61
+ parser.add_argument(
62
+ "--data_path",
63
+ default="/data2/liuchenyu/external_processed/split",
64
+ type=str,
65
+ help="training data path; subfolders: images, labels",
66
+ )
67
+ parser.add_argument(
68
+ "--work_dir", default="/data/louwei/nips_comp/convnext_fold0", help="path where to save models and logs"
69
+ )
70
+ parser.add_argument("--seed", default=2022, type=int)
71
+ # parser.add_argument("--resume", default=False, help="resume from checkpoint")
72
+ parser.add_argument("--num_workers", default=8, type=int)
73
+ parser.add_argument("--local_rank", type=int)
74
+ # Model parameters
75
+ parser.add_argument(
76
+ "--model_name", default="efficientunet", help="select mode: unet, unetr, swinunetr"
77
+ )
78
+ parser.add_argument("--num_class", default=3, type=int, help="segmentation classes")
79
+ parser.add_argument(
80
+ "--input_size", default=512, type=int, help="segmentation classes"
81
+ )
82
+ # Training parameters
83
+ parser.add_argument("--batch_size", default=16, type=int, help="Batch size per GPU")
84
+ parser.add_argument("--max_epochs", default=2000, type=int)
85
+ parser.add_argument("--val_interval", default=5, type=int)
86
+ parser.add_argument("--epoch_tolerance", default=100, type=int)
87
+ parser.add_argument("--initial_lr", type=float, default=1e-4, help="learning rate")
88
+
89
+ args = parser.parse_args()
90
+ torch.cuda.set_device(args.local_rank)
91
+ torch.distributed.init_process_group(backend='nccl')
92
+ monai.config.print_config()
93
+ n_rays = 32
94
+ pre_trained = True
95
+ #%% set training/validation split
96
+ np.random.seed(args.seed)
97
+ model_path = join(args.work_dir, args.model_name + "_3class")
98
+ os.makedirs(model_path, exist_ok=True)
99
+ run_id = datetime.now().strftime("%Y%m%d-%H%M")
100
+ # This must be change every runing time ! ! ! ! ! ! ! ! ! ! !
101
+ model_file = "models/flexible_unet_convext.py"
102
+ shutil.copyfile(
103
+ __file__, join(model_path, os.path.basename(__file__))
104
+ )
105
+ shutil.copyfile(
106
+ model_file, join(model_path, os.path.basename(model_file))
107
+ )
108
+ all_image_path = '/data/louwei/nips_comp/train_cellpose_multi0/'
109
+ all_img_path = join(all_image_path, "train/images")
110
+ all_gt_path = join(all_image_path, "train/tif")
111
+
112
+ all_img_names = sorted(os.listdir(all_img_path))
113
+ all_gt_names = [img_name.split(".")[0] + ".tif" for img_name in all_img_names]
114
+ all_img_files = [join(all_img_path, all_img_names[i]) for i in range(len(all_img_names))]
115
+ all_gt_files = [join(all_gt_path, all_gt_names[i]) for i in range(len(all_img_names))]
116
+ img_path = join(args.data_path, "train/images")
117
+ gt_path = join(args.data_path, "train/tif")
118
+ val_img_path = join(args.data_path, "test/images")
119
+ val_gt_path = join(args.data_path, "test/tif")
120
+ img_names = sorted(os.listdir(img_path))
121
+ gt_names = [img_name.split(".")[0] + ".tif" for img_name in img_names]
122
+ train_img_files = [join(img_path, img_names[i]) for i in range(len(img_names))]
123
+ train_gt_files = [join(gt_path, gt_names[i]) for i in range(len(img_names))]
124
+ cat_img_files = train_img_files + all_img_files
125
+ cat_gt_files = train_gt_files + all_gt_files
126
+ img_num = len(img_names)
127
+ val_frac = 0.1
128
+ val_img_names = sorted(os.listdir(val_img_path))
129
+ val_gt_names = [img_name.split(".")[0] + ".tif" for img_name in val_img_names]
130
+ #indices = np.arange(img_num)
131
+ #np.random.shuffle(indices)
132
+ #val_split = int(img_num * val_frac)
133
+ #train_indices = indices[val_split:]
134
+ #val_indices = indices[:val_split]
135
+
136
+ train_files = [
137
+ {"img": cat_img_files[i], "label": cat_gt_files[i]}
138
+ for i in range(len(cat_img_files))
139
+ ]
140
+ val_files = [
141
+ {"img": join(val_img_path, val_img_names[i]), "label": join(val_gt_path, val_gt_names[i])}
142
+ for i in range(len(val_img_names))
143
+ ]
144
+ print(
145
+ f"training image num: {len(train_files)}, validation image num: {len(val_files)}"
146
+ )
147
+ #%% define transforms for image and segmentation
148
+ train_transforms = Compose(
149
+ [
150
+ LoadImaged(
151
+ keys=["img", "label"], reader=PILReader, dtype=np.float32
152
+ ), # image three channels (H, W, 3); label: (H, W)
153
+ AddChanneld(keys=["label"], allow_missing_keys=True), # label: (1, H, W)
154
+ AsChannelFirstd(
155
+ keys=["img"], channel_dim=-1, allow_missing_keys=True
156
+ ), # image: (3, H, W)
157
+ #ScaleIntensityd(
158
+ #keys=["img"], allow_missing_keys=True
159
+ #), # Do not scale label
160
+ SpatialPadd(keys=["img", "label"], spatial_size=args.input_size),
161
+ RandSpatialCropd(
162
+ keys=["img", "label"], roi_size=args.input_size, random_size=False
163
+ ),
164
+ RandAxisFlipd(keys=["img", "label"], prob=0.5),
165
+ RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=[0, 1]),
166
+ # # intensity transform
167
+ RandGaussianNoised(keys=["img"], prob=0.25, mean=0, std=0.1),
168
+ RandAdjustContrastd(keys=["img"], prob=0.25, gamma=(1, 2)),
169
+ RandGaussianSmoothd(keys=["img"], prob=0.25, sigma_x=(1, 2)),
170
+ RandHistogramShiftd(keys=["img"], prob=0.25, num_control_points=3),
171
+ RandZoomd(
172
+ keys=["img", "label"],
173
+ prob=0.15,
174
+ min_zoom=0.5,
175
+ max_zoom=2,
176
+ mode=["area", "nearest"],
177
+ ),
178
+ EnsureTyped(keys=["img", "label"]),
179
+ ]
180
+ )
181
+
182
+ val_transforms = Compose(
183
+ [
184
+ LoadImaged(keys=["img", "label"], reader=PILReader, dtype=np.float32),
185
+ AddChanneld(keys=["label"], allow_missing_keys=True),
186
+ AsChannelFirstd(keys=["img"], channel_dim=-1, allow_missing_keys=True),
187
+ #ScaleIntensityd(keys=["img"], allow_missing_keys=True),
188
+ # AsDiscreted(keys=['label'], to_onehot=3),
189
+ EnsureTyped(keys=["img", "label"]),
190
+ ]
191
+ )
192
+
193
+ #% define dataset, data loader
194
+ check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
195
+ check_loader = DataLoader(check_ds, batch_size=1, num_workers=4)
196
+ check_data = monai.utils.misc.first(check_loader)
197
+ print(
198
+ "sanity check:",
199
+ check_data["img"].shape,
200
+ torch.max(check_data["img"]),
201
+ check_data["label"].shape,
202
+ torch.max(check_data["label"]),
203
+ )
204
+
205
+ #%% create a training data loader
206
+ train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
207
+ # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
208
+ train_loader = DataLoader(
209
+ train_ds,
210
+ batch_size=args.batch_size,
211
+ shuffle=True,
212
+ num_workers=args.num_workers,
213
+ pin_memory=torch.cuda.is_available(),
214
+ )
215
+ # create a validation data loader
216
+ val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
217
+ val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=1)
218
+
219
+ dice_metric = DiceMetric(
220
+ include_background=False, reduction="mean", get_not_nans=False
221
+ )
222
+
223
+ post_pred = Compose(
224
+ [EnsureType(), Activations(softmax=True), AsDiscrete(threshold=0.5)]
225
+ )
226
+ post_gt = Compose([EnsureType(), AsDiscrete(to_onehot=None)])
227
+ # create UNet, DiceLoss and Adam optimizer
228
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
229
+ if args.model_name.lower() == "unet":
230
+ model = monai.networks.nets.UNet(
231
+ spatial_dims=2,
232
+ in_channels=3,
233
+ out_channels=args.num_class,
234
+ channels=(16, 32, 64, 128, 256),
235
+ strides=(2, 2, 2, 2),
236
+ num_res_units=2,
237
+ ).to(device)
238
+
239
+ if args.model_name.lower() == "efficientunet":
240
+ model = FlexibleUNetConvext(
241
+ in_channels=3,
242
+ out_channels=n_rays+1,
243
+ backbone='convnext_small',
244
+ pretrained=True,
245
+ ).to(device)
246
+
247
+ if args.model_name.lower() == "swinunetr":
248
+ model = SwinUNETR(
249
+ img_size=(args.input_size, args.input_size),
250
+ in_channels=3,
251
+ out_channels=n_rays+1,
252
+ feature_size=24, # should be divisible by 12
253
+ spatial_dims=2,
254
+ ).to(device)
255
+
256
+ #loss_masked_dice = monai.losses.DiceCELoss(softmax=True)
257
+ loss_dice = monai.losses.DiceLoss(squared_pred=True,jaccard=True)
258
+ loss_bce = nn.BCELoss()
259
+ loss_dist_mae = nn.L1Loss()
260
+ activatation = nn.ReLU()
261
+ sigmoid = nn.Sigmoid()
262
+ #loss_dist_mae = monai.losses.DiceCELoss(softmax=True)
263
+ initial_lr = args.initial_lr
264
+ encoder = list(map(id, model.encoder.parameters()))
265
+ base_params = filter(lambda p: id(p) not in encoder, model.parameters())
266
+ params = [
267
+ {"params": base_params, "lr":initial_lr},
268
+ {"params": model.encoder.parameters(), "lr": initial_lr * 0.1},
269
+ ]
270
+ optimizer = torch.optim.AdamW(params, initial_lr)
271
+ #if pre_trained == True:
272
+ #print('Load pretrained weights...')
273
+ #checkpoint = torch.load('/mntnfs/med_data5/louwei/nips_comp/swin_stardist/swinunetr_3class/40.pth', map_location=torch.device(device))
274
+ #model.load_state_dict(checkpoint['model_state_dict'])
275
+ # start a typical PyTorch training
276
+ #checkpoint = torch.load("/data2/liuchenyu/log/convnextsmall/efficientunet_3class/510.pth", map_location=torch.device(device))
277
+ #model.load_state_dict(checkpoint['model_state_dict'])
278
+ print('distributed model')
279
+ model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
280
+ print('successful model')
281
+ max_epochs = args.max_epochs
282
+ epoch_tolerance = args.epoch_tolerance
283
+ val_interval = args.val_interval
284
+ best_metric = -1
285
+ best_metric_epoch = -1
286
+ epoch_loss_values = list()
287
+ metric_values = list()
288
+ writer = SummaryWriter(model_path)
289
+ max_f1 = 0
290
+ for epoch in range(0, max_epochs):
291
+ model.train()
292
+ epoch_loss = 0
293
+ epoch_loss_prob = 0
294
+ epoch_loss_dist_2 = 0
295
+ epoch_loss_dist_1 = 0
296
+ for step, batch_data in enumerate(tqdm.tqdm(train_loader), 1):
297
+ inputs, labels = batch_data["img"],batch_data["label"]
298
+ print(step)
299
+ processes_labels = []
300
+
301
+ for i in range(labels.shape[0]):
302
+ label = labels[i][0]
303
+ distances = star_dist(label,n_rays)
304
+ distances = np.transpose(distances,(2,0,1))
305
+ #print(distances.shape)
306
+ obj_probabilities = edt_prob(label.astype(int))
307
+ obj_probabilities = np.expand_dims(obj_probabilities,0)
308
+ #print(obj_probabilities.shape)
309
+ final_label = np.concatenate((distances,obj_probabilities),axis=0)
310
+ #print(final_label.shape)
311
+ processes_labels.append(final_label)
312
+
313
+ labels = np.stack(processes_labels)
314
+
315
+ #print(inputs.shape,labels.shape)
316
+ inputs, labels = torch.tensor(inputs).to(device), torch.tensor(labels).to(device)
317
+ #print(inputs.shape,labels.shape)
318
+ optimizer.zero_grad()
319
+ output_dist,output_prob = model(inputs)
320
+ #print(outputs.shape)
321
+ dist_output = output_dist
322
+ prob_output = output_prob
323
+ dist_label = labels[:,:n_rays,:,:]
324
+ prob_label = torch.unsqueeze(labels[:,-1,:,:], 1)
325
+ #print(dist_output.shape,prob_output.shape,dist_label.shape)
326
+ #labels_onehot = monai.networks.one_hot(
327
+ #labels, args.num_class
328
+ #) # (b,cls,256,256)
329
+ #print(prob_label.max(),prob_label.min())
330
+ loss_dist_1 = loss_dice(dist_output*prob_label,dist_label*prob_label)
331
+ #print(loss_dist_1)
332
+ loss_prob = loss_bce(prob_output,prob_label)
333
+ #print(prob_label.shape,dist_output.shape)
334
+ loss_dist_2 = loss_dist_mae(dist_output*prob_label,dist_label*prob_label)
335
+ #print(loss_dist_2)
336
+ loss = loss_prob + loss_dist_2*0.3 + loss_dist_1
337
+ loss.backward()
338
+ optimizer.step()
339
+ epoch_loss += loss.item()
340
+ epoch_loss_prob += loss_prob.item()
341
+ epoch_loss_dist_2 += loss_dist_2.item()
342
+ epoch_loss_dist_1 += loss_dist_1.item()
343
+ epoch_len = len(train_ds) // train_loader.batch_size
344
+
345
+ epoch_loss /= step
346
+ epoch_loss_prob /= step
347
+ epoch_loss_dist_2 /= step
348
+ epoch_loss_dist_1 /= step
349
+ epoch_loss_values.append(epoch_loss)
350
+ print(f"epoch {epoch} average loss: {epoch_loss:.4f}")
351
+ writer.add_scalar("train_loss", epoch_loss, epoch)
352
+ print('dist dice: '+str(epoch_loss_dist_1)+' dist mae: '+str(epoch_loss_dist_2)+' prob bce: '+str(epoch_loss_prob))
353
+ checkpoint = {
354
+ "epoch": epoch,
355
+ "model_state_dict": model.module.state_dict(),
356
+ "optimizer_state_dict": optimizer.state_dict(),
357
+ "loss": epoch_loss_values,
358
+ }
359
+ if epoch < 8:
360
+ continue
361
+ if epoch > 1 and epoch % val_interval == 0:
362
+ torch.save(checkpoint, join(model_path, str(epoch) + ".pth"))
363
+ model.eval()
364
+ with torch.no_grad():
365
+ val_images = None
366
+ val_labels = None
367
+ val_outputs = None
368
+ seg_metric = OrderedDict()
369
+ seg_metric['F1_Score'] = []
370
+ for val_data in tqdm.tqdm(val_loader):
371
+ val_images, val_labels = val_data["img"].to(device), val_data[
372
+ "label"
373
+ ].to(device)
374
+ roi_size = (512, 512)
375
+ sw_batch_size = 4
376
+ output_dist,output_prob = sliding_window_inference(
377
+ val_images, roi_size, sw_batch_size, model
378
+ )
379
+ val_labels = val_labels[0][0].cpu().numpy()
380
+ prob = output_prob[0][0].cpu().numpy()
381
+ dist = output_dist[0].cpu().numpy()
382
+ #print(val_labels.shape,prob.shape,dist.shape)
383
+ dist = np.transpose(dist,(1,2,0))
384
+ dist = np.maximum(1e-3, dist)
385
+ points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4)
386
+
387
+ coord = dist_to_coord(disti,points)
388
+
389
+ star_label = polygons_to_label(disti, points, prob=probi,shape=prob.shape)
390
+ gt = remove_boundary_cells(val_labels.astype(np.int32))
391
+ seg = remove_boundary_cells(star_label.astype(np.int32))
392
+ tp, fp, fn = eval_tp_fp_fn(gt, seg, threshold=0.5)
393
+ if tp == 0:
394
+ precision = 0
395
+ recall = 0
396
+ f1 = 0
397
+ else:
398
+ precision = tp / (tp + fp)
399
+ recall = tp / (tp + fn)
400
+ f1 = 2*(precision * recall)/ (precision + recall)
401
+ f1 = np.round(f1, 4)
402
+ seg_metric['F1_Score'].append(np.round(f1, 4))
403
+ avg_f1 = np.mean(seg_metric['F1_Score'])
404
+ writer.add_scalar("val_f1score", avg_f1, epoch)
405
+ if avg_f1 > max_f1:
406
+ max_f1 = avg_f1
407
+ print(str(epoch) + 'f1 score: ' + str(max_f1))
408
+ torch.save(checkpoint, join(model_path, "best_model.pth"))
409
+ np.savez_compressed(
410
+ join(model_path, "train_log.npz"),
411
+ val_dice=metric_values,
412
+ epoch_loss=epoch_loss_values,
413
+ )
414
+
415
+
416
+ if __name__ == "__main__":
417
+ main()
utils.py ADDED
@@ -0,0 +1,868 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import warnings
13
+ from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union
14
+
15
+ import cv2
16
+ import math
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn.functional as F
20
+ import colorsys
21
+ import itertools
22
+ import matplotlib.pyplot as plt
23
+ from matplotlib import cm
24
+
25
+ from monai.data.meta_tensor import MetaTensor
26
+ from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size
27
+ from monai.transforms import Resize
28
+ from monai.utils import (
29
+ BlendMode,
30
+ PytorchPadMode,
31
+ convert_data_type,
32
+ convert_to_dst_type,
33
+ ensure_tuple,
34
+ fall_back_tuple,
35
+ look_up_option,
36
+ optional_import,
37
+ )
38
+
39
+ from scipy import ndimage
40
+ from scipy.ndimage.filters import gaussian_filter
41
+ from scipy.ndimage.interpolation import affine_transform, map_coordinates
42
+
43
+ from skimage import morphology as morph
44
+ from scipy.ndimage import filters, measurements
45
+ from scipy.ndimage.morphology import (
46
+ binary_dilation,
47
+ binary_fill_holes,
48
+ distance_transform_cdt,
49
+ distance_transform_edt,
50
+ )
51
+
52
+ from skimage.segmentation import watershed
53
+ from skimage.exposure import rescale_intensity
54
+ from skimage.filters import sobel_h, sobel_v, gaussian
55
+ from skimage.morphology import disk, binary_opening
56
+
57
+ tqdm, _ = optional_import("tqdm", name="tqdm")
58
+
59
+ __all__ = ["sliding_window_inference"]
60
+
61
+ ####
62
+ def normalize(mask, dtype=np.uint8):
63
+ return (255 * mask / np.amax(mask)).astype(dtype)
64
+
65
+ def fix_mirror_padding(ann):
66
+ """Deal with duplicated instances due to mirroring in interpolation
67
+ during shape augmentation (scale, rotation etc.).
68
+
69
+ """
70
+ current_max_id = np.amax(ann)
71
+ inst_list = list(np.unique(ann))
72
+ if 0 in inst_list:
73
+ inst_list.remove(0) # 0 is background
74
+ for inst_id in inst_list:
75
+ inst_map = np.array(ann == inst_id, np.uint8)
76
+ remapped_ids = measurements.label(inst_map)[0]
77
+ remapped_ids[remapped_ids > 1] += current_max_id
78
+ ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1]
79
+ current_max_id = np.amax(ann)
80
+ return ann
81
+
82
+ ####
83
+ def get_bounding_box(img):
84
+ """Get bounding box coordinate information."""
85
+ rows = np.any(img, axis=1)
86
+ cols = np.any(img, axis=0)
87
+ rmin, rmax = np.where(rows)[0][[0, -1]]
88
+ cmin, cmax = np.where(cols)[0][[0, -1]]
89
+ # due to python indexing, need to add 1 to max
90
+ # else accessing will be 1px in the box, not out
91
+ rmax += 1
92
+ cmax += 1
93
+ return [rmin, rmax, cmin, cmax]
94
+
95
+
96
+ ####
97
+ def cropping_center(x, crop_shape, batch=False):
98
+ """Crop an input image at the centre.
99
+
100
+ Args:
101
+ x: input array
102
+ crop_shape: dimensions of cropped array
103
+
104
+ Returns:
105
+ x: cropped array
106
+
107
+ """
108
+ orig_shape = x.shape
109
+ if not batch:
110
+ h0 = int((orig_shape[0] - crop_shape[0]) * 0.5)
111
+ w0 = int((orig_shape[1] - crop_shape[1]) * 0.5)
112
+ x = x[h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1]]
113
+ else:
114
+ h0 = int((orig_shape[1] - crop_shape[0]) * 0.5)
115
+ w0 = int((orig_shape[2] - crop_shape[1]) * 0.5)
116
+ x = x[:, h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1]]
117
+ return x
118
+
119
+ def gen_instance_hv_map(ann, crop_shape):
120
+ """Input annotation must be of original shape.
121
+
122
+ The map is calculated only for instances within the crop portion
123
+ but based on the original shape in original image.
124
+
125
+ Perform following operation:
126
+ Obtain the horizontal and vertical distance maps for each
127
+ nuclear instance.
128
+
129
+ """
130
+ orig_ann = ann.copy() # instance ID map
131
+ fixed_ann = fix_mirror_padding(orig_ann)
132
+ # re-cropping with fixed instance id map
133
+ crop_ann = cropping_center(fixed_ann, crop_shape)
134
+ # TODO: deal with 1 label warning
135
+ crop_ann = morph.remove_small_objects(crop_ann, min_size=30)
136
+
137
+ x_map = np.zeros(orig_ann.shape[:2], dtype=np.float32)
138
+ y_map = np.zeros(orig_ann.shape[:2], dtype=np.float32)
139
+
140
+ inst_list = list(np.unique(crop_ann))
141
+ if 0 in inst_list:
142
+ inst_list.remove(0) # 0 is background
143
+ for inst_id in inst_list:
144
+ inst_map = np.array(fixed_ann == inst_id, np.uint8)
145
+ inst_box = get_bounding_box(inst_map) # rmin, rmax, cmin, cmax
146
+
147
+ # expand the box by 2px
148
+ # Because we first pad the ann at line 207, the bboxes
149
+ # will remain valid after expansion
150
+ inst_box[0] -= 2
151
+ inst_box[2] -= 2
152
+ inst_box[1] += 2
153
+ inst_box[3] += 2
154
+
155
+ # fix inst_box
156
+ inst_box[0] = max(inst_box[0], 0)
157
+ inst_box[2] = max(inst_box[2], 0)
158
+ # inst_box[1] = min(inst_box[1], fixed_ann.shape[0])
159
+ # inst_box[3] = min(inst_box[3], fixed_ann.shape[1])
160
+
161
+ inst_map = inst_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]]
162
+
163
+ if inst_map.shape[0] < 2 or inst_map.shape[1] < 2:
164
+ print(f'inst_map.shape < 2: {inst_map.shape}, {inst_box}, {get_bounding_box(np.array(fixed_ann == inst_id, np.uint8))}')
165
+ continue
166
+
167
+ # instance center of mass, rounded to nearest pixel
168
+ inst_com = list(measurements.center_of_mass(inst_map))
169
+ if np.isnan(measurements.center_of_mass(inst_map)).any():
170
+ print(inst_id, fixed_ann.shape, np.array(fixed_ann == inst_id, np.uint8).shape)
171
+ print(get_bounding_box(np.array(fixed_ann == inst_id, np.uint8)))
172
+ print(inst_map)
173
+ print(inst_list)
174
+ print(inst_box)
175
+ print(np.count_nonzero(np.array(fixed_ann == inst_id, np.uint8)))
176
+
177
+ inst_com[0] = int(inst_com[0] + 0.5)
178
+ inst_com[1] = int(inst_com[1] + 0.5)
179
+
180
+ inst_x_range = np.arange(1, inst_map.shape[1] + 1)
181
+ inst_y_range = np.arange(1, inst_map.shape[0] + 1)
182
+ # shifting center of pixels grid to instance center of mass
183
+ inst_x_range -= inst_com[1]
184
+ inst_y_range -= inst_com[0]
185
+
186
+ inst_x, inst_y = np.meshgrid(inst_x_range, inst_y_range)
187
+
188
+ # remove coord outside of instance
189
+ inst_x[inst_map == 0] = 0
190
+ inst_y[inst_map == 0] = 0
191
+ inst_x = inst_x.astype("float32")
192
+ inst_y = inst_y.astype("float32")
193
+
194
+ # normalize min into -1 scale
195
+ if np.min(inst_x) < 0:
196
+ inst_x[inst_x < 0] /= -np.amin(inst_x[inst_x < 0])
197
+ if np.min(inst_y) < 0:
198
+ inst_y[inst_y < 0] /= -np.amin(inst_y[inst_y < 0])
199
+ # normalize max into +1 scale
200
+ if np.max(inst_x) > 0:
201
+ inst_x[inst_x > 0] /= np.amax(inst_x[inst_x > 0])
202
+ if np.max(inst_y) > 0:
203
+ inst_y[inst_y > 0] /= np.amax(inst_y[inst_y > 0])
204
+
205
+ ####
206
+ x_map_box = x_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]]
207
+ x_map_box[inst_map > 0] = inst_x[inst_map > 0]
208
+
209
+ y_map_box = y_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]]
210
+ y_map_box[inst_map > 0] = inst_y[inst_map > 0]
211
+
212
+ hv_map = np.dstack([x_map, y_map])
213
+ return hv_map
214
+
215
+ def remove_small_objects(pred, min_size=64, connectivity=1):
216
+ """Remove connected components smaller than the specified size.
217
+
218
+ This function is taken from skimage.morphology.remove_small_objects, but the warning
219
+ is removed when a single label is provided.
220
+
221
+ Args:
222
+ pred: input labelled array
223
+ min_size: minimum size of instance in output array
224
+ connectivity: The connectivity defining the neighborhood of a pixel.
225
+
226
+ Returns:
227
+ out: output array with instances removed under min_size
228
+
229
+ """
230
+ out = pred
231
+
232
+ if min_size == 0: # shortcut for efficiency
233
+ return out
234
+
235
+ if out.dtype == bool:
236
+ selem = ndimage.generate_binary_structure(pred.ndim, connectivity)
237
+ ccs = np.zeros_like(pred, dtype=np.int32)
238
+ ndimage.label(pred, selem, output=ccs)
239
+ else:
240
+ ccs = out
241
+
242
+ try:
243
+ component_sizes = np.bincount(ccs.ravel())
244
+ except ValueError:
245
+ raise ValueError(
246
+ "Negative value labels are not supported. Try "
247
+ "relabeling the input with `scipy.ndimage.label` or "
248
+ "`skimage.morphology.label`."
249
+ )
250
+
251
+ too_small = component_sizes < min_size
252
+ too_small_mask = too_small[ccs]
253
+ out[too_small_mask] = 0
254
+
255
+ return out
256
+
257
+ ####
258
+ def gen_targets(ann, crop_shape, **kwargs):
259
+ """Generate the targets for the network."""
260
+ hv_map = gen_instance_hv_map(ann, crop_shape)
261
+ np_map = ann.copy()
262
+ np_map[np_map > 0] = 1
263
+
264
+ hv_map = cropping_center(hv_map, crop_shape)
265
+ np_map = cropping_center(np_map, crop_shape)
266
+
267
+ target_dict = {
268
+ "hv_map": hv_map,
269
+ "np_map": np_map,
270
+ }
271
+
272
+ return target_dict
273
+
274
+ ####
275
+ def xentropy_loss(true, pred, reduction="mean"):
276
+ """Cross entropy loss. Assumes NHWC!
277
+
278
+ Args:
279
+ pred: prediction array
280
+ true: ground truth array
281
+
282
+ Returns:
283
+ cross entropy loss
284
+
285
+ """
286
+ epsilon = 10e-8
287
+ # scale preds so that the class probs of each sample sum to 1
288
+ pred = pred / torch.sum(pred, -1, keepdim=True)
289
+ # manual computation of crossentropy
290
+ pred = torch.clamp(pred, epsilon, 1.0 - epsilon)
291
+ loss = -torch.sum((true * torch.log(pred)), -1, keepdim=True)
292
+ loss = loss.mean() if reduction == "mean" else loss.sum()
293
+ return loss
294
+
295
+
296
+ ####
297
+ def dice_loss(true, pred, smooth=1e-3):
298
+ """`pred` and `true` must be of torch.float32. Assuming of shape NxHxWxC."""
299
+ inse = torch.sum(pred * true, (0, 1, 2))
300
+ l = torch.sum(pred, (0, 1, 2))
301
+ r = torch.sum(true, (0, 1, 2))
302
+ loss = 1.0 - (2.0 * inse + smooth) / (l + r + smooth)
303
+ loss = torch.sum(loss)
304
+ return loss
305
+
306
+
307
+ ####
308
+ def mse_loss(true, pred):
309
+ """Calculate mean squared error loss.
310
+
311
+ Args:
312
+ true: ground truth of combined horizontal
313
+ and vertical maps
314
+ pred: prediction of combined horizontal
315
+ and vertical maps
316
+
317
+ Returns:
318
+ loss: mean squared error
319
+
320
+ """
321
+ loss = pred - true
322
+ loss = (loss * loss).mean()
323
+ return loss
324
+
325
+
326
+ ####
327
+ def msge_loss(true, pred, focus):
328
+ """Calculate the mean squared error of the gradients of
329
+ horizontal and vertical map predictions. Assumes
330
+ channel 0 is Vertical and channel 1 is Horizontal.
331
+
332
+ Args:
333
+ true: ground truth of combined horizontal
334
+ and vertical maps
335
+ pred: prediction of combined horizontal
336
+ and vertical maps
337
+ focus: area where to apply loss (we only calculate
338
+ the loss within the nuclei)
339
+
340
+ Returns:
341
+ loss: mean squared error of gradients
342
+
343
+ """
344
+
345
+ def get_sobel_kernel(size):
346
+ """Get sobel kernel with a given size."""
347
+ assert size % 2 == 1, "Must be odd, get size=%d" % size
348
+
349
+ h_range = torch.arange(
350
+ -size // 2 + 1,
351
+ size // 2 + 1,
352
+ dtype=torch.float32,
353
+ device="cuda",
354
+ requires_grad=False,
355
+ )
356
+ v_range = torch.arange(
357
+ -size // 2 + 1,
358
+ size // 2 + 1,
359
+ dtype=torch.float32,
360
+ device="cuda",
361
+ requires_grad=False,
362
+ )
363
+ h, v = torch.meshgrid(h_range, v_range)
364
+ kernel_h = h / (h * h + v * v + 1.0e-15)
365
+ kernel_v = v / (h * h + v * v + 1.0e-15)
366
+ return kernel_h, kernel_v
367
+
368
+ ####
369
+ def get_gradient_hv(hv):
370
+ """For calculating gradient."""
371
+ kernel_h, kernel_v = get_sobel_kernel(5)
372
+ kernel_h = kernel_h.view(1, 1, 5, 5) # constant
373
+ kernel_v = kernel_v.view(1, 1, 5, 5) # constant
374
+
375
+ h_ch = hv[..., 0].unsqueeze(1) # Nx1xHxW
376
+ v_ch = hv[..., 1].unsqueeze(1) # Nx1xHxW
377
+
378
+ # can only apply in NCHW mode
379
+ h_dh_ch = F.conv2d(h_ch, kernel_h, padding=2)
380
+ v_dv_ch = F.conv2d(v_ch, kernel_v, padding=2)
381
+ dhv = torch.cat([h_dh_ch, v_dv_ch], dim=1)
382
+ dhv = dhv.permute(0, 2, 3, 1).contiguous() # to NHWC
383
+ return dhv
384
+
385
+ focus = (focus[..., None]).float() # assume input NHW
386
+ focus = torch.cat([focus, focus], axis=-1)
387
+ true_grad = get_gradient_hv(true)
388
+ pred_grad = get_gradient_hv(pred)
389
+ loss = pred_grad - true_grad
390
+ loss = focus * (loss * loss)
391
+ # artificial reduce_mean with focused region
392
+ loss = loss.sum() / (focus.sum() + 1.0e-8)
393
+ return loss
394
+
395
+
396
+ def __proc_np_hv(pred, np_thres, ksize, overall_thres, obj_size_thres):
397
+ """Process Nuclei Prediction with XY Coordinate Map.
398
+
399
+ Args:
400
+ pred: prediction output, assuming
401
+ channel 0 contain probability map of nuclei
402
+ channel 1 containing the regressed X-map
403
+ channel 2 containing the regressed Y-map
404
+
405
+ """
406
+ pred = np.array(pred, dtype=np.float32)
407
+
408
+ blb_raw = pred[..., 0]
409
+ h_dir_raw = pred[..., 1]
410
+ v_dir_raw = pred[..., 2]
411
+
412
+ # processing
413
+ blb = np.array(blb_raw >= np_thres, dtype=np.int32)
414
+
415
+ blb = measurements.label(blb)[0]
416
+ blb = remove_small_objects(blb, min_size=10)
417
+ blb[blb > 0] = 1 # background is 0 already
418
+
419
+ h_dir = cv2.normalize(
420
+ h_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
421
+ )
422
+ v_dir = cv2.normalize(
423
+ v_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
424
+ )
425
+
426
+ sobelh = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=ksize)
427
+ sobelv = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=ksize)
428
+
429
+ sobelh = 1 - (
430
+ cv2.normalize(
431
+ sobelh, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
432
+ )
433
+ )
434
+ sobelv = 1 - (
435
+ cv2.normalize(
436
+ sobelv, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
437
+ )
438
+ )
439
+
440
+ overall = np.maximum(sobelh, sobelv)
441
+ overall = overall - (1 - blb)
442
+ overall[overall < 0] = 0
443
+
444
+ dist = (1.0 - overall) * blb
445
+ ## nuclei values form mountains so inverse to get basins
446
+ dist = -cv2.GaussianBlur(dist, (3, 3), 0)
447
+
448
+ overall = np.array(overall >= overall_thres, dtype=np.int32)
449
+
450
+ marker = blb - overall
451
+ marker[marker < 0] = 0
452
+ marker = binary_fill_holes(marker).astype("uint8")
453
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
454
+ marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel)
455
+ marker = measurements.label(marker)[0]
456
+ marker = remove_small_objects(marker, min_size=obj_size_thres)
457
+
458
+ proced_pred = watershed(dist, markers=marker, mask=blb)
459
+
460
+ return proced_pred
461
+
462
+ def __proc_np_hv_2(pred, np_thres=0.5, ksize=21, overall_thres=0.4, obj_size_thres=10):
463
+ """Process Nuclei Prediction with XY Coordinate Map.
464
+
465
+ Args:
466
+ pred: prediction output, assuming
467
+ channel 0 contain probability map of nuclei
468
+ channel 1 containing the regressed X-map
469
+ channel 2 containing the regressed Y-map
470
+
471
+ """
472
+ pred = np.array(pred, dtype=np.float32)
473
+
474
+ blb_raw = pred[..., 0]
475
+ h_dir_raw = pred[..., 1]
476
+ v_dir_raw = pred[..., 2]
477
+
478
+ # processing
479
+ blb = np.array(blb_raw >= np_thres, dtype=np.int32)
480
+
481
+ blb = measurements.label(blb)[0]
482
+ blb = remove_small_objects(blb, min_size=10)
483
+ blb[blb > 0] = 1 # background is 0 already
484
+
485
+ h_dir = rescale_intensity(h_dir_raw, out_range=(0, 1)).astype('float32')
486
+ v_dir = rescale_intensity(v_dir_raw, out_range=(0, 1)).astype('float32')
487
+
488
+ sobelh = sobel_v(h_dir).astype('float64')
489
+ sobelv = sobel_h(v_dir).astype('float64')
490
+
491
+ sobelh = 1 - rescale_intensity(sobelh, out_range=(0, 1)).astype('float32')
492
+ sobelv = 1 - rescale_intensity(sobelv, out_range=(0, 1)).astype('float32')
493
+
494
+ overall = np.maximum(sobelh, sobelv)
495
+ overall = overall - (1 - blb)
496
+ overall[overall < 0] = 0
497
+
498
+ dist = (1.0 - overall) * blb
499
+ ## nuclei values form mountains so inverse to get basins
500
+ dist = - gaussian(dist, sigma=0.8)
501
+
502
+ overall = np.array(overall >= overall_thres, dtype=np.int32)
503
+
504
+ marker = blb - overall
505
+ marker[marker < 0] = 0
506
+ marker = binary_fill_holes(marker).astype("uint8")
507
+ kernel = disk(2)
508
+ marker = binary_opening(marker, kernel)
509
+ marker = measurements.label(marker)[0]
510
+ marker = remove_small_objects(marker, min_size=obj_size_thres)
511
+
512
+ proced_pred = watershed(dist, markers=marker, mask=blb)
513
+
514
+ return proced_pred
515
+
516
+
517
+ ####
518
+ def colorize(ch, vmin, vmax):
519
+ """Will clamp value value outside the provided range to vmax and vmin."""
520
+ cmap = plt.get_cmap("jet")
521
+ ch = np.squeeze(ch.astype("float32"))
522
+ vmin = vmin if vmin is not None else ch.min()
523
+ vmax = vmax if vmax is not None else ch.max()
524
+ ch[ch > vmax] = vmax # clamp value
525
+ ch[ch < vmin] = vmin
526
+ ch = (ch - vmin) / (vmax - vmin + 1.0e-16)
527
+ # take RGB from RGBA heat map
528
+ ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8")
529
+ return ch_cmap
530
+
531
+
532
+ ####
533
+ def random_colors(N, bright=True):
534
+ """Generate random colors.
535
+
536
+ To get visually distinct colors, generate them in HSV space then
537
+ convert to RGB.
538
+ """
539
+ brightness = 1.0 if bright else 0.7
540
+ hsv = [(i / N, 1, brightness) for i in range(N)]
541
+ colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
542
+ random.shuffle(colors)
543
+ return colors
544
+
545
+
546
+ ####
547
+ def visualize_instances_map(
548
+ input_image, inst_map, type_map=None, type_colour=None, line_thickness=2
549
+ ):
550
+ """Overlays segmentation results on image as contours.
551
+
552
+ Args:
553
+ input_image: input image
554
+ inst_map: instance mask with unique value for every object
555
+ type_map: type mask with unique value for every class
556
+ type_colour: a dict of {type : colour} , `type` is from 0-N
557
+ and `colour` is a tuple of (R, G, B)
558
+ line_thickness: line thickness of contours
559
+
560
+ Returns:
561
+ overlay: output image with segmentation overlay as contours
562
+ """
563
+ overlay = np.copy((input_image).astype(np.uint8))
564
+
565
+ inst_list = list(np.unique(inst_map)) # get list of instances
566
+ inst_list.remove(0) # remove background
567
+
568
+ inst_rng_colors = random_colors(len(inst_list))
569
+ inst_rng_colors = np.array(inst_rng_colors) * 255
570
+ inst_rng_colors = inst_rng_colors.astype(np.uint8)
571
+
572
+ for inst_idx, inst_id in enumerate(inst_list):
573
+ inst_map_mask = np.array(inst_map == inst_id, np.uint8) # get single object
574
+ y1, y2, x1, x2 = get_bounding_box(inst_map_mask)
575
+ y1 = y1 - 2 if y1 - 2 >= 0 else y1
576
+ x1 = x1 - 2 if x1 - 2 >= 0 else x1
577
+ x2 = x2 + 2 if x2 + 2 <= inst_map.shape[1] - 1 else x2
578
+ y2 = y2 + 2 if y2 + 2 <= inst_map.shape[0] - 1 else y2
579
+ inst_map_crop = inst_map_mask[y1:y2, x1:x2]
580
+ contours_crop = cv2.findContours(
581
+ inst_map_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
582
+ )
583
+ # only has 1 instance per map, no need to check #contour detected by opencv
584
+ contours_crop = np.squeeze(
585
+ contours_crop[0][0].astype("int32")
586
+ ) # * opencv protocol format may break
587
+ contours_crop += np.asarray([[x1, y1]]) # index correction
588
+ if type_map is not None:
589
+ type_map_crop = type_map[y1:y2, x1:x2]
590
+ type_id = np.unique(type_map_crop).max() # non-zero
591
+ inst_colour = type_colour[type_id]
592
+ else:
593
+ inst_colour = (inst_rng_colors[inst_idx]).tolist()
594
+ cv2.drawContours(overlay, [contours_crop], -1, inst_colour, line_thickness)
595
+ return overlay
596
+
597
+
598
+ def sliding_window_inference(
599
+ inputs: torch.Tensor,
600
+ roi_size: Union[Sequence[int], int],
601
+ sw_batch_size: int,
602
+ predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]],
603
+ overlap: float = 0.25,
604
+ mode: Union[BlendMode, str] = BlendMode.CONSTANT,
605
+ sigma_scale: Union[Sequence[float], float] = 0.125,
606
+ padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
607
+ cval: float = 0.0,
608
+ sw_device: Union[torch.device, str, None] = None,
609
+ device: Union[torch.device, str, None] = None,
610
+ progress: bool = False,
611
+ roi_weight_map: Union[torch.Tensor, None] = None,
612
+ *args: Any,
613
+ **kwargs: Any,
614
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]:
615
+ """
616
+ Sliding window inference on `inputs` with `predictor`.
617
+
618
+ The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors.
619
+ Each output in the tuple or dict value is allowed to have different resolutions with respect to the input.
620
+ e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes
621
+ could be ([128,64,256], [64,32,128]).
622
+ In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still
623
+ an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters
624
+ so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension).
625
+
626
+ When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
627
+ To maintain the same spatial sizes, the output image will be cropped to the original input size.
628
+
629
+ Args:
630
+ inputs: input image to be processed (assuming NCHW[D])
631
+ roi_size: the spatial window size for inferences.
632
+ When its components have None or non-positives, the corresponding inputs dimension will be used.
633
+ if the components of the `roi_size` are non-positive values, the transform will use the
634
+ corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
635
+ to `(32, 64)` if the second spatial dimension size of img is `64`.
636
+ sw_batch_size: the batch size to run window slices.
637
+ predictor: given input tensor ``patch_data`` in shape NCHW[D],
638
+ The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary
639
+ with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D'];
640
+ where H'W'[D'] represents the output patch's spatial size, M is the number of output channels,
641
+ N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128),
642
+ the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)).
643
+ In this case, the parameter `overlap` and `roi_size` need to be carefully chosen
644
+ to ensure the scaled output ROI sizes are still integers.
645
+ If the `predictor`'s input and output spatial sizes are different,
646
+ we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension.
647
+ overlap: Amount of overlap between scans.
648
+ mode: {``"constant"``, ``"gaussian"``}
649
+ How to blend output of overlapping windows. Defaults to ``"constant"``.
650
+
651
+ - ``"constant``": gives equal weight to all predictions.
652
+ - ``"gaussian``": gives less weight to predictions on edges of windows.
653
+
654
+ sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
655
+ Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
656
+ When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
657
+ spatial dimensions.
658
+ padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
659
+ Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
660
+ See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
661
+ cval: fill value for 'constant' padding mode. Default: 0
662
+ sw_device: device for the window data.
663
+ By default the device (and accordingly the memory) of the `inputs` is used.
664
+ Normally `sw_device` should be consistent with the device where `predictor` is defined.
665
+ device: device for the stitched output prediction.
666
+ By default the device (and accordingly the memory) of the `inputs` is used. If for example
667
+ set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
668
+ `inputs` and `roi_size`. Output is on the `device`.
669
+ progress: whether to print a `tqdm` progress bar.
670
+ roi_weight_map: pre-computed (non-negative) weight map for each ROI.
671
+ If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
672
+ args: optional args to be passed to ``predictor``.
673
+ kwargs: optional keyword args to be passed to ``predictor``.
674
+
675
+ Note:
676
+ - input must be channel-first and have a batch dim, supports N-D sliding window.
677
+
678
+ """
679
+ compute_dtype = inputs.dtype
680
+ num_spatial_dims = len(inputs.shape) - 2
681
+ if overlap < 0 or overlap >= 1:
682
+ raise ValueError("overlap must be >= 0 and < 1.")
683
+
684
+ # determine image spatial size and batch size
685
+ # Note: all input images must have the same image size and batch size
686
+ batch_size, _, *image_size_ = inputs.shape
687
+
688
+ if device is None:
689
+ device = inputs.device
690
+ if sw_device is None:
691
+ sw_device = inputs.device
692
+
693
+ roi_size = fall_back_tuple(roi_size, image_size_)
694
+ # in case that image size is smaller than roi size
695
+ image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
696
+ pad_size = []
697
+ for k in range(len(inputs.shape) - 1, 1, -1):
698
+ diff = max(roi_size[k - 2] - inputs.shape[k], 0)
699
+ half = diff // 2
700
+ pad_size.extend([half, diff - half])
701
+ inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)
702
+
703
+ scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
704
+
705
+ # Store all slices in list
706
+ slices = dense_patch_slices(image_size, roi_size, scan_interval)
707
+ num_win = len(slices) # number of windows per image
708
+ total_slices = num_win * batch_size # total number of windows
709
+
710
+ # Create window-level importance map
711
+ valid_patch_size = get_valid_patch_size(image_size, roi_size)
712
+ if valid_patch_size == roi_size and (roi_weight_map is not None):
713
+ importance_map = roi_weight_map
714
+ else:
715
+ try:
716
+ importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device)
717
+ except BaseException as e:
718
+ raise RuntimeError(
719
+ "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
720
+ ) from e
721
+ importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore
722
+ # handle non-positive weights
723
+ min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3)
724
+ importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype)
725
+
726
+ # Perform predictions
727
+ dict_key, output_image_list, count_map_list = None, [], []
728
+ _initialized_ss = -1
729
+ is_tensor_output = True # whether the predictor's output is a tensor (instead of dict/tuple)
730
+
731
+ # for each patch
732
+ for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size):
733
+ slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices))
734
+ unravel_slice = [
735
+ [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win])
736
+ for idx in slice_range
737
+ ]
738
+ window_data = torch.cat(
739
+ [convert_data_type(inputs[win_slice], torch.Tensor)[0] for win_slice in unravel_slice]
740
+ ).to(sw_device)
741
+ seg_prob_out = predictor(window_data, *args, **kwargs) # batched patch segmentation
742
+
743
+ # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
744
+ seg_prob_tuple: Tuple[torch.Tensor, ...]
745
+ if isinstance(seg_prob_out, torch.Tensor):
746
+ seg_prob_tuple = (seg_prob_out,)
747
+ elif isinstance(seg_prob_out, Mapping):
748
+ if dict_key is None:
749
+ dict_key = sorted(seg_prob_out.keys()) # track predictor's output keys
750
+ seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key)
751
+ is_tensor_output = False
752
+ else:
753
+ seg_prob_tuple = ensure_tuple(seg_prob_out)
754
+ is_tensor_output = False
755
+
756
+ # for each output in multi-output list
757
+ for ss, seg_prob in enumerate(seg_prob_tuple):
758
+ seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN
759
+
760
+ # compute zoom scale: out_roi_size/in_roi_size
761
+ zoom_scale = []
762
+ for axis, (img_s_i, out_w_i, in_w_i) in enumerate(
763
+ zip(image_size, seg_prob.shape[2:], window_data.shape[2:])
764
+ ):
765
+ _scale = out_w_i / float(in_w_i)
766
+ if not (img_s_i * _scale).is_integer():
767
+ warnings.warn(
768
+ f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial "
769
+ f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs."
770
+ )
771
+ zoom_scale.append(_scale)
772
+
773
+ if _initialized_ss < ss: # init. the ss-th buffer at the first iteration
774
+ # construct multi-resolution outputs
775
+ output_classes = seg_prob.shape[1]
776
+ output_shape = [batch_size, output_classes] + [
777
+ int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip(image_size, zoom_scale)
778
+ ]
779
+ # allocate memory to store the full output and the count for overlapping parts
780
+ output_image_list.append(torch.zeros(output_shape, dtype=compute_dtype, device='cpu'))
781
+ count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device='cpu'))
782
+ _initialized_ss += 1
783
+
784
+ # resizing the importance_map
785
+ resizer = Resize(spatial_size=seg_prob.shape[2:], mode="nearest", anti_aliasing=False)
786
+
787
+ # store the result in the proper location of the full output. Apply weights from importance map.
788
+ for idx, original_idx in zip(slice_range, unravel_slice):
789
+ # zoom roi
790
+ original_idx_zoom = list(original_idx) # 4D for 2D image, 5D for 3D image
791
+ for axis in range(2, len(original_idx_zoom)):
792
+ zoomed_start = original_idx[axis].start * zoom_scale[axis - 2]
793
+ zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2]
794
+ if not zoomed_start.is_integer() or (not zoomed_end.is_integer()):
795
+ warnings.warn(
796
+ f"For axis-{axis-2} of output[{ss}], the output roi range is not int. "
797
+ f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). "
798
+ f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. "
799
+ f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n"
800
+ f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. "
801
+ "Tips: if overlap*roi_size*zoom_scale is an integer, it usually works."
802
+ )
803
+ original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None)
804
+ importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(compute_dtype)
805
+ # store results and weights
806
+ #print(output_image_list[ss][original_idx_zoom].device,importance_map_zoom.cpu().device,seg_prob.cpu().device)
807
+ output_image_list[ss][original_idx_zoom] += importance_map_zoom.cpu() * seg_prob[idx - slice_g].cpu()
808
+ count_map_list[ss][original_idx_zoom] += (
809
+ importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(count_map_list[ss][original_idx_zoom].shape).cpu()
810
+ )
811
+
812
+ # account for any overlapping sections
813
+ for ss in range(len(output_image_list)):
814
+ output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(compute_dtype)
815
+
816
+ # remove padding if image_size smaller than roi_size
817
+ for ss, output_i in enumerate(output_image_list):
818
+ if torch.isnan(output_i).any() or torch.isinf(output_i).any():
819
+ warnings.warn("Sliding window inference results contain NaN or Inf.")
820
+
821
+ zoom_scale = [
822
+ seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size)
823
+ ]
824
+
825
+ final_slicing: List[slice] = []
826
+ for sp in range(num_spatial_dims):
827
+ slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2])
828
+ slice_dim = slice(
829
+ int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])),
830
+ int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])),
831
+ )
832
+ final_slicing.insert(0, slice_dim)
833
+ while len(final_slicing) < len(output_i.shape):
834
+ final_slicing.insert(0, slice(None))
835
+ output_image_list[ss] = output_i[final_slicing]
836
+
837
+ if dict_key is not None: # if output of predictor is a dict
838
+ final_output = dict(zip(dict_key, output_image_list))
839
+ else:
840
+ final_output = tuple(output_image_list) # type: ignore
841
+ final_output = final_output[0] if is_tensor_output else final_output # type: ignore
842
+ if isinstance(inputs, MetaTensor):
843
+ final_output = convert_to_dst_type(final_output, inputs)[0] # type: ignore
844
+ return final_output
845
+
846
+
847
+ def _get_scan_interval(
848
+ image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float
849
+ ) -> Tuple[int, ...]:
850
+ """
851
+ Compute scan interval according to the image size, roi size and overlap.
852
+ Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,
853
+ use 1 instead to make sure sliding window works.
854
+
855
+ """
856
+ if len(image_size) != num_spatial_dims:
857
+ raise ValueError("image coord different from spatial dims.")
858
+ if len(roi_size) != num_spatial_dims:
859
+ raise ValueError("roi coord different from spatial dims.")
860
+
861
+ scan_interval = []
862
+ for i in range(num_spatial_dims):
863
+ if roi_size[i] == image_size[i]:
864
+ scan_interval.append(int(roi_size[i]))
865
+ else:
866
+ interval = int(roi_size[i] * (1 - overlap))
867
+ scan_interval.append(interval if interval > 0 else 1)
868
+ return tuple(scan_interval)
utils_modify.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import warnings
13
+ from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from stardist.big import _grid_divisible, BlockND, OBJECT_KEYS#, repaint_labels
18
+ from stardist.matching import relabel_sequential
19
+ from stardist import dist_to_coord, non_maximum_suppression, polygons_to_label
20
+ from stardist import random_label_cmap,ray_angles
21
+ from stardist import star_dist,edt_prob
22
+ from monai.data.meta_tensor import MetaTensor
23
+ from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size
24
+ from monai.transforms import Resize
25
+ from monai.utils import (
26
+ BlendMode,
27
+ PytorchPadMode,
28
+ convert_data_type,
29
+ convert_to_dst_type,
30
+ ensure_tuple,
31
+ fall_back_tuple,
32
+ look_up_option,
33
+ optional_import,
34
+ )
35
+
36
+ tqdm, _ = optional_import("tqdm", name="tqdm")
37
+
38
+ __all__ = ["sliding_window_inference"]
39
+
40
+
41
+ def sliding_window_inference_large(inputs,block_size,min_overlap,context,roi_size,sw_batch_size,predictor,device):
42
+
43
+ h,w = inputs.shape[0],inputs.shape[1]
44
+ if h < 5000 or w < 5000:
45
+ test_tensor = torch.from_numpy(np.expand_dims(inputs, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
46
+ output_dist,output_prob = sliding_window_inference(test_tensor, roi_size, sw_batch_size, predictor)
47
+ prob = output_prob[0][0].cpu().numpy()
48
+ dist = output_dist[0].cpu().numpy()
49
+ dist = np.transpose(dist,(1,2,0))
50
+ dist = np.maximum(1e-3, dist)
51
+ points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4)
52
+
53
+ coord = dist_to_coord(disti,points)
54
+
55
+ labels_out = polygons_to_label(disti, points, prob=probi,shape=prob.shape)
56
+ else:
57
+ n = inputs.ndim
58
+ axes = 'YXC'
59
+ grid = (1,1,1)
60
+ if np.isscalar(block_size): block_size = n*[block_size]
61
+ if np.isscalar(min_overlap): min_overlap = n*[min_overlap]
62
+ if np.isscalar(context): context = n*[context]
63
+ shape_out = (inputs.shape[0],inputs.shape[1])
64
+ labels_out = np.zeros(shape_out, dtype=np.uint64)
65
+ #print(inputs.dtype)
66
+ block_size[2] = inputs.shape[2]
67
+ min_overlap[2] = context[2] = 0
68
+ block_size = tuple(_grid_divisible(g, v, name='block_size', verbose=False) for v,g,a in zip(block_size, grid,axes))
69
+ min_overlap = tuple(_grid_divisible(g, v, name='min_overlap', verbose=False) for v,g,a in zip(min_overlap,grid,axes))
70
+ context = tuple(_grid_divisible(g, v, name='context', verbose=False) for v,g,a in zip(context, grid,axes))
71
+ print(f'effective: block_size={block_size}, min_overlap={min_overlap}, context={context}', flush=True)
72
+ blocks = BlockND.cover(inputs.shape, axes, block_size, min_overlap, context)
73
+ label_offset = 1
74
+ blocks = tqdm(blocks)
75
+ for block in blocks:
76
+ image = block.read(inputs, axes=axes)
77
+ test_tensor = torch.from_numpy(np.expand_dims(image, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
78
+ output_dist,output_prob = sliding_window_inference(test_tensor, roi_size, sw_batch_size, predictor)
79
+ prob = output_prob[0][0].cpu().numpy()
80
+ dist = output_dist[0].cpu().numpy()
81
+ dist = np.transpose(dist,(1,2,0))
82
+ dist = np.maximum(1e-3, dist)
83
+ points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4)
84
+
85
+ coord = dist_to_coord(disti,points)
86
+ polys = dict(coord=coord, points=points, prob=probi)
87
+ labels = polygons_to_label(disti, points, prob=probi,shape=prob.shape)
88
+ labels = block.crop_context(labels, axes='YX')
89
+ labels, polys = block.filter_objects(labels, polys, axes='YX')
90
+ labels = relabel_sequential(labels, label_offset)[0]
91
+ if labels_out is not None:
92
+ block.write(labels_out, labels, axes='YX')
93
+ #for k,v in polys.items():
94
+ #polys_all.setdefault(k,[]).append(v)
95
+ label_offset += len(polys['prob'])
96
+ del labels
97
+ #polys_all = {k: (np.concatenate(v) if k in OBJECT_KEYS else v[0]) for k,v in polys_all.items()}
98
+ return labels_out
99
+ def sliding_window_inference(
100
+ inputs: torch.Tensor,
101
+ roi_size: Union[Sequence[int], int],
102
+ sw_batch_size: int,
103
+ predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]],
104
+ overlap: float = 0.25,
105
+ mode: Union[BlendMode, str] = BlendMode.CONSTANT,
106
+ sigma_scale: Union[Sequence[float], float] = 0.125,
107
+ padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
108
+ cval: float = 0.0,
109
+ sw_device: Union[torch.device, str, None] = None,
110
+ device: Union[torch.device, str, None] = None,
111
+ progress: bool = False,
112
+ roi_weight_map: Union[torch.Tensor, None] = None,
113
+ *args: Any,
114
+ **kwargs: Any,
115
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]:
116
+ """
117
+ Sliding window inference on `inputs` with `predictor`.
118
+
119
+ The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors.
120
+ Each output in the tuple or dict value is allowed to have different resolutions with respect to the input.
121
+ e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes
122
+ could be ([128,64,256], [64,32,128]).
123
+ In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still
124
+ an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters
125
+ so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension).
126
+
127
+ When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
128
+ To maintain the same spatial sizes, the output image will be cropped to the original input size.
129
+
130
+ Args:
131
+ inputs: input image to be processed (assuming NCHW[D])
132
+ roi_size: the spatial window size for inferences.
133
+ When its components have None or non-positives, the corresponding inputs dimension will be used.
134
+ if the components of the `roi_size` are non-positive values, the transform will use the
135
+ corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
136
+ to `(32, 64)` if the second spatial dimension size of img is `64`.
137
+ sw_batch_size: the batch size to run window slices.
138
+ predictor: given input tensor ``patch_data`` in shape NCHW[D],
139
+ The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary
140
+ with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D'];
141
+ where H'W'[D'] represents the output patch's spatial size, M is the number of output channels,
142
+ N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128),
143
+ the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)).
144
+ In this case, the parameter `overlap` and `roi_size` need to be carefully chosen
145
+ to ensure the scaled output ROI sizes are still integers.
146
+ If the `predictor`'s input and output spatial sizes are different,
147
+ we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension.
148
+ overlap: Amount of overlap between scans.
149
+ mode: {``"constant"``, ``"gaussian"``}
150
+ How to blend output of overlapping windows. Defaults to ``"constant"``.
151
+
152
+ - ``"constant``": gives equal weight to all predictions.
153
+ - ``"gaussian``": gives less weight to predictions on edges of windows.
154
+
155
+ sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
156
+ Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
157
+ When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
158
+ spatial dimensions.
159
+ padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
160
+ Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
161
+ See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
162
+ cval: fill value for 'constant' padding mode. Default: 0
163
+ sw_device: device for the window data.
164
+ By default the device (and accordingly the memory) of the `inputs` is used.
165
+ Normally `sw_device` should be consistent with the device where `predictor` is defined.
166
+ device: device for the stitched output prediction.
167
+ By default the device (and accordingly the memory) of the `inputs` is used. If for example
168
+ set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
169
+ `inputs` and `roi_size`. Output is on the `device`.
170
+ progress: whether to print a `tqdm` progress bar.
171
+ roi_weight_map: pre-computed (non-negative) weight map for each ROI.
172
+ If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
173
+ args: optional args to be passed to ``predictor``.
174
+ kwargs: optional keyword args to be passed to ``predictor``.
175
+
176
+ Note:
177
+ - input must be channel-first and have a batch dim, supports N-D sliding window.
178
+
179
+ """
180
+ compute_dtype = inputs.dtype
181
+ num_spatial_dims = len(inputs.shape) - 2
182
+ if overlap < 0 or overlap >= 1:
183
+ raise ValueError("overlap must be >= 0 and < 1.")
184
+
185
+ # determine image spatial size and batch size
186
+ # Note: all input images must have the same image size and batch size
187
+ batch_size, _, *image_size_ = inputs.shape
188
+
189
+ if device is None:
190
+ device = inputs.device
191
+ if sw_device is None:
192
+ sw_device = inputs.device
193
+
194
+ roi_size = fall_back_tuple(roi_size, image_size_)
195
+ # in case that image size is smaller than roi size
196
+ image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
197
+ pad_size = []
198
+ for k in range(len(inputs.shape) - 1, 1, -1):
199
+ diff = max(roi_size[k - 2] - inputs.shape[k], 0)
200
+ half = diff // 2
201
+ pad_size.extend([half, diff - half])
202
+ inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)
203
+
204
+ scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
205
+
206
+ # Store all slices in list
207
+ slices = dense_patch_slices(image_size, roi_size, scan_interval)
208
+ num_win = len(slices) # number of windows per image
209
+ total_slices = num_win * batch_size # total number of windows
210
+
211
+ # Create window-level importance map
212
+ valid_patch_size = get_valid_patch_size(image_size, roi_size)
213
+ if valid_patch_size == roi_size and (roi_weight_map is not None):
214
+ importance_map = roi_weight_map
215
+ else:
216
+ try:
217
+ importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device)
218
+ except BaseException as e:
219
+ raise RuntimeError(
220
+ "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
221
+ ) from e
222
+ importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore
223
+ # handle non-positive weights
224
+ min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3)
225
+ importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype)
226
+
227
+ # Perform predictions
228
+ dict_key, output_image_list, count_map_list = None, [], []
229
+ _initialized_ss = -1
230
+ is_tensor_output = True # whether the predictor's output is a tensor (instead of dict/tuple)
231
+
232
+ # for each patch
233
+ for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size):
234
+ slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices))
235
+ unravel_slice = [
236
+ [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win])
237
+ for idx in slice_range
238
+ ]
239
+ window_data = torch.cat(
240
+ [convert_data_type(inputs[win_slice], torch.Tensor)[0] for win_slice in unravel_slice]
241
+ ).to(sw_device)
242
+ seg_prob_out = predictor(window_data, *args, **kwargs) # batched patch segmentation
243
+
244
+ # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
245
+ seg_prob_tuple: Tuple[torch.Tensor, ...]
246
+ if isinstance(seg_prob_out, torch.Tensor):
247
+ seg_prob_tuple = (seg_prob_out,)
248
+ elif isinstance(seg_prob_out, Mapping):
249
+ if dict_key is None:
250
+ dict_key = sorted(seg_prob_out.keys()) # track predictor's output keys
251
+ seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key)
252
+ is_tensor_output = False
253
+ else:
254
+ seg_prob_tuple = ensure_tuple(seg_prob_out)
255
+ is_tensor_output = False
256
+
257
+ # for each output in multi-output list
258
+ for ss, seg_prob in enumerate(seg_prob_tuple):
259
+ seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN
260
+
261
+ # compute zoom scale: out_roi_size/in_roi_size
262
+ zoom_scale = []
263
+ for axis, (img_s_i, out_w_i, in_w_i) in enumerate(
264
+ zip(image_size, seg_prob.shape[2:], window_data.shape[2:])
265
+ ):
266
+ _scale = out_w_i / float(in_w_i)
267
+ if not (img_s_i * _scale).is_integer():
268
+ warnings.warn(
269
+ f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial "
270
+ f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs."
271
+ )
272
+ zoom_scale.append(_scale)
273
+
274
+ if _initialized_ss < ss: # init. the ss-th buffer at the first iteration
275
+ # construct multi-resolution outputs
276
+ output_classes = seg_prob.shape[1]
277
+ output_shape = [batch_size, output_classes] + [
278
+ int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip(image_size, zoom_scale)
279
+ ]
280
+ # allocate memory to store the full output and the count for overlapping parts
281
+ output_image_list.append(torch.zeros(output_shape, dtype=compute_dtype, device=device))
282
+ count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device))
283
+ _initialized_ss += 1
284
+
285
+ # resizing the importance_map
286
+ resizer = Resize(spatial_size=seg_prob.shape[2:], mode="nearest", anti_aliasing=False)
287
+
288
+ # store the result in the proper location of the full output. Apply weights from importance map.
289
+ for idx, original_idx in zip(slice_range, unravel_slice):
290
+ # zoom roi
291
+ original_idx_zoom = list(original_idx) # 4D for 2D image, 5D for 3D image
292
+ for axis in range(2, len(original_idx_zoom)):
293
+ zoomed_start = original_idx[axis].start * zoom_scale[axis - 2]
294
+ zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2]
295
+ if not zoomed_start.is_integer() or (not zoomed_end.is_integer()):
296
+ warnings.warn(
297
+ f"For axis-{axis-2} of output[{ss}], the output roi range is not int. "
298
+ f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). "
299
+ f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. "
300
+ f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n"
301
+ f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. "
302
+ "Tips: if overlap*roi_size*zoom_scale is an integer, it usually works."
303
+ )
304
+ original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None)
305
+ importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(compute_dtype)
306
+ # store results and weights
307
+ output_image_list[ss][original_idx_zoom] += importance_map_zoom * seg_prob[idx - slice_g]
308
+ count_map_list[ss][original_idx_zoom] += (
309
+ importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(count_map_list[ss][original_idx_zoom].shape)
310
+ )
311
+
312
+ # account for any overlapping sections
313
+ for ss in range(len(output_image_list)):
314
+ output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(compute_dtype)
315
+
316
+ # remove padding if image_size smaller than roi_size
317
+ for ss, output_i in enumerate(output_image_list):
318
+ if torch.isnan(output_i).any() or torch.isinf(output_i).any():
319
+ warnings.warn("Sliding window inference results contain NaN or Inf.")
320
+
321
+ zoom_scale = [
322
+ seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size)
323
+ ]
324
+
325
+ final_slicing: List[slice] = []
326
+ for sp in range(num_spatial_dims):
327
+ slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2])
328
+ slice_dim = slice(
329
+ int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])),
330
+ int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])),
331
+ )
332
+ final_slicing.insert(0, slice_dim)
333
+ while len(final_slicing) < len(output_i.shape):
334
+ final_slicing.insert(0, slice(None))
335
+ output_image_list[ss] = output_i[final_slicing]
336
+
337
+ if dict_key is not None: # if output of predictor is a dict
338
+ final_output = dict(zip(dict_key, output_image_list))
339
+ else:
340
+ final_output = tuple(output_image_list) # type: ignore
341
+ final_output = final_output[0] if is_tensor_output else final_output
342
+
343
+ if isinstance(inputs, MetaTensor):
344
+ final_output = convert_to_dst_type(final_output, inputs, device=device)[0] # type: ignore
345
+ return final_output
346
+
347
+
348
+ def _get_scan_interval(
349
+ image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float
350
+ ) -> Tuple[int, ...]:
351
+ """
352
+ Compute scan interval according to the image size, roi size and overlap.
353
+ Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,
354
+ use 1 instead to make sure sliding window works.
355
+
356
+ """
357
+ if len(image_size) != num_spatial_dims:
358
+ raise ValueError("image coord different from spatial dims.")
359
+ if len(roi_size) != num_spatial_dims:
360
+ raise ValueError("roi coord different from spatial dims.")
361
+
362
+ scan_interval = []
363
+ for i in range(num_spatial_dims):
364
+ if roi_size[i] == image_size[i]:
365
+ scan_interval.append(int(roi_size[i]))
366
+ else:
367
+ interval = int(roi_size[i] * (1 - overlap))
368
+ scan_interval.append(interval if interval > 0 else 1)
369
+ return tuple(scan_interval)