import pickle import torch from torchvision import models import random import logging import numpy as np import json def set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True def set_logger(log_path): logger = logging.getLogger() logger.setLevel(logging.INFO) if not logger.handlers: # Logging to a file file_handler = logging.FileHandler(log_path) file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) logger.addHandler(file_handler) # Logging to console stream_handler = logging.StreamHandler() stream_handler.setFormatter(logging.Formatter('%(message)s')) logger.addHandler(stream_handler) def to_np(x): return x.data.cpu().numpy() def get_ids(length_dataset): ids = list(range(length_dataset)) random.shuffle(ids) train_split = round(0.6 * length_dataset) t_v_spplit = (length_dataset - train_split) // 2 train_ids = ids[:train_split] valid_ids = ids[train_split:train_split+t_v_spplit] test_ids = ids[train_split+t_v_spplit:] return train_ids, valid_ids, test_ids def dice_score(y, y_pred, smooth=1.0, thres=0.9): n = y.shape[0] y = y.view(n, -1) y_pred = y_pred.view(n, -1) # y_pred_[y_pred>=thres] = 1.0 # y_pred_[y_pred=thres] = 1.0 y_pred[y_pred