|
__all__ = ['Transformer'] |
|
|
|
import wandb |
|
import torch |
|
import numpy as np |
|
import functools |
|
import inspect |
|
import monai |
|
import random |
|
|
|
from tqdm import tqdm |
|
from functools import wraps |
|
from sklearn.base import BaseEstimator |
|
from sklearn.utils.validation import check_is_fitted |
|
from sklearn.model_selection import train_test_split |
|
from scipy.special import expit |
|
from copy import deepcopy |
|
from contextlib import suppress |
|
from typing import Any, Self, Type |
|
Tensor = Type[torch.Tensor] |
|
Module = Type[torch.nn.Module] |
|
from torch.utils.data import DataLoader |
|
from monai.utils.type_conversion import convert_to_tensor |
|
from monai.transforms import ( |
|
LoadImaged, |
|
Compose, |
|
CropForegroundd, |
|
CopyItemsd, |
|
SpatialPadd, |
|
EnsureChannelFirstd, |
|
Spacingd, |
|
OneOf, |
|
ScaleIntensityRanged, |
|
HistogramNormalized, |
|
RandSpatialCropSamplesd, |
|
RandSpatialCropd, |
|
CenterSpatialCropd, |
|
RandCoarseDropoutd, |
|
RandCoarseShuffled, |
|
Resized, |
|
) |
|
|
|
|
|
import torch.distributed as dist |
|
import torch.multiprocessing as mp |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
from .. import nn |
|
from ..utils.misc import ProgressBar |
|
from ..utils.misc import get_metrics_multitask, print_metrics_multitask |
|
from ..utils.misc import convert_args_kwargs_to_kwargs |
|
|
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
def _manage_ctx_fit(func): |
|
''' ... ''' |
|
@wraps(func) |
|
def wrapper(*args, **kwargs): |
|
|
|
kwargs = convert_args_kwargs_to_kwargs(func, args, kwargs) |
|
|
|
if kwargs['self']._device_ids is None: |
|
return func(**kwargs) |
|
else: |
|
|
|
default_device = kwargs['self'].device |
|
kwargs['self'].device = kwargs['self']._device_ids[0] |
|
rtn = func(**kwargs) |
|
kwargs['self'].to(default_device) |
|
return rtn |
|
return wrapper |
|
|
|
def collate_handle_corrupted(samples_list, dataset, labels, dtype=torch.half): |
|
|
|
orig_len = len(samples_list) |
|
|
|
for i, s in enumerate(samples_list): |
|
ic(s is None) |
|
if s is None: |
|
continue |
|
samples_list = list(filter(lambda x: x is not None, samples_list)) |
|
|
|
if len(samples_list) == 0: |
|
ic('recursive call') |
|
return collate_handle_corrupted([dataset[random.randint(0, len(dataset)-1)] for _ in range(orig_len)], dataset, labels) |
|
|
|
|
|
try: |
|
if "image" in samples_list[0]: |
|
samples_list = [s for s in samples_list if not torch.isnan(s["image"]).any()] |
|
|
|
collated_images = torch.stack([convert_to_tensor(s["image"]) for s in samples_list]) |
|
|
|
collated_labels = {k: torch.Tensor([s["label"][k] if s["label"][k] is not None else 0 for s in samples_list]) for k in labels} |
|
|
|
collated_mask = {k: torch.Tensor([1 if s["label"][k] is not None else 0 for s in samples_list]) for k in labels} |
|
|
|
return {"image": collated_images, |
|
"label": collated_labels, |
|
"mask": collated_mask} |
|
except: |
|
return collate_handle_corrupted([dataset[random.randint(0, len(dataset)-1)] for _ in range(orig_len)], dataset, labels) |
|
|
|
|
|
|
|
def get_backend(img_backend): |
|
if img_backend == 'C3D': |
|
return nn.C3D |
|
elif img_backend == 'DenseNet': |
|
return nn.DenseNet |
|
|
|
|
|
class ImagingModel(BaseEstimator): |
|
''' ... ''' |
|
def __init__(self, |
|
tgt_modalities: list[str], |
|
label_fractions: dict[str, float], |
|
num_epochs: int = 32, |
|
batch_size: int = 8, |
|
batch_size_multiplier: int = 1, |
|
lr: float = 1e-2, |
|
weight_decay: float = 0.0, |
|
beta: float = 0.9999, |
|
gamma: float = 2.0, |
|
bn_size: int = 4, |
|
growth_rate: int = 12, |
|
block_config: tuple = (3, 3, 3), |
|
compression: float = 0.5, |
|
num_init_features: int = 16, |
|
drop_rate: float = 0.2, |
|
criterion: str | None = None, |
|
device: str = 'cpu', |
|
cuda_devices: list = [1], |
|
ckpt_path: str = '/home/skowshik/ADRD_repo/adrd_tool/dev/ckpt/ckpt.pt', |
|
load_from_ckpt: bool = True, |
|
save_intermediate_ckpts: bool = False, |
|
data_parallel: bool = False, |
|
verbose: int = 0, |
|
img_backend: str | None = None, |
|
label_distribution: dict = {}, |
|
wandb_ = 1, |
|
_device_ids: list | None = None, |
|
_dataloader_num_workers: int = 4, |
|
_amp_enabled: bool = False, |
|
) -> None: |
|
''' ... ''' |
|
|
|
self._rank = 0 |
|
self._lock = None |
|
|
|
|
|
self.tgt_modalities = tgt_modalities |
|
|
|
|
|
self.label_fractions = label_fractions |
|
self.num_epochs = num_epochs |
|
self.batch_size = batch_size |
|
self.batch_size_multiplier = batch_size_multiplier |
|
self.lr = lr |
|
self.weight_decay = weight_decay |
|
self.beta = beta |
|
self.gamma = gamma |
|
self.bn_size = bn_size |
|
self.growth_rate = growth_rate |
|
self.block_config = block_config |
|
self.compression = compression |
|
self.num_init_features = num_init_features |
|
self.drop_rate = drop_rate |
|
self.criterion = criterion |
|
self.device = device |
|
self.cuda_devices = cuda_devices |
|
self.ckpt_path = ckpt_path |
|
self.load_from_ckpt = load_from_ckpt |
|
self.save_intermediate_ckpts = save_intermediate_ckpts |
|
self.data_parallel = data_parallel |
|
self.verbose = verbose |
|
self.img_backend = img_backend |
|
self.label_distribution = label_distribution |
|
self.wandb_ = wandb_ |
|
self._device_ids = _device_ids |
|
self._dataloader_num_workers = _dataloader_num_workers |
|
self._amp_enabled = _amp_enabled |
|
self.scaler = torch.cuda.amp.GradScaler() |
|
|
|
@_manage_ctx_fit |
|
def fit(self, trn_list, vld_list, img_train_trans=None, img_vld_trans=None) -> Self: |
|
|
|
''' ... ''' |
|
|
|
|
|
if self.wandb_ == 1: |
|
wandb.init( |
|
|
|
project="ADRD_main", |
|
|
|
|
|
config={ |
|
"Model": "DenseNet", |
|
"Loss": 'Focalloss', |
|
"EMB": "ALL_EMB", |
|
"epochs": 256, |
|
} |
|
) |
|
wandb.run.log_code("/home/skowshik/ADRD_repo/pipeline_v1_main/adrd_tool") |
|
else: |
|
wandb.init(mode="disabled") |
|
|
|
torch.set_num_threads(1) |
|
print(self.criterion) |
|
|
|
|
|
self._init_net() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ldr_trn, ldr_vld = self._init_dataloader(trn_list, vld_list, img_train_trans=img_train_trans, img_vld_trans=img_vld_trans) |
|
|
|
|
|
if not self.load_from_ckpt: |
|
self.optimizer = self._init_optimizer() |
|
self.scheduler = self._init_scheduler(self.optimizer) |
|
|
|
|
|
if self._amp_enabled: |
|
self.scaler = torch.cuda.amp.GradScaler() |
|
|
|
|
|
self.loss_fn = {} |
|
|
|
for k in self.tgt_modalities: |
|
if self.label_fractions[k] >= 0.3: |
|
alpha = -1 |
|
else: |
|
alpha = pow((1 - self.label_fractions[k]), 2) |
|
|
|
self.loss_fn[k] = nn.SigmoidFocalLoss( |
|
alpha = alpha, |
|
gamma = self.gamma, |
|
reduction = 'none' |
|
) |
|
|
|
|
|
if self.criterion is not None: |
|
best_crit = None |
|
best_crit_AUPR = None |
|
|
|
|
|
if self.verbose == 1: |
|
with self._lock if self._lock is not None else suppress(): |
|
pbr_epoch = tqdm( |
|
desc = 'Rank {:02d}'.format(self._rank), |
|
total = self.num_epochs, |
|
position = self._rank, |
|
ascii = True, |
|
leave = False, |
|
bar_format='{l_bar}{r_bar}' |
|
) |
|
|
|
|
|
def print_and_store_grad(grad, grad_list): |
|
grad_list.append(grad) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for epoch in range(self.start_epoch, self.num_epochs): |
|
met_trn = self.train_one_epoch(ldr_trn, epoch) |
|
met_vld = self.validate_one_epoch(ldr_vld, epoch) |
|
|
|
print(self.ckpt_path.split('/')[-1]) |
|
|
|
|
|
if self.criterion is None: continue |
|
|
|
|
|
|
|
curr_crit = np.mean([met_vld[i][self.criterion] for i in range(len(self.tgt_modalities))]) |
|
curr_crit_AUPR = np.mean([met_vld[i]["AUC (PR)"] for i in range(len(self.tgt_modalities))]) |
|
|
|
if best_crit is None or np.isnan(best_crit): |
|
is_better = True |
|
elif self.criterion == 'Loss' and best_crit >= curr_crit: |
|
is_better = True |
|
elif self.criterion != 'Loss' and best_crit <= curr_crit : |
|
is_better = True |
|
else: |
|
is_better = False |
|
|
|
|
|
if best_crit_AUPR is None or np.isnan(best_crit_AUPR): |
|
is_better_AUPR = True |
|
elif best_crit_AUPR <= curr_crit_AUPR : |
|
is_better_AUPR = True |
|
else: |
|
is_better_AUPR = False |
|
|
|
|
|
if is_better_AUPR: |
|
best_crit_AUPR = curr_crit_AUPR |
|
if self.save_intermediate_ckpts: |
|
print(f"Saving the model to {self.ckpt_path[:-3]}_AUPR.pt...") |
|
self.save(self.ckpt_path[:-3]+"_AUPR.pt", epoch) |
|
|
|
if is_better: |
|
best_crit = curr_crit |
|
best_state_dict = deepcopy(self.net_.state_dict()) |
|
if self.save_intermediate_ckpts: |
|
print(f"Saving the model to {self.ckpt_path}...") |
|
self.save(self.ckpt_path, epoch) |
|
|
|
if self.verbose > 2: |
|
print('Best {}: {}'.format(self.criterion, best_crit)) |
|
print('Best {}: {}'.format('AUC (PR)', best_crit_AUPR)) |
|
|
|
if self.verbose == 1: |
|
with self._lock if self._lock is not None else suppress(): |
|
pbr_epoch.update(1) |
|
pbr_epoch.refresh() |
|
|
|
return self |
|
|
|
def train_one_epoch(self, ldr_trn, epoch): |
|
|
|
|
|
if self.verbose > 1: |
|
pbr_batch = ProgressBar(len(ldr_trn.dataset), 'Epoch {:03d} (TRN)'.format(epoch)) |
|
|
|
torch.set_grad_enabled(True) |
|
self.net_.train() |
|
|
|
scores_trn, y_true_trn, y_mask_trn = [], [], [] |
|
losses_trn = [[] for _ in self.tgt_modalities] |
|
iters = len(ldr_trn) |
|
print(iters) |
|
for n_iter, batch_data in enumerate(ldr_trn): |
|
|
|
|
|
|
|
x_batch = batch_data["image"].to(self.device, non_blocking=True) |
|
y_batch = {k: v.to(self.device, non_blocking=True) for k,v in batch_data["label"].items()} |
|
y_mask = {k: v.to(self.device, non_blocking=True) for k,v in batch_data["mask"].items()} |
|
|
|
with torch.autocast( |
|
device_type = 'cpu' if self.device == 'cpu' else 'cuda', |
|
dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16, |
|
enabled = self._amp_enabled, |
|
): |
|
|
|
outputs = self.net_(x_batch, shap=False) |
|
|
|
|
|
loss = 0 |
|
for i, k in enumerate(self.tgt_modalities): |
|
loss_task = self.loss_fn[k](outputs[k], y_batch[k]) |
|
msk_loss_task = loss_task * y_mask[k] |
|
msk_loss_mean = msk_loss_task.sum() / y_mask[k].sum() |
|
loss += msk_loss_mean |
|
losses_trn[i] += msk_loss_task.detach().cpu().numpy().tolist() |
|
|
|
|
|
if self._amp_enabled: |
|
self.scaler.scale(loss).backward() |
|
else: |
|
loss.backward() |
|
|
|
|
|
|
|
|
|
|
|
if n_iter != 0 and n_iter % self.batch_size_multiplier == 0: |
|
if self._amp_enabled: |
|
self.scaler.step(self.optimizer) |
|
self.scaler.update() |
|
self.optimizer.zero_grad() |
|
else: |
|
self.optimizer.step() |
|
self.optimizer.zero_grad() |
|
|
|
self.scheduler.step(epoch + n_iter / iters) |
|
|
|
|
|
''' TODO: change array to dictionary later ''' |
|
outputs = torch.stack(list(outputs.values()), dim=1) |
|
y_batch = torch.stack(list(y_batch.values()), dim=1) |
|
y_mask = torch.stack(list(y_mask.values()), dim=1) |
|
|
|
|
|
scores_trn.append(outputs.detach().to(torch.float).cpu()) |
|
y_true_trn.append(y_batch.cpu()) |
|
y_mask_trn.append(y_mask.cpu()) |
|
|
|
|
|
|
|
|
|
if self.verbose > 1: |
|
batch_size = len(x_batch) |
|
pbr_batch.update(batch_size, {}) |
|
pbr_batch.refresh() |
|
|
|
|
|
if "cuda" in self.device: |
|
torch.cuda.empty_cache() |
|
|
|
|
|
if self.verbose > 1: |
|
pbr_batch.close() |
|
|
|
|
|
|
|
|
|
|
|
scores_trn = torch.cat(scores_trn) |
|
y_true_trn = torch.cat(y_true_trn) |
|
y_mask_trn = torch.cat(y_mask_trn) |
|
y_pred_trn = (scores_trn > 0).to(torch.int) |
|
y_prob_trn = torch.sigmoid(scores_trn) |
|
met_trn = get_metrics_multitask( |
|
y_true_trn.numpy(), |
|
y_pred_trn.numpy(), |
|
y_prob_trn.numpy(), |
|
y_mask_trn.numpy() |
|
) |
|
|
|
|
|
for i in range(len(self.tgt_modalities)): |
|
met_trn[i]['Loss'] = np.mean(losses_trn[i]) |
|
|
|
wandb.log({f"Train loss {list(self.tgt_modalities)[i]}": met_trn[i]['Loss'] for i in range(len(self.tgt_modalities))}, step=epoch) |
|
wandb.log({f"Train Balanced Accuracy {list(self.tgt_modalities)[i]}": met_trn[i]['Balanced Accuracy'] for i in range(len(self.tgt_modalities))}, step=epoch) |
|
|
|
wandb.log({f"Train AUC (ROC) {list(self.tgt_modalities)[i]}": met_trn[i]['AUC (ROC)'] for i in range(len(self.tgt_modalities))}, step=epoch) |
|
wandb.log({f"Train AUPR {list(self.tgt_modalities)[i]}": met_trn[i]['AUC (PR)'] for i in range(len(self.tgt_modalities))}, step=epoch) |
|
|
|
if self.verbose > 2: |
|
print_metrics_multitask(met_trn) |
|
|
|
return met_trn |
|
|
|
|
|
def validate_one_epoch(self, ldr_vld, epoch): |
|
|
|
if self.verbose > 1: |
|
pbr_batch = ProgressBar(len(ldr_vld.dataset), 'Epoch {:03d} (VLD)'.format(epoch)) |
|
|
|
|
|
torch.set_grad_enabled(False) |
|
self.net_.eval() |
|
|
|
scores_vld, y_true_vld, y_mask_vld = [], [], [] |
|
losses_vld = [[] for _ in self.tgt_modalities] |
|
for batch_data in ldr_vld: |
|
|
|
|
|
x_batch = batch_data["image"].to(self.device, non_blocking=True) |
|
y_batch = {k: v.to(self.device, non_blocking=True) for k,v in batch_data["label"].items()} |
|
y_mask = {k: v.to(self.device, non_blocking=True) for k,v in batch_data["mask"].items()} |
|
|
|
|
|
with torch.autocast( |
|
device_type = 'cpu' if self.device == 'cpu' else 'cuda', |
|
dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16, |
|
enabled = self._amp_enabled |
|
): |
|
|
|
outputs = self.net_(x_batch, shap=False) |
|
|
|
|
|
for i, k in enumerate(self.tgt_modalities): |
|
loss_task = self.loss_fn[k](outputs[k], y_batch[k]) |
|
msk_loss_task = loss_task * y_mask[k] |
|
losses_vld[i] += msk_loss_task.detach().cpu().numpy().tolist() |
|
|
|
''' TODO: change array to dictionary later ''' |
|
outputs = torch.stack(list(outputs.values()), dim=1) |
|
y_batch = torch.stack(list(y_batch.values()), dim=1) |
|
y_mask = torch.stack(list(y_mask.values()), dim=1) |
|
|
|
|
|
scores_vld.append(outputs.detach().to(torch.float).cpu()) |
|
y_true_vld.append(y_batch.cpu()) |
|
y_mask_vld.append(y_mask.cpu()) |
|
|
|
|
|
if self.verbose > 1: |
|
batch_size = len(x_batch) |
|
pbr_batch.update(batch_size, {}) |
|
pbr_batch.refresh() |
|
|
|
|
|
if "cuda" in self.device: |
|
torch.cuda.empty_cache() |
|
|
|
|
|
if self.verbose > 1: |
|
pbr_batch.close() |
|
|
|
|
|
scores_vld = torch.cat(scores_vld) |
|
y_true_vld = torch.cat(y_true_vld) |
|
y_mask_vld = torch.cat(y_mask_vld) |
|
y_pred_vld = (scores_vld > 0).to(torch.int) |
|
y_prob_vld = torch.sigmoid(scores_vld) |
|
met_vld = get_metrics_multitask( |
|
y_true_vld.numpy(), |
|
y_pred_vld.numpy(), |
|
y_prob_vld.numpy(), |
|
y_mask_vld.numpy() |
|
) |
|
|
|
|
|
for i in range(len(self.tgt_modalities)): |
|
met_vld[i]['Loss'] = np.mean(losses_vld[i]) |
|
|
|
wandb.log({f"Validation loss {list(self.tgt_modalities)[i]}": met_vld[i]['Loss'] for i in range(len(self.tgt_modalities))}, step=epoch) |
|
wandb.log({f"Validation Balanced Accuracy {list(self.tgt_modalities)[i]}": met_vld[i]['Balanced Accuracy'] for i in range(len(self.tgt_modalities))}, step=epoch) |
|
|
|
wandb.log({f"Validation AUC (ROC) {list(self.tgt_modalities)[i]}": met_vld[i]['AUC (ROC)'] for i in range(len(self.tgt_modalities))}, step=epoch) |
|
wandb.log({f"Validation AUPR {list(self.tgt_modalities)[i]}": met_vld[i]['AUC (PR)'] for i in range(len(self.tgt_modalities))}, step=epoch) |
|
|
|
if self.verbose > 2: |
|
print_metrics_multitask(met_vld) |
|
|
|
return met_vld |
|
|
|
|
|
def save(self, filepath: str, epoch: int = 0) -> None: |
|
''' ... ''' |
|
check_is_fitted(self) |
|
if self.data_parallel: |
|
state_dict = self.net_.module.state_dict() |
|
else: |
|
state_dict = self.net_.state_dict() |
|
|
|
|
|
state_dict['tgt_modalities'] = self.tgt_modalities |
|
state_dict['optimizer'] = self.optimizer |
|
state_dict['bn_size'] = self.bn_size |
|
state_dict['growth_rate'] = self.growth_rate |
|
state_dict['block_config'] = self.block_config |
|
state_dict['compression'] = self.compression |
|
state_dict['num_init_features'] = self.num_init_features |
|
state_dict['drop_rate'] = self.drop_rate |
|
state_dict['epoch'] = epoch |
|
|
|
if self.scaler is not None: |
|
state_dict['scaler'] = self.scaler.state_dict() |
|
if self.label_distribution: |
|
state_dict['label_distribution'] = self.label_distribution |
|
|
|
torch.save(state_dict, filepath) |
|
|
|
def load(self, filepath: str, map_location: str = 'cpu', how='latest') -> None: |
|
''' ... ''' |
|
|
|
if how == 'latest': |
|
if torch.load(filepath)['epoch'] > torch.load(f'{filepath[:-3]}_AUPR.pt')['epoch']: |
|
print("Loading model saved using AUROC") |
|
state_dict = torch.load(filepath, map_location=map_location) |
|
else: |
|
print("Loading model saved using AUPR") |
|
state_dict = torch.load(f'{filepath[:-3]}_AUPR.pt', map_location=map_location) |
|
else: |
|
state_dict = torch.load(filepath, map_location=map_location) |
|
|
|
|
|
self.tgt_modalities: dict[str, dict[str, Any]] = state_dict.pop('tgt_modalities') |
|
if 'label_distribution' in state_dict: |
|
self.label_distribution: dict[str, dict[int, int]] = state_dict.pop('label_distribution') |
|
if 'optimizer' in state_dict: |
|
self.optimizer = state_dict.pop('optimizer') |
|
if 'bn_size' in state_dict: |
|
self.bn_size = state_dict.pop('bn_size') |
|
if 'growth_rate' in state_dict: |
|
self.growth_rate = state_dict.pop('growth_rate') |
|
if 'block_config' in state_dict: |
|
self.block_config = state_dict.pop('block_config') |
|
if 'compression' in state_dict: |
|
self.compression = state_dict.pop('compression') |
|
if 'num_init_features' in state_dict: |
|
self.num_init_features = state_dict.pop('num_init_features') |
|
if 'drop_rate' in state_dict: |
|
self.drop_rate = state_dict.pop('drop_rate') |
|
if 'epoch' in state_dict: |
|
self.start_epoch = state_dict.pop('epoch') |
|
print(f'Epoch: {self.start_epoch}') |
|
|
|
|
|
|
|
self.net_ = get_backend(self.img_backend)( |
|
tgt_modalities = self.tgt_modalities, |
|
bn_size = self.bn_size, |
|
growth_rate=self.growth_rate, |
|
block_config=self.block_config, |
|
compression=self.compression, |
|
num_init_features=self.num_init_features, |
|
drop_rate=self.drop_rate, |
|
load_from_ckpt=self.load_from_ckpt |
|
) |
|
print(self.net_) |
|
|
|
if 'scaler' in state_dict and state_dict['scaler']: |
|
self.scaler.load_state_dict(state_dict.pop('scaler')) |
|
self.net_.load_state_dict(state_dict) |
|
check_is_fitted(self) |
|
self.net_.to(self.device) |
|
|
|
def to(self, device: str) -> Self: |
|
''' Mount model to the given device. ''' |
|
self.device = device |
|
if hasattr(self, 'model'): self.net_ = self.net_.to(device) |
|
return self |
|
|
|
@classmethod |
|
def from_ckpt(cls, filepath: str, device='cpu', img_backend=None, load_from_ckpt=True, how='latest') -> Self: |
|
''' ... ''' |
|
obj = cls(None, None, None,device=device) |
|
if device == 'cuda': |
|
obj.device = "{}:{}".format(obj.device, str(obj.cuda_devices[0])) |
|
print(obj.device) |
|
obj.img_backend=img_backend |
|
obj.load_from_ckpt = load_from_ckpt |
|
obj.load(filepath, map_location=obj.device, how=how) |
|
return obj |
|
|
|
def _init_net(self): |
|
""" ... """ |
|
self.start_epoch = 0 |
|
|
|
if self.device == 'cuda': |
|
self.device = "{}:{}".format(self.device, str(self.cuda_devices[0])) |
|
|
|
|
|
|
|
|
|
if self.load_from_ckpt: |
|
try: |
|
print("Loading model from checkpoint...") |
|
self.load(self.ckpt_path, map_location=self.device) |
|
except: |
|
print("Cannot load from checkpoint. Initializing new model...") |
|
self.load_from_ckpt = False |
|
|
|
if not self.load_from_ckpt: |
|
self.net_ = get_backend(self.img_backend)( |
|
tgt_modalities = self.tgt_modalities, |
|
bn_size = self.bn_size, |
|
growth_rate=self.growth_rate, |
|
block_config=self.block_config, |
|
compression=self.compression, |
|
num_init_features=self.num_init_features, |
|
drop_rate=self.drop_rate, |
|
load_from_ckpt=self.load_from_ckpt |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.net_.to(self.device) |
|
|
|
|
|
if self.data_parallel and torch.cuda.device_count() > 1: |
|
print("Available", torch.cuda.device_count(), "GPUs!") |
|
self.net_ = torch.nn.DataParallel(self.net_, device_ids=self.cuda_devices) |
|
|
|
|
|
|
|
def _init_dataloader(self, trn_list, vld_list, img_train_trans=None, img_vld_trans=None): |
|
|
|
""" ... """ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dat_trn = monai.data.Dataset(data=trn_list, transform=img_train_trans) |
|
dat_vld = monai.data.Dataset(data=vld_list, transform=img_vld_trans) |
|
collate_fn_trn = functools.partial(collate_handle_corrupted, dataset=dat_trn, dtype=torch.FloatTensor, labels=self.tgt_modalities) |
|
collate_fn_vld = functools.partial(collate_handle_corrupted, dataset=dat_vld, dtype=torch.FloatTensor, labels=self.tgt_modalities) |
|
|
|
ldr_trn = DataLoader( |
|
dataset = dat_trn, |
|
batch_size = self.batch_size, |
|
shuffle = True, |
|
drop_last = False, |
|
num_workers = self._dataloader_num_workers, |
|
collate_fn = collate_fn_trn, |
|
|
|
) |
|
|
|
ldr_vld = DataLoader( |
|
dataset = dat_vld, |
|
batch_size = self.batch_size, |
|
shuffle = False, |
|
drop_last = False, |
|
num_workers = self._dataloader_num_workers, |
|
collate_fn = collate_fn_vld, |
|
|
|
) |
|
|
|
return ldr_trn, ldr_vld |
|
|
|
def _init_optimizer(self): |
|
""" ... """ |
|
params = list(self.net_.parameters()) |
|
|
|
|
|
return torch.optim.AdamW( |
|
params, |
|
lr = self.lr, |
|
betas = (0.9, 0.98), |
|
weight_decay = self.weight_decay |
|
) |
|
|
|
def _init_scheduler(self, optimizer): |
|
""" ... """ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( |
|
optimizer=optimizer, |
|
T_0=64, |
|
T_mult=2, |
|
eta_min = 0, |
|
verbose=(self.verbose > 2) |
|
) |
|
|
|
def _init_loss_func(self, |
|
num_per_cls: dict[str, tuple[int, int]], |
|
) -> dict[str, Module]: |
|
""" ... """ |
|
return {k: nn.SigmoidFocalLossBeta( |
|
beta = self.beta, |
|
gamma = self.gamma, |
|
num_per_cls = num_per_cls[k], |
|
reduction = 'none', |
|
) for k in self.tgt_modalities} |
|
|
|
def _proc_fit(self): |
|
""" ... """ |
|
|
|
def _init_test_dataloader(self, batch_size, tst_list, img_tst_trans=None): |
|
|
|
check_is_fitted(self) |
|
print(self.device) |
|
|
|
|
|
torch.set_num_threads(1) |
|
|
|
|
|
torch.set_grad_enabled(False) |
|
self.net_.eval() |
|
|
|
dat_tst = monai.data.Dataset(data=tst_list, transform=img_tst_trans) |
|
collate_fn_tst = functools.partial(collate_handle_corrupted, dataset=dat_tst, dtype=torch.FloatTensor, labels=self.tgt_modalities) |
|
|
|
|
|
ldr_tst = DataLoader( |
|
dataset = dat_tst, |
|
batch_size = batch_size, |
|
shuffle = False, |
|
drop_last = False, |
|
num_workers = self._dataloader_num_workers, |
|
collate_fn = collate_fn_tst, |
|
|
|
) |
|
return ldr_tst |
|
|
|
|
|
def predict_logits(self, |
|
ldr_tst: Any | None = None, |
|
) -> list[dict[str, float]]: |
|
|
|
|
|
logits: list[dict[str, float]] = [] |
|
for batch_data in tqdm(ldr_tst): |
|
|
|
if len(batch_data) == 0: |
|
continue |
|
x_batch = batch_data["image"].to(self.device, non_blocking=True) |
|
outputs = self.net_(x_batch, shap=False) |
|
|
|
|
|
tmp = {k: outputs[k].tolist() for k in self.tgt_modalities} |
|
tmp = [{k: tmp[k][i] for k in self.tgt_modalities} for i in range(len(next(iter(tmp.values()))))] |
|
logits += tmp |
|
|
|
return logits |
|
|
|
def predict_proba(self, |
|
ldr_tst: Any | None = None, |
|
temperature: float = 1.0, |
|
) -> list[dict[str, float]]: |
|
''' ... ''' |
|
logits = self.predict_logits(ldr_tst) |
|
print("got logits") |
|
return logits, [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits] |
|
|
|
def predict(self, |
|
ldr_tst: Any | None = None, |
|
) -> list[dict[str, int]]: |
|
''' ... ''' |
|
logits, proba = self.predict_proba(ldr_tst) |
|
print("got proba") |
|
return logits, proba, [{k: int(smp[k] > 0.5) for k in self.tgt_modalities} for smp in proba] |