#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Adapted form MONAI Tutorial: https://github.com/Project-MONAI/tutorials/tree/main/2d_segmentation/torch """ import argparse import os join = os.path.join import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from stardist import star_dist,edt_prob from stardist import dist_to_coord, non_maximum_suppression, polygons_to_label from stardist import random_label_cmap,ray_angles import monai from collections import OrderedDict from compute_metric import eval_tp_fp_fn,remove_boundary_cells from monai.data import decollate_batch, PILReader from monai.inferers import sliding_window_inference from monai.metrics import DiceMetric from monai.transforms import ( Activations, AsChannelFirstd, AddChanneld, AsDiscrete, Compose, LoadImaged, SpatialPadd, RandSpatialCropd, RandRotate90d, ScaleIntensityd, RandAxisFlipd, RandZoomd, RandGaussianNoised, RandAdjustContrastd, RandGaussianSmoothd, RandHistogramShiftd, EnsureTyped, EnsureType, ) from monai.visualize import plot_2d_or_3d_image import matplotlib.pyplot as plt from datetime import datetime import shutil import tqdm from models.unetr2d import UNETR2D from models.swin_unetr import SwinUNETR from models.flexible_unet import FlexibleUNet from models.flexible_unet_convext import FlexibleUNetConvext print("Successfully imported all requirements!") torch.backends.cudnn.enabled =False #os.environ["OMP_NUM_THREADS"] = "1" #os.environ["MKL_NUM_THREADS"] = "1" def main(): parser = argparse.ArgumentParser("Baseline for Microscopy image segmentation") # Dataset parameters parser.add_argument( "--data_path", default="", type=str, help="training data path; subfolders: images, labels", ) parser.add_argument( "--work_dir", default="/mntnfs/med_data5/louwei/nips_comp/stardist_finetune1/", help="path where to save models and logs" ) parser.add_argument( "--model_dir", default="/", help="path where to load pretrained model" ) parser.add_argument("--seed", default=2022, type=int) # parser.add_argument("--resume", default=False, help="resume from checkpoint") parser.add_argument("--num_workers", default=4, type=int) #parser.add_argument("--local_rank", type=int) # Model parameters parser.add_argument( "--model_name", default="efficientunet", help="select mode: unet, unetr, swinunetr" ) parser.add_argument("--num_class", default=3, type=int, help="segmentation classes") parser.add_argument( "--input_size", default=512, type=int, help="segmentation classes" ) # Training parameters parser.add_argument("--batch_size", default=16, type=int, help="Batch size per GPU") parser.add_argument("--max_epochs", default=2000, type=int) parser.add_argument("--val_interval", default=10, type=int) parser.add_argument("--epoch_tolerance", default=100, type=int) parser.add_argument("--initial_lr", type=float, default=1e-4, help="learning rate") args = parser.parse_args() #torch.cuda.set_device(args.local_rank) #torch.distributed.init_process_group(backend='nccl') monai.config.print_config() n_rays = 32 pre_trained = True #%% set training/validation split np.random.seed(args.seed) pre_trained_path = args.model_dir model_path = join(args.work_dir, args.model_name + "_3class") os.makedirs(model_path, exist_ok=True) run_id = datetime.now().strftime("%Y%m%d-%H%M") # This must be change every runing time ! ! ! ! ! ! ! ! ! ! ! model_file = "models/flexible_unet_convext.py" shutil.copyfile( __file__, join(model_path, os.path.basename(__file__)) ) shutil.copyfile( model_file, join(model_path, os.path.basename(model_file)) ) img_path = join(args.data_path, "train/images") gt_path = join(args.data_path, "train/tif") val_img_path = join(args.data_path, "valid/images") val_gt_path = join(args.data_path, "valid/tif") img_names = sorted(os.listdir(img_path)) gt_names = [img_name.split(".")[0] + ".tif" for img_name in img_names] img_num = len(img_names) val_frac = 0.1 val_img_names = sorted(os.listdir(val_img_path)) val_gt_names = [img_name.split(".")[0] + ".tif" for img_name in val_img_names] #indices = np.arange(img_num) #np.random.shuffle(indices) #val_split = int(img_num * val_frac) #train_indices = indices[val_split:] #val_indices = indices[:val_split] train_files = [ {"img": join(img_path, img_names[i]), "label": join(gt_path, gt_names[i])} for i in range(len(img_names)) ] val_files = [ {"img": join(val_img_path, val_img_names[i]), "label": join(val_gt_path, val_gt_names[i])} for i in range(len(val_img_names)) ] print( f"training image num: {len(train_files)}, validation image num: {len(val_files)}" ) #%% define transforms for image and segmentation train_transforms = Compose( [ LoadImaged( keys=["img", "label"], reader=PILReader, dtype=np.float32 ), # image three channels (H, W, 3); label: (H, W) AddChanneld(keys=["label"], allow_missing_keys=True), # label: (1, H, W) AsChannelFirstd( keys=["img"], channel_dim=-1, allow_missing_keys=True ), # image: (3, H, W) #ScaleIntensityd( #keys=["img"], allow_missing_keys=True #), # Do not scale label SpatialPadd(keys=["img", "label"], spatial_size=args.input_size), RandSpatialCropd( keys=["img", "label"], roi_size=args.input_size, random_size=False ), RandAxisFlipd(keys=["img", "label"], prob=0.5), RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=[0, 1]), # # intensity transform RandGaussianNoised(keys=["img"], prob=0.25, mean=0, std=0.1), RandAdjustContrastd(keys=["img"], prob=0.25, gamma=(1, 2)), RandGaussianSmoothd(keys=["img"], prob=0.25, sigma_x=(1, 2)), RandHistogramShiftd(keys=["img"], prob=0.25, num_control_points=3), RandZoomd( keys=["img", "label"], prob=0.15, min_zoom=0.5, max_zoom=2, mode=["area", "nearest"], ), EnsureTyped(keys=["img", "label"]), ] ) val_transforms = Compose( [ LoadImaged(keys=["img", "label"], reader=PILReader, dtype=np.float32), AddChanneld(keys=["label"], allow_missing_keys=True), AsChannelFirstd(keys=["img"], channel_dim=-1, allow_missing_keys=True), #ScaleIntensityd(keys=["img"], allow_missing_keys=True), # AsDiscreted(keys=['label'], to_onehot=3), EnsureTyped(keys=["img", "label"]), ] ) #% define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) check_loader = DataLoader(check_ds, batch_size=1, num_workers=4) check_data = monai.utils.misc.first(check_loader) print( "sanity check:", check_data["img"].shape, torch.max(check_data["img"]), check_data["label"].shape, torch.max(check_data["label"]), ) #%% create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader( train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory =True, ) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=1) dice_metric = DiceMetric( include_background=False, reduction="mean", get_not_nans=False ) post_pred = Compose( [EnsureType(), Activations(softmax=True), AsDiscrete(threshold=0.5)] ) post_gt = Compose([EnsureType(), AsDiscrete(to_onehot=None)]) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.model_name.lower() == "efficientunet": model = FlexibleUNetConvext( in_channels=3, out_channels=n_rays+1, backbone='convnext_small', pretrained=False, ).to(device) #loss_masked_dice = monai.losses.DiceCELoss(softmax=True) loss_dice = monai.losses.DiceLoss(squared_pred=True,jaccard=True) loss_bce = nn.BCELoss() loss_dist_mae = nn.L1Loss() activatation = nn.ReLU() sigmoid = nn.Sigmoid() #loss_dist_mae = monai.losses.DiceCELoss(softmax=True) initial_lr = args.initial_lr encoder = list(map(id, model.encoder.parameters())) base_params = filter(lambda p: id(p) not in encoder, model.parameters()) params = [ {"params": base_params, "lr":initial_lr}, {"params": model.encoder.parameters(), "lr": initial_lr * 0.1}, ] optimizer = torch.optim.AdamW(params, initial_lr) if pre_trained == True: checkpoint = torch.load(pre_trained_path, map_location=torch.device(device)) model.load_state_dict(checkpoint['model_state_dict']) print('Load pretrained weights...') max_epochs = args.max_epochs epoch_tolerance = args.epoch_tolerance val_interval = args.val_interval best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter(model_path) max_f1 = 0 for epoch in range(0, max_epochs): model.train() epoch_loss = 0 epoch_loss_prob = 0 epoch_loss_dist_2 = 0 epoch_loss_dist_1 = 0 for step, batch_data in enumerate(train_loader, 1): print(step) inputs, labels = batch_data["img"],batch_data["label"] processes_labels = [] for i in range(labels.shape[0]): label = labels[i][0] distances = star_dist(label,n_rays,mode='opencl') distances = np.transpose(distances,(2,0,1)) #print(distances.shape) obj_probabilities = edt_prob(label.astype(int)) obj_probabilities = np.expand_dims(obj_probabilities,0) #print(obj_probabilities.shape) final_label = np.concatenate((distances,obj_probabilities),axis=0) #print(final_label.shape) processes_labels.append(final_label) labels = np.stack(processes_labels) #print(inputs.shape,labels.shape) inputs, labels = torch.tensor(inputs).to(device), torch.tensor(labels).to(device) #print(inputs.shape,labels.shape) optimizer.zero_grad() output_dist,output_prob = model(inputs) #print(outputs.shape) dist_output = output_dist prob_output = output_prob dist_label = labels[:,:n_rays,:,:] prob_label = torch.unsqueeze(labels[:,-1,:,:], 1) #print(dist_output.shape,prob_output.shape,dist_label.shape) #labels_onehot = monai.networks.one_hot( #labels, args.num_class #) # (b,cls,256,256) #print(prob_label.max(),prob_label.min()) loss_dist_1 = loss_dice(dist_output*prob_label,dist_label*prob_label) #print(loss_dist_1) loss_prob = loss_bce(prob_output,prob_label) #print(prob_label.shape,dist_output.shape) loss_dist_2 = loss_dist_mae(dist_output*prob_label,dist_label*prob_label) #print(loss_dist_2) loss = loss_prob + loss_dist_2*0.3 + loss_dist_1 loss.backward() optimizer.step() epoch_loss += loss.item() epoch_loss_prob += loss_prob.item() epoch_loss_dist_2 += loss_dist_2.item() epoch_loss_dist_1 += loss_dist_1.item() epoch_len = len(train_ds) // train_loader.batch_size epoch_loss /= step epoch_loss_prob /= step epoch_loss_dist_2 /= step epoch_loss_dist_1 /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch} average loss: {epoch_loss:.4f}") writer.add_scalar("train_loss", epoch_loss, epoch) print('dist dice: '+str(epoch_loss_dist_1)+' dist mae: '+str(epoch_loss_dist_2)+' prob bce: '+str(epoch_loss_prob)) checkpoint = { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": epoch_loss_values, } if epoch < 40: continue if epoch > 1 and epoch % val_interval == 0: torch.save(checkpoint, join(model_path, str(epoch) + ".pth")) model.eval() with torch.no_grad(): val_images = None val_labels = None val_outputs = None seg_metric = OrderedDict() seg_metric['F1_Score'] = [] for val_data in tqdm.tqdm(val_loader): val_images, val_labels = val_data["img"].to(device), val_data[ "label" ].to(device) roi_size = (512, 512) sw_batch_size = 4 output_dist,output_prob = sliding_window_inference( val_images, roi_size, sw_batch_size, model ) val_labels = val_labels[0][0].cpu().numpy() prob = output_prob[0][0].cpu().numpy() dist = output_dist[0].cpu().numpy() #print(val_labels.shape,prob.shape,dist.shape) dist = np.transpose(dist,(1,2,0)) dist = np.maximum(1e-3, dist) points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4) coord = dist_to_coord(disti,points) star_label = polygons_to_label(disti, points, prob=probi,shape=prob.shape) gt = remove_boundary_cells(val_labels.astype(np.int32)) seg = remove_boundary_cells(star_label.astype(np.int32)) tp, fp, fn = eval_tp_fp_fn(gt, seg, threshold=0.5) if tp == 0: precision = 0 recall = 0 f1 = 0 else: precision = tp / (tp + fp) recall = tp / (tp + fn) f1 = 2*(precision * recall)/ (precision + recall) f1 = np.round(f1, 4) seg_metric['F1_Score'].append(np.round(f1, 4)) avg_f1 = np.mean(seg_metric['F1_Score']) writer.add_scalar("val_f1score", avg_f1, epoch) if avg_f1 > max_f1: max_f1 = avg_f1 print(str(epoch) + 'f1 score: ' + str(max_f1)) torch.save(checkpoint, join(model_path, "best_model.pth")) np.savez_compressed( join(model_path, "train_log.npz"), val_dice=metric_values, epoch_loss=epoch_loss_values, ) if __name__ == "__main__": main()