| | import torch |
| | from torch import nn |
| | from torch.optim import Optimizer |
| | from torch.utils.data import DataLoader |
| | from torch.amp import GradScaler, autocast |
| | import numpy as np |
| | from tqdm import tqdm |
| | from typing import Dict, Tuple, Union |
| | from copy import deepcopy |
| |
|
| | from utils import barrier, reduce_mean, update_loss_info |
| | from evaluate import evaluate |
| |
|
| |
|
| | def train( |
| | model: nn.Module, |
| | data_loader: DataLoader, |
| | loss_fn: nn.Module, |
| | optimizer: Optimizer, |
| | grad_scaler: Union[GradScaler, None], |
| | device: torch.device = torch.device("cuda"), |
| | rank: int = 0, |
| | nprocs: int = 1, |
| | **kwargs, |
| | ) -> Tuple[nn.Module, Optimizer, GradScaler, Dict[str, float]]: |
| | info = None |
| | data_iter = tqdm(data_loader) if rank == 0 else data_loader |
| | ddp = nprocs > 1 |
| |
|
| | if "eval_data_loader" in kwargs: |
| | assert "eval_freq" in kwargs and 0 < kwargs["eval_freq"] < 1, f"eval_freq should be a float between 0 and 1, but got {kwargs['eval_freq']}" |
| | assert "sliding_window" in kwargs, "sliding_window should be provided in kwargs" |
| | assert "max_input_size" in kwargs, "max_input_size should be provided in kwargs" |
| | assert "window_size" in kwargs, "window_size should be provided in kwargs" |
| | assert "stride" in kwargs, "stride should be provided in kwargs" |
| | assert "max_num_windows" in kwargs, "max_num_windows should be provided in kwargs" |
| |
|
| | eval_within_epoch = True |
| | eval_data_loader = kwargs["eval_data_loader"] |
| | eval_freq = int(kwargs["eval_freq"] * len(data_loader)) |
| | sliding_window = kwargs["sliding_window"] |
| | max_input_size = kwargs["max_input_size"] |
| | window_size = kwargs["window_size"] |
| | stride = kwargs["stride"] |
| | max_num_windows = kwargs["max_num_windows"] |
| |
|
| | best_scores = {} |
| | best_weights = {} |
| |
|
| | else: |
| | eval_within_epoch = False |
| | best_scores = None |
| | best_weights = None |
| | |
| | for batch_idx, (image, gt_points, gt_den_map) in enumerate(data_iter): |
| | image = image.to(device) |
| | gt_points = [p.to(device) for p in gt_points] |
| | gt_den_map = gt_den_map.to(device) |
| | model.train() |
| | with torch.set_grad_enabled(True): |
| | with autocast(device_type="cuda", enabled=grad_scaler is not None and grad_scaler.is_enabled()): |
| | if (model.module.zero_inflated if ddp else model.zero_inflated): |
| | pred_logit_pi_map, pred_logit_map, pred_lambda_map, pred_den_map = model(image) |
| | total_loss, total_loss_info = loss_fn( |
| | pred_logit_pi_map=pred_logit_pi_map, |
| | pred_logit_map=pred_logit_map, |
| | pred_lambda_map=pred_lambda_map, |
| | pred_den_map=pred_den_map, |
| | gt_den_map=gt_den_map, |
| | gt_points=gt_points, |
| | ) |
| | else: |
| | pred_logit_map, pred_den_map = model(image) |
| | total_loss, total_loss_info = loss_fn( |
| | pred_logit_map=pred_logit_map, |
| | pred_den_map=pred_den_map, |
| | gt_den_map=gt_den_map, |
| | gt_points=gt_points, |
| | ) |
| |
|
| | optimizer.zero_grad() |
| | if grad_scaler is not None: |
| | grad_scaler.scale(total_loss).backward() |
| | grad_scaler.step(optimizer) |
| | grad_scaler.update() |
| | else: |
| | total_loss.backward() |
| | optimizer.step() |
| |
|
| | total_loss_info = {k: reduce_mean(v.detach(), nprocs).item() if ddp else v.detach().item() for k, v in total_loss_info.items()} |
| | info = update_loss_info(info, total_loss_info) |
| | barrier(ddp) |
| | |
| | if eval_within_epoch and ((batch_idx + 1) % eval_freq == 0 or batch_idx == len(data_loader) - 1): |
| | batch_scores = evaluate( |
| | model=model, |
| | data_loader=eval_data_loader, |
| | sliding_window=sliding_window, |
| | max_input_size=max_input_size, |
| | window_size=window_size, |
| | stride=stride, |
| | max_num_windows=max_num_windows, |
| | device=device, |
| | amp=grad_scaler is not None and grad_scaler.is_enabled(), |
| | local_rank=rank, |
| | nprocs=nprocs, |
| | progress_bar=False, |
| | ) |
| | for k, v in batch_scores.items(): |
| | if k not in best_scores: |
| | best_scores[k] = v |
| | best_weights[k] = deepcopy(model.module.state_dict() if ddp else model.state_dict()) |
| | elif v < best_scores[k]: |
| | best_scores[k] = v |
| | best_weights[k] = deepcopy(model.module.state_dict() if ddp else model.state_dict()) |
| |
|
| | barrier(ddp) |
| |
|
| | torch.cuda.empty_cache() |
| | return model, optimizer, grad_scaler, {k: np.mean(v) for k, v in info.items()}, best_scores, best_weights |
| |
|