""" Training Engine Script ver: Feb 8th 16:00 Based on MAE code. https://github.com/facebookresearch/mae """ import math import sys from typing import Iterable import os import torch from torchvision.transforms import ToPILImage import SSL_structures.misc as misc import utils.schedulers as lr_sched from utils.visual_usage import unpatchify, patchify, Draw_tri_fig def train_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, loss_scaler, fix_position_ratio_scheduler=None, puzzle_patch_size_scheduler=None, check_samples=1, print_freq=20, log_writer=None, args=None): model.train(True) # update logger metric_logger = misc.MetricLogger(delimiter=" ") # 初始化学习率记录 metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = 'Epoch: [{}]'.format(epoch) accum_iter = args.accum_iter optimizer.zero_grad() if log_writer is not None: # Tensorboard PATH print('log_dir: {}'.format(args.log_dir)) # Iteration for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): # per iteration lr scheduler基于中间epoch位置 # 来实现更精确的调节学习率:data_iter_step / len(data_loader) + epoch if data_iter_step % accum_iter == 0: lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) # 拿数据 samples = samples.to(device, non_blocking=True) with torch.cuda.amp.autocast(): # 使用自动混合精度加速训练 if fix_position_ratio_scheduler is not None and puzzle_patch_size_scheduler is not None: # SAE fix_position_ratio = fix_position_ratio_scheduler(epoch) puzzle_patch_size = puzzle_patch_size_scheduler(epoch) else: fix_position_ratio = None puzzle_patch_size = None if args.model[0:3] == 'sae': loss, pred, imgs_puzzled_patches = model(samples, fix_position_ratio=fix_position_ratio, puzzle_patch_size=puzzle_patch_size) # SAE else: # args.model[0:3] == 'mae' loss, pred, mask_patch_indicators = model(samples, mask_ratio=args.mask_ratio) # MAE # fixme mae curriculum maybe not good enough for future if args.DDP_distributed: loss_value = loss.item() else: loss_value = float(loss.cpu().detach().numpy()) \ if torch.cuda.device_count() == 1 else sum(loss.cpu().detach().numpy()) if not math.isfinite(loss_value): # 检查确保没有loss爆炸 print("Loss is {}, stopping training".format(loss_value)) sys.exit(1) loss = loss / accum_iter # 计算的是每个minibatch的loss,如果有梯度累加则需要减少占比,loss在loss_scaler里面会进行叠加 # loss backward 核心(不要怕,其实就是功能上集成了loss.backward+opt.step,然后引入了梯度裁剪) loss_scaler(loss, optimizer, parameters=model.parameters(), update_grad=(data_iter_step + 1) % accum_iter == 0) if (data_iter_step + 1) % accum_iter == 0: optimizer.zero_grad() torch.cuda.synchronize() # 等待当前设备上所有流中的所有核心完成 # 更新记录 metric_logger.update(loss=loss_value) lr = optimizer.param_groups[0]["lr"] metric_logger.update(lr=lr) # 计算平均在单卡上的loss loss_value_reduce = misc.all_reduce_mean(loss_value) if log_writer is not None: log_writer.add_scalar('train_loss', loss_value_reduce, epoch) log_writer.add_scalar('lr', lr, epoch) if fix_position_ratio is not None and puzzle_patch_size is not None: log_writer.add_scalar('puzzle_patch_size', puzzle_patch_size, epoch) log_writer.add_scalar('fix_position_ratio', fix_position_ratio, epoch) # gather the stats from all processes metric_logger.synchronize_between_processes() if fix_position_ratio is not None and puzzle_patch_size is not None: print("Averaged stats:", metric_logger, 'fix_position_ratio:', fix_position_ratio, ' puzzle_patch_size:', puzzle_patch_size) else: print("Averaged stats:", metric_logger) # TODO: currently, only paint at the end of each epoch Train, if args.model[0:3] == 'sae': imgs_puzzled_batch = unpatchify(imgs_puzzled_patches, patch_size=16) else: # MAE sample_img_patches = patchify(samples, patch_size=16) # on GPU masked_img_patches = sample_img_patches * \ mask_patch_indicators.unsqueeze(-1).expand(-1, -1, sample_img_patches.shape[-1]) masked_img_batch = unpatchify(masked_img_patches, patch_size=16) # paint images at the end of each epoch on main process if misc.is_main_process(): for sampleIDX in range(check_samples): sample_img = samples.cpu()[sampleIDX] sample_img = ToPILImage()(sample_img) sample_img.save(os.path.join(args.output_dir, 'figs', 'sample_e_' + str(epoch) + '_sampleIDX_' + str(sampleIDX) + '.jpg')) recons_img_batch = unpatchify(pred, patch_size=16) recons_img = recons_img_batch.cpu()[sampleIDX] recons_img = ToPILImage()(recons_img) recons_img.save(os.path.join(args.output_dir, 'figs', 'recons_e_' + str(epoch) + '_sampleIDX_' + str(sampleIDX) + '.jpg')) if args.model[0:3] == 'sae': # SAE puzzled_img = imgs_puzzled_batch.cpu()[sampleIDX] puzzled_img = ToPILImage()(puzzled_img) puzzled_img.save(os.path.join(args.output_dir, 'figs', 'puzzled_e_' + str(epoch) + '_sampleIDX_' + str(sampleIDX) + '.jpg')) picpath = os.path.join(args.output_dir, 'figs', 'puzzled_e_' + str(epoch) + '_sampleIDX_' + str(sampleIDX) + '.jpg') Draw_tri_fig(sample_img, puzzled_img, recons_img, picpath) else: # MAE masked_img = masked_img_batch.cpu()[sampleIDX] # put on CPU masked_img = ToPILImage()(masked_img) masked_img.save(os.path.join(args.output_dir, 'figs', 'masked_e_' + str(epoch) + '_sampleIDX_' + str(sampleIDX) + '.jpg')) picpath = os.path.join(args.output_dir, 'figs', 'masked_e_' + str(epoch) + '_sampleIDX_' + str(sampleIDX) + '.jpg') Draw_tri_fig(sample_img, masked_img, recons_img, picpath) # 返回记录,其他的已经在对象内迭代 return {k: meter.global_avg for k, meter in metric_logger.meters.items()}