code of stage1 & 3, remove large files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- 1_feature_extractor/1_main_training_IB.py +624 -0
- 1_feature_extractor/1_training_IB.sh +40 -0
- 1_feature_extractor/LICENSE +21 -0
- 1_feature_extractor/README copy.md +24 -0
- 1_feature_extractor/README.md +17 -0
- 1_feature_extractor/__pycache__/augmentations.cpython-39.pyc +0 -0
- 1_feature_extractor/__pycache__/datasets.cpython-39.pyc +0 -0
- 1_feature_extractor/__pycache__/losses_hint.cpython-39.pyc +0 -0
- 1_feature_extractor/__pycache__/models_IB.cpython-39.pyc +0 -0
- 1_feature_extractor/__pycache__/models_clip.cpython-39.pyc +0 -0
- 1_feature_extractor/__pycache__/models_dinov2.cpython-39.pyc +0 -0
- 1_feature_extractor/__pycache__/models_proteus_clip.cpython-39.pyc +0 -0
- 1_feature_extractor/__pycache__/models_proteus_dinov2.cpython-39.pyc +0 -0
- 1_feature_extractor/__pycache__/models_proteus_synclr.cpython-39.pyc +0 -0
- 1_feature_extractor/__pycache__/models_synclr.cpython-39.pyc +0 -0
- 1_feature_extractor/__pycache__/samplers.cpython-39.pyc +0 -0
- 1_feature_extractor/__pycache__/utils.cpython-39.pyc +0 -0
- 1_feature_extractor/augmentations.py +94 -0
- 1_feature_extractor/datasets.py +110 -0
- 1_feature_extractor/fast_vis.sh +37 -0
- 1_feature_extractor/fast_vis_proteus_feats.py +98 -0
- 1_feature_extractor/fast_vis_settings_all.py +548 -0
- 1_feature_extractor/log/DINOv2_training/log.txt +203 -0
- 1_feature_extractor/log/DINOv2_training/log/20240725_001002.log +0 -0
- 1_feature_extractor/log/DINOv2_training/log/20240725_084736.log +555 -0
- 1_feature_extractor/log/DINOv2_training/log/20240725_085916.log +0 -0
- 1_feature_extractor/log/DINOv2_training/log/20240726_110417.log +0 -0
- 1_feature_extractor/log/DINOv2_training/log/20240726_171814.log +0 -0
- 1_feature_extractor/log/DINOv2_training/log/20240728_153020.log +0 -0
- 1_feature_extractor/log/DINOv2_training/log/20240728_214526.log +0 -0
- 1_feature_extractor/log/DINOv2_training/log/20240729_102738.log +0 -0
- 1_feature_extractor/log/DINOv2_training/log/20240730_084148.log +301 -0
- 1_feature_extractor/log/DINOv2_training/log/20240730_085449.log +0 -0
- 1_feature_extractor/log/DINOv2_training/log/20240731_102940.log +0 -0
- 1_feature_extractor/log/DINOv2_training/log/20240801_091959.log +0 -0
- 1_feature_extractor/log/DINOv2_training/log/20240801_155326.log +0 -0
- 1_feature_extractor/log/DINOv2_training/log/20240803_163338.log +0 -0
- 1_feature_extractor/log/DINOv2_training/log/20240803_231933.log +0 -0
- 1_feature_extractor/log/DINOv2_training/log/20240804_144252.log +0 -0
- 1_feature_extractor/losses_hint.py +49 -0
- 1_feature_extractor/main.py +520 -0
- 1_feature_extractor/models_IB.py +40 -0
- 1_feature_extractor/models_clip.py +438 -0
- 1_feature_extractor/models_dinov2.py +907 -0
- 1_feature_extractor/models_proteus_clip.py +101 -0
- 1_feature_extractor/models_proteus_dinov2.py +200 -0
- 1_feature_extractor/models_proteus_synclr.py +161 -0
- 1_feature_extractor/models_synclr.py +500 -0
- 1_feature_extractor/original_images.png +0 -0
- 1_feature_extractor/requirements.txt +134 -0
1_feature_extractor/1_main_training_IB.py
ADDED
|
@@ -0,0 +1,624 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
import argparse
|
| 4 |
+
import datetime
|
| 5 |
+
import numpy as np
|
| 6 |
+
import time
|
| 7 |
+
import torch
|
| 8 |
+
import torch.backends.cudnn as cudnn
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
from timm.models import create_model
|
| 14 |
+
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
| 15 |
+
from timm.scheduler import create_scheduler
|
| 16 |
+
from timm.optim import create_optimizer
|
| 17 |
+
from timm.utils import NativeScaler, get_state_dict, ModelEma
|
| 18 |
+
from augmentations import collate_data_and_cast_aug
|
| 19 |
+
from datasets import build_dataset
|
| 20 |
+
|
| 21 |
+
from losses_hint import DistillationLoss
|
| 22 |
+
from samplers import RASampler
|
| 23 |
+
from functools import partial
|
| 24 |
+
|
| 25 |
+
import importlib
|
| 26 |
+
import utils
|
| 27 |
+
import random
|
| 28 |
+
import math
|
| 29 |
+
from multiprocessing import Value
|
| 30 |
+
from abc import ABC
|
| 31 |
+
|
| 32 |
+
import sys
|
| 33 |
+
from typing import Iterable, Optional
|
| 34 |
+
from timm.data import Mixup
|
| 35 |
+
from timm.utils import accuracy, ModelEma
|
| 36 |
+
import utils
|
| 37 |
+
import logging
|
| 38 |
+
import torch.distributed as dist
|
| 39 |
+
import os
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class MaskingGenerator(ABC):
|
| 43 |
+
def __init__(self, input_size):
|
| 44 |
+
if not isinstance(input_size, tuple):
|
| 45 |
+
input_size = (input_size,) * 2
|
| 46 |
+
self.height, self.width = input_size
|
| 47 |
+
self.num_patches = self.height * self.width
|
| 48 |
+
|
| 49 |
+
def __repr__(self):
|
| 50 |
+
raise NotImplementedError
|
| 51 |
+
|
| 52 |
+
def get_shape(self):
|
| 53 |
+
return self.height, self.width
|
| 54 |
+
|
| 55 |
+
def _mask(self, mask, max_mask_patches):
|
| 56 |
+
raise NotImplementedError
|
| 57 |
+
|
| 58 |
+
def get_none_mask(self):
|
| 59 |
+
return np.zeros(shape=self.get_shape(), dtype=bool)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class RandomMaskingGenerator(MaskingGenerator):
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
input_size,
|
| 66 |
+
):
|
| 67 |
+
"""
|
| 68 |
+
Args:
|
| 69 |
+
input_size: the size of the token map, e.g., 14x14
|
| 70 |
+
"""
|
| 71 |
+
super().__init__(input_size)
|
| 72 |
+
|
| 73 |
+
def __repr__(self):
|
| 74 |
+
repr_str = f"Random Generator({self.height}, {self.width})"
|
| 75 |
+
return repr_str
|
| 76 |
+
|
| 77 |
+
def _mask(self, mask, max_mask_patches):
|
| 78 |
+
return super()._mask(mask, max_mask_patches)
|
| 79 |
+
|
| 80 |
+
def __call__(self, num_masking_patches=0):
|
| 81 |
+
if num_masking_patches <= 0:
|
| 82 |
+
return np.zeros(shape=self.get_shape(), dtype=bool)
|
| 83 |
+
|
| 84 |
+
mask = np.hstack([np.ones(num_masking_patches, dtype=bool),
|
| 85 |
+
np.zeros(self.num_patches - num_masking_patches, dtype=bool)])
|
| 86 |
+
np.random.shuffle(mask)
|
| 87 |
+
mask = mask.reshape(self.get_shape())
|
| 88 |
+
return mask
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def setup_logger(log_dir, rank=0):
|
| 92 |
+
if rank != 0:
|
| 93 |
+
return # 只有主进程(rank 0)配置日志记录器
|
| 94 |
+
log_formatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s")
|
| 95 |
+
root_logger = logging.getLogger()
|
| 96 |
+
root_logger.setLevel(logging.INFO)
|
| 97 |
+
|
| 98 |
+
log_file_handler = logging.FileHandler(log_dir, encoding='utf-8')
|
| 99 |
+
log_file_handler.setFormatter(log_formatter)
|
| 100 |
+
root_logger.addHandler(log_file_handler)
|
| 101 |
+
|
| 102 |
+
log_stream_handler = logging.StreamHandler(sys.stdout)
|
| 103 |
+
log_stream_handler.setFormatter(log_formatter)
|
| 104 |
+
root_logger.addHandler(log_stream_handler)
|
| 105 |
+
|
| 106 |
+
logging.info('Logging file is %s' % log_dir)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def get_args_parser():
|
| 110 |
+
parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
|
| 111 |
+
parser.add_argument('--batch-size', default=64, type=int)
|
| 112 |
+
parser.add_argument('--epochs', default=300, type=int)
|
| 113 |
+
parser.add_argument('--bce-loss', action='store_true')
|
| 114 |
+
parser.add_argument('--unscale-lr', action='store_true')
|
| 115 |
+
|
| 116 |
+
# Model parameters
|
| 117 |
+
parser.add_argument('--model', default='deit_base_patch16_224', type=str)
|
| 118 |
+
parser.add_argument('--target_model', default='deit_base_patch16_224', type=str)
|
| 119 |
+
parser.add_argument('--input-size', default=224, type=int, help='images input size')
|
| 120 |
+
|
| 121 |
+
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
|
| 122 |
+
help='Dropout rate (default: 0.)')
|
| 123 |
+
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
|
| 124 |
+
help='Drop path rate (default: 0.1)')
|
| 125 |
+
|
| 126 |
+
parser.add_argument('--model-ema', action='store_true')
|
| 127 |
+
parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
|
| 128 |
+
parser.set_defaults(model_ema=True)
|
| 129 |
+
parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
|
| 130 |
+
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
|
| 131 |
+
|
| 132 |
+
# Optimizer parameters
|
| 133 |
+
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
|
| 134 |
+
help='Optimizer (default: "adamw"')
|
| 135 |
+
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
|
| 136 |
+
help='Optimizer Epsilon (default: 1e-8)')
|
| 137 |
+
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
|
| 138 |
+
help='Optimizer Betas (default: None, use opt default)')
|
| 139 |
+
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
|
| 140 |
+
help='Clip gradient norm (default: None, no clipping)')
|
| 141 |
+
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
| 142 |
+
help='SGD momentum (default: 0.9)')
|
| 143 |
+
parser.add_argument('--weight-decay', type=float, default=0.05,
|
| 144 |
+
help='weight decay (default: 0.05)')
|
| 145 |
+
# Learning rate schedule parameters
|
| 146 |
+
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
|
| 147 |
+
help='LR scheduler (default: "cosine"')
|
| 148 |
+
parser.add_argument('--lr', type=float, default=4e-4, metavar='LR',
|
| 149 |
+
help='learning rate (default: 5e-4)')
|
| 150 |
+
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
|
| 151 |
+
help='learning rate noise on/off epoch percentages')
|
| 152 |
+
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
|
| 153 |
+
help='learning rate noise limit percent (default: 0.67)')
|
| 154 |
+
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
|
| 155 |
+
help='learning rate noise std-dev (default: 1.0)')
|
| 156 |
+
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
|
| 157 |
+
help='warmup learning rate (default: 1e-6)')
|
| 158 |
+
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
|
| 159 |
+
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
| 160 |
+
|
| 161 |
+
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
|
| 162 |
+
help='epoch interval to decay LR')
|
| 163 |
+
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
|
| 164 |
+
help='epochs to warmup LR, if scheduler supports')
|
| 165 |
+
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
|
| 166 |
+
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
|
| 167 |
+
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
|
| 168 |
+
help='patience epochs for Plateau LR scheduler (default: 10')
|
| 169 |
+
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
| 170 |
+
help='LR decay rate (default: 0.1)')
|
| 171 |
+
|
| 172 |
+
# Augmentation parameters
|
| 173 |
+
parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT',
|
| 174 |
+
help='Color jitter factor (default: 0.3)')
|
| 175 |
+
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
|
| 176 |
+
help='Use AutoAugment policy. "v0" or "original". " + \
|
| 177 |
+
"(default: rand-m9-mstd0.5-inc1)'),
|
| 178 |
+
parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
|
| 179 |
+
parser.add_argument('--train-interpolation', type=str, default='bicubic',
|
| 180 |
+
help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
|
| 181 |
+
|
| 182 |
+
parser.add_argument('--repeated-aug', action='store_true')
|
| 183 |
+
parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
|
| 184 |
+
parser.set_defaults(repeated_aug=True)
|
| 185 |
+
|
| 186 |
+
parser.add_argument('--train-mode', action='store_true')
|
| 187 |
+
parser.add_argument('--no-train-mode', action='store_false', dest='train_mode')
|
| 188 |
+
parser.set_defaults(train_mode=True)
|
| 189 |
+
|
| 190 |
+
parser.add_argument('--ThreeAugment', action='store_true') #3augment
|
| 191 |
+
|
| 192 |
+
parser.add_argument('--src', action='store_true') #simple random crop
|
| 193 |
+
|
| 194 |
+
# add dataset parameters
|
| 195 |
+
parser.add_argument('--global_crops_size', '--img_size', default=224, type=int,
|
| 196 |
+
help="this should be equal to image size")
|
| 197 |
+
parser.add_argument('--patch_size', default=16, type=int,
|
| 198 |
+
help="patch size for vit patch embedding")
|
| 199 |
+
|
| 200 |
+
# add masking parameter
|
| 201 |
+
parser.add_argument('--mask_ratio', default=(0.1, 0.5), type=float, nargs='+',
|
| 202 |
+
help="mask ratio can be either a value or a range")
|
| 203 |
+
parser.add_argument('--mask_probability', default=0., type=float,
|
| 204 |
+
help="how many samples with be applied with masking")
|
| 205 |
+
parser.add_argument('--mask_first_n', action='store_true',
|
| 206 |
+
help="mask the first n sample to avoid shuffling. Needed for MAE-style encoder")
|
| 207 |
+
parser.add_argument('--clone_batch', default=1, type=int,
|
| 208 |
+
help="how many times to clone the batch for masking (default: 1, not cloning)")
|
| 209 |
+
|
| 210 |
+
# * Random Erase params
|
| 211 |
+
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
|
| 212 |
+
help='Random erase prob (default: 0.25)')
|
| 213 |
+
parser.add_argument('--remode', type=str, default='pixel',
|
| 214 |
+
help='Random erase mode (default: "pixel")')
|
| 215 |
+
parser.add_argument('--recount', type=int, default=1,
|
| 216 |
+
help='Random erase count (default: 1)')
|
| 217 |
+
parser.add_argument('--resplit', action='store_true', default=False,
|
| 218 |
+
help='Do not random erase first (clean) augmentation split')
|
| 219 |
+
|
| 220 |
+
# * Mixup params
|
| 221 |
+
parser.add_argument('--mixup', type=float, default=0.8,
|
| 222 |
+
help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
|
| 223 |
+
parser.add_argument('--cutmix', type=float, default=1.0,
|
| 224 |
+
help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
|
| 225 |
+
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
|
| 226 |
+
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
|
| 227 |
+
parser.add_argument('--mixup-prob', type=float, default=1.0,
|
| 228 |
+
help='Probability of performing mixup or cutmix when either/both is enabled')
|
| 229 |
+
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
|
| 230 |
+
help='Probability of switching to cutmix when both mixup and cutmix enabled')
|
| 231 |
+
parser.add_argument('--mixup-mode', type=str, default='batch',
|
| 232 |
+
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
|
| 233 |
+
|
| 234 |
+
# Distillation parameters
|
| 235 |
+
parser.add_argument('--teacher-model', default='base', type=str)
|
| 236 |
+
parser.add_argument('--teacher-path', type=str, default='')
|
| 237 |
+
parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
|
| 238 |
+
parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
|
| 239 |
+
parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
|
| 240 |
+
parser.add_argument('--lambda_token', type=float, default=1.0)
|
| 241 |
+
parser.add_argument('--lambda_fea', type=float, default=1.0)
|
| 242 |
+
parser.add_argument('--lambda_patch', type=float, default=1.0)
|
| 243 |
+
|
| 244 |
+
# * Cosub params
|
| 245 |
+
parser.add_argument('--cosub', action='store_true')
|
| 246 |
+
|
| 247 |
+
# * Finetuning params
|
| 248 |
+
parser.add_argument('--finetune', default='', help='finetune from checkpoint')
|
| 249 |
+
parser.add_argument('--attn-only', action='store_true')
|
| 250 |
+
parser.add_argument('--weight_inherit', default='')
|
| 251 |
+
|
| 252 |
+
# Dataset parameters
|
| 253 |
+
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
|
| 254 |
+
help='dataset path')
|
| 255 |
+
parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'IMNET_ibot', 'IMNET_ibot_aug', 'IMNET_ibot_fast_aug', 'INAT', 'INAT19', 'IMNET_L', 'IMNET_L_ibot'],
|
| 256 |
+
type=str, help='Image Net dataset path')
|
| 257 |
+
parser.add_argument('--inat-category', default='name',
|
| 258 |
+
choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
|
| 259 |
+
type=str, help='semantic granularity')
|
| 260 |
+
|
| 261 |
+
parser.add_argument('--output_dir', default='',
|
| 262 |
+
help='path where to save, empty for no saving')
|
| 263 |
+
parser.add_argument('--log_dir', default='/data1/qiyp/Proteus-pytorch/pretrain/log/DINOv2_training/log',
|
| 264 |
+
type=str, help='saving logging info every 20 iters')
|
| 265 |
+
parser.add_argument('--device', default='cuda',
|
| 266 |
+
help='device to use for training / testing')
|
| 267 |
+
parser.add_argument('--seed', default=0, type=int)
|
| 268 |
+
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
| 269 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
| 270 |
+
help='start epoch')
|
| 271 |
+
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
|
| 272 |
+
parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation")
|
| 273 |
+
parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
|
| 274 |
+
parser.add_argument('--num_workers', default=10, type=int)
|
| 275 |
+
parser.add_argument('--pin-mem', action='store_true',
|
| 276 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
| 277 |
+
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
|
| 278 |
+
help='')
|
| 279 |
+
parser.set_defaults(pin_mem=True)
|
| 280 |
+
|
| 281 |
+
# distributed training parameters
|
| 282 |
+
parser.add_argument('--distributed', action='store_true', default=False, help='Enabling distributed training')
|
| 283 |
+
parser.add_argument('--world_size', default=1, type=int,
|
| 284 |
+
help='number of distributed processes')
|
| 285 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
| 286 |
+
return parser
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def show_learnable_params(model):
|
| 290 |
+
enabled = set()
|
| 291 |
+
for name, param in model.named_parameters():
|
| 292 |
+
if param.requires_grad:
|
| 293 |
+
enabled.add(name)
|
| 294 |
+
# print("Parameters to be updated: ")
|
| 295 |
+
logging.info("Parameters to be updated: ")
|
| 296 |
+
for each in enabled:
|
| 297 |
+
# print('\t{}\n'.format(str(each)))
|
| 298 |
+
logging.info('\t{}\n'.format(str(each)))
|
| 299 |
+
# print('\n')
|
| 300 |
+
logging.info('\n')
|
| 301 |
+
|
| 302 |
+
def show_unlearnable_params(model):
|
| 303 |
+
disabled = set()
|
| 304 |
+
for name, param in model.named_parameters():
|
| 305 |
+
if not param.requires_grad:
|
| 306 |
+
disabled.add(name)
|
| 307 |
+
|
| 308 |
+
logging.info("Parameters that are not being updated: ")
|
| 309 |
+
for each in disabled:
|
| 310 |
+
logging.info('\t{}'.format(str(each)))
|
| 311 |
+
logging.info('\n')
|
| 312 |
+
|
| 313 |
+
def main(args):
|
| 314 |
+
utils.init_distributed_mode(args)
|
| 315 |
+
|
| 316 |
+
print(args)
|
| 317 |
+
|
| 318 |
+
device = torch.device(args.device)
|
| 319 |
+
# 获取当前进程的 rank
|
| 320 |
+
rank = dist.get_rank() if dist.is_initialized() else 0
|
| 321 |
+
# set up logger
|
| 322 |
+
os.makedirs(args.log_dir, exist_ok=True)
|
| 323 |
+
setup_logger(args.log_dir + '/' + time.strftime('%Y%m%d_%H%M%S') + '.log', rank)
|
| 324 |
+
logging.info('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
|
| 325 |
+
logging.info("{}".format(args).replace(', ', ',\n') + '\n')
|
| 326 |
+
|
| 327 |
+
# fix the seed for reproducibility
|
| 328 |
+
seed = args.seed + utils.get_rank()
|
| 329 |
+
torch.manual_seed(seed)
|
| 330 |
+
np.random.seed(seed)
|
| 331 |
+
# random.seed(seed)
|
| 332 |
+
|
| 333 |
+
cudnn.benchmark = True
|
| 334 |
+
|
| 335 |
+
dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
|
| 336 |
+
logging.info(dataset_train)
|
| 337 |
+
|
| 338 |
+
if args.distributed:
|
| 339 |
+
num_tasks = utils.get_world_size()
|
| 340 |
+
global_rank = utils.get_rank()
|
| 341 |
+
if args.repeated_aug:
|
| 342 |
+
sampler_train = RASampler(
|
| 343 |
+
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
| 344 |
+
)
|
| 345 |
+
else:
|
| 346 |
+
sampler_train = torch.utils.data.DistributedSampler(
|
| 347 |
+
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
| 348 |
+
)
|
| 349 |
+
else:
|
| 350 |
+
sampler_train = torch.utils.data.RandomSampler(dataset_train)
|
| 351 |
+
logging.info("Sampler_train = %s" % str(sampler_train))
|
| 352 |
+
|
| 353 |
+
n_tokens = (args.global_crops_size // args.patch_size) ** 2
|
| 354 |
+
mask_generator = RandomMaskingGenerator(
|
| 355 |
+
input_size=args.global_crops_size // args.patch_size,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
collate_fn = partial(
|
| 359 |
+
collate_data_and_cast_aug,
|
| 360 |
+
mask_ratio=args.mask_ratio,
|
| 361 |
+
mask_probability=args.mask_probability,
|
| 362 |
+
dtype=torch.half, # half precision
|
| 363 |
+
n_tokens=n_tokens,
|
| 364 |
+
mask_first_n=args.mask_first_n,
|
| 365 |
+
mask_generator=mask_generator,
|
| 366 |
+
clone_batch=args.clone_batch,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
data_loader_train = torch.utils.data.DataLoader(
|
| 370 |
+
dataset_train, sampler=sampler_train,
|
| 371 |
+
batch_size=args.batch_size,
|
| 372 |
+
num_workers=args.num_workers,
|
| 373 |
+
pin_memory=args.pin_mem,
|
| 374 |
+
drop_last=True,
|
| 375 |
+
collate_fn=collate_fn,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
mixup_fn = None
|
| 379 |
+
|
| 380 |
+
print(f"Creating model: {args.model}") # models_proteus_dinov2
|
| 381 |
+
meta_arch_module = importlib.import_module(args.model)
|
| 382 |
+
MetaArch = meta_arch_module.MetaArch
|
| 383 |
+
|
| 384 |
+
model = MetaArch(args)
|
| 385 |
+
logging.info("Model = %s" % str(model))
|
| 386 |
+
|
| 387 |
+
if args.finetune:
|
| 388 |
+
checkpoint = torch.load(args.finetune, map_location='cpu')
|
| 389 |
+
|
| 390 |
+
if 'state_dict' in checkpoint:
|
| 391 |
+
pretrained_dict = checkpoint['state_dict']
|
| 392 |
+
elif 'model' in checkpoint:
|
| 393 |
+
pretrained_dict = checkpoint['model']
|
| 394 |
+
else:
|
| 395 |
+
pretrained_dict = checkpoint
|
| 396 |
+
|
| 397 |
+
missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False)
|
| 398 |
+
# print('missing_keys: ', missing_keys)
|
| 399 |
+
# print('unexpected_keys: ', unexpected_keys)
|
| 400 |
+
logging.info('Finetuning from %s' % args.finetune)
|
| 401 |
+
logging.info('missing_keys: %s' % str(missing_keys))
|
| 402 |
+
logging.info('unexpected_keys: %s' % str(unexpected_keys))
|
| 403 |
+
|
| 404 |
+
if args.attn_only:
|
| 405 |
+
for name_p,p in model.named_parameters():
|
| 406 |
+
if '.attn.' in name_p:
|
| 407 |
+
p.requires_grad = True
|
| 408 |
+
else:
|
| 409 |
+
p.requires_grad = False
|
| 410 |
+
try:
|
| 411 |
+
model.head.weight.requires_grad = True
|
| 412 |
+
model.head.bias.requires_grad = True
|
| 413 |
+
except:
|
| 414 |
+
model.fc.weight.requires_grad = True
|
| 415 |
+
model.fc.bias.requires_grad = True
|
| 416 |
+
try:
|
| 417 |
+
model.pos_embed.requires_grad = True
|
| 418 |
+
except:
|
| 419 |
+
print('no position encoding')
|
| 420 |
+
try:
|
| 421 |
+
for p in model.patch_embed.parameters():
|
| 422 |
+
p.requires_grad = False
|
| 423 |
+
except:
|
| 424 |
+
print('no patch embed')
|
| 425 |
+
|
| 426 |
+
model.to(device)
|
| 427 |
+
|
| 428 |
+
model_ema = None
|
| 429 |
+
if args.model_ema:
|
| 430 |
+
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
|
| 431 |
+
model_ema = ModelEma(
|
| 432 |
+
model.student.backbone,
|
| 433 |
+
decay=args.model_ema_decay,
|
| 434 |
+
device='cpu' if args.model_ema_force_cpu else '',
|
| 435 |
+
resume='')
|
| 436 |
+
|
| 437 |
+
model_without_ddp = model
|
| 438 |
+
if args.distributed:
|
| 439 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
| 440 |
+
model_without_ddp = model.module
|
| 441 |
+
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 442 |
+
# print('number of params:', n_parameters)
|
| 443 |
+
logging.info('number of params: %s' % n_parameters)
|
| 444 |
+
|
| 445 |
+
if not args.unscale_lr:
|
| 446 |
+
logging.info('base lr = %s' % args.lr)
|
| 447 |
+
linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
|
| 448 |
+
args.lr = linear_scaled_lr
|
| 449 |
+
logging.info('actural lr = %s' % linear_scaled_lr)
|
| 450 |
+
|
| 451 |
+
optimizer = create_optimizer(args, model_without_ddp)
|
| 452 |
+
loss_scaler = NativeScaler()
|
| 453 |
+
|
| 454 |
+
lr_scheduler, _ = create_scheduler(args, optimizer)
|
| 455 |
+
|
| 456 |
+
output_dir = Path(args.output_dir)
|
| 457 |
+
if args.resume:
|
| 458 |
+
if args.resume.startswith('https'):
|
| 459 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 460 |
+
args.resume, map_location='cpu', check_hash=True)
|
| 461 |
+
else:
|
| 462 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
| 463 |
+
|
| 464 |
+
model_without_ddp.load_state_dict(checkpoint['model'])
|
| 465 |
+
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
|
| 466 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 467 |
+
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
| 468 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
| 469 |
+
if args.model_ema:
|
| 470 |
+
utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
|
| 471 |
+
if 'scaler' in checkpoint:
|
| 472 |
+
loss_scaler.load_state_dict(checkpoint['scaler'])
|
| 473 |
+
lr_scheduler.step(args.start_epoch)
|
| 474 |
+
logging.info('Resuming from %s' % args.resume)
|
| 475 |
+
|
| 476 |
+
# print(f"Start training for {args.epochs} epochs")
|
| 477 |
+
logging.info("Start training for %s epochs" % args.epochs)
|
| 478 |
+
start_time = time.time()
|
| 479 |
+
max_accuracy = 0.0
|
| 480 |
+
for epoch in range(args.start_epoch, args.epochs):
|
| 481 |
+
if args.distributed:
|
| 482 |
+
data_loader_train.sampler.set_epoch(epoch)
|
| 483 |
+
|
| 484 |
+
if epoch < 5:
|
| 485 |
+
# 前5个epoch仅放开 entropy model 的参数
|
| 486 |
+
for name, param in model.named_parameters():
|
| 487 |
+
if 'info_bottleneck' in name:
|
| 488 |
+
param.requires_grad = True
|
| 489 |
+
else:
|
| 490 |
+
param.requires_grad = False
|
| 491 |
+
if epoch == 0:
|
| 492 |
+
show_learnable_params(model)
|
| 493 |
+
else:
|
| 494 |
+
# 其余epoch放开所有参数,但固定model.teacher的参数
|
| 495 |
+
for name, param in model.named_parameters():
|
| 496 |
+
param.requires_grad = True
|
| 497 |
+
for name, param in model.named_parameters():
|
| 498 |
+
if 'teacher' in name:
|
| 499 |
+
param.requires_grad = False
|
| 500 |
+
if epoch == 5:
|
| 501 |
+
show_unlearnable_params(model)
|
| 502 |
+
|
| 503 |
+
train_stats = train_one_epoch(
|
| 504 |
+
model, data_loader_train,
|
| 505 |
+
optimizer, device, epoch, loss_scaler,
|
| 506 |
+
args.clip_grad, model_ema, mixup_fn,
|
| 507 |
+
set_training_mode=args.train_mode, # keep in eval mode for deit finetuning / train mode for training and deit III finetuning
|
| 508 |
+
args = args,
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
lr_scheduler.step(epoch)
|
| 512 |
+
if args.output_dir:
|
| 513 |
+
checkpoint_paths = [output_dir / 'checkpoint.pth']
|
| 514 |
+
for checkpoint_path in checkpoint_paths:
|
| 515 |
+
utils.save_on_master({
|
| 516 |
+
'model': model_without_ddp.state_dict(),
|
| 517 |
+
'optimizer': optimizer.state_dict(),
|
| 518 |
+
'lr_scheduler': lr_scheduler.state_dict(),
|
| 519 |
+
'epoch': epoch,
|
| 520 |
+
'model_ema': get_state_dict(model_ema),
|
| 521 |
+
'scaler': loss_scaler.state_dict(),
|
| 522 |
+
'args': args,
|
| 523 |
+
}, checkpoint_path)
|
| 524 |
+
if args.output_dir and (epoch % 5 == 0 or epoch + 1 == args.epochs):
|
| 525 |
+
checkpoint_path = output_dir / f'checkpoint{epoch:04}.pth'
|
| 526 |
+
utils.save_on_master({
|
| 527 |
+
'model': model_without_ddp.state_dict(),
|
| 528 |
+
'optimizer': optimizer.state_dict(),
|
| 529 |
+
'lr_scheduler': lr_scheduler.state_dict(),
|
| 530 |
+
'epoch': epoch,
|
| 531 |
+
'model_ema': get_state_dict(model_ema),
|
| 532 |
+
'scaler': loss_scaler.state_dict(),
|
| 533 |
+
'args': args,
|
| 534 |
+
}, checkpoint_path)
|
| 535 |
+
|
| 536 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
| 537 |
+
'epoch': epoch,
|
| 538 |
+
'n_parameters': n_parameters}
|
| 539 |
+
|
| 540 |
+
if args.output_dir and utils.is_main_process():
|
| 541 |
+
with (output_dir / "log.txt").open("a") as f:
|
| 542 |
+
f.write(json.dumps(log_stats) + "\n")
|
| 543 |
+
|
| 544 |
+
total_time = time.time() - start_time
|
| 545 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 546 |
+
# print('Training time {}'.format(total_time_str))
|
| 547 |
+
logging.info('Training time %s' % total_time_str)
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def train_one_epoch(model: torch.nn.Module,
|
| 551 |
+
data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
| 552 |
+
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
|
| 553 |
+
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
|
| 554 |
+
set_training_mode=True, args = None):
|
| 555 |
+
model.train(set_training_mode)
|
| 556 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
| 557 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
| 558 |
+
header = 'Epoch: [{}]'.format(epoch)
|
| 559 |
+
print_freq = 20
|
| 560 |
+
|
| 561 |
+
loader_len = len(data_loader)
|
| 562 |
+
|
| 563 |
+
for data_iter_step, inputs_dict in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
| 564 |
+
|
| 565 |
+
for k, v in inputs_dict.items():
|
| 566 |
+
if isinstance(v, torch.Tensor):
|
| 567 |
+
inputs_dict[k] = v.to(device, non_blocking=True)
|
| 568 |
+
|
| 569 |
+
with torch.cuda.amp.autocast():
|
| 570 |
+
loss_dict = model(inputs_dict)
|
| 571 |
+
|
| 572 |
+
loss = loss_dict["loss"]
|
| 573 |
+
patch_loss = loss_dict["patch_loss"]
|
| 574 |
+
fea_loss = loss_dict["fea_loss"]
|
| 575 |
+
token_loss = loss_dict["token_loss"]
|
| 576 |
+
bpp_loss = loss_dict["bpp_loss"]
|
| 577 |
+
task_loss = loss_dict["task_loss"]
|
| 578 |
+
|
| 579 |
+
patch_loss_value = patch_loss.item()
|
| 580 |
+
token_loss_value = token_loss.item()
|
| 581 |
+
fea_loss_value = fea_loss.item()
|
| 582 |
+
bpp_loss_value = bpp_loss.item()
|
| 583 |
+
task_loss_value = task_loss.item()
|
| 584 |
+
loss_value = loss.item()
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
if not math.isfinite(loss_value):
|
| 588 |
+
# print("Loss is {}, stopping training".format(loss_value))
|
| 589 |
+
logging.info("Loss is %s, stopping training" % loss_value)
|
| 590 |
+
logging.info("bpp_loss is {}, patch_loss is {}, token_loss is {}, fea_loss is {}".format(bpp_loss_value, patch_loss_value, token_loss_value, fea_loss_value))
|
| 591 |
+
sys.exit(1)
|
| 592 |
+
|
| 593 |
+
optimizer.zero_grad()
|
| 594 |
+
|
| 595 |
+
# this attribute is added by timm on one optimizer (adahessian)
|
| 596 |
+
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
|
| 597 |
+
loss_scaler(loss, optimizer, clip_grad=max_norm,
|
| 598 |
+
parameters=model.parameters(), create_graph=is_second_order)
|
| 599 |
+
|
| 600 |
+
torch.cuda.synchronize()
|
| 601 |
+
if model_ema is not None:
|
| 602 |
+
model_ema.update(model.module.student.backbone)
|
| 603 |
+
|
| 604 |
+
metric_logger.update(loss=loss_value)
|
| 605 |
+
metric_logger.update(task_loss=task_loss_value)
|
| 606 |
+
metric_logger.update(bpp_loss=bpp_loss_value)
|
| 607 |
+
metric_logger.update(patch_loss=patch_loss_value)
|
| 608 |
+
metric_logger.update(token_loss=token_loss_value)
|
| 609 |
+
metric_logger.update(fea_loss=fea_loss_value)
|
| 610 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
| 611 |
+
# gather the stats from all processes
|
| 612 |
+
metric_logger.synchronize_between_processes()
|
| 613 |
+
# print("Averaged stats:", metric_logger)
|
| 614 |
+
logging.info("Averaged stats: {}".format(metric_logger))
|
| 615 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
if __name__ == '__main__':
|
| 620 |
+
parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
|
| 621 |
+
args = parser.parse_args()
|
| 622 |
+
if args.output_dir:
|
| 623 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 624 |
+
main(args)
|
1_feature_extractor/1_training_IB.sh
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
#### access DINOv2
|
| 4 |
+
export CUDA_VISIBLE_DEVICES=0,1,2;
|
| 5 |
+
|
| 6 |
+
python -m torch.distributed.launch --nproc_per_node=3 --use_env 1_main_training_IB.py \
|
| 7 |
+
--batch-size 48 --warmup-epochs 5 --epochs 200 \
|
| 8 |
+
--data-set IMNET --data-path '/data1/datasets/imagenet_fold' \
|
| 9 |
+
--teacher-model vit_large --target_model vit_base --model models_proteus_dinov2 \
|
| 10 |
+
--patch_size 14 --mask_probability 0.5 --mask_ratio 0.5 --mask_first_n \
|
| 11 |
+
--lambda_token 1.0 --lambda_fea 1.05 --lambda_patch 1.05 \
|
| 12 |
+
--resume "/data0/qiyp/Proteus-pytorch/pretrain/log/DINOv2_training/checkpoint0160.pth" \
|
| 13 |
+
--log_dir '/data0/qiyp/Proteus-pytorch/pretrain/log/DINOv2_training/log/' \
|
| 14 |
+
--output_dir log/DINOv2_training;
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
#### access SynCLR
|
| 19 |
+
|
| 20 |
+
# python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py \
|
| 21 |
+
# --batch-size 128 --warmup-epochs 5 --epochs 300 \
|
| 22 |
+
# --data-set IMNET --data-path imagenet_path \
|
| 23 |
+
# --teacher-model vit_large --target_model vit_base --model models_proteus_synclr \
|
| 24 |
+
# --teacher-path pretrained_synclr_path \
|
| 25 |
+
# --patch_size 14 --mask_probability 0.5 --mask_ratio 0.5 --mask_first_n \
|
| 26 |
+
# --lambda_token 1.0 --lambda_fea 1.0 --lambda_patch 1.0 \
|
| 27 |
+
# --output_dir log/SynCLR_training;
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
#### access CLIP
|
| 32 |
+
|
| 33 |
+
# python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py \
|
| 34 |
+
# --batch-size 128 --warmup-epochs 5 --epochs 300 \
|
| 35 |
+
# --data-set IMNET --data-path imagenet_path \
|
| 36 |
+
# --teacher-model vit_large --target_model vit_base --model models_proteus_clip \
|
| 37 |
+
# --teacher-path pretrained_clip_path \
|
| 38 |
+
# --patch_size 14 --mask_probability 0.5 --mask_ratio 0.5 --mask_first_n \
|
| 39 |
+
# --lambda_token 1.0 --lambda_fea 0.0 --lambda_patch 0.0 \
|
| 40 |
+
# --output_dir log/CLIP_training;
|
1_feature_extractor/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Yunpeng Qi
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
1_feature_extractor/README copy.md
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pre-training on ImageNet-1K
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
## Installation
|
| 5 |
+
Please follow the installation instructions in [DINOv2](https://github.com/facebookresearch/dinov2/tree/main?tab=readme-ov-file#installation) and install timm==0.9.16 as well.
|
| 6 |
+
|
| 7 |
+
## Dataset
|
| 8 |
+
We prepare ImageNet-1K following the instructions in [DeiT](https://github.com/facebookresearch/deit/blob/main/README_deit.md#data-preparation).
|
| 9 |
+
|
| 10 |
+
## Training
|
| 11 |
+
1. Specify the directory of datasets with `data-path` in the training script `run_pretrain.sh`.
|
| 12 |
+
2. Use the `teacher-model` and `target_model` parameters to select the appropriate teacher and student models.
|
| 13 |
+
3. Specify the model choices with `model` to choose from DINOv2, SynCLR, CLIP.
|
| 14 |
+
4. For SynCLR and CLIP training, use the `teacher-path` parameter to indicate the path to the pre-trained teacher model.
|
| 15 |
+
5. Simply run the training script as follows:
|
| 16 |
+
|
| 17 |
+
```
|
| 18 |
+
bash run_pretrain.sh
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
## Acknowledgment
|
| 23 |
+
|
| 24 |
+
This part is heavily build upon [DeiT](https://github.com/facebookresearch/deit?tab=readme-ov-file), [DINOv2](https://github.com/facebookresearch/dinov2), [SynCLR](https://github.com/google-research/syn-rep-learn/tree/main/SynCLR). We gratefully thank the authors for their wonderful works.
|
1_feature_extractor/README.md
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 1_feature_extractor
|
| 2 |
+
training information bottleneck
|
| 3 |
+
|
| 4 |
+
IF Model: https://huggingface.co/Qiyp/1_feature_extractor
|
| 5 |
+
|
| 6 |
+
Installation:
|
| 7 |
+
Clone the repository and then use the provided requirements.txt to install the dependencies:
|
| 8 |
+
pip install -r requirements.txt
|
| 9 |
+
|
| 10 |
+
put the proteus_vitb_backbone.pth into ./ckpt
|
| 11 |
+
|
| 12 |
+
load pretrained Feature Extractor for further training: --resume + Feature Extractor path
|
| 13 |
+
|
| 14 |
+
load pretrained Feature Extractor for feature extractraction: --finetune + Feature Extractor path
|
| 15 |
+
|
| 16 |
+
training scripts: 1_training_IB.sh
|
| 17 |
+
vis feature's semantic information: train_dec.sh
|
1_feature_extractor/__pycache__/augmentations.cpython-39.pyc
ADDED
|
Binary file (2.08 kB). View file
|
|
|
1_feature_extractor/__pycache__/datasets.cpython-39.pyc
ADDED
|
Binary file (3.12 kB). View file
|
|
|
1_feature_extractor/__pycache__/losses_hint.cpython-39.pyc
ADDED
|
Binary file (2.17 kB). View file
|
|
|
1_feature_extractor/__pycache__/models_IB.cpython-39.pyc
ADDED
|
Binary file (1.62 kB). View file
|
|
|
1_feature_extractor/__pycache__/models_clip.cpython-39.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
1_feature_extractor/__pycache__/models_dinov2.cpython-39.pyc
ADDED
|
Binary file (27.1 kB). View file
|
|
|
1_feature_extractor/__pycache__/models_proteus_clip.cpython-39.pyc
ADDED
|
Binary file (2.26 kB). View file
|
|
|
1_feature_extractor/__pycache__/models_proteus_dinov2.cpython-39.pyc
ADDED
|
Binary file (4.61 kB). View file
|
|
|
1_feature_extractor/__pycache__/models_proteus_synclr.cpython-39.pyc
ADDED
|
Binary file (3.58 kB). View file
|
|
|
1_feature_extractor/__pycache__/models_synclr.cpython-39.pyc
ADDED
|
Binary file (16.1 kB). View file
|
|
|
1_feature_extractor/__pycache__/samplers.cpython-39.pyc
ADDED
|
Binary file (2.25 kB). View file
|
|
|
1_feature_extractor/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (7.63 kB). View file
|
|
|
1_feature_extractor/augmentations.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Google LLC
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import random
|
| 17 |
+
from torchvision import transforms
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger("dinov2")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def collate_data_and_cast_aug(
|
| 25 |
+
samples_list,
|
| 26 |
+
mask_ratio,
|
| 27 |
+
mask_probability,
|
| 28 |
+
dtype,
|
| 29 |
+
n_tokens=None,
|
| 30 |
+
mask_first_n=False,
|
| 31 |
+
mask_generator=None,
|
| 32 |
+
clone_batch=1,
|
| 33 |
+
):
|
| 34 |
+
# dtype = torch.half # TODO: Remove
|
| 35 |
+
|
| 36 |
+
n_global_crops = 1
|
| 37 |
+
|
| 38 |
+
assert n_global_crops > 0, "global crops number should be > 0"
|
| 39 |
+
collated_global_crops = torch.stack([s[i] for i in range(n_global_crops) for s in samples_list])
|
| 40 |
+
|
| 41 |
+
labels = [s[1] for s in samples_list]
|
| 42 |
+
labels = torch.LongTensor(labels)
|
| 43 |
+
collated_global_labels = labels.repeat(n_global_crops)
|
| 44 |
+
|
| 45 |
+
B = len(collated_global_crops)
|
| 46 |
+
N = n_tokens
|
| 47 |
+
n_samples_masked = int(B * mask_probability)
|
| 48 |
+
|
| 49 |
+
masks_list = []
|
| 50 |
+
upperbound = 0
|
| 51 |
+
|
| 52 |
+
masks_enc = torch.full((1,), 0, dtype=torch.int32)
|
| 53 |
+
masks_pred = torch.full((1,), 0, dtype=torch.int32)
|
| 54 |
+
# specify the number of masks to append
|
| 55 |
+
number_masks = n_samples_masked * clone_batch
|
| 56 |
+
# do per-sample masking
|
| 57 |
+
if isinstance(mask_ratio, (tuple, list)) and len(mask_ratio) == 2:
|
| 58 |
+
probs = torch.linspace(*mask_ratio, number_masks + 1)
|
| 59 |
+
for i in range(0, number_masks):
|
| 60 |
+
prob_min = probs[i]
|
| 61 |
+
prob_max = probs[i + 1]
|
| 62 |
+
masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max)))))
|
| 63 |
+
upperbound += int(N * prob_max)
|
| 64 |
+
else:
|
| 65 |
+
mask_ratio = mask_ratio[0]
|
| 66 |
+
# apply the same mask ratio to all images
|
| 67 |
+
for i in range(0, number_masks):
|
| 68 |
+
masks_list.append(torch.BoolTensor(mask_generator(int(N * mask_ratio))))
|
| 69 |
+
upperbound += int(N * mask_ratio)
|
| 70 |
+
|
| 71 |
+
# append masks for unmasked samples
|
| 72 |
+
for i in range(n_samples_masked, B):
|
| 73 |
+
# masks_list.append(torch.BoolTensor(mask_generator(0)))
|
| 74 |
+
masks_list.append(torch.BoolTensor(mask_generator.get_none_mask()))
|
| 75 |
+
|
| 76 |
+
if not mask_first_n and mask_probability > 0.0: # shuffle masking -- not shuffling for mae-style
|
| 77 |
+
random.shuffle(masks_list)
|
| 78 |
+
|
| 79 |
+
collated_masks = torch.stack(masks_list).flatten(1)
|
| 80 |
+
mask_indices_list = collated_masks.flatten().nonzero().flatten()
|
| 81 |
+
|
| 82 |
+
masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks]
|
| 83 |
+
|
| 84 |
+
return {
|
| 85 |
+
"collated_global_crops": collated_global_crops.to(dtype),
|
| 86 |
+
"collated_global_labels": collated_global_labels,
|
| 87 |
+
"collated_masks": collated_masks,
|
| 88 |
+
"mask_indices_list": mask_indices_list,
|
| 89 |
+
"masks_weight": masks_weight,
|
| 90 |
+
"upperbound": upperbound,
|
| 91 |
+
"n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long),
|
| 92 |
+
"masks_enc": masks_enc,
|
| 93 |
+
"masks_pred": masks_pred,
|
| 94 |
+
}
|
1_feature_extractor/datasets.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
from torchvision.datasets import DatasetFolder
|
| 6 |
+
from torchvision.io import read_image
|
| 7 |
+
from torchvision import datasets, transforms
|
| 8 |
+
from torchvision.datasets.folder import ImageFolder, default_loader
|
| 9 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 10 |
+
from timm.data import create_transform
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class INatDataset(ImageFolder):
|
| 15 |
+
def __init__(self, root, train=True, year=2018, transform=None, target_transform=None,
|
| 16 |
+
category='name', loader=default_loader):
|
| 17 |
+
self.transform = transform
|
| 18 |
+
self.loader = loader
|
| 19 |
+
self.target_transform = target_transform
|
| 20 |
+
self.year = year
|
| 21 |
+
# assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
|
| 22 |
+
path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json')
|
| 23 |
+
with open(path_json) as json_file:
|
| 24 |
+
data = json.load(json_file)
|
| 25 |
+
|
| 26 |
+
with open(os.path.join(root, 'categories.json')) as json_file:
|
| 27 |
+
data_catg = json.load(json_file)
|
| 28 |
+
|
| 29 |
+
path_json_for_targeter = os.path.join(root, f"train{year}.json")
|
| 30 |
+
|
| 31 |
+
with open(path_json_for_targeter) as json_file:
|
| 32 |
+
data_for_targeter = json.load(json_file)
|
| 33 |
+
|
| 34 |
+
targeter = {}
|
| 35 |
+
indexer = 0
|
| 36 |
+
for elem in data_for_targeter['annotations']:
|
| 37 |
+
king = []
|
| 38 |
+
king.append(data_catg[int(elem['category_id'])][category])
|
| 39 |
+
if king[0] not in targeter.keys():
|
| 40 |
+
targeter[king[0]] = indexer
|
| 41 |
+
indexer += 1
|
| 42 |
+
self.nb_classes = len(targeter)
|
| 43 |
+
|
| 44 |
+
self.samples = []
|
| 45 |
+
for elem in data['images']:
|
| 46 |
+
cut = elem['file_name'].split('/')
|
| 47 |
+
target_current = int(cut[2])
|
| 48 |
+
path_current = os.path.join(root, cut[0], cut[2], cut[3])
|
| 49 |
+
|
| 50 |
+
categors = data_catg[target_current]
|
| 51 |
+
target_current_true = targeter[categors[category]]
|
| 52 |
+
self.samples.append((path_current, target_current_true))
|
| 53 |
+
|
| 54 |
+
# __getitem__ and __len__ inherited from ImageFolder
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def build_dataset(is_train, args):
|
| 58 |
+
transform = build_transform(is_train, args)
|
| 59 |
+
|
| 60 |
+
if args.data_set == 'CIFAR':
|
| 61 |
+
dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform)
|
| 62 |
+
nb_classes = 100
|
| 63 |
+
elif args.data_set == 'IMNET':
|
| 64 |
+
root = os.path.join(args.data_path, 'train' if is_train else 'val')
|
| 65 |
+
dataset = datasets.ImageFolder(root, transform=transform)
|
| 66 |
+
nb_classes = 1000
|
| 67 |
+
elif args.data_set == 'INAT':
|
| 68 |
+
dataset = INatDataset(args.data_path, train=is_train, year=2018,
|
| 69 |
+
category=args.inat_category, transform=transform)
|
| 70 |
+
nb_classes = dataset.nb_classes
|
| 71 |
+
elif args.data_set == 'INAT19':
|
| 72 |
+
dataset = INatDataset(args.data_path, train=is_train, year=2019,
|
| 73 |
+
category=args.inat_category, transform=transform)
|
| 74 |
+
nb_classes = dataset.nb_classes
|
| 75 |
+
|
| 76 |
+
return dataset, nb_classes
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def build_transform(is_train, args):
|
| 80 |
+
resize_im = args.input_size > 32
|
| 81 |
+
if is_train:
|
| 82 |
+
# this should always dispatch to transforms_imagenet_train
|
| 83 |
+
transform = create_transform(
|
| 84 |
+
input_size=args.input_size,
|
| 85 |
+
is_training=True,
|
| 86 |
+
color_jitter=args.color_jitter,
|
| 87 |
+
auto_augment=args.aa,
|
| 88 |
+
interpolation=args.train_interpolation,
|
| 89 |
+
re_prob=args.reprob,
|
| 90 |
+
re_mode=args.remode,
|
| 91 |
+
re_count=args.recount,
|
| 92 |
+
)
|
| 93 |
+
if not resize_im:
|
| 94 |
+
# replace RandomResizedCropAndInterpolation with
|
| 95 |
+
# RandomCrop
|
| 96 |
+
transform.transforms[0] = transforms.RandomCrop(
|
| 97 |
+
args.input_size, padding=4)
|
| 98 |
+
return transform
|
| 99 |
+
|
| 100 |
+
t = []
|
| 101 |
+
if resize_im:
|
| 102 |
+
size = int(args.input_size / args.eval_crop_ratio)
|
| 103 |
+
t.append(
|
| 104 |
+
transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
|
| 105 |
+
)
|
| 106 |
+
t.append(transforms.CenterCrop(args.input_size))
|
| 107 |
+
|
| 108 |
+
t.append(transforms.ToTensor())
|
| 109 |
+
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
|
| 110 |
+
return transforms.Compose(t)
|
1_feature_extractor/fast_vis.sh
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
#### access DINOv2
|
| 4 |
+
export CUDA_VISIBLE_DEVICES=6;
|
| 5 |
+
|
| 6 |
+
python fast_vis_settings_all.py \
|
| 7 |
+
--batch-size 64 --warmup-epochs 5 --epochs 300 \
|
| 8 |
+
--data-set IMNET --data-path '/data1/datasets/imagenet_fold' \
|
| 9 |
+
--teacher-model vit_large --target_model vit_base --model models_proteus_dinov2 \
|
| 10 |
+
--patch_size 14 --mask_probability 0.5 --mask_ratio 0.5 --mask_first_n \
|
| 11 |
+
--lambda_token 1.0 --lambda_fea 1.0 --lambda_patch 1.0 \
|
| 12 |
+
--finetune "/data0/qiyp/Proteus-pytorch/pretrain/log/DINOv2_training/checkpoint0160.pth" \
|
| 13 |
+
--output_dir log/DINOv2_training;
|
| 14 |
+
|
| 15 |
+
#### access SynCLR
|
| 16 |
+
|
| 17 |
+
# python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py \
|
| 18 |
+
# --batch-size 128 --warmup-epochs 5 --epochs 300 \
|
| 19 |
+
# --data-set IMNET --data-path imagenet_path \
|
| 20 |
+
# --teacher-model vit_large --target_model vit_base --model models_proteus_synclr \
|
| 21 |
+
# --teacher-path pretrained_synclr_path \
|
| 22 |
+
# --patch_size 14 --mask_probability 0.5 --mask_ratio 0.5 --mask_first_n \
|
| 23 |
+
# --lambda_token 1.0 --lambda_fea 1.0 --lambda_patch 1.0 \
|
| 24 |
+
# --output_dir log/SynCLR_training;
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
#### access CLIP
|
| 29 |
+
|
| 30 |
+
# python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py \
|
| 31 |
+
# --batch-size 128 --warmup-epochs 5 --epochs 300 \
|
| 32 |
+
# --data-set IMNET --data-path imagenet_path \
|
| 33 |
+
# --teacher-model vit_large --target_model vit_base --model models_proteus_clip \
|
| 34 |
+
# --teacher-path pretrained_clip_path \
|
| 35 |
+
# --patch_size 14 --mask_probability 0.5 --mask_ratio 0.5 --mask_first_n \
|
| 36 |
+
# --lambda_token 1.0 --lambda_fea 0.0 --lambda_patch 0.0 \
|
| 37 |
+
# --output_dir log/CLIP_training;
|
1_feature_extractor/fast_vis_proteus_feats.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import numpy as np
|
| 4 |
+
import einops
|
| 5 |
+
import torch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import torchvision
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import models_dinov2
|
| 10 |
+
|
| 11 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 12 |
+
patch_size = 14
|
| 13 |
+
|
| 14 |
+
# feat_extractor = getattr(models_dinov2, 'vit_base')
|
| 15 |
+
feat_extractor = getattr(models_dinov2, 'vit_large')
|
| 16 |
+
model = feat_extractor(img_size=224,
|
| 17 |
+
patch_size=14,
|
| 18 |
+
init_values=1.0,
|
| 19 |
+
ffn_layer='mlp',
|
| 20 |
+
block_chunks=0,
|
| 21 |
+
num_register_tokens=0,
|
| 22 |
+
interpolate_antialias=False,
|
| 23 |
+
interpolate_offset=0.1)
|
| 24 |
+
|
| 25 |
+
# checkpoint_path = '/data0/qiyp/Proteus-pytorch/pretrain/ckpt/proteus_vitb_backbone.pth' # 替换为实际的检查点路径
|
| 26 |
+
checkpoint_path = '/data0/qiyp/Proteus-pytorch/pretrain/ckpt/proteus_vitl_backbone.pth' # 替换为实际的检查点路径
|
| 27 |
+
# 加载检查点
|
| 28 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 29 |
+
|
| 30 |
+
# 加载模型权重
|
| 31 |
+
if 'state_dict' in checkpoint:
|
| 32 |
+
pretrained_dict = checkpoint['state_dict']
|
| 33 |
+
elif 'model' in checkpoint:
|
| 34 |
+
pretrained_dict = checkpoint['model']
|
| 35 |
+
else:
|
| 36 |
+
pretrained_dict = checkpoint
|
| 37 |
+
|
| 38 |
+
# 只加载与学生模型相关的部分
|
| 39 |
+
model.load_state_dict(pretrained_dict, strict=False)
|
| 40 |
+
model.to(device)
|
| 41 |
+
patch_h = 224 // 14
|
| 42 |
+
patch_w = 224 // 14
|
| 43 |
+
feat_dim = 768
|
| 44 |
+
|
| 45 |
+
def visualize_features(features, output_path='./feature_visualization.png'):
|
| 46 |
+
# Assuming features are of shape (batch_size, num_features, height, width)
|
| 47 |
+
batch_size, num_features, height, width = features.shape
|
| 48 |
+
|
| 49 |
+
# Normalize the feature maps to the range [0, 1]
|
| 50 |
+
vis = features.mean(dim=1, keepdim=True)
|
| 51 |
+
vis = vis - vis.min()
|
| 52 |
+
vis = vis / vis.max()
|
| 53 |
+
|
| 54 |
+
# Squeeze the channel dimension
|
| 55 |
+
vis = vis.squeeze(1).cpu().detach().numpy()
|
| 56 |
+
|
| 57 |
+
# Apply a colormap (e.g., viridis) to convert it to RGB
|
| 58 |
+
vis_colored = np.zeros((batch_size, height, width, 3))
|
| 59 |
+
for i in range(batch_size):
|
| 60 |
+
vis_colored[i] = plt.cm.viridis(vis[i])[:, :, :3] # Drop the alpha channel
|
| 61 |
+
|
| 62 |
+
# Convert vis_colored to a tensor and save using torchvision
|
| 63 |
+
vis_colored = torch.tensor(vis_colored).permute(0, 3, 1, 2) # Convert to (batch, channels, height, width)
|
| 64 |
+
|
| 65 |
+
# Save the image
|
| 66 |
+
torchvision.utils.save_image(vis_colored, output_path, normalize=True)
|
| 67 |
+
|
| 68 |
+
from torchvision import transforms
|
| 69 |
+
|
| 70 |
+
transform = transforms.Compose([
|
| 71 |
+
transforms.Resize((224, 224)), # 调整图像大小
|
| 72 |
+
transforms.ToTensor(), # 转换为tensor
|
| 73 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化
|
| 74 |
+
])
|
| 75 |
+
|
| 76 |
+
images = [
|
| 77 |
+
Image.open("/data0/qiyp/mae/imagenet-1k-samples/0-anime_boy_sticker__holding_kitten__happy.png"),
|
| 78 |
+
Image.open("/data0/qiyp/mae/imagenet-1k-samples/62-Deadpool_minion.png"),
|
| 79 |
+
Image.open("/data0/qiyp/mae/imagenet-1k-samples/79-with_Wooden_carved_bear__salmon_and_gold_mini_ball_surround_the_blank_signboard__illustrate.png"),
|
| 80 |
+
Image.open("/data0/qiyp/mae/imagenet-1k-samples/99-Akira_toriyama_motorbike__cheatah__puma__japanese_classic_car__collectable_figure__shiny_plastic_.png"),
|
| 81 |
+
Image.open("/data0/qiyp/mae/imagenet-1k-samples/124-crowded_1920s_Chicago_street_with_lots_of_model_T_cars_and_people.png"),
|
| 82 |
+
Image.open("/data0/qiyp/mae/imagenet-1k-samples/157-steampunk_girl_with_pink_hair_riding_in_a_hot_air_balloon__hot_air_balloon_resembles_gold_and_si.png"),
|
| 83 |
+
Image.open("/data0/qiyp/mae/imagenet-1k-samples/ILSVRC2012_val_00008636.png"),
|
| 84 |
+
Image.open("/data0/qiyp/mae/imagenet-1k-samples/ILSVRC2012_val_00010240.png"),
|
| 85 |
+
]
|
| 86 |
+
# inputs = processor(images=images, return_tensors="pt", padding=True).to(device)
|
| 87 |
+
tensors = [transform(img) for img in images]
|
| 88 |
+
batched_tensors = torch.stack(tensors).to(device)
|
| 89 |
+
with torch.no_grad():
|
| 90 |
+
outputs = model(batched_tensors, is_training=True)
|
| 91 |
+
features = outputs['x_norm_patchtokens'] # (batch_size, num_patches, feat_dim)
|
| 92 |
+
print(features.shape)
|
| 93 |
+
|
| 94 |
+
features = features.view(-1, patch_h, patch_w, features.shape[2]) # [B, h, w, c]
|
| 95 |
+
features = features.permute(0, 3, 1, 2)
|
| 96 |
+
visualize_features(features)
|
| 97 |
+
|
| 98 |
+
# pooled_output = outputs.pooler_output # pooled CLS states.
|
1_feature_extractor/fast_vis_settings_all.py
ADDED
|
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
import argparse
|
| 4 |
+
import datetime
|
| 5 |
+
import numpy as np
|
| 6 |
+
import time
|
| 7 |
+
import torch
|
| 8 |
+
import torch.backends.cudnn as cudnn
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
from timm.models import create_model
|
| 14 |
+
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
| 15 |
+
from timm.scheduler import create_scheduler
|
| 16 |
+
from timm.optim import create_optimizer
|
| 17 |
+
from timm.utils import NativeScaler, get_state_dict, ModelEma
|
| 18 |
+
from augmentations import collate_data_and_cast_aug
|
| 19 |
+
from datasets import build_dataset
|
| 20 |
+
|
| 21 |
+
from losses_hint import DistillationLoss
|
| 22 |
+
from samplers import RASampler
|
| 23 |
+
from functools import partial
|
| 24 |
+
|
| 25 |
+
import importlib
|
| 26 |
+
import utils
|
| 27 |
+
import random
|
| 28 |
+
import math
|
| 29 |
+
from multiprocessing import Value
|
| 30 |
+
from abc import ABC
|
| 31 |
+
|
| 32 |
+
import sys
|
| 33 |
+
from typing import Iterable, Optional
|
| 34 |
+
from timm.data import Mixup
|
| 35 |
+
from timm.utils import accuracy, ModelEma
|
| 36 |
+
import utils
|
| 37 |
+
import os
|
| 38 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = "5"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class MaskingGenerator(ABC):
|
| 42 |
+
def __init__(self, input_size):
|
| 43 |
+
if not isinstance(input_size, tuple):
|
| 44 |
+
input_size = (input_size,) * 2
|
| 45 |
+
self.height, self.width = input_size
|
| 46 |
+
self.num_patches = self.height * self.width
|
| 47 |
+
|
| 48 |
+
def __repr__(self):
|
| 49 |
+
raise NotImplementedError
|
| 50 |
+
|
| 51 |
+
def get_shape(self):
|
| 52 |
+
return self.height, self.width
|
| 53 |
+
|
| 54 |
+
def _mask(self, mask, max_mask_patches):
|
| 55 |
+
raise NotImplementedError
|
| 56 |
+
|
| 57 |
+
def get_none_mask(self):
|
| 58 |
+
return np.zeros(shape=self.get_shape(), dtype=bool)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class RandomMaskingGenerator(MaskingGenerator):
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
input_size,
|
| 66 |
+
):
|
| 67 |
+
"""
|
| 68 |
+
Args:
|
| 69 |
+
input_size: the size of the token map, e.g., 14x14
|
| 70 |
+
"""
|
| 71 |
+
super().__init__(input_size)
|
| 72 |
+
|
| 73 |
+
def __repr__(self):
|
| 74 |
+
repr_str = f"Random Generator({self.height}, {self.width})"
|
| 75 |
+
return repr_str
|
| 76 |
+
|
| 77 |
+
def _mask(self, mask, max_mask_patches):
|
| 78 |
+
return super()._mask(mask, max_mask_patches)
|
| 79 |
+
|
| 80 |
+
def __call__(self, num_masking_patches=0):
|
| 81 |
+
if num_masking_patches <= 0:
|
| 82 |
+
return np.zeros(shape=self.get_shape(), dtype=bool)
|
| 83 |
+
|
| 84 |
+
mask = np.hstack([np.ones(num_masking_patches, dtype=bool),
|
| 85 |
+
np.zeros(self.num_patches - num_masking_patches, dtype=bool)])
|
| 86 |
+
np.random.shuffle(mask)
|
| 87 |
+
mask = mask.reshape(self.get_shape())
|
| 88 |
+
return mask
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_args_parser():
|
| 92 |
+
parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
|
| 93 |
+
parser.add_argument('--batch-size', default=64, type=int)
|
| 94 |
+
parser.add_argument('--epochs', default=300, type=int)
|
| 95 |
+
parser.add_argument('--bce-loss', action='store_true')
|
| 96 |
+
parser.add_argument('--unscale-lr', action='store_true')
|
| 97 |
+
|
| 98 |
+
# Model parameters
|
| 99 |
+
parser.add_argument('--model', default='deit_base_patch16_224', type=str)
|
| 100 |
+
parser.add_argument('--target_model', default='deit_base_patch16_224', type=str)
|
| 101 |
+
parser.add_argument('--input-size', default=224, type=int, help='images input size')
|
| 102 |
+
|
| 103 |
+
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
|
| 104 |
+
help='Dropout rate (default: 0.)')
|
| 105 |
+
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
|
| 106 |
+
help='Drop path rate (default: 0.1)')
|
| 107 |
+
|
| 108 |
+
parser.add_argument('--model-ema', action='store_true')
|
| 109 |
+
parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
|
| 110 |
+
parser.set_defaults(model_ema=True)
|
| 111 |
+
parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
|
| 112 |
+
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
|
| 113 |
+
|
| 114 |
+
# Optimizer parameters
|
| 115 |
+
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
|
| 116 |
+
help='Optimizer (default: "adamw"')
|
| 117 |
+
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
|
| 118 |
+
help='Optimizer Epsilon (default: 1e-8)')
|
| 119 |
+
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
|
| 120 |
+
help='Optimizer Betas (default: None, use opt default)')
|
| 121 |
+
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
|
| 122 |
+
help='Clip gradient norm (default: None, no clipping)')
|
| 123 |
+
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
| 124 |
+
help='SGD momentum (default: 0.9)')
|
| 125 |
+
parser.add_argument('--weight-decay', type=float, default=0.05,
|
| 126 |
+
help='weight decay (default: 0.05)')
|
| 127 |
+
# Learning rate schedule parameters
|
| 128 |
+
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
|
| 129 |
+
help='LR scheduler (default: "cosine"')
|
| 130 |
+
parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
|
| 131 |
+
help='learning rate (default: 5e-4)')
|
| 132 |
+
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
|
| 133 |
+
help='learning rate noise on/off epoch percentages')
|
| 134 |
+
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
|
| 135 |
+
help='learning rate noise limit percent (default: 0.67)')
|
| 136 |
+
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
|
| 137 |
+
help='learning rate noise std-dev (default: 1.0)')
|
| 138 |
+
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
|
| 139 |
+
help='warmup learning rate (default: 1e-6)')
|
| 140 |
+
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
|
| 141 |
+
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
| 142 |
+
|
| 143 |
+
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
|
| 144 |
+
help='epoch interval to decay LR')
|
| 145 |
+
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
|
| 146 |
+
help='epochs to warmup LR, if scheduler supports')
|
| 147 |
+
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
|
| 148 |
+
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
|
| 149 |
+
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
|
| 150 |
+
help='patience epochs for Plateau LR scheduler (default: 10')
|
| 151 |
+
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
| 152 |
+
help='LR decay rate (default: 0.1)')
|
| 153 |
+
|
| 154 |
+
# Augmentation parameters
|
| 155 |
+
parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT',
|
| 156 |
+
help='Color jitter factor (default: 0.3)')
|
| 157 |
+
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
|
| 158 |
+
help='Use AutoAugment policy. "v0" or "original". " + \
|
| 159 |
+
"(default: rand-m9-mstd0.5-inc1)'),
|
| 160 |
+
parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
|
| 161 |
+
parser.add_argument('--train-interpolation', type=str, default='bicubic',
|
| 162 |
+
help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
|
| 163 |
+
|
| 164 |
+
parser.add_argument('--repeated-aug', action='store_true')
|
| 165 |
+
parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
|
| 166 |
+
parser.set_defaults(repeated_aug=True)
|
| 167 |
+
|
| 168 |
+
parser.add_argument('--train-mode', action='store_true')
|
| 169 |
+
parser.add_argument('--no-train-mode', action='store_false', dest='train_mode')
|
| 170 |
+
parser.set_defaults(train_mode=True)
|
| 171 |
+
|
| 172 |
+
parser.add_argument('--ThreeAugment', action='store_true') #3augment
|
| 173 |
+
|
| 174 |
+
parser.add_argument('--src', action='store_true') #simple random crop
|
| 175 |
+
|
| 176 |
+
# add dataset parameters
|
| 177 |
+
parser.add_argument('--global_crops_size', '--img_size', default=224, type=int,
|
| 178 |
+
help="this should be equal to image size")
|
| 179 |
+
parser.add_argument('--patch_size', default=16, type=int,
|
| 180 |
+
help="patch size for vit patch embedding")
|
| 181 |
+
|
| 182 |
+
# add masking parameter
|
| 183 |
+
parser.add_argument('--mask_ratio', default=(0.1, 0.5), type=float, nargs='+',
|
| 184 |
+
help="mask ratio can be either a value or a range")
|
| 185 |
+
parser.add_argument('--mask_probability', default=0., type=float,
|
| 186 |
+
help="how many samples with be applied with masking")
|
| 187 |
+
parser.add_argument('--mask_first_n', action='store_true',
|
| 188 |
+
help="mask the first n sample to avoid shuffling. Needed for MAE-style encoder")
|
| 189 |
+
parser.add_argument('--clone_batch', default=1, type=int,
|
| 190 |
+
help="how many times to clone the batch for masking (default: 1, not cloning)")
|
| 191 |
+
|
| 192 |
+
# * Random Erase params
|
| 193 |
+
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
|
| 194 |
+
help='Random erase prob (default: 0.25)')
|
| 195 |
+
parser.add_argument('--remode', type=str, default='pixel',
|
| 196 |
+
help='Random erase mode (default: "pixel")')
|
| 197 |
+
parser.add_argument('--recount', type=int, default=1,
|
| 198 |
+
help='Random erase count (default: 1)')
|
| 199 |
+
parser.add_argument('--resplit', action='store_true', default=False,
|
| 200 |
+
help='Do not random erase first (clean) augmentation split')
|
| 201 |
+
|
| 202 |
+
# * Mixup params
|
| 203 |
+
parser.add_argument('--mixup', type=float, default=0.8,
|
| 204 |
+
help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
|
| 205 |
+
parser.add_argument('--cutmix', type=float, default=1.0,
|
| 206 |
+
help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
|
| 207 |
+
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
|
| 208 |
+
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
|
| 209 |
+
parser.add_argument('--mixup-prob', type=float, default=1.0,
|
| 210 |
+
help='Probability of performing mixup or cutmix when either/both is enabled')
|
| 211 |
+
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
|
| 212 |
+
help='Probability of switching to cutmix when both mixup and cutmix enabled')
|
| 213 |
+
parser.add_argument('--mixup-mode', type=str, default='batch',
|
| 214 |
+
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
|
| 215 |
+
|
| 216 |
+
# Distillation parameters
|
| 217 |
+
parser.add_argument('--teacher-model', default='base', type=str)
|
| 218 |
+
parser.add_argument('--teacher-path', type=str, default='')
|
| 219 |
+
parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
|
| 220 |
+
parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
|
| 221 |
+
parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
|
| 222 |
+
parser.add_argument('--lambda_token', type=float, default=1.0)
|
| 223 |
+
parser.add_argument('--lambda_fea', type=float, default=1.0)
|
| 224 |
+
parser.add_argument('--lambda_patch', type=float, default=1.0)
|
| 225 |
+
|
| 226 |
+
# * Cosub params
|
| 227 |
+
parser.add_argument('--cosub', action='store_true')
|
| 228 |
+
|
| 229 |
+
# * Finetuning params
|
| 230 |
+
parser.add_argument('--finetune', default='', help='finetune from checkpoint')
|
| 231 |
+
parser.add_argument('--attn-only', action='store_true')
|
| 232 |
+
parser.add_argument('--weight_inherit', default='')
|
| 233 |
+
|
| 234 |
+
# Dataset parameters
|
| 235 |
+
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
|
| 236 |
+
help='dataset path')
|
| 237 |
+
parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'IMNET_ibot', 'IMNET_ibot_aug', 'IMNET_ibot_fast_aug', 'INAT', 'INAT19', 'IMNET_L', 'IMNET_L_ibot'],
|
| 238 |
+
type=str, help='Image Net dataset path')
|
| 239 |
+
parser.add_argument('--inat-category', default='name',
|
| 240 |
+
choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
|
| 241 |
+
type=str, help='semantic granularity')
|
| 242 |
+
|
| 243 |
+
parser.add_argument('--output_dir', default='',
|
| 244 |
+
help='path where to save, empty for no saving')
|
| 245 |
+
parser.add_argument('--device', default='cuda',
|
| 246 |
+
help='device to use for training / testing')
|
| 247 |
+
parser.add_argument('--seed', default=0, type=int)
|
| 248 |
+
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
| 249 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
| 250 |
+
help='start epoch')
|
| 251 |
+
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
|
| 252 |
+
parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation")
|
| 253 |
+
parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
|
| 254 |
+
parser.add_argument('--num_workers', default=10, type=int)
|
| 255 |
+
parser.add_argument('--pin-mem', action='store_true',
|
| 256 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
| 257 |
+
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
|
| 258 |
+
help='')
|
| 259 |
+
parser.set_defaults(pin_mem=True)
|
| 260 |
+
|
| 261 |
+
# distributed training parameters
|
| 262 |
+
parser.add_argument('--distributed', action='store_true', default=False, help='Enabling distributed training')
|
| 263 |
+
parser.add_argument('--world_size', default=1, type=int,
|
| 264 |
+
help='number of distributed processes')
|
| 265 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
| 266 |
+
return parser
|
| 267 |
+
|
| 268 |
+
import torchvision
|
| 269 |
+
import matplotlib.pyplot as plt
|
| 270 |
+
import torchvision.transforms as transforms
|
| 271 |
+
def visualize_features(features, output_path='./feature_visualization_w_ib.png'):
|
| 272 |
+
# Assuming features are of shape (batch_size, num_features, height, width)
|
| 273 |
+
batch_size, num_features, height, width = features.shape
|
| 274 |
+
|
| 275 |
+
# Normalize the feature maps to the range [0, 1]
|
| 276 |
+
vis = features.mean(dim=1, keepdim=True)
|
| 277 |
+
vis = vis - vis.min()
|
| 278 |
+
vis = vis / vis.max()
|
| 279 |
+
|
| 280 |
+
# Squeeze the channel dimension
|
| 281 |
+
vis = vis.squeeze(1).cpu().detach().numpy()
|
| 282 |
+
|
| 283 |
+
# Apply a colormap (e.g., viridis) to convert it to RGB
|
| 284 |
+
vis_colored = np.zeros((batch_size, height, width, 3))
|
| 285 |
+
for i in range(batch_size):
|
| 286 |
+
vis_colored[i] = plt.cm.viridis(vis[i])[:, :, :3] # Drop the alpha channel
|
| 287 |
+
|
| 288 |
+
# Convert vis_colored to a tensor and save using torchvision
|
| 289 |
+
vis_colored = torch.tensor(vis_colored).permute(0, 3, 1, 2) # Convert to (batch, channels, height, width)
|
| 290 |
+
|
| 291 |
+
# Save the image
|
| 292 |
+
torchvision.utils.save_image(vis_colored, output_path, normalize=True)
|
| 293 |
+
|
| 294 |
+
def save_original_images(tensors, output_path='./original_images.png'):
|
| 295 |
+
# 将归一化反转
|
| 296 |
+
unnormalize = transforms.Normalize(
|
| 297 |
+
mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
|
| 298 |
+
std=[1/0.229, 1/0.224, 1/0.225]
|
| 299 |
+
)
|
| 300 |
+
unnormalized_tensors = [unnormalize(tensor) for tensor in tensors]
|
| 301 |
+
unnormalized_batch = torch.stack(unnormalized_tensors)
|
| 302 |
+
torchvision.utils.save_image(unnormalized_batch, output_path, nrow=4, normalize=True)
|
| 303 |
+
|
| 304 |
+
def main(args):
|
| 305 |
+
utils.init_distributed_mode(args)
|
| 306 |
+
|
| 307 |
+
print(args)
|
| 308 |
+
|
| 309 |
+
device = torch.device(args.device)
|
| 310 |
+
|
| 311 |
+
# fix the seed for reproducibility
|
| 312 |
+
seed = args.seed + utils.get_rank()
|
| 313 |
+
torch.manual_seed(seed)
|
| 314 |
+
np.random.seed(seed)
|
| 315 |
+
# random.seed(seed)
|
| 316 |
+
|
| 317 |
+
cudnn.benchmark = True
|
| 318 |
+
|
| 319 |
+
dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
|
| 320 |
+
|
| 321 |
+
if args.distributed:
|
| 322 |
+
num_tasks = utils.get_world_size()
|
| 323 |
+
global_rank = utils.get_rank()
|
| 324 |
+
if args.repeated_aug:
|
| 325 |
+
sampler_train = RASampler(
|
| 326 |
+
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
| 327 |
+
)
|
| 328 |
+
else:
|
| 329 |
+
sampler_train = torch.utils.data.DistributedSampler(
|
| 330 |
+
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
| 331 |
+
)
|
| 332 |
+
else:
|
| 333 |
+
sampler_train = torch.utils.data.RandomSampler(dataset_train)
|
| 334 |
+
|
| 335 |
+
n_tokens = (args.global_crops_size // args.patch_size) ** 2
|
| 336 |
+
mask_generator = RandomMaskingGenerator(
|
| 337 |
+
input_size=args.global_crops_size // args.patch_size,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
collate_fn = partial(
|
| 341 |
+
collate_data_and_cast_aug,
|
| 342 |
+
mask_ratio=args.mask_ratio,
|
| 343 |
+
mask_probability=args.mask_probability,
|
| 344 |
+
dtype=torch.half, # half precision
|
| 345 |
+
n_tokens=n_tokens,
|
| 346 |
+
mask_first_n=args.mask_first_n,
|
| 347 |
+
mask_generator=mask_generator,
|
| 348 |
+
clone_batch=args.clone_batch,
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
data_loader_train = torch.utils.data.DataLoader(
|
| 352 |
+
dataset_train, sampler=sampler_train,
|
| 353 |
+
batch_size=args.batch_size,
|
| 354 |
+
num_workers=args.num_workers,
|
| 355 |
+
pin_memory=args.pin_mem,
|
| 356 |
+
drop_last=True,
|
| 357 |
+
collate_fn=collate_fn,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
mixup_fn = None
|
| 361 |
+
|
| 362 |
+
print(f"Creating model: {args.model}")
|
| 363 |
+
meta_arch_module = importlib.import_module(args.model)
|
| 364 |
+
MetaArch = meta_arch_module.MetaArch
|
| 365 |
+
|
| 366 |
+
model = MetaArch(args)
|
| 367 |
+
|
| 368 |
+
if args.finetune:
|
| 369 |
+
checkpoint = torch.load(args.finetune, map_location='cpu')
|
| 370 |
+
|
| 371 |
+
if 'state_dict' in checkpoint:
|
| 372 |
+
pretrained_dict = checkpoint['state_dict']
|
| 373 |
+
elif 'model' in checkpoint:
|
| 374 |
+
pretrained_dict = checkpoint['model']
|
| 375 |
+
else:
|
| 376 |
+
pretrained_dict = checkpoint
|
| 377 |
+
|
| 378 |
+
missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False)
|
| 379 |
+
print('missing_keys: ', missing_keys)
|
| 380 |
+
print('unexpected_keys: ', unexpected_keys)
|
| 381 |
+
|
| 382 |
+
if args.attn_only:
|
| 383 |
+
for name_p,p in model.named_parameters():
|
| 384 |
+
if '.attn.' in name_p:
|
| 385 |
+
p.requires_grad = True
|
| 386 |
+
else:
|
| 387 |
+
p.requires_grad = False
|
| 388 |
+
try:
|
| 389 |
+
model.head.weight.requires_grad = True
|
| 390 |
+
model.head.bias.requires_grad = True
|
| 391 |
+
except:
|
| 392 |
+
model.fc.weight.requires_grad = True
|
| 393 |
+
model.fc.bias.requires_grad = True
|
| 394 |
+
try:
|
| 395 |
+
model.pos_embed.requires_grad = True
|
| 396 |
+
except:
|
| 397 |
+
print('no position encoding')
|
| 398 |
+
try:
|
| 399 |
+
for p in model.patch_embed.parameters():
|
| 400 |
+
p.requires_grad = False
|
| 401 |
+
except:
|
| 402 |
+
print('no patch embed')
|
| 403 |
+
|
| 404 |
+
model.to(device)
|
| 405 |
+
|
| 406 |
+
model_ema = None
|
| 407 |
+
if args.model_ema:
|
| 408 |
+
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
|
| 409 |
+
model_ema = ModelEma(
|
| 410 |
+
model.student.backbone,
|
| 411 |
+
decay=args.model_ema_decay,
|
| 412 |
+
device='cpu' if args.model_ema_force_cpu else '',
|
| 413 |
+
resume='')
|
| 414 |
+
|
| 415 |
+
model_without_ddp = model
|
| 416 |
+
if args.distributed:
|
| 417 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
| 418 |
+
model_without_ddp = model.module
|
| 419 |
+
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 420 |
+
print('number of params:', n_parameters)
|
| 421 |
+
|
| 422 |
+
if not args.unscale_lr:
|
| 423 |
+
linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
|
| 424 |
+
args.lr = linear_scaled_lr
|
| 425 |
+
|
| 426 |
+
optimizer = create_optimizer(args, model_without_ddp)
|
| 427 |
+
loss_scaler = NativeScaler()
|
| 428 |
+
|
| 429 |
+
lr_scheduler, _ = create_scheduler(args, optimizer)
|
| 430 |
+
|
| 431 |
+
output_dir = Path(args.output_dir)
|
| 432 |
+
if args.resume:
|
| 433 |
+
if args.resume.startswith('https'):
|
| 434 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 435 |
+
args.resume, map_location='cpu', check_hash=True)
|
| 436 |
+
else:
|
| 437 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
| 438 |
+
|
| 439 |
+
model_without_ddp.load_state_dict(checkpoint['model'])
|
| 440 |
+
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
|
| 441 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 442 |
+
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
| 443 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
| 444 |
+
if args.model_ema:
|
| 445 |
+
utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
|
| 446 |
+
if 'scaler' in checkpoint:
|
| 447 |
+
loss_scaler.load_state_dict(checkpoint['scaler'])
|
| 448 |
+
lr_scheduler.step(args.start_epoch)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
from torchvision import transforms
|
| 452 |
+
transform = transforms.Compose([
|
| 453 |
+
transforms.Resize((224, 224)), # 调整图像大小
|
| 454 |
+
transforms.ToTensor(), # 转换为tensor
|
| 455 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化
|
| 456 |
+
])
|
| 457 |
+
from PIL import Image
|
| 458 |
+
images = [
|
| 459 |
+
Image.open("/data0/qiyp/mae/imagenet-1k-samples/0-anime_boy_sticker__holding_kitten__happy.png"),
|
| 460 |
+
Image.open("/data0/qiyp/mae/imagenet-1k-samples/62-Deadpool_minion.png"),
|
| 461 |
+
Image.open("/data0/qiyp/mae/imagenet-1k-samples/79-with_Wooden_carved_bear__salmon_and_gold_mini_ball_surround_the_blank_signboard__illustrate.png"),
|
| 462 |
+
Image.open("/data0/qiyp/mae/imagenet-1k-samples/99-Akira_toriyama_motorbike__cheatah__puma__japanese_classic_car__collectable_figure__shiny_plastic_.png"),
|
| 463 |
+
Image.open("/data0/qiyp/mae/imagenet-1k-samples/124-crowded_1920s_Chicago_street_with_lots_of_model_T_cars_and_people.png"),
|
| 464 |
+
Image.open("/data0/qiyp/mae/imagenet-1k-samples/157-steampunk_girl_with_pink_hair_riding_in_a_hot_air_balloon__hot_air_balloon_resembles_gold_and_si.png"),
|
| 465 |
+
Image.open("/data0/qiyp/mae/imagenet-1k-samples/ILSVRC2012_val_00008636.png"),
|
| 466 |
+
Image.open("/data0/qiyp/mae/imagenet-1k-samples/ILSVRC2012_val_00010240.png"),
|
| 467 |
+
]
|
| 468 |
+
|
| 469 |
+
tensors = [transform(img) for img in images]
|
| 470 |
+
batched_tensors = torch.stack(tensors).to(device)
|
| 471 |
+
save_original_images(batched_tensors, output_path='./original_images.png')
|
| 472 |
+
with torch.no_grad():
|
| 473 |
+
outputs = model.student.backbone(batched_tensors, is_training=True)
|
| 474 |
+
features = outputs['x_norm_patchtokens'] # (batch_size, num_patches, feat_dim)
|
| 475 |
+
print(features.shape)
|
| 476 |
+
features, _ = model.info_bottleneck(features, is_training=False)
|
| 477 |
+
|
| 478 |
+
features = features.view(-1, 16, 16, features.shape[2]) # [B, h, w, c]
|
| 479 |
+
features = features.permute(0, 3, 1, 2)
|
| 480 |
+
visualize_features(features)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def train_one_epoch(model: torch.nn.Module,
|
| 484 |
+
data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
| 485 |
+
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
|
| 486 |
+
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
|
| 487 |
+
set_training_mode=True, args = None):
|
| 488 |
+
model.train(set_training_mode)
|
| 489 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
| 490 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
| 491 |
+
header = 'Epoch: [{}]'.format(epoch)
|
| 492 |
+
print_freq = 10
|
| 493 |
+
|
| 494 |
+
loader_len = len(data_loader)
|
| 495 |
+
|
| 496 |
+
for data_iter_step, inputs_dict in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
| 497 |
+
|
| 498 |
+
for k, v in inputs_dict.items():
|
| 499 |
+
if isinstance(v, torch.Tensor):
|
| 500 |
+
inputs_dict[k] = v.to(device, non_blocking=True)
|
| 501 |
+
|
| 502 |
+
with torch.cuda.amp.autocast():
|
| 503 |
+
loss_dict = model(inputs_dict)
|
| 504 |
+
|
| 505 |
+
loss = loss_dict["loss"]
|
| 506 |
+
patch_loss = loss_dict["patch_loss"]
|
| 507 |
+
fea_loss = loss_dict["fea_loss"]
|
| 508 |
+
token_loss = loss_dict["token_loss"]
|
| 509 |
+
|
| 510 |
+
patch_loss_value = patch_loss.item()
|
| 511 |
+
token_loss_value = token_loss.item()
|
| 512 |
+
fea_loss_value = fea_loss.item()
|
| 513 |
+
loss_value = loss.item()
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
if not math.isfinite(loss_value):
|
| 517 |
+
print("Loss is {}, stopping training".format(loss_value))
|
| 518 |
+
sys.exit(1)
|
| 519 |
+
|
| 520 |
+
optimizer.zero_grad()
|
| 521 |
+
|
| 522 |
+
# this attribute is added by timm on one optimizer (adahessian)
|
| 523 |
+
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
|
| 524 |
+
loss_scaler(loss, optimizer, clip_grad=max_norm,
|
| 525 |
+
parameters=model.parameters(), create_graph=is_second_order)
|
| 526 |
+
|
| 527 |
+
torch.cuda.synchronize()
|
| 528 |
+
if model_ema is not None:
|
| 529 |
+
model_ema.update(model.module.student.backbone)
|
| 530 |
+
|
| 531 |
+
metric_logger.update(loss=loss_value)
|
| 532 |
+
metric_logger.update(patch_loss=patch_loss_value)
|
| 533 |
+
metric_logger.update(token_loss=token_loss_value)
|
| 534 |
+
metric_logger.update(fea_loss=fea_loss_value)
|
| 535 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
| 536 |
+
# gather the stats from all processes
|
| 537 |
+
metric_logger.synchronize_between_processes()
|
| 538 |
+
print("Averaged stats:", metric_logger)
|
| 539 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
if __name__ == '__main__':
|
| 544 |
+
parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
|
| 545 |
+
args = parser.parse_args()
|
| 546 |
+
if args.output_dir:
|
| 547 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 548 |
+
main(args)
|
1_feature_extractor/log/DINOv2_training/log.txt
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"train_lr": 1.0000000000000353e-06, "train_loss": 9.738612308347825, "train_task_loss": 4.1162851987411075, "train_bpp_loss": 28.111635084060744, "train_patch_loss": 2.1926947419615765, "train_token_loss": 0.6686155230319091, "train_fea_loss": 1.254974935296104, "epoch": 0, "n_parameters": 144845568}
|
| 2 |
+
{"train_lr": 1.0000000000000353e-06, "train_loss": 9.232701392577802, "train_task_loss": 3.6589209317660734, "train_bpp_loss": 27.86890183459464, "train_patch_loss": 1.9597956623069102, "train_token_loss": 0.6684419002523906, "train_fea_loss": 1.030683369156149, "epoch": 1, "n_parameters": 144845568}
|
| 3 |
+
{"train_lr": 7.579999999999412e-05, "train_loss": 7.21169285546485, "train_task_loss": 3.139855162607585, "train_bpp_loss": 20.35918810360914, "train_patch_loss": 1.6089605613649607, "train_token_loss": 0.6681283143200843, "train_fea_loss": 0.8627662869969861, "epoch": 2, "n_parameters": 144845568}
|
| 4 |
+
{"train_lr": 0.0001506000000000007, "train_loss": 4.554179431234809, "train_task_loss": 3.242361048226067, "train_bpp_loss": 6.559091770915772, "train_patch_loss": 1.606790432570983, "train_token_loss": 0.6681644691438745, "train_fea_loss": 0.9674061452158444, "epoch": 3, "n_parameters": 144845568}
|
| 5 |
+
{"train_lr": 0.00022540000000000762, "train_loss": 3.735120667863807, "train_task_loss": 3.387666613328085, "train_bpp_loss": 1.7372702604485788, "train_patch_loss": 1.6483151540636636, "train_token_loss": 0.6679934723992095, "train_fea_loss": 1.0713579858914555, "epoch": 4, "n_parameters": 144845568}
|
| 6 |
+
{"train_lr": 1.0000000000000597e-06, "train_loss": 5.909162417715259, "train_task_loss": 3.1000390200824333, "train_bpp_loss": 28.091233481177323, "train_patch_loss": 1.2870068884135293, "train_token_loss": 0.6689227937365607, "train_fea_loss": 1.1441093232500539, "epoch": 0, "n_parameters": 144845568}
|
| 7 |
+
{"train_lr": 1.0000000000000597e-06, "train_loss": 5.4208089515959905, "train_task_loss": 2.6417781315404447, "train_bpp_loss": 27.790307712640693, "train_patch_loss": 1.0698811880355295, "train_token_loss": 0.6684176272018064, "train_fea_loss": 0.9034793063801684, "epoch": 1, "n_parameters": 144845568}
|
| 8 |
+
{"train_lr": 2.330000000000072e-05, "train_loss": 4.90253949746025, "train_task_loss": 2.468064440561713, "train_bpp_loss": 24.344750106834965, "train_patch_loss": 0.9905159351565569, "train_token_loss": 0.6681102167086885, "train_fea_loss": 0.809438281672464, "epoch": 2, "n_parameters": 144845568}
|
| 9 |
+
{"train_lr": 4.5600000000004605e-05, "train_loss": 4.039170976486995, "train_task_loss": 2.4127648418648637, "train_bpp_loss": 16.264061019539263, "train_patch_loss": 0.9645157243489183, "train_token_loss": 0.668388756130525, "train_fea_loss": 0.7798603551274688, "epoch": 3, "n_parameters": 144845568}
|
| 10 |
+
{"train_lr": 6.790000000000497e-05, "train_loss": 3.154224511977437, "train_task_loss": 2.4628324630252605, "train_bpp_loss": 6.913920321606761, "train_patch_loss": 0.9734525120118075, "train_token_loss": 0.668419569109877, "train_fea_loss": 0.8209603751900134, "epoch": 4, "n_parameters": 144845568}
|
| 11 |
+
{"train_lr": 9.019999999999779e-05, "train_loss": 2.855963341034145, "train_task_loss": 2.595376729947343, "train_bpp_loss": 2.605866110010399, "train_patch_loss": 1.0060626347426422, "train_token_loss": 0.7175173823357665, "train_fea_loss": 0.8717967045071314, "epoch": 5, "n_parameters": 144845568}
|
| 12 |
+
{"train_lr": 0.00011234201335381617, "train_loss": 2.8053760061786472, "train_task_loss": 2.636510731433507, "train_bpp_loss": 1.6886527469109216, "train_patch_loss": 1.019217302998175, "train_token_loss": 0.7347769178539443, "train_fea_loss": 0.8825165003926742, "epoch": 6, "n_parameters": 144845568}
|
| 13 |
+
{"train_lr": 0.00011227255068590994, "train_loss": 2.761493076880773, "train_task_loss": 2.6242000526464957, "train_bpp_loss": 1.3729302407391994, "train_patch_loss": 1.0161546610956855, "train_token_loss": 0.735874281325143, "train_fea_loss": 0.8721711006757024, "epoch": 7, "n_parameters": 144845568}
|
| 14 |
+
{"train_lr": 0.0001121904989670886, "train_loss": 2.7252962913819783, "train_task_loss": 2.6068851682994008, "train_bpp_loss": 1.1841112330144972, "train_patch_loss": 1.01152529233812, "train_token_loss": 0.7345748117854662, "train_fea_loss": 0.8607850549831915, "epoch": 8, "n_parameters": 144845568}
|
| 15 |
+
{"train_lr": 0.00011209587844235662, "train_loss": 2.6992331013548, "train_task_loss": 2.594223390594661, "train_bpp_loss": 1.0500971039855402, "train_patch_loss": 1.00772097030525, "train_token_loss": 0.7340729326429627, "train_fea_loss": 0.8524294771026257, "epoch": 9, "n_parameters": 144845568}
|
| 16 |
+
{"train_lr": 0.0001119887124579783, "train_loss": 2.678213749384637, "train_task_loss": 2.5821204649494423, "train_bpp_loss": 0.9609328405931592, "train_patch_loss": 1.0029366761982013, "train_token_loss": 0.7342181133186753, "train_fea_loss": 0.8449656662421761, "epoch": 10, "n_parameters": 144845568}
|
| 17 |
+
{"train_lr": 0.00011186902745551124, "train_loss": 2.662179193613555, "train_task_loss": 2.571508611331312, "train_bpp_loss": 0.9067058223556522, "train_patch_loss": 0.9990729609732147, "train_token_loss": 0.7330601164686987, "train_fea_loss": 0.8393755242634985, "epoch": 11, "n_parameters": 144845568}
|
| 18 |
+
{"train_lr": 0.00011173685296543875, "train_loss": 2.6492657475530814, "train_task_loss": 2.562294816915437, "train_bpp_loss": 0.8697093047259702, "train_patch_loss": 0.995477003563526, "train_token_loss": 0.7320167214611916, "train_fea_loss": 0.834801082157617, "epoch": 12, "n_parameters": 144845568}
|
| 19 |
+
{"train_lr": 0.00011159222159984347, "train_loss": 2.6404250215033263, "train_task_loss": 2.5562614543492987, "train_bpp_loss": 0.8416356684604316, "train_patch_loss": 0.9931515277579999, "train_token_loss": 0.7319663158244938, "train_fea_loss": 0.8311436018064546, "epoch": 13, "n_parameters": 144845568}
|
| 20 |
+
{"train_lr": 0.00011143516904437194, "train_loss": 2.631879109013781, "train_task_loss": 2.5498513719625326, "train_bpp_loss": 0.8202773725337011, "train_patch_loss": 0.99033929754665, "train_token_loss": 0.7316777065275706, "train_fea_loss": 0.8278343591714006, "epoch": 14, "n_parameters": 144845568}
|
| 21 |
+
{"train_lr": 0.00011126573404935182, "train_loss": 2.6259225274054265, "train_task_loss": 2.5454919715579467, "train_bpp_loss": 0.8043055602369441, "train_patch_loss": 0.9886904231225737, "train_token_loss": 0.731208170509542, "train_fea_loss": 0.8255933674758108, "epoch": 15, "n_parameters": 144845568}
|
| 22 |
+
{"train_lr": 0.0001110839584203689, "train_loss": 2.6198472417533685, "train_task_loss": 2.5406655372529148, "train_bpp_loss": 0.791817048280913, "train_patch_loss": 0.9866770549277536, "train_token_loss": 0.7308970645976831, "train_fea_loss": 0.8230914076035328, "epoch": 16, "n_parameters": 144845568}
|
| 23 |
+
{"train_lr": 0.0001108898870078482, "train_loss": 2.6147061569584933, "train_task_loss": 2.5367300099552534, "train_bpp_loss": 0.7797614707582468, "train_patch_loss": 0.985000457371221, "train_token_loss": 0.7312123553148455, "train_fea_loss": 0.8205171879849953, "epoch": 17, "n_parameters": 144845568}
|
| 24 |
+
{"train_lr": 0.00011068356769595686, "train_loss": 2.609768410336128, "train_task_loss": 2.5328920738910026, "train_bpp_loss": 0.7687633632072549, "train_patch_loss": 0.9837793158322609, "train_token_loss": 0.7307359146476864, "train_fea_loss": 0.8183768322217844, "epoch": 18, "n_parameters": 144845568}
|
| 25 |
+
{"train_lr": 0.0001104650513909484, "train_loss": 2.6053047203313677, "train_task_loss": 2.52939695734486, "train_bpp_loss": 0.7590776344502179, "train_patch_loss": 0.9826806558796363, "train_token_loss": 0.7301465437459431, "train_fea_loss": 0.8165697485534585, "epoch": 19, "n_parameters": 144845568}
|
| 26 |
+
{"train_lr": 0.00011023439200841275, "train_loss": 2.60126930419847, "train_task_loss": 2.526219468670998, "train_bpp_loss": 0.7504983498341888, "train_patch_loss": 0.981161863478325, "train_token_loss": 0.7301577162488901, "train_fea_loss": 0.8148998804368096, "epoch": 20, "n_parameters": 144845568}
|
| 27 |
+
{"train_lr": 0.00010973687463990103, "train_loss": 2.6637201429449684, "train_task_loss": 2.5465192173382074, "train_bpp_loss": 0.5860046291750648, "train_patch_loss": 0.9867626110592084, "train_token_loss": 0.7298668318939545, "train_fea_loss": 0.8298897645157054, "epoch": 21, "n_parameters": 144845568}
|
| 28 |
+
{"train_lr": 0.00010973687463990103, "train_loss": 2.6604584804970584, "train_task_loss": 2.5488337897550433, "train_bpp_loss": 0.5581234564080084, "train_patch_loss": 0.98777336684643, "train_token_loss": 0.730239200644749, "train_fea_loss": 0.8308212128188684, "epoch": 22, "n_parameters": 144845568}
|
| 29 |
+
{"train_lr": 0.00010947013940891518, "train_loss": 2.6580159922512316, "train_task_loss": 2.5488927058553954, "train_bpp_loss": 0.5456164319233464, "train_patch_loss": 0.9876678835765885, "train_token_loss": 0.7302756367019004, "train_fea_loss": 0.8309491750956189, "epoch": 23, "n_parameters": 144845568}
|
| 30 |
+
{"train_lr": 0.00010919150658002209, "train_loss": 2.6552707054864446, "train_task_loss": 2.548027373189263, "train_bpp_loss": 0.5362166606305208, "train_patch_loss": 0.987526344177838, "train_token_loss": 0.7304425147100747, "train_fea_loss": 0.830058504461045, "epoch": 24, "n_parameters": 144845568}
|
| 31 |
+
{"train_lr": 0.00010890104490178482, "train_loss": 2.6532317997365116, "train_task_loss": 2.5475874447243676, "train_bpp_loss": 0.5282217790852848, "train_patch_loss": 0.9868854832560788, "train_token_loss": 0.7307839357099838, "train_fea_loss": 0.8299180160475363, "epoch": 25, "n_parameters": 144845568}
|
| 32 |
+
{"train_lr": 0.00010828492456631287, "train_loss": 2.6756018440750577, "train_task_loss": 2.554671129796931, "train_bpp_loss": 0.48372285632077133, "train_patch_loss": 0.9893747716762429, "train_token_loss": 0.7302002353723244, "train_fea_loss": 0.83509611297506, "epoch": 26, "n_parameters": 144845568}
|
| 33 |
+
{"train_lr": 0.00010828492456631287, "train_loss": 2.672080845176745, "train_task_loss": 2.5541286950649544, "train_bpp_loss": 0.4718086026637978, "train_patch_loss": 0.9892294842401426, "train_token_loss": 0.7304161941873835, "train_fea_loss": 0.8344830075631754, "epoch": 27, "n_parameters": 144845568}
|
| 34 |
+
{"train_lr": 0.00010795941792757138, "train_loss": 2.6682109270039365, "train_task_loss": 2.552076613383113, "train_bpp_loss": 0.4645372550583911, "train_patch_loss": 0.9893483498738967, "train_token_loss": 0.7297189640332398, "train_fea_loss": 0.8330092908327671, "epoch": 28, "n_parameters": 144845568}
|
| 35 |
+
{"train_lr": 0.00010762238643889585, "train_loss": 2.6648726280626778, "train_task_loss": 2.5502606650068462, "train_bpp_loss": 0.458447851656045, "train_patch_loss": 0.988596663794679, "train_token_loss": 0.7296670780559595, "train_fea_loss": 0.8319969125855848, "epoch": 29, "n_parameters": 144845568}
|
| 36 |
+
{"train_lr": 0.00010727391325772412, "train_loss": 2.66122928966614, "train_task_loss": 2.5479887132521726, "train_bpp_loss": 0.4529623082800985, "train_patch_loss": 0.9877295848607517, "train_token_loss": 0.7298175230917099, "train_fea_loss": 0.8304415963706281, "epoch": 30, "n_parameters": 144845568}
|
| 37 |
+
{"train_lr": 0.00010691408436465084, "train_loss": 2.6585902343425962, "train_task_loss": 2.546459772737978, "train_bpp_loss": 0.44852184575292275, "train_patch_loss": 0.9874212175062734, "train_token_loss": 0.7292956591402884, "train_fea_loss": 0.8297428859540705, "epoch": 31, "n_parameters": 144845568}
|
| 38 |
+
{"train_lr": 0.00010654298854205488, "train_loss": 2.655904775418395, "train_task_loss": 2.5447790948875086, "train_bpp_loss": 0.44450271940216135, "train_patch_loss": 0.9865645243444567, "train_token_loss": 0.7294259656095605, "train_fea_loss": 0.8287885952718573, "epoch": 32, "n_parameters": 144845568}
|
| 39 |
+
{"train_lr": 0.0001061607173522522, "train_loss": 2.653865991727554, "train_task_loss": 2.5435698100184796, "train_bpp_loss": 0.4411847302678981, "train_patch_loss": 0.98615532235535, "train_token_loss": 0.7293780767236003, "train_fea_loss": 0.8280364017178722, "epoch": 33, "n_parameters": 144845568}
|
| 40 |
+
{"train_lr": 0.00010576736511496153, "train_loss": 2.649993774818359, "train_task_loss": 2.540434418590092, "train_bpp_loss": 0.4382374218332289, "train_patch_loss": 0.9851716845977506, "train_token_loss": 0.7283954293907082, "train_fea_loss": 0.8268672945759565, "epoch": 34, "n_parameters": 144845568}
|
| 41 |
+
{"train_lr": 0.00010536302888396166, "train_loss": 2.64796845856116, "train_task_loss": 2.539067151297411, "train_bpp_loss": 0.4356052296792271, "train_patch_loss": 0.9849777258128571, "train_token_loss": 0.7279274276853036, "train_fea_loss": 0.8261619895200828, "epoch": 35, "n_parameters": 144845568}
|
| 42 |
+
{"train_lr": 0.00010494780842314013, "train_loss": 2.646916040002728, "train_task_loss": 2.538628663867712, "train_bpp_loss": 0.43314950562102333, "train_patch_loss": 0.9843870765022594, "train_token_loss": 0.7287507657747831, "train_fea_loss": 0.8254908122640434, "epoch": 36, "n_parameters": 144845568}
|
| 43 |
+
{"train_lr": 0.00010452180618197858, "train_loss": 2.6438271829979025, "train_task_loss": 2.5360889318678304, "train_bpp_loss": 0.43095300335781, "train_patch_loss": 0.9832675016509722, "train_token_loss": 0.7280113412965116, "train_fea_loss": 0.8248100794753511, "epoch": 37, "n_parameters": 144845568}
|
| 44 |
+
{"train_lr": 0.00010408512727011787, "train_loss": 2.6436403147715457, "train_task_loss": 2.536438628754241, "train_bpp_loss": 0.42880674428697363, "train_patch_loss": 0.9832872114673579, "train_token_loss": 0.7284368306166119, "train_fea_loss": 0.8247145772900442, "epoch": 38, "n_parameters": 144845568}
|
| 45 |
+
{"train_lr": 0.00010363787943157281, "train_loss": 2.640265642149414, "train_task_loss": 2.5335381660583636, "train_bpp_loss": 0.42690990286783204, "train_patch_loss": 0.982539902213398, "train_token_loss": 0.7274603552767913, "train_fea_loss": 0.8235378990137427, "epoch": 39, "n_parameters": 144845568}
|
| 46 |
+
{"train_lr": 0.0001031801730180277, "train_loss": 2.638479518384742, "train_task_loss": 2.532214540898871, "train_bpp_loss": 0.4250599086704514, "train_patch_loss": 0.9822922450729268, "train_token_loss": 0.7268436730546738, "train_fea_loss": 0.8230786147199088, "epoch": 40, "n_parameters": 144845568}
|
| 47 |
+
{"train_lr": 0.00010271212096170505, "train_loss": 2.6368867533473517, "train_task_loss": 2.5310158089464374, "train_bpp_loss": 0.4234837778348127, "train_patch_loss": 0.9816671004675204, "train_token_loss": 0.7268876022875881, "train_fea_loss": 0.8224610961522523, "epoch": 41, "n_parameters": 144845568}
|
| 48 |
+
{"train_lr": 0.00010223383874746677, "train_loss": 2.6354142183826554, "train_task_loss": 2.529919751140354, "train_bpp_loss": 0.42197786968612344, "train_patch_loss": 0.9807341171861839, "train_token_loss": 0.7272372794665879, "train_fea_loss": 0.8219483458689398, "epoch": 42, "n_parameters": 144845568}
|
| 49 |
+
{"train_lr": 0.00010174544438424974, "train_loss": 2.634054534628594, "train_task_loss": 2.528918670253645, "train_bpp_loss": 0.4205434575589816, "train_patch_loss": 0.9803530661996713, "train_token_loss": 0.7265599571498964, "train_fea_loss": 0.8220056383613702, "epoch": 43, "n_parameters": 144845568}
|
| 50 |
+
{"train_lr": 0.00010124705837609591, "train_loss": 2.6349923593892184, "train_task_loss": 2.5301807783186723, "train_bpp_loss": 0.4192463266741422, "train_patch_loss": 0.9804664164607366, "train_token_loss": 0.726979716709978, "train_fea_loss": 0.8227346358883331, "epoch": 44, "n_parameters": 144845568}
|
| 51 |
+
{"train_lr": 0.00010073880369226542, "train_loss": 2.63449551127583, "train_task_loss": 2.529952161675163, "train_bpp_loss": 0.4181733996109318, "train_patch_loss": 0.9800947665212323, "train_token_loss": 0.7259607360401242, "train_fea_loss": 0.8238966495861753, "epoch": 45, "n_parameters": 144845568}
|
| 52 |
+
{"train_lr": 0.00010022080573700511, "train_loss": 2.6345730214078222, "train_task_loss": 2.530297286192076, "train_bpp_loss": 0.4171029411935263, "train_patch_loss": 0.9800362460766372, "train_token_loss": 0.7260628667832029, "train_fea_loss": 0.8241981637554703, "epoch": 46, "n_parameters": 144845568}
|
| 53 |
+
{"train_lr": 9.969319231856176e-05, "train_loss": 2.6355131778874985, "train_task_loss": 2.5314376252464394, "train_bpp_loss": 0.4163022163889105, "train_patch_loss": 0.9805744725929426, "train_token_loss": 0.7263783447029029, "train_fea_loss": 0.8244847989500427, "epoch": 47, "n_parameters": 144845568}
|
| 54 |
+
{"train_lr": 9.915609361765753e-05, "train_loss": 2.634658775285637, "train_task_loss": 2.5308349787099043, "train_bpp_loss": 0.4152951848266615, "train_patch_loss": 0.9807003051248029, "train_token_loss": 0.725708744122381, "train_fea_loss": 0.8244259209490473, "epoch": 48, "n_parameters": 144845568}
|
| 55 |
+
{"train_lr": 9.860964215535301e-05, "train_loss": 2.6395780852765776, "train_task_loss": 2.536209069951761, "train_bpp_loss": 0.4134760623657071, "train_patch_loss": 0.9824451866852258, "train_token_loss": 0.7274219641725508, "train_fea_loss": 0.8263419094569177, "epoch": 49, "n_parameters": 144845568}
|
| 56 |
+
{"train_lr": 9.805397276035986e-05, "train_loss": 2.639889271710988, "train_task_loss": 2.536885856259927, "train_bpp_loss": 0.4120136631007264, "train_patch_loss": 0.9825187112367768, "train_token_loss": 0.7275711355386949, "train_fea_loss": 0.8267959996776508, "epoch": 50, "n_parameters": 144845568}
|
| 57 |
+
{"train_lr": 9.748922253581646e-05, "train_loss": 2.6434670704320893, "train_task_loss": 2.5408570495137064, "train_bpp_loss": 0.4104400824474023, "train_patch_loss": 0.9840436073498046, "train_token_loss": 0.7284569417749592, "train_fea_loss": 0.828356491024349, "epoch": 51, "n_parameters": 144845568}
|
| 58 |
+
{"train_lr": 9.691553082535863e-05, "train_loss": 2.6544609911453954, "train_task_loss": 2.5523642709644005, "train_bpp_loss": 0.40838687800259577, "train_patch_loss": 0.9880639841150252, "train_token_loss": 0.7311584631057023, "train_fea_loss": 0.8331418140172708, "epoch": 52, "n_parameters": 144845568}
|
| 59 |
+
{"train_lr": 9.633303917884302e-05, "train_loss": 2.6421383294395264, "train_task_loss": 2.5401510994175642, "train_bpp_loss": 0.40794892040498965, "train_patch_loss": 0.9832984244683199, "train_token_loss": 0.7280591241032671, "train_fea_loss": 0.8287935424283397, "epoch": 53, "n_parameters": 144845568}
|
| 60 |
+
{"train_lr": 9.574189131737902e-05, "train_loss": 2.6519711823504175, "train_task_loss": 2.5503779566384477, "train_bpp_loss": 0.40637290150338085, "train_patch_loss": 0.987097349643207, "train_token_loss": 0.7309697274904814, "train_fea_loss": 0.832310870747647, "epoch": 54, "n_parameters": 144845568}
|
| 61 |
+
{"train_lr": 9.514223309782753e-05, "train_loss": 2.649527451563939, "train_task_loss": 2.5482009406730852, "train_bpp_loss": 0.4053060460971932, "train_patch_loss": 0.9869010643251865, "train_token_loss": 0.7299437226093323, "train_fea_loss": 0.8313561448876521, "epoch": 55, "n_parameters": 144845568}
|
| 62 |
+
{"train_lr": 9.453421247691757e-05, "train_loss": 2.643370286264508, "train_task_loss": 2.5421072299049485, "train_bpp_loss": 0.40505222469563784, "train_patch_loss": 0.9843411268666387, "train_token_loss": 0.7289866340465004, "train_fea_loss": 0.8287794603441331, "epoch": 56, "n_parameters": 144845568}
|
| 63 |
+
{"train_lr": 9.391797947461475e-05, "train_loss": 2.6461283357577217, "train_task_loss": 2.5451179012316736, "train_bpp_loss": 0.40404173961842804, "train_patch_loss": 0.9857150221699111, "train_token_loss": 0.7291283701508523, "train_fea_loss": 0.8302744993274447, "epoch": 57, "n_parameters": 144845568}
|
| 64 |
+
{"train_lr": 9.329368613720009e-05, "train_loss": 2.652468080438084, "train_task_loss": 2.551866788899513, "train_bpp_loss": 0.40240516668471465, "train_patch_loss": 0.9880323676700643, "train_token_loss": 0.7304665103925468, "train_fea_loss": 0.8333679016130112, "epoch": 58, "n_parameters": 144845568}
|
| 65 |
+
{"train_lr": 9.266148649972007e-05, "train_loss": 2.650359462124767, "train_task_loss": 2.5499309687752376, "train_bpp_loss": 0.401713975702973, "train_patch_loss": 0.9875440163452991, "train_token_loss": 0.7306666403307581, "train_fea_loss": 0.8317203026117502, "epoch": 59, "n_parameters": 144845568}
|
| 66 |
+
{"train_lr": 9.202153654795684e-05, "train_loss": 2.6472997942613326, "train_task_loss": 2.547006531638636, "train_bpp_loss": 0.4011730529006115, "train_patch_loss": 0.986042402963862, "train_token_loss": 0.7303420982616161, "train_fea_loss": 0.8306220213389046, "epoch": 60, "n_parameters": 144845568}
|
| 67 |
+
{"train_lr": 9.137399417998249e-05, "train_loss": 2.635524819071273, "train_task_loss": 2.5351910238875615, "train_bpp_loss": 0.40133518058520784, "train_patch_loss": 0.9820189805460169, "train_token_loss": 0.7271724717727275, "train_fea_loss": 0.8259995626531345, "epoch": 61, "n_parameters": 144845568}
|
| 68 |
+
{"train_lr": 9.071901916722404e-05, "train_loss": 2.648858120589019, "train_task_loss": 2.5489517275056395, "train_bpp_loss": 0.39962557505043467, "train_patch_loss": 0.986155892487803, "train_token_loss": 0.73087697630029, "train_fea_loss": 0.8319188496388262, "epoch": 62, "n_parameters": 144845568}
|
| 69 |
+
{"train_lr": 9.071901916722404e-05, "train_loss": 2.7465927910497436, "train_task_loss": 2.5472986304449092, "train_bpp_loss": 0.3870172494111194, "train_patch_loss": 0.986047504940997, "train_token_loss": 0.7293615024804855, "train_fea_loss": 0.9150785957809261, "epoch": 61, "n_parameters": 144845568}
|
| 70 |
+
{"train_lr": 9.071901916722404e-05, "train_loss": 2.732339255324156, "train_task_loss": 2.5342044824247454, "train_bpp_loss": 0.3851228334161518, "train_patch_loss": 0.9823006979486854, "train_token_loss": 0.725924804654839, "train_fea_loss": 0.9085768888921117, "epoch": 62, "n_parameters": 144845568}
|
| 71 |
+
{"train_lr": 9.005677311491453e-05, "train_loss": 2.733937658449943, "train_task_loss": 2.5361717377456543, "train_bpp_loss": 0.38358666918550927, "train_patch_loss": 0.9829422572395391, "train_token_loss": 0.7263305272473741, "train_fea_loss": 0.9095888588916984, "epoch": 63, "n_parameters": 144845568}
|
| 72 |
+
{"train_lr": 8.938741942239847e-05, "train_loss": 2.735222797715764, "train_task_loss": 2.5376820612332516, "train_bpp_loss": 0.382704222326924, "train_patch_loss": 0.9836816218920605, "train_token_loss": 0.7267059973671103, "train_fea_loss": 0.9100238969839877, "epoch": 64, "n_parameters": 144845568}
|
| 73 |
+
{"train_lr": 8.871112324267081e-05, "train_loss": 2.7348346617742836, "train_task_loss": 2.537643769224413, "train_bpp_loss": 0.38169531316588, "train_patch_loss": 0.9843459360355096, "train_token_loss": 0.726475094608087, "train_fea_loss": 0.9095050229219724, "epoch": 65, "n_parameters": 144845568}
|
| 74 |
+
{"train_lr": 8.733837255720078e-05, "train_loss": 2.7936204858570934, "train_task_loss": 2.558094645236179, "train_bpp_loss": 0.3368677387395941, "train_patch_loss": 0.9911075603882435, "train_token_loss": 0.7276336732755998, "train_fea_loss": 0.9232887631746755, "epoch": 66, "n_parameters": 144845568}
|
| 75 |
+
{"train_lr": 8.733837255720078e-05, "train_loss": 2.7915395827209064, "train_task_loss": 2.558850450686914, "train_bpp_loss": 0.3302956995413565, "train_patch_loss": 0.9912216057375574, "train_token_loss": 0.7270683532455664, "train_fea_loss": 0.9246165518954896, "epoch": 67, "n_parameters": 144845568}
|
| 76 |
+
{"train_lr": 8.66422567571558e-05, "train_loss": 2.8010428122443547, "train_task_loss": 2.569311464486791, "train_bpp_loss": 0.3271386409915394, "train_patch_loss": 0.9944994980447012, "train_token_loss": 0.729622550842484, "train_fea_loss": 0.929708367548615, "epoch": 68, "n_parameters": 144845568}
|
| 77 |
+
{"train_lr": 8.59398757977085e-05, "train_loss": 2.8097731665634424, "train_task_loss": 2.5788058781241485, "train_bpp_loss": 0.3244854515335775, "train_patch_loss": 0.9977880177184117, "train_token_loss": 0.7315296879824367, "train_fea_loss": 0.9344370006996807, "epoch": 69, "n_parameters": 144845568}
|
| 78 |
+
{"train_lr": 8.523140298084917e-05, "train_loss": 2.7927491063747905, "train_task_loss": 2.562695958050821, "train_bpp_loss": 0.3240329879190889, "train_patch_loss": 0.9920173354035945, "train_token_loss": 0.7282957452923345, "train_fea_loss": 0.9266211757105853, "epoch": 70, "n_parameters": 144845568}
|
| 79 |
+
{"train_lr": 8.451701311164659e-05, "train_loss": 2.816440251421371, "train_task_loss": 2.5866508388297733, "train_bpp_loss": 0.32121099593403707, "train_patch_loss": 1.0007479091105667, "train_token_loss": 0.7334584623767282, "train_fea_loss": 0.9376889240237878, "epoch": 71, "n_parameters": 144845568}
|
| 80 |
+
{"train_lr": 8.379688245511898e-05, "train_loss": 2.8120930993692768, "train_task_loss": 2.5828994919570514, "train_bpp_loss": 0.32026763344004894, "train_patch_loss": 0.9992939126554975, "train_token_loss": 0.732874036779989, "train_fea_loss": 0.9358047071388729, "epoch": 72, "n_parameters": 144845568}
|
| 81 |
+
{"train_lr": 8.307118869271464e-05, "train_loss": 2.8027984519763818, "train_task_loss": 2.5742826347978567, "train_bpp_loss": 0.3194215351233272, "train_patch_loss": 0.9962847879741618, "train_token_loss": 0.7302367620156716, "train_fea_loss": 0.932537203840018, "epoch": 73, "n_parameters": 144845568}
|
| 82 |
+
{"train_lr": 8.234011087850579e-05, "train_loss": 2.832480542442138, "train_task_loss": 2.6040796041131307, "train_bpp_loss": 0.3165612864526931, "train_patch_loss": 1.0063268801385015, "train_token_loss": 0.7382693186433589, "train_fea_loss": 0.945431755996383, "epoch": 74, "n_parameters": 144845568}
|
| 83 |
+
{"train_lr": 8.160382939503717e-05, "train_loss": 2.799674530850826, "train_task_loss": 2.5721579685592824, "train_bpp_loss": 0.3175276949637674, "train_patch_loss": 0.9966744806246756, "train_token_loss": 0.7291926598819576, "train_fea_loss": 0.9309199213762757, "epoch": 75, "n_parameters": 144845568}
|
| 84 |
+
{"train_lr": 8.011638332509435e-05, "train_loss": 2.8255829363078666, "train_task_loss": 2.578499776871799, "train_bpp_loss": 0.2999681402433315, "train_patch_loss": 0.9987926142361703, "train_token_loss": 0.7287038154274725, "train_fea_loss": 0.9361036925038381, "epoch": 76, "n_parameters": 144845568}
|
| 85 |
+
{"train_lr": 8.011638332509435e-05, "train_loss": 2.840959037218591, "train_task_loss": 2.5953164593778926, "train_bpp_loss": 0.2960671754754919, "train_patch_loss": 1.0047289559068224, "train_token_loss": 0.7329247621774709, "train_fea_loss": 0.9434290263216666, "epoch": 77, "n_parameters": 144845568}
|
| 86 |
+
{"train_lr": 7.93655857436786e-05, "train_loss": 2.8381004938922767, "train_task_loss": 2.5934427621946345, "train_bpp_loss": 0.29446334153194903, "train_patch_loss": 1.0050567829807242, "train_token_loss": 0.7319110037146999, "train_fea_loss": 0.9421224841419206, "epoch": 78, "n_parameters": 144845568}
|
| 87 |
+
{"train_lr": 7.86103184125689e-05, "train_loss": 2.847826922665254, "train_task_loss": 2.603588223600273, "train_bpp_loss": 0.29271904649034286, "train_patch_loss": 1.0064198724385038, "train_token_loss": 0.7354645022689152, "train_fea_loss": 0.9478742439049682, "epoch": 79, "n_parameters": 144845568}
|
| 88 |
+
{"train_lr": 7.785076768264985e-05, "train_loss": 2.8496559623369784, "train_task_loss": 2.6059850953447876, "train_bpp_loss": 0.29151453778183434, "train_patch_loss": 1.0084720284899147, "train_token_loss": 0.7349831929711772, "train_fea_loss": 0.9487828727950045, "epoch": 80, "n_parameters": 144845568}
|
| 89 |
+
{"train_lr": 7.708712096171631e-05, "train_loss": 2.8463599768867023, "train_task_loss": 2.6032515673292913, "train_bpp_loss": 0.2907897324716712, "train_patch_loss": 1.0084952734437564, "train_token_loss": 0.7339370518864737, "train_fea_loss": 0.9469011765759554, "epoch": 81, "n_parameters": 144845568}
|
| 90 |
+
{"train_lr": 7.631956666815207e-05, "train_loss": 2.8331575906015845, "train_task_loss": 2.5902766323418356, "train_bpp_loss": 0.29128668203445185, "train_patch_loss": 1.0040304742896442, "train_token_loss": 0.7303849485070764, "train_fea_loss": 0.9414473412786242, "epoch": 82, "n_parameters": 144845568}
|
| 91 |
+
{"train_lr": 7.554829418450765e-05, "train_loss": 6.022797273306681, "train_task_loss": 5.747255473551776, "train_bpp_loss": 0.15595893754092843, "train_patch_loss": 1.9295167815905037, "train_token_loss": 1.9044995418676047, "train_fea_loss": 2.1045631034237196, "epoch": 83, "n_parameters": 144845568}
|
| 92 |
+
{"train_lr": 7.477349381072652e-05, "train_loss": 6.33198429772751, "train_task_loss": 6.103147584757359, "train_bpp_loss": 0.04776455266578447, "train_patch_loss": 2.038426331773722, "train_token_loss": 2.034283296100134, "train_fea_loss": 2.2334817904058837, "epoch": 84, "n_parameters": 144845568}
|
| 93 |
+
{"train_lr": 7.399535671720344e-05, "train_loss": 6.269092192288211, "train_task_loss": 6.053333911088874, "train_bpp_loss": 0.029455166824030238, "train_patch_loss": 2.0307816479676704, "train_token_loss": 2.024027929974081, "train_fea_loss": 2.1983768022049675, "epoch": 85, "n_parameters": 144845568}
|
| 94 |
+
{"train_lr": 7.321407489761549e-05, "train_loss": 6.215953613303119, "train_task_loss": 6.003636568552442, "train_bpp_loss": 0.026147063460792316, "train_patch_loss": 2.018133198333194, "train_token_loss": 2.0035276831500677, "train_fea_loss": 2.1801732855288387, "epoch": 86, "n_parameters": 144845568}
|
| 95 |
+
{"train_lr": 7.242984112156774e-05, "train_loss": 5.868365104446451, "train_task_loss": 5.527729460429088, "train_bpp_loss": 0.2684132926173316, "train_patch_loss": 1.9746971274469254, "train_token_loss": 1.5961082754723674, "train_fea_loss": 2.1526165047361197, "epoch": 87, "n_parameters": 144845568}
|
| 96 |
+
{"train_lr": 7.16428488870196e-05, "train_loss": 5.4557995484958735, "train_task_loss": 5.110851676522685, "train_bpp_loss": 0.2943931522651305, "train_patch_loss": 1.8687771737611265, "train_token_loss": 1.3823193566252787, "train_fea_loss": 2.0457307090067105, "epoch": 88, "n_parameters": 144845568}
|
| 97 |
+
{"train_lr": 7.085329237251759e-05, "train_loss": 6.162637532811513, "train_task_loss": 5.88701508848144, "train_bpp_loss": 0.13411895593832643, "train_patch_loss": 2.0334672068672286, "train_token_loss": 1.821566376444265, "train_fea_loss": 2.2351796914660316, "epoch": 89, "n_parameters": 144845568}
|
| 98 |
+
{"train_lr": 7.006136638931818e-05, "train_loss": 5.024170764112215, "train_task_loss": 4.809058998057025, "train_bpp_loss": 0.05317469353166999, "train_patch_loss": 1.8792113221967393, "train_token_loss": 1.0658738444262235, "train_fea_loss": 2.0503712694118206, "epoch": 90, "n_parameters": 144845568}
|
| 99 |
+
{"train_lr": 6.926726633331106e-05, "train_loss": 4.723379994860942, "train_task_loss": 4.515086256485763, "train_bpp_loss": 0.057305316362173765, "train_patch_loss": 1.798007144451999, "train_token_loss": 0.9435908788395264, "train_fea_loss": 1.9508371142764314, "epoch": 91, "n_parameters": 144845568}
|
| 100 |
+
{"train_lr": 6.847118813679865e-05, "train_loss": 5.031990508435013, "train_task_loss": 4.765979826932759, "train_bpp_loss": 0.15900398099812293, "train_patch_loss": 1.824176159184828, "train_token_loss": 1.1403188319715987, "train_fea_loss": 1.981633366195025, "epoch": 92, "n_parameters": 144845568}
|
| 101 |
+
{"train_lr": 7.631956666815207e-05, "train_loss": 2.814900835471259, "train_task_loss": 2.5911692696259463, "train_bpp_loss": 0.30722877928625336, "train_patch_loss": 1.003347651021247, "train_token_loss": 0.7330356562840674, "train_fea_loss": 0.854785952643364, "epoch": 81, "n_parameters": 144845568}
|
| 102 |
+
{"train_lr": 7.631956666815207e-05, "train_loss": 2.8142806054126446, "train_task_loss": 2.5897073253131597, "train_bpp_loss": 0.30940049388993096, "train_patch_loss": 1.0041070785507453, "train_token_loss": 0.7321698584502263, "train_fea_loss": 0.8534303789900289, "epoch": 82, "n_parameters": 144845568}
|
| 103 |
+
{"train_lr": 7.554829418450765e-05, "train_loss": 2.816245504384704, "train_task_loss": 2.591411675659301, "train_bpp_loss": 0.30987349456864044, "train_patch_loss": 1.003780739015634, "train_token_loss": 0.7337235686563545, "train_fea_loss": 0.8539073579683364, "epoch": 83, "n_parameters": 144845568}
|
| 104 |
+
{"train_lr": 7.477349381072652e-05, "train_loss": 2.800685787649392, "train_task_loss": 2.5759294760688176, "train_bpp_loss": 0.3108195169511691, "train_patch_loss": 0.9979896100653376, "train_token_loss": 0.729064737578376, "train_fea_loss": 0.8488751188081374, "epoch": 84, "n_parameters": 144845568}
|
| 105 |
+
{"train_lr": 7.399535671720344e-05, "train_loss": 2.78943671736357, "train_task_loss": 2.5648709468239788, "train_bpp_loss": 0.31156991790573624, "train_patch_loss": 0.9950137249611729, "train_token_loss": 0.7262643160526272, "train_fea_loss": 0.8435928968185608, "epoch": 85, "n_parameters": 144845568}
|
| 106 |
+
{"train_lr": 7.321407489761549e-05, "train_loss": 2.844264631443244, "train_task_loss": 2.614422207854563, "train_bpp_loss": 0.319169871819516, "train_patch_loss": 1.0113446999082176, "train_token_loss": 0.7409178713242784, "train_fea_loss": 0.8621596260023096, "epoch": 86, "n_parameters": 144845568}
|
| 107 |
+
{"train_lr": 7.242984112156774e-05, "train_loss": 2.851643653814312, "train_task_loss": 2.6122646205198707, "train_bpp_loss": 0.33907441596994153, "train_patch_loss": 1.01671234990651, "train_token_loss": 0.7275969981968545, "train_fea_loss": 0.8679552621183564, "epoch": 87, "n_parameters": 144845568}
|
| 108 |
+
{"train_lr": 7.16428488870196e-05, "train_loss": 2.778331270660285, "train_task_loss": 2.552063302813674, "train_bpp_loss": 0.31613819805538945, "train_patch_loss": 0.9926303882571719, "train_token_loss": 0.7193752784768848, "train_fea_loss": 0.8400576262571793, "epoch": 88, "n_parameters": 144845568}
|
| 109 |
+
{"train_lr": 7.085329237251759e-05, "train_loss": 2.7897735733261926, "train_task_loss": 2.565236364247845, "train_bpp_loss": 0.311233962220486, "train_patch_loss": 0.9957534326776434, "train_token_loss": 0.7246638578329262, "train_fea_loss": 0.8448190648729602, "epoch": 89, "n_parameters": 144845568}
|
| 110 |
+
{"train_lr": 7.006136638931818e-05, "train_loss": 2.7982831528093173, "train_task_loss": 2.574218000189292, "train_bpp_loss": 0.3094736575059664, "train_patch_loss": 1.000583418184887, "train_token_loss": 0.7256146864831555, "train_fea_loss": 0.8480198857010447, "epoch": 90, "n_parameters": 144845568}
|
| 111 |
+
{"train_lr": 6.926726633331106e-05, "train_loss": 3.415653905788843, "train_task_loss": 2.975828631783275, "train_bpp_loss": 0.7560260864588872, "train_patch_loss": 1.1288673198967278, "train_token_loss": 0.8508261374780814, "train_fea_loss": 0.9961351647445981, "epoch": 91, "n_parameters": 144845568}
|
| 112 |
+
{"train_lr": 6.847118813679865e-05, "train_loss": 6.568326573541982, "train_task_loss": 6.1273969157779815, "train_bpp_loss": 0.5257196640345934, "train_patch_loss": 2.043614938980241, "train_token_loss": 2.040224435022302, "train_fea_loss": 2.04355752422548, "epoch": 92, "n_parameters": 144845568}
|
| 113 |
+
{"train_lr": 6.767332822016792e-05, "train_loss": 6.867288129757062, "train_task_loss": 6.089932662566646, "train_bpp_loss": 1.2734335873903297, "train_patch_loss": 2.0432088341144063, "train_token_loss": 2.003620763126162, "train_fea_loss": 2.043103043592984, "epoch": 93, "n_parameters": 144845568}
|
| 114 |
+
{"train_lr": 6.687388344341571e-05, "train_loss": 6.613793869366606, "train_task_loss": 4.8460818514299335, "train_bpp_loss": 3.475327093109524, "train_patch_loss": 2.043151048612859, "train_token_loss": 0.7647827833252167, "train_fea_loss": 2.038148040657671, "epoch": 94, "n_parameters": 144845568}
|
| 115 |
+
{"train_lr": 6.607305105757049e-05, "train_loss": 6.583523838229174, "train_task_loss": 4.818476951540374, "train_bpp_loss": 3.4695032598624986, "train_patch_loss": 2.0428966481319004, "train_token_loss": 0.737876279511862, "train_fea_loss": 2.037704043197546, "epoch": 95, "n_parameters": 144845568}
|
| 116 |
+
{"train_lr": 6.5271028656055e-05, "train_loss": 6.5769114312794, "train_task_loss": 4.8131763801979215, "train_bpp_loss": 3.46658337029586, "train_patch_loss": 2.0428622997046517, "train_token_loss": 0.7325888925904243, "train_fea_loss": 2.0377252054997057, "epoch": 96, "n_parameters": 144845568}
|
| 117 |
+
{"train_lr": 6.446801412587525e-05, "train_loss": 6.572428462227329, "train_task_loss": 4.809816579738681, "train_bpp_loss": 3.4640789968754464, "train_patch_loss": 2.0429170189662087, "train_token_loss": 0.7291364202284806, "train_fea_loss": 2.0377631590699408, "epoch": 97, "n_parameters": 144845568}
|
| 118 |
+
{"train_lr": 6.36642055988671e-05, "train_loss": 6.5666217358349614, "train_task_loss": 4.80531757448217, "train_bpp_loss": 3.461199431673443, "train_patch_loss": 2.0427514293746984, "train_token_loss": 0.7249221600507267, "train_fea_loss": 2.0376440027742078, "epoch": 98, "n_parameters": 144845568}
|
| 119 |
+
{"train_lr": 6.285980140274965e-05, "train_loss": 6.492165406455668, "train_task_loss": 4.802894322652754, "train_bpp_loss": 3.3011300552377314, "train_patch_loss": 2.042857281175806, "train_token_loss": 0.72241166340549, "train_fea_loss": 2.0376253973210243, "epoch": 99, "n_parameters": 144845568}
|
| 120 |
+
{"train_lr": 6.205500001222403e-05, "train_loss": 6.469163187473512, "train_task_loss": 4.799823627161751, "train_bpp_loss": 3.2568444883723338, "train_patch_loss": 2.042693032090255, "train_token_loss": 0.7195353948642953, "train_fea_loss": 2.0375952215226505, "epoch": 100, "n_parameters": 144845568}
|
| 121 |
+
{"train_lr": 6.847118813679865e-05, "train_loss": 2.708930113305934, "train_task_loss": 2.5679492560668673, "train_bpp_loss": 0.3132908021231635, "train_patch_loss": 0.9976560541538062, "train_token_loss": 0.7208710700648723, "train_fea_loss": 0.8494221233278155, "epoch": 91, "n_parameters": 144845568}
|
| 122 |
+
{"train_lr": 6.847118813679865e-05, "train_loss": 2.699149013628705, "train_task_loss": 2.5613451289723246, "train_bpp_loss": 0.3062308639035576, "train_patch_loss": 0.9944496202400978, "train_token_loss": 0.7207848698646724, "train_fea_loss": 0.846110629630264, "epoch": 92, "n_parameters": 144845568}
|
| 123 |
+
{"train_lr": 6.767332822016792e-05, "train_loss": 2.695356560687867, "train_task_loss": 2.5580616601490433, "train_bpp_loss": 0.3050997874931805, "train_patch_loss": 0.9943495049304385, "train_token_loss": 0.7195412099754496, "train_fea_loss": 0.8441709351862113, "epoch": 93, "n_parameters": 144845568}
|
| 124 |
+
{"train_lr": 6.687388344341571e-05, "train_loss": 2.709555460635921, "train_task_loss": 2.572864127852362, "train_bpp_loss": 0.3037585253364963, "train_patch_loss": 0.9997199533967306, "train_token_loss": 0.7228693755843835, "train_fea_loss": 0.850274788731669, "epoch": 94, "n_parameters": 144845568}
|
| 125 |
+
{"train_lr": 6.607305105757049e-05, "train_loss": 2.7318896783841886, "train_task_loss": 2.59231706789965, "train_bpp_loss": 0.31016136522402615, "train_patch_loss": 1.0061367933327954, "train_token_loss": 0.7272278621984865, "train_fea_loss": 0.858952402501262, "epoch": 95, "n_parameters": 144845568}
|
| 126 |
+
{"train_lr": 6.5271028656055e-05, "train_loss": 2.7111547499162545, "train_task_loss": 2.569625718514625, "train_bpp_loss": 0.31450896638912335, "train_patch_loss": 1.0005275200364543, "train_token_loss": 0.7153408098586207, "train_fea_loss": 0.8537573801527778, "epoch": 96, "n_parameters": 144845568}
|
| 127 |
+
{"train_lr": 6.446801412587525e-05, "train_loss": 2.707958362782173, "train_task_loss": 2.570670375998, "train_bpp_loss": 0.30508442372374006, "train_patch_loss": 0.9985086321964752, "train_token_loss": 0.7216268923039809, "train_fea_loss": 0.8505348416952063, "epoch": 97, "n_parameters": 144845568}
|
| 128 |
+
{"train_lr": 6.36642055988671e-05, "train_loss": 2.697063536604317, "train_task_loss": 2.560573679866265, "train_bpp_loss": 0.3033108016983014, "train_patch_loss": 0.9955314651368369, "train_token_loss": 0.7185008986127213, "train_fea_loss": 0.846541307080421, "epoch": 98, "n_parameters": 144845568}
|
| 129 |
+
{"train_lr": 6.285980140274965e-05, "train_loss": 2.679979503681834, "train_task_loss": 2.5433403860768684, "train_bpp_loss": 0.30364249174190544, "train_patch_loss": 0.989279427558934, "train_token_loss": 0.7147660507581575, "train_fea_loss": 0.8392948986810568, "epoch": 99, "n_parameters": 144845568}
|
| 130 |
+
{"train_lr": 6.205500001222403e-05, "train_loss": 2.818446820222145, "train_task_loss": 2.6755781511751584, "train_bpp_loss": 0.3174859403087635, "train_patch_loss": 1.0392651655052385, "train_token_loss": 0.7372288242098596, "train_fea_loss": 0.8990841494868699, "epoch": 100, "n_parameters": 144845568}
|
| 131 |
+
{"train_lr": 6.12500000000064e-05, "train_loss": 2.7441475540310214, "train_task_loss": 2.597907278704629, "train_bpp_loss": 0.32497839667171025, "train_patch_loss": 1.0195179717119436, "train_token_loss": 0.7045413070298034, "train_fea_loss": 0.8738479904709853, "epoch": 101, "n_parameters": 144845568}
|
| 132 |
+
{"train_lr": 6.044499998777186e-05, "train_loss": 2.7079060795299297, "train_task_loss": 2.565669994283137, "train_bpp_loss": 0.31608019693479905, "train_patch_loss": 1.0030313897252583, "train_token_loss": 0.7082743480091365, "train_fea_loss": 0.8543642482986089, "epoch": 102, "n_parameters": 144845568}
|
| 133 |
+
{"train_lr": 5.964019859724661e-05, "train_loss": 2.696671325293519, "train_task_loss": 2.557040445006294, "train_bpp_loss": 0.3102908528876676, "train_patch_loss": 0.9976166626081157, "train_token_loss": 0.7102920439817922, "train_fea_loss": 0.8491317290227584, "epoch": 103, "n_parameters": 144845568}
|
| 134 |
+
{"train_lr": 5.8835794401133974e-05, "train_loss": 2.691563926851578, "train_task_loss": 2.553079867954377, "train_bpp_loss": 0.3077423636526834, "train_patch_loss": 0.9958474041438468, "train_token_loss": 0.7099756244815689, "train_fea_loss": 0.8472568294439885, "epoch": 104, "n_parameters": 144845568}
|
| 135 |
+
{"train_lr": 5.8031985874119795e-05, "train_loss": 2.692192362849232, "train_task_loss": 2.5543478446916565, "train_bpp_loss": 0.30632116043788316, "train_patch_loss": 0.9955702859291927, "train_token_loss": 0.7111749301689099, "train_fea_loss": 0.8476026195617352, "epoch": 105, "n_parameters": 144845568}
|
| 136 |
+
{"train_lr": 5.722897134394433e-05, "train_loss": 2.6670049736075265, "train_task_loss": 2.5291183430889097, "train_bpp_loss": 0.30641474233207694, "train_patch_loss": 0.9873280670623378, "train_token_loss": 0.7049247939295942, "train_fea_loss": 0.8368654731232271, "epoch": 106, "n_parameters": 144845568}
|
| 137 |
+
{"train_lr": 5.642694894242339e-05, "train_loss": 2.6757864850066264, "train_task_loss": 2.538502600463889, "train_bpp_loss": 0.30507530629348983, "train_patch_loss": 0.9907261180832124, "train_token_loss": 0.7074770695045233, "train_fea_loss": 0.8402994043602467, "epoch": 107, "n_parameters": 144845568}
|
| 138 |
+
{"train_lr": 5.562611655657961e-05, "train_loss": 2.7075679923811977, "train_task_loss": 2.5688214136613645, "train_bpp_loss": 0.30832573917316947, "train_patch_loss": 1.00075649227545, "train_token_loss": 0.7133955942328385, "train_fea_loss": 0.8546693176053447, "epoch": 108, "n_parameters": 144845568}
|
| 139 |
+
{"train_lr": 5.642694894242339e-05, "train_loss": 2.668979725284542, "train_task_loss": 2.5311407839669453, "train_bpp_loss": 0.30630876569577653, "train_patch_loss": 0.9885714473649502, "train_token_loss": 0.7049617424553676, "train_fea_loss": 0.8376075848758363, "epoch": 106, "n_parameters": 144845568}
|
| 140 |
+
{"train_lr": 5.642694894242339e-05, "train_loss": 2.683824887423278, "train_task_loss": 2.5468173437535193, "train_bpp_loss": 0.3044612155992725, "train_patch_loss": 0.9947493373275661, "train_token_loss": 0.708695182786714, "train_fea_loss": 0.8433728142076438, "epoch": 107, "n_parameters": 144845568}
|
| 141 |
+
{"train_lr": 5.562611655657961e-05, "train_loss": 2.6808553397548285, "train_task_loss": 2.544291126874568, "train_bpp_loss": 0.30347603688025465, "train_patch_loss": 0.9940348935472093, "train_token_loss": 0.7075042654525676, "train_fea_loss": 0.8427519599440322, "epoch": 108, "n_parameters": 144845568}
|
| 142 |
+
{"train_lr": 5.482667177983261e-05, "train_loss": 4.911457611171962, "train_task_loss": 4.756159201309049, "train_bpp_loss": 0.3451075883260068, "train_patch_loss": 1.923462125296287, "train_token_loss": 0.9228969771633867, "train_fea_loss": 1.909800109163159, "epoch": 109, "n_parameters": 144845568}
|
| 143 |
+
{"train_lr": 5.402881186319929e-05, "train_loss": 6.119308372142075, "train_task_loss": 6.05821056978451, "train_bpp_loss": 0.13577289702430786, "train_patch_loss": 2.0250163967821666, "train_token_loss": 2.0047141713507886, "train_fea_loss": 2.028479987103269, "epoch": 110, "n_parameters": 144845568}
|
| 144 |
+
{"train_lr": 5.402881186319929e-05, "train_loss": 2.8206423869092974, "train_task_loss": 2.5848962855514146, "train_bpp_loss": 0.3167879025416105, "train_patch_loss": 1.0065730715067767, "train_token_loss": 0.721063517386929, "train_fea_loss": 0.8572596858927815, "epoch": 109, "n_parameters": 144845568}
|
| 145 |
+
{"train_lr": 5.402881186319929e-05, "train_loss": 2.7694720208591264, "train_task_loss": 2.536599607177847, "train_bpp_loss": 0.31375666602561586, "train_patch_loss": 0.9914828701552263, "train_token_loss": 0.7029595081839595, "train_fea_loss": 0.8421572199877467, "epoch": 110, "n_parameters": 144845568}
|
| 146 |
+
{"train_lr": 5.323273366669127e-05, "train_loss": 2.754181052810497, "train_task_loss": 2.52081877488098, "train_bpp_loss": 0.31649657639471607, "train_patch_loss": 0.9853969901527683, "train_token_loss": 0.7020405733832817, "train_fea_loss": 0.8333812024247803, "epoch": 111, "n_parameters": 144845568}
|
| 147 |
+
{"train_lr": 5.24386336106797e-05, "train_loss": 2.778903913664089, "train_task_loss": 2.54480576113188, "train_bpp_loss": 0.31611214793115244, "train_patch_loss": 0.9951485451065201, "train_token_loss": 0.7078502580961235, "train_fea_loss": 0.8418069494334306, "epoch": 112, "n_parameters": 144845568}
|
| 148 |
+
{"train_lr": 5.1646707627478925e-05, "train_loss": 2.8131009751854896, "train_task_loss": 2.5725743175803615, "train_bpp_loss": 0.32791142306922616, "train_patch_loss": 1.0041894356025567, "train_token_loss": 0.7132421619497793, "train_fea_loss": 0.8551427105003946, "epoch": 113, "n_parameters": 144845568}
|
| 149 |
+
{"train_lr": 5.0857151112976574e-05, "train_loss": 2.7399135867147137, "train_task_loss": 2.5091776530072987, "train_bpp_loss": 0.31148358721875674, "train_patch_loss": 0.9819968594774175, "train_token_loss": 0.6978094259017913, "train_fea_loss": 0.8293713586163713, "epoch": 114, "n_parameters": 144845568}
|
| 150 |
+
{"train_lr": 5.007015887842505e-05, "train_loss": 2.8684129846825015, "train_task_loss": 2.616686708519427, "train_bpp_loss": 0.3489265601872076, "train_patch_loss": 1.0203611283837783, "train_token_loss": 0.7224983579727063, "train_fea_loss": 0.8738272108821025, "epoch": 115, "n_parameters": 144845568}
|
| 151 |
+
{"train_lr": 4.928592510238729e-05, "train_loss": 2.739907768231144, "train_task_loss": 2.50749061155698, "train_bpp_loss": 0.3151009082838857, "train_patch_loss": 0.9832527024260492, "train_token_loss": 0.695053851908649, "train_fea_loss": 0.8291840486706411, "epoch": 116, "n_parameters": 144845568}
|
| 152 |
+
{"train_lr": 4.850464328279906e-05, "train_loss": 2.7357478904620494, "train_task_loss": 2.5043268277079798, "train_bpp_loss": 0.31352339563690285, "train_patch_loss": 0.9800204708680904, "train_token_loss": 0.6976143020144898, "train_fea_loss": 0.8266920464814639, "epoch": 117, "n_parameters": 144845568}
|
| 153 |
+
{"train_lr": 4.7726506189276635e-05, "train_loss": 2.755742895097541, "train_task_loss": 2.5216859226211086, "train_bpp_loss": 0.31784185127531595, "train_patch_loss": 0.9866870453634815, "train_token_loss": 0.701121356496833, "train_fea_loss": 0.833877512173419, "epoch": 118, "n_parameters": 144845568}
|
| 154 |
+
{"train_lr": 4.69517058154867e-05, "train_loss": 2.7318975441247155, "train_task_loss": 2.5014077417141527, "train_bpp_loss": 0.31166150058537045, "train_patch_loss": 0.9790000828715573, "train_token_loss": 0.6965634005313976, "train_fea_loss": 0.8258442498243256, "epoch": 119, "n_parameters": 144845568}
|
| 155 |
+
{"train_lr": 4.6180433331847694e-05, "train_loss": 2.7518172350307877, "train_task_loss": 2.5213089466577383, "train_bpp_loss": 0.31003568373504375, "train_patch_loss": 0.9861942503458376, "train_token_loss": 0.7014625150290468, "train_fea_loss": 0.8336521727021793, "epoch": 120, "n_parameters": 144845568}
|
| 156 |
+
{"train_lr": 4.541287903828179e-05, "train_loss": 2.7335522819390827, "train_task_loss": 2.50380109216598, "train_bpp_loss": 0.3096768988276289, "train_patch_loss": 0.9810024427442099, "train_token_loss": 0.6958675646743507, "train_fea_loss": 0.8269310759501063, "epoch": 121, "n_parameters": 144845568}
|
| 157 |
+
{"train_lr": 4.4649232317341524e-05, "train_loss": 2.733630260037218, "train_task_loss": 2.503839423107229, "train_bpp_loss": 0.309786735956805, "train_patch_loss": 0.9804957953708278, "train_token_loss": 0.6961014977024256, "train_fea_loss": 0.8272421213170643, "epoch": 122, "n_parameters": 144845568}
|
| 158 |
+
{"train_lr": 4.3889681587425266e-05, "train_loss": 2.8017713390469408, "train_task_loss": 2.557888786087362, "train_bpp_loss": 0.3367434704309995, "train_patch_loss": 0.9978989840482922, "train_token_loss": 0.7109271357421109, "train_fea_loss": 0.8490626564633539, "epoch": 123, "n_parameters": 144845568}
|
| 159 |
+
{"train_lr": 4.313441425631543e-05, "train_loss": 2.7663599788803133, "train_task_loss": 2.515742145719931, "train_bpp_loss": 0.35449836285320296, "train_patch_loss": 0.9864850308973905, "train_token_loss": 0.6938689482427365, "train_fea_loss": 0.8353881579388281, "epoch": 124, "n_parameters": 144845568}
|
| 160 |
+
{"train_lr": 4.238361667491207e-05, "train_loss": 2.7239683468153864, "train_task_loss": 2.4941051970068497, "train_bpp_loss": 0.31064632278029547, "train_patch_loss": 0.9776195450965092, "train_token_loss": 0.6926572750338155, "train_fea_loss": 0.8238283673957955, "epoch": 125, "n_parameters": 144845568}
|
| 161 |
+
{"train_lr": 4.1637474091286196e-05, "train_loss": 2.731808961188193, "train_task_loss": 2.502355479927872, "train_bpp_loss": 0.30911144960618064, "train_patch_loss": 0.9799622440648129, "train_token_loss": 0.6952871139065253, "train_fea_loss": 0.827106113864971, "epoch": 126, "n_parameters": 144845568}
|
| 162 |
+
{"train_lr": 4.089617060496659e-05, "train_loss": 2.7330640450530916, "train_task_loss": 2.5038258624519947, "train_bpp_loss": 0.3084725750156855, "train_patch_loss": 0.9800150552474195, "train_token_loss": 0.6953135876708644, "train_fea_loss": 0.8284972119714609, "epoch": 127, "n_parameters": 144845568}
|
| 163 |
+
{"train_lr": 4.015988912148501e-05, "train_loss": 2.7349733328558417, "train_task_loss": 2.4990430673856814, "train_bpp_loss": 0.3236985370981322, "train_patch_loss": 0.9784446784811054, "train_token_loss": 0.6937227977542497, "train_fea_loss": 0.8268755828197911, "epoch": 128, "n_parameters": 144845568}
|
| 164 |
+
{"train_lr": 3.942881130728865e-05, "train_loss": 2.7178904607856302, "train_task_loss": 2.4889944930085175, "train_bpp_loss": 0.3089114015997885, "train_patch_loss": 0.9757260797004834, "train_token_loss": 0.6912759416524491, "train_fea_loss": 0.821992463240181, "epoch": 129, "n_parameters": 144845568}
|
| 165 |
+
{"train_lr": 3.870311754488397e-05, "train_loss": 2.7325405643301472, "train_task_loss": 2.5016696595584604, "train_bpp_loss": 0.31216120841294276, "train_patch_loss": 0.9799974317856722, "train_token_loss": 0.6937006032407034, "train_fea_loss": 0.8279716155806677, "epoch": 130, "n_parameters": 144845568}
|
| 166 |
+
{"train_lr": 3.798298688834852e-05, "train_loss": 2.7105408391917494, "train_task_loss": 2.4821574514736233, "train_bpp_loss": 0.30818621284714776, "train_patch_loss": 0.9738334429783108, "train_token_loss": 0.6881637847457501, "train_fea_loss": 0.820160214387202, "epoch": 131, "n_parameters": 144845568}
|
| 167 |
+
{"train_lr": 3.726859701914403e-05, "train_loss": 2.7591426904252967, "train_task_loss": 2.5228572249930683, "train_bpp_loss": 0.3226083974237833, "train_patch_loss": 0.9864406956994026, "train_token_loss": 0.7006216509741165, "train_fea_loss": 0.8357948688053184, "epoch": 132, "n_parameters": 144845568}
|
| 168 |
+
{"train_lr": 3.656012420228689e-05, "train_loss": 2.68967245043003, "train_task_loss": 2.460287703357512, "train_bpp_loss": 0.3120911338152163, "train_patch_loss": 0.9669189435883249, "train_token_loss": 0.6814111765313278, "train_fea_loss": 0.81195757473982, "epoch": 133, "n_parameters": 144845568}
|
| 169 |
+
{"train_lr": 3.5857743242838835e-05, "train_loss": 2.7014216272432883, "train_task_loss": 2.472743094181843, "train_bpp_loss": 0.30973817068163295, "train_patch_loss": 0.9706215003429414, "train_token_loss": 0.6868141930465468, "train_fea_loss": 0.8153073933484744, "epoch": 134, "n_parameters": 144845568}
|
| 170 |
+
{"train_lr": 3.516162744279572e-05, "train_loss": 2.705487959069028, "train_task_loss": 2.4765327244067934, "train_bpp_loss": 0.3098553310503527, "train_patch_loss": 0.9728047358189269, "train_token_loss": 0.68612423896057, "train_fea_loss": 0.8176037415066998, "epoch": 135, "n_parameters": 144845568}
|
| 171 |
+
{"train_lr": 3.447194855830639e-05, "train_loss": 2.7113154404383484, "train_task_loss": 2.478181616079321, "train_bpp_loss": 0.31905700233530443, "train_patch_loss": 0.9727772305422514, "train_token_loss": 0.68701637881234, "train_fea_loss": 0.818387999384702, "epoch": 136, "n_parameters": 144845568}
|
| 172 |
+
{"train_lr": 3.378887675732868e-05, "train_loss": 2.6797968578620925, "train_task_loss": 2.451133315422409, "train_bpp_loss": 0.3113255005260612, "train_patch_loss": 0.9640113939207307, "train_token_loss": 0.6797902039285723, "train_fea_loss": 0.8073317096557477, "epoch": 137, "n_parameters": 144845568}
|
| 173 |
+
{"train_lr": 3.311258057759679e-05, "train_loss": 2.694638561546731, "train_task_loss": 2.4661759902551164, "train_bpp_loss": 0.3096901442793062, "train_patch_loss": 0.9691699242545403, "train_token_loss": 0.6841340864242481, "train_fea_loss": 0.8128719716969368, "epoch": 138, "n_parameters": 144845568}
|
| 174 |
+
{"train_lr": 3.244322688507758e-05, "train_loss": 2.697916217782586, "train_task_loss": 2.4694298465224764, "train_bpp_loss": 0.3092930535785854, "train_patch_loss": 0.9713080364482687, "train_token_loss": 0.6833381301698853, "train_fea_loss": 0.8147836721902021, "epoch": 139, "n_parameters": 144845568}
|
| 175 |
+
{"train_lr": 3.1780980832784374e-05, "train_loss": 2.697745393452456, "train_task_loss": 2.465554346515835, "train_bpp_loss": 0.3177864467847994, "train_patch_loss": 0.9703194021450565, "train_token_loss": 0.6818096557480219, "train_fea_loss": 0.8134252810515732, "epoch": 140, "n_parameters": 144845568}
|
| 176 |
+
{"train_lr": 3.112600582001298e-05, "train_loss": 2.692317468710512, "train_task_loss": 2.4639319612039365, "train_bpp_loss": 0.3095171678927448, "train_patch_loss": 0.9694184379415404, "train_token_loss": 0.6818745398862684, "train_fea_loss": 0.8126389763923251, "epoch": 141, "n_parameters": 144845568}
|
| 177 |
+
{"train_lr": 3.047846345205177e-05, "train_loss": 2.7029235281175277, "train_task_loss": 2.47447449285135, "train_bpp_loss": 0.30879852916527206, "train_patch_loss": 0.9717401454551257, "train_token_loss": 0.6846787373957445, "train_fea_loss": 0.8180556024426965, "epoch": 142, "n_parameters": 144845568}
|
| 178 |
+
{"train_lr": 2.9838513500286588e-05, "train_loss": 2.6954629459469723, "train_task_loss": 2.461703634662308, "train_bpp_loss": 0.3215753604463643, "train_patch_loss": 0.9685670451911019, "train_token_loss": 0.6806938354801432, "train_fea_loss": 0.8124427453947546, "epoch": 143, "n_parameters": 144845568}
|
| 179 |
+
{"train_lr": 2.920631386279756e-05, "train_loss": 2.6789074894978846, "train_task_loss": 2.4506222840836296, "train_bpp_loss": 0.3103648301312249, "train_patch_loss": 0.9643291673049045, "train_token_loss": 0.6781998446358622, "train_fea_loss": 0.8080932640043987, "epoch": 144, "n_parameters": 144845568}
|
| 180 |
+
{"train_lr": 2.8582020525382766e-05, "train_loss": 2.698681264338519, "train_task_loss": 2.465655054361057, "train_bpp_loss": 0.3195096456470321, "train_patch_loss": 0.9703498526188217, "train_token_loss": 0.6807158637626601, "train_fea_loss": 0.8145893300555164, "epoch": 145, "n_parameters": 144845568}
|
| 181 |
+
{"train_lr": 2.7965787523079142e-05, "train_loss": 2.6876495937845832, "train_task_loss": 2.4591161769142538, "train_bpp_loss": 0.3100968778760974, "train_patch_loss": 0.9675739941743317, "train_token_loss": 0.6793179688975215, "train_fea_loss": 0.8122242059786435, "epoch": 146, "n_parameters": 144845568}
|
| 182 |
+
{"train_lr": 2.7357766902161244e-05, "train_loss": 2.686950199314945, "train_task_loss": 2.4579590937520246, "train_bpp_loss": 0.31120623359166316, "train_patch_loss": 0.9676960334248788, "train_token_loss": 0.6789913053403227, "train_fea_loss": 0.8112717479337855, "epoch": 147, "n_parameters": 144845568}
|
| 183 |
+
{"train_lr": 2.616696082115359e-05, "train_loss": 2.6700563189937627, "train_task_loss": 2.441784028437355, "train_bpp_loss": 0.31092876969329675, "train_patch_loss": 0.9618561120337541, "train_token_loss": 0.6746953771956604, "train_fea_loss": 0.8052325316769577, "epoch": 148, "n_parameters": 144845568}
|
| 184 |
+
{"train_lr": 2.616696082115359e-05, "train_loss": 2.6674406088138225, "train_task_loss": 2.4389501352229424, "train_bpp_loss": 0.3116684112110459, "train_patch_loss": 0.9607490962326563, "train_token_loss": 0.6741545937120093, "train_fea_loss": 0.8040464378075955, "epoch": 149, "n_parameters": 144845568}
|
| 185 |
+
{"train_lr": 2.558446917464184e-05, "train_loss": 2.6966678162129947, "train_task_loss": 2.4605731371191624, "train_bpp_loss": 0.32649256723287473, "train_patch_loss": 0.9711854177922439, "train_token_loss": 0.6771108717668792, "train_fea_loss": 0.81227684003129, "epoch": 150, "n_parameters": 144845568}
|
| 186 |
+
{"train_lr": 2.5010777464192224e-05, "train_loss": 2.649356059051103, "train_task_loss": 2.4205873620649463, "train_bpp_loss": 0.3137229763348844, "train_patch_loss": 0.9557904277322925, "train_token_loss": 0.6687184504994885, "train_fea_loss": 0.7960784757974372, "epoch": 151, "n_parameters": 144845568}
|
| 187 |
+
{"train_lr": 2.444602723963776e-05, "train_loss": 2.6504098884582663, "train_task_loss": 2.421797823491428, "train_bpp_loss": 0.3133131659475805, "train_patch_loss": 0.9554594567903191, "train_token_loss": 0.6693732499343743, "train_fea_loss": 0.7969651087689743, "epoch": 152, "n_parameters": 144845568}
|
| 188 |
+
{"train_lr": 2.3343906382349e-05, "train_loss": 2.7114228718429447, "train_task_loss": 2.458482252813572, "train_bpp_loss": 0.34088637049106674, "train_patch_loss": 0.9697519682405843, "train_token_loss": 0.6721772563775523, "train_fea_loss": 0.8165530218369812, "epoch": 153, "n_parameters": 144845568}
|
| 189 |
+
{"train_lr": 2.3343906382349e-05, "train_loss": 2.650369150729345, "train_task_loss": 2.4105861190197277, "train_bpp_loss": 0.31759558117231557, "train_patch_loss": 0.9528444103351671, "train_token_loss": 0.6638412714524896, "train_fea_loss": 0.7939004297111388, "epoch": 154, "n_parameters": 144845568}
|
| 190 |
+
{"train_lr": 2.280680768143689e-05, "train_loss": 2.658987253216459, "train_task_loss": 2.421917999400009, "train_bpp_loss": 0.31114791267567116, "train_patch_loss": 0.9561528907575124, "train_token_loss": 0.6675511025901821, "train_fea_loss": 0.7982139991355135, "epoch": 155, "n_parameters": 144845568}
|
| 191 |
+
{"train_lr": 2.2279194262997928e-05, "train_loss": 2.695716130978269, "train_task_loss": 2.4542629720031215, "train_bpp_loss": 0.31765079001362306, "train_patch_loss": 0.9665458119615055, "train_token_loss": 0.6746455496869428, "train_fea_loss": 0.8130716029219598, "epoch": 156, "n_parameters": 144845568}
|
| 192 |
+
{"train_lr": 2.1761196307742086e-05, "train_loss": 2.6704102983202436, "train_task_loss": 2.4334238473442102, "train_bpp_loss": 0.30996573550349693, "train_patch_loss": 0.9596143975134769, "train_token_loss": 0.6693641086682963, "train_fea_loss": 0.8044453353511629, "epoch": 157, "n_parameters": 144845568}
|
| 193 |
+
{"train_lr": 2.1252941623912577e-05, "train_loss": 2.6702361341735585, "train_task_loss": 2.4332374961917207, "train_bpp_loss": 0.3099473466777312, "train_patch_loss": 0.9602035106019043, "train_token_loss": 0.6687574936278003, "train_fea_loss": 0.8042764851122165, "epoch": 158, "n_parameters": 144845568}
|
| 194 |
+
{"train_lr": 2.0754555615745688e-05, "train_loss": 2.6676239087480864, "train_task_loss": 2.430506248768571, "train_bpp_loss": 0.31043111976030635, "train_patch_loss": 0.9591814159645701, "train_token_loss": 0.6682900374808864, "train_fea_loss": 0.8030347887748223, "epoch": 159, "n_parameters": 144845568}
|
| 195 |
+
{"train_lr": 2.0266161252534863e-05, "train_loss": 2.682328796143726, "train_task_loss": 2.4449838221484095, "train_bpp_loss": 0.3097276722284244, "train_patch_loss": 0.9639880869501858, "train_token_loss": 0.6714682123563487, "train_fea_loss": 0.8095275162444483, "epoch": 160, "n_parameters": 144845568}
|
| 196 |
+
{"train_lr": 1.9787879038283694e-05, "train_loss": 2.794825274291799, "train_task_loss": 2.5297502170381643, "train_bpp_loss": 0.3605138300236859, "train_patch_loss": 0.9950330435751761, "train_token_loss": 0.6891800253216526, "train_fea_loss": 0.8455371403780808, "epoch": 161, "n_parameters": 144845568}
|
| 197 |
+
{"train_lr": 1.9319826981968032e-05, "train_loss": 2.6933876661052234, "train_task_loss": 2.4523086957990836, "train_bpp_loss": 0.3166517836579995, "train_patch_loss": 0.9668755698480087, "train_token_loss": 0.6705846348107397, "train_fea_loss": 0.8148484852687596, "epoch": 162, "n_parameters": 144845568}
|
| 198 |
+
{"train_lr": 1.8862120568428674e-05, "train_loss": 2.7029373263426537, "train_task_loss": 2.4646165441873547, "train_bpp_loss": 0.31019026768110175, "train_patch_loss": 0.9700850130145927, "train_token_loss": 0.6760256816487875, "train_fea_loss": 0.8185058429220812, "epoch": 163, "n_parameters": 144845568}
|
| 199 |
+
{"train_lr": 1.8414872729877464e-05, "train_loss": 2.7035999098252192, "train_task_loss": 2.4655938773311012, "train_bpp_loss": 0.30946591779214444, "train_patch_loss": 0.9705398866317148, "train_token_loss": 0.6763441974645574, "train_fea_loss": 0.8187097842231107, "epoch": 164, "n_parameters": 144845568}
|
| 200 |
+
{"train_lr": 1.9319826981970397e-05, "train_loss": 2.740814351822904, "train_task_loss": 2.492047573066801, "train_bpp_loss": 0.329818556624592, "train_patch_loss": 0.9823994228093744, "train_token_loss": 0.6829682969436836, "train_fea_loss": 0.8266798452512144, "epoch": 161, "n_parameters": 144845568}
|
| 201 |
+
{"train_lr": 1.9319826981970397e-05, "train_loss": 2.7282523382946455, "train_task_loss": 2.48745193989943, "train_bpp_loss": 0.3139055231898773, "train_patch_loss": 0.9767762538854131, "train_token_loss": 0.6849351603363594, "train_fea_loss": 0.8257405170858156, "epoch": 162, "n_parameters": 144845568}
|
| 202 |
+
{"train_lr": 1.8862120568426702e-05, "train_loss": 2.7373931265259674, "train_task_loss": 2.493901286015122, "train_bpp_loss": 0.3186548652218954, "train_patch_loss": 0.9806770237019773, "train_token_loss": 0.6831493455675437, "train_fea_loss": 0.8300749089220445, "epoch": 163, "n_parameters": 144845568}
|
| 203 |
+
{"train_lr": 1.84148727298801e-05, "train_loss": 2.7362954941370505, "train_task_loss": 2.4902694111319184, "train_bpp_loss": 0.3243905476437395, "train_patch_loss": 0.9797143076258383, "train_token_loss": 0.6838951857773949, "train_fea_loss": 0.8266599099073622, "epoch": 164, "n_parameters": 144845568}
|
1_feature_extractor/log/DINOv2_training/log/20240725_001002.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
1_feature_extractor/log/DINOv2_training/log/20240725_084736.log
ADDED
|
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2024-07-25 08:47:36,358 [INFO ] Logging file is /data0/qiyp/Proteus-pytorch/pretrain/log/DINOv2_training/log//20240725_084736.log
|
| 2 |
+
2024-07-25 08:47:36,359 [INFO ] job dir: /data0/qiyp/Proteus-pytorch/pretrain
|
| 3 |
+
2024-07-25 08:47:36,359 [INFO ] Namespace(batch_size=48,
|
| 4 |
+
epochs=200,
|
| 5 |
+
bce_loss=False,
|
| 6 |
+
unscale_lr=False,
|
| 7 |
+
model='models_proteus_dinov2',
|
| 8 |
+
target_model='vit_base',
|
| 9 |
+
input_size=224,
|
| 10 |
+
drop=0.0,
|
| 11 |
+
drop_path=0.1,
|
| 12 |
+
model_ema=True,
|
| 13 |
+
model_ema_decay=0.99996,
|
| 14 |
+
model_ema_force_cpu=False,
|
| 15 |
+
opt='adamw',
|
| 16 |
+
opt_eps=1e-08,
|
| 17 |
+
opt_betas=None,
|
| 18 |
+
clip_grad=None,
|
| 19 |
+
momentum=0.9,
|
| 20 |
+
weight_decay=0.05,
|
| 21 |
+
sched='cosine',
|
| 22 |
+
lr=0.0002,
|
| 23 |
+
lr_noise=None,
|
| 24 |
+
lr_noise_pct=0.67,
|
| 25 |
+
lr_noise_std=1.0,
|
| 26 |
+
warmup_lr=1e-06,
|
| 27 |
+
min_lr=1e-05,
|
| 28 |
+
decay_epochs=30,
|
| 29 |
+
warmup_epochs=5,
|
| 30 |
+
cooldown_epochs=10,
|
| 31 |
+
patience_epochs=10,
|
| 32 |
+
decay_rate=0.1,
|
| 33 |
+
color_jitter=0.3,
|
| 34 |
+
aa='rand-m9-mstd0.5-inc1',
|
| 35 |
+
smoothing=0.1,
|
| 36 |
+
train_interpolation='bicubic',
|
| 37 |
+
repeated_aug=True,
|
| 38 |
+
train_mode=True,
|
| 39 |
+
ThreeAugment=False,
|
| 40 |
+
src=False,
|
| 41 |
+
global_crops_size=224,
|
| 42 |
+
patch_size=14,
|
| 43 |
+
mask_ratio=[0.5],
|
| 44 |
+
mask_probability=0.5,
|
| 45 |
+
mask_first_n=True,
|
| 46 |
+
clone_batch=1,
|
| 47 |
+
reprob=0.25,
|
| 48 |
+
remode='pixel',
|
| 49 |
+
recount=1,
|
| 50 |
+
resplit=False,
|
| 51 |
+
mixup=0.8,
|
| 52 |
+
cutmix=1.0,
|
| 53 |
+
cutmix_minmax=None,
|
| 54 |
+
mixup_prob=1.0,
|
| 55 |
+
mixup_switch_prob=0.5,
|
| 56 |
+
mixup_mode='batch',
|
| 57 |
+
teacher_model='vit_large',
|
| 58 |
+
teacher_path='',
|
| 59 |
+
distillation_type='none',
|
| 60 |
+
distillation_alpha=0.5,
|
| 61 |
+
distillation_tau=1.0,
|
| 62 |
+
lambda_token=1.0,
|
| 63 |
+
lambda_fea=1.0,
|
| 64 |
+
lambda_patch=1.0,
|
| 65 |
+
cosub=False,
|
| 66 |
+
finetune='/data0/qiyp/Proteus-pytorch/pretrain/ckpt/proteus_vitb_backbone.pth',
|
| 67 |
+
attn_only=False,
|
| 68 |
+
weight_inherit='',
|
| 69 |
+
data_path='/data1/datasets/imagenet_fold',
|
| 70 |
+
data_set='IMNET',
|
| 71 |
+
inat_category='name',
|
| 72 |
+
output_dir='log/DINOv2_training',
|
| 73 |
+
log_dir='/data0/qiyp/Proteus-pytorch/pretrain/log/DINOv2_training/log/',
|
| 74 |
+
device='cuda',
|
| 75 |
+
seed=0,
|
| 76 |
+
resume='',
|
| 77 |
+
start_epoch=0,
|
| 78 |
+
eval=False,
|
| 79 |
+
eval_crop_ratio=0.875,
|
| 80 |
+
dist_eval=False,
|
| 81 |
+
num_workers=10,
|
| 82 |
+
pin_mem=True,
|
| 83 |
+
distributed=True,
|
| 84 |
+
world_size=6,
|
| 85 |
+
dist_url='env://',
|
| 86 |
+
rank=0,
|
| 87 |
+
gpu=0,
|
| 88 |
+
dist_backend='nccl')
|
| 89 |
+
|
| 90 |
+
2024-07-25 08:47:38,916 [INFO ] Dataset ImageFolder
|
| 91 |
+
Number of datapoints: 1281167
|
| 92 |
+
Root location: /data1/datasets/imagenet_fold/train
|
| 93 |
+
StandardTransform
|
| 94 |
+
Transform: Compose(
|
| 95 |
+
RandomResizedCropAndInterpolation(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=bicubic)
|
| 96 |
+
RandomHorizontalFlip(p=0.5)
|
| 97 |
+
RandAugment(n=2, ops=
|
| 98 |
+
AugmentOp(name=AutoContrast, p=0.5, m=9, mstd=0.5)
|
| 99 |
+
AugmentOp(name=Equalize, p=0.5, m=9, mstd=0.5)
|
| 100 |
+
AugmentOp(name=Invert, p=0.5, m=9, mstd=0.5)
|
| 101 |
+
AugmentOp(name=Rotate, p=0.5, m=9, mstd=0.5)
|
| 102 |
+
AugmentOp(name=PosterizeIncreasing, p=0.5, m=9, mstd=0.5)
|
| 103 |
+
AugmentOp(name=SolarizeIncreasing, p=0.5, m=9, mstd=0.5)
|
| 104 |
+
AugmentOp(name=SolarizeAdd, p=0.5, m=9, mstd=0.5)
|
| 105 |
+
AugmentOp(name=ColorIncreasing, p=0.5, m=9, mstd=0.5)
|
| 106 |
+
AugmentOp(name=ContrastIncreasing, p=0.5, m=9, mstd=0.5)
|
| 107 |
+
AugmentOp(name=BrightnessIncreasing, p=0.5, m=9, mstd=0.5)
|
| 108 |
+
AugmentOp(name=SharpnessIncreasing, p=0.5, m=9, mstd=0.5)
|
| 109 |
+
AugmentOp(name=ShearX, p=0.5, m=9, mstd=0.5)
|
| 110 |
+
AugmentOp(name=ShearY, p=0.5, m=9, mstd=0.5)
|
| 111 |
+
AugmentOp(name=TranslateXRel, p=0.5, m=9, mstd=0.5)
|
| 112 |
+
AugmentOp(name=TranslateYRel, p=0.5, m=9, mstd=0.5))
|
| 113 |
+
ToTensor()
|
| 114 |
+
Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
|
| 115 |
+
RandomErasing(p=0.25, mode=pixel, count=(1, 1))
|
| 116 |
+
)
|
| 117 |
+
2024-07-25 08:47:38,926 [INFO ] Sampler_train = <samplers.RASampler object at 0x7fc548b7f3a0>
|
| 118 |
+
2024-07-25 08:47:38,928 [INFO ] Created a temporary directory at /tmp/tmpe0qnk5ic
|
| 119 |
+
2024-07-25 08:47:38,929 [INFO ] Writing /tmp/tmpe0qnk5ic/_remote_module_non_scriptable.py
|
| 120 |
+
2024-07-25 08:47:40,640 [INFO ] using MLP layer as FFN
|
| 121 |
+
2024-07-25 08:49:51,514 [INFO ] using MLP layer as FFN
|
| 122 |
+
2024-07-25 08:49:57,349 [INFO ] Model = MetaArch(
|
| 123 |
+
(student): ModuleDict(
|
| 124 |
+
(backbone): DinoVisionTransformer(
|
| 125 |
+
(patch_embed): PatchEmbed(
|
| 126 |
+
(proj): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
|
| 127 |
+
(norm): Identity()
|
| 128 |
+
)
|
| 129 |
+
(blocks): ModuleList(
|
| 130 |
+
(0-11): 12 x Block(
|
| 131 |
+
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
|
| 132 |
+
(attn): MemEffAttention(
|
| 133 |
+
(qkv): Linear(in_features=768, out_features=2304, bias=True)
|
| 134 |
+
(attn_drop): Dropout(p=0.0, inplace=False)
|
| 135 |
+
(proj): Linear(in_features=768, out_features=768, bias=True)
|
| 136 |
+
(proj_drop): Dropout(p=0.0, inplace=False)
|
| 137 |
+
)
|
| 138 |
+
(ls1): LayerScale()
|
| 139 |
+
(drop_path1): Identity()
|
| 140 |
+
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
|
| 141 |
+
(mlp): Mlp(
|
| 142 |
+
(fc1): Linear(in_features=768, out_features=3072, bias=True)
|
| 143 |
+
(act): GELU(approximate='none')
|
| 144 |
+
(fc2): Linear(in_features=3072, out_features=768, bias=True)
|
| 145 |
+
(drop): Dropout(p=0.0, inplace=False)
|
| 146 |
+
)
|
| 147 |
+
(ls2): LayerScale()
|
| 148 |
+
(drop_path2): Identity()
|
| 149 |
+
)
|
| 150 |
+
)
|
| 151 |
+
(norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
|
| 152 |
+
(head): Identity()
|
| 153 |
+
)
|
| 154 |
+
)
|
| 155 |
+
(teacher): ModuleDict(
|
| 156 |
+
(backbone): DinoVisionTransformer(
|
| 157 |
+
(patch_embed): PatchEmbed(
|
| 158 |
+
(proj): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14))
|
| 159 |
+
(norm): Identity()
|
| 160 |
+
)
|
| 161 |
+
(blocks): ModuleList(
|
| 162 |
+
(0-23): 24 x NestedTensorBlock(
|
| 163 |
+
(norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
|
| 164 |
+
(attn): MemEffAttention(
|
| 165 |
+
(qkv): Linear(in_features=1024, out_features=3072, bias=True)
|
| 166 |
+
(attn_drop): Dropout(p=0.0, inplace=False)
|
| 167 |
+
(proj): Linear(in_features=1024, out_features=1024, bias=True)
|
| 168 |
+
(proj_drop): Dropout(p=0.0, inplace=False)
|
| 169 |
+
)
|
| 170 |
+
(ls1): LayerScale()
|
| 171 |
+
(drop_path1): Identity()
|
| 172 |
+
(norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
|
| 173 |
+
(mlp): Mlp(
|
| 174 |
+
(fc1): Linear(in_features=1024, out_features=4096, bias=True)
|
| 175 |
+
(act): GELU(approximate='none')
|
| 176 |
+
(fc2): Linear(in_features=4096, out_features=1024, bias=True)
|
| 177 |
+
(drop): Dropout(p=0.0, inplace=False)
|
| 178 |
+
)
|
| 179 |
+
(ls2): LayerScale()
|
| 180 |
+
(drop_path2): Identity()
|
| 181 |
+
)
|
| 182 |
+
)
|
| 183 |
+
(norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
|
| 184 |
+
(head): Identity()
|
| 185 |
+
)
|
| 186 |
+
)
|
| 187 |
+
(patch_head): Sequential(
|
| 188 |
+
(0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| 189 |
+
(1): Linear(in_features=768, out_features=1024, bias=True)
|
| 190 |
+
)
|
| 191 |
+
(token_head): Sequential(
|
| 192 |
+
(0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| 193 |
+
(1): Linear(in_features=768, out_features=1024, bias=True)
|
| 194 |
+
)
|
| 195 |
+
(fea_head): Sequential(
|
| 196 |
+
(0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| 197 |
+
(1): Linear(in_features=768, out_features=1024, bias=True)
|
| 198 |
+
)
|
| 199 |
+
(soft_criterion): MSELoss()
|
| 200 |
+
(info_bottleneck): IF_Module(
|
| 201 |
+
(encoder_blocks): ModuleList(
|
| 202 |
+
(0-3): 4 x Block(
|
| 203 |
+
(norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| 204 |
+
(attn): Attention(
|
| 205 |
+
(qkv): Linear(in_features=768, out_features=2304, bias=True)
|
| 206 |
+
(q_norm): Identity()
|
| 207 |
+
(k_norm): Identity()
|
| 208 |
+
(attn_drop): Dropout(p=0.0, inplace=False)
|
| 209 |
+
(proj): Linear(in_features=768, out_features=768, bias=True)
|
| 210 |
+
(proj_drop): Dropout(p=0.0, inplace=False)
|
| 211 |
+
)
|
| 212 |
+
(ls1): Identity()
|
| 213 |
+
(drop_path1): Identity()
|
| 214 |
+
(norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| 215 |
+
(mlp): Mlp(
|
| 216 |
+
(fc1): Linear(in_features=768, out_features=3072, bias=True)
|
| 217 |
+
(act): GELU(approximate='none')
|
| 218 |
+
(drop1): Dropout(p=0.0, inplace=False)
|
| 219 |
+
(norm): Identity()
|
| 220 |
+
(fc2): Linear(in_features=3072, out_features=768, bias=True)
|
| 221 |
+
(drop2): Dropout(p=0.0, inplace=False)
|
| 222 |
+
)
|
| 223 |
+
(ls2): Identity()
|
| 224 |
+
(drop_path2): Identity()
|
| 225 |
+
)
|
| 226 |
+
)
|
| 227 |
+
(encoder_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| 228 |
+
(decoder_blocks): ModuleList(
|
| 229 |
+
(0-3): 4 x Block(
|
| 230 |
+
(norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| 231 |
+
(attn): Attention(
|
| 232 |
+
(qkv): Linear(in_features=768, out_features=2304, bias=True)
|
| 233 |
+
(q_norm): Identity()
|
| 234 |
+
(k_norm): Identity()
|
| 235 |
+
(attn_drop): Dropout(p=0.0, inplace=False)
|
| 236 |
+
(proj): Linear(in_features=768, out_features=768, bias=True)
|
| 237 |
+
(proj_drop): Dropout(p=0.0, inplace=False)
|
| 238 |
+
)
|
| 239 |
+
(ls1): Identity()
|
| 240 |
+
(drop_path1): Identity()
|
| 241 |
+
(norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| 242 |
+
(mlp): Mlp(
|
| 243 |
+
(fc1): Linear(in_features=768, out_features=3072, bias=True)
|
| 244 |
+
(act): GELU(approximate='none')
|
| 245 |
+
(drop1): Dropout(p=0.0, inplace=False)
|
| 246 |
+
(norm): Identity()
|
| 247 |
+
(fc2): Linear(in_features=3072, out_features=768, bias=True)
|
| 248 |
+
(drop2): Dropout(p=0.0, inplace=False)
|
| 249 |
+
)
|
| 250 |
+
(ls2): Identity()
|
| 251 |
+
(drop_path2): Identity()
|
| 252 |
+
)
|
| 253 |
+
)
|
| 254 |
+
(decoder_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| 255 |
+
(entropy_bottleneck): EntropyBottleneck(
|
| 256 |
+
(likelihood_lower_bound): LowerBound()
|
| 257 |
+
)
|
| 258 |
+
)
|
| 259 |
+
)
|
| 260 |
+
2024-07-25 08:49:59,702 [INFO ] Finetuning from /data0/qiyp/Proteus-pytorch/pretrain/ckpt/proteus_vitb_backbone.pth
|
| 261 |
+
2024-07-25 08:49:59,703 [INFO ] missing_keys: ['patch_head.0.weight', 'patch_head.0.bias', 'patch_head.1.weight', 'patch_head.1.bias', 'info_bottleneck.encoder_blocks.0.norm1.weight', 'info_bottleneck.encoder_blocks.0.norm1.bias', 'info_bottleneck.encoder_blocks.0.attn.qkv.weight', 'info_bottleneck.encoder_blocks.0.attn.qkv.bias', 'info_bottleneck.encoder_blocks.0.attn.proj.weight', 'info_bottleneck.encoder_blocks.0.attn.proj.bias', 'info_bottleneck.encoder_blocks.0.norm2.weight', 'info_bottleneck.encoder_blocks.0.norm2.bias', 'info_bottleneck.encoder_blocks.0.mlp.fc1.weight', 'info_bottleneck.encoder_blocks.0.mlp.fc1.bias', 'info_bottleneck.encoder_blocks.0.mlp.fc2.weight', 'info_bottleneck.encoder_blocks.0.mlp.fc2.bias', 'info_bottleneck.encoder_blocks.1.norm1.weight', 'info_bottleneck.encoder_blocks.1.norm1.bias', 'info_bottleneck.encoder_blocks.1.attn.qkv.weight', 'info_bottleneck.encoder_blocks.1.attn.qkv.bias', 'info_bottleneck.encoder_blocks.1.attn.proj.weight', 'info_bottleneck.encoder_blocks.1.attn.proj.bias', 'info_bottleneck.encoder_blocks.1.norm2.weight', 'info_bottleneck.encoder_blocks.1.norm2.bias', 'info_bottleneck.encoder_blocks.1.mlp.fc1.weight', 'info_bottleneck.encoder_blocks.1.mlp.fc1.bias', 'info_bottleneck.encoder_blocks.1.mlp.fc2.weight', 'info_bottleneck.encoder_blocks.1.mlp.fc2.bias', 'info_bottleneck.encoder_blocks.2.norm1.weight', 'info_bottleneck.encoder_blocks.2.norm1.bias', 'info_bottleneck.encoder_blocks.2.attn.qkv.weight', 'info_bottleneck.encoder_blocks.2.attn.qkv.bias', 'info_bottleneck.encoder_blocks.2.attn.proj.weight', 'info_bottleneck.encoder_blocks.2.attn.proj.bias', 'info_bottleneck.encoder_blocks.2.norm2.weight', 'info_bottleneck.encoder_blocks.2.norm2.bias', 'info_bottleneck.encoder_blocks.2.mlp.fc1.weight', 'info_bottleneck.encoder_blocks.2.mlp.fc1.bias', 'info_bottleneck.encoder_blocks.2.mlp.fc2.weight', 'info_bottleneck.encoder_blocks.2.mlp.fc2.bias', 'info_bottleneck.encoder_blocks.3.norm1.weight', 'info_bottleneck.encoder_blocks.3.norm1.bias', 'info_bottleneck.encoder_blocks.3.attn.qkv.weight', 'info_bottleneck.encoder_blocks.3.attn.qkv.bias', 'info_bottleneck.encoder_blocks.3.attn.proj.weight', 'info_bottleneck.encoder_blocks.3.attn.proj.bias', 'info_bottleneck.encoder_blocks.3.norm2.weight', 'info_bottleneck.encoder_blocks.3.norm2.bias', 'info_bottleneck.encoder_blocks.3.mlp.fc1.weight', 'info_bottleneck.encoder_blocks.3.mlp.fc1.bias', 'info_bottleneck.encoder_blocks.3.mlp.fc2.weight', 'info_bottleneck.encoder_blocks.3.mlp.fc2.bias', 'info_bottleneck.encoder_norm.weight', 'info_bottleneck.encoder_norm.bias', 'info_bottleneck.decoder_blocks.0.norm1.weight', 'info_bottleneck.decoder_blocks.0.norm1.bias', 'info_bottleneck.decoder_blocks.0.attn.qkv.weight', 'info_bottleneck.decoder_blocks.0.attn.qkv.bias', 'info_bottleneck.decoder_blocks.0.attn.proj.weight', 'info_bottleneck.decoder_blocks.0.attn.proj.bias', 'info_bottleneck.decoder_blocks.0.norm2.weight', 'info_bottleneck.decoder_blocks.0.norm2.bias', 'info_bottleneck.decoder_blocks.0.mlp.fc1.weight', 'info_bottleneck.decoder_blocks.0.mlp.fc1.bias', 'info_bottleneck.decoder_blocks.0.mlp.fc2.weight', 'info_bottleneck.decoder_blocks.0.mlp.fc2.bias', 'info_bottleneck.decoder_blocks.1.norm1.weight', 'info_bottleneck.decoder_blocks.1.norm1.bias', 'info_bottleneck.decoder_blocks.1.attn.qkv.weight', 'info_bottleneck.decoder_blocks.1.attn.qkv.bias', 'info_bottleneck.decoder_blocks.1.attn.proj.weight', 'info_bottleneck.decoder_blocks.1.attn.proj.bias', 'info_bottleneck.decoder_blocks.1.norm2.weight', 'info_bottleneck.decoder_blocks.1.norm2.bias', 'info_bottleneck.decoder_blocks.1.mlp.fc1.weight', 'info_bottleneck.decoder_blocks.1.mlp.fc1.bias', 'info_bottleneck.decoder_blocks.1.mlp.fc2.weight', 'info_bottleneck.decoder_blocks.1.mlp.fc2.bias', 'info_bottleneck.decoder_blocks.2.norm1.weight', 'info_bottleneck.decoder_blocks.2.norm1.bias', 'info_bottleneck.decoder_blocks.2.attn.qkv.weight', 'info_bottleneck.decoder_blocks.2.attn.qkv.bias', 'info_bottleneck.decoder_blocks.2.attn.proj.weight', 'info_bottleneck.decoder_blocks.2.attn.proj.bias', 'info_bottleneck.decoder_blocks.2.norm2.weight', 'info_bottleneck.decoder_blocks.2.norm2.bias', 'info_bottleneck.decoder_blocks.2.mlp.fc1.weight', 'info_bottleneck.decoder_blocks.2.mlp.fc1.bias', 'info_bottleneck.decoder_blocks.2.mlp.fc2.weight', 'info_bottleneck.decoder_blocks.2.mlp.fc2.bias', 'info_bottleneck.decoder_blocks.3.norm1.weight', 'info_bottleneck.decoder_blocks.3.norm1.bias', 'info_bottleneck.decoder_blocks.3.attn.qkv.weight', 'info_bottleneck.decoder_blocks.3.attn.qkv.bias', 'info_bottleneck.decoder_blocks.3.attn.proj.weight', 'info_bottleneck.decoder_blocks.3.attn.proj.bias', 'info_bottleneck.decoder_blocks.3.norm2.weight', 'info_bottleneck.decoder_blocks.3.norm2.bias', 'info_bottleneck.decoder_blocks.3.mlp.fc1.weight', 'info_bottleneck.decoder_blocks.3.mlp.fc1.bias', 'info_bottleneck.decoder_blocks.3.mlp.fc2.weight', 'info_bottleneck.decoder_blocks.3.mlp.fc2.bias', 'info_bottleneck.decoder_norm.weight', 'info_bottleneck.decoder_norm.bias', 'info_bottleneck.entropy_bottleneck._matrix0', 'info_bottleneck.entropy_bottleneck._bias0', 'info_bottleneck.entropy_bottleneck._factor0', 'info_bottleneck.entropy_bottleneck._matrix1', 'info_bottleneck.entropy_bottleneck._bias1', 'info_bottleneck.entropy_bottleneck._factor1', 'info_bottleneck.entropy_bottleneck._matrix2', 'info_bottleneck.entropy_bottleneck._bias2', 'info_bottleneck.entropy_bottleneck._factor2', 'info_bottleneck.entropy_bottleneck._matrix3', 'info_bottleneck.entropy_bottleneck._bias3', 'info_bottleneck.entropy_bottleneck._factor3', 'info_bottleneck.entropy_bottleneck._matrix4', 'info_bottleneck.entropy_bottleneck._bias4', 'info_bottleneck.entropy_bottleneck.quantiles', 'info_bottleneck.entropy_bottleneck._offset', 'info_bottleneck.entropy_bottleneck._quantized_cdf', 'info_bottleneck.entropy_bottleneck._cdf_length', 'info_bottleneck.entropy_bottleneck.target', 'info_bottleneck.entropy_bottleneck.likelihood_lower_bound.bound']
|
| 262 |
+
2024-07-25 08:49:59,703 [INFO ] unexpected_keys: ['ibot_head.0.weight', 'ibot_head.0.bias', 'ibot_head.1.weight', 'ibot_head.1.bias']
|
| 263 |
+
2024-07-25 08:50:00,636 [INFO ] number of params: 144845568
|
| 264 |
+
2024-07-25 08:50:00,636 [INFO ] base lr = 0.0002
|
| 265 |
+
2024-07-25 08:50:00,636 [INFO ] actural lr = 0.00011250000000000001
|
| 266 |
+
2024-07-25 08:50:00,639 [INFO ] Start training for 200 epochs
|
| 267 |
+
2024-07-25 08:50:00,642 [INFO ] Parameters to be updated:
|
| 268 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.0.norm2.weight
|
| 269 |
+
|
| 270 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.2.norm2.weight
|
| 271 |
+
|
| 272 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.1.attn.proj.bias
|
| 273 |
+
|
| 274 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.1.attn.qkv.weight
|
| 275 |
+
|
| 276 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.1.norm2.bias
|
| 277 |
+
|
| 278 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.0.norm2.bias
|
| 279 |
+
|
| 280 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.1.mlp.fc2.weight
|
| 281 |
+
|
| 282 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.entropy_bottleneck._bias0
|
| 283 |
+
|
| 284 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.2.attn.proj.bias
|
| 285 |
+
|
| 286 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.1.mlp.fc2.bias
|
| 287 |
+
|
| 288 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.0.norm1.bias
|
| 289 |
+
|
| 290 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.entropy_bottleneck._matrix0
|
| 291 |
+
|
| 292 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.entropy_bottleneck._factor2
|
| 293 |
+
|
| 294 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.3.attn.proj.bias
|
| 295 |
+
|
| 296 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.1.attn.proj.bias
|
| 297 |
+
|
| 298 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.2.norm1.weight
|
| 299 |
+
|
| 300 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.1.mlp.fc2.bias
|
| 301 |
+
|
| 302 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.entropy_bottleneck._matrix4
|
| 303 |
+
|
| 304 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.3.mlp.fc2.bias
|
| 305 |
+
|
| 306 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.3.norm1.weight
|
| 307 |
+
|
| 308 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.0.mlp.fc2.bias
|
| 309 |
+
|
| 310 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.0.norm1.bias
|
| 311 |
+
|
| 312 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.2.norm2.bias
|
| 313 |
+
|
| 314 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.3.norm1.bias
|
| 315 |
+
|
| 316 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.entropy_bottleneck._bias2
|
| 317 |
+
|
| 318 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.2.attn.qkv.weight
|
| 319 |
+
|
| 320 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.3.attn.qkv.bias
|
| 321 |
+
|
| 322 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.1.norm1.bias
|
| 323 |
+
|
| 324 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.1.attn.qkv.weight
|
| 325 |
+
|
| 326 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.0.norm2.weight
|
| 327 |
+
|
| 328 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.2.attn.proj.weight
|
| 329 |
+
|
| 330 |
+
2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.2.mlp.fc2.weight
|
| 331 |
+
|
| 332 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.3.attn.proj.weight
|
| 333 |
+
|
| 334 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.3.mlp.fc2.weight
|
| 335 |
+
|
| 336 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.3.norm1.weight
|
| 337 |
+
|
| 338 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.0.mlp.fc1.bias
|
| 339 |
+
|
| 340 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.entropy_bottleneck._bias1
|
| 341 |
+
|
| 342 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.1.norm2.bias
|
| 343 |
+
|
| 344 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.0.attn.qkv.weight
|
| 345 |
+
|
| 346 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.2.norm2.bias
|
| 347 |
+
|
| 348 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.2.attn.qkv.bias
|
| 349 |
+
|
| 350 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.1.norm1.bias
|
| 351 |
+
|
| 352 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.1.mlp.fc1.bias
|
| 353 |
+
|
| 354 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.1.mlp.fc1.bias
|
| 355 |
+
|
| 356 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.2.mlp.fc1.weight
|
| 357 |
+
|
| 358 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.entropy_bottleneck._bias3
|
| 359 |
+
|
| 360 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.2.norm2.weight
|
| 361 |
+
|
| 362 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.2.mlp.fc1.bias
|
| 363 |
+
|
| 364 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.1.norm2.weight
|
| 365 |
+
|
| 366 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.0.norm1.weight
|
| 367 |
+
|
| 368 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.0.attn.qkv.bias
|
| 369 |
+
|
| 370 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.3.norm2.weight
|
| 371 |
+
|
| 372 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.3.mlp.fc2.bias
|
| 373 |
+
|
| 374 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.0.mlp.fc1.weight
|
| 375 |
+
|
| 376 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.1.attn.proj.weight
|
| 377 |
+
|
| 378 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_norm.bias
|
| 379 |
+
|
| 380 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_norm.weight
|
| 381 |
+
|
| 382 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.3.mlp.fc2.weight
|
| 383 |
+
|
| 384 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.3.norm2.weight
|
| 385 |
+
|
| 386 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.1.norm1.weight
|
| 387 |
+
|
| 388 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.entropy_bottleneck.quantiles
|
| 389 |
+
|
| 390 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.entropy_bottleneck._matrix2
|
| 391 |
+
|
| 392 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.0.attn.proj.weight
|
| 393 |
+
|
| 394 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.3.attn.qkv.bias
|
| 395 |
+
|
| 396 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.3.mlp.fc1.weight
|
| 397 |
+
|
| 398 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.2.attn.proj.weight
|
| 399 |
+
|
| 400 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.3.norm2.bias
|
| 401 |
+
|
| 402 |
+
2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.2.mlp.fc2.bias
|
| 403 |
+
|
| 404 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.2.mlp.fc2.bias
|
| 405 |
+
|
| 406 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.3.attn.proj.bias
|
| 407 |
+
|
| 408 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.2.attn.qkv.weight
|
| 409 |
+
|
| 410 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.entropy_bottleneck._matrix1
|
| 411 |
+
|
| 412 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.3.attn.proj.weight
|
| 413 |
+
|
| 414 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.0.attn.qkv.bias
|
| 415 |
+
|
| 416 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.1.mlp.fc2.weight
|
| 417 |
+
|
| 418 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.0.mlp.fc2.bias
|
| 419 |
+
|
| 420 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.1.norm1.weight
|
| 421 |
+
|
| 422 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.0.attn.proj.weight
|
| 423 |
+
|
| 424 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.2.mlp.fc2.weight
|
| 425 |
+
|
| 426 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.2.mlp.fc1.bias
|
| 427 |
+
|
| 428 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.1.mlp.fc1.weight
|
| 429 |
+
|
| 430 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.0.norm1.weight
|
| 431 |
+
|
| 432 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.3.attn.qkv.weight
|
| 433 |
+
|
| 434 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.3.norm1.bias
|
| 435 |
+
|
| 436 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.2.attn.qkv.bias
|
| 437 |
+
|
| 438 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.2.attn.proj.bias
|
| 439 |
+
|
| 440 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.2.norm1.bias
|
| 441 |
+
|
| 442 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.0.mlp.fc2.weight
|
| 443 |
+
|
| 444 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.1.mlp.fc1.weight
|
| 445 |
+
|
| 446 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.2.mlp.fc1.weight
|
| 447 |
+
|
| 448 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.3.attn.qkv.weight
|
| 449 |
+
|
| 450 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.entropy_bottleneck._matrix3
|
| 451 |
+
|
| 452 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_norm.bias
|
| 453 |
+
|
| 454 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.1.norm2.weight
|
| 455 |
+
|
| 456 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.0.mlp.fc2.weight
|
| 457 |
+
|
| 458 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.0.mlp.fc1.bias
|
| 459 |
+
|
| 460 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.0.attn.proj.bias
|
| 461 |
+
|
| 462 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.entropy_bottleneck._factor0
|
| 463 |
+
|
| 464 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.1.attn.proj.weight
|
| 465 |
+
|
| 466 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.entropy_bottleneck._bias4
|
| 467 |
+
|
| 468 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.entropy_bottleneck._factor3
|
| 469 |
+
|
| 470 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.0.mlp.fc1.weight
|
| 471 |
+
|
| 472 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.2.norm1.weight
|
| 473 |
+
|
| 474 |
+
2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.3.norm2.bias
|
| 475 |
+
|
| 476 |
+
2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.decoder_norm.weight
|
| 477 |
+
|
| 478 |
+
2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.decoder_blocks.0.norm2.bias
|
| 479 |
+
|
| 480 |
+
2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.decoder_blocks.1.attn.qkv.bias
|
| 481 |
+
|
| 482 |
+
2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.decoder_blocks.2.norm1.bias
|
| 483 |
+
|
| 484 |
+
2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.decoder_blocks.0.attn.qkv.weight
|
| 485 |
+
|
| 486 |
+
2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.encoder_blocks.3.mlp.fc1.bias
|
| 487 |
+
|
| 488 |
+
2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.decoder_blocks.3.mlp.fc1.weight
|
| 489 |
+
|
| 490 |
+
2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.encoder_blocks.0.attn.proj.bias
|
| 491 |
+
|
| 492 |
+
2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.encoder_blocks.1.attn.qkv.bias
|
| 493 |
+
|
| 494 |
+
2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.decoder_blocks.3.mlp.fc1.bias
|
| 495 |
+
|
| 496 |
+
2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.entropy_bottleneck._factor1
|
| 497 |
+
|
| 498 |
+
2024-07-25 08:50:00,645 [INFO ]
|
| 499 |
+
|
| 500 |
+
2024-07-25 08:50:06,715 [INFO ] Epoch: [0] [ 0/4448] eta: 7:29:47 lr: 0.000001 loss: 8.3333 (8.3333) task_loss: 5.5091 (5.5091) bpp_loss: 28.2426 (28.2426) patch_loss: 2.4117 (2.4117) token_loss: 0.6234 (0.6234) fea_loss: 2.4741 (2.4741) time: 6.0673 data: 2.4582 max mem: 10534
|
| 501 |
+
2024-07-25 08:50:16,045 [INFO ] Epoch: [0] [ 10/4448] eta: 1:43:31 lr: 0.000001 loss: 8.3085 (8.3065) task_loss: 5.4842 (5.4822) bpp_loss: 28.2425 (28.2426) patch_loss: 2.4007 (2.4001) token_loss: 0.6673 (0.6727) fea_loss: 2.4212 (2.4095) time: 1.3996 data: 0.2236 max mem: 11190
|
| 502 |
+
2024-07-25 08:50:25,369 [INFO ] Epoch: [0] [ 20/4448] eta: 1:26:52 lr: 0.000001 loss: 8.2112 (8.2280) task_loss: 5.3869 (5.4038) bpp_loss: 28.2425 (28.2427) patch_loss: 2.3943 (2.3977) token_loss: 0.6668 (0.6682) fea_loss: 2.3140 (2.3378) time: 0.9326 data: 0.0001 max mem: 11191
|
| 503 |
+
2024-07-25 08:50:34,728 [INFO ] Epoch: [0] [ 30/4448] eta: 1:20:56 lr: 0.000001 loss: 8.0968 (8.1642) task_loss: 5.2725 (5.3399) bpp_loss: 28.2428 (28.2427) patch_loss: 2.3875 (2.3914) token_loss: 0.6681 (0.6704) fea_loss: 2.2053 (2.2782) time: 0.9341 data: 0.0001 max mem: 11191
|
| 504 |
+
2024-07-25 08:50:44,122 [INFO ] Epoch: [0] [ 40/4448] eta: 1:17:53 lr: 0.000001 loss: 7.9714 (8.1093) task_loss: 5.1472 (5.2851) bpp_loss: 28.2426 (28.2427) patch_loss: 2.3728 (2.3870) token_loss: 0.6831 (0.6721) fea_loss: 2.0939 (2.2259) time: 0.9375 data: 0.0001 max mem: 11191
|
| 505 |
+
2024-07-25 08:50:53,546 [INFO ] Epoch: [0] [ 50/4448] eta: 1:16:01 lr: 0.000001 loss: 7.8988 (8.0590) task_loss: 5.0745 (5.2347) bpp_loss: 28.2426 (28.2427) patch_loss: 2.3717 (2.3832) token_loss: 0.6707 (0.6729) fea_loss: 2.0184 (2.1786) time: 0.9408 data: 0.0001 max mem: 11192
|
| 506 |
+
2024-07-25 08:51:02,993 [INFO ] Epoch: [0] [ 60/4448] eta: 1:14:44 lr: 0.000001 loss: 7.8305 (8.0147) task_loss: 5.0062 (5.1904) bpp_loss: 28.2426 (28.2427) patch_loss: 2.3574 (2.3787) token_loss: 0.6676 (0.6735) fea_loss: 1.9486 (2.1383) time: 0.9435 data: 0.0001 max mem: 11192
|
| 507 |
+
2024-07-25 08:51:12,484 [INFO ] Epoch: [0] [ 70/4448] eta: 1:13:49 lr: 0.000001 loss: 7.7545 (7.9697) task_loss: 4.9303 (5.1454) bpp_loss: 28.2426 (28.2427) patch_loss: 2.3489 (2.3734) token_loss: 0.6659 (0.6675) fea_loss: 1.9142 (2.1045) time: 0.9469 data: 0.0001 max mem: 11192
|
| 508 |
+
2024-07-25 08:51:21,991 [INFO ] Epoch: [0] [ 80/4448] eta: 1:13:06 lr: 0.000001 loss: 7.6890 (7.9328) task_loss: 4.8647 (5.1085) bpp_loss: 28.2428 (28.2427) patch_loss: 2.3404 (2.3694) token_loss: 0.6583 (0.6661) fea_loss: 1.8710 (2.0731) time: 0.9498 data: 0.0001 max mem: 11194
|
| 509 |
+
2024-07-25 08:51:31,508 [INFO ] Epoch: [0] [ 90/4448] eta: 1:12:30 lr: 0.000001 loss: 7.6696 (7.9028) task_loss: 4.8453 (5.0785) bpp_loss: 28.2426 (28.2427) patch_loss: 2.3398 (2.3655) token_loss: 0.6613 (0.6665) fea_loss: 1.8399 (2.0465) time: 0.9511 data: 0.0001 max mem: 11194
|
| 510 |
+
2024-07-25 08:51:41,113 [INFO ] Epoch: [0] [ 100/4448] eta: 1:12:04 lr: 0.000001 loss: 7.6273 (7.8742) task_loss: 4.8031 (5.0499) bpp_loss: 28.2425 (28.2427) patch_loss: 2.3329 (2.3620) token_loss: 0.6699 (0.6671) fea_loss: 1.8089 (2.0209) time: 0.9560 data: 0.0001 max mem: 11194
|
| 511 |
+
2024-07-25 08:51:50,804 [INFO ] Epoch: [0] [ 110/4448] eta: 1:11:44 lr: 0.000001 loss: 7.5991 (7.8487) task_loss: 4.7749 (5.0244) bpp_loss: 28.2424 (28.2427) patch_loss: 2.3215 (2.3584) token_loss: 0.6780 (0.6676) fea_loss: 1.7759 (1.9984) time: 0.9647 data: 0.0001 max mem: 11194
|
| 512 |
+
2024-07-25 08:52:00,570 [INFO ] Epoch: [0] [ 120/4448] eta: 1:11:29 lr: 0.000001 loss: 7.5780 (7.8257) task_loss: 4.7538 (5.0014) bpp_loss: 28.2425 (28.2427) patch_loss: 2.3218 (2.3558) token_loss: 0.6774 (0.6677) fea_loss: 1.7559 (1.9780) time: 0.9728 data: 0.0001 max mem: 11194
|
| 513 |
+
2024-07-25 08:52:10,528 [INFO ] Epoch: [0] [ 130/4448] eta: 1:11:20 lr: 0.000001 loss: 7.5670 (7.8045) task_loss: 4.7427 (4.9802) bpp_loss: 28.2426 (28.2427) patch_loss: 2.3239 (2.3528) token_loss: 0.6669 (0.6682) fea_loss: 1.7497 (1.9592) time: 0.9861 data: 0.0001 max mem: 11194
|
| 514 |
+
2024-07-25 08:52:20,550 [INFO ] Epoch: [0] [ 140/4448] eta: 1:11:14 lr: 0.000001 loss: 7.5367 (7.7836) task_loss: 4.7125 (4.9594) bpp_loss: 28.2423 (28.2427) patch_loss: 2.3120 (2.3499) token_loss: 0.6791 (0.6681) fea_loss: 1.7283 (1.9414) time: 0.9989 data: 0.0001 max mem: 11194
|
| 515 |
+
2024-07-25 08:52:30,565 [INFO ] Epoch: [0] [ 150/4448] eta: 1:11:06 lr: 0.000001 loss: 7.4978 (7.7638) task_loss: 4.6735 (4.9396) bpp_loss: 28.2423 (28.2426) patch_loss: 2.3131 (2.3477) token_loss: 0.6641 (0.6675) fea_loss: 1.6953 (1.9244) time: 1.0018 data: 0.0001 max mem: 11194
|
| 516 |
+
2024-07-25 08:52:40,399 [INFO ] Epoch: [0] [ 160/4448] eta: 1:10:54 lr: 0.000001 loss: 7.4593 (7.7455) task_loss: 4.6350 (4.9213) bpp_loss: 28.2423 (28.2426) patch_loss: 2.3100 (2.3450) token_loss: 0.6513 (0.6671) fea_loss: 1.6766 (1.9092) time: 0.9924 data: 0.0001 max mem: 11194
|
| 517 |
+
2024-07-25 08:52:50,157 [INFO ] Epoch: [0] [ 170/4448] eta: 1:10:40 lr: 0.000001 loss: 7.4549 (7.7282) task_loss: 4.6306 (4.9040) bpp_loss: 28.2425 (28.2426) patch_loss: 2.3005 (2.3425) token_loss: 0.6528 (0.6669) fea_loss: 1.6654 (1.8945) time: 0.9795 data: 0.0001 max mem: 11194
|
| 518 |
+
2024-07-25 08:52:59,911 [INFO ] Epoch: [0] [ 180/4448] eta: 1:10:26 lr: 0.000001 loss: 7.4523 (7.7123) task_loss: 4.6281 (4.8880) bpp_loss: 28.2424 (28.2426) patch_loss: 2.3051 (2.3407) token_loss: 0.6546 (0.6663) fea_loss: 1.6590 (1.8810) time: 0.9755 data: 0.0001 max mem: 11194
|
| 519 |
+
2024-07-25 08:53:09,637 [INFO ] Epoch: [0] [ 190/4448] eta: 1:10:12 lr: 0.000001 loss: 7.4077 (7.6963) task_loss: 4.5834 (4.8720) bpp_loss: 28.2424 (28.2426) patch_loss: 2.3067 (2.3386) token_loss: 0.6544 (0.6660) fea_loss: 1.6313 (1.8674) time: 0.9739 data: 0.0001 max mem: 11194
|
| 520 |
+
2024-07-25 08:53:19,360 [INFO ] Epoch: [0] [ 200/4448] eta: 1:09:59 lr: 0.000001 loss: 7.4077 (7.6829) task_loss: 4.5834 (4.8586) bpp_loss: 28.2423 (28.2426) patch_loss: 2.3058 (2.3366) token_loss: 0.6651 (0.6670) fea_loss: 1.6217 (1.8550) time: 0.9723 data: 0.0001 max mem: 11194
|
| 521 |
+
2024-07-25 08:53:29,073 [INFO ] Epoch: [0] [ 210/4448] eta: 1:09:45 lr: 0.000001 loss: 7.4075 (7.6695) task_loss: 4.5833 (4.8453) bpp_loss: 28.2422 (28.2425) patch_loss: 2.2973 (2.3348) token_loss: 0.6806 (0.6674) fea_loss: 1.6116 (1.8431) time: 0.9717 data: 0.0001 max mem: 11195
|
| 522 |
+
2024-07-25 08:53:38,788 [INFO ] Epoch: [0] [ 220/4448] eta: 1:09:32 lr: 0.000001 loss: 7.3916 (7.6572) task_loss: 4.5674 (4.8329) bpp_loss: 28.2419 (28.2425) patch_loss: 2.2990 (2.3332) token_loss: 0.6721 (0.6680) fea_loss: 1.5945 (1.8317) time: 0.9713 data: 0.0001 max mem: 11195
|
| 523 |
+
2024-07-25 08:53:48,511 [INFO ] Epoch: [0] [ 230/4448] eta: 1:09:20 lr: 0.000001 loss: 7.3755 (7.6451) task_loss: 4.5514 (4.8209) bpp_loss: 28.2417 (28.2425) patch_loss: 2.2998 (2.3316) token_loss: 0.6759 (0.6679) fea_loss: 1.5874 (1.8213) time: 0.9718 data: 0.0001 max mem: 11195
|
| 524 |
+
2024-07-25 08:53:58,215 [INFO ] Epoch: [0] [ 240/4448] eta: 1:09:07 lr: 0.000001 loss: 7.3750 (7.6335) task_loss: 4.5508 (4.8092) bpp_loss: 28.2416 (28.2425) patch_loss: 2.3008 (2.3302) token_loss: 0.6651 (0.6679) fea_loss: 1.5759 (1.8111) time: 0.9713 data: 0.0001 max mem: 11195
|
| 525 |
+
2024-07-25 08:54:07,932 [INFO ] Epoch: [0] [ 250/4448] eta: 1:08:55 lr: 0.000001 loss: 7.3393 (7.6212) task_loss: 4.5152 (4.7970) bpp_loss: 28.2414 (28.2424) patch_loss: 2.2951 (2.3289) token_loss: 0.6563 (0.6671) fea_loss: 1.5655 (1.8010) time: 0.9710 data: 0.0001 max mem: 11195
|
| 526 |
+
2024-07-25 08:54:17,651 [INFO ] Epoch: [0] [ 260/4448] eta: 1:08:43 lr: 0.000001 loss: 7.3268 (7.6096) task_loss: 4.5027 (4.7854) bpp_loss: 28.2413 (28.2424) patch_loss: 2.2931 (2.3274) token_loss: 0.6561 (0.6665) fea_loss: 1.5615 (1.7915) time: 0.9717 data: 0.0001 max mem: 11195
|
| 527 |
+
2024-07-25 08:54:27,368 [INFO ] Epoch: [0] [ 270/4448] eta: 1:08:31 lr: 0.000001 loss: 7.3396 (7.5996) task_loss: 4.5156 (4.7754) bpp_loss: 28.2411 (28.2423) patch_loss: 2.2837 (2.3257) token_loss: 0.6719 (0.6672) fea_loss: 1.5560 (1.7825) time: 0.9717 data: 0.0001 max mem: 11195
|
| 528 |
+
2024-07-25 08:54:37,070 [INFO ] Epoch: [0] [ 280/4448] eta: 1:08:19 lr: 0.000001 loss: 7.3154 (7.5896) task_loss: 4.4914 (4.7654) bpp_loss: 28.2404 (28.2422) patch_loss: 2.2837 (2.3246) token_loss: 0.6619 (0.6668) fea_loss: 1.5447 (1.7740) time: 0.9709 data: 0.0001 max mem: 11195
|
| 529 |
+
2024-07-25 08:54:46,786 [INFO ] Epoch: [0] [ 290/4448] eta: 1:08:08 lr: 0.000001 loss: 7.3109 (7.5802) task_loss: 4.4868 (4.7560) bpp_loss: 28.2396 (28.2421) patch_loss: 2.2924 (2.3236) token_loss: 0.6524 (0.6669) fea_loss: 1.5375 (1.7655) time: 0.9708 data: 0.0001 max mem: 11195
|
| 530 |
+
2024-07-25 08:54:56,541 [INFO ] Epoch: [0] [ 300/4448] eta: 1:07:57 lr: 0.000001 loss: 7.3166 (7.5713) task_loss: 4.4926 (4.7471) bpp_loss: 28.2385 (28.2420) patch_loss: 2.2950 (2.3225) token_loss: 0.6756 (0.6672) fea_loss: 1.5210 (1.7574) time: 0.9735 data: 0.0001 max mem: 11195
|
| 531 |
+
2024-07-25 08:55:06,276 [INFO ] Epoch: [0] [ 310/4448] eta: 1:07:46 lr: 0.000001 loss: 7.3026 (7.5628) task_loss: 4.4788 (4.7386) bpp_loss: 28.2378 (28.2418) patch_loss: 2.2952 (2.3217) token_loss: 0.6699 (0.6673) fea_loss: 1.5175 (1.7496) time: 0.9744 data: 0.0001 max mem: 11195
|
| 532 |
+
2024-07-25 08:55:15,984 [INFO ] Epoch: [0] [ 320/4448] eta: 1:07:34 lr: 0.000001 loss: 7.2847 (7.5538) task_loss: 4.4616 (4.7296) bpp_loss: 28.2320 (28.2415) patch_loss: 2.2890 (2.3206) token_loss: 0.6612 (0.6672) fea_loss: 1.5070 (1.7418) time: 0.9720 data: 0.0001 max mem: 11195
|
| 533 |
+
2024-07-25 08:55:25,715 [INFO ] Epoch: [0] [ 330/4448] eta: 1:07:23 lr: 0.000001 loss: 7.2884 (7.5463) task_loss: 4.4653 (4.7222) bpp_loss: 28.2311 (28.2412) patch_loss: 2.2923 (2.3199) token_loss: 0.6727 (0.6678) fea_loss: 1.4998 (1.7344) time: 0.9719 data: 0.0001 max mem: 11195
|
| 534 |
+
2024-07-25 08:55:35,470 [INFO ] Epoch: [0] [ 340/4448] eta: 1:07:13 lr: 0.000001 loss: 7.2818 (7.5382) task_loss: 4.4587 (4.7141) bpp_loss: 28.2306 (28.2408) patch_loss: 2.2966 (2.3192) token_loss: 0.6657 (0.6677) fea_loss: 1.4959 (1.7273) time: 0.9742 data: 0.0001 max mem: 11195
|
| 535 |
+
2024-07-25 08:55:45,224 [INFO ] Epoch: [0] [ 350/4448] eta: 1:07:02 lr: 0.000001 loss: 7.2619 (7.5306) task_loss: 4.4388 (4.7066) bpp_loss: 28.2131 (28.2400) patch_loss: 2.2874 (2.3183) token_loss: 0.6561 (0.6679) fea_loss: 1.4892 (1.7205) time: 0.9754 data: 0.0001 max mem: 11195
|
| 536 |
+
2024-07-25 08:55:54,962 [INFO ] Epoch: [0] [ 360/4448] eta: 1:06:51 lr: 0.000001 loss: 7.2578 (7.5231) task_loss: 4.4367 (4.6991) bpp_loss: 28.2120 (28.2392) patch_loss: 2.2885 (2.3175) token_loss: 0.6579 (0.6678) fea_loss: 1.4836 (1.7138) time: 0.9745 data: 0.0001 max mem: 11195
|
| 537 |
+
2024-07-25 08:56:04,713 [INFO ] Epoch: [0] [ 370/4448] eta: 1:06:41 lr: 0.000001 loss: 7.2578 (7.5154) task_loss: 4.4367 (4.6916) bpp_loss: 28.2115 (28.2385) patch_loss: 2.2920 (2.3168) token_loss: 0.6579 (0.6676) fea_loss: 1.4752 (1.7072) time: 0.9743 data: 0.0001 max mem: 11195
|
| 538 |
+
2024-07-25 08:56:14,459 [INFO ] Epoch: [0] [ 380/4448] eta: 1:06:30 lr: 0.000001 loss: 7.2504 (7.5086) task_loss: 4.4293 (4.6848) bpp_loss: 28.2110 (28.2377) patch_loss: 2.2893 (2.3161) token_loss: 0.6630 (0.6677) fea_loss: 1.4696 (1.7010) time: 0.9748 data: 0.0001 max mem: 11195
|
| 539 |
+
2024-07-25 08:56:24,203 [INFO ] Epoch: [0] [ 390/4448] eta: 1:06:20 lr: 0.000001 loss: 7.2421 (7.5016) task_loss: 4.4212 (4.6779) bpp_loss: 28.2100 (28.2370) patch_loss: 2.2893 (2.3154) token_loss: 0.6630 (0.6676) fea_loss: 1.4607 (1.6949) time: 0.9744 data: 0.0001 max mem: 11195
|
| 540 |
+
2024-07-25 08:56:33,961 [INFO ] Epoch: [0] [ 400/4448] eta: 1:06:10 lr: 0.000001 loss: 7.2267 (7.4947) task_loss: 4.4057 (4.6711) bpp_loss: 28.2090 (28.2363) patch_loss: 2.2825 (2.3147) token_loss: 0.6591 (0.6674) fea_loss: 1.4549 (1.6890) time: 0.9750 data: 0.0001 max mem: 11195
|
| 541 |
+
2024-07-25 08:56:43,733 [INFO ] Epoch: [0] [ 410/4448] eta: 1:05:59 lr: 0.000001 loss: 7.2241 (7.4887) task_loss: 4.4032 (4.6651) bpp_loss: 28.2080 (28.2356) patch_loss: 2.2854 (2.3141) token_loss: 0.6674 (0.6676) fea_loss: 1.4618 (1.6834) time: 0.9764 data: 0.0001 max mem: 11195
|
| 542 |
+
2024-07-25 08:56:53,446 [INFO ] Epoch: [0] [ 420/4448] eta: 1:05:49 lr: 0.000001 loss: 7.2331 (7.4827) task_loss: 4.4122 (4.6592) bpp_loss: 28.2068 (28.2349) patch_loss: 2.2857 (2.3134) token_loss: 0.6711 (0.6679) fea_loss: 1.4545 (1.6779) time: 0.9742 data: 0.0001 max mem: 11195
|
| 543 |
+
2024-07-25 08:57:03,201 [INFO ] Epoch: [0] [ 430/4448] eta: 1:05:38 lr: 0.000001 loss: 7.2254 (7.4765) task_loss: 4.4047 (4.6531) bpp_loss: 28.2055 (28.2342) patch_loss: 2.2863 (2.3127) token_loss: 0.6718 (0.6679) fea_loss: 1.4462 (1.6725) time: 0.9733 data: 0.0001 max mem: 11195
|
| 544 |
+
2024-07-25 08:57:12,938 [INFO ] Epoch: [0] [ 440/4448] eta: 1:05:28 lr: 0.000001 loss: 7.2275 (7.4714) task_loss: 4.4071 (4.6480) bpp_loss: 28.2046 (28.2335) patch_loss: 2.2863 (2.3122) token_loss: 0.6750 (0.6683) fea_loss: 1.4481 (1.6675) time: 0.9745 data: 0.0001 max mem: 11195
|
| 545 |
+
2024-07-25 08:57:22,692 [INFO ] Epoch: [0] [ 450/4448] eta: 1:05:18 lr: 0.000001 loss: 7.2270 (7.4658) task_loss: 4.4067 (4.6425) bpp_loss: 28.2033 (28.2329) patch_loss: 2.2936 (2.3118) token_loss: 0.6782 (0.6684) fea_loss: 1.4495 (1.6624) time: 0.9745 data: 0.0001 max mem: 11195
|
| 546 |
+
2024-07-25 08:57:32,428 [INFO ] Epoch: [0] [ 460/4448] eta: 1:05:07 lr: 0.000001 loss: 7.2030 (7.4601) task_loss: 4.3827 (4.6368) bpp_loss: 28.2019 (28.2322) patch_loss: 2.2834 (2.3110) token_loss: 0.6762 (0.6684) fea_loss: 1.4385 (1.6574) time: 0.9744 data: 0.0001 max mem: 11195
|
| 547 |
+
2024-07-25 08:57:42,172 [INFO ] Epoch: [0] [ 470/4448] eta: 1:04:57 lr: 0.000001 loss: 7.2220 (7.4547) task_loss: 4.4020 (4.6315) bpp_loss: 28.2006 (28.2315) patch_loss: 2.2808 (2.3105) token_loss: 0.6535 (0.6685) fea_loss: 1.4261 (1.6526) time: 0.9739 data: 0.0001 max mem: 11195
|
| 548 |
+
2024-07-25 08:57:51,902 [INFO ] Epoch: [0] [ 480/4448] eta: 1:04:47 lr: 0.000001 loss: 7.2008 (7.4490) task_loss: 4.3808 (4.6259) bpp_loss: 28.1997 (28.2308) patch_loss: 2.2765 (2.3099) token_loss: 0.6534 (0.6682) fea_loss: 1.4257 (1.6478) time: 0.9736 data: 0.0001 max mem: 11195
|
| 549 |
+
2024-07-25 08:58:01,653 [INFO ] Epoch: [0] [ 490/4448] eta: 1:04:37 lr: 0.000001 loss: 7.2050 (7.4444) task_loss: 4.3852 (4.6214) bpp_loss: 28.1985 (28.2301) patch_loss: 2.2831 (2.3095) token_loss: 0.6743 (0.6686) fea_loss: 1.4232 (1.6432) time: 0.9740 data: 0.0001 max mem: 11195
|
| 550 |
+
2024-07-25 08:58:11,373 [INFO ] Epoch: [0] [ 500/4448] eta: 1:04:26 lr: 0.000001 loss: 7.2024 (7.4392) task_loss: 4.3827 (4.6162) bpp_loss: 28.1970 (28.2295) patch_loss: 2.2877 (2.3091) token_loss: 0.6743 (0.6684) fea_loss: 1.4212 (1.6387) time: 0.9735 data: 0.0001 max mem: 11195
|
| 551 |
+
2024-07-25 08:58:21,089 [INFO ] Epoch: [0] [ 510/4448] eta: 1:04:16 lr: 0.000001 loss: 7.1806 (7.4341) task_loss: 4.3611 (4.6112) bpp_loss: 28.1960 (28.2288) patch_loss: 2.2865 (2.3086) token_loss: 0.6606 (0.6683) fea_loss: 1.4148 (1.6343) time: 0.9717 data: 0.0001 max mem: 11195
|
| 552 |
+
2024-07-25 08:58:30,784 [INFO ] Epoch: [0] [ 520/4448] eta: 1:04:05 lr: 0.000001 loss: 7.1758 (7.4292) task_loss: 4.3564 (4.6064) bpp_loss: 28.1949 (28.2282) patch_loss: 2.2844 (2.3081) token_loss: 0.6636 (0.6683) fea_loss: 1.4082 (1.6299) time: 0.9705 data: 0.0001 max mem: 11195
|
| 553 |
+
2024-07-25 08:58:40,522 [INFO ] Epoch: [0] [ 530/4448] eta: 1:03:55 lr: 0.000001 loss: 7.1598 (7.4244) task_loss: 4.3403 (4.6017) bpp_loss: 28.1940 (28.2275) patch_loss: 2.2851 (2.3077) token_loss: 0.6642 (0.6683) fea_loss: 1.4067 (1.6258) time: 0.9715 data: 0.0001 max mem: 11195
|
| 554 |
+
2024-07-25 08:58:50,244 [INFO ] Epoch: [0] [ 540/4448] eta: 1:03:45 lr: 0.000001 loss: 7.1524 (7.4195) task_loss: 4.3331 (4.5969) bpp_loss: 28.1934 (28.2269) patch_loss: 2.2790 (2.3072) token_loss: 0.6654 (0.6681) fea_loss: 1.4014 (1.6216) time: 0.9729 data: 0.0001 max mem: 11195
|
| 555 |
+
2024-07-25 08:59:00,201 [INFO ] Epoch: [0] [ 550/4448] eta: 1:03:36 lr: 0.000001 loss: 7.1595 (7.4148) task_loss: 4.3401 (4.5922) bpp_loss: 28.1927 (28.2262) patch_loss: 2.2758 (2.3066) token_loss: 0.6682 (0.6682) fea_loss: 1.3959 (1.6174) time: 0.9839 data: 0.0001 max mem: 11195
|
1_feature_extractor/log/DINOv2_training/log/20240725_085916.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
1_feature_extractor/log/DINOv2_training/log/20240726_110417.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
1_feature_extractor/log/DINOv2_training/log/20240726_171814.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
1_feature_extractor/log/DINOv2_training/log/20240728_153020.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
1_feature_extractor/log/DINOv2_training/log/20240728_214526.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
1_feature_extractor/log/DINOv2_training/log/20240729_102738.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
1_feature_extractor/log/DINOv2_training/log/20240730_084148.log
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2024-07-30 08:41:48,409 [INFO ] Logging file is /data0/qiyp/Proteus-pytorch/pretrain/log/DINOv2_training/log//20240730_084148.log
|
| 2 |
+
2024-07-30 08:41:48,410 [INFO ] job dir: /data0/qiyp/Proteus-pytorch/pretrain
|
| 3 |
+
2024-07-30 08:41:48,410 [INFO ] Namespace(batch_size=48,
|
| 4 |
+
epochs=200,
|
| 5 |
+
bce_loss=False,
|
| 6 |
+
unscale_lr=False,
|
| 7 |
+
model='models_proteus_dinov2',
|
| 8 |
+
target_model='vit_base',
|
| 9 |
+
input_size=224,
|
| 10 |
+
drop=0.0,
|
| 11 |
+
drop_path=0.1,
|
| 12 |
+
model_ema=True,
|
| 13 |
+
model_ema_decay=0.99996,
|
| 14 |
+
model_ema_force_cpu=False,
|
| 15 |
+
opt='adamw',
|
| 16 |
+
opt_eps=1e-08,
|
| 17 |
+
opt_betas=None,
|
| 18 |
+
clip_grad=None,
|
| 19 |
+
momentum=0.9,
|
| 20 |
+
weight_decay=0.05,
|
| 21 |
+
sched='cosine',
|
| 22 |
+
lr=0.0002,
|
| 23 |
+
lr_noise=None,
|
| 24 |
+
lr_noise_pct=0.67,
|
| 25 |
+
lr_noise_std=1.0,
|
| 26 |
+
warmup_lr=1e-06,
|
| 27 |
+
min_lr=1e-05,
|
| 28 |
+
decay_epochs=30,
|
| 29 |
+
warmup_epochs=5,
|
| 30 |
+
cooldown_epochs=10,
|
| 31 |
+
patience_epochs=10,
|
| 32 |
+
decay_rate=0.1,
|
| 33 |
+
color_jitter=0.3,
|
| 34 |
+
aa='rand-m9-mstd0.5-inc1',
|
| 35 |
+
smoothing=0.1,
|
| 36 |
+
train_interpolation='bicubic',
|
| 37 |
+
repeated_aug=True,
|
| 38 |
+
train_mode=True,
|
| 39 |
+
ThreeAugment=False,
|
| 40 |
+
src=False,
|
| 41 |
+
global_crops_size=224,
|
| 42 |
+
patch_size=14,
|
| 43 |
+
mask_ratio=[0.5],
|
| 44 |
+
mask_probability=0.5,
|
| 45 |
+
mask_first_n=True,
|
| 46 |
+
clone_batch=1,
|
| 47 |
+
reprob=0.25,
|
| 48 |
+
remode='pixel',
|
| 49 |
+
recount=1,
|
| 50 |
+
resplit=False,
|
| 51 |
+
mixup=0.8,
|
| 52 |
+
cutmix=1.0,
|
| 53 |
+
cutmix_minmax=None,
|
| 54 |
+
mixup_prob=1.0,
|
| 55 |
+
mixup_switch_prob=0.5,
|
| 56 |
+
mixup_mode='batch',
|
| 57 |
+
teacher_model='vit_large',
|
| 58 |
+
teacher_path='',
|
| 59 |
+
distillation_type='none',
|
| 60 |
+
distillation_alpha=0.5,
|
| 61 |
+
distillation_tau=1.0,
|
| 62 |
+
lambda_token=1.0,
|
| 63 |
+
lambda_fea=1.1,
|
| 64 |
+
lambda_patch=1.0,
|
| 65 |
+
cosub=False,
|
| 66 |
+
finetune='',
|
| 67 |
+
attn_only=False,
|
| 68 |
+
weight_inherit='',
|
| 69 |
+
data_path='/data1/datasets/imagenet_fold',
|
| 70 |
+
data_set='IMNET',
|
| 71 |
+
inat_category='name',
|
| 72 |
+
output_dir='log/DINOv2_training',
|
| 73 |
+
log_dir='/data0/qiyp/Proteus-pytorch/pretrain/log/DINOv2_training/log/',
|
| 74 |
+
device='cuda',
|
| 75 |
+
seed=0,
|
| 76 |
+
resume='/data0/qiyp/Proteus-pytorch/pretrain/log/DINOv2_training/checkpoint0080.pth',
|
| 77 |
+
start_epoch=0,
|
| 78 |
+
eval=False,
|
| 79 |
+
eval_crop_ratio=0.875,
|
| 80 |
+
dist_eval=False,
|
| 81 |
+
num_workers=10,
|
| 82 |
+
pin_mem=True,
|
| 83 |
+
distributed=True,
|
| 84 |
+
world_size=6,
|
| 85 |
+
dist_url='env://',
|
| 86 |
+
rank=0,
|
| 87 |
+
gpu=0,
|
| 88 |
+
dist_backend='nccl')
|
| 89 |
+
|
| 90 |
+
2024-07-30 08:41:52,685 [INFO ] Dataset ImageFolder
|
| 91 |
+
Number of datapoints: 1281167
|
| 92 |
+
Root location: /data1/datasets/imagenet_fold/train
|
| 93 |
+
StandardTransform
|
| 94 |
+
Transform: Compose(
|
| 95 |
+
RandomResizedCropAndInterpolation(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=bicubic)
|
| 96 |
+
RandomHorizontalFlip(p=0.5)
|
| 97 |
+
RandAugment(n=2, ops=
|
| 98 |
+
AugmentOp(name=AutoContrast, p=0.5, m=9, mstd=0.5)
|
| 99 |
+
AugmentOp(name=Equalize, p=0.5, m=9, mstd=0.5)
|
| 100 |
+
AugmentOp(name=Invert, p=0.5, m=9, mstd=0.5)
|
| 101 |
+
AugmentOp(name=Rotate, p=0.5, m=9, mstd=0.5)
|
| 102 |
+
AugmentOp(name=PosterizeIncreasing, p=0.5, m=9, mstd=0.5)
|
| 103 |
+
AugmentOp(name=SolarizeIncreasing, p=0.5, m=9, mstd=0.5)
|
| 104 |
+
AugmentOp(name=SolarizeAdd, p=0.5, m=9, mstd=0.5)
|
| 105 |
+
AugmentOp(name=ColorIncreasing, p=0.5, m=9, mstd=0.5)
|
| 106 |
+
AugmentOp(name=ContrastIncreasing, p=0.5, m=9, mstd=0.5)
|
| 107 |
+
AugmentOp(name=BrightnessIncreasing, p=0.5, m=9, mstd=0.5)
|
| 108 |
+
AugmentOp(name=SharpnessIncreasing, p=0.5, m=9, mstd=0.5)
|
| 109 |
+
AugmentOp(name=ShearX, p=0.5, m=9, mstd=0.5)
|
| 110 |
+
AugmentOp(name=ShearY, p=0.5, m=9, mstd=0.5)
|
| 111 |
+
AugmentOp(name=TranslateXRel, p=0.5, m=9, mstd=0.5)
|
| 112 |
+
AugmentOp(name=TranslateYRel, p=0.5, m=9, mstd=0.5))
|
| 113 |
+
ToTensor()
|
| 114 |
+
Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
|
| 115 |
+
RandomErasing(p=0.25, mode=pixel, count=(1, 1))
|
| 116 |
+
)
|
| 117 |
+
2024-07-30 08:41:52,760 [INFO ] Sampler_train = <samplers.RASampler object at 0x7efd92aaf3a0>
|
| 118 |
+
2024-07-30 08:41:52,792 [INFO ] Created a temporary directory at /tmp/tmp1r345n4h
|
| 119 |
+
2024-07-30 08:41:52,793 [INFO ] Writing /tmp/tmp1r345n4h/_remote_module_non_scriptable.py
|
| 120 |
+
2024-07-30 08:42:03,348 [INFO ] using MLP layer as FFN
|
| 121 |
+
2024-07-30 08:42:05,496 [INFO ] using MLP layer as FFN
|
| 122 |
+
2024-07-30 08:42:13,281 [INFO ] Model = MetaArch(
|
| 123 |
+
(student): ModuleDict(
|
| 124 |
+
(backbone): DinoVisionTransformer(
|
| 125 |
+
(patch_embed): PatchEmbed(
|
| 126 |
+
(proj): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
|
| 127 |
+
(norm): Identity()
|
| 128 |
+
)
|
| 129 |
+
(blocks): ModuleList(
|
| 130 |
+
(0-11): 12 x Block(
|
| 131 |
+
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
|
| 132 |
+
(attn): MemEffAttention(
|
| 133 |
+
(qkv): Linear(in_features=768, out_features=2304, bias=True)
|
| 134 |
+
(attn_drop): Dropout(p=0.0, inplace=False)
|
| 135 |
+
(proj): Linear(in_features=768, out_features=768, bias=True)
|
| 136 |
+
(proj_drop): Dropout(p=0.0, inplace=False)
|
| 137 |
+
)
|
| 138 |
+
(ls1): LayerScale()
|
| 139 |
+
(drop_path1): Identity()
|
| 140 |
+
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
|
| 141 |
+
(mlp): Mlp(
|
| 142 |
+
(fc1): Linear(in_features=768, out_features=3072, bias=True)
|
| 143 |
+
(act): GELU(approximate='none')
|
| 144 |
+
(fc2): Linear(in_features=3072, out_features=768, bias=True)
|
| 145 |
+
(drop): Dropout(p=0.0, inplace=False)
|
| 146 |
+
)
|
| 147 |
+
(ls2): LayerScale()
|
| 148 |
+
(drop_path2): Identity()
|
| 149 |
+
)
|
| 150 |
+
)
|
| 151 |
+
(norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
|
| 152 |
+
(head): Identity()
|
| 153 |
+
)
|
| 154 |
+
)
|
| 155 |
+
(teacher): ModuleDict(
|
| 156 |
+
(backbone): DinoVisionTransformer(
|
| 157 |
+
(patch_embed): PatchEmbed(
|
| 158 |
+
(proj): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14))
|
| 159 |
+
(norm): Identity()
|
| 160 |
+
)
|
| 161 |
+
(blocks): ModuleList(
|
| 162 |
+
(0-23): 24 x NestedTensorBlock(
|
| 163 |
+
(norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
|
| 164 |
+
(attn): MemEffAttention(
|
| 165 |
+
(qkv): Linear(in_features=1024, out_features=3072, bias=True)
|
| 166 |
+
(attn_drop): Dropout(p=0.0, inplace=False)
|
| 167 |
+
(proj): Linear(in_features=1024, out_features=1024, bias=True)
|
| 168 |
+
(proj_drop): Dropout(p=0.0, inplace=False)
|
| 169 |
+
)
|
| 170 |
+
(ls1): LayerScale()
|
| 171 |
+
(drop_path1): Identity()
|
| 172 |
+
(norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
|
| 173 |
+
(mlp): Mlp(
|
| 174 |
+
(fc1): Linear(in_features=1024, out_features=4096, bias=True)
|
| 175 |
+
(act): GELU(approximate='none')
|
| 176 |
+
(fc2): Linear(in_features=4096, out_features=1024, bias=True)
|
| 177 |
+
(drop): Dropout(p=0.0, inplace=False)
|
| 178 |
+
)
|
| 179 |
+
(ls2): LayerScale()
|
| 180 |
+
(drop_path2): Identity()
|
| 181 |
+
)
|
| 182 |
+
)
|
| 183 |
+
(norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
|
| 184 |
+
(head): Identity()
|
| 185 |
+
)
|
| 186 |
+
)
|
| 187 |
+
(ibot_head): Sequential(
|
| 188 |
+
(0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| 189 |
+
(1): Linear(in_features=768, out_features=1024, bias=True)
|
| 190 |
+
)
|
| 191 |
+
(token_head): Sequential(
|
| 192 |
+
(0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| 193 |
+
(1): Linear(in_features=768, out_features=1024, bias=True)
|
| 194 |
+
)
|
| 195 |
+
(fea_head): Sequential(
|
| 196 |
+
(0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| 197 |
+
(1): Linear(in_features=768, out_features=1024, bias=True)
|
| 198 |
+
)
|
| 199 |
+
(soft_criterion): MSELoss()
|
| 200 |
+
(info_bottleneck): IF_Module(
|
| 201 |
+
(encoder_blocks): ModuleList(
|
| 202 |
+
(0-3): 4 x Block(
|
| 203 |
+
(norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| 204 |
+
(attn): Attention(
|
| 205 |
+
(qkv): Linear(in_features=768, out_features=2304, bias=True)
|
| 206 |
+
(q_norm): Identity()
|
| 207 |
+
(k_norm): Identity()
|
| 208 |
+
(attn_drop): Dropout(p=0.0, inplace=False)
|
| 209 |
+
(proj): Linear(in_features=768, out_features=768, bias=True)
|
| 210 |
+
(proj_drop): Dropout(p=0.0, inplace=False)
|
| 211 |
+
)
|
| 212 |
+
(ls1): Identity()
|
| 213 |
+
(drop_path1): Identity()
|
| 214 |
+
(norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| 215 |
+
(mlp): Mlp(
|
| 216 |
+
(fc1): Linear(in_features=768, out_features=3072, bias=True)
|
| 217 |
+
(act): GELU(approximate='none')
|
| 218 |
+
(drop1): Dropout(p=0.0, inplace=False)
|
| 219 |
+
(norm): Identity()
|
| 220 |
+
(fc2): Linear(in_features=3072, out_features=768, bias=True)
|
| 221 |
+
(drop2): Dropout(p=0.0, inplace=False)
|
| 222 |
+
)
|
| 223 |
+
(ls2): Identity()
|
| 224 |
+
(drop_path2): Identity()
|
| 225 |
+
)
|
| 226 |
+
)
|
| 227 |
+
(encoder_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| 228 |
+
(decoder_blocks): ModuleList(
|
| 229 |
+
(0-3): 4 x Block(
|
| 230 |
+
(norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| 231 |
+
(attn): Attention(
|
| 232 |
+
(qkv): Linear(in_features=768, out_features=2304, bias=True)
|
| 233 |
+
(q_norm): Identity()
|
| 234 |
+
(k_norm): Identity()
|
| 235 |
+
(attn_drop): Dropout(p=0.0, inplace=False)
|
| 236 |
+
(proj): Linear(in_features=768, out_features=768, bias=True)
|
| 237 |
+
(proj_drop): Dropout(p=0.0, inplace=False)
|
| 238 |
+
)
|
| 239 |
+
(ls1): Identity()
|
| 240 |
+
(drop_path1): Identity()
|
| 241 |
+
(norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| 242 |
+
(mlp): Mlp(
|
| 243 |
+
(fc1): Linear(in_features=768, out_features=3072, bias=True)
|
| 244 |
+
(act): GELU(approximate='none')
|
| 245 |
+
(drop1): Dropout(p=0.0, inplace=False)
|
| 246 |
+
(norm): Identity()
|
| 247 |
+
(fc2): Linear(in_features=3072, out_features=768, bias=True)
|
| 248 |
+
(drop2): Dropout(p=0.0, inplace=False)
|
| 249 |
+
)
|
| 250 |
+
(ls2): Identity()
|
| 251 |
+
(drop_path2): Identity()
|
| 252 |
+
)
|
| 253 |
+
)
|
| 254 |
+
(decoder_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| 255 |
+
(entropy_bottleneck): EntropyBottleneck(
|
| 256 |
+
(likelihood_lower_bound): LowerBound()
|
| 257 |
+
)
|
| 258 |
+
)
|
| 259 |
+
)
|
| 260 |
+
2024-07-30 08:42:16,057 [INFO ] number of params: 144845568
|
| 261 |
+
2024-07-30 08:42:16,057 [INFO ] base lr = 0.0002
|
| 262 |
+
2024-07-30 08:42:16,058 [INFO ] actural lr = 0.00011250000000000001
|
| 263 |
+
2024-07-30 08:42:35,018 [INFO ] Loaded state_dict_ema
|
| 264 |
+
2024-07-30 08:42:35,076 [INFO ] Resuming from /data0/qiyp/Proteus-pytorch/pretrain/log/DINOv2_training/checkpoint0080.pth
|
| 265 |
+
2024-07-30 08:42:35,076 [INFO ] Start training for 200 epochs
|
| 266 |
+
2024-07-30 08:42:47,484 [INFO ] Epoch: [81] [ 0/4448] eta: 15:19:18 lr: 0.000076 loss: 2.8205 (2.8205) task_loss: 2.6200 (2.6200) bpp_loss: 0.2883 (0.2883) patch_loss: 0.9649 (0.9649) token_loss: 0.8029 (0.8029) fea_loss: 0.8522 (0.8522) time: 12.4007 data: 3.4410 max mem: 20527
|
| 267 |
+
2024-07-30 08:43:06,985 [INFO ] Epoch: [81] [ 20/4448] eta: 1:52:06 lr: 0.000076 loss: 2.8058 (2.8015) task_loss: 2.5992 (2.5967) bpp_loss: 0.2986 (0.2988) patch_loss: 0.9937 (0.9960) token_loss: 0.7355 (0.7480) fea_loss: 0.8571 (0.8527) time: 0.9749 data: 0.0001 max mem: 21093
|
| 268 |
+
2024-07-30 08:43:26,649 [INFO ] Epoch: [81] [ 40/4448] eta: 1:32:23 lr: 0.000076 loss: 2.7488 (2.7856) task_loss: 2.5421 (2.5804) bpp_loss: 0.3020 (0.3006) patch_loss: 0.9832 (0.9945) token_loss: 0.7257 (0.7364) fea_loss: 0.8415 (0.8495) time: 0.9831 data: 0.0001 max mem: 21095
|
| 269 |
+
2024-07-30 08:43:46,729 [INFO ] Epoch: [81] [ 60/4448] eta: 1:25:53 lr: 0.000076 loss: 2.7729 (2.7857) task_loss: 2.5642 (2.5803) bpp_loss: 0.3040 (0.3017) patch_loss: 1.0076 (0.9998) token_loss: 0.7200 (0.7333) fea_loss: 0.8399 (0.8472) time: 1.0040 data: 0.0001 max mem: 21095
|
| 270 |
+
2024-07-30 08:44:07,288 [INFO ] Epoch: [81] [ 80/4448] eta: 1:22:51 lr: 0.000076 loss: 2.8271 (2.7941) task_loss: 2.6186 (2.5881) bpp_loss: 0.3040 (0.3025) patch_loss: 1.0128 (1.0035) token_loss: 0.7397 (0.7346) fea_loss: 0.8573 (0.8500) time: 1.0279 data: 0.0001 max mem: 21095
|
| 271 |
+
2024-07-30 08:44:27,381 [INFO ] Epoch: [81] [ 100/4448] eta: 1:20:34 lr: 0.000076 loss: 2.8005 (2.8008) task_loss: 2.5913 (2.5942) bpp_loss: 0.3074 (0.3034) patch_loss: 1.0188 (1.0078) token_loss: 0.7295 (0.7347) fea_loss: 0.8552 (0.8517) time: 1.0046 data: 0.0001 max mem: 21095
|
| 272 |
+
2024-07-30 08:44:47,266 [INFO ] Epoch: [81] [ 120/4448] eta: 1:18:47 lr: 0.000076 loss: 2.7752 (2.7995) task_loss: 2.5660 (2.5928) bpp_loss: 0.3070 (0.3040) patch_loss: 0.9996 (1.0060) token_loss: 0.7348 (0.7359) fea_loss: 0.8485 (0.8509) time: 0.9942 data: 0.0001 max mem: 21095
|
| 273 |
+
2024-07-30 08:45:07,249 [INFO ] Epoch: [81] [ 140/4448] eta: 1:17:28 lr: 0.000076 loss: 2.8033 (2.8006) task_loss: 2.5930 (2.5937) bpp_loss: 0.3064 (0.3044) patch_loss: 1.0073 (1.0067) token_loss: 0.7371 (0.7357) fea_loss: 0.8503 (0.8513) time: 0.9991 data: 0.0002 max mem: 21096
|
| 274 |
+
2024-07-30 08:45:27,250 [INFO ] Epoch: [81] [ 160/4448] eta: 1:16:25 lr: 0.000076 loss: 2.7980 (2.8014) task_loss: 2.5895 (2.5944) bpp_loss: 0.3069 (0.3047) patch_loss: 1.0087 (1.0062) token_loss: 0.7382 (0.7361) fea_loss: 0.8631 (0.8520) time: 1.0000 data: 0.0001 max mem: 21096
|
| 275 |
+
2024-07-30 08:45:47,340 [INFO ] Epoch: [81] [ 180/4448] eta: 1:15:33 lr: 0.000076 loss: 2.8292 (2.8018) task_loss: 2.6184 (2.5945) bpp_loss: 0.3091 (0.3052) patch_loss: 1.0158 (1.0076) token_loss: 0.7273 (0.7351) fea_loss: 0.8446 (0.8518) time: 1.0044 data: 0.0001 max mem: 21096
|
| 276 |
+
2024-07-30 08:46:07,943 [INFO ] Epoch: [81] [ 200/4448] eta: 1:14:58 lr: 0.000076 loss: 2.7974 (2.7999) task_loss: 2.5896 (2.5926) bpp_loss: 0.3082 (0.3055) patch_loss: 0.9949 (1.0069) token_loss: 0.7182 (0.7340) fea_loss: 0.8525 (0.8516) time: 1.0301 data: 0.0001 max mem: 21096
|
| 277 |
+
2024-07-30 08:46:28,477 [INFO ] Epoch: [81] [ 220/4448] eta: 1:14:24 lr: 0.000076 loss: 2.7864 (2.8008) task_loss: 2.5794 (2.5933) bpp_loss: 0.3078 (0.3057) patch_loss: 1.0100 (1.0072) token_loss: 0.7306 (0.7342) fea_loss: 0.8441 (0.8518) time: 1.0266 data: 0.0001 max mem: 21096
|
| 278 |
+
2024-07-30 08:46:48,558 [INFO ] Epoch: [81] [ 240/4448] eta: 1:13:45 lr: 0.000076 loss: 2.8172 (2.8027) task_loss: 2.6068 (2.5950) bpp_loss: 0.3095 (0.3060) patch_loss: 1.0120 (1.0081) token_loss: 0.7279 (0.7344) fea_loss: 0.8583 (0.8525) time: 1.0040 data: 0.0001 max mem: 21096
|
| 279 |
+
2024-07-30 08:47:08,665 [INFO ] Epoch: [81] [ 260/4448] eta: 1:13:09 lr: 0.000076 loss: 2.7650 (2.8013) task_loss: 2.5573 (2.5936) bpp_loss: 0.3086 (0.3063) patch_loss: 0.9966 (1.0075) token_loss: 0.7286 (0.7338) fea_loss: 0.8450 (0.8523) time: 1.0053 data: 0.0001 max mem: 21096
|
| 280 |
+
2024-07-30 08:47:28,757 [INFO ] Epoch: [81] [ 280/4448] eta: 1:12:35 lr: 0.000076 loss: 2.7513 (2.7988) task_loss: 2.5452 (2.5910) bpp_loss: 0.3100 (0.3066) patch_loss: 0.9790 (1.0058) token_loss: 0.7214 (0.7333) fea_loss: 0.8448 (0.8518) time: 1.0045 data: 0.0001 max mem: 21096
|
| 281 |
+
2024-07-30 08:47:48,814 [INFO ] Epoch: [81] [ 300/4448] eta: 1:12:03 lr: 0.000076 loss: 2.7910 (2.7984) task_loss: 2.5833 (2.5906) bpp_loss: 0.3097 (0.3068) patch_loss: 0.9979 (1.0056) token_loss: 0.7220 (0.7334) fea_loss: 0.8453 (0.8515) time: 1.0028 data: 0.0001 max mem: 21096
|
| 282 |
+
2024-07-30 08:48:08,893 [INFO ] Epoch: [81] [ 320/4448] eta: 1:11:32 lr: 0.000076 loss: 2.7957 (2.7991) task_loss: 2.5868 (2.5911) bpp_loss: 0.3101 (0.3071) patch_loss: 1.0096 (1.0065) token_loss: 0.7125 (0.7326) fea_loss: 0.8559 (0.8520) time: 1.0039 data: 0.0001 max mem: 21096
|
| 283 |
+
2024-07-30 08:48:29,072 [INFO ] Epoch: [81] [ 340/4448] eta: 1:11:04 lr: 0.000076 loss: 2.7660 (2.7990) task_loss: 2.5585 (2.5909) bpp_loss: 0.3108 (0.3072) patch_loss: 0.9918 (1.0063) token_loss: 0.7228 (0.7327) fea_loss: 0.8468 (0.8519) time: 1.0089 data: 0.0001 max mem: 21096
|
| 284 |
+
2024-07-30 08:48:49,101 [INFO ] Epoch: [81] [ 360/4448] eta: 1:10:35 lr: 0.000076 loss: 2.8059 (2.7999) task_loss: 2.5933 (2.5917) bpp_loss: 0.3136 (0.3076) patch_loss: 1.0057 (1.0066) token_loss: 0.7353 (0.7330) fea_loss: 0.8496 (0.8521) time: 1.0014 data: 0.0001 max mem: 21096
|
| 285 |
+
2024-07-30 08:49:09,150 [INFO ] Epoch: [81] [ 380/4448] eta: 1:10:07 lr: 0.000076 loss: 2.7807 (2.7979) task_loss: 2.5695 (2.5897) bpp_loss: 0.3108 (0.3078) patch_loss: 0.9834 (1.0055) token_loss: 0.7254 (0.7327) fea_loss: 0.8475 (0.8515) time: 1.0024 data: 0.0001 max mem: 21096
|
| 286 |
+
2024-07-30 08:49:29,184 [INFO ] Epoch: [81] [ 400/4448] eta: 1:09:40 lr: 0.000076 loss: 2.7576 (2.7964) task_loss: 2.5457 (2.5881) bpp_loss: 0.3135 (0.3081) patch_loss: 0.9908 (1.0050) token_loss: 0.7106 (0.7318) fea_loss: 0.8434 (0.8513) time: 1.0016 data: 0.0001 max mem: 21096
|
| 287 |
+
2024-07-30 08:49:49,309 [INFO ] Epoch: [81] [ 420/4448] eta: 1:09:14 lr: 0.000076 loss: 2.7548 (2.7950) task_loss: 2.5465 (2.5866) bpp_loss: 0.3124 (0.3083) patch_loss: 0.9950 (1.0045) token_loss: 0.7273 (0.7312) fea_loss: 0.8426 (0.8510) time: 1.0062 data: 0.0002 max mem: 21096
|
| 288 |
+
2024-07-30 08:50:09,308 [INFO ] Epoch: [81] [ 440/4448] eta: 1:08:47 lr: 0.000076 loss: 2.7648 (2.7940) task_loss: 2.5540 (2.5855) bpp_loss: 0.3125 (0.3085) patch_loss: 0.9801 (1.0036) token_loss: 0.7217 (0.7311) fea_loss: 0.8493 (0.8508) time: 0.9999 data: 0.0002 max mem: 21096
|
| 289 |
+
2024-07-30 08:50:29,354 [INFO ] Epoch: [81] [ 460/4448] eta: 1:08:22 lr: 0.000076 loss: 2.7843 (2.7932) task_loss: 2.5758 (2.5847) bpp_loss: 0.3126 (0.3087) patch_loss: 0.9979 (1.0033) token_loss: 0.7206 (0.7308) fea_loss: 0.8484 (0.8506) time: 1.0022 data: 0.0002 max mem: 21096
|
| 290 |
+
2024-07-30 08:50:49,382 [INFO ] Epoch: [81] [ 480/4448] eta: 1:07:57 lr: 0.000076 loss: 2.7398 (2.7920) task_loss: 2.5274 (2.5834) bpp_loss: 0.3156 (0.3090) patch_loss: 0.9877 (1.0032) token_loss: 0.7010 (0.7300) fea_loss: 0.8394 (0.8502) time: 1.0013 data: 0.0002 max mem: 21096
|
| 291 |
+
2024-07-30 08:51:09,296 [INFO ] Epoch: [81] [ 500/4448] eta: 1:07:31 lr: 0.000076 loss: 2.7482 (2.7908) task_loss: 2.5427 (2.5822) bpp_loss: 0.3129 (0.3091) patch_loss: 0.9996 (1.0030) token_loss: 0.6940 (0.7294) fea_loss: 0.8401 (0.8499) time: 0.9956 data: 0.0002 max mem: 21096
|
| 292 |
+
2024-07-30 08:51:29,262 [INFO ] Epoch: [81] [ 520/4448] eta: 1:07:07 lr: 0.000076 loss: 2.7785 (2.7908) task_loss: 2.5668 (2.5821) bpp_loss: 0.3141 (0.3093) patch_loss: 1.0087 (1.0028) token_loss: 0.7261 (0.7294) fea_loss: 0.8467 (0.8498) time: 0.9982 data: 0.0002 max mem: 21096
|
| 293 |
+
2024-07-30 08:51:49,227 [INFO ] Epoch: [81] [ 540/4448] eta: 1:06:42 lr: 0.000076 loss: 2.8121 (2.7921) task_loss: 2.5996 (2.5833) bpp_loss: 0.3148 (0.3096) patch_loss: 0.9975 (1.0033) token_loss: 0.7236 (0.7298) fea_loss: 0.8581 (0.8502) time: 0.9982 data: 0.0002 max mem: 21096
|
| 294 |
+
2024-07-30 08:52:09,093 [INFO ] Epoch: [81] [ 560/4448] eta: 1:06:17 lr: 0.000076 loss: 2.7914 (2.7926) task_loss: 2.5781 (2.5837) bpp_loss: 0.3144 (0.3097) patch_loss: 1.0098 (1.0037) token_loss: 0.7028 (0.7295) fea_loss: 0.8557 (0.8505) time: 0.9932 data: 0.0002 max mem: 21096
|
| 295 |
+
2024-07-30 08:52:29,024 [INFO ] Epoch: [81] [ 580/4448] eta: 1:05:53 lr: 0.000076 loss: 2.8253 (2.7942) task_loss: 2.6127 (2.5851) bpp_loss: 0.3147 (0.3099) patch_loss: 1.0152 (1.0044) token_loss: 0.7303 (0.7297) fea_loss: 0.8694 (0.8511) time: 0.9965 data: 0.0002 max mem: 21096
|
| 296 |
+
2024-07-30 08:52:48,932 [INFO ] Epoch: [81] [ 600/4448] eta: 1:05:30 lr: 0.000076 loss: 2.8256 (2.7949) task_loss: 2.6144 (2.5858) bpp_loss: 0.3141 (0.3101) patch_loss: 0.9969 (1.0043) token_loss: 0.7540 (0.7303) fea_loss: 0.8563 (0.8512) time: 0.9953 data: 0.0002 max mem: 21096
|
| 297 |
+
2024-07-30 08:53:08,892 [INFO ] Epoch: [81] [ 620/4448] eta: 1:05:06 lr: 0.000076 loss: 2.7807 (2.7948) task_loss: 2.5716 (2.5856) bpp_loss: 0.3145 (0.3102) patch_loss: 1.0056 (1.0042) token_loss: 0.7300 (0.7303) fea_loss: 0.8505 (0.8510) time: 0.9980 data: 0.0002 max mem: 21096
|
| 298 |
+
2024-07-30 08:53:28,936 [INFO ] Epoch: [81] [ 640/4448] eta: 1:04:44 lr: 0.000076 loss: 2.8072 (2.7954) task_loss: 2.5994 (2.5861) bpp_loss: 0.3138 (0.3103) patch_loss: 1.0003 (1.0044) token_loss: 0.7397 (0.7304) fea_loss: 0.8647 (0.8514) time: 1.0021 data: 0.0002 max mem: 21096
|
| 299 |
+
2024-07-30 08:53:48,900 [INFO ] Epoch: [81] [ 660/4448] eta: 1:04:21 lr: 0.000076 loss: 2.7934 (2.7958) task_loss: 2.5801 (2.5865) bpp_loss: 0.3144 (0.3104) patch_loss: 0.9936 (1.0043) token_loss: 0.7337 (0.7307) fea_loss: 0.8504 (0.8515) time: 0.9982 data: 0.0002 max mem: 21096
|
| 300 |
+
2024-07-30 08:54:08,946 [INFO ] Epoch: [81] [ 680/4448] eta: 1:03:58 lr: 0.000076 loss: 2.7859 (2.7956) task_loss: 2.5747 (2.5862) bpp_loss: 0.3140 (0.3106) patch_loss: 0.9925 (1.0043) token_loss: 0.7216 (0.7306) fea_loss: 0.8422 (0.8514) time: 1.0022 data: 0.0002 max mem: 21096
|
| 301 |
+
2024-07-30 08:54:28,905 [INFO ] Epoch: [81] [ 700/4448] eta: 1:03:36 lr: 0.000076 loss: 2.7273 (2.7944) task_loss: 2.5182 (2.5851) bpp_loss: 0.3142 (0.3107) patch_loss: 0.9854 (1.0038) token_loss: 0.7153 (0.7301) fea_loss: 0.8333 (0.8511) time: 0.9979 data: 0.0001 max mem: 21096
|
1_feature_extractor/log/DINOv2_training/log/20240730_085449.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
1_feature_extractor/log/DINOv2_training/log/20240731_102940.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
1_feature_extractor/log/DINOv2_training/log/20240801_091959.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
1_feature_extractor/log/DINOv2_training/log/20240801_155326.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
1_feature_extractor/log/DINOv2_training/log/20240803_163338.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
1_feature_extractor/log/DINOv2_training/log/20240803_231933.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
1_feature_extractor/log/DINOv2_training/log/20240804_144252.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
1_feature_extractor/losses_hint.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
"""
|
| 4 |
+
Implements the knowledge distillation loss
|
| 5 |
+
"""
|
| 6 |
+
import torch
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DistillationLoss(torch.nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
This module wraps a standard criterion and adds an extra knowledge distillation loss by
|
| 13 |
+
taking a teacher model prediction and using it as additional supervision.
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
|
| 16 |
+
distillation_type: str, lambda_token: float, lambda_fea: float):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.base_criterion = base_criterion
|
| 19 |
+
self.teacher_model = teacher_model
|
| 20 |
+
assert distillation_type in ['none', 'soft', 'hard']
|
| 21 |
+
self.distillation_type = distillation_type
|
| 22 |
+
self.lambda_token = lambda_token
|
| 23 |
+
self.lambda_fea = lambda_fea
|
| 24 |
+
self.soft_criterion = torch.nn.MSELoss()
|
| 25 |
+
|
| 26 |
+
def forward(self, inputs, outputs, labels):
|
| 27 |
+
"""
|
| 28 |
+
Args:
|
| 29 |
+
inputs: The original inputs that are feed to the teacher model
|
| 30 |
+
outputs: the outputs of the model to be trained. It is expected to be
|
| 31 |
+
either a Tensor, or a Tuple[Tensor, Tensor], with the original output
|
| 32 |
+
in the first position and the distillation predictions as the second output
|
| 33 |
+
labels: the labels for the base criterion
|
| 34 |
+
"""
|
| 35 |
+
outputs_token, outputs_fea = outputs
|
| 36 |
+
|
| 37 |
+
# don't backprop throught the teacher
|
| 38 |
+
with torch.no_grad():
|
| 39 |
+
teacher_outputs = self.teacher_model.backbone.get_intermediate_layers(inputs, n=4, return_class_token=True)
|
| 40 |
+
teacher_outputs_token = teacher_outputs[3][1]
|
| 41 |
+
teacher_outputs_fea = torch.cat((teacher_outputs_token.unsqueeze(1),teacher_outputs[3][0]),dim=1)
|
| 42 |
+
|
| 43 |
+
distillation_loss_token = self.soft_criterion(outputs_token, teacher_outputs_token)
|
| 44 |
+
distillation_loss_fea = self.soft_criterion(outputs_fea, teacher_outputs_fea)
|
| 45 |
+
|
| 46 |
+
token_loss = self.lambda_token * distillation_loss_token
|
| 47 |
+
fea_loss = self.lambda_fea * distillation_loss_fea
|
| 48 |
+
|
| 49 |
+
return token_loss, fea_loss
|
1_feature_extractor/main.py
ADDED
|
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
import argparse
|
| 4 |
+
import datetime
|
| 5 |
+
import numpy as np
|
| 6 |
+
import time
|
| 7 |
+
import torch
|
| 8 |
+
import torch.backends.cudnn as cudnn
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
from timm.models import create_model
|
| 14 |
+
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
| 15 |
+
from timm.scheduler import create_scheduler
|
| 16 |
+
from timm.optim import create_optimizer
|
| 17 |
+
from timm.utils import NativeScaler, get_state_dict, ModelEma
|
| 18 |
+
from augmentations import collate_data_and_cast_aug
|
| 19 |
+
from datasets import build_dataset
|
| 20 |
+
|
| 21 |
+
from losses_hint import DistillationLoss
|
| 22 |
+
from samplers import RASampler
|
| 23 |
+
from functools import partial
|
| 24 |
+
|
| 25 |
+
import importlib
|
| 26 |
+
import utils
|
| 27 |
+
import random
|
| 28 |
+
import math
|
| 29 |
+
from multiprocessing import Value
|
| 30 |
+
from abc import ABC
|
| 31 |
+
|
| 32 |
+
import sys
|
| 33 |
+
from typing import Iterable, Optional
|
| 34 |
+
from timm.data import Mixup
|
| 35 |
+
from timm.utils import accuracy, ModelEma
|
| 36 |
+
import utils
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class MaskingGenerator(ABC):
|
| 40 |
+
def __init__(self, input_size):
|
| 41 |
+
if not isinstance(input_size, tuple):
|
| 42 |
+
input_size = (input_size,) * 2
|
| 43 |
+
self.height, self.width = input_size
|
| 44 |
+
self.num_patches = self.height * self.width
|
| 45 |
+
|
| 46 |
+
def __repr__(self):
|
| 47 |
+
raise NotImplementedError
|
| 48 |
+
|
| 49 |
+
def get_shape(self):
|
| 50 |
+
return self.height, self.width
|
| 51 |
+
|
| 52 |
+
def _mask(self, mask, max_mask_patches):
|
| 53 |
+
raise NotImplementedError
|
| 54 |
+
|
| 55 |
+
def get_none_mask(self):
|
| 56 |
+
return np.zeros(shape=self.get_shape(), dtype=bool)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class RandomMaskingGenerator(MaskingGenerator):
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
input_size,
|
| 64 |
+
):
|
| 65 |
+
"""
|
| 66 |
+
Args:
|
| 67 |
+
input_size: the size of the token map, e.g., 14x14
|
| 68 |
+
"""
|
| 69 |
+
super().__init__(input_size)
|
| 70 |
+
|
| 71 |
+
def __repr__(self):
|
| 72 |
+
repr_str = f"Random Generator({self.height}, {self.width})"
|
| 73 |
+
return repr_str
|
| 74 |
+
|
| 75 |
+
def _mask(self, mask, max_mask_patches):
|
| 76 |
+
return super()._mask(mask, max_mask_patches)
|
| 77 |
+
|
| 78 |
+
def __call__(self, num_masking_patches=0):
|
| 79 |
+
if num_masking_patches <= 0:
|
| 80 |
+
return np.zeros(shape=self.get_shape(), dtype=bool)
|
| 81 |
+
|
| 82 |
+
mask = np.hstack([np.ones(num_masking_patches, dtype=bool),
|
| 83 |
+
np.zeros(self.num_patches - num_masking_patches, dtype=bool)])
|
| 84 |
+
np.random.shuffle(mask)
|
| 85 |
+
mask = mask.reshape(self.get_shape())
|
| 86 |
+
return mask
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def get_args_parser():
|
| 90 |
+
parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
|
| 91 |
+
parser.add_argument('--batch-size', default=64, type=int)
|
| 92 |
+
parser.add_argument('--epochs', default=300, type=int)
|
| 93 |
+
parser.add_argument('--bce-loss', action='store_true')
|
| 94 |
+
parser.add_argument('--unscale-lr', action='store_true')
|
| 95 |
+
|
| 96 |
+
# Model parameters
|
| 97 |
+
parser.add_argument('--model', default='deit_base_patch16_224', type=str)
|
| 98 |
+
parser.add_argument('--target_model', default='deit_base_patch16_224', type=str)
|
| 99 |
+
parser.add_argument('--input-size', default=224, type=int, help='images input size')
|
| 100 |
+
|
| 101 |
+
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
|
| 102 |
+
help='Dropout rate (default: 0.)')
|
| 103 |
+
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
|
| 104 |
+
help='Drop path rate (default: 0.1)')
|
| 105 |
+
|
| 106 |
+
parser.add_argument('--model-ema', action='store_true')
|
| 107 |
+
parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
|
| 108 |
+
parser.set_defaults(model_ema=True)
|
| 109 |
+
parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
|
| 110 |
+
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
|
| 111 |
+
|
| 112 |
+
# Optimizer parameters
|
| 113 |
+
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
|
| 114 |
+
help='Optimizer (default: "adamw"')
|
| 115 |
+
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
|
| 116 |
+
help='Optimizer Epsilon (default: 1e-8)')
|
| 117 |
+
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
|
| 118 |
+
help='Optimizer Betas (default: None, use opt default)')
|
| 119 |
+
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
|
| 120 |
+
help='Clip gradient norm (default: None, no clipping)')
|
| 121 |
+
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
| 122 |
+
help='SGD momentum (default: 0.9)')
|
| 123 |
+
parser.add_argument('--weight-decay', type=float, default=0.05,
|
| 124 |
+
help='weight decay (default: 0.05)')
|
| 125 |
+
# Learning rate schedule parameters
|
| 126 |
+
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
|
| 127 |
+
help='LR scheduler (default: "cosine"')
|
| 128 |
+
parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
|
| 129 |
+
help='learning rate (default: 5e-4)')
|
| 130 |
+
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
|
| 131 |
+
help='learning rate noise on/off epoch percentages')
|
| 132 |
+
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
|
| 133 |
+
help='learning rate noise limit percent (default: 0.67)')
|
| 134 |
+
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
|
| 135 |
+
help='learning rate noise std-dev (default: 1.0)')
|
| 136 |
+
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
|
| 137 |
+
help='warmup learning rate (default: 1e-6)')
|
| 138 |
+
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
|
| 139 |
+
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
| 140 |
+
|
| 141 |
+
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
|
| 142 |
+
help='epoch interval to decay LR')
|
| 143 |
+
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
|
| 144 |
+
help='epochs to warmup LR, if scheduler supports')
|
| 145 |
+
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
|
| 146 |
+
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
|
| 147 |
+
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
|
| 148 |
+
help='patience epochs for Plateau LR scheduler (default: 10')
|
| 149 |
+
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
| 150 |
+
help='LR decay rate (default: 0.1)')
|
| 151 |
+
|
| 152 |
+
# Augmentation parameters
|
| 153 |
+
parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT',
|
| 154 |
+
help='Color jitter factor (default: 0.3)')
|
| 155 |
+
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
|
| 156 |
+
help='Use AutoAugment policy. "v0" or "original". " + \
|
| 157 |
+
"(default: rand-m9-mstd0.5-inc1)'),
|
| 158 |
+
parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
|
| 159 |
+
parser.add_argument('--train-interpolation', type=str, default='bicubic',
|
| 160 |
+
help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
|
| 161 |
+
|
| 162 |
+
parser.add_argument('--repeated-aug', action='store_true')
|
| 163 |
+
parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
|
| 164 |
+
parser.set_defaults(repeated_aug=True)
|
| 165 |
+
|
| 166 |
+
parser.add_argument('--train-mode', action='store_true')
|
| 167 |
+
parser.add_argument('--no-train-mode', action='store_false', dest='train_mode')
|
| 168 |
+
parser.set_defaults(train_mode=True)
|
| 169 |
+
|
| 170 |
+
parser.add_argument('--ThreeAugment', action='store_true') #3augment
|
| 171 |
+
|
| 172 |
+
parser.add_argument('--src', action='store_true') #simple random crop
|
| 173 |
+
|
| 174 |
+
# add dataset parameters
|
| 175 |
+
parser.add_argument('--global_crops_size', '--img_size', default=224, type=int,
|
| 176 |
+
help="this should be equal to image size")
|
| 177 |
+
parser.add_argument('--patch_size', default=16, type=int,
|
| 178 |
+
help="patch size for vit patch embedding")
|
| 179 |
+
|
| 180 |
+
# add masking parameter
|
| 181 |
+
parser.add_argument('--mask_ratio', default=(0.1, 0.5), type=float, nargs='+',
|
| 182 |
+
help="mask ratio can be either a value or a range")
|
| 183 |
+
parser.add_argument('--mask_probability', default=0., type=float,
|
| 184 |
+
help="how many samples with be applied with masking")
|
| 185 |
+
parser.add_argument('--mask_first_n', action='store_true',
|
| 186 |
+
help="mask the first n sample to avoid shuffling. Needed for MAE-style encoder")
|
| 187 |
+
parser.add_argument('--clone_batch', default=1, type=int,
|
| 188 |
+
help="how many times to clone the batch for masking (default: 1, not cloning)")
|
| 189 |
+
|
| 190 |
+
# * Random Erase params
|
| 191 |
+
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
|
| 192 |
+
help='Random erase prob (default: 0.25)')
|
| 193 |
+
parser.add_argument('--remode', type=str, default='pixel',
|
| 194 |
+
help='Random erase mode (default: "pixel")')
|
| 195 |
+
parser.add_argument('--recount', type=int, default=1,
|
| 196 |
+
help='Random erase count (default: 1)')
|
| 197 |
+
parser.add_argument('--resplit', action='store_true', default=False,
|
| 198 |
+
help='Do not random erase first (clean) augmentation split')
|
| 199 |
+
|
| 200 |
+
# * Mixup params
|
| 201 |
+
parser.add_argument('--mixup', type=float, default=0.8,
|
| 202 |
+
help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
|
| 203 |
+
parser.add_argument('--cutmix', type=float, default=1.0,
|
| 204 |
+
help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
|
| 205 |
+
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
|
| 206 |
+
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
|
| 207 |
+
parser.add_argument('--mixup-prob', type=float, default=1.0,
|
| 208 |
+
help='Probability of performing mixup or cutmix when either/both is enabled')
|
| 209 |
+
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
|
| 210 |
+
help='Probability of switching to cutmix when both mixup and cutmix enabled')
|
| 211 |
+
parser.add_argument('--mixup-mode', type=str, default='batch',
|
| 212 |
+
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
|
| 213 |
+
|
| 214 |
+
# Distillation parameters
|
| 215 |
+
parser.add_argument('--teacher-model', default='base', type=str)
|
| 216 |
+
parser.add_argument('--teacher-path', type=str, default='')
|
| 217 |
+
parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
|
| 218 |
+
parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
|
| 219 |
+
parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
|
| 220 |
+
parser.add_argument('--lambda_token', type=float, default=1.0)
|
| 221 |
+
parser.add_argument('--lambda_fea', type=float, default=1.0)
|
| 222 |
+
parser.add_argument('--lambda_patch', type=float, default=1.0)
|
| 223 |
+
|
| 224 |
+
# * Cosub params
|
| 225 |
+
parser.add_argument('--cosub', action='store_true')
|
| 226 |
+
|
| 227 |
+
# * Finetuning params
|
| 228 |
+
parser.add_argument('--finetune', default='', help='finetune from checkpoint')
|
| 229 |
+
parser.add_argument('--attn-only', action='store_true')
|
| 230 |
+
parser.add_argument('--weight_inherit', default='')
|
| 231 |
+
|
| 232 |
+
# Dataset parameters
|
| 233 |
+
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
|
| 234 |
+
help='dataset path')
|
| 235 |
+
parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'IMNET_ibot', 'IMNET_ibot_aug', 'IMNET_ibot_fast_aug', 'INAT', 'INAT19', 'IMNET_L', 'IMNET_L_ibot'],
|
| 236 |
+
type=str, help='Image Net dataset path')
|
| 237 |
+
parser.add_argument('--inat-category', default='name',
|
| 238 |
+
choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
|
| 239 |
+
type=str, help='semantic granularity')
|
| 240 |
+
|
| 241 |
+
parser.add_argument('--output_dir', default='',
|
| 242 |
+
help='path where to save, empty for no saving')
|
| 243 |
+
parser.add_argument('--device', default='cuda',
|
| 244 |
+
help='device to use for training / testing')
|
| 245 |
+
parser.add_argument('--seed', default=0, type=int)
|
| 246 |
+
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
| 247 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
| 248 |
+
help='start epoch')
|
| 249 |
+
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
|
| 250 |
+
parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation")
|
| 251 |
+
parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
|
| 252 |
+
parser.add_argument('--num_workers', default=10, type=int)
|
| 253 |
+
parser.add_argument('--pin-mem', action='store_true',
|
| 254 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
| 255 |
+
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
|
| 256 |
+
help='')
|
| 257 |
+
parser.set_defaults(pin_mem=True)
|
| 258 |
+
|
| 259 |
+
# distributed training parameters
|
| 260 |
+
parser.add_argument('--distributed', action='store_true', default=False, help='Enabling distributed training')
|
| 261 |
+
parser.add_argument('--world_size', default=1, type=int,
|
| 262 |
+
help='number of distributed processes')
|
| 263 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
| 264 |
+
return parser
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def main(args):
|
| 268 |
+
utils.init_distributed_mode(args)
|
| 269 |
+
|
| 270 |
+
print(args)
|
| 271 |
+
|
| 272 |
+
device = torch.device(args.device)
|
| 273 |
+
|
| 274 |
+
# fix the seed for reproducibility
|
| 275 |
+
seed = args.seed + utils.get_rank()
|
| 276 |
+
torch.manual_seed(seed)
|
| 277 |
+
np.random.seed(seed)
|
| 278 |
+
# random.seed(seed)
|
| 279 |
+
|
| 280 |
+
cudnn.benchmark = True
|
| 281 |
+
|
| 282 |
+
dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
|
| 283 |
+
|
| 284 |
+
if args.distributed:
|
| 285 |
+
num_tasks = utils.get_world_size()
|
| 286 |
+
global_rank = utils.get_rank()
|
| 287 |
+
if args.repeated_aug:
|
| 288 |
+
sampler_train = RASampler(
|
| 289 |
+
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
| 290 |
+
)
|
| 291 |
+
else:
|
| 292 |
+
sampler_train = torch.utils.data.DistributedSampler(
|
| 293 |
+
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
| 294 |
+
)
|
| 295 |
+
else:
|
| 296 |
+
sampler_train = torch.utils.data.RandomSampler(dataset_train)
|
| 297 |
+
|
| 298 |
+
n_tokens = (args.global_crops_size // args.patch_size) ** 2
|
| 299 |
+
mask_generator = RandomMaskingGenerator(
|
| 300 |
+
input_size=args.global_crops_size // args.patch_size,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
collate_fn = partial(
|
| 304 |
+
collate_data_and_cast_aug,
|
| 305 |
+
mask_ratio=args.mask_ratio,
|
| 306 |
+
mask_probability=args.mask_probability,
|
| 307 |
+
dtype=torch.half, # half precision
|
| 308 |
+
n_tokens=n_tokens,
|
| 309 |
+
mask_first_n=args.mask_first_n,
|
| 310 |
+
mask_generator=mask_generator,
|
| 311 |
+
clone_batch=args.clone_batch,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
data_loader_train = torch.utils.data.DataLoader(
|
| 315 |
+
dataset_train, sampler=sampler_train,
|
| 316 |
+
batch_size=args.batch_size,
|
| 317 |
+
num_workers=args.num_workers,
|
| 318 |
+
pin_memory=args.pin_mem,
|
| 319 |
+
drop_last=True,
|
| 320 |
+
collate_fn=collate_fn,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
mixup_fn = None
|
| 324 |
+
|
| 325 |
+
print(f"Creating model: {args.model}")
|
| 326 |
+
meta_arch_module = importlib.import_module(args.model)
|
| 327 |
+
MetaArch = meta_arch_module.MetaArch
|
| 328 |
+
|
| 329 |
+
model = MetaArch(args)
|
| 330 |
+
|
| 331 |
+
if args.finetune:
|
| 332 |
+
checkpoint = torch.load(args.finetune, map_location='cpu')
|
| 333 |
+
|
| 334 |
+
if 'state_dict' in checkpoint:
|
| 335 |
+
pretrained_dict = checkpoint['state_dict']
|
| 336 |
+
elif 'model' in checkpoint:
|
| 337 |
+
pretrained_dict = checkpoint['model']
|
| 338 |
+
else:
|
| 339 |
+
pretrained_dict = checkpoint
|
| 340 |
+
|
| 341 |
+
missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False)
|
| 342 |
+
print('missing_keys: ', missing_keys)
|
| 343 |
+
print('unexpected_keys: ', unexpected_keys)
|
| 344 |
+
|
| 345 |
+
if args.attn_only:
|
| 346 |
+
for name_p,p in model.named_parameters():
|
| 347 |
+
if '.attn.' in name_p:
|
| 348 |
+
p.requires_grad = True
|
| 349 |
+
else:
|
| 350 |
+
p.requires_grad = False
|
| 351 |
+
try:
|
| 352 |
+
model.head.weight.requires_grad = True
|
| 353 |
+
model.head.bias.requires_grad = True
|
| 354 |
+
except:
|
| 355 |
+
model.fc.weight.requires_grad = True
|
| 356 |
+
model.fc.bias.requires_grad = True
|
| 357 |
+
try:
|
| 358 |
+
model.pos_embed.requires_grad = True
|
| 359 |
+
except:
|
| 360 |
+
print('no position encoding')
|
| 361 |
+
try:
|
| 362 |
+
for p in model.patch_embed.parameters():
|
| 363 |
+
p.requires_grad = False
|
| 364 |
+
except:
|
| 365 |
+
print('no patch embed')
|
| 366 |
+
|
| 367 |
+
model.to(device)
|
| 368 |
+
|
| 369 |
+
model_ema = None
|
| 370 |
+
if args.model_ema:
|
| 371 |
+
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
|
| 372 |
+
model_ema = ModelEma(
|
| 373 |
+
model.student.backbone,
|
| 374 |
+
decay=args.model_ema_decay,
|
| 375 |
+
device='cpu' if args.model_ema_force_cpu else '',
|
| 376 |
+
resume='')
|
| 377 |
+
|
| 378 |
+
model_without_ddp = model
|
| 379 |
+
if args.distributed:
|
| 380 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
| 381 |
+
model_without_ddp = model.module
|
| 382 |
+
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 383 |
+
print('number of params:', n_parameters)
|
| 384 |
+
|
| 385 |
+
if not args.unscale_lr:
|
| 386 |
+
linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
|
| 387 |
+
args.lr = linear_scaled_lr
|
| 388 |
+
|
| 389 |
+
optimizer = create_optimizer(args, model_without_ddp)
|
| 390 |
+
loss_scaler = NativeScaler()
|
| 391 |
+
|
| 392 |
+
lr_scheduler, _ = create_scheduler(args, optimizer)
|
| 393 |
+
|
| 394 |
+
output_dir = Path(args.output_dir)
|
| 395 |
+
if args.resume:
|
| 396 |
+
if args.resume.startswith('https'):
|
| 397 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 398 |
+
args.resume, map_location='cpu', check_hash=True)
|
| 399 |
+
else:
|
| 400 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
| 401 |
+
|
| 402 |
+
model_without_ddp.load_state_dict(checkpoint['model'])
|
| 403 |
+
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
|
| 404 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 405 |
+
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
| 406 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
| 407 |
+
if args.model_ema:
|
| 408 |
+
utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
|
| 409 |
+
if 'scaler' in checkpoint:
|
| 410 |
+
loss_scaler.load_state_dict(checkpoint['scaler'])
|
| 411 |
+
lr_scheduler.step(args.start_epoch)
|
| 412 |
+
|
| 413 |
+
print(f"Start training for {args.epochs} epochs")
|
| 414 |
+
start_time = time.time()
|
| 415 |
+
max_accuracy = 0.0
|
| 416 |
+
for epoch in range(args.start_epoch, args.epochs):
|
| 417 |
+
if args.distributed:
|
| 418 |
+
data_loader_train.sampler.set_epoch(epoch)
|
| 419 |
+
|
| 420 |
+
train_stats = train_one_epoch(
|
| 421 |
+
model, data_loader_train,
|
| 422 |
+
optimizer, device, epoch, loss_scaler,
|
| 423 |
+
args.clip_grad, model_ema, mixup_fn,
|
| 424 |
+
set_training_mode=args.train_mode, # keep in eval mode for deit finetuning / train mode for training and deit III finetuning
|
| 425 |
+
args = args,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
lr_scheduler.step(epoch)
|
| 429 |
+
if args.output_dir:
|
| 430 |
+
checkpoint_paths = [output_dir / 'checkpoint.pth']
|
| 431 |
+
for checkpoint_path in checkpoint_paths:
|
| 432 |
+
utils.save_on_master({
|
| 433 |
+
'model': model_without_ddp.state_dict(),
|
| 434 |
+
'optimizer': optimizer.state_dict(),
|
| 435 |
+
'lr_scheduler': lr_scheduler.state_dict(),
|
| 436 |
+
'epoch': epoch,
|
| 437 |
+
'model_ema': get_state_dict(model_ema),
|
| 438 |
+
'scaler': loss_scaler.state_dict(),
|
| 439 |
+
'args': args,
|
| 440 |
+
}, checkpoint_path)
|
| 441 |
+
|
| 442 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
| 443 |
+
'epoch': epoch,
|
| 444 |
+
'n_parameters': n_parameters}
|
| 445 |
+
|
| 446 |
+
if args.output_dir and utils.is_main_process():
|
| 447 |
+
with (output_dir / "log.txt").open("a") as f:
|
| 448 |
+
f.write(json.dumps(log_stats) + "\n")
|
| 449 |
+
|
| 450 |
+
total_time = time.time() - start_time
|
| 451 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 452 |
+
print('Training time {}'.format(total_time_str))
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def train_one_epoch(model: torch.nn.Module,
|
| 456 |
+
data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
| 457 |
+
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
|
| 458 |
+
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
|
| 459 |
+
set_training_mode=True, args = None):
|
| 460 |
+
model.train(set_training_mode)
|
| 461 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
| 462 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
| 463 |
+
header = 'Epoch: [{}]'.format(epoch)
|
| 464 |
+
print_freq = 10
|
| 465 |
+
|
| 466 |
+
loader_len = len(data_loader)
|
| 467 |
+
|
| 468 |
+
for data_iter_step, inputs_dict in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
| 469 |
+
|
| 470 |
+
for k, v in inputs_dict.items():
|
| 471 |
+
if isinstance(v, torch.Tensor):
|
| 472 |
+
inputs_dict[k] = v.to(device, non_blocking=True)
|
| 473 |
+
|
| 474 |
+
with torch.cuda.amp.autocast():
|
| 475 |
+
loss_dict = model(inputs_dict)
|
| 476 |
+
|
| 477 |
+
loss = loss_dict["loss"]
|
| 478 |
+
patch_loss = loss_dict["patch_loss"]
|
| 479 |
+
fea_loss = loss_dict["fea_loss"]
|
| 480 |
+
token_loss = loss_dict["token_loss"]
|
| 481 |
+
|
| 482 |
+
patch_loss_value = patch_loss.item()
|
| 483 |
+
token_loss_value = token_loss.item()
|
| 484 |
+
fea_loss_value = fea_loss.item()
|
| 485 |
+
loss_value = loss.item()
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
if not math.isfinite(loss_value):
|
| 489 |
+
print("Loss is {}, stopping training".format(loss_value))
|
| 490 |
+
sys.exit(1)
|
| 491 |
+
|
| 492 |
+
optimizer.zero_grad()
|
| 493 |
+
|
| 494 |
+
# this attribute is added by timm on one optimizer (adahessian)
|
| 495 |
+
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
|
| 496 |
+
loss_scaler(loss, optimizer, clip_grad=max_norm,
|
| 497 |
+
parameters=model.parameters(), create_graph=is_second_order)
|
| 498 |
+
|
| 499 |
+
torch.cuda.synchronize()
|
| 500 |
+
if model_ema is not None:
|
| 501 |
+
model_ema.update(model.module.student.backbone)
|
| 502 |
+
|
| 503 |
+
metric_logger.update(loss=loss_value)
|
| 504 |
+
metric_logger.update(patch_loss=patch_loss_value)
|
| 505 |
+
metric_logger.update(token_loss=token_loss_value)
|
| 506 |
+
metric_logger.update(fea_loss=fea_loss_value)
|
| 507 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
| 508 |
+
# gather the stats from all processes
|
| 509 |
+
metric_logger.synchronize_between_processes()
|
| 510 |
+
print("Averaged stats:", metric_logger)
|
| 511 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
if __name__ == '__main__':
|
| 516 |
+
parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
|
| 517 |
+
args = parser.parse_args()
|
| 518 |
+
if args.output_dir:
|
| 519 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 520 |
+
main(args)
|
1_feature_extractor/models_IB.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from compressai.entropy_models import EntropyBottleneck
|
| 3 |
+
from timm.models.vision_transformer import Block
|
| 4 |
+
|
| 5 |
+
class IF_Module(nn.Module):
|
| 6 |
+
def __init__(self, embed_dim, num_heads, mlp_ratio, depth=4, norm_layer=nn.LayerNorm):
|
| 7 |
+
super(IF_Module, self).__init__()
|
| 8 |
+
|
| 9 |
+
self.encoder_blocks = nn.ModuleList([ # 4 layers, embed_dim=768, num_heads=12
|
| 10 |
+
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
|
| 11 |
+
for i in range(depth)])
|
| 12 |
+
self.encoder_norm = norm_layer(embed_dim)
|
| 13 |
+
|
| 14 |
+
self.decoder_blocks = nn.ModuleList([ # 4 layers, embed_dim=768, num_heads=12
|
| 15 |
+
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
|
| 16 |
+
for i in range(depth)])
|
| 17 |
+
|
| 18 |
+
self.decoder_norm = norm_layer(embed_dim)
|
| 19 |
+
self.entropy_bottleneck = EntropyBottleneck(embed_dim)
|
| 20 |
+
|
| 21 |
+
def forward(self, x, is_training=False):
|
| 22 |
+
# ViT analysis transform
|
| 23 |
+
for blk in self.encoder_blocks:
|
| 24 |
+
x = blk(x)
|
| 25 |
+
x = self.encoder_norm(x)
|
| 26 |
+
|
| 27 |
+
if is_training:
|
| 28 |
+
x = x.permute(0, 2, 1)
|
| 29 |
+
x_hat, x_likelihood = self.entropy_bottleneck(x)
|
| 30 |
+
x_hat = x_hat.permute(0, 2, 1)
|
| 31 |
+
else:
|
| 32 |
+
x_hat = x
|
| 33 |
+
x_likelihood = None
|
| 34 |
+
|
| 35 |
+
# ViT synthesis transform
|
| 36 |
+
for blk in self.decoder_blocks:
|
| 37 |
+
x_hat = blk(x_hat)
|
| 38 |
+
x_hat = self.decoder_norm(x_hat)
|
| 39 |
+
|
| 40 |
+
return x_hat, x_likelihood
|
1_feature_extractor/models_clip.py
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
from typing import Tuple, Union, Callable
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.nn.init import trunc_normal_
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
| 12 |
+
if not depth_first and include_root:
|
| 13 |
+
fn(module=module, name=name)
|
| 14 |
+
for child_name, child_module in module.named_children():
|
| 15 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
| 16 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
| 17 |
+
if depth_first and include_root:
|
| 18 |
+
fn(module=module, name=name)
|
| 19 |
+
return module
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Bottleneck(nn.Module):
|
| 23 |
+
expansion = 4
|
| 24 |
+
|
| 25 |
+
def __init__(self, inplanes, planes, stride=1):
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
| 29 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
| 30 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 31 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 32 |
+
|
| 33 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
| 34 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 35 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 36 |
+
|
| 37 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
| 38 |
+
|
| 39 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
| 40 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 41 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 42 |
+
|
| 43 |
+
self.downsample = None
|
| 44 |
+
self.stride = stride
|
| 45 |
+
|
| 46 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
| 47 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
| 48 |
+
self.downsample = nn.Sequential(OrderedDict([
|
| 49 |
+
("-1", nn.AvgPool2d(stride)),
|
| 50 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
| 51 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
| 52 |
+
]))
|
| 53 |
+
|
| 54 |
+
def forward(self, x: torch.Tensor):
|
| 55 |
+
identity = x
|
| 56 |
+
|
| 57 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
| 58 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
| 59 |
+
out = self.avgpool(out)
|
| 60 |
+
out = self.bn3(self.conv3(out))
|
| 61 |
+
|
| 62 |
+
if self.downsample is not None:
|
| 63 |
+
identity = self.downsample(x)
|
| 64 |
+
|
| 65 |
+
out += identity
|
| 66 |
+
out = self.relu3(out)
|
| 67 |
+
return out
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class AttentionPool2d(nn.Module):
|
| 71 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
| 74 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 75 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 76 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 77 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
| 78 |
+
self.num_heads = num_heads
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
| 82 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
| 83 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
| 84 |
+
x, _ = F.multi_head_attention_forward(
|
| 85 |
+
query=x[:1], key=x, value=x,
|
| 86 |
+
embed_dim_to_check=x.shape[-1],
|
| 87 |
+
num_heads=self.num_heads,
|
| 88 |
+
q_proj_weight=self.q_proj.weight,
|
| 89 |
+
k_proj_weight=self.k_proj.weight,
|
| 90 |
+
v_proj_weight=self.v_proj.weight,
|
| 91 |
+
in_proj_weight=None,
|
| 92 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
| 93 |
+
bias_k=None,
|
| 94 |
+
bias_v=None,
|
| 95 |
+
add_zero_attn=False,
|
| 96 |
+
dropout_p=0,
|
| 97 |
+
out_proj_weight=self.c_proj.weight,
|
| 98 |
+
out_proj_bias=self.c_proj.bias,
|
| 99 |
+
use_separate_proj_weight=True,
|
| 100 |
+
training=self.training,
|
| 101 |
+
need_weights=False
|
| 102 |
+
)
|
| 103 |
+
return x.squeeze(0)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class ModifiedResNet(nn.Module):
|
| 107 |
+
"""
|
| 108 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
| 109 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
| 110 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
| 111 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.output_dim = output_dim
|
| 117 |
+
self.input_resolution = input_resolution
|
| 118 |
+
|
| 119 |
+
# the 3-layer stem
|
| 120 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
| 121 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
| 122 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 123 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
| 124 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
| 125 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 126 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
| 127 |
+
self.bn3 = nn.BatchNorm2d(width)
|
| 128 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 129 |
+
self.avgpool = nn.AvgPool2d(2)
|
| 130 |
+
|
| 131 |
+
# residual layers
|
| 132 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
| 133 |
+
self.layer1 = self._make_layer(width, layers[0])
|
| 134 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
| 135 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
| 136 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
| 137 |
+
|
| 138 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
| 139 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
| 140 |
+
|
| 141 |
+
def _make_layer(self, planes, blocks, stride=1):
|
| 142 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
| 143 |
+
|
| 144 |
+
self._inplanes = planes * Bottleneck.expansion
|
| 145 |
+
for _ in range(1, blocks):
|
| 146 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
| 147 |
+
|
| 148 |
+
return nn.Sequential(*layers)
|
| 149 |
+
|
| 150 |
+
def forward(self, x):
|
| 151 |
+
def stem(x):
|
| 152 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
| 153 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
| 154 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
| 155 |
+
x = self.avgpool(x)
|
| 156 |
+
return x
|
| 157 |
+
|
| 158 |
+
x = x.type(self.conv1.weight.dtype)
|
| 159 |
+
x = stem(x)
|
| 160 |
+
x = self.layer1(x)
|
| 161 |
+
x = self.layer2(x)
|
| 162 |
+
x = self.layer3(x)
|
| 163 |
+
x = self.layer4(x)
|
| 164 |
+
x = self.attnpool(x)
|
| 165 |
+
|
| 166 |
+
return x
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class LayerNorm(nn.LayerNorm):
|
| 170 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
| 171 |
+
|
| 172 |
+
def forward(self, x: torch.Tensor):
|
| 173 |
+
orig_type = x.dtype
|
| 174 |
+
ret = super().forward(x.type(torch.float32))
|
| 175 |
+
return ret.type(orig_type)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class QuickGELU(nn.Module):
|
| 179 |
+
def forward(self, x: torch.Tensor):
|
| 180 |
+
return x * torch.sigmoid(1.702 * x)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class ResidualAttentionBlock(nn.Module):
|
| 184 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
| 185 |
+
super().__init__()
|
| 186 |
+
|
| 187 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
| 188 |
+
self.ln_1 = LayerNorm(d_model)
|
| 189 |
+
self.mlp = nn.Sequential(OrderedDict([
|
| 190 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
| 191 |
+
("gelu", QuickGELU()),
|
| 192 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
| 193 |
+
]))
|
| 194 |
+
self.ln_2 = LayerNorm(d_model)
|
| 195 |
+
self.attn_mask = attn_mask
|
| 196 |
+
|
| 197 |
+
def attention(self, x: torch.Tensor):
|
| 198 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
| 199 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
| 200 |
+
|
| 201 |
+
def forward(self, x: torch.Tensor):
|
| 202 |
+
x = x + self.attention(self.ln_1(x))
|
| 203 |
+
x = x + self.mlp(self.ln_2(x))
|
| 204 |
+
return x
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class Transformer(nn.Module):
|
| 208 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
| 209 |
+
super().__init__()
|
| 210 |
+
self.width = width
|
| 211 |
+
self.layers = layers
|
| 212 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
| 213 |
+
|
| 214 |
+
def forward(self, x: torch.Tensor):
|
| 215 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 216 |
+
x = self.resblocks(x)
|
| 217 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 218 |
+
|
| 219 |
+
return x
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class VisionTransformer(nn.Module):
|
| 223 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int):
|
| 224 |
+
super().__init__()
|
| 225 |
+
self.input_resolution = input_resolution
|
| 226 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
| 227 |
+
|
| 228 |
+
scale = width ** -0.5
|
| 229 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
| 230 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
| 231 |
+
self.ln_pre = LayerNorm(width)
|
| 232 |
+
|
| 233 |
+
self.transformer = Transformer(width, layers, heads)
|
| 234 |
+
|
| 235 |
+
self.mask_token = nn.Parameter(torch.zeros(1, width))
|
| 236 |
+
|
| 237 |
+
self.ln_post = LayerNorm(width)
|
| 238 |
+
|
| 239 |
+
self.embed_dim = width
|
| 240 |
+
self.patch_size = patch_size
|
| 241 |
+
|
| 242 |
+
self.init_weights()
|
| 243 |
+
|
| 244 |
+
def init_weights(self):
|
| 245 |
+
trunc_normal_(self.positional_embedding, std=0.02)
|
| 246 |
+
nn.init.normal_(self.class_embedding, std=1e-6)
|
| 247 |
+
named_apply(init_weights_vit_timm, self)
|
| 248 |
+
|
| 249 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
| 250 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
| 251 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
| 252 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
| 253 |
+
|
| 254 |
+
if masks is not None:
|
| 255 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
| 256 |
+
|
| 257 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
| 258 |
+
x = x + self.positional_embedding.to(x.dtype)
|
| 259 |
+
x = self.ln_pre(x)
|
| 260 |
+
|
| 261 |
+
return x
|
| 262 |
+
|
| 263 |
+
def forward_features_list(self, x_list, masks_list):
|
| 264 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
| 265 |
+
|
| 266 |
+
all_x = [self.transformer(t) for t in x]
|
| 267 |
+
|
| 268 |
+
output = []
|
| 269 |
+
for x, masks in zip(all_x, masks_list):
|
| 270 |
+
output.append(
|
| 271 |
+
{
|
| 272 |
+
"x_norm_clstoken": self.ln_post(x[:, 0]),
|
| 273 |
+
"x_norm_patchtokens": x[:, 1 :],
|
| 274 |
+
"x_prenorm": x,
|
| 275 |
+
"masks": masks,
|
| 276 |
+
}
|
| 277 |
+
)
|
| 278 |
+
return output
|
| 279 |
+
|
| 280 |
+
def forward(self, x: torch.Tensor, masks=None):
|
| 281 |
+
if isinstance(x, list):
|
| 282 |
+
return self.forward_features_list(x, masks)
|
| 283 |
+
|
| 284 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
| 285 |
+
|
| 286 |
+
x = self.transformer(x)
|
| 287 |
+
|
| 288 |
+
return {
|
| 289 |
+
"x_norm_clstoken": self.ln_post(x[:, 0]),
|
| 290 |
+
"x_norm_patchtokens": x[:, 1 :],
|
| 291 |
+
"x_prenorm": x,
|
| 292 |
+
"masks": masks,
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
| 297 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 298 |
+
if isinstance(module, nn.Linear):
|
| 299 |
+
trunc_normal_(module.weight, std=0.02)
|
| 300 |
+
if module.bias is not None:
|
| 301 |
+
nn.init.zeros_(module.bias)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def vit_small(patch_size=14, teacher_path=None):
|
| 306 |
+
model = VisionTransformer(
|
| 307 |
+
input_resolution=224,
|
| 308 |
+
patch_size=patch_size,
|
| 309 |
+
width=384,
|
| 310 |
+
layers=12,
|
| 311 |
+
heads=6
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
if teacher_path is not None:
|
| 315 |
+
checkpoint = torch.load(teacher_path, map_location='cpu')
|
| 316 |
+
|
| 317 |
+
if 'state_dict' in checkpoint:
|
| 318 |
+
pretrained_dict = checkpoint['state_dict']
|
| 319 |
+
elif 'model' in checkpoint:
|
| 320 |
+
pretrained_dict = checkpoint['model']
|
| 321 |
+
else:
|
| 322 |
+
pretrained_dict = checkpoint
|
| 323 |
+
|
| 324 |
+
missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False)
|
| 325 |
+
print('missing_keys: ', missing_keys)
|
| 326 |
+
print('unexpected_keys: ', unexpected_keys)
|
| 327 |
+
|
| 328 |
+
return model
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def vit_base(patch_size=14, teacher_path=None):
|
| 332 |
+
model = VisionTransformer(
|
| 333 |
+
input_resolution=224,
|
| 334 |
+
patch_size=patch_size,
|
| 335 |
+
width=768,
|
| 336 |
+
layers=12,
|
| 337 |
+
heads=12
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
if teacher_path is not None:
|
| 341 |
+
checkpoint = torch.load(teacher_path, map_location='cpu')
|
| 342 |
+
|
| 343 |
+
if 'state_dict' in checkpoint:
|
| 344 |
+
pretrained_dict = checkpoint['state_dict']
|
| 345 |
+
elif 'model' in checkpoint:
|
| 346 |
+
pretrained_dict = checkpoint['model']
|
| 347 |
+
else:
|
| 348 |
+
pretrained_dict = checkpoint
|
| 349 |
+
|
| 350 |
+
missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False)
|
| 351 |
+
print('missing_keys: ', missing_keys)
|
| 352 |
+
print('unexpected_keys: ', unexpected_keys)
|
| 353 |
+
|
| 354 |
+
return model
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def vit_large(patch_size=14, teacher_path=None):
|
| 358 |
+
model = VisionTransformer(
|
| 359 |
+
input_resolution=224,
|
| 360 |
+
patch_size=patch_size,
|
| 361 |
+
width=1024,
|
| 362 |
+
layers=24,
|
| 363 |
+
heads=16
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
if teacher_path is not None:
|
| 367 |
+
checkpoint = torch.load(teacher_path, map_location='cpu')
|
| 368 |
+
|
| 369 |
+
if 'state_dict' in checkpoint:
|
| 370 |
+
pretrained_dict = checkpoint['state_dict']
|
| 371 |
+
elif 'model' in checkpoint:
|
| 372 |
+
pretrained_dict = checkpoint['model']
|
| 373 |
+
else:
|
| 374 |
+
pretrained_dict = checkpoint
|
| 375 |
+
|
| 376 |
+
missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False)
|
| 377 |
+
print('missing_keys: ', missing_keys)
|
| 378 |
+
print('unexpected_keys: ', unexpected_keys)
|
| 379 |
+
|
| 380 |
+
return model
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
if __name__ == "__main__":
|
| 385 |
+
import argparse
|
| 386 |
+
import clip
|
| 387 |
+
import open_clip
|
| 388 |
+
from fvcore.nn import FlopCountAnalysis, parameter_count_table
|
| 389 |
+
parser = argparse.ArgumentParser(description='PyTorch resnet Training')
|
| 390 |
+
args = parser.parse_args()
|
| 391 |
+
|
| 392 |
+
# with torch.no_grad():
|
| 393 |
+
# print(clip.available_models())
|
| 394 |
+
|
| 395 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 396 |
+
# model, preprocess = clip.load('ViT-L/14', device)
|
| 397 |
+
# print(model.visual)
|
| 398 |
+
|
| 399 |
+
# model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', pretrained='laion400m_e32')
|
| 400 |
+
# model = model.to('cuda')
|
| 401 |
+
|
| 402 |
+
# for k,v in model.visual.named_parameters():
|
| 403 |
+
# print(k, v.shape)
|
| 404 |
+
|
| 405 |
+
# self_model = VisionTransformer(
|
| 406 |
+
# input_resolution=224,
|
| 407 |
+
# patch_size=32,
|
| 408 |
+
# width=768,
|
| 409 |
+
# layers=12,
|
| 410 |
+
# heads=12
|
| 411 |
+
# )
|
| 412 |
+
|
| 413 |
+
# print(self_model)
|
| 414 |
+
|
| 415 |
+
# for k,v in self_model.named_parameters():
|
| 416 |
+
# print(k, v.shape)
|
| 417 |
+
|
| 418 |
+
# new_ckpt = OrderedDict()
|
| 419 |
+
# for k,v in model.visual.named_parameters():
|
| 420 |
+
# if 'proj' != k:
|
| 421 |
+
# print(k)
|
| 422 |
+
# new_ckpt[k] = v
|
| 423 |
+
# new_ckpt[k] = v
|
| 424 |
+
|
| 425 |
+
# torch.save(new_ckpt, '/home/qw/yitian/TA-KD/clip_model/clip_l_14_400m.pth')
|
| 426 |
+
|
| 427 |
+
# model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion400m_e32')
|
| 428 |
+
# print(model.visual)
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
# model = clip_base_32()
|
| 432 |
+
# model = clip_base_14()
|
| 433 |
+
|
| 434 |
+
# print(parameter_count_table(model))
|
| 435 |
+
|
| 436 |
+
# tensor = torch.rand(1, 3, 224, 224)
|
| 437 |
+
# flops = FlopCountAnalysis(model, tensor)
|
| 438 |
+
# print("FLOPs: ", flops.total()/1e9)
|
1_feature_extractor/models_dinov2.py
ADDED
|
@@ -0,0 +1,907 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
from functools import partial
|
| 11 |
+
import math
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Sequence, Tuple, Union, Callable, Dict, Optional, Any, List
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.utils.checkpoint
|
| 18 |
+
from torch.nn.init import trunc_normal_
|
| 19 |
+
import os
|
| 20 |
+
import warnings
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger("dinov2")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
from xformers.ops import memory_efficient_attention, unbind, fmha, scaled_index_add, index_select_cat
|
| 26 |
+
from xformers.ops import SwiGLU
|
| 27 |
+
|
| 28 |
+
XFORMERS_AVAILABLE = True
|
| 29 |
+
warnings.warn("xFormers is available (Attention)")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 34 |
+
if drop_prob == 0.0 or not training:
|
| 35 |
+
return x
|
| 36 |
+
keep_prob = 1 - drop_prob
|
| 37 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 38 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 39 |
+
if keep_prob > 0.0:
|
| 40 |
+
random_tensor.div_(keep_prob)
|
| 41 |
+
output = x * random_tensor
|
| 42 |
+
return output
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class DropPath(nn.Module):
|
| 46 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 47 |
+
|
| 48 |
+
def __init__(self, drop_prob=None):
|
| 49 |
+
super(DropPath, self).__init__()
|
| 50 |
+
self.drop_prob = drop_prob
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class LayerScale(nn.Module):
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
dim: int,
|
| 60 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 61 |
+
inplace: bool = False,
|
| 62 |
+
) -> None:
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.inplace = inplace
|
| 65 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 66 |
+
|
| 67 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 68 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class Attention(nn.Module):
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
dim: int,
|
| 75 |
+
num_heads: int = 8,
|
| 76 |
+
qkv_bias: bool = False,
|
| 77 |
+
proj_bias: bool = True,
|
| 78 |
+
attn_drop: float = 0.0,
|
| 79 |
+
proj_drop: float = 0.0,
|
| 80 |
+
) -> None:
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.num_heads = num_heads
|
| 83 |
+
head_dim = dim // num_heads
|
| 84 |
+
self.scale = head_dim**-0.5
|
| 85 |
+
|
| 86 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 87 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 88 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 89 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 90 |
+
|
| 91 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 92 |
+
B, N, C = x.shape
|
| 93 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 94 |
+
|
| 95 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 96 |
+
attn = q @ k.transpose(-2, -1)
|
| 97 |
+
|
| 98 |
+
attn = attn.softmax(dim=-1)
|
| 99 |
+
attn = self.attn_drop(attn)
|
| 100 |
+
|
| 101 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 102 |
+
x = self.proj(x)
|
| 103 |
+
x = self.proj_drop(x)
|
| 104 |
+
return x
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class MemEffAttention(Attention):
|
| 108 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 109 |
+
if not XFORMERS_AVAILABLE:
|
| 110 |
+
if attn_bias is not None:
|
| 111 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 112 |
+
return super().forward(x)
|
| 113 |
+
|
| 114 |
+
B, N, C = x.shape
|
| 115 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 116 |
+
|
| 117 |
+
q, k, v = unbind(qkv, 2)
|
| 118 |
+
|
| 119 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 120 |
+
x = x.reshape([B, N, C])
|
| 121 |
+
|
| 122 |
+
x = self.proj(x)
|
| 123 |
+
x = self.proj_drop(x)
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class SwiGLUFFN(nn.Module):
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
in_features: int,
|
| 132 |
+
hidden_features: Optional[int] = None,
|
| 133 |
+
out_features: Optional[int] = None,
|
| 134 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 135 |
+
drop: float = 0.0,
|
| 136 |
+
bias: bool = True,
|
| 137 |
+
) -> None:
|
| 138 |
+
super().__init__()
|
| 139 |
+
out_features = out_features or in_features
|
| 140 |
+
hidden_features = hidden_features or in_features
|
| 141 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 142 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 143 |
+
|
| 144 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 145 |
+
x12 = self.w12(x)
|
| 146 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 147 |
+
hidden = F.silu(x1) * x2
|
| 148 |
+
return self.w3(hidden)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class SwiGLUFFNFused(SwiGLU):
|
| 153 |
+
def __init__(
|
| 154 |
+
self,
|
| 155 |
+
in_features: int,
|
| 156 |
+
hidden_features: Optional[int] = None,
|
| 157 |
+
out_features: Optional[int] = None,
|
| 158 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 159 |
+
drop: float = 0.0,
|
| 160 |
+
bias: bool = True,
|
| 161 |
+
) -> None:
|
| 162 |
+
out_features = out_features or in_features
|
| 163 |
+
hidden_features = hidden_features or in_features
|
| 164 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 165 |
+
super().__init__(
|
| 166 |
+
in_features=in_features,
|
| 167 |
+
hidden_features=hidden_features,
|
| 168 |
+
out_features=out_features,
|
| 169 |
+
bias=bias,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def make_2tuple(x):
|
| 175 |
+
if isinstance(x, tuple):
|
| 176 |
+
assert len(x) == 2
|
| 177 |
+
return x
|
| 178 |
+
|
| 179 |
+
assert isinstance(x, int)
|
| 180 |
+
return (x, x)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class PatchEmbed(nn.Module):
|
| 184 |
+
"""
|
| 185 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
img_size: Image size.
|
| 189 |
+
patch_size: Patch token size.
|
| 190 |
+
in_chans: Number of input image channels.
|
| 191 |
+
embed_dim: Number of linear projection output channels.
|
| 192 |
+
norm_layer: Normalization layer.
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
def __init__(
|
| 196 |
+
self,
|
| 197 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 198 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 199 |
+
in_chans: int = 3,
|
| 200 |
+
embed_dim: int = 768,
|
| 201 |
+
norm_layer: Optional[Callable] = None,
|
| 202 |
+
flatten_embedding: bool = True,
|
| 203 |
+
) -> None:
|
| 204 |
+
super().__init__()
|
| 205 |
+
|
| 206 |
+
image_HW = make_2tuple(img_size)
|
| 207 |
+
patch_HW = make_2tuple(patch_size)
|
| 208 |
+
patch_grid_size = (
|
| 209 |
+
image_HW[0] // patch_HW[0],
|
| 210 |
+
image_HW[1] // patch_HW[1],
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
self.img_size = image_HW
|
| 214 |
+
self.patch_size = patch_HW
|
| 215 |
+
self.patches_resolution = patch_grid_size
|
| 216 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 217 |
+
|
| 218 |
+
self.in_chans = in_chans
|
| 219 |
+
self.embed_dim = embed_dim
|
| 220 |
+
|
| 221 |
+
self.flatten_embedding = flatten_embedding
|
| 222 |
+
|
| 223 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
| 224 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 225 |
+
|
| 226 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 227 |
+
_, _, H, W = x.shape
|
| 228 |
+
patch_H, patch_W = self.patch_size
|
| 229 |
+
|
| 230 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 231 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 232 |
+
|
| 233 |
+
x = self.proj(x) # B C H W
|
| 234 |
+
H, W = x.size(2), x.size(3)
|
| 235 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 236 |
+
x = self.norm(x)
|
| 237 |
+
if not self.flatten_embedding:
|
| 238 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 239 |
+
return x
|
| 240 |
+
|
| 241 |
+
def flops(self) -> float:
|
| 242 |
+
Ho, Wo = self.patches_resolution
|
| 243 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
| 244 |
+
if self.norm is not None:
|
| 245 |
+
flops += Ho * Wo * self.embed_dim
|
| 246 |
+
return flops
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class Mlp(nn.Module):
|
| 250 |
+
def __init__(
|
| 251 |
+
self,
|
| 252 |
+
in_features: int,
|
| 253 |
+
hidden_features: Optional[int] = None,
|
| 254 |
+
out_features: Optional[int] = None,
|
| 255 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 256 |
+
drop: float = 0.0,
|
| 257 |
+
bias: bool = True,
|
| 258 |
+
) -> None:
|
| 259 |
+
super().__init__()
|
| 260 |
+
out_features = out_features or in_features
|
| 261 |
+
hidden_features = hidden_features or in_features
|
| 262 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 263 |
+
self.act = act_layer()
|
| 264 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 265 |
+
self.drop = nn.Dropout(drop)
|
| 266 |
+
|
| 267 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 268 |
+
x = self.fc1(x)
|
| 269 |
+
x = self.act(x)
|
| 270 |
+
x = self.drop(x)
|
| 271 |
+
x = self.fc2(x)
|
| 272 |
+
x = self.drop(x)
|
| 273 |
+
return x
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class Basic_Block(nn.Module):
|
| 279 |
+
def __init__(
|
| 280 |
+
self,
|
| 281 |
+
dim: int,
|
| 282 |
+
num_heads: int,
|
| 283 |
+
mlp_ratio: float = 4.0,
|
| 284 |
+
qkv_bias: bool = False,
|
| 285 |
+
proj_bias: bool = True,
|
| 286 |
+
ffn_bias: bool = True,
|
| 287 |
+
drop: float = 0.0,
|
| 288 |
+
attn_drop: float = 0.0,
|
| 289 |
+
init_values=None,
|
| 290 |
+
drop_path: float = 0.0,
|
| 291 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 292 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 293 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 294 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 295 |
+
) -> None:
|
| 296 |
+
super().__init__()
|
| 297 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 298 |
+
self.norm1 = norm_layer(dim)
|
| 299 |
+
self.attn = attn_class(
|
| 300 |
+
dim,
|
| 301 |
+
num_heads=num_heads,
|
| 302 |
+
qkv_bias=qkv_bias,
|
| 303 |
+
proj_bias=proj_bias,
|
| 304 |
+
attn_drop=attn_drop,
|
| 305 |
+
proj_drop=drop,
|
| 306 |
+
)
|
| 307 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 308 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 309 |
+
|
| 310 |
+
self.norm2 = norm_layer(dim)
|
| 311 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 312 |
+
self.mlp = ffn_layer(
|
| 313 |
+
in_features=dim,
|
| 314 |
+
hidden_features=mlp_hidden_dim,
|
| 315 |
+
act_layer=act_layer,
|
| 316 |
+
drop=drop,
|
| 317 |
+
bias=ffn_bias,
|
| 318 |
+
)
|
| 319 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 320 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 321 |
+
|
| 322 |
+
self.sample_drop_ratio = drop_path
|
| 323 |
+
|
| 324 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 325 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
| 326 |
+
return self.ls1(self.attn(self.norm1(x)))
|
| 327 |
+
|
| 328 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 329 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 330 |
+
|
| 331 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 332 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 333 |
+
x = drop_add_residual_stochastic_depth(
|
| 334 |
+
x,
|
| 335 |
+
residual_func=attn_residual_func,
|
| 336 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 337 |
+
)
|
| 338 |
+
x = drop_add_residual_stochastic_depth(
|
| 339 |
+
x,
|
| 340 |
+
residual_func=ffn_residual_func,
|
| 341 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 342 |
+
)
|
| 343 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 344 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
| 345 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 346 |
+
else:
|
| 347 |
+
x = x + attn_residual_func(x)
|
| 348 |
+
x = x + ffn_residual_func(x)
|
| 349 |
+
return x
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def drop_add_residual_stochastic_depth(
|
| 353 |
+
x: Tensor,
|
| 354 |
+
residual_func: Callable[[Tensor], Tensor],
|
| 355 |
+
sample_drop_ratio: float = 0.0,
|
| 356 |
+
) -> Tensor:
|
| 357 |
+
# 1) extract subset using permutation
|
| 358 |
+
b, n, d = x.shape
|
| 359 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 360 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 361 |
+
x_subset = x[brange]
|
| 362 |
+
|
| 363 |
+
# 2) apply residual_func to get residual
|
| 364 |
+
residual = residual_func(x_subset)
|
| 365 |
+
|
| 366 |
+
x_flat = x.flatten(1)
|
| 367 |
+
residual = residual.flatten(1)
|
| 368 |
+
|
| 369 |
+
residual_scale_factor = b / sample_subset_size
|
| 370 |
+
|
| 371 |
+
# 3) add the residual
|
| 372 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 373 |
+
return x_plus_residual.view_as(x)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 377 |
+
b, n, d = x.shape
|
| 378 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 379 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 380 |
+
residual_scale_factor = b / sample_subset_size
|
| 381 |
+
return brange, residual_scale_factor
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 385 |
+
if scaling_vector is None:
|
| 386 |
+
x_flat = x.flatten(1)
|
| 387 |
+
residual = residual.flatten(1)
|
| 388 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 389 |
+
else:
|
| 390 |
+
x_plus_residual = scaled_index_add(
|
| 391 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
| 392 |
+
)
|
| 393 |
+
return x_plus_residual
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
| 400 |
+
"""
|
| 401 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 402 |
+
"""
|
| 403 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
| 404 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 405 |
+
if all_shapes not in attn_bias_cache.keys():
|
| 406 |
+
seqlens = []
|
| 407 |
+
for b, x in zip(batch_sizes, x_list):
|
| 408 |
+
for _ in range(b):
|
| 409 |
+
seqlens.append(x.shape[1])
|
| 410 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 411 |
+
attn_bias._batch_sizes = batch_sizes
|
| 412 |
+
attn_bias_cache[all_shapes] = attn_bias
|
| 413 |
+
|
| 414 |
+
if branges is not None:
|
| 415 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
| 416 |
+
else:
|
| 417 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 418 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 419 |
+
|
| 420 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def drop_add_residual_stochastic_depth_list(
|
| 424 |
+
x_list: List[Tensor],
|
| 425 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
| 426 |
+
sample_drop_ratio: float = 0.0,
|
| 427 |
+
scaling_vector=None,
|
| 428 |
+
) -> Tensor:
|
| 429 |
+
# 1) generate random set of indices for dropping samples in the batch
|
| 430 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
| 431 |
+
branges = [s[0] for s in branges_scales]
|
| 432 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
| 433 |
+
|
| 434 |
+
# 2) get attention bias and index+concat the tensors
|
| 435 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 436 |
+
|
| 437 |
+
# 3) apply residual_func to get residual, and split the result
|
| 438 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 439 |
+
|
| 440 |
+
outputs = []
|
| 441 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
| 442 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
| 443 |
+
return outputs
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class Block(Basic_Block):
|
| 447 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 448 |
+
"""
|
| 449 |
+
x_list contains a list of tensors to nest together and run
|
| 450 |
+
"""
|
| 451 |
+
assert isinstance(self.attn, MemEffAttention)
|
| 452 |
+
|
| 453 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 454 |
+
|
| 455 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 456 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 457 |
+
|
| 458 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 459 |
+
return self.mlp(self.norm2(x))
|
| 460 |
+
|
| 461 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 462 |
+
x_list,
|
| 463 |
+
residual_func=attn_residual_func,
|
| 464 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 465 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 466 |
+
)
|
| 467 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 468 |
+
x_list,
|
| 469 |
+
residual_func=ffn_residual_func,
|
| 470 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 471 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 472 |
+
)
|
| 473 |
+
return x_list
|
| 474 |
+
else:
|
| 475 |
+
|
| 476 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 477 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 478 |
+
|
| 479 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 480 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 481 |
+
|
| 482 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 483 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 484 |
+
x = x + ffn_residual_func(x)
|
| 485 |
+
return attn_bias.split(x)
|
| 486 |
+
|
| 487 |
+
def forward(self, x_or_x_list):
|
| 488 |
+
if isinstance(x_or_x_list, Tensor):
|
| 489 |
+
return super().forward(x_or_x_list)
|
| 490 |
+
elif isinstance(x_or_x_list, list):
|
| 491 |
+
if not XFORMERS_AVAILABLE:
|
| 492 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 493 |
+
return self.forward_nested(x_or_x_list)
|
| 494 |
+
else:
|
| 495 |
+
raise AssertionError
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
| 502 |
+
if not depth_first and include_root:
|
| 503 |
+
fn(module=module, name=name)
|
| 504 |
+
for child_name, child_module in module.named_children():
|
| 505 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
| 506 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
| 507 |
+
if depth_first and include_root:
|
| 508 |
+
fn(module=module, name=name)
|
| 509 |
+
return module
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
class BlockChunk(nn.ModuleList):
|
| 513 |
+
def forward(self, x):
|
| 514 |
+
for b in self:
|
| 515 |
+
x = b(x)
|
| 516 |
+
return x
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
class DinoVisionTransformer(nn.Module):
|
| 520 |
+
def __init__(
|
| 521 |
+
self,
|
| 522 |
+
img_size=224,
|
| 523 |
+
patch_size=16,
|
| 524 |
+
in_chans=3,
|
| 525 |
+
embed_dim=768,
|
| 526 |
+
depth=12,
|
| 527 |
+
num_heads=12,
|
| 528 |
+
mlp_ratio=4.0,
|
| 529 |
+
qkv_bias=True,
|
| 530 |
+
ffn_bias=True,
|
| 531 |
+
proj_bias=True,
|
| 532 |
+
drop_path_rate=0.0,
|
| 533 |
+
drop_path_uniform=False,
|
| 534 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
| 535 |
+
embed_layer=PatchEmbed,
|
| 536 |
+
act_layer=nn.GELU,
|
| 537 |
+
block_fn=Block,
|
| 538 |
+
ffn_layer="mlp",
|
| 539 |
+
block_chunks=1,
|
| 540 |
+
num_register_tokens=0,
|
| 541 |
+
interpolate_antialias=False,
|
| 542 |
+
interpolate_offset=0.1,
|
| 543 |
+
):
|
| 544 |
+
"""
|
| 545 |
+
Args:
|
| 546 |
+
img_size (int, tuple): input image size
|
| 547 |
+
patch_size (int, tuple): patch size
|
| 548 |
+
in_chans (int): number of input channels
|
| 549 |
+
embed_dim (int): embedding dimension
|
| 550 |
+
depth (int): depth of transformer
|
| 551 |
+
num_heads (int): number of attention heads
|
| 552 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 553 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 554 |
+
proj_bias (bool): enable bias for proj in attn if True
|
| 555 |
+
ffn_bias (bool): enable bias for ffn if True
|
| 556 |
+
drop_path_rate (float): stochastic depth rate
|
| 557 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
| 558 |
+
weight_init (str): weight init scheme
|
| 559 |
+
init_values (float): layer-scale init values
|
| 560 |
+
embed_layer (nn.Module): patch embedding layer
|
| 561 |
+
act_layer (nn.Module): MLP activation layer
|
| 562 |
+
block_fn (nn.Module): transformer block class
|
| 563 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
| 564 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
| 565 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
| 566 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
| 567 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
| 568 |
+
"""
|
| 569 |
+
super().__init__()
|
| 570 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 571 |
+
|
| 572 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 573 |
+
self.num_tokens = 1
|
| 574 |
+
self.n_blocks = depth
|
| 575 |
+
self.num_heads = num_heads
|
| 576 |
+
self.patch_size = patch_size
|
| 577 |
+
self.num_register_tokens = num_register_tokens
|
| 578 |
+
self.interpolate_antialias = interpolate_antialias
|
| 579 |
+
self.interpolate_offset = interpolate_offset
|
| 580 |
+
|
| 581 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 582 |
+
num_patches = self.patch_embed.num_patches
|
| 583 |
+
|
| 584 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 585 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
| 586 |
+
assert num_register_tokens >= 0
|
| 587 |
+
self.register_tokens = (
|
| 588 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
if drop_path_uniform is True:
|
| 592 |
+
dpr = [drop_path_rate] * depth
|
| 593 |
+
else:
|
| 594 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 595 |
+
|
| 596 |
+
if ffn_layer == "mlp":
|
| 597 |
+
logger.info("using MLP layer as FFN")
|
| 598 |
+
ffn_layer = Mlp
|
| 599 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
| 600 |
+
logger.info("using SwiGLU layer as FFN")
|
| 601 |
+
ffn_layer = SwiGLUFFNFused
|
| 602 |
+
elif ffn_layer == "identity":
|
| 603 |
+
logger.info("using Identity layer as FFN")
|
| 604 |
+
|
| 605 |
+
def f(*args, **kwargs):
|
| 606 |
+
return nn.Identity()
|
| 607 |
+
|
| 608 |
+
ffn_layer = f
|
| 609 |
+
else:
|
| 610 |
+
raise NotImplementedError
|
| 611 |
+
|
| 612 |
+
blocks_list = [
|
| 613 |
+
block_fn(
|
| 614 |
+
dim=embed_dim,
|
| 615 |
+
num_heads=num_heads,
|
| 616 |
+
mlp_ratio=mlp_ratio,
|
| 617 |
+
qkv_bias=qkv_bias,
|
| 618 |
+
proj_bias=proj_bias,
|
| 619 |
+
ffn_bias=ffn_bias,
|
| 620 |
+
drop_path=dpr[i],
|
| 621 |
+
norm_layer=norm_layer,
|
| 622 |
+
act_layer=act_layer,
|
| 623 |
+
ffn_layer=ffn_layer,
|
| 624 |
+
init_values=init_values,
|
| 625 |
+
)
|
| 626 |
+
for i in range(depth)
|
| 627 |
+
]
|
| 628 |
+
if block_chunks > 0:
|
| 629 |
+
self.chunked_blocks = True
|
| 630 |
+
chunked_blocks = []
|
| 631 |
+
chunksize = depth // block_chunks
|
| 632 |
+
for i in range(0, depth, chunksize):
|
| 633 |
+
# this is to keep the block index consistent if we chunk the block list
|
| 634 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
| 635 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
| 636 |
+
else:
|
| 637 |
+
self.chunked_blocks = False
|
| 638 |
+
self.blocks = nn.ModuleList(blocks_list)
|
| 639 |
+
|
| 640 |
+
self.norm = norm_layer(embed_dim)
|
| 641 |
+
self.head = nn.Identity()
|
| 642 |
+
|
| 643 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
| 644 |
+
|
| 645 |
+
self.init_weights()
|
| 646 |
+
|
| 647 |
+
def init_weights(self):
|
| 648 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 649 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 650 |
+
if self.register_tokens is not None:
|
| 651 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
| 652 |
+
named_apply(init_weights_vit_timm, self)
|
| 653 |
+
|
| 654 |
+
def interpolate_pos_encoding(self, x, w, h):
|
| 655 |
+
previous_dtype = x.dtype
|
| 656 |
+
npatch = x.shape[1] - 1
|
| 657 |
+
N = self.pos_embed.shape[1] - 1
|
| 658 |
+
if npatch == N and w == h:
|
| 659 |
+
return self.pos_embed
|
| 660 |
+
pos_embed = self.pos_embed.float()
|
| 661 |
+
class_pos_embed = pos_embed[:, 0]
|
| 662 |
+
patch_pos_embed = pos_embed[:, 1:]
|
| 663 |
+
dim = x.shape[-1]
|
| 664 |
+
w0 = w // self.patch_size
|
| 665 |
+
h0 = h // self.patch_size
|
| 666 |
+
# we add a small number to avoid floating point error in the interpolation
|
| 667 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
| 668 |
+
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
|
| 669 |
+
|
| 670 |
+
sqrt_N = math.sqrt(N)
|
| 671 |
+
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
|
| 672 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 673 |
+
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
|
| 674 |
+
scale_factor=(sx, sy),
|
| 675 |
+
mode="bicubic",
|
| 676 |
+
antialias=self.interpolate_antialias,
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
assert int(w0) == patch_pos_embed.shape[-2]
|
| 680 |
+
assert int(h0) == patch_pos_embed.shape[-1]
|
| 681 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 682 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
| 683 |
+
|
| 684 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
| 685 |
+
B, nc, w, h = x.shape
|
| 686 |
+
x = self.patch_embed(x)
|
| 687 |
+
if masks is not None:
|
| 688 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 692 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
| 693 |
+
|
| 694 |
+
if self.register_tokens is not None:
|
| 695 |
+
x = torch.cat(
|
| 696 |
+
(
|
| 697 |
+
x[:, :1],
|
| 698 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
| 699 |
+
x[:, 1:],
|
| 700 |
+
),
|
| 701 |
+
dim=1,
|
| 702 |
+
)
|
| 703 |
+
|
| 704 |
+
return x
|
| 705 |
+
|
| 706 |
+
def forward_features_list(self, x_list, masks_list):
|
| 707 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
| 708 |
+
for blk in self.blocks:
|
| 709 |
+
x = blk(x)
|
| 710 |
+
|
| 711 |
+
all_x = x
|
| 712 |
+
output = []
|
| 713 |
+
for x, masks in zip(all_x, masks_list):
|
| 714 |
+
x_norm = self.norm(x)
|
| 715 |
+
output.append(
|
| 716 |
+
{
|
| 717 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 718 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 719 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 720 |
+
"x_prenorm": x,
|
| 721 |
+
"masks": masks,
|
| 722 |
+
}
|
| 723 |
+
)
|
| 724 |
+
return output
|
| 725 |
+
|
| 726 |
+
def forward_features(self, x, masks=None):
|
| 727 |
+
if isinstance(x, list):
|
| 728 |
+
return self.forward_features_list(x, masks)
|
| 729 |
+
|
| 730 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
| 731 |
+
|
| 732 |
+
for blk in self.blocks:
|
| 733 |
+
x = blk(x)
|
| 734 |
+
|
| 735 |
+
x_norm = self.norm(x)
|
| 736 |
+
return {
|
| 737 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 738 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 739 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 740 |
+
"x_prenorm": x,
|
| 741 |
+
"masks": masks,
|
| 742 |
+
}
|
| 743 |
+
|
| 744 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
| 745 |
+
x = self.prepare_tokens_with_masks(x)
|
| 746 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 747 |
+
output, total_block_len = [], len(self.blocks)
|
| 748 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 749 |
+
for i, blk in enumerate(self.blocks):
|
| 750 |
+
x = blk(x)
|
| 751 |
+
if i in blocks_to_take:
|
| 752 |
+
output.append(x)
|
| 753 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 754 |
+
return output
|
| 755 |
+
|
| 756 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
| 757 |
+
x = self.prepare_tokens_with_masks(x)
|
| 758 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
| 759 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 760 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 761 |
+
for block_chunk in self.blocks:
|
| 762 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
| 763 |
+
x = blk(x)
|
| 764 |
+
if i in blocks_to_take:
|
| 765 |
+
output.append(x)
|
| 766 |
+
i += 1
|
| 767 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 768 |
+
return output
|
| 769 |
+
|
| 770 |
+
def get_intermediate_layers(
|
| 771 |
+
self,
|
| 772 |
+
x: torch.Tensor,
|
| 773 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
| 774 |
+
reshape: bool = False,
|
| 775 |
+
return_class_token: bool = False,
|
| 776 |
+
norm=True,
|
| 777 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
| 778 |
+
if self.chunked_blocks:
|
| 779 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
| 780 |
+
else:
|
| 781 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
| 782 |
+
if norm:
|
| 783 |
+
outputs = [self.norm(out) for out in outputs]
|
| 784 |
+
class_tokens = [out[:, 0] for out in outputs]
|
| 785 |
+
outputs = [out[:, 1:] for out in outputs]
|
| 786 |
+
if reshape:
|
| 787 |
+
B, _, w, h = x.shape
|
| 788 |
+
outputs = [
|
| 789 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
| 790 |
+
for out in outputs
|
| 791 |
+
]
|
| 792 |
+
if return_class_token:
|
| 793 |
+
return tuple(zip(outputs, class_tokens))
|
| 794 |
+
return tuple(outputs)
|
| 795 |
+
|
| 796 |
+
def forward(self, *args, is_training=False, **kwargs):
|
| 797 |
+
ret = self.forward_features(*args, **kwargs)
|
| 798 |
+
if is_training:
|
| 799 |
+
return ret
|
| 800 |
+
else:
|
| 801 |
+
return self.head(ret["x_norm_clstoken"])
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
| 805 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 806 |
+
if isinstance(module, nn.Linear):
|
| 807 |
+
trunc_normal_(module.weight, std=0.02)
|
| 808 |
+
if module.bias is not None:
|
| 809 |
+
nn.init.zeros_(module.bias)
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
def vit_tiny(patch_size=16, num_register_tokens=0, **kwargs):
|
| 813 |
+
model = DinoVisionTransformer(
|
| 814 |
+
patch_size=patch_size,
|
| 815 |
+
embed_dim=192,
|
| 816 |
+
depth=12,
|
| 817 |
+
num_heads=3,
|
| 818 |
+
mlp_ratio=4,
|
| 819 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 820 |
+
num_register_tokens=num_register_tokens,
|
| 821 |
+
**kwargs,
|
| 822 |
+
)
|
| 823 |
+
return model
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
| 828 |
+
model = DinoVisionTransformer(
|
| 829 |
+
patch_size=patch_size,
|
| 830 |
+
embed_dim=384,
|
| 831 |
+
depth=12,
|
| 832 |
+
num_heads=6,
|
| 833 |
+
mlp_ratio=4,
|
| 834 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 835 |
+
num_register_tokens=num_register_tokens,
|
| 836 |
+
**kwargs,
|
| 837 |
+
)
|
| 838 |
+
return model
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
| 842 |
+
model = DinoVisionTransformer(
|
| 843 |
+
patch_size=patch_size,
|
| 844 |
+
embed_dim=768,
|
| 845 |
+
depth=12,
|
| 846 |
+
num_heads=12,
|
| 847 |
+
mlp_ratio=4,
|
| 848 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 849 |
+
num_register_tokens=num_register_tokens,
|
| 850 |
+
**kwargs,
|
| 851 |
+
)
|
| 852 |
+
return model
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
| 856 |
+
model = DinoVisionTransformer(
|
| 857 |
+
patch_size=patch_size,
|
| 858 |
+
embed_dim=1024,
|
| 859 |
+
depth=24,
|
| 860 |
+
num_heads=16,
|
| 861 |
+
mlp_ratio=4,
|
| 862 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 863 |
+
num_register_tokens=num_register_tokens,
|
| 864 |
+
**kwargs,
|
| 865 |
+
)
|
| 866 |
+
return model
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
| 870 |
+
"""
|
| 871 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
| 872 |
+
"""
|
| 873 |
+
model = DinoVisionTransformer(
|
| 874 |
+
patch_size=patch_size,
|
| 875 |
+
embed_dim=1536,
|
| 876 |
+
depth=40,
|
| 877 |
+
num_heads=24,
|
| 878 |
+
mlp_ratio=4,
|
| 879 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 880 |
+
num_register_tokens=num_register_tokens,
|
| 881 |
+
**kwargs,
|
| 882 |
+
)
|
| 883 |
+
return model
|
| 884 |
+
|
| 885 |
+
|
| 886 |
+
if __name__ == "__main__":
|
| 887 |
+
import argparse
|
| 888 |
+
from fvcore.nn import FlopCountAnalysis, parameter_count_table
|
| 889 |
+
|
| 890 |
+
with torch.no_grad():
|
| 891 |
+
model = vit_base(img_size=224,
|
| 892 |
+
patch_size=14,
|
| 893 |
+
init_values=1.0,
|
| 894 |
+
ffn_layer='mlp',
|
| 895 |
+
block_chunks=0,
|
| 896 |
+
num_register_tokens=0,
|
| 897 |
+
interpolate_antialias=False,
|
| 898 |
+
interpolate_offset=0.1)
|
| 899 |
+
|
| 900 |
+
for name, param in model.named_parameters():
|
| 901 |
+
print(name, param)
|
| 902 |
+
|
| 903 |
+
# print(parameter_count_table(model))
|
| 904 |
+
|
| 905 |
+
# tensor = torch.rand(1, 3, 224, 224)
|
| 906 |
+
# flops = FlopCountAnalysis(model, tensor)
|
| 907 |
+
# print("FLOPs: ", flops.total()/1e9)
|
1_feature_extractor/models_proteus_clip.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Google LLC
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import torch.distributed.nn
|
| 19 |
+
import torch.distributed as dist
|
| 20 |
+
from torch.nn.init import trunc_normal_
|
| 21 |
+
from torch.nn.utils import weight_norm
|
| 22 |
+
import models_clip
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class MetaArch(nn.Module):
|
| 27 |
+
|
| 28 |
+
def __init__(self, cfg):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.cfg = cfg
|
| 31 |
+
|
| 32 |
+
student_model_dict = dict()
|
| 33 |
+
teacher_model_dict = dict()
|
| 34 |
+
|
| 35 |
+
import_student = getattr(models_clip, cfg.target_model)
|
| 36 |
+
student = import_student()
|
| 37 |
+
|
| 38 |
+
embed_dim = student.embed_dim
|
| 39 |
+
|
| 40 |
+
import_teacher = getattr(models_clip, cfg.teacher_model)
|
| 41 |
+
teacher_backbone = import_teacher(teacher_path=cfg.teacher_path)
|
| 42 |
+
teacher_backbone.eval()
|
| 43 |
+
|
| 44 |
+
student_model_dict['backbone'] = student
|
| 45 |
+
teacher_model_dict['backbone'] = teacher_backbone
|
| 46 |
+
|
| 47 |
+
self.embed_dim = embed_dim
|
| 48 |
+
|
| 49 |
+
# initialize parameters and checks
|
| 50 |
+
self.total_n_global_crops = cfg.batch_size
|
| 51 |
+
|
| 52 |
+
self.student = nn.ModuleDict(student_model_dict)
|
| 53 |
+
self.teacher = nn.ModuleDict(teacher_model_dict)
|
| 54 |
+
|
| 55 |
+
teacher_embed_dim = teacher_backbone.embed_dim
|
| 56 |
+
|
| 57 |
+
self.token_head = nn.Sequential(
|
| 58 |
+
nn.LayerNorm(embed_dim),
|
| 59 |
+
nn.Linear(embed_dim, teacher_embed_dim))
|
| 60 |
+
|
| 61 |
+
self.soft_criterion = torch.nn.MSELoss()
|
| 62 |
+
|
| 63 |
+
for param in self.teacher.backbone.parameters():
|
| 64 |
+
param.requires_grad = False
|
| 65 |
+
|
| 66 |
+
## we explicitly remove the patch and feature learning objectives for CLIP training following the original design
|
| 67 |
+
def forward(self, inputs):
|
| 68 |
+
global_crops = inputs["collated_global_crops"]
|
| 69 |
+
|
| 70 |
+
# compute teacher output
|
| 71 |
+
# @torch.no_grad()
|
| 72 |
+
def compute_teacher_output():
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
teacher_backbone_output_dict = self.teacher.backbone(global_crops)
|
| 75 |
+
teacher_cls_tokens = teacher_backbone_output_dict["x_norm_clstoken"]
|
| 76 |
+
|
| 77 |
+
return teacher_cls_tokens
|
| 78 |
+
|
| 79 |
+
# get the teacher outputs
|
| 80 |
+
teacher_cls_tokens = compute_teacher_output()
|
| 81 |
+
|
| 82 |
+
student_backbone_output_dict_unmask = self.student.backbone(global_crops)
|
| 83 |
+
|
| 84 |
+
student_cls_token_unmask = student_backbone_output_dict_unmask["x_norm_clstoken"]
|
| 85 |
+
|
| 86 |
+
## projection head
|
| 87 |
+
student_cls_token_unmask = self.token_head(student_cls_token_unmask)
|
| 88 |
+
|
| 89 |
+
## token objective
|
| 90 |
+
distillation_loss_token = self.soft_criterion(student_cls_token_unmask, teacher_cls_tokens)
|
| 91 |
+
|
| 92 |
+
# coefficient
|
| 93 |
+
token_loss = self.cfg.lambda_token * distillation_loss_token
|
| 94 |
+
|
| 95 |
+
# compute the total loss
|
| 96 |
+
total_loss = token_loss
|
| 97 |
+
|
| 98 |
+
# return the final loss dict
|
| 99 |
+
loss_dict = {"patch_loss": torch.tensor(0.0), "fea_loss": torch.tensor(0.0), "token_loss": token_loss, "loss": total_loss}
|
| 100 |
+
|
| 101 |
+
return loss_dict
|
1_feature_extractor/models_proteus_dinov2.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Google LLC
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import torch.distributed.nn
|
| 19 |
+
import torch.distributed as dist
|
| 20 |
+
from torch.nn.init import trunc_normal_
|
| 21 |
+
from torch.nn.utils import weight_norm
|
| 22 |
+
import models_dinov2
|
| 23 |
+
from models_IB import IF_Module
|
| 24 |
+
import math
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class MetaArch(nn.Module):
|
| 28 |
+
|
| 29 |
+
def __init__(self, cfg):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.cfg = cfg
|
| 32 |
+
|
| 33 |
+
student_model_dict = dict()
|
| 34 |
+
teacher_model_dict = dict()
|
| 35 |
+
|
| 36 |
+
import_student = getattr(models_dinov2, cfg.target_model)
|
| 37 |
+
student = import_student(img_size=224,
|
| 38 |
+
patch_size=cfg.patch_size,
|
| 39 |
+
init_values=1.0,
|
| 40 |
+
ffn_layer='mlp',
|
| 41 |
+
block_chunks=0,
|
| 42 |
+
num_register_tokens=0,
|
| 43 |
+
interpolate_antialias=False,
|
| 44 |
+
interpolate_offset=0.1)
|
| 45 |
+
|
| 46 |
+
embed_dim = student.embed_dim
|
| 47 |
+
|
| 48 |
+
if cfg.teacher_model == 'vit_base':
|
| 49 |
+
teacher_backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_lc')
|
| 50 |
+
elif cfg.teacher_model == 'vit_small':
|
| 51 |
+
teacher_backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_lc')
|
| 52 |
+
elif cfg.teacher_model == 'vit_large':
|
| 53 |
+
teacher_backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_lc')
|
| 54 |
+
elif cfg.teacher_model == 'vit_giant':
|
| 55 |
+
teacher_backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_lc')
|
| 56 |
+
teacher_backbone.eval()
|
| 57 |
+
|
| 58 |
+
student_model_dict['backbone'] = student
|
| 59 |
+
teacher_model_dict['backbone'] = teacher_backbone.backbone
|
| 60 |
+
|
| 61 |
+
self.embed_dim = embed_dim
|
| 62 |
+
|
| 63 |
+
# initialize parameters and checks
|
| 64 |
+
self.total_n_global_crops = cfg.batch_size
|
| 65 |
+
|
| 66 |
+
self.student = nn.ModuleDict(student_model_dict)
|
| 67 |
+
self.teacher = nn.ModuleDict(teacher_model_dict)
|
| 68 |
+
|
| 69 |
+
teacher_embed_dim = teacher_backbone.backbone.embed_dim
|
| 70 |
+
self.ibot_head = nn.Sequential(
|
| 71 |
+
nn.LayerNorm(embed_dim),
|
| 72 |
+
nn.Linear(embed_dim, teacher_embed_dim))
|
| 73 |
+
|
| 74 |
+
self.token_head = nn.Sequential(
|
| 75 |
+
nn.LayerNorm(embed_dim),
|
| 76 |
+
nn.Linear(embed_dim, teacher_embed_dim))
|
| 77 |
+
|
| 78 |
+
self.fea_head = nn.Sequential(
|
| 79 |
+
nn.LayerNorm(embed_dim),
|
| 80 |
+
nn.Linear(embed_dim, teacher_embed_dim))
|
| 81 |
+
|
| 82 |
+
self.soft_criterion = torch.nn.MSELoss()
|
| 83 |
+
|
| 84 |
+
self.info_bottleneck = IF_Module(embed_dim=embed_dim, num_heads=12, mlp_ratio=4, depth=4)
|
| 85 |
+
|
| 86 |
+
for param in self.teacher.backbone.parameters():
|
| 87 |
+
param.requires_grad = False
|
| 88 |
+
|
| 89 |
+
def cal_bpp(self, image, unmask_likelihood, mask_likelihood):
|
| 90 |
+
b, _, h, w = image.size()
|
| 91 |
+
num_pixels = b * h * w
|
| 92 |
+
log_unmask_likelihoods = torch.log(unmask_likelihood)
|
| 93 |
+
log_mask_likelihoods = torch.log(mask_likelihood)
|
| 94 |
+
bpp = (log_unmask_likelihoods.sum() + log_mask_likelihoods.sum()) / (-math.log(2) * num_pixels * 1.5)
|
| 95 |
+
return bpp
|
| 96 |
+
|
| 97 |
+
def forward(self, inputs):
|
| 98 |
+
global_crops = inputs["collated_global_crops"]
|
| 99 |
+
|
| 100 |
+
masks = inputs["collated_masks"]
|
| 101 |
+
mask_indices_list = inputs["mask_indices_list"]
|
| 102 |
+
n_masked_patches = mask_indices_list.shape[0]
|
| 103 |
+
upperbound = inputs["upperbound"]
|
| 104 |
+
|
| 105 |
+
n_global_crops = 1
|
| 106 |
+
|
| 107 |
+
# compute teacher output
|
| 108 |
+
# @torch.no_grad()
|
| 109 |
+
def compute_teacher_output():
|
| 110 |
+
with torch.no_grad():
|
| 111 |
+
teacher_backbone_output_dict = self.teacher.backbone(global_crops, is_training=True)
|
| 112 |
+
teacher_cls_tokens = teacher_backbone_output_dict["x_norm_clstoken"]
|
| 113 |
+
teacher_patch_tokens = teacher_backbone_output_dict["x_norm_patchtokens"]
|
| 114 |
+
_dim = teacher_patch_tokens.shape[-1]
|
| 115 |
+
|
| 116 |
+
# mask teacher patch tokens
|
| 117 |
+
buffer_tensor_teacher = teacher_patch_tokens.new_zeros(upperbound, _dim)
|
| 118 |
+
torch.index_select(
|
| 119 |
+
teacher_patch_tokens.flatten(0, 1),
|
| 120 |
+
dim=0,
|
| 121 |
+
index=mask_indices_list,
|
| 122 |
+
out=buffer_tensor_teacher[:n_masked_patches],
|
| 123 |
+
)
|
| 124 |
+
teacher_patch_tokens_masked = buffer_tensor_teacher[:n_masked_patches]
|
| 125 |
+
|
| 126 |
+
return teacher_cls_tokens, teacher_patch_tokens, teacher_patch_tokens_masked
|
| 127 |
+
|
| 128 |
+
# get the teacher outputs
|
| 129 |
+
(
|
| 130 |
+
teacher_cls_tokens,
|
| 131 |
+
teacher_patch_tokens,
|
| 132 |
+
teacher_patch_tokens_masked
|
| 133 |
+
) = compute_teacher_output()
|
| 134 |
+
|
| 135 |
+
cur_masks = masks if self.cfg.mask_probability > 0 else None
|
| 136 |
+
|
| 137 |
+
student_backbone_output_dict, student_backbone_output_dict_unmask = self.student.backbone(
|
| 138 |
+
[global_crops, global_crops], masks=[cur_masks, None], is_training=True
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
student_cls_token_unmask = student_backbone_output_dict_unmask["x_norm_clstoken"]
|
| 142 |
+
student_patch_tokens_unmask = student_backbone_output_dict_unmask["x_norm_patchtokens"]
|
| 143 |
+
student_patch_tokens = student_backbone_output_dict["x_norm_patchtokens"]
|
| 144 |
+
|
| 145 |
+
# calculate bitrate
|
| 146 |
+
student_patch_tokens_unmask, unmask_likelihood = self.info_bottleneck(student_patch_tokens_unmask, is_training=True)
|
| 147 |
+
student_patch_tokens, mask_likelihood = self.info_bottleneck(student_patch_tokens, is_training=True)
|
| 148 |
+
bpp = self.cal_bpp(global_crops, unmask_likelihood, mask_likelihood)
|
| 149 |
+
|
| 150 |
+
# mask student patch tokens
|
| 151 |
+
_dim = student_patch_tokens.shape[-1]
|
| 152 |
+
|
| 153 |
+
buffer_tensor_student = student_patch_tokens.new_zeros(upperbound, _dim)
|
| 154 |
+
buffer_tensor_student[:n_masked_patches].copy_(
|
| 155 |
+
torch.index_select(student_patch_tokens.flatten(0, 1),
|
| 156 |
+
dim=0,
|
| 157 |
+
index=mask_indices_list)
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
## projection head
|
| 161 |
+
student_patch_tokens_unmask = self.fea_head(student_patch_tokens_unmask)
|
| 162 |
+
|
| 163 |
+
student_cls_token_unmask = self.token_head(student_cls_token_unmask)
|
| 164 |
+
|
| 165 |
+
tokens_after_head = self.ibot_head(buffer_tensor_student)
|
| 166 |
+
student_patch_tokens_masked = tokens_after_head[:n_masked_patches]
|
| 167 |
+
|
| 168 |
+
## token objective
|
| 169 |
+
distillation_loss_token = self.soft_criterion(student_cls_token_unmask, teacher_cls_tokens)
|
| 170 |
+
|
| 171 |
+
## fea objective
|
| 172 |
+
student_whole_fea = torch.cat((student_cls_token_unmask.unsqueeze(1),student_patch_tokens_unmask),dim=1)
|
| 173 |
+
teacher_whole_fea = torch.cat((teacher_cls_tokens.unsqueeze(1),teacher_patch_tokens),dim=1)
|
| 174 |
+
distillation_loss_fea = self.soft_criterion(student_whole_fea, teacher_whole_fea)
|
| 175 |
+
|
| 176 |
+
## patch objective
|
| 177 |
+
patch_loss = self.soft_criterion(student_patch_tokens_masked, teacher_patch_tokens_masked)
|
| 178 |
+
|
| 179 |
+
# coefficient
|
| 180 |
+
token_loss = self.cfg.lambda_token * distillation_loss_token
|
| 181 |
+
fea_loss = self.cfg.lambda_fea * distillation_loss_fea
|
| 182 |
+
patch_loss_weighted = self.cfg.lambda_patch * patch_loss
|
| 183 |
+
# print(f"self.cfg: {self.cfg}")
|
| 184 |
+
# print(f"self.cfg.lambda_token: {self.cfg.lambda_token}, self.cfg.lambda_fea: {self.cfg.lambda_fea}, self.cfg.lambda_patch: {self.cfg.lambda_patch}")
|
| 185 |
+
|
| 186 |
+
# compute the total loss
|
| 187 |
+
total_loss = patch_loss_weighted + fea_loss + token_loss + 0.48 * bpp
|
| 188 |
+
# task_loss = patch_loss + fea_loss + token_loss
|
| 189 |
+
task_loss = patch_loss + distillation_loss_fea + distillation_loss_token
|
| 190 |
+
|
| 191 |
+
# return the final loss dict
|
| 192 |
+
loss_dict = {"bpp_loss": bpp,
|
| 193 |
+
"patch_loss": patch_loss,
|
| 194 |
+
"fea_loss": distillation_loss_fea,
|
| 195 |
+
"token_loss": token_loss,
|
| 196 |
+
"loss": total_loss,
|
| 197 |
+
"task_loss": task_loss,
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
return loss_dict
|
1_feature_extractor/models_proteus_synclr.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Google LLC
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import torch.distributed.nn
|
| 19 |
+
import torch.distributed as dist
|
| 20 |
+
from torch.nn.init import trunc_normal_
|
| 21 |
+
from torch.nn.utils import weight_norm
|
| 22 |
+
import models_synclr
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class MetaArch(nn.Module):
|
| 27 |
+
|
| 28 |
+
def __init__(self, cfg):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.cfg = cfg
|
| 31 |
+
|
| 32 |
+
student_model_dict = dict()
|
| 33 |
+
teacher_model_dict = dict()
|
| 34 |
+
|
| 35 |
+
import_student = getattr(models_synclr, cfg.target_model)
|
| 36 |
+
student = import_student(patch_size=cfg.patch_size, num_classes=0, mask_style='ibot')
|
| 37 |
+
|
| 38 |
+
embed_dim = student.embed_dim
|
| 39 |
+
|
| 40 |
+
import_teacher = getattr(models_synclr, cfg.teacher_model)
|
| 41 |
+
teacher_backbone = import_teacher(patch_size=cfg.patch_size, teacher_path=cfg.teacher_path, num_classes=0, mask_style='ibot')
|
| 42 |
+
teacher_backbone.eval()
|
| 43 |
+
|
| 44 |
+
student_model_dict['backbone'] = student
|
| 45 |
+
teacher_model_dict['backbone'] = teacher_backbone
|
| 46 |
+
|
| 47 |
+
self.embed_dim = embed_dim
|
| 48 |
+
|
| 49 |
+
# initialize parameters and checks
|
| 50 |
+
self.total_n_global_crops = cfg.batch_size
|
| 51 |
+
|
| 52 |
+
self.student = nn.ModuleDict(student_model_dict)
|
| 53 |
+
self.teacher = nn.ModuleDict(teacher_model_dict)
|
| 54 |
+
|
| 55 |
+
teacher_embed_dim = teacher_backbone.embed_dim
|
| 56 |
+
self.patch_head = nn.Sequential(
|
| 57 |
+
nn.LayerNorm(embed_dim),
|
| 58 |
+
nn.Linear(embed_dim, teacher_embed_dim))
|
| 59 |
+
|
| 60 |
+
self.token_head = nn.Sequential(
|
| 61 |
+
nn.LayerNorm(embed_dim),
|
| 62 |
+
nn.Linear(embed_dim, teacher_embed_dim))
|
| 63 |
+
|
| 64 |
+
self.fea_head = nn.Sequential(
|
| 65 |
+
nn.LayerNorm(embed_dim),
|
| 66 |
+
nn.Linear(embed_dim, teacher_embed_dim))
|
| 67 |
+
|
| 68 |
+
self.soft_criterion = torch.nn.MSELoss()
|
| 69 |
+
|
| 70 |
+
for param in self.teacher.backbone.parameters():
|
| 71 |
+
param.requires_grad = False
|
| 72 |
+
|
| 73 |
+
def forward(self, inputs):
|
| 74 |
+
global_crops = inputs["collated_global_crops"]
|
| 75 |
+
|
| 76 |
+
masks = inputs["collated_masks"]
|
| 77 |
+
mask_indices_list = inputs["mask_indices_list"]
|
| 78 |
+
n_masked_patches = mask_indices_list.shape[0]
|
| 79 |
+
upperbound = inputs["upperbound"]
|
| 80 |
+
|
| 81 |
+
n_global_crops = 1
|
| 82 |
+
|
| 83 |
+
# compute teacher output
|
| 84 |
+
# @torch.no_grad()
|
| 85 |
+
def compute_teacher_output():
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
teacher_backbone_output_dict = self.teacher.backbone(global_crops, is_training=True)
|
| 88 |
+
teacher_cls_tokens = teacher_backbone_output_dict["x_norm_clstoken"]
|
| 89 |
+
teacher_patch_tokens = teacher_backbone_output_dict["x_norm_patchtokens"]
|
| 90 |
+
_dim = teacher_patch_tokens.shape[-1]
|
| 91 |
+
|
| 92 |
+
# mask teacher patch tokens
|
| 93 |
+
buffer_tensor_teacher = teacher_patch_tokens.new_zeros(upperbound, _dim)
|
| 94 |
+
torch.index_select(
|
| 95 |
+
teacher_patch_tokens.flatten(0, 1),
|
| 96 |
+
dim=0,
|
| 97 |
+
index=mask_indices_list,
|
| 98 |
+
out=buffer_tensor_teacher[:n_masked_patches],
|
| 99 |
+
)
|
| 100 |
+
teacher_patch_tokens_masked = buffer_tensor_teacher[:n_masked_patches]
|
| 101 |
+
|
| 102 |
+
return teacher_cls_tokens, teacher_patch_tokens, teacher_patch_tokens_masked
|
| 103 |
+
|
| 104 |
+
# get the teacher outputs
|
| 105 |
+
(
|
| 106 |
+
teacher_cls_tokens,
|
| 107 |
+
teacher_patch_tokens,
|
| 108 |
+
teacher_patch_tokens_masked
|
| 109 |
+
) = compute_teacher_output()
|
| 110 |
+
|
| 111 |
+
cur_masks = masks if self.cfg.mask_probability > 0 else None
|
| 112 |
+
|
| 113 |
+
student_backbone_output_dict, student_backbone_output_dict_unmask = self.student.backbone(
|
| 114 |
+
[global_crops, global_crops], masks=[cur_masks, None], is_training=True
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
student_cls_token_unmask = student_backbone_output_dict_unmask["x_norm_clstoken"]
|
| 118 |
+
student_patch_tokens_unmask = student_backbone_output_dict_unmask["x_norm_patchtokens"]
|
| 119 |
+
student_patch_tokens = student_backbone_output_dict["x_norm_patchtokens"]
|
| 120 |
+
|
| 121 |
+
# mask student patch tokens
|
| 122 |
+
_dim = student_patch_tokens.shape[-1]
|
| 123 |
+
|
| 124 |
+
buffer_tensor_student = student_patch_tokens.new_zeros(upperbound, _dim)
|
| 125 |
+
buffer_tensor_student[:n_masked_patches].copy_(
|
| 126 |
+
torch.index_select(student_patch_tokens.flatten(0, 1),
|
| 127 |
+
dim=0,
|
| 128 |
+
index=mask_indices_list)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
## projection head
|
| 132 |
+
student_patch_tokens_unmask = self.fea_head(student_patch_tokens_unmask)
|
| 133 |
+
|
| 134 |
+
student_cls_token_unmask = self.token_head(student_cls_token_unmask)
|
| 135 |
+
|
| 136 |
+
tokens_after_head = self.patch_head(buffer_tensor_student)
|
| 137 |
+
student_patch_tokens_masked = tokens_after_head[:n_masked_patches]
|
| 138 |
+
|
| 139 |
+
## token objective
|
| 140 |
+
distillation_loss_token = self.soft_criterion(student_cls_token_unmask, teacher_cls_tokens)
|
| 141 |
+
|
| 142 |
+
## fea objective
|
| 143 |
+
student_whole_fea = torch.cat((student_cls_token_unmask.unsqueeze(1),student_patch_tokens_unmask),dim=1)
|
| 144 |
+
teacher_whole_fea = torch.cat((teacher_cls_tokens.unsqueeze(1),teacher_patch_tokens),dim=1)
|
| 145 |
+
distillation_loss_fea = self.soft_criterion(student_whole_fea, teacher_whole_fea)
|
| 146 |
+
|
| 147 |
+
## patch objective
|
| 148 |
+
patch_loss = self.soft_criterion(student_patch_tokens_masked, teacher_patch_tokens_masked)
|
| 149 |
+
|
| 150 |
+
# coefficient
|
| 151 |
+
token_loss = self.cfg.lambda_token * distillation_loss_token
|
| 152 |
+
fea_loss = self.cfg.lambda_fea * distillation_loss_fea
|
| 153 |
+
patch_loss = self.cfg.lambda_patch * patch_loss
|
| 154 |
+
|
| 155 |
+
# compute the total loss
|
| 156 |
+
total_loss = patch_loss + fea_loss + token_loss
|
| 157 |
+
|
| 158 |
+
# return the final loss dict
|
| 159 |
+
loss_dict = {"patch_loss": patch_loss, "fea_loss": fea_loss, "token_loss": token_loss, "loss": total_loss}
|
| 160 |
+
|
| 161 |
+
return loss_dict
|
1_feature_extractor/models_synclr.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Google LLC
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from functools import partial
|
| 17 |
+
from collections import OrderedDict
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
|
| 22 |
+
from timm.models.layers import trunc_normal_, lecun_normal_, to_2tuple
|
| 23 |
+
from timm.models.vision_transformer import Attention
|
| 24 |
+
from timm.models.layers import Mlp, DropPath
|
| 25 |
+
from timm.models.helpers import named_apply
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Block(nn.Module):
|
| 30 |
+
|
| 31 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
| 32 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, ffn_targets=False,
|
| 33 |
+
return_layer_targets=False):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.norm1 = norm_layer(dim)
|
| 36 |
+
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
| 37 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 38 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 39 |
+
self.norm2 = norm_layer(dim)
|
| 40 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 41 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 42 |
+
|
| 43 |
+
# specify the targets for feature regression
|
| 44 |
+
self.ffn_targets = ffn_targets
|
| 45 |
+
self.return_layer_targets = return_layer_targets
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
if isinstance(x, tuple):
|
| 49 |
+
x = x[0]
|
| 50 |
+
|
| 51 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
| 52 |
+
ffn_out = self.mlp(self.norm2(x))
|
| 53 |
+
x = x + self.drop_path(ffn_out)
|
| 54 |
+
|
| 55 |
+
target = ffn_out if self.ffn_targets else x
|
| 56 |
+
|
| 57 |
+
if self.return_layer_targets:
|
| 58 |
+
return x, target
|
| 59 |
+
return x
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class PatchEmbed(nn.Module):
|
| 63 |
+
""" 2D Image to Patch Embedding
|
| 64 |
+
"""
|
| 65 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
|
| 66 |
+
super().__init__()
|
| 67 |
+
img_size = to_2tuple(img_size)
|
| 68 |
+
patch_size = to_2tuple(patch_size)
|
| 69 |
+
self.img_size = img_size
|
| 70 |
+
self.patch_size = patch_size
|
| 71 |
+
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
| 72 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
| 73 |
+
self.flatten = flatten
|
| 74 |
+
|
| 75 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 76 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
B, C, H, W = x.shape
|
| 80 |
+
patch_H, patch_W = self.patch_size
|
| 81 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 82 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 83 |
+
x = self.proj(x)
|
| 84 |
+
if self.flatten:
|
| 85 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
| 86 |
+
x = self.norm(x)
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class VisionTransformer(nn.Module):
|
| 91 |
+
""" Vision Transformer
|
| 92 |
+
|
| 93 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
| 94 |
+
- https://arxiv.org/abs/2010.11929
|
| 95 |
+
|
| 96 |
+
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
|
| 97 |
+
- https://arxiv.org/abs/2012.12877
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
| 101 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
|
| 102 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
|
| 103 |
+
act_layer=None, weight_init='', ffn_targets=False, return_layer_targets=False):
|
| 104 |
+
"""
|
| 105 |
+
Args:
|
| 106 |
+
img_size (int, tuple): input image size
|
| 107 |
+
patch_size (int, tuple): patch size
|
| 108 |
+
in_chans (int): number of input channels
|
| 109 |
+
num_classes (int): number of classes for classification head
|
| 110 |
+
embed_dim (int): embedding dimension
|
| 111 |
+
depth (int): depth of transformer
|
| 112 |
+
num_heads (int): number of attention heads
|
| 113 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 114 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 115 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
| 116 |
+
distilled (bool): model includes a distillation token and head as in DeiT models
|
| 117 |
+
drop_rate (float): dropout rate
|
| 118 |
+
attn_drop_rate (float): attention dropout rate
|
| 119 |
+
drop_path_rate (float): stochastic depth rate
|
| 120 |
+
embed_layer (nn.Module): patch embedding layer
|
| 121 |
+
norm_layer: (nn.Module): normalization layer
|
| 122 |
+
weight_init: (str): weight init scheme
|
| 123 |
+
ffn_targets (bool): whether we use ffn output or block end as the feature targets
|
| 124 |
+
return_layer_targets (bool): whether we return every layer targets
|
| 125 |
+
"""
|
| 126 |
+
super().__init__()
|
| 127 |
+
self.num_classes = num_classes
|
| 128 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 129 |
+
self.num_tokens = 2 if distilled else 1
|
| 130 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
| 131 |
+
act_layer = act_layer or nn.GELU
|
| 132 |
+
|
| 133 |
+
self.patch_embed = embed_layer(
|
| 134 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 135 |
+
num_patches = self.patch_embed.num_patches
|
| 136 |
+
|
| 137 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 138 |
+
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
|
| 139 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
| 140 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 141 |
+
|
| 142 |
+
self.ffn_targets = ffn_targets
|
| 143 |
+
self.return_layer_targets = return_layer_targets
|
| 144 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 145 |
+
self.blocks = nn.Sequential(*[
|
| 146 |
+
Block(
|
| 147 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
|
| 148 |
+
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer,
|
| 149 |
+
ffn_targets=ffn_targets, return_layer_targets=return_layer_targets,
|
| 150 |
+
)
|
| 151 |
+
for i in range(depth)])
|
| 152 |
+
self.norm = norm_layer(embed_dim)
|
| 153 |
+
|
| 154 |
+
# Representation layer
|
| 155 |
+
if representation_size and not distilled:
|
| 156 |
+
self.num_features = representation_size
|
| 157 |
+
self.pre_logits = nn.Sequential(OrderedDict([
|
| 158 |
+
('fc', nn.Linear(embed_dim, representation_size)),
|
| 159 |
+
('act', nn.Tanh())
|
| 160 |
+
]))
|
| 161 |
+
else:
|
| 162 |
+
self.pre_logits = nn.Identity()
|
| 163 |
+
|
| 164 |
+
# Classifier head(s)
|
| 165 |
+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 166 |
+
self.head_dist = None
|
| 167 |
+
if distilled:
|
| 168 |
+
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
| 169 |
+
|
| 170 |
+
self.init_weights(weight_init)
|
| 171 |
+
|
| 172 |
+
def init_weights(self, mode=''):
|
| 173 |
+
assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
|
| 174 |
+
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
| 175 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 176 |
+
if self.dist_token is not None:
|
| 177 |
+
trunc_normal_(self.dist_token, std=.02)
|
| 178 |
+
if mode.startswith('jax'):
|
| 179 |
+
# leave cls token as zeros to match jax impl
|
| 180 |
+
named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self)
|
| 181 |
+
else:
|
| 182 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 183 |
+
self.apply(_init_vit_weights)
|
| 184 |
+
|
| 185 |
+
def _init_weights(self, m):
|
| 186 |
+
# this fn left here for compat with downstream users
|
| 187 |
+
_init_vit_weights(m)
|
| 188 |
+
|
| 189 |
+
@torch.jit.ignore
|
| 190 |
+
def no_weight_decay(self):
|
| 191 |
+
return {'pos_embed', 'cls_token', 'dist_token'}
|
| 192 |
+
|
| 193 |
+
def get_classifier(self):
|
| 194 |
+
if self.dist_token is None:
|
| 195 |
+
return self.head
|
| 196 |
+
else:
|
| 197 |
+
return self.head, self.head_dist
|
| 198 |
+
|
| 199 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
| 200 |
+
self.num_classes = num_classes
|
| 201 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 202 |
+
if self.num_tokens == 2:
|
| 203 |
+
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
| 204 |
+
|
| 205 |
+
def forward_features(self, x):
|
| 206 |
+
x = self.patch_embed(x)
|
| 207 |
+
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
| 208 |
+
if self.dist_token is None:
|
| 209 |
+
x = torch.cat((cls_token, x), dim=1)
|
| 210 |
+
else:
|
| 211 |
+
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 212 |
+
x = self.pos_drop(x + self.pos_embed)
|
| 213 |
+
x = self.blocks(x)
|
| 214 |
+
x = self.norm(x)
|
| 215 |
+
if self.dist_token is None:
|
| 216 |
+
return self.pre_logits(x[:, 0])
|
| 217 |
+
else:
|
| 218 |
+
return x[:, 0], x[:, 1]
|
| 219 |
+
|
| 220 |
+
def forward(self, x):
|
| 221 |
+
x = self.forward_features(x)
|
| 222 |
+
if self.head_dist is not None:
|
| 223 |
+
x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
|
| 224 |
+
if self.training and not torch.jit.is_scripting():
|
| 225 |
+
# during inference, return the average of both classifier predictions
|
| 226 |
+
return x, x_dist
|
| 227 |
+
else:
|
| 228 |
+
return (x + x_dist) / 2
|
| 229 |
+
else:
|
| 230 |
+
x = self.head(x)
|
| 231 |
+
return x
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
|
| 235 |
+
""" ViT weight initialization
|
| 236 |
+
* When called without n, head_bias, jax_impl args it will behave exactly the same
|
| 237 |
+
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
|
| 238 |
+
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
|
| 239 |
+
"""
|
| 240 |
+
if isinstance(module, nn.Linear):
|
| 241 |
+
if name.startswith('head'):
|
| 242 |
+
nn.init.zeros_(module.weight)
|
| 243 |
+
nn.init.constant_(module.bias, head_bias)
|
| 244 |
+
elif name.startswith('pre_logits'):
|
| 245 |
+
lecun_normal_(module.weight)
|
| 246 |
+
nn.init.zeros_(module.bias)
|
| 247 |
+
else:
|
| 248 |
+
if jax_impl:
|
| 249 |
+
nn.init.xavier_uniform_(module.weight)
|
| 250 |
+
if module.bias is not None:
|
| 251 |
+
if 'mlp' in name:
|
| 252 |
+
nn.init.normal_(module.bias, std=1e-6)
|
| 253 |
+
else:
|
| 254 |
+
nn.init.zeros_(module.bias)
|
| 255 |
+
else:
|
| 256 |
+
trunc_normal_(module.weight, std=.02)
|
| 257 |
+
if module.bias is not None:
|
| 258 |
+
nn.init.zeros_(module.bias)
|
| 259 |
+
elif jax_impl and isinstance(module, nn.Conv2d):
|
| 260 |
+
# NOTE conv was left to pytorch default in my original init
|
| 261 |
+
lecun_normal_(module.weight)
|
| 262 |
+
if module.bias is not None:
|
| 263 |
+
nn.init.zeros_(module.bias)
|
| 264 |
+
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
| 265 |
+
nn.init.zeros_(module.bias)
|
| 266 |
+
nn.init.ones_(module.weight)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def compute_gather_ids(masks):
|
| 271 |
+
unmask_indices = masks.logical_not().nonzero(as_tuple=False)
|
| 272 |
+
ids_keep = unmask_indices[:, -1].reshape(masks.shape[0], -1)
|
| 273 |
+
return ids_keep
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class MaskedTransformer(VisionTransformer):
|
| 277 |
+
"""Inherit vision transformer from timm"""
|
| 278 |
+
|
| 279 |
+
def __init__(self, mask_style='ibot', **kwargs):
|
| 280 |
+
super().__init__(**kwargs)
|
| 281 |
+
assert mask_style in ["ibot", "mae", "none"], "mask_style must be `ibot`, `mae`, or `none`"
|
| 282 |
+
|
| 283 |
+
self.patch_size = self.patch_embed.patch_size
|
| 284 |
+
if isinstance(self.patch_size, tuple):
|
| 285 |
+
self.patch_size = self.patch_size[0]
|
| 286 |
+
|
| 287 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 288 |
+
|
| 289 |
+
self.mask_style = mask_style
|
| 290 |
+
if self.mask_style == "ibot":
|
| 291 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
| 292 |
+
torch.nn.init.normal_(self.mask_token, std=.02)
|
| 293 |
+
|
| 294 |
+
def interpolate_pos_encoding(self, x, w, h, npatch):
|
| 295 |
+
previous_dtype = x.dtype
|
| 296 |
+
N = self.pos_embed.shape[1] - 1
|
| 297 |
+
if npatch == N and w == h:
|
| 298 |
+
return self.pos_embed
|
| 299 |
+
pos_embed = self.pos_embed.float()
|
| 300 |
+
class_pos_embed = pos_embed[:, 0]
|
| 301 |
+
patch_pos_embed = pos_embed[:, 1:]
|
| 302 |
+
dim = x.shape[-1]
|
| 303 |
+
w0 = w // self.patch_size
|
| 304 |
+
h0 = h // self.patch_size
|
| 305 |
+
# we add a small number to avoid floating point error in the interpolation
|
| 306 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
| 307 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
| 308 |
+
|
| 309 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 310 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
| 311 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
| 312 |
+
mode="bicubic",
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
| 316 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 317 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
| 318 |
+
|
| 319 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
| 320 |
+
"""
|
| 321 |
+
Args:
|
| 322 |
+
x: data w/ shape [b, c, h, w]
|
| 323 |
+
masks: shape [b, n], n is the number of tokens, 1 means masked, 0 means unmasked
|
| 324 |
+
"""
|
| 325 |
+
b, c, h, w = x.shape
|
| 326 |
+
x = self.patch_embed(x)
|
| 327 |
+
if masks is not None:
|
| 328 |
+
if self.mask_style == 'ibot':
|
| 329 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype), x)
|
| 330 |
+
elif self.mask_style == 'mae': # only gather unmasked patches
|
| 331 |
+
# add pos_embed before shuffle
|
| 332 |
+
pos_embed = self.interpolate_pos_encoding(x, w, h, npatch=x.shape[1])
|
| 333 |
+
x = x + pos_embed[:, 1:, :]
|
| 334 |
+
ids_keep = compute_gather_ids(masks)
|
| 335 |
+
x = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, x.shape[-1]))
|
| 336 |
+
# x = x[masks.logical_not()]
|
| 337 |
+
# x = x.reshape(b, -1, x.size(-1))
|
| 338 |
+
else:
|
| 339 |
+
raise NotImplementedError(f"mask style {self.mask_style} is not supported")
|
| 340 |
+
|
| 341 |
+
if (masks is None) or (self.mask_style != "mae"):
|
| 342 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 343 |
+
x = x + self.interpolate_pos_encoding(x, w, h, npatch=x.shape[1]-1)
|
| 344 |
+
else:
|
| 345 |
+
# mae-style masking, only need to add cls tokens w/ pos embedding
|
| 346 |
+
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
| 347 |
+
x = torch.cat((cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 348 |
+
|
| 349 |
+
return x
|
| 350 |
+
|
| 351 |
+
def forward_features_list(self, x_list, masks_list):
|
| 352 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
| 353 |
+
|
| 354 |
+
num_data = len(x)
|
| 355 |
+
if self.return_layer_targets:
|
| 356 |
+
all_layer_results = [[] for _ in range(num_data)]
|
| 357 |
+
for i, blk in enumerate(self.blocks):
|
| 358 |
+
out = [blk(t) for t in x]
|
| 359 |
+
x = [o[0] for o in out]
|
| 360 |
+
# store layer targets
|
| 361 |
+
for j in range(num_data):
|
| 362 |
+
all_layer_results[j].append(out[j][1])
|
| 363 |
+
all_x = x
|
| 364 |
+
else:
|
| 365 |
+
all_x = [self.blocks(t) for t in x]
|
| 366 |
+
all_layer_results = [None for _ in range(num_data)]
|
| 367 |
+
|
| 368 |
+
output = []
|
| 369 |
+
for x, masks, layer_results in zip(all_x, masks_list, all_layer_results):
|
| 370 |
+
x_norm = self.norm(x)
|
| 371 |
+
output.append(
|
| 372 |
+
{
|
| 373 |
+
"x_norm": x_norm,
|
| 374 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 375 |
+
"x_norm_patchtokens": x_norm[:, 1:],
|
| 376 |
+
"masks": masks,
|
| 377 |
+
"layer_results": layer_results,
|
| 378 |
+
}
|
| 379 |
+
)
|
| 380 |
+
return output
|
| 381 |
+
|
| 382 |
+
def forward_features(self, x, masks=None):
|
| 383 |
+
if isinstance(x, list):
|
| 384 |
+
return self.forward_features_list(x, masks)
|
| 385 |
+
|
| 386 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
| 387 |
+
|
| 388 |
+
if self.return_layer_targets:
|
| 389 |
+
layer_results = []
|
| 390 |
+
for i, blk in enumerate(self.blocks):
|
| 391 |
+
x, lr = blk(x)
|
| 392 |
+
layer_results.append(lr)
|
| 393 |
+
else:
|
| 394 |
+
x = self.blocks(x)
|
| 395 |
+
layer_results = None
|
| 396 |
+
|
| 397 |
+
x_norm = self.norm(x)
|
| 398 |
+
return {
|
| 399 |
+
"x_norm": x_norm,
|
| 400 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 401 |
+
"x_norm_patchtokens": x_norm[:, 1:],
|
| 402 |
+
"masks": masks,
|
| 403 |
+
"layer_results": layer_results,
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
def forward(self, *args, is_training=False, **kwargs):
|
| 407 |
+
ret = self.forward_features(*args, **kwargs)
|
| 408 |
+
if is_training:
|
| 409 |
+
return ret
|
| 410 |
+
else:
|
| 411 |
+
return ret["x_norm_clstoken"]
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def vit_small(patch_size=16, teacher_path=None, **kwargs):
|
| 415 |
+
model = MaskedTransformer(
|
| 416 |
+
patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
| 417 |
+
|
| 418 |
+
if teacher_path is not None:
|
| 419 |
+
checkpoint = torch.load(teacher_path, map_location='cpu')
|
| 420 |
+
|
| 421 |
+
if 'state_dict' in checkpoint:
|
| 422 |
+
pretrained_dict = checkpoint['state_dict']
|
| 423 |
+
elif 'model' in checkpoint:
|
| 424 |
+
pretrained_dict = checkpoint['model']
|
| 425 |
+
else:
|
| 426 |
+
pretrained_dict = checkpoint
|
| 427 |
+
|
| 428 |
+
pretrained_dict = {k.replace("module.visual.", ""): v for k, v in pretrained_dict.items()}
|
| 429 |
+
|
| 430 |
+
missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False)
|
| 431 |
+
print('missing_keys: ', missing_keys)
|
| 432 |
+
print('unexpected_keys: ', unexpected_keys)
|
| 433 |
+
|
| 434 |
+
return model
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def vit_base(patch_size=16, teacher_path=None, **kwargs):
|
| 438 |
+
model = MaskedTransformer(
|
| 439 |
+
patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 440 |
+
|
| 441 |
+
if teacher_path is not None:
|
| 442 |
+
checkpoint = torch.load(teacher_path, map_location='cpu')
|
| 443 |
+
|
| 444 |
+
if 'state_dict' in checkpoint:
|
| 445 |
+
pretrained_dict = checkpoint['state_dict']
|
| 446 |
+
elif 'model' in checkpoint:
|
| 447 |
+
pretrained_dict = checkpoint['model']
|
| 448 |
+
else:
|
| 449 |
+
pretrained_dict = checkpoint
|
| 450 |
+
|
| 451 |
+
pretrained_dict = {k.replace("module.visual.", ""): v for k, v in pretrained_dict.items()}
|
| 452 |
+
|
| 453 |
+
missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False)
|
| 454 |
+
print('missing_keys: ', missing_keys)
|
| 455 |
+
print('unexpected_keys: ', unexpected_keys)
|
| 456 |
+
|
| 457 |
+
return model
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def vit_large(patch_size=14, teacher_path=None, **kwargs):
|
| 461 |
+
model = MaskedTransformer(
|
| 462 |
+
patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
| 463 |
+
|
| 464 |
+
if teacher_path is not None:
|
| 465 |
+
checkpoint = torch.load(teacher_path, map_location='cpu')
|
| 466 |
+
|
| 467 |
+
if 'state_dict' in checkpoint:
|
| 468 |
+
pretrained_dict = checkpoint['state_dict']
|
| 469 |
+
elif 'model' in checkpoint:
|
| 470 |
+
pretrained_dict = checkpoint['model']
|
| 471 |
+
else:
|
| 472 |
+
pretrained_dict = checkpoint
|
| 473 |
+
|
| 474 |
+
pretrained_dict = {k.replace("module.visual.", ""): v for k, v in pretrained_dict.items()}
|
| 475 |
+
|
| 476 |
+
missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False)
|
| 477 |
+
print('missing_keys: ', missing_keys)
|
| 478 |
+
print('unexpected_keys: ', unexpected_keys)
|
| 479 |
+
|
| 480 |
+
return model
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
if __name__ == '__main__':
|
| 484 |
+
import argparse
|
| 485 |
+
from fvcore.nn import FlopCountAnalysis, parameter_count_table
|
| 486 |
+
parser = argparse.ArgumentParser(description='PyTorch resnet Training')
|
| 487 |
+
args = parser.parse_args()
|
| 488 |
+
|
| 489 |
+
with torch.no_grad():
|
| 490 |
+
model = vit_base(patch_size=14, num_classes=0, mask_style='ibot')
|
| 491 |
+
|
| 492 |
+
# x = torch.randn(1, 3, 224, 224)
|
| 493 |
+
# out = model(x)
|
| 494 |
+
# print(out.shape)
|
| 495 |
+
|
| 496 |
+
print(parameter_count_table(model))
|
| 497 |
+
|
| 498 |
+
tensor = torch.rand(1, 3, 224, 224)
|
| 499 |
+
flops = FlopCountAnalysis(model, tensor)
|
| 500 |
+
print("FLOPs: ", flops.total()/1e9)
|
1_feature_extractor/original_images.png
ADDED
|
1_feature_extractor/requirements.txt
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==2.1.0
|
| 2 |
+
accelerate==0.33.0
|
| 3 |
+
aiohttp==3.9.5
|
| 4 |
+
aiosignal==1.3.1
|
| 5 |
+
antlr4-python3-runtime==4.9.3
|
| 6 |
+
async-timeout==4.0.3
|
| 7 |
+
attrs==23.2.0
|
| 8 |
+
cachetools==5.4.0
|
| 9 |
+
certifi==2024.7.4
|
| 10 |
+
charset-normalizer==3.3.2
|
| 11 |
+
click==8.1.7
|
| 12 |
+
cloudpickle==3.0.0
|
| 13 |
+
cmake==3.30.1
|
| 14 |
+
compressai==1.2.6
|
| 15 |
+
contourpy==1.2.1
|
| 16 |
+
cubinlinker-cu11==0.3.0.post2
|
| 17 |
+
cuda-python==11.8.2
|
| 18 |
+
cudf-cu11==24.6.1
|
| 19 |
+
cuml-cu11==24.6.1
|
| 20 |
+
cupy-cuda11x==13.2.0
|
| 21 |
+
cycler==0.12.1
|
| 22 |
+
Cython==3.0.10
|
| 23 |
+
dask==2024.5.1
|
| 24 |
+
dask-cuda==24.6.0
|
| 25 |
+
dask-cudf-cu11==24.6.1
|
| 26 |
+
dask-expr==1.1.1
|
| 27 |
+
diffusers==0.30.0
|
| 28 |
+
distributed==2024.5.1
|
| 29 |
+
distributed-ucxx-cu11==0.38.0
|
| 30 |
+
einops==0.8.0
|
| 31 |
+
fastrlock==0.8.2
|
| 32 |
+
filelock==3.15.4
|
| 33 |
+
fonttools==4.53.1
|
| 34 |
+
frozenlist==1.4.1
|
| 35 |
+
fsspec==2024.6.1
|
| 36 |
+
ftfy==6.2.0
|
| 37 |
+
future==1.0.0
|
| 38 |
+
fvcore==0.1.5.post20221221
|
| 39 |
+
grpcio==1.65.2
|
| 40 |
+
h5py==3.11.0
|
| 41 |
+
huggingface-hub==0.24.5
|
| 42 |
+
idna==3.7
|
| 43 |
+
importlib_metadata==8.0.0
|
| 44 |
+
importlib_resources==6.4.0
|
| 45 |
+
iopath==0.1.10
|
| 46 |
+
Jinja2==3.1.4
|
| 47 |
+
joblib==1.4.2
|
| 48 |
+
kiwisolver==1.4.5
|
| 49 |
+
libucx-cu11==1.15.0.post1
|
| 50 |
+
lightning-utilities==0.11.6
|
| 51 |
+
lit==18.1.8
|
| 52 |
+
llvmlite==0.43.0
|
| 53 |
+
locket==1.0.0
|
| 54 |
+
Markdown==3.6
|
| 55 |
+
markdown-it-py==3.0.0
|
| 56 |
+
MarkupSafe==2.1.5
|
| 57 |
+
matplotlib==3.9.1
|
| 58 |
+
mdurl==0.1.2
|
| 59 |
+
mpmath==1.3.0
|
| 60 |
+
msgpack==1.0.8
|
| 61 |
+
multidict==6.0.5
|
| 62 |
+
mypy-extensions==1.0.0
|
| 63 |
+
networkx==3.2.1
|
| 64 |
+
numba==0.60.0
|
| 65 |
+
numpy==1.26.4
|
| 66 |
+
nvtx==0.2.10
|
| 67 |
+
omegaconf==2.3.0
|
| 68 |
+
open-clip-torch==2.0.2
|
| 69 |
+
opencv-python==4.10.0.84
|
| 70 |
+
packaging==24.1
|
| 71 |
+
pandas==2.2.2
|
| 72 |
+
partd==1.4.2
|
| 73 |
+
pillow==10.4.0
|
| 74 |
+
portalocker==2.10.1
|
| 75 |
+
protobuf==4.25.4
|
| 76 |
+
psutil==6.0.0
|
| 77 |
+
ptxcompiler-cu11==0.8.1.post1
|
| 78 |
+
pyarrow==16.1.0
|
| 79 |
+
pycocotools==2.0.8
|
| 80 |
+
pyDeprecate==0.3.1
|
| 81 |
+
Pygments==2.18.0
|
| 82 |
+
pylibraft-cu11==24.6.0
|
| 83 |
+
pynvml==11.4.1
|
| 84 |
+
pyparsing==3.1.2
|
| 85 |
+
pyre-extensions==0.0.23
|
| 86 |
+
python-dateutil==2.9.0.post0
|
| 87 |
+
pytorch-lightning==1.5.0
|
| 88 |
+
pytorch-msssim==1.0.0
|
| 89 |
+
pytz==2024.1
|
| 90 |
+
PyYAML==6.0.1
|
| 91 |
+
raft-dask-cu11==24.6.0
|
| 92 |
+
rapids-dask-dependency==24.6.0
|
| 93 |
+
regex==2024.7.24
|
| 94 |
+
requests==2.32.3
|
| 95 |
+
rich==13.7.1
|
| 96 |
+
rmm-cu11==24.6.0
|
| 97 |
+
safetensors==0.4.3
|
| 98 |
+
scikit-learn==1.5.1
|
| 99 |
+
scipy==1.13.1
|
| 100 |
+
six==1.16.0
|
| 101 |
+
sortedcontainers==2.4.0
|
| 102 |
+
submitit==1.5.1
|
| 103 |
+
sympy==1.13.1
|
| 104 |
+
tabulate==0.9.0
|
| 105 |
+
tblib==3.0.0
|
| 106 |
+
tensorboard==2.17.0
|
| 107 |
+
tensorboard-data-server==0.7.2
|
| 108 |
+
termcolor==2.4.0
|
| 109 |
+
threadpoolctl==3.5.0
|
| 110 |
+
timm==1.0.7
|
| 111 |
+
tokenizers==0.13.3
|
| 112 |
+
toolz==0.12.1
|
| 113 |
+
torch==2.0.0+cu117
|
| 114 |
+
torch_geometric==2.5.3
|
| 115 |
+
torchmetrics==0.10.3
|
| 116 |
+
torchvision==0.15.0+cu117
|
| 117 |
+
tornado==6.4.1
|
| 118 |
+
tqdm==4.66.4
|
| 119 |
+
transformers==4.27.0
|
| 120 |
+
treelite==4.1.2
|
| 121 |
+
triton==2.0.0
|
| 122 |
+
typing-inspect==0.9.0
|
| 123 |
+
typing_extensions==4.12.2
|
| 124 |
+
tzdata==2024.1
|
| 125 |
+
ucx-py-cu11==0.38.0
|
| 126 |
+
ucxx-cu11==0.38.0
|
| 127 |
+
urllib3==2.2.2
|
| 128 |
+
wcwidth==0.2.13
|
| 129 |
+
Werkzeug==3.0.3
|
| 130 |
+
xformers==0.0.18
|
| 131 |
+
yacs==0.1.8
|
| 132 |
+
yarl==1.9.4
|
| 133 |
+
zict==3.0.0
|
| 134 |
+
zipp==3.19.2
|