Qiyp commited on
Commit
1633fcc
·
1 Parent(s): 3e5c029

code of stage1 & 3, remove large files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. 1_feature_extractor/1_main_training_IB.py +624 -0
  2. 1_feature_extractor/1_training_IB.sh +40 -0
  3. 1_feature_extractor/LICENSE +21 -0
  4. 1_feature_extractor/README copy.md +24 -0
  5. 1_feature_extractor/README.md +17 -0
  6. 1_feature_extractor/__pycache__/augmentations.cpython-39.pyc +0 -0
  7. 1_feature_extractor/__pycache__/datasets.cpython-39.pyc +0 -0
  8. 1_feature_extractor/__pycache__/losses_hint.cpython-39.pyc +0 -0
  9. 1_feature_extractor/__pycache__/models_IB.cpython-39.pyc +0 -0
  10. 1_feature_extractor/__pycache__/models_clip.cpython-39.pyc +0 -0
  11. 1_feature_extractor/__pycache__/models_dinov2.cpython-39.pyc +0 -0
  12. 1_feature_extractor/__pycache__/models_proteus_clip.cpython-39.pyc +0 -0
  13. 1_feature_extractor/__pycache__/models_proteus_dinov2.cpython-39.pyc +0 -0
  14. 1_feature_extractor/__pycache__/models_proteus_synclr.cpython-39.pyc +0 -0
  15. 1_feature_extractor/__pycache__/models_synclr.cpython-39.pyc +0 -0
  16. 1_feature_extractor/__pycache__/samplers.cpython-39.pyc +0 -0
  17. 1_feature_extractor/__pycache__/utils.cpython-39.pyc +0 -0
  18. 1_feature_extractor/augmentations.py +94 -0
  19. 1_feature_extractor/datasets.py +110 -0
  20. 1_feature_extractor/fast_vis.sh +37 -0
  21. 1_feature_extractor/fast_vis_proteus_feats.py +98 -0
  22. 1_feature_extractor/fast_vis_settings_all.py +548 -0
  23. 1_feature_extractor/log/DINOv2_training/log.txt +203 -0
  24. 1_feature_extractor/log/DINOv2_training/log/20240725_001002.log +0 -0
  25. 1_feature_extractor/log/DINOv2_training/log/20240725_084736.log +555 -0
  26. 1_feature_extractor/log/DINOv2_training/log/20240725_085916.log +0 -0
  27. 1_feature_extractor/log/DINOv2_training/log/20240726_110417.log +0 -0
  28. 1_feature_extractor/log/DINOv2_training/log/20240726_171814.log +0 -0
  29. 1_feature_extractor/log/DINOv2_training/log/20240728_153020.log +0 -0
  30. 1_feature_extractor/log/DINOv2_training/log/20240728_214526.log +0 -0
  31. 1_feature_extractor/log/DINOv2_training/log/20240729_102738.log +0 -0
  32. 1_feature_extractor/log/DINOv2_training/log/20240730_084148.log +301 -0
  33. 1_feature_extractor/log/DINOv2_training/log/20240730_085449.log +0 -0
  34. 1_feature_extractor/log/DINOv2_training/log/20240731_102940.log +0 -0
  35. 1_feature_extractor/log/DINOv2_training/log/20240801_091959.log +0 -0
  36. 1_feature_extractor/log/DINOv2_training/log/20240801_155326.log +0 -0
  37. 1_feature_extractor/log/DINOv2_training/log/20240803_163338.log +0 -0
  38. 1_feature_extractor/log/DINOv2_training/log/20240803_231933.log +0 -0
  39. 1_feature_extractor/log/DINOv2_training/log/20240804_144252.log +0 -0
  40. 1_feature_extractor/losses_hint.py +49 -0
  41. 1_feature_extractor/main.py +520 -0
  42. 1_feature_extractor/models_IB.py +40 -0
  43. 1_feature_extractor/models_clip.py +438 -0
  44. 1_feature_extractor/models_dinov2.py +907 -0
  45. 1_feature_extractor/models_proteus_clip.py +101 -0
  46. 1_feature_extractor/models_proteus_dinov2.py +200 -0
  47. 1_feature_extractor/models_proteus_synclr.py +161 -0
  48. 1_feature_extractor/models_synclr.py +500 -0
  49. 1_feature_extractor/original_images.png +0 -0
  50. 1_feature_extractor/requirements.txt +134 -0
1_feature_extractor/1_main_training_IB.py ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2015-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ import argparse
4
+ import datetime
5
+ import numpy as np
6
+ import time
7
+ import torch
8
+ import torch.backends.cudnn as cudnn
9
+ import json
10
+
11
+ from pathlib import Path
12
+
13
+ from timm.models import create_model
14
+ from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
15
+ from timm.scheduler import create_scheduler
16
+ from timm.optim import create_optimizer
17
+ from timm.utils import NativeScaler, get_state_dict, ModelEma
18
+ from augmentations import collate_data_and_cast_aug
19
+ from datasets import build_dataset
20
+
21
+ from losses_hint import DistillationLoss
22
+ from samplers import RASampler
23
+ from functools import partial
24
+
25
+ import importlib
26
+ import utils
27
+ import random
28
+ import math
29
+ from multiprocessing import Value
30
+ from abc import ABC
31
+
32
+ import sys
33
+ from typing import Iterable, Optional
34
+ from timm.data import Mixup
35
+ from timm.utils import accuracy, ModelEma
36
+ import utils
37
+ import logging
38
+ import torch.distributed as dist
39
+ import os
40
+
41
+
42
+ class MaskingGenerator(ABC):
43
+ def __init__(self, input_size):
44
+ if not isinstance(input_size, tuple):
45
+ input_size = (input_size,) * 2
46
+ self.height, self.width = input_size
47
+ self.num_patches = self.height * self.width
48
+
49
+ def __repr__(self):
50
+ raise NotImplementedError
51
+
52
+ def get_shape(self):
53
+ return self.height, self.width
54
+
55
+ def _mask(self, mask, max_mask_patches):
56
+ raise NotImplementedError
57
+
58
+ def get_none_mask(self):
59
+ return np.zeros(shape=self.get_shape(), dtype=bool)
60
+
61
+
62
+ class RandomMaskingGenerator(MaskingGenerator):
63
+ def __init__(
64
+ self,
65
+ input_size,
66
+ ):
67
+ """
68
+ Args:
69
+ input_size: the size of the token map, e.g., 14x14
70
+ """
71
+ super().__init__(input_size)
72
+
73
+ def __repr__(self):
74
+ repr_str = f"Random Generator({self.height}, {self.width})"
75
+ return repr_str
76
+
77
+ def _mask(self, mask, max_mask_patches):
78
+ return super()._mask(mask, max_mask_patches)
79
+
80
+ def __call__(self, num_masking_patches=0):
81
+ if num_masking_patches <= 0:
82
+ return np.zeros(shape=self.get_shape(), dtype=bool)
83
+
84
+ mask = np.hstack([np.ones(num_masking_patches, dtype=bool),
85
+ np.zeros(self.num_patches - num_masking_patches, dtype=bool)])
86
+ np.random.shuffle(mask)
87
+ mask = mask.reshape(self.get_shape())
88
+ return mask
89
+
90
+
91
+ def setup_logger(log_dir, rank=0):
92
+ if rank != 0:
93
+ return # 只有主进程(rank 0)配置日志记录器
94
+ log_formatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s")
95
+ root_logger = logging.getLogger()
96
+ root_logger.setLevel(logging.INFO)
97
+
98
+ log_file_handler = logging.FileHandler(log_dir, encoding='utf-8')
99
+ log_file_handler.setFormatter(log_formatter)
100
+ root_logger.addHandler(log_file_handler)
101
+
102
+ log_stream_handler = logging.StreamHandler(sys.stdout)
103
+ log_stream_handler.setFormatter(log_formatter)
104
+ root_logger.addHandler(log_stream_handler)
105
+
106
+ logging.info('Logging file is %s' % log_dir)
107
+
108
+
109
+ def get_args_parser():
110
+ parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
111
+ parser.add_argument('--batch-size', default=64, type=int)
112
+ parser.add_argument('--epochs', default=300, type=int)
113
+ parser.add_argument('--bce-loss', action='store_true')
114
+ parser.add_argument('--unscale-lr', action='store_true')
115
+
116
+ # Model parameters
117
+ parser.add_argument('--model', default='deit_base_patch16_224', type=str)
118
+ parser.add_argument('--target_model', default='deit_base_patch16_224', type=str)
119
+ parser.add_argument('--input-size', default=224, type=int, help='images input size')
120
+
121
+ parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
122
+ help='Dropout rate (default: 0.)')
123
+ parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
124
+ help='Drop path rate (default: 0.1)')
125
+
126
+ parser.add_argument('--model-ema', action='store_true')
127
+ parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
128
+ parser.set_defaults(model_ema=True)
129
+ parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
130
+ parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
131
+
132
+ # Optimizer parameters
133
+ parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
134
+ help='Optimizer (default: "adamw"')
135
+ parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
136
+ help='Optimizer Epsilon (default: 1e-8)')
137
+ parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
138
+ help='Optimizer Betas (default: None, use opt default)')
139
+ parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
140
+ help='Clip gradient norm (default: None, no clipping)')
141
+ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
142
+ help='SGD momentum (default: 0.9)')
143
+ parser.add_argument('--weight-decay', type=float, default=0.05,
144
+ help='weight decay (default: 0.05)')
145
+ # Learning rate schedule parameters
146
+ parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
147
+ help='LR scheduler (default: "cosine"')
148
+ parser.add_argument('--lr', type=float, default=4e-4, metavar='LR',
149
+ help='learning rate (default: 5e-4)')
150
+ parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
151
+ help='learning rate noise on/off epoch percentages')
152
+ parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
153
+ help='learning rate noise limit percent (default: 0.67)')
154
+ parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
155
+ help='learning rate noise std-dev (default: 1.0)')
156
+ parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
157
+ help='warmup learning rate (default: 1e-6)')
158
+ parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
159
+ help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
160
+
161
+ parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
162
+ help='epoch interval to decay LR')
163
+ parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
164
+ help='epochs to warmup LR, if scheduler supports')
165
+ parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
166
+ help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
167
+ parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
168
+ help='patience epochs for Plateau LR scheduler (default: 10')
169
+ parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
170
+ help='LR decay rate (default: 0.1)')
171
+
172
+ # Augmentation parameters
173
+ parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT',
174
+ help='Color jitter factor (default: 0.3)')
175
+ parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
176
+ help='Use AutoAugment policy. "v0" or "original". " + \
177
+ "(default: rand-m9-mstd0.5-inc1)'),
178
+ parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
179
+ parser.add_argument('--train-interpolation', type=str, default='bicubic',
180
+ help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
181
+
182
+ parser.add_argument('--repeated-aug', action='store_true')
183
+ parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
184
+ parser.set_defaults(repeated_aug=True)
185
+
186
+ parser.add_argument('--train-mode', action='store_true')
187
+ parser.add_argument('--no-train-mode', action='store_false', dest='train_mode')
188
+ parser.set_defaults(train_mode=True)
189
+
190
+ parser.add_argument('--ThreeAugment', action='store_true') #3augment
191
+
192
+ parser.add_argument('--src', action='store_true') #simple random crop
193
+
194
+ # add dataset parameters
195
+ parser.add_argument('--global_crops_size', '--img_size', default=224, type=int,
196
+ help="this should be equal to image size")
197
+ parser.add_argument('--patch_size', default=16, type=int,
198
+ help="patch size for vit patch embedding")
199
+
200
+ # add masking parameter
201
+ parser.add_argument('--mask_ratio', default=(0.1, 0.5), type=float, nargs='+',
202
+ help="mask ratio can be either a value or a range")
203
+ parser.add_argument('--mask_probability', default=0., type=float,
204
+ help="how many samples with be applied with masking")
205
+ parser.add_argument('--mask_first_n', action='store_true',
206
+ help="mask the first n sample to avoid shuffling. Needed for MAE-style encoder")
207
+ parser.add_argument('--clone_batch', default=1, type=int,
208
+ help="how many times to clone the batch for masking (default: 1, not cloning)")
209
+
210
+ # * Random Erase params
211
+ parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
212
+ help='Random erase prob (default: 0.25)')
213
+ parser.add_argument('--remode', type=str, default='pixel',
214
+ help='Random erase mode (default: "pixel")')
215
+ parser.add_argument('--recount', type=int, default=1,
216
+ help='Random erase count (default: 1)')
217
+ parser.add_argument('--resplit', action='store_true', default=False,
218
+ help='Do not random erase first (clean) augmentation split')
219
+
220
+ # * Mixup params
221
+ parser.add_argument('--mixup', type=float, default=0.8,
222
+ help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
223
+ parser.add_argument('--cutmix', type=float, default=1.0,
224
+ help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
225
+ parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
226
+ help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
227
+ parser.add_argument('--mixup-prob', type=float, default=1.0,
228
+ help='Probability of performing mixup or cutmix when either/both is enabled')
229
+ parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
230
+ help='Probability of switching to cutmix when both mixup and cutmix enabled')
231
+ parser.add_argument('--mixup-mode', type=str, default='batch',
232
+ help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
233
+
234
+ # Distillation parameters
235
+ parser.add_argument('--teacher-model', default='base', type=str)
236
+ parser.add_argument('--teacher-path', type=str, default='')
237
+ parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
238
+ parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
239
+ parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
240
+ parser.add_argument('--lambda_token', type=float, default=1.0)
241
+ parser.add_argument('--lambda_fea', type=float, default=1.0)
242
+ parser.add_argument('--lambda_patch', type=float, default=1.0)
243
+
244
+ # * Cosub params
245
+ parser.add_argument('--cosub', action='store_true')
246
+
247
+ # * Finetuning params
248
+ parser.add_argument('--finetune', default='', help='finetune from checkpoint')
249
+ parser.add_argument('--attn-only', action='store_true')
250
+ parser.add_argument('--weight_inherit', default='')
251
+
252
+ # Dataset parameters
253
+ parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
254
+ help='dataset path')
255
+ parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'IMNET_ibot', 'IMNET_ibot_aug', 'IMNET_ibot_fast_aug', 'INAT', 'INAT19', 'IMNET_L', 'IMNET_L_ibot'],
256
+ type=str, help='Image Net dataset path')
257
+ parser.add_argument('--inat-category', default='name',
258
+ choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
259
+ type=str, help='semantic granularity')
260
+
261
+ parser.add_argument('--output_dir', default='',
262
+ help='path where to save, empty for no saving')
263
+ parser.add_argument('--log_dir', default='/data1/qiyp/Proteus-pytorch/pretrain/log/DINOv2_training/log',
264
+ type=str, help='saving logging info every 20 iters')
265
+ parser.add_argument('--device', default='cuda',
266
+ help='device to use for training / testing')
267
+ parser.add_argument('--seed', default=0, type=int)
268
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
269
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
270
+ help='start epoch')
271
+ parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
272
+ parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation")
273
+ parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
274
+ parser.add_argument('--num_workers', default=10, type=int)
275
+ parser.add_argument('--pin-mem', action='store_true',
276
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
277
+ parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
278
+ help='')
279
+ parser.set_defaults(pin_mem=True)
280
+
281
+ # distributed training parameters
282
+ parser.add_argument('--distributed', action='store_true', default=False, help='Enabling distributed training')
283
+ parser.add_argument('--world_size', default=1, type=int,
284
+ help='number of distributed processes')
285
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
286
+ return parser
287
+
288
+
289
+ def show_learnable_params(model):
290
+ enabled = set()
291
+ for name, param in model.named_parameters():
292
+ if param.requires_grad:
293
+ enabled.add(name)
294
+ # print("Parameters to be updated: ")
295
+ logging.info("Parameters to be updated: ")
296
+ for each in enabled:
297
+ # print('\t{}\n'.format(str(each)))
298
+ logging.info('\t{}\n'.format(str(each)))
299
+ # print('\n')
300
+ logging.info('\n')
301
+
302
+ def show_unlearnable_params(model):
303
+ disabled = set()
304
+ for name, param in model.named_parameters():
305
+ if not param.requires_grad:
306
+ disabled.add(name)
307
+
308
+ logging.info("Parameters that are not being updated: ")
309
+ for each in disabled:
310
+ logging.info('\t{}'.format(str(each)))
311
+ logging.info('\n')
312
+
313
+ def main(args):
314
+ utils.init_distributed_mode(args)
315
+
316
+ print(args)
317
+
318
+ device = torch.device(args.device)
319
+ # 获取当前进程的 rank
320
+ rank = dist.get_rank() if dist.is_initialized() else 0
321
+ # set up logger
322
+ os.makedirs(args.log_dir, exist_ok=True)
323
+ setup_logger(args.log_dir + '/' + time.strftime('%Y%m%d_%H%M%S') + '.log', rank)
324
+ logging.info('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
325
+ logging.info("{}".format(args).replace(', ', ',\n') + '\n')
326
+
327
+ # fix the seed for reproducibility
328
+ seed = args.seed + utils.get_rank()
329
+ torch.manual_seed(seed)
330
+ np.random.seed(seed)
331
+ # random.seed(seed)
332
+
333
+ cudnn.benchmark = True
334
+
335
+ dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
336
+ logging.info(dataset_train)
337
+
338
+ if args.distributed:
339
+ num_tasks = utils.get_world_size()
340
+ global_rank = utils.get_rank()
341
+ if args.repeated_aug:
342
+ sampler_train = RASampler(
343
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
344
+ )
345
+ else:
346
+ sampler_train = torch.utils.data.DistributedSampler(
347
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
348
+ )
349
+ else:
350
+ sampler_train = torch.utils.data.RandomSampler(dataset_train)
351
+ logging.info("Sampler_train = %s" % str(sampler_train))
352
+
353
+ n_tokens = (args.global_crops_size // args.patch_size) ** 2
354
+ mask_generator = RandomMaskingGenerator(
355
+ input_size=args.global_crops_size // args.patch_size,
356
+ )
357
+
358
+ collate_fn = partial(
359
+ collate_data_and_cast_aug,
360
+ mask_ratio=args.mask_ratio,
361
+ mask_probability=args.mask_probability,
362
+ dtype=torch.half, # half precision
363
+ n_tokens=n_tokens,
364
+ mask_first_n=args.mask_first_n,
365
+ mask_generator=mask_generator,
366
+ clone_batch=args.clone_batch,
367
+ )
368
+
369
+ data_loader_train = torch.utils.data.DataLoader(
370
+ dataset_train, sampler=sampler_train,
371
+ batch_size=args.batch_size,
372
+ num_workers=args.num_workers,
373
+ pin_memory=args.pin_mem,
374
+ drop_last=True,
375
+ collate_fn=collate_fn,
376
+ )
377
+
378
+ mixup_fn = None
379
+
380
+ print(f"Creating model: {args.model}") # models_proteus_dinov2
381
+ meta_arch_module = importlib.import_module(args.model)
382
+ MetaArch = meta_arch_module.MetaArch
383
+
384
+ model = MetaArch(args)
385
+ logging.info("Model = %s" % str(model))
386
+
387
+ if args.finetune:
388
+ checkpoint = torch.load(args.finetune, map_location='cpu')
389
+
390
+ if 'state_dict' in checkpoint:
391
+ pretrained_dict = checkpoint['state_dict']
392
+ elif 'model' in checkpoint:
393
+ pretrained_dict = checkpoint['model']
394
+ else:
395
+ pretrained_dict = checkpoint
396
+
397
+ missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False)
398
+ # print('missing_keys: ', missing_keys)
399
+ # print('unexpected_keys: ', unexpected_keys)
400
+ logging.info('Finetuning from %s' % args.finetune)
401
+ logging.info('missing_keys: %s' % str(missing_keys))
402
+ logging.info('unexpected_keys: %s' % str(unexpected_keys))
403
+
404
+ if args.attn_only:
405
+ for name_p,p in model.named_parameters():
406
+ if '.attn.' in name_p:
407
+ p.requires_grad = True
408
+ else:
409
+ p.requires_grad = False
410
+ try:
411
+ model.head.weight.requires_grad = True
412
+ model.head.bias.requires_grad = True
413
+ except:
414
+ model.fc.weight.requires_grad = True
415
+ model.fc.bias.requires_grad = True
416
+ try:
417
+ model.pos_embed.requires_grad = True
418
+ except:
419
+ print('no position encoding')
420
+ try:
421
+ for p in model.patch_embed.parameters():
422
+ p.requires_grad = False
423
+ except:
424
+ print('no patch embed')
425
+
426
+ model.to(device)
427
+
428
+ model_ema = None
429
+ if args.model_ema:
430
+ # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
431
+ model_ema = ModelEma(
432
+ model.student.backbone,
433
+ decay=args.model_ema_decay,
434
+ device='cpu' if args.model_ema_force_cpu else '',
435
+ resume='')
436
+
437
+ model_without_ddp = model
438
+ if args.distributed:
439
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
440
+ model_without_ddp = model.module
441
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
442
+ # print('number of params:', n_parameters)
443
+ logging.info('number of params: %s' % n_parameters)
444
+
445
+ if not args.unscale_lr:
446
+ logging.info('base lr = %s' % args.lr)
447
+ linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
448
+ args.lr = linear_scaled_lr
449
+ logging.info('actural lr = %s' % linear_scaled_lr)
450
+
451
+ optimizer = create_optimizer(args, model_without_ddp)
452
+ loss_scaler = NativeScaler()
453
+
454
+ lr_scheduler, _ = create_scheduler(args, optimizer)
455
+
456
+ output_dir = Path(args.output_dir)
457
+ if args.resume:
458
+ if args.resume.startswith('https'):
459
+ checkpoint = torch.hub.load_state_dict_from_url(
460
+ args.resume, map_location='cpu', check_hash=True)
461
+ else:
462
+ checkpoint = torch.load(args.resume, map_location='cpu')
463
+
464
+ model_without_ddp.load_state_dict(checkpoint['model'])
465
+ if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
466
+ optimizer.load_state_dict(checkpoint['optimizer'])
467
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
468
+ args.start_epoch = checkpoint['epoch'] + 1
469
+ if args.model_ema:
470
+ utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
471
+ if 'scaler' in checkpoint:
472
+ loss_scaler.load_state_dict(checkpoint['scaler'])
473
+ lr_scheduler.step(args.start_epoch)
474
+ logging.info('Resuming from %s' % args.resume)
475
+
476
+ # print(f"Start training for {args.epochs} epochs")
477
+ logging.info("Start training for %s epochs" % args.epochs)
478
+ start_time = time.time()
479
+ max_accuracy = 0.0
480
+ for epoch in range(args.start_epoch, args.epochs):
481
+ if args.distributed:
482
+ data_loader_train.sampler.set_epoch(epoch)
483
+
484
+ if epoch < 5:
485
+ # 前5个epoch仅放开 entropy model 的参数
486
+ for name, param in model.named_parameters():
487
+ if 'info_bottleneck' in name:
488
+ param.requires_grad = True
489
+ else:
490
+ param.requires_grad = False
491
+ if epoch == 0:
492
+ show_learnable_params(model)
493
+ else:
494
+ # 其余epoch放开所有参数,但固定model.teacher的参数
495
+ for name, param in model.named_parameters():
496
+ param.requires_grad = True
497
+ for name, param in model.named_parameters():
498
+ if 'teacher' in name:
499
+ param.requires_grad = False
500
+ if epoch == 5:
501
+ show_unlearnable_params(model)
502
+
503
+ train_stats = train_one_epoch(
504
+ model, data_loader_train,
505
+ optimizer, device, epoch, loss_scaler,
506
+ args.clip_grad, model_ema, mixup_fn,
507
+ set_training_mode=args.train_mode, # keep in eval mode for deit finetuning / train mode for training and deit III finetuning
508
+ args = args,
509
+ )
510
+
511
+ lr_scheduler.step(epoch)
512
+ if args.output_dir:
513
+ checkpoint_paths = [output_dir / 'checkpoint.pth']
514
+ for checkpoint_path in checkpoint_paths:
515
+ utils.save_on_master({
516
+ 'model': model_without_ddp.state_dict(),
517
+ 'optimizer': optimizer.state_dict(),
518
+ 'lr_scheduler': lr_scheduler.state_dict(),
519
+ 'epoch': epoch,
520
+ 'model_ema': get_state_dict(model_ema),
521
+ 'scaler': loss_scaler.state_dict(),
522
+ 'args': args,
523
+ }, checkpoint_path)
524
+ if args.output_dir and (epoch % 5 == 0 or epoch + 1 == args.epochs):
525
+ checkpoint_path = output_dir / f'checkpoint{epoch:04}.pth'
526
+ utils.save_on_master({
527
+ 'model': model_without_ddp.state_dict(),
528
+ 'optimizer': optimizer.state_dict(),
529
+ 'lr_scheduler': lr_scheduler.state_dict(),
530
+ 'epoch': epoch,
531
+ 'model_ema': get_state_dict(model_ema),
532
+ 'scaler': loss_scaler.state_dict(),
533
+ 'args': args,
534
+ }, checkpoint_path)
535
+
536
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
537
+ 'epoch': epoch,
538
+ 'n_parameters': n_parameters}
539
+
540
+ if args.output_dir and utils.is_main_process():
541
+ with (output_dir / "log.txt").open("a") as f:
542
+ f.write(json.dumps(log_stats) + "\n")
543
+
544
+ total_time = time.time() - start_time
545
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
546
+ # print('Training time {}'.format(total_time_str))
547
+ logging.info('Training time %s' % total_time_str)
548
+
549
+
550
+ def train_one_epoch(model: torch.nn.Module,
551
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
552
+ device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
553
+ model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
554
+ set_training_mode=True, args = None):
555
+ model.train(set_training_mode)
556
+ metric_logger = utils.MetricLogger(delimiter=" ")
557
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
558
+ header = 'Epoch: [{}]'.format(epoch)
559
+ print_freq = 20
560
+
561
+ loader_len = len(data_loader)
562
+
563
+ for data_iter_step, inputs_dict in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
564
+
565
+ for k, v in inputs_dict.items():
566
+ if isinstance(v, torch.Tensor):
567
+ inputs_dict[k] = v.to(device, non_blocking=True)
568
+
569
+ with torch.cuda.amp.autocast():
570
+ loss_dict = model(inputs_dict)
571
+
572
+ loss = loss_dict["loss"]
573
+ patch_loss = loss_dict["patch_loss"]
574
+ fea_loss = loss_dict["fea_loss"]
575
+ token_loss = loss_dict["token_loss"]
576
+ bpp_loss = loss_dict["bpp_loss"]
577
+ task_loss = loss_dict["task_loss"]
578
+
579
+ patch_loss_value = patch_loss.item()
580
+ token_loss_value = token_loss.item()
581
+ fea_loss_value = fea_loss.item()
582
+ bpp_loss_value = bpp_loss.item()
583
+ task_loss_value = task_loss.item()
584
+ loss_value = loss.item()
585
+
586
+
587
+ if not math.isfinite(loss_value):
588
+ # print("Loss is {}, stopping training".format(loss_value))
589
+ logging.info("Loss is %s, stopping training" % loss_value)
590
+ logging.info("bpp_loss is {}, patch_loss is {}, token_loss is {}, fea_loss is {}".format(bpp_loss_value, patch_loss_value, token_loss_value, fea_loss_value))
591
+ sys.exit(1)
592
+
593
+ optimizer.zero_grad()
594
+
595
+ # this attribute is added by timm on one optimizer (adahessian)
596
+ is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
597
+ loss_scaler(loss, optimizer, clip_grad=max_norm,
598
+ parameters=model.parameters(), create_graph=is_second_order)
599
+
600
+ torch.cuda.synchronize()
601
+ if model_ema is not None:
602
+ model_ema.update(model.module.student.backbone)
603
+
604
+ metric_logger.update(loss=loss_value)
605
+ metric_logger.update(task_loss=task_loss_value)
606
+ metric_logger.update(bpp_loss=bpp_loss_value)
607
+ metric_logger.update(patch_loss=patch_loss_value)
608
+ metric_logger.update(token_loss=token_loss_value)
609
+ metric_logger.update(fea_loss=fea_loss_value)
610
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
611
+ # gather the stats from all processes
612
+ metric_logger.synchronize_between_processes()
613
+ # print("Averaged stats:", metric_logger)
614
+ logging.info("Averaged stats: {}".format(metric_logger))
615
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
616
+
617
+
618
+
619
+ if __name__ == '__main__':
620
+ parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
621
+ args = parser.parse_args()
622
+ if args.output_dir:
623
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
624
+ main(args)
1_feature_extractor/1_training_IB.sh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ #### access DINOv2
4
+ export CUDA_VISIBLE_DEVICES=0,1,2;
5
+
6
+ python -m torch.distributed.launch --nproc_per_node=3 --use_env 1_main_training_IB.py \
7
+ --batch-size 48 --warmup-epochs 5 --epochs 200 \
8
+ --data-set IMNET --data-path '/data1/datasets/imagenet_fold' \
9
+ --teacher-model vit_large --target_model vit_base --model models_proteus_dinov2 \
10
+ --patch_size 14 --mask_probability 0.5 --mask_ratio 0.5 --mask_first_n \
11
+ --lambda_token 1.0 --lambda_fea 1.05 --lambda_patch 1.05 \
12
+ --resume "/data0/qiyp/Proteus-pytorch/pretrain/log/DINOv2_training/checkpoint0160.pth" \
13
+ --log_dir '/data0/qiyp/Proteus-pytorch/pretrain/log/DINOv2_training/log/' \
14
+ --output_dir log/DINOv2_training;
15
+
16
+
17
+
18
+ #### access SynCLR
19
+
20
+ # python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py \
21
+ # --batch-size 128 --warmup-epochs 5 --epochs 300 \
22
+ # --data-set IMNET --data-path imagenet_path \
23
+ # --teacher-model vit_large --target_model vit_base --model models_proteus_synclr \
24
+ # --teacher-path pretrained_synclr_path \
25
+ # --patch_size 14 --mask_probability 0.5 --mask_ratio 0.5 --mask_first_n \
26
+ # --lambda_token 1.0 --lambda_fea 1.0 --lambda_patch 1.0 \
27
+ # --output_dir log/SynCLR_training;
28
+
29
+
30
+
31
+ #### access CLIP
32
+
33
+ # python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py \
34
+ # --batch-size 128 --warmup-epochs 5 --epochs 300 \
35
+ # --data-set IMNET --data-path imagenet_path \
36
+ # --teacher-model vit_large --target_model vit_base --model models_proteus_clip \
37
+ # --teacher-path pretrained_clip_path \
38
+ # --patch_size 14 --mask_probability 0.5 --mask_ratio 0.5 --mask_first_n \
39
+ # --lambda_token 1.0 --lambda_fea 0.0 --lambda_patch 0.0 \
40
+ # --output_dir log/CLIP_training;
1_feature_extractor/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Yunpeng Qi
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
1_feature_extractor/README copy.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pre-training on ImageNet-1K
2
+
3
+
4
+ ## Installation
5
+ Please follow the installation instructions in [DINOv2](https://github.com/facebookresearch/dinov2/tree/main?tab=readme-ov-file#installation) and install timm==0.9.16 as well.
6
+
7
+ ## Dataset
8
+ We prepare ImageNet-1K following the instructions in [DeiT](https://github.com/facebookresearch/deit/blob/main/README_deit.md#data-preparation).
9
+
10
+ ## Training
11
+ 1. Specify the directory of datasets with `data-path` in the training script `run_pretrain.sh`.
12
+ 2. Use the `teacher-model` and `target_model` parameters to select the appropriate teacher and student models.
13
+ 3. Specify the model choices with `model` to choose from DINOv2, SynCLR, CLIP.
14
+ 4. For SynCLR and CLIP training, use the `teacher-path` parameter to indicate the path to the pre-trained teacher model.
15
+ 5. Simply run the training script as follows:
16
+
17
+ ```
18
+ bash run_pretrain.sh
19
+ ```
20
+
21
+
22
+ ## Acknowledgment
23
+
24
+ This part is heavily build upon [DeiT](https://github.com/facebookresearch/deit?tab=readme-ov-file), [DINOv2](https://github.com/facebookresearch/dinov2), [SynCLR](https://github.com/google-research/syn-rep-learn/tree/main/SynCLR). We gratefully thank the authors for their wonderful works.
1_feature_extractor/README.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1_feature_extractor
2
+ training information bottleneck
3
+
4
+ IF Model: https://huggingface.co/Qiyp/1_feature_extractor
5
+
6
+ Installation:
7
+ Clone the repository and then use the provided requirements.txt to install the dependencies:
8
+ pip install -r requirements.txt
9
+
10
+ put the proteus_vitb_backbone.pth into ./ckpt
11
+
12
+ load pretrained Feature Extractor for further training: --resume + Feature Extractor path
13
+
14
+ load pretrained Feature Extractor for feature extractraction: --finetune + Feature Extractor path
15
+
16
+ training scripts: 1_training_IB.sh
17
+ vis feature's semantic information: train_dec.sh
1_feature_extractor/__pycache__/augmentations.cpython-39.pyc ADDED
Binary file (2.08 kB). View file
 
1_feature_extractor/__pycache__/datasets.cpython-39.pyc ADDED
Binary file (3.12 kB). View file
 
1_feature_extractor/__pycache__/losses_hint.cpython-39.pyc ADDED
Binary file (2.17 kB). View file
 
1_feature_extractor/__pycache__/models_IB.cpython-39.pyc ADDED
Binary file (1.62 kB). View file
 
1_feature_extractor/__pycache__/models_clip.cpython-39.pyc ADDED
Binary file (12.6 kB). View file
 
1_feature_extractor/__pycache__/models_dinov2.cpython-39.pyc ADDED
Binary file (27.1 kB). View file
 
1_feature_extractor/__pycache__/models_proteus_clip.cpython-39.pyc ADDED
Binary file (2.26 kB). View file
 
1_feature_extractor/__pycache__/models_proteus_dinov2.cpython-39.pyc ADDED
Binary file (4.61 kB). View file
 
1_feature_extractor/__pycache__/models_proteus_synclr.cpython-39.pyc ADDED
Binary file (3.58 kB). View file
 
1_feature_extractor/__pycache__/models_synclr.cpython-39.pyc ADDED
Binary file (16.1 kB). View file
 
1_feature_extractor/__pycache__/samplers.cpython-39.pyc ADDED
Binary file (2.25 kB). View file
 
1_feature_extractor/__pycache__/utils.cpython-39.pyc ADDED
Binary file (7.63 kB). View file
 
1_feature_extractor/augmentations.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import random
17
+ from torchvision import transforms
18
+ import torch
19
+
20
+ logger = logging.getLogger("dinov2")
21
+
22
+
23
+
24
+ def collate_data_and_cast_aug(
25
+ samples_list,
26
+ mask_ratio,
27
+ mask_probability,
28
+ dtype,
29
+ n_tokens=None,
30
+ mask_first_n=False,
31
+ mask_generator=None,
32
+ clone_batch=1,
33
+ ):
34
+ # dtype = torch.half # TODO: Remove
35
+
36
+ n_global_crops = 1
37
+
38
+ assert n_global_crops > 0, "global crops number should be > 0"
39
+ collated_global_crops = torch.stack([s[i] for i in range(n_global_crops) for s in samples_list])
40
+
41
+ labels = [s[1] for s in samples_list]
42
+ labels = torch.LongTensor(labels)
43
+ collated_global_labels = labels.repeat(n_global_crops)
44
+
45
+ B = len(collated_global_crops)
46
+ N = n_tokens
47
+ n_samples_masked = int(B * mask_probability)
48
+
49
+ masks_list = []
50
+ upperbound = 0
51
+
52
+ masks_enc = torch.full((1,), 0, dtype=torch.int32)
53
+ masks_pred = torch.full((1,), 0, dtype=torch.int32)
54
+ # specify the number of masks to append
55
+ number_masks = n_samples_masked * clone_batch
56
+ # do per-sample masking
57
+ if isinstance(mask_ratio, (tuple, list)) and len(mask_ratio) == 2:
58
+ probs = torch.linspace(*mask_ratio, number_masks + 1)
59
+ for i in range(0, number_masks):
60
+ prob_min = probs[i]
61
+ prob_max = probs[i + 1]
62
+ masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max)))))
63
+ upperbound += int(N * prob_max)
64
+ else:
65
+ mask_ratio = mask_ratio[0]
66
+ # apply the same mask ratio to all images
67
+ for i in range(0, number_masks):
68
+ masks_list.append(torch.BoolTensor(mask_generator(int(N * mask_ratio))))
69
+ upperbound += int(N * mask_ratio)
70
+
71
+ # append masks for unmasked samples
72
+ for i in range(n_samples_masked, B):
73
+ # masks_list.append(torch.BoolTensor(mask_generator(0)))
74
+ masks_list.append(torch.BoolTensor(mask_generator.get_none_mask()))
75
+
76
+ if not mask_first_n and mask_probability > 0.0: # shuffle masking -- not shuffling for mae-style
77
+ random.shuffle(masks_list)
78
+
79
+ collated_masks = torch.stack(masks_list).flatten(1)
80
+ mask_indices_list = collated_masks.flatten().nonzero().flatten()
81
+
82
+ masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks]
83
+
84
+ return {
85
+ "collated_global_crops": collated_global_crops.to(dtype),
86
+ "collated_global_labels": collated_global_labels,
87
+ "collated_masks": collated_masks,
88
+ "mask_indices_list": mask_indices_list,
89
+ "masks_weight": masks_weight,
90
+ "upperbound": upperbound,
91
+ "n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long),
92
+ "masks_enc": masks_enc,
93
+ "masks_pred": masks_pred,
94
+ }
1_feature_extractor/datasets.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2015-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ import os
4
+ import json
5
+ from torchvision.datasets import DatasetFolder
6
+ from torchvision.io import read_image
7
+ from torchvision import datasets, transforms
8
+ from torchvision.datasets.folder import ImageFolder, default_loader
9
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
10
+ from timm.data import create_transform
11
+ from PIL import Image
12
+
13
+
14
+ class INatDataset(ImageFolder):
15
+ def __init__(self, root, train=True, year=2018, transform=None, target_transform=None,
16
+ category='name', loader=default_loader):
17
+ self.transform = transform
18
+ self.loader = loader
19
+ self.target_transform = target_transform
20
+ self.year = year
21
+ # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
22
+ path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json')
23
+ with open(path_json) as json_file:
24
+ data = json.load(json_file)
25
+
26
+ with open(os.path.join(root, 'categories.json')) as json_file:
27
+ data_catg = json.load(json_file)
28
+
29
+ path_json_for_targeter = os.path.join(root, f"train{year}.json")
30
+
31
+ with open(path_json_for_targeter) as json_file:
32
+ data_for_targeter = json.load(json_file)
33
+
34
+ targeter = {}
35
+ indexer = 0
36
+ for elem in data_for_targeter['annotations']:
37
+ king = []
38
+ king.append(data_catg[int(elem['category_id'])][category])
39
+ if king[0] not in targeter.keys():
40
+ targeter[king[0]] = indexer
41
+ indexer += 1
42
+ self.nb_classes = len(targeter)
43
+
44
+ self.samples = []
45
+ for elem in data['images']:
46
+ cut = elem['file_name'].split('/')
47
+ target_current = int(cut[2])
48
+ path_current = os.path.join(root, cut[0], cut[2], cut[3])
49
+
50
+ categors = data_catg[target_current]
51
+ target_current_true = targeter[categors[category]]
52
+ self.samples.append((path_current, target_current_true))
53
+
54
+ # __getitem__ and __len__ inherited from ImageFolder
55
+
56
+
57
+ def build_dataset(is_train, args):
58
+ transform = build_transform(is_train, args)
59
+
60
+ if args.data_set == 'CIFAR':
61
+ dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform)
62
+ nb_classes = 100
63
+ elif args.data_set == 'IMNET':
64
+ root = os.path.join(args.data_path, 'train' if is_train else 'val')
65
+ dataset = datasets.ImageFolder(root, transform=transform)
66
+ nb_classes = 1000
67
+ elif args.data_set == 'INAT':
68
+ dataset = INatDataset(args.data_path, train=is_train, year=2018,
69
+ category=args.inat_category, transform=transform)
70
+ nb_classes = dataset.nb_classes
71
+ elif args.data_set == 'INAT19':
72
+ dataset = INatDataset(args.data_path, train=is_train, year=2019,
73
+ category=args.inat_category, transform=transform)
74
+ nb_classes = dataset.nb_classes
75
+
76
+ return dataset, nb_classes
77
+
78
+
79
+ def build_transform(is_train, args):
80
+ resize_im = args.input_size > 32
81
+ if is_train:
82
+ # this should always dispatch to transforms_imagenet_train
83
+ transform = create_transform(
84
+ input_size=args.input_size,
85
+ is_training=True,
86
+ color_jitter=args.color_jitter,
87
+ auto_augment=args.aa,
88
+ interpolation=args.train_interpolation,
89
+ re_prob=args.reprob,
90
+ re_mode=args.remode,
91
+ re_count=args.recount,
92
+ )
93
+ if not resize_im:
94
+ # replace RandomResizedCropAndInterpolation with
95
+ # RandomCrop
96
+ transform.transforms[0] = transforms.RandomCrop(
97
+ args.input_size, padding=4)
98
+ return transform
99
+
100
+ t = []
101
+ if resize_im:
102
+ size = int(args.input_size / args.eval_crop_ratio)
103
+ t.append(
104
+ transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
105
+ )
106
+ t.append(transforms.CenterCrop(args.input_size))
107
+
108
+ t.append(transforms.ToTensor())
109
+ t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
110
+ return transforms.Compose(t)
1_feature_extractor/fast_vis.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ #### access DINOv2
4
+ export CUDA_VISIBLE_DEVICES=6;
5
+
6
+ python fast_vis_settings_all.py \
7
+ --batch-size 64 --warmup-epochs 5 --epochs 300 \
8
+ --data-set IMNET --data-path '/data1/datasets/imagenet_fold' \
9
+ --teacher-model vit_large --target_model vit_base --model models_proteus_dinov2 \
10
+ --patch_size 14 --mask_probability 0.5 --mask_ratio 0.5 --mask_first_n \
11
+ --lambda_token 1.0 --lambda_fea 1.0 --lambda_patch 1.0 \
12
+ --finetune "/data0/qiyp/Proteus-pytorch/pretrain/log/DINOv2_training/checkpoint0160.pth" \
13
+ --output_dir log/DINOv2_training;
14
+
15
+ #### access SynCLR
16
+
17
+ # python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py \
18
+ # --batch-size 128 --warmup-epochs 5 --epochs 300 \
19
+ # --data-set IMNET --data-path imagenet_path \
20
+ # --teacher-model vit_large --target_model vit_base --model models_proteus_synclr \
21
+ # --teacher-path pretrained_synclr_path \
22
+ # --patch_size 14 --mask_probability 0.5 --mask_ratio 0.5 --mask_first_n \
23
+ # --lambda_token 1.0 --lambda_fea 1.0 --lambda_patch 1.0 \
24
+ # --output_dir log/SynCLR_training;
25
+
26
+
27
+
28
+ #### access CLIP
29
+
30
+ # python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py \
31
+ # --batch-size 128 --warmup-epochs 5 --epochs 300 \
32
+ # --data-set IMNET --data-path imagenet_path \
33
+ # --teacher-model vit_large --target_model vit_base --model models_proteus_clip \
34
+ # --teacher-path pretrained_clip_path \
35
+ # --patch_size 14 --mask_probability 0.5 --mask_ratio 0.5 --mask_first_n \
36
+ # --lambda_token 1.0 --lambda_fea 0.0 --lambda_patch 0.0 \
37
+ # --output_dir log/CLIP_training;
1_feature_extractor/fast_vis_proteus_feats.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import einops
5
+ import torch
6
+ from PIL import Image
7
+ import torchvision
8
+ from PIL import Image
9
+ import models_dinov2
10
+
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ patch_size = 14
13
+
14
+ # feat_extractor = getattr(models_dinov2, 'vit_base')
15
+ feat_extractor = getattr(models_dinov2, 'vit_large')
16
+ model = feat_extractor(img_size=224,
17
+ patch_size=14,
18
+ init_values=1.0,
19
+ ffn_layer='mlp',
20
+ block_chunks=0,
21
+ num_register_tokens=0,
22
+ interpolate_antialias=False,
23
+ interpolate_offset=0.1)
24
+
25
+ # checkpoint_path = '/data0/qiyp/Proteus-pytorch/pretrain/ckpt/proteus_vitb_backbone.pth' # 替换为实际的检查点路径
26
+ checkpoint_path = '/data0/qiyp/Proteus-pytorch/pretrain/ckpt/proteus_vitl_backbone.pth' # 替换为实际的检查点路径
27
+ # 加载检查点
28
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
29
+
30
+ # 加载模型权重
31
+ if 'state_dict' in checkpoint:
32
+ pretrained_dict = checkpoint['state_dict']
33
+ elif 'model' in checkpoint:
34
+ pretrained_dict = checkpoint['model']
35
+ else:
36
+ pretrained_dict = checkpoint
37
+
38
+ # 只加载与学生模型相关的部分
39
+ model.load_state_dict(pretrained_dict, strict=False)
40
+ model.to(device)
41
+ patch_h = 224 // 14
42
+ patch_w = 224 // 14
43
+ feat_dim = 768
44
+
45
+ def visualize_features(features, output_path='./feature_visualization.png'):
46
+ # Assuming features are of shape (batch_size, num_features, height, width)
47
+ batch_size, num_features, height, width = features.shape
48
+
49
+ # Normalize the feature maps to the range [0, 1]
50
+ vis = features.mean(dim=1, keepdim=True)
51
+ vis = vis - vis.min()
52
+ vis = vis / vis.max()
53
+
54
+ # Squeeze the channel dimension
55
+ vis = vis.squeeze(1).cpu().detach().numpy()
56
+
57
+ # Apply a colormap (e.g., viridis) to convert it to RGB
58
+ vis_colored = np.zeros((batch_size, height, width, 3))
59
+ for i in range(batch_size):
60
+ vis_colored[i] = plt.cm.viridis(vis[i])[:, :, :3] # Drop the alpha channel
61
+
62
+ # Convert vis_colored to a tensor and save using torchvision
63
+ vis_colored = torch.tensor(vis_colored).permute(0, 3, 1, 2) # Convert to (batch, channels, height, width)
64
+
65
+ # Save the image
66
+ torchvision.utils.save_image(vis_colored, output_path, normalize=True)
67
+
68
+ from torchvision import transforms
69
+
70
+ transform = transforms.Compose([
71
+ transforms.Resize((224, 224)), # 调整图像大小
72
+ transforms.ToTensor(), # 转换为tensor
73
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化
74
+ ])
75
+
76
+ images = [
77
+ Image.open("/data0/qiyp/mae/imagenet-1k-samples/0-anime_boy_sticker__holding_kitten__happy.png"),
78
+ Image.open("/data0/qiyp/mae/imagenet-1k-samples/62-Deadpool_minion.png"),
79
+ Image.open("/data0/qiyp/mae/imagenet-1k-samples/79-with_Wooden_carved_bear__salmon_and_gold_mini_ball_surround_the_blank_signboard__illustrate.png"),
80
+ Image.open("/data0/qiyp/mae/imagenet-1k-samples/99-Akira_toriyama_motorbike__cheatah__puma__japanese_classic_car__collectable_figure__shiny_plastic_.png"),
81
+ Image.open("/data0/qiyp/mae/imagenet-1k-samples/124-crowded_1920s_Chicago_street_with_lots_of_model_T_cars_and_people.png"),
82
+ Image.open("/data0/qiyp/mae/imagenet-1k-samples/157-steampunk_girl_with_pink_hair_riding_in_a_hot_air_balloon__hot_air_balloon_resembles_gold_and_si.png"),
83
+ Image.open("/data0/qiyp/mae/imagenet-1k-samples/ILSVRC2012_val_00008636.png"),
84
+ Image.open("/data0/qiyp/mae/imagenet-1k-samples/ILSVRC2012_val_00010240.png"),
85
+ ]
86
+ # inputs = processor(images=images, return_tensors="pt", padding=True).to(device)
87
+ tensors = [transform(img) for img in images]
88
+ batched_tensors = torch.stack(tensors).to(device)
89
+ with torch.no_grad():
90
+ outputs = model(batched_tensors, is_training=True)
91
+ features = outputs['x_norm_patchtokens'] # (batch_size, num_patches, feat_dim)
92
+ print(features.shape)
93
+
94
+ features = features.view(-1, patch_h, patch_w, features.shape[2]) # [B, h, w, c]
95
+ features = features.permute(0, 3, 1, 2)
96
+ visualize_features(features)
97
+
98
+ # pooled_output = outputs.pooler_output # pooled CLS states.
1_feature_extractor/fast_vis_settings_all.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2015-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ import argparse
4
+ import datetime
5
+ import numpy as np
6
+ import time
7
+ import torch
8
+ import torch.backends.cudnn as cudnn
9
+ import json
10
+
11
+ from pathlib import Path
12
+
13
+ from timm.models import create_model
14
+ from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
15
+ from timm.scheduler import create_scheduler
16
+ from timm.optim import create_optimizer
17
+ from timm.utils import NativeScaler, get_state_dict, ModelEma
18
+ from augmentations import collate_data_and_cast_aug
19
+ from datasets import build_dataset
20
+
21
+ from losses_hint import DistillationLoss
22
+ from samplers import RASampler
23
+ from functools import partial
24
+
25
+ import importlib
26
+ import utils
27
+ import random
28
+ import math
29
+ from multiprocessing import Value
30
+ from abc import ABC
31
+
32
+ import sys
33
+ from typing import Iterable, Optional
34
+ from timm.data import Mixup
35
+ from timm.utils import accuracy, ModelEma
36
+ import utils
37
+ import os
38
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "5"
39
+
40
+
41
+ class MaskingGenerator(ABC):
42
+ def __init__(self, input_size):
43
+ if not isinstance(input_size, tuple):
44
+ input_size = (input_size,) * 2
45
+ self.height, self.width = input_size
46
+ self.num_patches = self.height * self.width
47
+
48
+ def __repr__(self):
49
+ raise NotImplementedError
50
+
51
+ def get_shape(self):
52
+ return self.height, self.width
53
+
54
+ def _mask(self, mask, max_mask_patches):
55
+ raise NotImplementedError
56
+
57
+ def get_none_mask(self):
58
+ return np.zeros(shape=self.get_shape(), dtype=bool)
59
+
60
+
61
+
62
+ class RandomMaskingGenerator(MaskingGenerator):
63
+ def __init__(
64
+ self,
65
+ input_size,
66
+ ):
67
+ """
68
+ Args:
69
+ input_size: the size of the token map, e.g., 14x14
70
+ """
71
+ super().__init__(input_size)
72
+
73
+ def __repr__(self):
74
+ repr_str = f"Random Generator({self.height}, {self.width})"
75
+ return repr_str
76
+
77
+ def _mask(self, mask, max_mask_patches):
78
+ return super()._mask(mask, max_mask_patches)
79
+
80
+ def __call__(self, num_masking_patches=0):
81
+ if num_masking_patches <= 0:
82
+ return np.zeros(shape=self.get_shape(), dtype=bool)
83
+
84
+ mask = np.hstack([np.ones(num_masking_patches, dtype=bool),
85
+ np.zeros(self.num_patches - num_masking_patches, dtype=bool)])
86
+ np.random.shuffle(mask)
87
+ mask = mask.reshape(self.get_shape())
88
+ return mask
89
+
90
+
91
+ def get_args_parser():
92
+ parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
93
+ parser.add_argument('--batch-size', default=64, type=int)
94
+ parser.add_argument('--epochs', default=300, type=int)
95
+ parser.add_argument('--bce-loss', action='store_true')
96
+ parser.add_argument('--unscale-lr', action='store_true')
97
+
98
+ # Model parameters
99
+ parser.add_argument('--model', default='deit_base_patch16_224', type=str)
100
+ parser.add_argument('--target_model', default='deit_base_patch16_224', type=str)
101
+ parser.add_argument('--input-size', default=224, type=int, help='images input size')
102
+
103
+ parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
104
+ help='Dropout rate (default: 0.)')
105
+ parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
106
+ help='Drop path rate (default: 0.1)')
107
+
108
+ parser.add_argument('--model-ema', action='store_true')
109
+ parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
110
+ parser.set_defaults(model_ema=True)
111
+ parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
112
+ parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
113
+
114
+ # Optimizer parameters
115
+ parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
116
+ help='Optimizer (default: "adamw"')
117
+ parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
118
+ help='Optimizer Epsilon (default: 1e-8)')
119
+ parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
120
+ help='Optimizer Betas (default: None, use opt default)')
121
+ parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
122
+ help='Clip gradient norm (default: None, no clipping)')
123
+ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
124
+ help='SGD momentum (default: 0.9)')
125
+ parser.add_argument('--weight-decay', type=float, default=0.05,
126
+ help='weight decay (default: 0.05)')
127
+ # Learning rate schedule parameters
128
+ parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
129
+ help='LR scheduler (default: "cosine"')
130
+ parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
131
+ help='learning rate (default: 5e-4)')
132
+ parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
133
+ help='learning rate noise on/off epoch percentages')
134
+ parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
135
+ help='learning rate noise limit percent (default: 0.67)')
136
+ parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
137
+ help='learning rate noise std-dev (default: 1.0)')
138
+ parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
139
+ help='warmup learning rate (default: 1e-6)')
140
+ parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
141
+ help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
142
+
143
+ parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
144
+ help='epoch interval to decay LR')
145
+ parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
146
+ help='epochs to warmup LR, if scheduler supports')
147
+ parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
148
+ help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
149
+ parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
150
+ help='patience epochs for Plateau LR scheduler (default: 10')
151
+ parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
152
+ help='LR decay rate (default: 0.1)')
153
+
154
+ # Augmentation parameters
155
+ parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT',
156
+ help='Color jitter factor (default: 0.3)')
157
+ parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
158
+ help='Use AutoAugment policy. "v0" or "original". " + \
159
+ "(default: rand-m9-mstd0.5-inc1)'),
160
+ parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
161
+ parser.add_argument('--train-interpolation', type=str, default='bicubic',
162
+ help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
163
+
164
+ parser.add_argument('--repeated-aug', action='store_true')
165
+ parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
166
+ parser.set_defaults(repeated_aug=True)
167
+
168
+ parser.add_argument('--train-mode', action='store_true')
169
+ parser.add_argument('--no-train-mode', action='store_false', dest='train_mode')
170
+ parser.set_defaults(train_mode=True)
171
+
172
+ parser.add_argument('--ThreeAugment', action='store_true') #3augment
173
+
174
+ parser.add_argument('--src', action='store_true') #simple random crop
175
+
176
+ # add dataset parameters
177
+ parser.add_argument('--global_crops_size', '--img_size', default=224, type=int,
178
+ help="this should be equal to image size")
179
+ parser.add_argument('--patch_size', default=16, type=int,
180
+ help="patch size for vit patch embedding")
181
+
182
+ # add masking parameter
183
+ parser.add_argument('--mask_ratio', default=(0.1, 0.5), type=float, nargs='+',
184
+ help="mask ratio can be either a value or a range")
185
+ parser.add_argument('--mask_probability', default=0., type=float,
186
+ help="how many samples with be applied with masking")
187
+ parser.add_argument('--mask_first_n', action='store_true',
188
+ help="mask the first n sample to avoid shuffling. Needed for MAE-style encoder")
189
+ parser.add_argument('--clone_batch', default=1, type=int,
190
+ help="how many times to clone the batch for masking (default: 1, not cloning)")
191
+
192
+ # * Random Erase params
193
+ parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
194
+ help='Random erase prob (default: 0.25)')
195
+ parser.add_argument('--remode', type=str, default='pixel',
196
+ help='Random erase mode (default: "pixel")')
197
+ parser.add_argument('--recount', type=int, default=1,
198
+ help='Random erase count (default: 1)')
199
+ parser.add_argument('--resplit', action='store_true', default=False,
200
+ help='Do not random erase first (clean) augmentation split')
201
+
202
+ # * Mixup params
203
+ parser.add_argument('--mixup', type=float, default=0.8,
204
+ help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
205
+ parser.add_argument('--cutmix', type=float, default=1.0,
206
+ help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
207
+ parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
208
+ help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
209
+ parser.add_argument('--mixup-prob', type=float, default=1.0,
210
+ help='Probability of performing mixup or cutmix when either/both is enabled')
211
+ parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
212
+ help='Probability of switching to cutmix when both mixup and cutmix enabled')
213
+ parser.add_argument('--mixup-mode', type=str, default='batch',
214
+ help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
215
+
216
+ # Distillation parameters
217
+ parser.add_argument('--teacher-model', default='base', type=str)
218
+ parser.add_argument('--teacher-path', type=str, default='')
219
+ parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
220
+ parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
221
+ parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
222
+ parser.add_argument('--lambda_token', type=float, default=1.0)
223
+ parser.add_argument('--lambda_fea', type=float, default=1.0)
224
+ parser.add_argument('--lambda_patch', type=float, default=1.0)
225
+
226
+ # * Cosub params
227
+ parser.add_argument('--cosub', action='store_true')
228
+
229
+ # * Finetuning params
230
+ parser.add_argument('--finetune', default='', help='finetune from checkpoint')
231
+ parser.add_argument('--attn-only', action='store_true')
232
+ parser.add_argument('--weight_inherit', default='')
233
+
234
+ # Dataset parameters
235
+ parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
236
+ help='dataset path')
237
+ parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'IMNET_ibot', 'IMNET_ibot_aug', 'IMNET_ibot_fast_aug', 'INAT', 'INAT19', 'IMNET_L', 'IMNET_L_ibot'],
238
+ type=str, help='Image Net dataset path')
239
+ parser.add_argument('--inat-category', default='name',
240
+ choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
241
+ type=str, help='semantic granularity')
242
+
243
+ parser.add_argument('--output_dir', default='',
244
+ help='path where to save, empty for no saving')
245
+ parser.add_argument('--device', default='cuda',
246
+ help='device to use for training / testing')
247
+ parser.add_argument('--seed', default=0, type=int)
248
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
249
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
250
+ help='start epoch')
251
+ parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
252
+ parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation")
253
+ parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
254
+ parser.add_argument('--num_workers', default=10, type=int)
255
+ parser.add_argument('--pin-mem', action='store_true',
256
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
257
+ parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
258
+ help='')
259
+ parser.set_defaults(pin_mem=True)
260
+
261
+ # distributed training parameters
262
+ parser.add_argument('--distributed', action='store_true', default=False, help='Enabling distributed training')
263
+ parser.add_argument('--world_size', default=1, type=int,
264
+ help='number of distributed processes')
265
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
266
+ return parser
267
+
268
+ import torchvision
269
+ import matplotlib.pyplot as plt
270
+ import torchvision.transforms as transforms
271
+ def visualize_features(features, output_path='./feature_visualization_w_ib.png'):
272
+ # Assuming features are of shape (batch_size, num_features, height, width)
273
+ batch_size, num_features, height, width = features.shape
274
+
275
+ # Normalize the feature maps to the range [0, 1]
276
+ vis = features.mean(dim=1, keepdim=True)
277
+ vis = vis - vis.min()
278
+ vis = vis / vis.max()
279
+
280
+ # Squeeze the channel dimension
281
+ vis = vis.squeeze(1).cpu().detach().numpy()
282
+
283
+ # Apply a colormap (e.g., viridis) to convert it to RGB
284
+ vis_colored = np.zeros((batch_size, height, width, 3))
285
+ for i in range(batch_size):
286
+ vis_colored[i] = plt.cm.viridis(vis[i])[:, :, :3] # Drop the alpha channel
287
+
288
+ # Convert vis_colored to a tensor and save using torchvision
289
+ vis_colored = torch.tensor(vis_colored).permute(0, 3, 1, 2) # Convert to (batch, channels, height, width)
290
+
291
+ # Save the image
292
+ torchvision.utils.save_image(vis_colored, output_path, normalize=True)
293
+
294
+ def save_original_images(tensors, output_path='./original_images.png'):
295
+ # 将归一化反转
296
+ unnormalize = transforms.Normalize(
297
+ mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
298
+ std=[1/0.229, 1/0.224, 1/0.225]
299
+ )
300
+ unnormalized_tensors = [unnormalize(tensor) for tensor in tensors]
301
+ unnormalized_batch = torch.stack(unnormalized_tensors)
302
+ torchvision.utils.save_image(unnormalized_batch, output_path, nrow=4, normalize=True)
303
+
304
+ def main(args):
305
+ utils.init_distributed_mode(args)
306
+
307
+ print(args)
308
+
309
+ device = torch.device(args.device)
310
+
311
+ # fix the seed for reproducibility
312
+ seed = args.seed + utils.get_rank()
313
+ torch.manual_seed(seed)
314
+ np.random.seed(seed)
315
+ # random.seed(seed)
316
+
317
+ cudnn.benchmark = True
318
+
319
+ dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
320
+
321
+ if args.distributed:
322
+ num_tasks = utils.get_world_size()
323
+ global_rank = utils.get_rank()
324
+ if args.repeated_aug:
325
+ sampler_train = RASampler(
326
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
327
+ )
328
+ else:
329
+ sampler_train = torch.utils.data.DistributedSampler(
330
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
331
+ )
332
+ else:
333
+ sampler_train = torch.utils.data.RandomSampler(dataset_train)
334
+
335
+ n_tokens = (args.global_crops_size // args.patch_size) ** 2
336
+ mask_generator = RandomMaskingGenerator(
337
+ input_size=args.global_crops_size // args.patch_size,
338
+ )
339
+
340
+ collate_fn = partial(
341
+ collate_data_and_cast_aug,
342
+ mask_ratio=args.mask_ratio,
343
+ mask_probability=args.mask_probability,
344
+ dtype=torch.half, # half precision
345
+ n_tokens=n_tokens,
346
+ mask_first_n=args.mask_first_n,
347
+ mask_generator=mask_generator,
348
+ clone_batch=args.clone_batch,
349
+ )
350
+
351
+ data_loader_train = torch.utils.data.DataLoader(
352
+ dataset_train, sampler=sampler_train,
353
+ batch_size=args.batch_size,
354
+ num_workers=args.num_workers,
355
+ pin_memory=args.pin_mem,
356
+ drop_last=True,
357
+ collate_fn=collate_fn,
358
+ )
359
+
360
+ mixup_fn = None
361
+
362
+ print(f"Creating model: {args.model}")
363
+ meta_arch_module = importlib.import_module(args.model)
364
+ MetaArch = meta_arch_module.MetaArch
365
+
366
+ model = MetaArch(args)
367
+
368
+ if args.finetune:
369
+ checkpoint = torch.load(args.finetune, map_location='cpu')
370
+
371
+ if 'state_dict' in checkpoint:
372
+ pretrained_dict = checkpoint['state_dict']
373
+ elif 'model' in checkpoint:
374
+ pretrained_dict = checkpoint['model']
375
+ else:
376
+ pretrained_dict = checkpoint
377
+
378
+ missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False)
379
+ print('missing_keys: ', missing_keys)
380
+ print('unexpected_keys: ', unexpected_keys)
381
+
382
+ if args.attn_only:
383
+ for name_p,p in model.named_parameters():
384
+ if '.attn.' in name_p:
385
+ p.requires_grad = True
386
+ else:
387
+ p.requires_grad = False
388
+ try:
389
+ model.head.weight.requires_grad = True
390
+ model.head.bias.requires_grad = True
391
+ except:
392
+ model.fc.weight.requires_grad = True
393
+ model.fc.bias.requires_grad = True
394
+ try:
395
+ model.pos_embed.requires_grad = True
396
+ except:
397
+ print('no position encoding')
398
+ try:
399
+ for p in model.patch_embed.parameters():
400
+ p.requires_grad = False
401
+ except:
402
+ print('no patch embed')
403
+
404
+ model.to(device)
405
+
406
+ model_ema = None
407
+ if args.model_ema:
408
+ # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
409
+ model_ema = ModelEma(
410
+ model.student.backbone,
411
+ decay=args.model_ema_decay,
412
+ device='cpu' if args.model_ema_force_cpu else '',
413
+ resume='')
414
+
415
+ model_without_ddp = model
416
+ if args.distributed:
417
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
418
+ model_without_ddp = model.module
419
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
420
+ print('number of params:', n_parameters)
421
+
422
+ if not args.unscale_lr:
423
+ linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
424
+ args.lr = linear_scaled_lr
425
+
426
+ optimizer = create_optimizer(args, model_without_ddp)
427
+ loss_scaler = NativeScaler()
428
+
429
+ lr_scheduler, _ = create_scheduler(args, optimizer)
430
+
431
+ output_dir = Path(args.output_dir)
432
+ if args.resume:
433
+ if args.resume.startswith('https'):
434
+ checkpoint = torch.hub.load_state_dict_from_url(
435
+ args.resume, map_location='cpu', check_hash=True)
436
+ else:
437
+ checkpoint = torch.load(args.resume, map_location='cpu')
438
+
439
+ model_without_ddp.load_state_dict(checkpoint['model'])
440
+ if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
441
+ optimizer.load_state_dict(checkpoint['optimizer'])
442
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
443
+ args.start_epoch = checkpoint['epoch'] + 1
444
+ if args.model_ema:
445
+ utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
446
+ if 'scaler' in checkpoint:
447
+ loss_scaler.load_state_dict(checkpoint['scaler'])
448
+ lr_scheduler.step(args.start_epoch)
449
+
450
+
451
+ from torchvision import transforms
452
+ transform = transforms.Compose([
453
+ transforms.Resize((224, 224)), # 调整图像大小
454
+ transforms.ToTensor(), # 转换为tensor
455
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化
456
+ ])
457
+ from PIL import Image
458
+ images = [
459
+ Image.open("/data0/qiyp/mae/imagenet-1k-samples/0-anime_boy_sticker__holding_kitten__happy.png"),
460
+ Image.open("/data0/qiyp/mae/imagenet-1k-samples/62-Deadpool_minion.png"),
461
+ Image.open("/data0/qiyp/mae/imagenet-1k-samples/79-with_Wooden_carved_bear__salmon_and_gold_mini_ball_surround_the_blank_signboard__illustrate.png"),
462
+ Image.open("/data0/qiyp/mae/imagenet-1k-samples/99-Akira_toriyama_motorbike__cheatah__puma__japanese_classic_car__collectable_figure__shiny_plastic_.png"),
463
+ Image.open("/data0/qiyp/mae/imagenet-1k-samples/124-crowded_1920s_Chicago_street_with_lots_of_model_T_cars_and_people.png"),
464
+ Image.open("/data0/qiyp/mae/imagenet-1k-samples/157-steampunk_girl_with_pink_hair_riding_in_a_hot_air_balloon__hot_air_balloon_resembles_gold_and_si.png"),
465
+ Image.open("/data0/qiyp/mae/imagenet-1k-samples/ILSVRC2012_val_00008636.png"),
466
+ Image.open("/data0/qiyp/mae/imagenet-1k-samples/ILSVRC2012_val_00010240.png"),
467
+ ]
468
+
469
+ tensors = [transform(img) for img in images]
470
+ batched_tensors = torch.stack(tensors).to(device)
471
+ save_original_images(batched_tensors, output_path='./original_images.png')
472
+ with torch.no_grad():
473
+ outputs = model.student.backbone(batched_tensors, is_training=True)
474
+ features = outputs['x_norm_patchtokens'] # (batch_size, num_patches, feat_dim)
475
+ print(features.shape)
476
+ features, _ = model.info_bottleneck(features, is_training=False)
477
+
478
+ features = features.view(-1, 16, 16, features.shape[2]) # [B, h, w, c]
479
+ features = features.permute(0, 3, 1, 2)
480
+ visualize_features(features)
481
+
482
+
483
+ def train_one_epoch(model: torch.nn.Module,
484
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
485
+ device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
486
+ model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
487
+ set_training_mode=True, args = None):
488
+ model.train(set_training_mode)
489
+ metric_logger = utils.MetricLogger(delimiter=" ")
490
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
491
+ header = 'Epoch: [{}]'.format(epoch)
492
+ print_freq = 10
493
+
494
+ loader_len = len(data_loader)
495
+
496
+ for data_iter_step, inputs_dict in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
497
+
498
+ for k, v in inputs_dict.items():
499
+ if isinstance(v, torch.Tensor):
500
+ inputs_dict[k] = v.to(device, non_blocking=True)
501
+
502
+ with torch.cuda.amp.autocast():
503
+ loss_dict = model(inputs_dict)
504
+
505
+ loss = loss_dict["loss"]
506
+ patch_loss = loss_dict["patch_loss"]
507
+ fea_loss = loss_dict["fea_loss"]
508
+ token_loss = loss_dict["token_loss"]
509
+
510
+ patch_loss_value = patch_loss.item()
511
+ token_loss_value = token_loss.item()
512
+ fea_loss_value = fea_loss.item()
513
+ loss_value = loss.item()
514
+
515
+
516
+ if not math.isfinite(loss_value):
517
+ print("Loss is {}, stopping training".format(loss_value))
518
+ sys.exit(1)
519
+
520
+ optimizer.zero_grad()
521
+
522
+ # this attribute is added by timm on one optimizer (adahessian)
523
+ is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
524
+ loss_scaler(loss, optimizer, clip_grad=max_norm,
525
+ parameters=model.parameters(), create_graph=is_second_order)
526
+
527
+ torch.cuda.synchronize()
528
+ if model_ema is not None:
529
+ model_ema.update(model.module.student.backbone)
530
+
531
+ metric_logger.update(loss=loss_value)
532
+ metric_logger.update(patch_loss=patch_loss_value)
533
+ metric_logger.update(token_loss=token_loss_value)
534
+ metric_logger.update(fea_loss=fea_loss_value)
535
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
536
+ # gather the stats from all processes
537
+ metric_logger.synchronize_between_processes()
538
+ print("Averaged stats:", metric_logger)
539
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
540
+
541
+
542
+
543
+ if __name__ == '__main__':
544
+ parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
545
+ args = parser.parse_args()
546
+ if args.output_dir:
547
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
548
+ main(args)
1_feature_extractor/log/DINOv2_training/log.txt ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"train_lr": 1.0000000000000353e-06, "train_loss": 9.738612308347825, "train_task_loss": 4.1162851987411075, "train_bpp_loss": 28.111635084060744, "train_patch_loss": 2.1926947419615765, "train_token_loss": 0.6686155230319091, "train_fea_loss": 1.254974935296104, "epoch": 0, "n_parameters": 144845568}
2
+ {"train_lr": 1.0000000000000353e-06, "train_loss": 9.232701392577802, "train_task_loss": 3.6589209317660734, "train_bpp_loss": 27.86890183459464, "train_patch_loss": 1.9597956623069102, "train_token_loss": 0.6684419002523906, "train_fea_loss": 1.030683369156149, "epoch": 1, "n_parameters": 144845568}
3
+ {"train_lr": 7.579999999999412e-05, "train_loss": 7.21169285546485, "train_task_loss": 3.139855162607585, "train_bpp_loss": 20.35918810360914, "train_patch_loss": 1.6089605613649607, "train_token_loss": 0.6681283143200843, "train_fea_loss": 0.8627662869969861, "epoch": 2, "n_parameters": 144845568}
4
+ {"train_lr": 0.0001506000000000007, "train_loss": 4.554179431234809, "train_task_loss": 3.242361048226067, "train_bpp_loss": 6.559091770915772, "train_patch_loss": 1.606790432570983, "train_token_loss": 0.6681644691438745, "train_fea_loss": 0.9674061452158444, "epoch": 3, "n_parameters": 144845568}
5
+ {"train_lr": 0.00022540000000000762, "train_loss": 3.735120667863807, "train_task_loss": 3.387666613328085, "train_bpp_loss": 1.7372702604485788, "train_patch_loss": 1.6483151540636636, "train_token_loss": 0.6679934723992095, "train_fea_loss": 1.0713579858914555, "epoch": 4, "n_parameters": 144845568}
6
+ {"train_lr": 1.0000000000000597e-06, "train_loss": 5.909162417715259, "train_task_loss": 3.1000390200824333, "train_bpp_loss": 28.091233481177323, "train_patch_loss": 1.2870068884135293, "train_token_loss": 0.6689227937365607, "train_fea_loss": 1.1441093232500539, "epoch": 0, "n_parameters": 144845568}
7
+ {"train_lr": 1.0000000000000597e-06, "train_loss": 5.4208089515959905, "train_task_loss": 2.6417781315404447, "train_bpp_loss": 27.790307712640693, "train_patch_loss": 1.0698811880355295, "train_token_loss": 0.6684176272018064, "train_fea_loss": 0.9034793063801684, "epoch": 1, "n_parameters": 144845568}
8
+ {"train_lr": 2.330000000000072e-05, "train_loss": 4.90253949746025, "train_task_loss": 2.468064440561713, "train_bpp_loss": 24.344750106834965, "train_patch_loss": 0.9905159351565569, "train_token_loss": 0.6681102167086885, "train_fea_loss": 0.809438281672464, "epoch": 2, "n_parameters": 144845568}
9
+ {"train_lr": 4.5600000000004605e-05, "train_loss": 4.039170976486995, "train_task_loss": 2.4127648418648637, "train_bpp_loss": 16.264061019539263, "train_patch_loss": 0.9645157243489183, "train_token_loss": 0.668388756130525, "train_fea_loss": 0.7798603551274688, "epoch": 3, "n_parameters": 144845568}
10
+ {"train_lr": 6.790000000000497e-05, "train_loss": 3.154224511977437, "train_task_loss": 2.4628324630252605, "train_bpp_loss": 6.913920321606761, "train_patch_loss": 0.9734525120118075, "train_token_loss": 0.668419569109877, "train_fea_loss": 0.8209603751900134, "epoch": 4, "n_parameters": 144845568}
11
+ {"train_lr": 9.019999999999779e-05, "train_loss": 2.855963341034145, "train_task_loss": 2.595376729947343, "train_bpp_loss": 2.605866110010399, "train_patch_loss": 1.0060626347426422, "train_token_loss": 0.7175173823357665, "train_fea_loss": 0.8717967045071314, "epoch": 5, "n_parameters": 144845568}
12
+ {"train_lr": 0.00011234201335381617, "train_loss": 2.8053760061786472, "train_task_loss": 2.636510731433507, "train_bpp_loss": 1.6886527469109216, "train_patch_loss": 1.019217302998175, "train_token_loss": 0.7347769178539443, "train_fea_loss": 0.8825165003926742, "epoch": 6, "n_parameters": 144845568}
13
+ {"train_lr": 0.00011227255068590994, "train_loss": 2.761493076880773, "train_task_loss": 2.6242000526464957, "train_bpp_loss": 1.3729302407391994, "train_patch_loss": 1.0161546610956855, "train_token_loss": 0.735874281325143, "train_fea_loss": 0.8721711006757024, "epoch": 7, "n_parameters": 144845568}
14
+ {"train_lr": 0.0001121904989670886, "train_loss": 2.7252962913819783, "train_task_loss": 2.6068851682994008, "train_bpp_loss": 1.1841112330144972, "train_patch_loss": 1.01152529233812, "train_token_loss": 0.7345748117854662, "train_fea_loss": 0.8607850549831915, "epoch": 8, "n_parameters": 144845568}
15
+ {"train_lr": 0.00011209587844235662, "train_loss": 2.6992331013548, "train_task_loss": 2.594223390594661, "train_bpp_loss": 1.0500971039855402, "train_patch_loss": 1.00772097030525, "train_token_loss": 0.7340729326429627, "train_fea_loss": 0.8524294771026257, "epoch": 9, "n_parameters": 144845568}
16
+ {"train_lr": 0.0001119887124579783, "train_loss": 2.678213749384637, "train_task_loss": 2.5821204649494423, "train_bpp_loss": 0.9609328405931592, "train_patch_loss": 1.0029366761982013, "train_token_loss": 0.7342181133186753, "train_fea_loss": 0.8449656662421761, "epoch": 10, "n_parameters": 144845568}
17
+ {"train_lr": 0.00011186902745551124, "train_loss": 2.662179193613555, "train_task_loss": 2.571508611331312, "train_bpp_loss": 0.9067058223556522, "train_patch_loss": 0.9990729609732147, "train_token_loss": 0.7330601164686987, "train_fea_loss": 0.8393755242634985, "epoch": 11, "n_parameters": 144845568}
18
+ {"train_lr": 0.00011173685296543875, "train_loss": 2.6492657475530814, "train_task_loss": 2.562294816915437, "train_bpp_loss": 0.8697093047259702, "train_patch_loss": 0.995477003563526, "train_token_loss": 0.7320167214611916, "train_fea_loss": 0.834801082157617, "epoch": 12, "n_parameters": 144845568}
19
+ {"train_lr": 0.00011159222159984347, "train_loss": 2.6404250215033263, "train_task_loss": 2.5562614543492987, "train_bpp_loss": 0.8416356684604316, "train_patch_loss": 0.9931515277579999, "train_token_loss": 0.7319663158244938, "train_fea_loss": 0.8311436018064546, "epoch": 13, "n_parameters": 144845568}
20
+ {"train_lr": 0.00011143516904437194, "train_loss": 2.631879109013781, "train_task_loss": 2.5498513719625326, "train_bpp_loss": 0.8202773725337011, "train_patch_loss": 0.99033929754665, "train_token_loss": 0.7316777065275706, "train_fea_loss": 0.8278343591714006, "epoch": 14, "n_parameters": 144845568}
21
+ {"train_lr": 0.00011126573404935182, "train_loss": 2.6259225274054265, "train_task_loss": 2.5454919715579467, "train_bpp_loss": 0.8043055602369441, "train_patch_loss": 0.9886904231225737, "train_token_loss": 0.731208170509542, "train_fea_loss": 0.8255933674758108, "epoch": 15, "n_parameters": 144845568}
22
+ {"train_lr": 0.0001110839584203689, "train_loss": 2.6198472417533685, "train_task_loss": 2.5406655372529148, "train_bpp_loss": 0.791817048280913, "train_patch_loss": 0.9866770549277536, "train_token_loss": 0.7308970645976831, "train_fea_loss": 0.8230914076035328, "epoch": 16, "n_parameters": 144845568}
23
+ {"train_lr": 0.0001108898870078482, "train_loss": 2.6147061569584933, "train_task_loss": 2.5367300099552534, "train_bpp_loss": 0.7797614707582468, "train_patch_loss": 0.985000457371221, "train_token_loss": 0.7312123553148455, "train_fea_loss": 0.8205171879849953, "epoch": 17, "n_parameters": 144845568}
24
+ {"train_lr": 0.00011068356769595686, "train_loss": 2.609768410336128, "train_task_loss": 2.5328920738910026, "train_bpp_loss": 0.7687633632072549, "train_patch_loss": 0.9837793158322609, "train_token_loss": 0.7307359146476864, "train_fea_loss": 0.8183768322217844, "epoch": 18, "n_parameters": 144845568}
25
+ {"train_lr": 0.0001104650513909484, "train_loss": 2.6053047203313677, "train_task_loss": 2.52939695734486, "train_bpp_loss": 0.7590776344502179, "train_patch_loss": 0.9826806558796363, "train_token_loss": 0.7301465437459431, "train_fea_loss": 0.8165697485534585, "epoch": 19, "n_parameters": 144845568}
26
+ {"train_lr": 0.00011023439200841275, "train_loss": 2.60126930419847, "train_task_loss": 2.526219468670998, "train_bpp_loss": 0.7504983498341888, "train_patch_loss": 0.981161863478325, "train_token_loss": 0.7301577162488901, "train_fea_loss": 0.8148998804368096, "epoch": 20, "n_parameters": 144845568}
27
+ {"train_lr": 0.00010973687463990103, "train_loss": 2.6637201429449684, "train_task_loss": 2.5465192173382074, "train_bpp_loss": 0.5860046291750648, "train_patch_loss": 0.9867626110592084, "train_token_loss": 0.7298668318939545, "train_fea_loss": 0.8298897645157054, "epoch": 21, "n_parameters": 144845568}
28
+ {"train_lr": 0.00010973687463990103, "train_loss": 2.6604584804970584, "train_task_loss": 2.5488337897550433, "train_bpp_loss": 0.5581234564080084, "train_patch_loss": 0.98777336684643, "train_token_loss": 0.730239200644749, "train_fea_loss": 0.8308212128188684, "epoch": 22, "n_parameters": 144845568}
29
+ {"train_lr": 0.00010947013940891518, "train_loss": 2.6580159922512316, "train_task_loss": 2.5488927058553954, "train_bpp_loss": 0.5456164319233464, "train_patch_loss": 0.9876678835765885, "train_token_loss": 0.7302756367019004, "train_fea_loss": 0.8309491750956189, "epoch": 23, "n_parameters": 144845568}
30
+ {"train_lr": 0.00010919150658002209, "train_loss": 2.6552707054864446, "train_task_loss": 2.548027373189263, "train_bpp_loss": 0.5362166606305208, "train_patch_loss": 0.987526344177838, "train_token_loss": 0.7304425147100747, "train_fea_loss": 0.830058504461045, "epoch": 24, "n_parameters": 144845568}
31
+ {"train_lr": 0.00010890104490178482, "train_loss": 2.6532317997365116, "train_task_loss": 2.5475874447243676, "train_bpp_loss": 0.5282217790852848, "train_patch_loss": 0.9868854832560788, "train_token_loss": 0.7307839357099838, "train_fea_loss": 0.8299180160475363, "epoch": 25, "n_parameters": 144845568}
32
+ {"train_lr": 0.00010828492456631287, "train_loss": 2.6756018440750577, "train_task_loss": 2.554671129796931, "train_bpp_loss": 0.48372285632077133, "train_patch_loss": 0.9893747716762429, "train_token_loss": 0.7302002353723244, "train_fea_loss": 0.83509611297506, "epoch": 26, "n_parameters": 144845568}
33
+ {"train_lr": 0.00010828492456631287, "train_loss": 2.672080845176745, "train_task_loss": 2.5541286950649544, "train_bpp_loss": 0.4718086026637978, "train_patch_loss": 0.9892294842401426, "train_token_loss": 0.7304161941873835, "train_fea_loss": 0.8344830075631754, "epoch": 27, "n_parameters": 144845568}
34
+ {"train_lr": 0.00010795941792757138, "train_loss": 2.6682109270039365, "train_task_loss": 2.552076613383113, "train_bpp_loss": 0.4645372550583911, "train_patch_loss": 0.9893483498738967, "train_token_loss": 0.7297189640332398, "train_fea_loss": 0.8330092908327671, "epoch": 28, "n_parameters": 144845568}
35
+ {"train_lr": 0.00010762238643889585, "train_loss": 2.6648726280626778, "train_task_loss": 2.5502606650068462, "train_bpp_loss": 0.458447851656045, "train_patch_loss": 0.988596663794679, "train_token_loss": 0.7296670780559595, "train_fea_loss": 0.8319969125855848, "epoch": 29, "n_parameters": 144845568}
36
+ {"train_lr": 0.00010727391325772412, "train_loss": 2.66122928966614, "train_task_loss": 2.5479887132521726, "train_bpp_loss": 0.4529623082800985, "train_patch_loss": 0.9877295848607517, "train_token_loss": 0.7298175230917099, "train_fea_loss": 0.8304415963706281, "epoch": 30, "n_parameters": 144845568}
37
+ {"train_lr": 0.00010691408436465084, "train_loss": 2.6585902343425962, "train_task_loss": 2.546459772737978, "train_bpp_loss": 0.44852184575292275, "train_patch_loss": 0.9874212175062734, "train_token_loss": 0.7292956591402884, "train_fea_loss": 0.8297428859540705, "epoch": 31, "n_parameters": 144845568}
38
+ {"train_lr": 0.00010654298854205488, "train_loss": 2.655904775418395, "train_task_loss": 2.5447790948875086, "train_bpp_loss": 0.44450271940216135, "train_patch_loss": 0.9865645243444567, "train_token_loss": 0.7294259656095605, "train_fea_loss": 0.8287885952718573, "epoch": 32, "n_parameters": 144845568}
39
+ {"train_lr": 0.0001061607173522522, "train_loss": 2.653865991727554, "train_task_loss": 2.5435698100184796, "train_bpp_loss": 0.4411847302678981, "train_patch_loss": 0.98615532235535, "train_token_loss": 0.7293780767236003, "train_fea_loss": 0.8280364017178722, "epoch": 33, "n_parameters": 144845568}
40
+ {"train_lr": 0.00010576736511496153, "train_loss": 2.649993774818359, "train_task_loss": 2.540434418590092, "train_bpp_loss": 0.4382374218332289, "train_patch_loss": 0.9851716845977506, "train_token_loss": 0.7283954293907082, "train_fea_loss": 0.8268672945759565, "epoch": 34, "n_parameters": 144845568}
41
+ {"train_lr": 0.00010536302888396166, "train_loss": 2.64796845856116, "train_task_loss": 2.539067151297411, "train_bpp_loss": 0.4356052296792271, "train_patch_loss": 0.9849777258128571, "train_token_loss": 0.7279274276853036, "train_fea_loss": 0.8261619895200828, "epoch": 35, "n_parameters": 144845568}
42
+ {"train_lr": 0.00010494780842314013, "train_loss": 2.646916040002728, "train_task_loss": 2.538628663867712, "train_bpp_loss": 0.43314950562102333, "train_patch_loss": 0.9843870765022594, "train_token_loss": 0.7287507657747831, "train_fea_loss": 0.8254908122640434, "epoch": 36, "n_parameters": 144845568}
43
+ {"train_lr": 0.00010452180618197858, "train_loss": 2.6438271829979025, "train_task_loss": 2.5360889318678304, "train_bpp_loss": 0.43095300335781, "train_patch_loss": 0.9832675016509722, "train_token_loss": 0.7280113412965116, "train_fea_loss": 0.8248100794753511, "epoch": 37, "n_parameters": 144845568}
44
+ {"train_lr": 0.00010408512727011787, "train_loss": 2.6436403147715457, "train_task_loss": 2.536438628754241, "train_bpp_loss": 0.42880674428697363, "train_patch_loss": 0.9832872114673579, "train_token_loss": 0.7284368306166119, "train_fea_loss": 0.8247145772900442, "epoch": 38, "n_parameters": 144845568}
45
+ {"train_lr": 0.00010363787943157281, "train_loss": 2.640265642149414, "train_task_loss": 2.5335381660583636, "train_bpp_loss": 0.42690990286783204, "train_patch_loss": 0.982539902213398, "train_token_loss": 0.7274603552767913, "train_fea_loss": 0.8235378990137427, "epoch": 39, "n_parameters": 144845568}
46
+ {"train_lr": 0.0001031801730180277, "train_loss": 2.638479518384742, "train_task_loss": 2.532214540898871, "train_bpp_loss": 0.4250599086704514, "train_patch_loss": 0.9822922450729268, "train_token_loss": 0.7268436730546738, "train_fea_loss": 0.8230786147199088, "epoch": 40, "n_parameters": 144845568}
47
+ {"train_lr": 0.00010271212096170505, "train_loss": 2.6368867533473517, "train_task_loss": 2.5310158089464374, "train_bpp_loss": 0.4234837778348127, "train_patch_loss": 0.9816671004675204, "train_token_loss": 0.7268876022875881, "train_fea_loss": 0.8224610961522523, "epoch": 41, "n_parameters": 144845568}
48
+ {"train_lr": 0.00010223383874746677, "train_loss": 2.6354142183826554, "train_task_loss": 2.529919751140354, "train_bpp_loss": 0.42197786968612344, "train_patch_loss": 0.9807341171861839, "train_token_loss": 0.7272372794665879, "train_fea_loss": 0.8219483458689398, "epoch": 42, "n_parameters": 144845568}
49
+ {"train_lr": 0.00010174544438424974, "train_loss": 2.634054534628594, "train_task_loss": 2.528918670253645, "train_bpp_loss": 0.4205434575589816, "train_patch_loss": 0.9803530661996713, "train_token_loss": 0.7265599571498964, "train_fea_loss": 0.8220056383613702, "epoch": 43, "n_parameters": 144845568}
50
+ {"train_lr": 0.00010124705837609591, "train_loss": 2.6349923593892184, "train_task_loss": 2.5301807783186723, "train_bpp_loss": 0.4192463266741422, "train_patch_loss": 0.9804664164607366, "train_token_loss": 0.726979716709978, "train_fea_loss": 0.8227346358883331, "epoch": 44, "n_parameters": 144845568}
51
+ {"train_lr": 0.00010073880369226542, "train_loss": 2.63449551127583, "train_task_loss": 2.529952161675163, "train_bpp_loss": 0.4181733996109318, "train_patch_loss": 0.9800947665212323, "train_token_loss": 0.7259607360401242, "train_fea_loss": 0.8238966495861753, "epoch": 45, "n_parameters": 144845568}
52
+ {"train_lr": 0.00010022080573700511, "train_loss": 2.6345730214078222, "train_task_loss": 2.530297286192076, "train_bpp_loss": 0.4171029411935263, "train_patch_loss": 0.9800362460766372, "train_token_loss": 0.7260628667832029, "train_fea_loss": 0.8241981637554703, "epoch": 46, "n_parameters": 144845568}
53
+ {"train_lr": 9.969319231856176e-05, "train_loss": 2.6355131778874985, "train_task_loss": 2.5314376252464394, "train_bpp_loss": 0.4163022163889105, "train_patch_loss": 0.9805744725929426, "train_token_loss": 0.7263783447029029, "train_fea_loss": 0.8244847989500427, "epoch": 47, "n_parameters": 144845568}
54
+ {"train_lr": 9.915609361765753e-05, "train_loss": 2.634658775285637, "train_task_loss": 2.5308349787099043, "train_bpp_loss": 0.4152951848266615, "train_patch_loss": 0.9807003051248029, "train_token_loss": 0.725708744122381, "train_fea_loss": 0.8244259209490473, "epoch": 48, "n_parameters": 144845568}
55
+ {"train_lr": 9.860964215535301e-05, "train_loss": 2.6395780852765776, "train_task_loss": 2.536209069951761, "train_bpp_loss": 0.4134760623657071, "train_patch_loss": 0.9824451866852258, "train_token_loss": 0.7274219641725508, "train_fea_loss": 0.8263419094569177, "epoch": 49, "n_parameters": 144845568}
56
+ {"train_lr": 9.805397276035986e-05, "train_loss": 2.639889271710988, "train_task_loss": 2.536885856259927, "train_bpp_loss": 0.4120136631007264, "train_patch_loss": 0.9825187112367768, "train_token_loss": 0.7275711355386949, "train_fea_loss": 0.8267959996776508, "epoch": 50, "n_parameters": 144845568}
57
+ {"train_lr": 9.748922253581646e-05, "train_loss": 2.6434670704320893, "train_task_loss": 2.5408570495137064, "train_bpp_loss": 0.4104400824474023, "train_patch_loss": 0.9840436073498046, "train_token_loss": 0.7284569417749592, "train_fea_loss": 0.828356491024349, "epoch": 51, "n_parameters": 144845568}
58
+ {"train_lr": 9.691553082535863e-05, "train_loss": 2.6544609911453954, "train_task_loss": 2.5523642709644005, "train_bpp_loss": 0.40838687800259577, "train_patch_loss": 0.9880639841150252, "train_token_loss": 0.7311584631057023, "train_fea_loss": 0.8331418140172708, "epoch": 52, "n_parameters": 144845568}
59
+ {"train_lr": 9.633303917884302e-05, "train_loss": 2.6421383294395264, "train_task_loss": 2.5401510994175642, "train_bpp_loss": 0.40794892040498965, "train_patch_loss": 0.9832984244683199, "train_token_loss": 0.7280591241032671, "train_fea_loss": 0.8287935424283397, "epoch": 53, "n_parameters": 144845568}
60
+ {"train_lr": 9.574189131737902e-05, "train_loss": 2.6519711823504175, "train_task_loss": 2.5503779566384477, "train_bpp_loss": 0.40637290150338085, "train_patch_loss": 0.987097349643207, "train_token_loss": 0.7309697274904814, "train_fea_loss": 0.832310870747647, "epoch": 54, "n_parameters": 144845568}
61
+ {"train_lr": 9.514223309782753e-05, "train_loss": 2.649527451563939, "train_task_loss": 2.5482009406730852, "train_bpp_loss": 0.4053060460971932, "train_patch_loss": 0.9869010643251865, "train_token_loss": 0.7299437226093323, "train_fea_loss": 0.8313561448876521, "epoch": 55, "n_parameters": 144845568}
62
+ {"train_lr": 9.453421247691757e-05, "train_loss": 2.643370286264508, "train_task_loss": 2.5421072299049485, "train_bpp_loss": 0.40505222469563784, "train_patch_loss": 0.9843411268666387, "train_token_loss": 0.7289866340465004, "train_fea_loss": 0.8287794603441331, "epoch": 56, "n_parameters": 144845568}
63
+ {"train_lr": 9.391797947461475e-05, "train_loss": 2.6461283357577217, "train_task_loss": 2.5451179012316736, "train_bpp_loss": 0.40404173961842804, "train_patch_loss": 0.9857150221699111, "train_token_loss": 0.7291283701508523, "train_fea_loss": 0.8302744993274447, "epoch": 57, "n_parameters": 144845568}
64
+ {"train_lr": 9.329368613720009e-05, "train_loss": 2.652468080438084, "train_task_loss": 2.551866788899513, "train_bpp_loss": 0.40240516668471465, "train_patch_loss": 0.9880323676700643, "train_token_loss": 0.7304665103925468, "train_fea_loss": 0.8333679016130112, "epoch": 58, "n_parameters": 144845568}
65
+ {"train_lr": 9.266148649972007e-05, "train_loss": 2.650359462124767, "train_task_loss": 2.5499309687752376, "train_bpp_loss": 0.401713975702973, "train_patch_loss": 0.9875440163452991, "train_token_loss": 0.7306666403307581, "train_fea_loss": 0.8317203026117502, "epoch": 59, "n_parameters": 144845568}
66
+ {"train_lr": 9.202153654795684e-05, "train_loss": 2.6472997942613326, "train_task_loss": 2.547006531638636, "train_bpp_loss": 0.4011730529006115, "train_patch_loss": 0.986042402963862, "train_token_loss": 0.7303420982616161, "train_fea_loss": 0.8306220213389046, "epoch": 60, "n_parameters": 144845568}
67
+ {"train_lr": 9.137399417998249e-05, "train_loss": 2.635524819071273, "train_task_loss": 2.5351910238875615, "train_bpp_loss": 0.40133518058520784, "train_patch_loss": 0.9820189805460169, "train_token_loss": 0.7271724717727275, "train_fea_loss": 0.8259995626531345, "epoch": 61, "n_parameters": 144845568}
68
+ {"train_lr": 9.071901916722404e-05, "train_loss": 2.648858120589019, "train_task_loss": 2.5489517275056395, "train_bpp_loss": 0.39962557505043467, "train_patch_loss": 0.986155892487803, "train_token_loss": 0.73087697630029, "train_fea_loss": 0.8319188496388262, "epoch": 62, "n_parameters": 144845568}
69
+ {"train_lr": 9.071901916722404e-05, "train_loss": 2.7465927910497436, "train_task_loss": 2.5472986304449092, "train_bpp_loss": 0.3870172494111194, "train_patch_loss": 0.986047504940997, "train_token_loss": 0.7293615024804855, "train_fea_loss": 0.9150785957809261, "epoch": 61, "n_parameters": 144845568}
70
+ {"train_lr": 9.071901916722404e-05, "train_loss": 2.732339255324156, "train_task_loss": 2.5342044824247454, "train_bpp_loss": 0.3851228334161518, "train_patch_loss": 0.9823006979486854, "train_token_loss": 0.725924804654839, "train_fea_loss": 0.9085768888921117, "epoch": 62, "n_parameters": 144845568}
71
+ {"train_lr": 9.005677311491453e-05, "train_loss": 2.733937658449943, "train_task_loss": 2.5361717377456543, "train_bpp_loss": 0.38358666918550927, "train_patch_loss": 0.9829422572395391, "train_token_loss": 0.7263305272473741, "train_fea_loss": 0.9095888588916984, "epoch": 63, "n_parameters": 144845568}
72
+ {"train_lr": 8.938741942239847e-05, "train_loss": 2.735222797715764, "train_task_loss": 2.5376820612332516, "train_bpp_loss": 0.382704222326924, "train_patch_loss": 0.9836816218920605, "train_token_loss": 0.7267059973671103, "train_fea_loss": 0.9100238969839877, "epoch": 64, "n_parameters": 144845568}
73
+ {"train_lr": 8.871112324267081e-05, "train_loss": 2.7348346617742836, "train_task_loss": 2.537643769224413, "train_bpp_loss": 0.38169531316588, "train_patch_loss": 0.9843459360355096, "train_token_loss": 0.726475094608087, "train_fea_loss": 0.9095050229219724, "epoch": 65, "n_parameters": 144845568}
74
+ {"train_lr": 8.733837255720078e-05, "train_loss": 2.7936204858570934, "train_task_loss": 2.558094645236179, "train_bpp_loss": 0.3368677387395941, "train_patch_loss": 0.9911075603882435, "train_token_loss": 0.7276336732755998, "train_fea_loss": 0.9232887631746755, "epoch": 66, "n_parameters": 144845568}
75
+ {"train_lr": 8.733837255720078e-05, "train_loss": 2.7915395827209064, "train_task_loss": 2.558850450686914, "train_bpp_loss": 0.3302956995413565, "train_patch_loss": 0.9912216057375574, "train_token_loss": 0.7270683532455664, "train_fea_loss": 0.9246165518954896, "epoch": 67, "n_parameters": 144845568}
76
+ {"train_lr": 8.66422567571558e-05, "train_loss": 2.8010428122443547, "train_task_loss": 2.569311464486791, "train_bpp_loss": 0.3271386409915394, "train_patch_loss": 0.9944994980447012, "train_token_loss": 0.729622550842484, "train_fea_loss": 0.929708367548615, "epoch": 68, "n_parameters": 144845568}
77
+ {"train_lr": 8.59398757977085e-05, "train_loss": 2.8097731665634424, "train_task_loss": 2.5788058781241485, "train_bpp_loss": 0.3244854515335775, "train_patch_loss": 0.9977880177184117, "train_token_loss": 0.7315296879824367, "train_fea_loss": 0.9344370006996807, "epoch": 69, "n_parameters": 144845568}
78
+ {"train_lr": 8.523140298084917e-05, "train_loss": 2.7927491063747905, "train_task_loss": 2.562695958050821, "train_bpp_loss": 0.3240329879190889, "train_patch_loss": 0.9920173354035945, "train_token_loss": 0.7282957452923345, "train_fea_loss": 0.9266211757105853, "epoch": 70, "n_parameters": 144845568}
79
+ {"train_lr": 8.451701311164659e-05, "train_loss": 2.816440251421371, "train_task_loss": 2.5866508388297733, "train_bpp_loss": 0.32121099593403707, "train_patch_loss": 1.0007479091105667, "train_token_loss": 0.7334584623767282, "train_fea_loss": 0.9376889240237878, "epoch": 71, "n_parameters": 144845568}
80
+ {"train_lr": 8.379688245511898e-05, "train_loss": 2.8120930993692768, "train_task_loss": 2.5828994919570514, "train_bpp_loss": 0.32026763344004894, "train_patch_loss": 0.9992939126554975, "train_token_loss": 0.732874036779989, "train_fea_loss": 0.9358047071388729, "epoch": 72, "n_parameters": 144845568}
81
+ {"train_lr": 8.307118869271464e-05, "train_loss": 2.8027984519763818, "train_task_loss": 2.5742826347978567, "train_bpp_loss": 0.3194215351233272, "train_patch_loss": 0.9962847879741618, "train_token_loss": 0.7302367620156716, "train_fea_loss": 0.932537203840018, "epoch": 73, "n_parameters": 144845568}
82
+ {"train_lr": 8.234011087850579e-05, "train_loss": 2.832480542442138, "train_task_loss": 2.6040796041131307, "train_bpp_loss": 0.3165612864526931, "train_patch_loss": 1.0063268801385015, "train_token_loss": 0.7382693186433589, "train_fea_loss": 0.945431755996383, "epoch": 74, "n_parameters": 144845568}
83
+ {"train_lr": 8.160382939503717e-05, "train_loss": 2.799674530850826, "train_task_loss": 2.5721579685592824, "train_bpp_loss": 0.3175276949637674, "train_patch_loss": 0.9966744806246756, "train_token_loss": 0.7291926598819576, "train_fea_loss": 0.9309199213762757, "epoch": 75, "n_parameters": 144845568}
84
+ {"train_lr": 8.011638332509435e-05, "train_loss": 2.8255829363078666, "train_task_loss": 2.578499776871799, "train_bpp_loss": 0.2999681402433315, "train_patch_loss": 0.9987926142361703, "train_token_loss": 0.7287038154274725, "train_fea_loss": 0.9361036925038381, "epoch": 76, "n_parameters": 144845568}
85
+ {"train_lr": 8.011638332509435e-05, "train_loss": 2.840959037218591, "train_task_loss": 2.5953164593778926, "train_bpp_loss": 0.2960671754754919, "train_patch_loss": 1.0047289559068224, "train_token_loss": 0.7329247621774709, "train_fea_loss": 0.9434290263216666, "epoch": 77, "n_parameters": 144845568}
86
+ {"train_lr": 7.93655857436786e-05, "train_loss": 2.8381004938922767, "train_task_loss": 2.5934427621946345, "train_bpp_loss": 0.29446334153194903, "train_patch_loss": 1.0050567829807242, "train_token_loss": 0.7319110037146999, "train_fea_loss": 0.9421224841419206, "epoch": 78, "n_parameters": 144845568}
87
+ {"train_lr": 7.86103184125689e-05, "train_loss": 2.847826922665254, "train_task_loss": 2.603588223600273, "train_bpp_loss": 0.29271904649034286, "train_patch_loss": 1.0064198724385038, "train_token_loss": 0.7354645022689152, "train_fea_loss": 0.9478742439049682, "epoch": 79, "n_parameters": 144845568}
88
+ {"train_lr": 7.785076768264985e-05, "train_loss": 2.8496559623369784, "train_task_loss": 2.6059850953447876, "train_bpp_loss": 0.29151453778183434, "train_patch_loss": 1.0084720284899147, "train_token_loss": 0.7349831929711772, "train_fea_loss": 0.9487828727950045, "epoch": 80, "n_parameters": 144845568}
89
+ {"train_lr": 7.708712096171631e-05, "train_loss": 2.8463599768867023, "train_task_loss": 2.6032515673292913, "train_bpp_loss": 0.2907897324716712, "train_patch_loss": 1.0084952734437564, "train_token_loss": 0.7339370518864737, "train_fea_loss": 0.9469011765759554, "epoch": 81, "n_parameters": 144845568}
90
+ {"train_lr": 7.631956666815207e-05, "train_loss": 2.8331575906015845, "train_task_loss": 2.5902766323418356, "train_bpp_loss": 0.29128668203445185, "train_patch_loss": 1.0040304742896442, "train_token_loss": 0.7303849485070764, "train_fea_loss": 0.9414473412786242, "epoch": 82, "n_parameters": 144845568}
91
+ {"train_lr": 7.554829418450765e-05, "train_loss": 6.022797273306681, "train_task_loss": 5.747255473551776, "train_bpp_loss": 0.15595893754092843, "train_patch_loss": 1.9295167815905037, "train_token_loss": 1.9044995418676047, "train_fea_loss": 2.1045631034237196, "epoch": 83, "n_parameters": 144845568}
92
+ {"train_lr": 7.477349381072652e-05, "train_loss": 6.33198429772751, "train_task_loss": 6.103147584757359, "train_bpp_loss": 0.04776455266578447, "train_patch_loss": 2.038426331773722, "train_token_loss": 2.034283296100134, "train_fea_loss": 2.2334817904058837, "epoch": 84, "n_parameters": 144845568}
93
+ {"train_lr": 7.399535671720344e-05, "train_loss": 6.269092192288211, "train_task_loss": 6.053333911088874, "train_bpp_loss": 0.029455166824030238, "train_patch_loss": 2.0307816479676704, "train_token_loss": 2.024027929974081, "train_fea_loss": 2.1983768022049675, "epoch": 85, "n_parameters": 144845568}
94
+ {"train_lr": 7.321407489761549e-05, "train_loss": 6.215953613303119, "train_task_loss": 6.003636568552442, "train_bpp_loss": 0.026147063460792316, "train_patch_loss": 2.018133198333194, "train_token_loss": 2.0035276831500677, "train_fea_loss": 2.1801732855288387, "epoch": 86, "n_parameters": 144845568}
95
+ {"train_lr": 7.242984112156774e-05, "train_loss": 5.868365104446451, "train_task_loss": 5.527729460429088, "train_bpp_loss": 0.2684132926173316, "train_patch_loss": 1.9746971274469254, "train_token_loss": 1.5961082754723674, "train_fea_loss": 2.1526165047361197, "epoch": 87, "n_parameters": 144845568}
96
+ {"train_lr": 7.16428488870196e-05, "train_loss": 5.4557995484958735, "train_task_loss": 5.110851676522685, "train_bpp_loss": 0.2943931522651305, "train_patch_loss": 1.8687771737611265, "train_token_loss": 1.3823193566252787, "train_fea_loss": 2.0457307090067105, "epoch": 88, "n_parameters": 144845568}
97
+ {"train_lr": 7.085329237251759e-05, "train_loss": 6.162637532811513, "train_task_loss": 5.88701508848144, "train_bpp_loss": 0.13411895593832643, "train_patch_loss": 2.0334672068672286, "train_token_loss": 1.821566376444265, "train_fea_loss": 2.2351796914660316, "epoch": 89, "n_parameters": 144845568}
98
+ {"train_lr": 7.006136638931818e-05, "train_loss": 5.024170764112215, "train_task_loss": 4.809058998057025, "train_bpp_loss": 0.05317469353166999, "train_patch_loss": 1.8792113221967393, "train_token_loss": 1.0658738444262235, "train_fea_loss": 2.0503712694118206, "epoch": 90, "n_parameters": 144845568}
99
+ {"train_lr": 6.926726633331106e-05, "train_loss": 4.723379994860942, "train_task_loss": 4.515086256485763, "train_bpp_loss": 0.057305316362173765, "train_patch_loss": 1.798007144451999, "train_token_loss": 0.9435908788395264, "train_fea_loss": 1.9508371142764314, "epoch": 91, "n_parameters": 144845568}
100
+ {"train_lr": 6.847118813679865e-05, "train_loss": 5.031990508435013, "train_task_loss": 4.765979826932759, "train_bpp_loss": 0.15900398099812293, "train_patch_loss": 1.824176159184828, "train_token_loss": 1.1403188319715987, "train_fea_loss": 1.981633366195025, "epoch": 92, "n_parameters": 144845568}
101
+ {"train_lr": 7.631956666815207e-05, "train_loss": 2.814900835471259, "train_task_loss": 2.5911692696259463, "train_bpp_loss": 0.30722877928625336, "train_patch_loss": 1.003347651021247, "train_token_loss": 0.7330356562840674, "train_fea_loss": 0.854785952643364, "epoch": 81, "n_parameters": 144845568}
102
+ {"train_lr": 7.631956666815207e-05, "train_loss": 2.8142806054126446, "train_task_loss": 2.5897073253131597, "train_bpp_loss": 0.30940049388993096, "train_patch_loss": 1.0041070785507453, "train_token_loss": 0.7321698584502263, "train_fea_loss": 0.8534303789900289, "epoch": 82, "n_parameters": 144845568}
103
+ {"train_lr": 7.554829418450765e-05, "train_loss": 2.816245504384704, "train_task_loss": 2.591411675659301, "train_bpp_loss": 0.30987349456864044, "train_patch_loss": 1.003780739015634, "train_token_loss": 0.7337235686563545, "train_fea_loss": 0.8539073579683364, "epoch": 83, "n_parameters": 144845568}
104
+ {"train_lr": 7.477349381072652e-05, "train_loss": 2.800685787649392, "train_task_loss": 2.5759294760688176, "train_bpp_loss": 0.3108195169511691, "train_patch_loss": 0.9979896100653376, "train_token_loss": 0.729064737578376, "train_fea_loss": 0.8488751188081374, "epoch": 84, "n_parameters": 144845568}
105
+ {"train_lr": 7.399535671720344e-05, "train_loss": 2.78943671736357, "train_task_loss": 2.5648709468239788, "train_bpp_loss": 0.31156991790573624, "train_patch_loss": 0.9950137249611729, "train_token_loss": 0.7262643160526272, "train_fea_loss": 0.8435928968185608, "epoch": 85, "n_parameters": 144845568}
106
+ {"train_lr": 7.321407489761549e-05, "train_loss": 2.844264631443244, "train_task_loss": 2.614422207854563, "train_bpp_loss": 0.319169871819516, "train_patch_loss": 1.0113446999082176, "train_token_loss": 0.7409178713242784, "train_fea_loss": 0.8621596260023096, "epoch": 86, "n_parameters": 144845568}
107
+ {"train_lr": 7.242984112156774e-05, "train_loss": 2.851643653814312, "train_task_loss": 2.6122646205198707, "train_bpp_loss": 0.33907441596994153, "train_patch_loss": 1.01671234990651, "train_token_loss": 0.7275969981968545, "train_fea_loss": 0.8679552621183564, "epoch": 87, "n_parameters": 144845568}
108
+ {"train_lr": 7.16428488870196e-05, "train_loss": 2.778331270660285, "train_task_loss": 2.552063302813674, "train_bpp_loss": 0.31613819805538945, "train_patch_loss": 0.9926303882571719, "train_token_loss": 0.7193752784768848, "train_fea_loss": 0.8400576262571793, "epoch": 88, "n_parameters": 144845568}
109
+ {"train_lr": 7.085329237251759e-05, "train_loss": 2.7897735733261926, "train_task_loss": 2.565236364247845, "train_bpp_loss": 0.311233962220486, "train_patch_loss": 0.9957534326776434, "train_token_loss": 0.7246638578329262, "train_fea_loss": 0.8448190648729602, "epoch": 89, "n_parameters": 144845568}
110
+ {"train_lr": 7.006136638931818e-05, "train_loss": 2.7982831528093173, "train_task_loss": 2.574218000189292, "train_bpp_loss": 0.3094736575059664, "train_patch_loss": 1.000583418184887, "train_token_loss": 0.7256146864831555, "train_fea_loss": 0.8480198857010447, "epoch": 90, "n_parameters": 144845568}
111
+ {"train_lr": 6.926726633331106e-05, "train_loss": 3.415653905788843, "train_task_loss": 2.975828631783275, "train_bpp_loss": 0.7560260864588872, "train_patch_loss": 1.1288673198967278, "train_token_loss": 0.8508261374780814, "train_fea_loss": 0.9961351647445981, "epoch": 91, "n_parameters": 144845568}
112
+ {"train_lr": 6.847118813679865e-05, "train_loss": 6.568326573541982, "train_task_loss": 6.1273969157779815, "train_bpp_loss": 0.5257196640345934, "train_patch_loss": 2.043614938980241, "train_token_loss": 2.040224435022302, "train_fea_loss": 2.04355752422548, "epoch": 92, "n_parameters": 144845568}
113
+ {"train_lr": 6.767332822016792e-05, "train_loss": 6.867288129757062, "train_task_loss": 6.089932662566646, "train_bpp_loss": 1.2734335873903297, "train_patch_loss": 2.0432088341144063, "train_token_loss": 2.003620763126162, "train_fea_loss": 2.043103043592984, "epoch": 93, "n_parameters": 144845568}
114
+ {"train_lr": 6.687388344341571e-05, "train_loss": 6.613793869366606, "train_task_loss": 4.8460818514299335, "train_bpp_loss": 3.475327093109524, "train_patch_loss": 2.043151048612859, "train_token_loss": 0.7647827833252167, "train_fea_loss": 2.038148040657671, "epoch": 94, "n_parameters": 144845568}
115
+ {"train_lr": 6.607305105757049e-05, "train_loss": 6.583523838229174, "train_task_loss": 4.818476951540374, "train_bpp_loss": 3.4695032598624986, "train_patch_loss": 2.0428966481319004, "train_token_loss": 0.737876279511862, "train_fea_loss": 2.037704043197546, "epoch": 95, "n_parameters": 144845568}
116
+ {"train_lr": 6.5271028656055e-05, "train_loss": 6.5769114312794, "train_task_loss": 4.8131763801979215, "train_bpp_loss": 3.46658337029586, "train_patch_loss": 2.0428622997046517, "train_token_loss": 0.7325888925904243, "train_fea_loss": 2.0377252054997057, "epoch": 96, "n_parameters": 144845568}
117
+ {"train_lr": 6.446801412587525e-05, "train_loss": 6.572428462227329, "train_task_loss": 4.809816579738681, "train_bpp_loss": 3.4640789968754464, "train_patch_loss": 2.0429170189662087, "train_token_loss": 0.7291364202284806, "train_fea_loss": 2.0377631590699408, "epoch": 97, "n_parameters": 144845568}
118
+ {"train_lr": 6.36642055988671e-05, "train_loss": 6.5666217358349614, "train_task_loss": 4.80531757448217, "train_bpp_loss": 3.461199431673443, "train_patch_loss": 2.0427514293746984, "train_token_loss": 0.7249221600507267, "train_fea_loss": 2.0376440027742078, "epoch": 98, "n_parameters": 144845568}
119
+ {"train_lr": 6.285980140274965e-05, "train_loss": 6.492165406455668, "train_task_loss": 4.802894322652754, "train_bpp_loss": 3.3011300552377314, "train_patch_loss": 2.042857281175806, "train_token_loss": 0.72241166340549, "train_fea_loss": 2.0376253973210243, "epoch": 99, "n_parameters": 144845568}
120
+ {"train_lr": 6.205500001222403e-05, "train_loss": 6.469163187473512, "train_task_loss": 4.799823627161751, "train_bpp_loss": 3.2568444883723338, "train_patch_loss": 2.042693032090255, "train_token_loss": 0.7195353948642953, "train_fea_loss": 2.0375952215226505, "epoch": 100, "n_parameters": 144845568}
121
+ {"train_lr": 6.847118813679865e-05, "train_loss": 2.708930113305934, "train_task_loss": 2.5679492560668673, "train_bpp_loss": 0.3132908021231635, "train_patch_loss": 0.9976560541538062, "train_token_loss": 0.7208710700648723, "train_fea_loss": 0.8494221233278155, "epoch": 91, "n_parameters": 144845568}
122
+ {"train_lr": 6.847118813679865e-05, "train_loss": 2.699149013628705, "train_task_loss": 2.5613451289723246, "train_bpp_loss": 0.3062308639035576, "train_patch_loss": 0.9944496202400978, "train_token_loss": 0.7207848698646724, "train_fea_loss": 0.846110629630264, "epoch": 92, "n_parameters": 144845568}
123
+ {"train_lr": 6.767332822016792e-05, "train_loss": 2.695356560687867, "train_task_loss": 2.5580616601490433, "train_bpp_loss": 0.3050997874931805, "train_patch_loss": 0.9943495049304385, "train_token_loss": 0.7195412099754496, "train_fea_loss": 0.8441709351862113, "epoch": 93, "n_parameters": 144845568}
124
+ {"train_lr": 6.687388344341571e-05, "train_loss": 2.709555460635921, "train_task_loss": 2.572864127852362, "train_bpp_loss": 0.3037585253364963, "train_patch_loss": 0.9997199533967306, "train_token_loss": 0.7228693755843835, "train_fea_loss": 0.850274788731669, "epoch": 94, "n_parameters": 144845568}
125
+ {"train_lr": 6.607305105757049e-05, "train_loss": 2.7318896783841886, "train_task_loss": 2.59231706789965, "train_bpp_loss": 0.31016136522402615, "train_patch_loss": 1.0061367933327954, "train_token_loss": 0.7272278621984865, "train_fea_loss": 0.858952402501262, "epoch": 95, "n_parameters": 144845568}
126
+ {"train_lr": 6.5271028656055e-05, "train_loss": 2.7111547499162545, "train_task_loss": 2.569625718514625, "train_bpp_loss": 0.31450896638912335, "train_patch_loss": 1.0005275200364543, "train_token_loss": 0.7153408098586207, "train_fea_loss": 0.8537573801527778, "epoch": 96, "n_parameters": 144845568}
127
+ {"train_lr": 6.446801412587525e-05, "train_loss": 2.707958362782173, "train_task_loss": 2.570670375998, "train_bpp_loss": 0.30508442372374006, "train_patch_loss": 0.9985086321964752, "train_token_loss": 0.7216268923039809, "train_fea_loss": 0.8505348416952063, "epoch": 97, "n_parameters": 144845568}
128
+ {"train_lr": 6.36642055988671e-05, "train_loss": 2.697063536604317, "train_task_loss": 2.560573679866265, "train_bpp_loss": 0.3033108016983014, "train_patch_loss": 0.9955314651368369, "train_token_loss": 0.7185008986127213, "train_fea_loss": 0.846541307080421, "epoch": 98, "n_parameters": 144845568}
129
+ {"train_lr": 6.285980140274965e-05, "train_loss": 2.679979503681834, "train_task_loss": 2.5433403860768684, "train_bpp_loss": 0.30364249174190544, "train_patch_loss": 0.989279427558934, "train_token_loss": 0.7147660507581575, "train_fea_loss": 0.8392948986810568, "epoch": 99, "n_parameters": 144845568}
130
+ {"train_lr": 6.205500001222403e-05, "train_loss": 2.818446820222145, "train_task_loss": 2.6755781511751584, "train_bpp_loss": 0.3174859403087635, "train_patch_loss": 1.0392651655052385, "train_token_loss": 0.7372288242098596, "train_fea_loss": 0.8990841494868699, "epoch": 100, "n_parameters": 144845568}
131
+ {"train_lr": 6.12500000000064e-05, "train_loss": 2.7441475540310214, "train_task_loss": 2.597907278704629, "train_bpp_loss": 0.32497839667171025, "train_patch_loss": 1.0195179717119436, "train_token_loss": 0.7045413070298034, "train_fea_loss": 0.8738479904709853, "epoch": 101, "n_parameters": 144845568}
132
+ {"train_lr": 6.044499998777186e-05, "train_loss": 2.7079060795299297, "train_task_loss": 2.565669994283137, "train_bpp_loss": 0.31608019693479905, "train_patch_loss": 1.0030313897252583, "train_token_loss": 0.7082743480091365, "train_fea_loss": 0.8543642482986089, "epoch": 102, "n_parameters": 144845568}
133
+ {"train_lr": 5.964019859724661e-05, "train_loss": 2.696671325293519, "train_task_loss": 2.557040445006294, "train_bpp_loss": 0.3102908528876676, "train_patch_loss": 0.9976166626081157, "train_token_loss": 0.7102920439817922, "train_fea_loss": 0.8491317290227584, "epoch": 103, "n_parameters": 144845568}
134
+ {"train_lr": 5.8835794401133974e-05, "train_loss": 2.691563926851578, "train_task_loss": 2.553079867954377, "train_bpp_loss": 0.3077423636526834, "train_patch_loss": 0.9958474041438468, "train_token_loss": 0.7099756244815689, "train_fea_loss": 0.8472568294439885, "epoch": 104, "n_parameters": 144845568}
135
+ {"train_lr": 5.8031985874119795e-05, "train_loss": 2.692192362849232, "train_task_loss": 2.5543478446916565, "train_bpp_loss": 0.30632116043788316, "train_patch_loss": 0.9955702859291927, "train_token_loss": 0.7111749301689099, "train_fea_loss": 0.8476026195617352, "epoch": 105, "n_parameters": 144845568}
136
+ {"train_lr": 5.722897134394433e-05, "train_loss": 2.6670049736075265, "train_task_loss": 2.5291183430889097, "train_bpp_loss": 0.30641474233207694, "train_patch_loss": 0.9873280670623378, "train_token_loss": 0.7049247939295942, "train_fea_loss": 0.8368654731232271, "epoch": 106, "n_parameters": 144845568}
137
+ {"train_lr": 5.642694894242339e-05, "train_loss": 2.6757864850066264, "train_task_loss": 2.538502600463889, "train_bpp_loss": 0.30507530629348983, "train_patch_loss": 0.9907261180832124, "train_token_loss": 0.7074770695045233, "train_fea_loss": 0.8402994043602467, "epoch": 107, "n_parameters": 144845568}
138
+ {"train_lr": 5.562611655657961e-05, "train_loss": 2.7075679923811977, "train_task_loss": 2.5688214136613645, "train_bpp_loss": 0.30832573917316947, "train_patch_loss": 1.00075649227545, "train_token_loss": 0.7133955942328385, "train_fea_loss": 0.8546693176053447, "epoch": 108, "n_parameters": 144845568}
139
+ {"train_lr": 5.642694894242339e-05, "train_loss": 2.668979725284542, "train_task_loss": 2.5311407839669453, "train_bpp_loss": 0.30630876569577653, "train_patch_loss": 0.9885714473649502, "train_token_loss": 0.7049617424553676, "train_fea_loss": 0.8376075848758363, "epoch": 106, "n_parameters": 144845568}
140
+ {"train_lr": 5.642694894242339e-05, "train_loss": 2.683824887423278, "train_task_loss": 2.5468173437535193, "train_bpp_loss": 0.3044612155992725, "train_patch_loss": 0.9947493373275661, "train_token_loss": 0.708695182786714, "train_fea_loss": 0.8433728142076438, "epoch": 107, "n_parameters": 144845568}
141
+ {"train_lr": 5.562611655657961e-05, "train_loss": 2.6808553397548285, "train_task_loss": 2.544291126874568, "train_bpp_loss": 0.30347603688025465, "train_patch_loss": 0.9940348935472093, "train_token_loss": 0.7075042654525676, "train_fea_loss": 0.8427519599440322, "epoch": 108, "n_parameters": 144845568}
142
+ {"train_lr": 5.482667177983261e-05, "train_loss": 4.911457611171962, "train_task_loss": 4.756159201309049, "train_bpp_loss": 0.3451075883260068, "train_patch_loss": 1.923462125296287, "train_token_loss": 0.9228969771633867, "train_fea_loss": 1.909800109163159, "epoch": 109, "n_parameters": 144845568}
143
+ {"train_lr": 5.402881186319929e-05, "train_loss": 6.119308372142075, "train_task_loss": 6.05821056978451, "train_bpp_loss": 0.13577289702430786, "train_patch_loss": 2.0250163967821666, "train_token_loss": 2.0047141713507886, "train_fea_loss": 2.028479987103269, "epoch": 110, "n_parameters": 144845568}
144
+ {"train_lr": 5.402881186319929e-05, "train_loss": 2.8206423869092974, "train_task_loss": 2.5848962855514146, "train_bpp_loss": 0.3167879025416105, "train_patch_loss": 1.0065730715067767, "train_token_loss": 0.721063517386929, "train_fea_loss": 0.8572596858927815, "epoch": 109, "n_parameters": 144845568}
145
+ {"train_lr": 5.402881186319929e-05, "train_loss": 2.7694720208591264, "train_task_loss": 2.536599607177847, "train_bpp_loss": 0.31375666602561586, "train_patch_loss": 0.9914828701552263, "train_token_loss": 0.7029595081839595, "train_fea_loss": 0.8421572199877467, "epoch": 110, "n_parameters": 144845568}
146
+ {"train_lr": 5.323273366669127e-05, "train_loss": 2.754181052810497, "train_task_loss": 2.52081877488098, "train_bpp_loss": 0.31649657639471607, "train_patch_loss": 0.9853969901527683, "train_token_loss": 0.7020405733832817, "train_fea_loss": 0.8333812024247803, "epoch": 111, "n_parameters": 144845568}
147
+ {"train_lr": 5.24386336106797e-05, "train_loss": 2.778903913664089, "train_task_loss": 2.54480576113188, "train_bpp_loss": 0.31611214793115244, "train_patch_loss": 0.9951485451065201, "train_token_loss": 0.7078502580961235, "train_fea_loss": 0.8418069494334306, "epoch": 112, "n_parameters": 144845568}
148
+ {"train_lr": 5.1646707627478925e-05, "train_loss": 2.8131009751854896, "train_task_loss": 2.5725743175803615, "train_bpp_loss": 0.32791142306922616, "train_patch_loss": 1.0041894356025567, "train_token_loss": 0.7132421619497793, "train_fea_loss": 0.8551427105003946, "epoch": 113, "n_parameters": 144845568}
149
+ {"train_lr": 5.0857151112976574e-05, "train_loss": 2.7399135867147137, "train_task_loss": 2.5091776530072987, "train_bpp_loss": 0.31148358721875674, "train_patch_loss": 0.9819968594774175, "train_token_loss": 0.6978094259017913, "train_fea_loss": 0.8293713586163713, "epoch": 114, "n_parameters": 144845568}
150
+ {"train_lr": 5.007015887842505e-05, "train_loss": 2.8684129846825015, "train_task_loss": 2.616686708519427, "train_bpp_loss": 0.3489265601872076, "train_patch_loss": 1.0203611283837783, "train_token_loss": 0.7224983579727063, "train_fea_loss": 0.8738272108821025, "epoch": 115, "n_parameters": 144845568}
151
+ {"train_lr": 4.928592510238729e-05, "train_loss": 2.739907768231144, "train_task_loss": 2.50749061155698, "train_bpp_loss": 0.3151009082838857, "train_patch_loss": 0.9832527024260492, "train_token_loss": 0.695053851908649, "train_fea_loss": 0.8291840486706411, "epoch": 116, "n_parameters": 144845568}
152
+ {"train_lr": 4.850464328279906e-05, "train_loss": 2.7357478904620494, "train_task_loss": 2.5043268277079798, "train_bpp_loss": 0.31352339563690285, "train_patch_loss": 0.9800204708680904, "train_token_loss": 0.6976143020144898, "train_fea_loss": 0.8266920464814639, "epoch": 117, "n_parameters": 144845568}
153
+ {"train_lr": 4.7726506189276635e-05, "train_loss": 2.755742895097541, "train_task_loss": 2.5216859226211086, "train_bpp_loss": 0.31784185127531595, "train_patch_loss": 0.9866870453634815, "train_token_loss": 0.701121356496833, "train_fea_loss": 0.833877512173419, "epoch": 118, "n_parameters": 144845568}
154
+ {"train_lr": 4.69517058154867e-05, "train_loss": 2.7318975441247155, "train_task_loss": 2.5014077417141527, "train_bpp_loss": 0.31166150058537045, "train_patch_loss": 0.9790000828715573, "train_token_loss": 0.6965634005313976, "train_fea_loss": 0.8258442498243256, "epoch": 119, "n_parameters": 144845568}
155
+ {"train_lr": 4.6180433331847694e-05, "train_loss": 2.7518172350307877, "train_task_loss": 2.5213089466577383, "train_bpp_loss": 0.31003568373504375, "train_patch_loss": 0.9861942503458376, "train_token_loss": 0.7014625150290468, "train_fea_loss": 0.8336521727021793, "epoch": 120, "n_parameters": 144845568}
156
+ {"train_lr": 4.541287903828179e-05, "train_loss": 2.7335522819390827, "train_task_loss": 2.50380109216598, "train_bpp_loss": 0.3096768988276289, "train_patch_loss": 0.9810024427442099, "train_token_loss": 0.6958675646743507, "train_fea_loss": 0.8269310759501063, "epoch": 121, "n_parameters": 144845568}
157
+ {"train_lr": 4.4649232317341524e-05, "train_loss": 2.733630260037218, "train_task_loss": 2.503839423107229, "train_bpp_loss": 0.309786735956805, "train_patch_loss": 0.9804957953708278, "train_token_loss": 0.6961014977024256, "train_fea_loss": 0.8272421213170643, "epoch": 122, "n_parameters": 144845568}
158
+ {"train_lr": 4.3889681587425266e-05, "train_loss": 2.8017713390469408, "train_task_loss": 2.557888786087362, "train_bpp_loss": 0.3367434704309995, "train_patch_loss": 0.9978989840482922, "train_token_loss": 0.7109271357421109, "train_fea_loss": 0.8490626564633539, "epoch": 123, "n_parameters": 144845568}
159
+ {"train_lr": 4.313441425631543e-05, "train_loss": 2.7663599788803133, "train_task_loss": 2.515742145719931, "train_bpp_loss": 0.35449836285320296, "train_patch_loss": 0.9864850308973905, "train_token_loss": 0.6938689482427365, "train_fea_loss": 0.8353881579388281, "epoch": 124, "n_parameters": 144845568}
160
+ {"train_lr": 4.238361667491207e-05, "train_loss": 2.7239683468153864, "train_task_loss": 2.4941051970068497, "train_bpp_loss": 0.31064632278029547, "train_patch_loss": 0.9776195450965092, "train_token_loss": 0.6926572750338155, "train_fea_loss": 0.8238283673957955, "epoch": 125, "n_parameters": 144845568}
161
+ {"train_lr": 4.1637474091286196e-05, "train_loss": 2.731808961188193, "train_task_loss": 2.502355479927872, "train_bpp_loss": 0.30911144960618064, "train_patch_loss": 0.9799622440648129, "train_token_loss": 0.6952871139065253, "train_fea_loss": 0.827106113864971, "epoch": 126, "n_parameters": 144845568}
162
+ {"train_lr": 4.089617060496659e-05, "train_loss": 2.7330640450530916, "train_task_loss": 2.5038258624519947, "train_bpp_loss": 0.3084725750156855, "train_patch_loss": 0.9800150552474195, "train_token_loss": 0.6953135876708644, "train_fea_loss": 0.8284972119714609, "epoch": 127, "n_parameters": 144845568}
163
+ {"train_lr": 4.015988912148501e-05, "train_loss": 2.7349733328558417, "train_task_loss": 2.4990430673856814, "train_bpp_loss": 0.3236985370981322, "train_patch_loss": 0.9784446784811054, "train_token_loss": 0.6937227977542497, "train_fea_loss": 0.8268755828197911, "epoch": 128, "n_parameters": 144845568}
164
+ {"train_lr": 3.942881130728865e-05, "train_loss": 2.7178904607856302, "train_task_loss": 2.4889944930085175, "train_bpp_loss": 0.3089114015997885, "train_patch_loss": 0.9757260797004834, "train_token_loss": 0.6912759416524491, "train_fea_loss": 0.821992463240181, "epoch": 129, "n_parameters": 144845568}
165
+ {"train_lr": 3.870311754488397e-05, "train_loss": 2.7325405643301472, "train_task_loss": 2.5016696595584604, "train_bpp_loss": 0.31216120841294276, "train_patch_loss": 0.9799974317856722, "train_token_loss": 0.6937006032407034, "train_fea_loss": 0.8279716155806677, "epoch": 130, "n_parameters": 144845568}
166
+ {"train_lr": 3.798298688834852e-05, "train_loss": 2.7105408391917494, "train_task_loss": 2.4821574514736233, "train_bpp_loss": 0.30818621284714776, "train_patch_loss": 0.9738334429783108, "train_token_loss": 0.6881637847457501, "train_fea_loss": 0.820160214387202, "epoch": 131, "n_parameters": 144845568}
167
+ {"train_lr": 3.726859701914403e-05, "train_loss": 2.7591426904252967, "train_task_loss": 2.5228572249930683, "train_bpp_loss": 0.3226083974237833, "train_patch_loss": 0.9864406956994026, "train_token_loss": 0.7006216509741165, "train_fea_loss": 0.8357948688053184, "epoch": 132, "n_parameters": 144845568}
168
+ {"train_lr": 3.656012420228689e-05, "train_loss": 2.68967245043003, "train_task_loss": 2.460287703357512, "train_bpp_loss": 0.3120911338152163, "train_patch_loss": 0.9669189435883249, "train_token_loss": 0.6814111765313278, "train_fea_loss": 0.81195757473982, "epoch": 133, "n_parameters": 144845568}
169
+ {"train_lr": 3.5857743242838835e-05, "train_loss": 2.7014216272432883, "train_task_loss": 2.472743094181843, "train_bpp_loss": 0.30973817068163295, "train_patch_loss": 0.9706215003429414, "train_token_loss": 0.6868141930465468, "train_fea_loss": 0.8153073933484744, "epoch": 134, "n_parameters": 144845568}
170
+ {"train_lr": 3.516162744279572e-05, "train_loss": 2.705487959069028, "train_task_loss": 2.4765327244067934, "train_bpp_loss": 0.3098553310503527, "train_patch_loss": 0.9728047358189269, "train_token_loss": 0.68612423896057, "train_fea_loss": 0.8176037415066998, "epoch": 135, "n_parameters": 144845568}
171
+ {"train_lr": 3.447194855830639e-05, "train_loss": 2.7113154404383484, "train_task_loss": 2.478181616079321, "train_bpp_loss": 0.31905700233530443, "train_patch_loss": 0.9727772305422514, "train_token_loss": 0.68701637881234, "train_fea_loss": 0.818387999384702, "epoch": 136, "n_parameters": 144845568}
172
+ {"train_lr": 3.378887675732868e-05, "train_loss": 2.6797968578620925, "train_task_loss": 2.451133315422409, "train_bpp_loss": 0.3113255005260612, "train_patch_loss": 0.9640113939207307, "train_token_loss": 0.6797902039285723, "train_fea_loss": 0.8073317096557477, "epoch": 137, "n_parameters": 144845568}
173
+ {"train_lr": 3.311258057759679e-05, "train_loss": 2.694638561546731, "train_task_loss": 2.4661759902551164, "train_bpp_loss": 0.3096901442793062, "train_patch_loss": 0.9691699242545403, "train_token_loss": 0.6841340864242481, "train_fea_loss": 0.8128719716969368, "epoch": 138, "n_parameters": 144845568}
174
+ {"train_lr": 3.244322688507758e-05, "train_loss": 2.697916217782586, "train_task_loss": 2.4694298465224764, "train_bpp_loss": 0.3092930535785854, "train_patch_loss": 0.9713080364482687, "train_token_loss": 0.6833381301698853, "train_fea_loss": 0.8147836721902021, "epoch": 139, "n_parameters": 144845568}
175
+ {"train_lr": 3.1780980832784374e-05, "train_loss": 2.697745393452456, "train_task_loss": 2.465554346515835, "train_bpp_loss": 0.3177864467847994, "train_patch_loss": 0.9703194021450565, "train_token_loss": 0.6818096557480219, "train_fea_loss": 0.8134252810515732, "epoch": 140, "n_parameters": 144845568}
176
+ {"train_lr": 3.112600582001298e-05, "train_loss": 2.692317468710512, "train_task_loss": 2.4639319612039365, "train_bpp_loss": 0.3095171678927448, "train_patch_loss": 0.9694184379415404, "train_token_loss": 0.6818745398862684, "train_fea_loss": 0.8126389763923251, "epoch": 141, "n_parameters": 144845568}
177
+ {"train_lr": 3.047846345205177e-05, "train_loss": 2.7029235281175277, "train_task_loss": 2.47447449285135, "train_bpp_loss": 0.30879852916527206, "train_patch_loss": 0.9717401454551257, "train_token_loss": 0.6846787373957445, "train_fea_loss": 0.8180556024426965, "epoch": 142, "n_parameters": 144845568}
178
+ {"train_lr": 2.9838513500286588e-05, "train_loss": 2.6954629459469723, "train_task_loss": 2.461703634662308, "train_bpp_loss": 0.3215753604463643, "train_patch_loss": 0.9685670451911019, "train_token_loss": 0.6806938354801432, "train_fea_loss": 0.8124427453947546, "epoch": 143, "n_parameters": 144845568}
179
+ {"train_lr": 2.920631386279756e-05, "train_loss": 2.6789074894978846, "train_task_loss": 2.4506222840836296, "train_bpp_loss": 0.3103648301312249, "train_patch_loss": 0.9643291673049045, "train_token_loss": 0.6781998446358622, "train_fea_loss": 0.8080932640043987, "epoch": 144, "n_parameters": 144845568}
180
+ {"train_lr": 2.8582020525382766e-05, "train_loss": 2.698681264338519, "train_task_loss": 2.465655054361057, "train_bpp_loss": 0.3195096456470321, "train_patch_loss": 0.9703498526188217, "train_token_loss": 0.6807158637626601, "train_fea_loss": 0.8145893300555164, "epoch": 145, "n_parameters": 144845568}
181
+ {"train_lr": 2.7965787523079142e-05, "train_loss": 2.6876495937845832, "train_task_loss": 2.4591161769142538, "train_bpp_loss": 0.3100968778760974, "train_patch_loss": 0.9675739941743317, "train_token_loss": 0.6793179688975215, "train_fea_loss": 0.8122242059786435, "epoch": 146, "n_parameters": 144845568}
182
+ {"train_lr": 2.7357766902161244e-05, "train_loss": 2.686950199314945, "train_task_loss": 2.4579590937520246, "train_bpp_loss": 0.31120623359166316, "train_patch_loss": 0.9676960334248788, "train_token_loss": 0.6789913053403227, "train_fea_loss": 0.8112717479337855, "epoch": 147, "n_parameters": 144845568}
183
+ {"train_lr": 2.616696082115359e-05, "train_loss": 2.6700563189937627, "train_task_loss": 2.441784028437355, "train_bpp_loss": 0.31092876969329675, "train_patch_loss": 0.9618561120337541, "train_token_loss": 0.6746953771956604, "train_fea_loss": 0.8052325316769577, "epoch": 148, "n_parameters": 144845568}
184
+ {"train_lr": 2.616696082115359e-05, "train_loss": 2.6674406088138225, "train_task_loss": 2.4389501352229424, "train_bpp_loss": 0.3116684112110459, "train_patch_loss": 0.9607490962326563, "train_token_loss": 0.6741545937120093, "train_fea_loss": 0.8040464378075955, "epoch": 149, "n_parameters": 144845568}
185
+ {"train_lr": 2.558446917464184e-05, "train_loss": 2.6966678162129947, "train_task_loss": 2.4605731371191624, "train_bpp_loss": 0.32649256723287473, "train_patch_loss": 0.9711854177922439, "train_token_loss": 0.6771108717668792, "train_fea_loss": 0.81227684003129, "epoch": 150, "n_parameters": 144845568}
186
+ {"train_lr": 2.5010777464192224e-05, "train_loss": 2.649356059051103, "train_task_loss": 2.4205873620649463, "train_bpp_loss": 0.3137229763348844, "train_patch_loss": 0.9557904277322925, "train_token_loss": 0.6687184504994885, "train_fea_loss": 0.7960784757974372, "epoch": 151, "n_parameters": 144845568}
187
+ {"train_lr": 2.444602723963776e-05, "train_loss": 2.6504098884582663, "train_task_loss": 2.421797823491428, "train_bpp_loss": 0.3133131659475805, "train_patch_loss": 0.9554594567903191, "train_token_loss": 0.6693732499343743, "train_fea_loss": 0.7969651087689743, "epoch": 152, "n_parameters": 144845568}
188
+ {"train_lr": 2.3343906382349e-05, "train_loss": 2.7114228718429447, "train_task_loss": 2.458482252813572, "train_bpp_loss": 0.34088637049106674, "train_patch_loss": 0.9697519682405843, "train_token_loss": 0.6721772563775523, "train_fea_loss": 0.8165530218369812, "epoch": 153, "n_parameters": 144845568}
189
+ {"train_lr": 2.3343906382349e-05, "train_loss": 2.650369150729345, "train_task_loss": 2.4105861190197277, "train_bpp_loss": 0.31759558117231557, "train_patch_loss": 0.9528444103351671, "train_token_loss": 0.6638412714524896, "train_fea_loss": 0.7939004297111388, "epoch": 154, "n_parameters": 144845568}
190
+ {"train_lr": 2.280680768143689e-05, "train_loss": 2.658987253216459, "train_task_loss": 2.421917999400009, "train_bpp_loss": 0.31114791267567116, "train_patch_loss": 0.9561528907575124, "train_token_loss": 0.6675511025901821, "train_fea_loss": 0.7982139991355135, "epoch": 155, "n_parameters": 144845568}
191
+ {"train_lr": 2.2279194262997928e-05, "train_loss": 2.695716130978269, "train_task_loss": 2.4542629720031215, "train_bpp_loss": 0.31765079001362306, "train_patch_loss": 0.9665458119615055, "train_token_loss": 0.6746455496869428, "train_fea_loss": 0.8130716029219598, "epoch": 156, "n_parameters": 144845568}
192
+ {"train_lr": 2.1761196307742086e-05, "train_loss": 2.6704102983202436, "train_task_loss": 2.4334238473442102, "train_bpp_loss": 0.30996573550349693, "train_patch_loss": 0.9596143975134769, "train_token_loss": 0.6693641086682963, "train_fea_loss": 0.8044453353511629, "epoch": 157, "n_parameters": 144845568}
193
+ {"train_lr": 2.1252941623912577e-05, "train_loss": 2.6702361341735585, "train_task_loss": 2.4332374961917207, "train_bpp_loss": 0.3099473466777312, "train_patch_loss": 0.9602035106019043, "train_token_loss": 0.6687574936278003, "train_fea_loss": 0.8042764851122165, "epoch": 158, "n_parameters": 144845568}
194
+ {"train_lr": 2.0754555615745688e-05, "train_loss": 2.6676239087480864, "train_task_loss": 2.430506248768571, "train_bpp_loss": 0.31043111976030635, "train_patch_loss": 0.9591814159645701, "train_token_loss": 0.6682900374808864, "train_fea_loss": 0.8030347887748223, "epoch": 159, "n_parameters": 144845568}
195
+ {"train_lr": 2.0266161252534863e-05, "train_loss": 2.682328796143726, "train_task_loss": 2.4449838221484095, "train_bpp_loss": 0.3097276722284244, "train_patch_loss": 0.9639880869501858, "train_token_loss": 0.6714682123563487, "train_fea_loss": 0.8095275162444483, "epoch": 160, "n_parameters": 144845568}
196
+ {"train_lr": 1.9787879038283694e-05, "train_loss": 2.794825274291799, "train_task_loss": 2.5297502170381643, "train_bpp_loss": 0.3605138300236859, "train_patch_loss": 0.9950330435751761, "train_token_loss": 0.6891800253216526, "train_fea_loss": 0.8455371403780808, "epoch": 161, "n_parameters": 144845568}
197
+ {"train_lr": 1.9319826981968032e-05, "train_loss": 2.6933876661052234, "train_task_loss": 2.4523086957990836, "train_bpp_loss": 0.3166517836579995, "train_patch_loss": 0.9668755698480087, "train_token_loss": 0.6705846348107397, "train_fea_loss": 0.8148484852687596, "epoch": 162, "n_parameters": 144845568}
198
+ {"train_lr": 1.8862120568428674e-05, "train_loss": 2.7029373263426537, "train_task_loss": 2.4646165441873547, "train_bpp_loss": 0.31019026768110175, "train_patch_loss": 0.9700850130145927, "train_token_loss": 0.6760256816487875, "train_fea_loss": 0.8185058429220812, "epoch": 163, "n_parameters": 144845568}
199
+ {"train_lr": 1.8414872729877464e-05, "train_loss": 2.7035999098252192, "train_task_loss": 2.4655938773311012, "train_bpp_loss": 0.30946591779214444, "train_patch_loss": 0.9705398866317148, "train_token_loss": 0.6763441974645574, "train_fea_loss": 0.8187097842231107, "epoch": 164, "n_parameters": 144845568}
200
+ {"train_lr": 1.9319826981970397e-05, "train_loss": 2.740814351822904, "train_task_loss": 2.492047573066801, "train_bpp_loss": 0.329818556624592, "train_patch_loss": 0.9823994228093744, "train_token_loss": 0.6829682969436836, "train_fea_loss": 0.8266798452512144, "epoch": 161, "n_parameters": 144845568}
201
+ {"train_lr": 1.9319826981970397e-05, "train_loss": 2.7282523382946455, "train_task_loss": 2.48745193989943, "train_bpp_loss": 0.3139055231898773, "train_patch_loss": 0.9767762538854131, "train_token_loss": 0.6849351603363594, "train_fea_loss": 0.8257405170858156, "epoch": 162, "n_parameters": 144845568}
202
+ {"train_lr": 1.8862120568426702e-05, "train_loss": 2.7373931265259674, "train_task_loss": 2.493901286015122, "train_bpp_loss": 0.3186548652218954, "train_patch_loss": 0.9806770237019773, "train_token_loss": 0.6831493455675437, "train_fea_loss": 0.8300749089220445, "epoch": 163, "n_parameters": 144845568}
203
+ {"train_lr": 1.84148727298801e-05, "train_loss": 2.7362954941370505, "train_task_loss": 2.4902694111319184, "train_bpp_loss": 0.3243905476437395, "train_patch_loss": 0.9797143076258383, "train_token_loss": 0.6838951857773949, "train_fea_loss": 0.8266599099073622, "epoch": 164, "n_parameters": 144845568}
1_feature_extractor/log/DINOv2_training/log/20240725_001002.log ADDED
The diff for this file is too large to render. See raw diff
 
1_feature_extractor/log/DINOv2_training/log/20240725_084736.log ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2024-07-25 08:47:36,358 [INFO ] Logging file is /data0/qiyp/Proteus-pytorch/pretrain/log/DINOv2_training/log//20240725_084736.log
2
+ 2024-07-25 08:47:36,359 [INFO ] job dir: /data0/qiyp/Proteus-pytorch/pretrain
3
+ 2024-07-25 08:47:36,359 [INFO ] Namespace(batch_size=48,
4
+ epochs=200,
5
+ bce_loss=False,
6
+ unscale_lr=False,
7
+ model='models_proteus_dinov2',
8
+ target_model='vit_base',
9
+ input_size=224,
10
+ drop=0.0,
11
+ drop_path=0.1,
12
+ model_ema=True,
13
+ model_ema_decay=0.99996,
14
+ model_ema_force_cpu=False,
15
+ opt='adamw',
16
+ opt_eps=1e-08,
17
+ opt_betas=None,
18
+ clip_grad=None,
19
+ momentum=0.9,
20
+ weight_decay=0.05,
21
+ sched='cosine',
22
+ lr=0.0002,
23
+ lr_noise=None,
24
+ lr_noise_pct=0.67,
25
+ lr_noise_std=1.0,
26
+ warmup_lr=1e-06,
27
+ min_lr=1e-05,
28
+ decay_epochs=30,
29
+ warmup_epochs=5,
30
+ cooldown_epochs=10,
31
+ patience_epochs=10,
32
+ decay_rate=0.1,
33
+ color_jitter=0.3,
34
+ aa='rand-m9-mstd0.5-inc1',
35
+ smoothing=0.1,
36
+ train_interpolation='bicubic',
37
+ repeated_aug=True,
38
+ train_mode=True,
39
+ ThreeAugment=False,
40
+ src=False,
41
+ global_crops_size=224,
42
+ patch_size=14,
43
+ mask_ratio=[0.5],
44
+ mask_probability=0.5,
45
+ mask_first_n=True,
46
+ clone_batch=1,
47
+ reprob=0.25,
48
+ remode='pixel',
49
+ recount=1,
50
+ resplit=False,
51
+ mixup=0.8,
52
+ cutmix=1.0,
53
+ cutmix_minmax=None,
54
+ mixup_prob=1.0,
55
+ mixup_switch_prob=0.5,
56
+ mixup_mode='batch',
57
+ teacher_model='vit_large',
58
+ teacher_path='',
59
+ distillation_type='none',
60
+ distillation_alpha=0.5,
61
+ distillation_tau=1.0,
62
+ lambda_token=1.0,
63
+ lambda_fea=1.0,
64
+ lambda_patch=1.0,
65
+ cosub=False,
66
+ finetune='/data0/qiyp/Proteus-pytorch/pretrain/ckpt/proteus_vitb_backbone.pth',
67
+ attn_only=False,
68
+ weight_inherit='',
69
+ data_path='/data1/datasets/imagenet_fold',
70
+ data_set='IMNET',
71
+ inat_category='name',
72
+ output_dir='log/DINOv2_training',
73
+ log_dir='/data0/qiyp/Proteus-pytorch/pretrain/log/DINOv2_training/log/',
74
+ device='cuda',
75
+ seed=0,
76
+ resume='',
77
+ start_epoch=0,
78
+ eval=False,
79
+ eval_crop_ratio=0.875,
80
+ dist_eval=False,
81
+ num_workers=10,
82
+ pin_mem=True,
83
+ distributed=True,
84
+ world_size=6,
85
+ dist_url='env://',
86
+ rank=0,
87
+ gpu=0,
88
+ dist_backend='nccl')
89
+
90
+ 2024-07-25 08:47:38,916 [INFO ] Dataset ImageFolder
91
+ Number of datapoints: 1281167
92
+ Root location: /data1/datasets/imagenet_fold/train
93
+ StandardTransform
94
+ Transform: Compose(
95
+ RandomResizedCropAndInterpolation(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=bicubic)
96
+ RandomHorizontalFlip(p=0.5)
97
+ RandAugment(n=2, ops=
98
+ AugmentOp(name=AutoContrast, p=0.5, m=9, mstd=0.5)
99
+ AugmentOp(name=Equalize, p=0.5, m=9, mstd=0.5)
100
+ AugmentOp(name=Invert, p=0.5, m=9, mstd=0.5)
101
+ AugmentOp(name=Rotate, p=0.5, m=9, mstd=0.5)
102
+ AugmentOp(name=PosterizeIncreasing, p=0.5, m=9, mstd=0.5)
103
+ AugmentOp(name=SolarizeIncreasing, p=0.5, m=9, mstd=0.5)
104
+ AugmentOp(name=SolarizeAdd, p=0.5, m=9, mstd=0.5)
105
+ AugmentOp(name=ColorIncreasing, p=0.5, m=9, mstd=0.5)
106
+ AugmentOp(name=ContrastIncreasing, p=0.5, m=9, mstd=0.5)
107
+ AugmentOp(name=BrightnessIncreasing, p=0.5, m=9, mstd=0.5)
108
+ AugmentOp(name=SharpnessIncreasing, p=0.5, m=9, mstd=0.5)
109
+ AugmentOp(name=ShearX, p=0.5, m=9, mstd=0.5)
110
+ AugmentOp(name=ShearY, p=0.5, m=9, mstd=0.5)
111
+ AugmentOp(name=TranslateXRel, p=0.5, m=9, mstd=0.5)
112
+ AugmentOp(name=TranslateYRel, p=0.5, m=9, mstd=0.5))
113
+ ToTensor()
114
+ Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
115
+ RandomErasing(p=0.25, mode=pixel, count=(1, 1))
116
+ )
117
+ 2024-07-25 08:47:38,926 [INFO ] Sampler_train = <samplers.RASampler object at 0x7fc548b7f3a0>
118
+ 2024-07-25 08:47:38,928 [INFO ] Created a temporary directory at /tmp/tmpe0qnk5ic
119
+ 2024-07-25 08:47:38,929 [INFO ] Writing /tmp/tmpe0qnk5ic/_remote_module_non_scriptable.py
120
+ 2024-07-25 08:47:40,640 [INFO ] using MLP layer as FFN
121
+ 2024-07-25 08:49:51,514 [INFO ] using MLP layer as FFN
122
+ 2024-07-25 08:49:57,349 [INFO ] Model = MetaArch(
123
+ (student): ModuleDict(
124
+ (backbone): DinoVisionTransformer(
125
+ (patch_embed): PatchEmbed(
126
+ (proj): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
127
+ (norm): Identity()
128
+ )
129
+ (blocks): ModuleList(
130
+ (0-11): 12 x Block(
131
+ (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
132
+ (attn): MemEffAttention(
133
+ (qkv): Linear(in_features=768, out_features=2304, bias=True)
134
+ (attn_drop): Dropout(p=0.0, inplace=False)
135
+ (proj): Linear(in_features=768, out_features=768, bias=True)
136
+ (proj_drop): Dropout(p=0.0, inplace=False)
137
+ )
138
+ (ls1): LayerScale()
139
+ (drop_path1): Identity()
140
+ (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
141
+ (mlp): Mlp(
142
+ (fc1): Linear(in_features=768, out_features=3072, bias=True)
143
+ (act): GELU(approximate='none')
144
+ (fc2): Linear(in_features=3072, out_features=768, bias=True)
145
+ (drop): Dropout(p=0.0, inplace=False)
146
+ )
147
+ (ls2): LayerScale()
148
+ (drop_path2): Identity()
149
+ )
150
+ )
151
+ (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
152
+ (head): Identity()
153
+ )
154
+ )
155
+ (teacher): ModuleDict(
156
+ (backbone): DinoVisionTransformer(
157
+ (patch_embed): PatchEmbed(
158
+ (proj): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14))
159
+ (norm): Identity()
160
+ )
161
+ (blocks): ModuleList(
162
+ (0-23): 24 x NestedTensorBlock(
163
+ (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
164
+ (attn): MemEffAttention(
165
+ (qkv): Linear(in_features=1024, out_features=3072, bias=True)
166
+ (attn_drop): Dropout(p=0.0, inplace=False)
167
+ (proj): Linear(in_features=1024, out_features=1024, bias=True)
168
+ (proj_drop): Dropout(p=0.0, inplace=False)
169
+ )
170
+ (ls1): LayerScale()
171
+ (drop_path1): Identity()
172
+ (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
173
+ (mlp): Mlp(
174
+ (fc1): Linear(in_features=1024, out_features=4096, bias=True)
175
+ (act): GELU(approximate='none')
176
+ (fc2): Linear(in_features=4096, out_features=1024, bias=True)
177
+ (drop): Dropout(p=0.0, inplace=False)
178
+ )
179
+ (ls2): LayerScale()
180
+ (drop_path2): Identity()
181
+ )
182
+ )
183
+ (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
184
+ (head): Identity()
185
+ )
186
+ )
187
+ (patch_head): Sequential(
188
+ (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
189
+ (1): Linear(in_features=768, out_features=1024, bias=True)
190
+ )
191
+ (token_head): Sequential(
192
+ (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
193
+ (1): Linear(in_features=768, out_features=1024, bias=True)
194
+ )
195
+ (fea_head): Sequential(
196
+ (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
197
+ (1): Linear(in_features=768, out_features=1024, bias=True)
198
+ )
199
+ (soft_criterion): MSELoss()
200
+ (info_bottleneck): IF_Module(
201
+ (encoder_blocks): ModuleList(
202
+ (0-3): 4 x Block(
203
+ (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
204
+ (attn): Attention(
205
+ (qkv): Linear(in_features=768, out_features=2304, bias=True)
206
+ (q_norm): Identity()
207
+ (k_norm): Identity()
208
+ (attn_drop): Dropout(p=0.0, inplace=False)
209
+ (proj): Linear(in_features=768, out_features=768, bias=True)
210
+ (proj_drop): Dropout(p=0.0, inplace=False)
211
+ )
212
+ (ls1): Identity()
213
+ (drop_path1): Identity()
214
+ (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
215
+ (mlp): Mlp(
216
+ (fc1): Linear(in_features=768, out_features=3072, bias=True)
217
+ (act): GELU(approximate='none')
218
+ (drop1): Dropout(p=0.0, inplace=False)
219
+ (norm): Identity()
220
+ (fc2): Linear(in_features=3072, out_features=768, bias=True)
221
+ (drop2): Dropout(p=0.0, inplace=False)
222
+ )
223
+ (ls2): Identity()
224
+ (drop_path2): Identity()
225
+ )
226
+ )
227
+ (encoder_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
228
+ (decoder_blocks): ModuleList(
229
+ (0-3): 4 x Block(
230
+ (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
231
+ (attn): Attention(
232
+ (qkv): Linear(in_features=768, out_features=2304, bias=True)
233
+ (q_norm): Identity()
234
+ (k_norm): Identity()
235
+ (attn_drop): Dropout(p=0.0, inplace=False)
236
+ (proj): Linear(in_features=768, out_features=768, bias=True)
237
+ (proj_drop): Dropout(p=0.0, inplace=False)
238
+ )
239
+ (ls1): Identity()
240
+ (drop_path1): Identity()
241
+ (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
242
+ (mlp): Mlp(
243
+ (fc1): Linear(in_features=768, out_features=3072, bias=True)
244
+ (act): GELU(approximate='none')
245
+ (drop1): Dropout(p=0.0, inplace=False)
246
+ (norm): Identity()
247
+ (fc2): Linear(in_features=3072, out_features=768, bias=True)
248
+ (drop2): Dropout(p=0.0, inplace=False)
249
+ )
250
+ (ls2): Identity()
251
+ (drop_path2): Identity()
252
+ )
253
+ )
254
+ (decoder_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
255
+ (entropy_bottleneck): EntropyBottleneck(
256
+ (likelihood_lower_bound): LowerBound()
257
+ )
258
+ )
259
+ )
260
+ 2024-07-25 08:49:59,702 [INFO ] Finetuning from /data0/qiyp/Proteus-pytorch/pretrain/ckpt/proteus_vitb_backbone.pth
261
+ 2024-07-25 08:49:59,703 [INFO ] missing_keys: ['patch_head.0.weight', 'patch_head.0.bias', 'patch_head.1.weight', 'patch_head.1.bias', 'info_bottleneck.encoder_blocks.0.norm1.weight', 'info_bottleneck.encoder_blocks.0.norm1.bias', 'info_bottleneck.encoder_blocks.0.attn.qkv.weight', 'info_bottleneck.encoder_blocks.0.attn.qkv.bias', 'info_bottleneck.encoder_blocks.0.attn.proj.weight', 'info_bottleneck.encoder_blocks.0.attn.proj.bias', 'info_bottleneck.encoder_blocks.0.norm2.weight', 'info_bottleneck.encoder_blocks.0.norm2.bias', 'info_bottleneck.encoder_blocks.0.mlp.fc1.weight', 'info_bottleneck.encoder_blocks.0.mlp.fc1.bias', 'info_bottleneck.encoder_blocks.0.mlp.fc2.weight', 'info_bottleneck.encoder_blocks.0.mlp.fc2.bias', 'info_bottleneck.encoder_blocks.1.norm1.weight', 'info_bottleneck.encoder_blocks.1.norm1.bias', 'info_bottleneck.encoder_blocks.1.attn.qkv.weight', 'info_bottleneck.encoder_blocks.1.attn.qkv.bias', 'info_bottleneck.encoder_blocks.1.attn.proj.weight', 'info_bottleneck.encoder_blocks.1.attn.proj.bias', 'info_bottleneck.encoder_blocks.1.norm2.weight', 'info_bottleneck.encoder_blocks.1.norm2.bias', 'info_bottleneck.encoder_blocks.1.mlp.fc1.weight', 'info_bottleneck.encoder_blocks.1.mlp.fc1.bias', 'info_bottleneck.encoder_blocks.1.mlp.fc2.weight', 'info_bottleneck.encoder_blocks.1.mlp.fc2.bias', 'info_bottleneck.encoder_blocks.2.norm1.weight', 'info_bottleneck.encoder_blocks.2.norm1.bias', 'info_bottleneck.encoder_blocks.2.attn.qkv.weight', 'info_bottleneck.encoder_blocks.2.attn.qkv.bias', 'info_bottleneck.encoder_blocks.2.attn.proj.weight', 'info_bottleneck.encoder_blocks.2.attn.proj.bias', 'info_bottleneck.encoder_blocks.2.norm2.weight', 'info_bottleneck.encoder_blocks.2.norm2.bias', 'info_bottleneck.encoder_blocks.2.mlp.fc1.weight', 'info_bottleneck.encoder_blocks.2.mlp.fc1.bias', 'info_bottleneck.encoder_blocks.2.mlp.fc2.weight', 'info_bottleneck.encoder_blocks.2.mlp.fc2.bias', 'info_bottleneck.encoder_blocks.3.norm1.weight', 'info_bottleneck.encoder_blocks.3.norm1.bias', 'info_bottleneck.encoder_blocks.3.attn.qkv.weight', 'info_bottleneck.encoder_blocks.3.attn.qkv.bias', 'info_bottleneck.encoder_blocks.3.attn.proj.weight', 'info_bottleneck.encoder_blocks.3.attn.proj.bias', 'info_bottleneck.encoder_blocks.3.norm2.weight', 'info_bottleneck.encoder_blocks.3.norm2.bias', 'info_bottleneck.encoder_blocks.3.mlp.fc1.weight', 'info_bottleneck.encoder_blocks.3.mlp.fc1.bias', 'info_bottleneck.encoder_blocks.3.mlp.fc2.weight', 'info_bottleneck.encoder_blocks.3.mlp.fc2.bias', 'info_bottleneck.encoder_norm.weight', 'info_bottleneck.encoder_norm.bias', 'info_bottleneck.decoder_blocks.0.norm1.weight', 'info_bottleneck.decoder_blocks.0.norm1.bias', 'info_bottleneck.decoder_blocks.0.attn.qkv.weight', 'info_bottleneck.decoder_blocks.0.attn.qkv.bias', 'info_bottleneck.decoder_blocks.0.attn.proj.weight', 'info_bottleneck.decoder_blocks.0.attn.proj.bias', 'info_bottleneck.decoder_blocks.0.norm2.weight', 'info_bottleneck.decoder_blocks.0.norm2.bias', 'info_bottleneck.decoder_blocks.0.mlp.fc1.weight', 'info_bottleneck.decoder_blocks.0.mlp.fc1.bias', 'info_bottleneck.decoder_blocks.0.mlp.fc2.weight', 'info_bottleneck.decoder_blocks.0.mlp.fc2.bias', 'info_bottleneck.decoder_blocks.1.norm1.weight', 'info_bottleneck.decoder_blocks.1.norm1.bias', 'info_bottleneck.decoder_blocks.1.attn.qkv.weight', 'info_bottleneck.decoder_blocks.1.attn.qkv.bias', 'info_bottleneck.decoder_blocks.1.attn.proj.weight', 'info_bottleneck.decoder_blocks.1.attn.proj.bias', 'info_bottleneck.decoder_blocks.1.norm2.weight', 'info_bottleneck.decoder_blocks.1.norm2.bias', 'info_bottleneck.decoder_blocks.1.mlp.fc1.weight', 'info_bottleneck.decoder_blocks.1.mlp.fc1.bias', 'info_bottleneck.decoder_blocks.1.mlp.fc2.weight', 'info_bottleneck.decoder_blocks.1.mlp.fc2.bias', 'info_bottleneck.decoder_blocks.2.norm1.weight', 'info_bottleneck.decoder_blocks.2.norm1.bias', 'info_bottleneck.decoder_blocks.2.attn.qkv.weight', 'info_bottleneck.decoder_blocks.2.attn.qkv.bias', 'info_bottleneck.decoder_blocks.2.attn.proj.weight', 'info_bottleneck.decoder_blocks.2.attn.proj.bias', 'info_bottleneck.decoder_blocks.2.norm2.weight', 'info_bottleneck.decoder_blocks.2.norm2.bias', 'info_bottleneck.decoder_blocks.2.mlp.fc1.weight', 'info_bottleneck.decoder_blocks.2.mlp.fc1.bias', 'info_bottleneck.decoder_blocks.2.mlp.fc2.weight', 'info_bottleneck.decoder_blocks.2.mlp.fc2.bias', 'info_bottleneck.decoder_blocks.3.norm1.weight', 'info_bottleneck.decoder_blocks.3.norm1.bias', 'info_bottleneck.decoder_blocks.3.attn.qkv.weight', 'info_bottleneck.decoder_blocks.3.attn.qkv.bias', 'info_bottleneck.decoder_blocks.3.attn.proj.weight', 'info_bottleneck.decoder_blocks.3.attn.proj.bias', 'info_bottleneck.decoder_blocks.3.norm2.weight', 'info_bottleneck.decoder_blocks.3.norm2.bias', 'info_bottleneck.decoder_blocks.3.mlp.fc1.weight', 'info_bottleneck.decoder_blocks.3.mlp.fc1.bias', 'info_bottleneck.decoder_blocks.3.mlp.fc2.weight', 'info_bottleneck.decoder_blocks.3.mlp.fc2.bias', 'info_bottleneck.decoder_norm.weight', 'info_bottleneck.decoder_norm.bias', 'info_bottleneck.entropy_bottleneck._matrix0', 'info_bottleneck.entropy_bottleneck._bias0', 'info_bottleneck.entropy_bottleneck._factor0', 'info_bottleneck.entropy_bottleneck._matrix1', 'info_bottleneck.entropy_bottleneck._bias1', 'info_bottleneck.entropy_bottleneck._factor1', 'info_bottleneck.entropy_bottleneck._matrix2', 'info_bottleneck.entropy_bottleneck._bias2', 'info_bottleneck.entropy_bottleneck._factor2', 'info_bottleneck.entropy_bottleneck._matrix3', 'info_bottleneck.entropy_bottleneck._bias3', 'info_bottleneck.entropy_bottleneck._factor3', 'info_bottleneck.entropy_bottleneck._matrix4', 'info_bottleneck.entropy_bottleneck._bias4', 'info_bottleneck.entropy_bottleneck.quantiles', 'info_bottleneck.entropy_bottleneck._offset', 'info_bottleneck.entropy_bottleneck._quantized_cdf', 'info_bottleneck.entropy_bottleneck._cdf_length', 'info_bottleneck.entropy_bottleneck.target', 'info_bottleneck.entropy_bottleneck.likelihood_lower_bound.bound']
262
+ 2024-07-25 08:49:59,703 [INFO ] unexpected_keys: ['ibot_head.0.weight', 'ibot_head.0.bias', 'ibot_head.1.weight', 'ibot_head.1.bias']
263
+ 2024-07-25 08:50:00,636 [INFO ] number of params: 144845568
264
+ 2024-07-25 08:50:00,636 [INFO ] base lr = 0.0002
265
+ 2024-07-25 08:50:00,636 [INFO ] actural lr = 0.00011250000000000001
266
+ 2024-07-25 08:50:00,639 [INFO ] Start training for 200 epochs
267
+ 2024-07-25 08:50:00,642 [INFO ] Parameters to be updated:
268
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.0.norm2.weight
269
+
270
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.2.norm2.weight
271
+
272
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.1.attn.proj.bias
273
+
274
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.1.attn.qkv.weight
275
+
276
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.1.norm2.bias
277
+
278
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.0.norm2.bias
279
+
280
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.1.mlp.fc2.weight
281
+
282
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.entropy_bottleneck._bias0
283
+
284
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.2.attn.proj.bias
285
+
286
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.1.mlp.fc2.bias
287
+
288
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.0.norm1.bias
289
+
290
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.entropy_bottleneck._matrix0
291
+
292
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.entropy_bottleneck._factor2
293
+
294
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.3.attn.proj.bias
295
+
296
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.1.attn.proj.bias
297
+
298
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.2.norm1.weight
299
+
300
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.1.mlp.fc2.bias
301
+
302
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.entropy_bottleneck._matrix4
303
+
304
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.3.mlp.fc2.bias
305
+
306
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.3.norm1.weight
307
+
308
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.0.mlp.fc2.bias
309
+
310
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.0.norm1.bias
311
+
312
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.2.norm2.bias
313
+
314
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.3.norm1.bias
315
+
316
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.entropy_bottleneck._bias2
317
+
318
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.2.attn.qkv.weight
319
+
320
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.3.attn.qkv.bias
321
+
322
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.1.norm1.bias
323
+
324
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.1.attn.qkv.weight
325
+
326
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.0.norm2.weight
327
+
328
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.decoder_blocks.2.attn.proj.weight
329
+
330
+ 2024-07-25 08:50:00,642 [INFO ] module.info_bottleneck.encoder_blocks.2.mlp.fc2.weight
331
+
332
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.3.attn.proj.weight
333
+
334
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.3.mlp.fc2.weight
335
+
336
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.3.norm1.weight
337
+
338
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.0.mlp.fc1.bias
339
+
340
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.entropy_bottleneck._bias1
341
+
342
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.1.norm2.bias
343
+
344
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.0.attn.qkv.weight
345
+
346
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.2.norm2.bias
347
+
348
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.2.attn.qkv.bias
349
+
350
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.1.norm1.bias
351
+
352
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.1.mlp.fc1.bias
353
+
354
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.1.mlp.fc1.bias
355
+
356
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.2.mlp.fc1.weight
357
+
358
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.entropy_bottleneck._bias3
359
+
360
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.2.norm2.weight
361
+
362
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.2.mlp.fc1.bias
363
+
364
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.1.norm2.weight
365
+
366
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.0.norm1.weight
367
+
368
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.0.attn.qkv.bias
369
+
370
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.3.norm2.weight
371
+
372
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.3.mlp.fc2.bias
373
+
374
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.0.mlp.fc1.weight
375
+
376
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.1.attn.proj.weight
377
+
378
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_norm.bias
379
+
380
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_norm.weight
381
+
382
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.3.mlp.fc2.weight
383
+
384
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.3.norm2.weight
385
+
386
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.1.norm1.weight
387
+
388
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.entropy_bottleneck.quantiles
389
+
390
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.entropy_bottleneck._matrix2
391
+
392
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.decoder_blocks.0.attn.proj.weight
393
+
394
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.3.attn.qkv.bias
395
+
396
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.3.mlp.fc1.weight
397
+
398
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.2.attn.proj.weight
399
+
400
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.3.norm2.bias
401
+
402
+ 2024-07-25 08:50:00,643 [INFO ] module.info_bottleneck.encoder_blocks.2.mlp.fc2.bias
403
+
404
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.2.mlp.fc2.bias
405
+
406
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.3.attn.proj.bias
407
+
408
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.2.attn.qkv.weight
409
+
410
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.entropy_bottleneck._matrix1
411
+
412
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.3.attn.proj.weight
413
+
414
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.0.attn.qkv.bias
415
+
416
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.1.mlp.fc2.weight
417
+
418
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.0.mlp.fc2.bias
419
+
420
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.1.norm1.weight
421
+
422
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.0.attn.proj.weight
423
+
424
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.2.mlp.fc2.weight
425
+
426
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.2.mlp.fc1.bias
427
+
428
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.1.mlp.fc1.weight
429
+
430
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.0.norm1.weight
431
+
432
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.3.attn.qkv.weight
433
+
434
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.3.norm1.bias
435
+
436
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.2.attn.qkv.bias
437
+
438
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.2.attn.proj.bias
439
+
440
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.2.norm1.bias
441
+
442
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.0.mlp.fc2.weight
443
+
444
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.1.mlp.fc1.weight
445
+
446
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.2.mlp.fc1.weight
447
+
448
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.3.attn.qkv.weight
449
+
450
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.entropy_bottleneck._matrix3
451
+
452
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_norm.bias
453
+
454
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.1.norm2.weight
455
+
456
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.0.mlp.fc2.weight
457
+
458
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.0.mlp.fc1.bias
459
+
460
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.0.attn.proj.bias
461
+
462
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.entropy_bottleneck._factor0
463
+
464
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.1.attn.proj.weight
465
+
466
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.entropy_bottleneck._bias4
467
+
468
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.entropy_bottleneck._factor3
469
+
470
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.encoder_blocks.0.mlp.fc1.weight
471
+
472
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.2.norm1.weight
473
+
474
+ 2024-07-25 08:50:00,644 [INFO ] module.info_bottleneck.decoder_blocks.3.norm2.bias
475
+
476
+ 2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.decoder_norm.weight
477
+
478
+ 2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.decoder_blocks.0.norm2.bias
479
+
480
+ 2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.decoder_blocks.1.attn.qkv.bias
481
+
482
+ 2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.decoder_blocks.2.norm1.bias
483
+
484
+ 2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.decoder_blocks.0.attn.qkv.weight
485
+
486
+ 2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.encoder_blocks.3.mlp.fc1.bias
487
+
488
+ 2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.decoder_blocks.3.mlp.fc1.weight
489
+
490
+ 2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.encoder_blocks.0.attn.proj.bias
491
+
492
+ 2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.encoder_blocks.1.attn.qkv.bias
493
+
494
+ 2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.decoder_blocks.3.mlp.fc1.bias
495
+
496
+ 2024-07-25 08:50:00,645 [INFO ] module.info_bottleneck.entropy_bottleneck._factor1
497
+
498
+ 2024-07-25 08:50:00,645 [INFO ]
499
+
500
+ 2024-07-25 08:50:06,715 [INFO ] Epoch: [0] [ 0/4448] eta: 7:29:47 lr: 0.000001 loss: 8.3333 (8.3333) task_loss: 5.5091 (5.5091) bpp_loss: 28.2426 (28.2426) patch_loss: 2.4117 (2.4117) token_loss: 0.6234 (0.6234) fea_loss: 2.4741 (2.4741) time: 6.0673 data: 2.4582 max mem: 10534
501
+ 2024-07-25 08:50:16,045 [INFO ] Epoch: [0] [ 10/4448] eta: 1:43:31 lr: 0.000001 loss: 8.3085 (8.3065) task_loss: 5.4842 (5.4822) bpp_loss: 28.2425 (28.2426) patch_loss: 2.4007 (2.4001) token_loss: 0.6673 (0.6727) fea_loss: 2.4212 (2.4095) time: 1.3996 data: 0.2236 max mem: 11190
502
+ 2024-07-25 08:50:25,369 [INFO ] Epoch: [0] [ 20/4448] eta: 1:26:52 lr: 0.000001 loss: 8.2112 (8.2280) task_loss: 5.3869 (5.4038) bpp_loss: 28.2425 (28.2427) patch_loss: 2.3943 (2.3977) token_loss: 0.6668 (0.6682) fea_loss: 2.3140 (2.3378) time: 0.9326 data: 0.0001 max mem: 11191
503
+ 2024-07-25 08:50:34,728 [INFO ] Epoch: [0] [ 30/4448] eta: 1:20:56 lr: 0.000001 loss: 8.0968 (8.1642) task_loss: 5.2725 (5.3399) bpp_loss: 28.2428 (28.2427) patch_loss: 2.3875 (2.3914) token_loss: 0.6681 (0.6704) fea_loss: 2.2053 (2.2782) time: 0.9341 data: 0.0001 max mem: 11191
504
+ 2024-07-25 08:50:44,122 [INFO ] Epoch: [0] [ 40/4448] eta: 1:17:53 lr: 0.000001 loss: 7.9714 (8.1093) task_loss: 5.1472 (5.2851) bpp_loss: 28.2426 (28.2427) patch_loss: 2.3728 (2.3870) token_loss: 0.6831 (0.6721) fea_loss: 2.0939 (2.2259) time: 0.9375 data: 0.0001 max mem: 11191
505
+ 2024-07-25 08:50:53,546 [INFO ] Epoch: [0] [ 50/4448] eta: 1:16:01 lr: 0.000001 loss: 7.8988 (8.0590) task_loss: 5.0745 (5.2347) bpp_loss: 28.2426 (28.2427) patch_loss: 2.3717 (2.3832) token_loss: 0.6707 (0.6729) fea_loss: 2.0184 (2.1786) time: 0.9408 data: 0.0001 max mem: 11192
506
+ 2024-07-25 08:51:02,993 [INFO ] Epoch: [0] [ 60/4448] eta: 1:14:44 lr: 0.000001 loss: 7.8305 (8.0147) task_loss: 5.0062 (5.1904) bpp_loss: 28.2426 (28.2427) patch_loss: 2.3574 (2.3787) token_loss: 0.6676 (0.6735) fea_loss: 1.9486 (2.1383) time: 0.9435 data: 0.0001 max mem: 11192
507
+ 2024-07-25 08:51:12,484 [INFO ] Epoch: [0] [ 70/4448] eta: 1:13:49 lr: 0.000001 loss: 7.7545 (7.9697) task_loss: 4.9303 (5.1454) bpp_loss: 28.2426 (28.2427) patch_loss: 2.3489 (2.3734) token_loss: 0.6659 (0.6675) fea_loss: 1.9142 (2.1045) time: 0.9469 data: 0.0001 max mem: 11192
508
+ 2024-07-25 08:51:21,991 [INFO ] Epoch: [0] [ 80/4448] eta: 1:13:06 lr: 0.000001 loss: 7.6890 (7.9328) task_loss: 4.8647 (5.1085) bpp_loss: 28.2428 (28.2427) patch_loss: 2.3404 (2.3694) token_loss: 0.6583 (0.6661) fea_loss: 1.8710 (2.0731) time: 0.9498 data: 0.0001 max mem: 11194
509
+ 2024-07-25 08:51:31,508 [INFO ] Epoch: [0] [ 90/4448] eta: 1:12:30 lr: 0.000001 loss: 7.6696 (7.9028) task_loss: 4.8453 (5.0785) bpp_loss: 28.2426 (28.2427) patch_loss: 2.3398 (2.3655) token_loss: 0.6613 (0.6665) fea_loss: 1.8399 (2.0465) time: 0.9511 data: 0.0001 max mem: 11194
510
+ 2024-07-25 08:51:41,113 [INFO ] Epoch: [0] [ 100/4448] eta: 1:12:04 lr: 0.000001 loss: 7.6273 (7.8742) task_loss: 4.8031 (5.0499) bpp_loss: 28.2425 (28.2427) patch_loss: 2.3329 (2.3620) token_loss: 0.6699 (0.6671) fea_loss: 1.8089 (2.0209) time: 0.9560 data: 0.0001 max mem: 11194
511
+ 2024-07-25 08:51:50,804 [INFO ] Epoch: [0] [ 110/4448] eta: 1:11:44 lr: 0.000001 loss: 7.5991 (7.8487) task_loss: 4.7749 (5.0244) bpp_loss: 28.2424 (28.2427) patch_loss: 2.3215 (2.3584) token_loss: 0.6780 (0.6676) fea_loss: 1.7759 (1.9984) time: 0.9647 data: 0.0001 max mem: 11194
512
+ 2024-07-25 08:52:00,570 [INFO ] Epoch: [0] [ 120/4448] eta: 1:11:29 lr: 0.000001 loss: 7.5780 (7.8257) task_loss: 4.7538 (5.0014) bpp_loss: 28.2425 (28.2427) patch_loss: 2.3218 (2.3558) token_loss: 0.6774 (0.6677) fea_loss: 1.7559 (1.9780) time: 0.9728 data: 0.0001 max mem: 11194
513
+ 2024-07-25 08:52:10,528 [INFO ] Epoch: [0] [ 130/4448] eta: 1:11:20 lr: 0.000001 loss: 7.5670 (7.8045) task_loss: 4.7427 (4.9802) bpp_loss: 28.2426 (28.2427) patch_loss: 2.3239 (2.3528) token_loss: 0.6669 (0.6682) fea_loss: 1.7497 (1.9592) time: 0.9861 data: 0.0001 max mem: 11194
514
+ 2024-07-25 08:52:20,550 [INFO ] Epoch: [0] [ 140/4448] eta: 1:11:14 lr: 0.000001 loss: 7.5367 (7.7836) task_loss: 4.7125 (4.9594) bpp_loss: 28.2423 (28.2427) patch_loss: 2.3120 (2.3499) token_loss: 0.6791 (0.6681) fea_loss: 1.7283 (1.9414) time: 0.9989 data: 0.0001 max mem: 11194
515
+ 2024-07-25 08:52:30,565 [INFO ] Epoch: [0] [ 150/4448] eta: 1:11:06 lr: 0.000001 loss: 7.4978 (7.7638) task_loss: 4.6735 (4.9396) bpp_loss: 28.2423 (28.2426) patch_loss: 2.3131 (2.3477) token_loss: 0.6641 (0.6675) fea_loss: 1.6953 (1.9244) time: 1.0018 data: 0.0001 max mem: 11194
516
+ 2024-07-25 08:52:40,399 [INFO ] Epoch: [0] [ 160/4448] eta: 1:10:54 lr: 0.000001 loss: 7.4593 (7.7455) task_loss: 4.6350 (4.9213) bpp_loss: 28.2423 (28.2426) patch_loss: 2.3100 (2.3450) token_loss: 0.6513 (0.6671) fea_loss: 1.6766 (1.9092) time: 0.9924 data: 0.0001 max mem: 11194
517
+ 2024-07-25 08:52:50,157 [INFO ] Epoch: [0] [ 170/4448] eta: 1:10:40 lr: 0.000001 loss: 7.4549 (7.7282) task_loss: 4.6306 (4.9040) bpp_loss: 28.2425 (28.2426) patch_loss: 2.3005 (2.3425) token_loss: 0.6528 (0.6669) fea_loss: 1.6654 (1.8945) time: 0.9795 data: 0.0001 max mem: 11194
518
+ 2024-07-25 08:52:59,911 [INFO ] Epoch: [0] [ 180/4448] eta: 1:10:26 lr: 0.000001 loss: 7.4523 (7.7123) task_loss: 4.6281 (4.8880) bpp_loss: 28.2424 (28.2426) patch_loss: 2.3051 (2.3407) token_loss: 0.6546 (0.6663) fea_loss: 1.6590 (1.8810) time: 0.9755 data: 0.0001 max mem: 11194
519
+ 2024-07-25 08:53:09,637 [INFO ] Epoch: [0] [ 190/4448] eta: 1:10:12 lr: 0.000001 loss: 7.4077 (7.6963) task_loss: 4.5834 (4.8720) bpp_loss: 28.2424 (28.2426) patch_loss: 2.3067 (2.3386) token_loss: 0.6544 (0.6660) fea_loss: 1.6313 (1.8674) time: 0.9739 data: 0.0001 max mem: 11194
520
+ 2024-07-25 08:53:19,360 [INFO ] Epoch: [0] [ 200/4448] eta: 1:09:59 lr: 0.000001 loss: 7.4077 (7.6829) task_loss: 4.5834 (4.8586) bpp_loss: 28.2423 (28.2426) patch_loss: 2.3058 (2.3366) token_loss: 0.6651 (0.6670) fea_loss: 1.6217 (1.8550) time: 0.9723 data: 0.0001 max mem: 11194
521
+ 2024-07-25 08:53:29,073 [INFO ] Epoch: [0] [ 210/4448] eta: 1:09:45 lr: 0.000001 loss: 7.4075 (7.6695) task_loss: 4.5833 (4.8453) bpp_loss: 28.2422 (28.2425) patch_loss: 2.2973 (2.3348) token_loss: 0.6806 (0.6674) fea_loss: 1.6116 (1.8431) time: 0.9717 data: 0.0001 max mem: 11195
522
+ 2024-07-25 08:53:38,788 [INFO ] Epoch: [0] [ 220/4448] eta: 1:09:32 lr: 0.000001 loss: 7.3916 (7.6572) task_loss: 4.5674 (4.8329) bpp_loss: 28.2419 (28.2425) patch_loss: 2.2990 (2.3332) token_loss: 0.6721 (0.6680) fea_loss: 1.5945 (1.8317) time: 0.9713 data: 0.0001 max mem: 11195
523
+ 2024-07-25 08:53:48,511 [INFO ] Epoch: [0] [ 230/4448] eta: 1:09:20 lr: 0.000001 loss: 7.3755 (7.6451) task_loss: 4.5514 (4.8209) bpp_loss: 28.2417 (28.2425) patch_loss: 2.2998 (2.3316) token_loss: 0.6759 (0.6679) fea_loss: 1.5874 (1.8213) time: 0.9718 data: 0.0001 max mem: 11195
524
+ 2024-07-25 08:53:58,215 [INFO ] Epoch: [0] [ 240/4448] eta: 1:09:07 lr: 0.000001 loss: 7.3750 (7.6335) task_loss: 4.5508 (4.8092) bpp_loss: 28.2416 (28.2425) patch_loss: 2.3008 (2.3302) token_loss: 0.6651 (0.6679) fea_loss: 1.5759 (1.8111) time: 0.9713 data: 0.0001 max mem: 11195
525
+ 2024-07-25 08:54:07,932 [INFO ] Epoch: [0] [ 250/4448] eta: 1:08:55 lr: 0.000001 loss: 7.3393 (7.6212) task_loss: 4.5152 (4.7970) bpp_loss: 28.2414 (28.2424) patch_loss: 2.2951 (2.3289) token_loss: 0.6563 (0.6671) fea_loss: 1.5655 (1.8010) time: 0.9710 data: 0.0001 max mem: 11195
526
+ 2024-07-25 08:54:17,651 [INFO ] Epoch: [0] [ 260/4448] eta: 1:08:43 lr: 0.000001 loss: 7.3268 (7.6096) task_loss: 4.5027 (4.7854) bpp_loss: 28.2413 (28.2424) patch_loss: 2.2931 (2.3274) token_loss: 0.6561 (0.6665) fea_loss: 1.5615 (1.7915) time: 0.9717 data: 0.0001 max mem: 11195
527
+ 2024-07-25 08:54:27,368 [INFO ] Epoch: [0] [ 270/4448] eta: 1:08:31 lr: 0.000001 loss: 7.3396 (7.5996) task_loss: 4.5156 (4.7754) bpp_loss: 28.2411 (28.2423) patch_loss: 2.2837 (2.3257) token_loss: 0.6719 (0.6672) fea_loss: 1.5560 (1.7825) time: 0.9717 data: 0.0001 max mem: 11195
528
+ 2024-07-25 08:54:37,070 [INFO ] Epoch: [0] [ 280/4448] eta: 1:08:19 lr: 0.000001 loss: 7.3154 (7.5896) task_loss: 4.4914 (4.7654) bpp_loss: 28.2404 (28.2422) patch_loss: 2.2837 (2.3246) token_loss: 0.6619 (0.6668) fea_loss: 1.5447 (1.7740) time: 0.9709 data: 0.0001 max mem: 11195
529
+ 2024-07-25 08:54:46,786 [INFO ] Epoch: [0] [ 290/4448] eta: 1:08:08 lr: 0.000001 loss: 7.3109 (7.5802) task_loss: 4.4868 (4.7560) bpp_loss: 28.2396 (28.2421) patch_loss: 2.2924 (2.3236) token_loss: 0.6524 (0.6669) fea_loss: 1.5375 (1.7655) time: 0.9708 data: 0.0001 max mem: 11195
530
+ 2024-07-25 08:54:56,541 [INFO ] Epoch: [0] [ 300/4448] eta: 1:07:57 lr: 0.000001 loss: 7.3166 (7.5713) task_loss: 4.4926 (4.7471) bpp_loss: 28.2385 (28.2420) patch_loss: 2.2950 (2.3225) token_loss: 0.6756 (0.6672) fea_loss: 1.5210 (1.7574) time: 0.9735 data: 0.0001 max mem: 11195
531
+ 2024-07-25 08:55:06,276 [INFO ] Epoch: [0] [ 310/4448] eta: 1:07:46 lr: 0.000001 loss: 7.3026 (7.5628) task_loss: 4.4788 (4.7386) bpp_loss: 28.2378 (28.2418) patch_loss: 2.2952 (2.3217) token_loss: 0.6699 (0.6673) fea_loss: 1.5175 (1.7496) time: 0.9744 data: 0.0001 max mem: 11195
532
+ 2024-07-25 08:55:15,984 [INFO ] Epoch: [0] [ 320/4448] eta: 1:07:34 lr: 0.000001 loss: 7.2847 (7.5538) task_loss: 4.4616 (4.7296) bpp_loss: 28.2320 (28.2415) patch_loss: 2.2890 (2.3206) token_loss: 0.6612 (0.6672) fea_loss: 1.5070 (1.7418) time: 0.9720 data: 0.0001 max mem: 11195
533
+ 2024-07-25 08:55:25,715 [INFO ] Epoch: [0] [ 330/4448] eta: 1:07:23 lr: 0.000001 loss: 7.2884 (7.5463) task_loss: 4.4653 (4.7222) bpp_loss: 28.2311 (28.2412) patch_loss: 2.2923 (2.3199) token_loss: 0.6727 (0.6678) fea_loss: 1.4998 (1.7344) time: 0.9719 data: 0.0001 max mem: 11195
534
+ 2024-07-25 08:55:35,470 [INFO ] Epoch: [0] [ 340/4448] eta: 1:07:13 lr: 0.000001 loss: 7.2818 (7.5382) task_loss: 4.4587 (4.7141) bpp_loss: 28.2306 (28.2408) patch_loss: 2.2966 (2.3192) token_loss: 0.6657 (0.6677) fea_loss: 1.4959 (1.7273) time: 0.9742 data: 0.0001 max mem: 11195
535
+ 2024-07-25 08:55:45,224 [INFO ] Epoch: [0] [ 350/4448] eta: 1:07:02 lr: 0.000001 loss: 7.2619 (7.5306) task_loss: 4.4388 (4.7066) bpp_loss: 28.2131 (28.2400) patch_loss: 2.2874 (2.3183) token_loss: 0.6561 (0.6679) fea_loss: 1.4892 (1.7205) time: 0.9754 data: 0.0001 max mem: 11195
536
+ 2024-07-25 08:55:54,962 [INFO ] Epoch: [0] [ 360/4448] eta: 1:06:51 lr: 0.000001 loss: 7.2578 (7.5231) task_loss: 4.4367 (4.6991) bpp_loss: 28.2120 (28.2392) patch_loss: 2.2885 (2.3175) token_loss: 0.6579 (0.6678) fea_loss: 1.4836 (1.7138) time: 0.9745 data: 0.0001 max mem: 11195
537
+ 2024-07-25 08:56:04,713 [INFO ] Epoch: [0] [ 370/4448] eta: 1:06:41 lr: 0.000001 loss: 7.2578 (7.5154) task_loss: 4.4367 (4.6916) bpp_loss: 28.2115 (28.2385) patch_loss: 2.2920 (2.3168) token_loss: 0.6579 (0.6676) fea_loss: 1.4752 (1.7072) time: 0.9743 data: 0.0001 max mem: 11195
538
+ 2024-07-25 08:56:14,459 [INFO ] Epoch: [0] [ 380/4448] eta: 1:06:30 lr: 0.000001 loss: 7.2504 (7.5086) task_loss: 4.4293 (4.6848) bpp_loss: 28.2110 (28.2377) patch_loss: 2.2893 (2.3161) token_loss: 0.6630 (0.6677) fea_loss: 1.4696 (1.7010) time: 0.9748 data: 0.0001 max mem: 11195
539
+ 2024-07-25 08:56:24,203 [INFO ] Epoch: [0] [ 390/4448] eta: 1:06:20 lr: 0.000001 loss: 7.2421 (7.5016) task_loss: 4.4212 (4.6779) bpp_loss: 28.2100 (28.2370) patch_loss: 2.2893 (2.3154) token_loss: 0.6630 (0.6676) fea_loss: 1.4607 (1.6949) time: 0.9744 data: 0.0001 max mem: 11195
540
+ 2024-07-25 08:56:33,961 [INFO ] Epoch: [0] [ 400/4448] eta: 1:06:10 lr: 0.000001 loss: 7.2267 (7.4947) task_loss: 4.4057 (4.6711) bpp_loss: 28.2090 (28.2363) patch_loss: 2.2825 (2.3147) token_loss: 0.6591 (0.6674) fea_loss: 1.4549 (1.6890) time: 0.9750 data: 0.0001 max mem: 11195
541
+ 2024-07-25 08:56:43,733 [INFO ] Epoch: [0] [ 410/4448] eta: 1:05:59 lr: 0.000001 loss: 7.2241 (7.4887) task_loss: 4.4032 (4.6651) bpp_loss: 28.2080 (28.2356) patch_loss: 2.2854 (2.3141) token_loss: 0.6674 (0.6676) fea_loss: 1.4618 (1.6834) time: 0.9764 data: 0.0001 max mem: 11195
542
+ 2024-07-25 08:56:53,446 [INFO ] Epoch: [0] [ 420/4448] eta: 1:05:49 lr: 0.000001 loss: 7.2331 (7.4827) task_loss: 4.4122 (4.6592) bpp_loss: 28.2068 (28.2349) patch_loss: 2.2857 (2.3134) token_loss: 0.6711 (0.6679) fea_loss: 1.4545 (1.6779) time: 0.9742 data: 0.0001 max mem: 11195
543
+ 2024-07-25 08:57:03,201 [INFO ] Epoch: [0] [ 430/4448] eta: 1:05:38 lr: 0.000001 loss: 7.2254 (7.4765) task_loss: 4.4047 (4.6531) bpp_loss: 28.2055 (28.2342) patch_loss: 2.2863 (2.3127) token_loss: 0.6718 (0.6679) fea_loss: 1.4462 (1.6725) time: 0.9733 data: 0.0001 max mem: 11195
544
+ 2024-07-25 08:57:12,938 [INFO ] Epoch: [0] [ 440/4448] eta: 1:05:28 lr: 0.000001 loss: 7.2275 (7.4714) task_loss: 4.4071 (4.6480) bpp_loss: 28.2046 (28.2335) patch_loss: 2.2863 (2.3122) token_loss: 0.6750 (0.6683) fea_loss: 1.4481 (1.6675) time: 0.9745 data: 0.0001 max mem: 11195
545
+ 2024-07-25 08:57:22,692 [INFO ] Epoch: [0] [ 450/4448] eta: 1:05:18 lr: 0.000001 loss: 7.2270 (7.4658) task_loss: 4.4067 (4.6425) bpp_loss: 28.2033 (28.2329) patch_loss: 2.2936 (2.3118) token_loss: 0.6782 (0.6684) fea_loss: 1.4495 (1.6624) time: 0.9745 data: 0.0001 max mem: 11195
546
+ 2024-07-25 08:57:32,428 [INFO ] Epoch: [0] [ 460/4448] eta: 1:05:07 lr: 0.000001 loss: 7.2030 (7.4601) task_loss: 4.3827 (4.6368) bpp_loss: 28.2019 (28.2322) patch_loss: 2.2834 (2.3110) token_loss: 0.6762 (0.6684) fea_loss: 1.4385 (1.6574) time: 0.9744 data: 0.0001 max mem: 11195
547
+ 2024-07-25 08:57:42,172 [INFO ] Epoch: [0] [ 470/4448] eta: 1:04:57 lr: 0.000001 loss: 7.2220 (7.4547) task_loss: 4.4020 (4.6315) bpp_loss: 28.2006 (28.2315) patch_loss: 2.2808 (2.3105) token_loss: 0.6535 (0.6685) fea_loss: 1.4261 (1.6526) time: 0.9739 data: 0.0001 max mem: 11195
548
+ 2024-07-25 08:57:51,902 [INFO ] Epoch: [0] [ 480/4448] eta: 1:04:47 lr: 0.000001 loss: 7.2008 (7.4490) task_loss: 4.3808 (4.6259) bpp_loss: 28.1997 (28.2308) patch_loss: 2.2765 (2.3099) token_loss: 0.6534 (0.6682) fea_loss: 1.4257 (1.6478) time: 0.9736 data: 0.0001 max mem: 11195
549
+ 2024-07-25 08:58:01,653 [INFO ] Epoch: [0] [ 490/4448] eta: 1:04:37 lr: 0.000001 loss: 7.2050 (7.4444) task_loss: 4.3852 (4.6214) bpp_loss: 28.1985 (28.2301) patch_loss: 2.2831 (2.3095) token_loss: 0.6743 (0.6686) fea_loss: 1.4232 (1.6432) time: 0.9740 data: 0.0001 max mem: 11195
550
+ 2024-07-25 08:58:11,373 [INFO ] Epoch: [0] [ 500/4448] eta: 1:04:26 lr: 0.000001 loss: 7.2024 (7.4392) task_loss: 4.3827 (4.6162) bpp_loss: 28.1970 (28.2295) patch_loss: 2.2877 (2.3091) token_loss: 0.6743 (0.6684) fea_loss: 1.4212 (1.6387) time: 0.9735 data: 0.0001 max mem: 11195
551
+ 2024-07-25 08:58:21,089 [INFO ] Epoch: [0] [ 510/4448] eta: 1:04:16 lr: 0.000001 loss: 7.1806 (7.4341) task_loss: 4.3611 (4.6112) bpp_loss: 28.1960 (28.2288) patch_loss: 2.2865 (2.3086) token_loss: 0.6606 (0.6683) fea_loss: 1.4148 (1.6343) time: 0.9717 data: 0.0001 max mem: 11195
552
+ 2024-07-25 08:58:30,784 [INFO ] Epoch: [0] [ 520/4448] eta: 1:04:05 lr: 0.000001 loss: 7.1758 (7.4292) task_loss: 4.3564 (4.6064) bpp_loss: 28.1949 (28.2282) patch_loss: 2.2844 (2.3081) token_loss: 0.6636 (0.6683) fea_loss: 1.4082 (1.6299) time: 0.9705 data: 0.0001 max mem: 11195
553
+ 2024-07-25 08:58:40,522 [INFO ] Epoch: [0] [ 530/4448] eta: 1:03:55 lr: 0.000001 loss: 7.1598 (7.4244) task_loss: 4.3403 (4.6017) bpp_loss: 28.1940 (28.2275) patch_loss: 2.2851 (2.3077) token_loss: 0.6642 (0.6683) fea_loss: 1.4067 (1.6258) time: 0.9715 data: 0.0001 max mem: 11195
554
+ 2024-07-25 08:58:50,244 [INFO ] Epoch: [0] [ 540/4448] eta: 1:03:45 lr: 0.000001 loss: 7.1524 (7.4195) task_loss: 4.3331 (4.5969) bpp_loss: 28.1934 (28.2269) patch_loss: 2.2790 (2.3072) token_loss: 0.6654 (0.6681) fea_loss: 1.4014 (1.6216) time: 0.9729 data: 0.0001 max mem: 11195
555
+ 2024-07-25 08:59:00,201 [INFO ] Epoch: [0] [ 550/4448] eta: 1:03:36 lr: 0.000001 loss: 7.1595 (7.4148) task_loss: 4.3401 (4.5922) bpp_loss: 28.1927 (28.2262) patch_loss: 2.2758 (2.3066) token_loss: 0.6682 (0.6682) fea_loss: 1.3959 (1.6174) time: 0.9839 data: 0.0001 max mem: 11195
1_feature_extractor/log/DINOv2_training/log/20240725_085916.log ADDED
The diff for this file is too large to render. See raw diff
 
1_feature_extractor/log/DINOv2_training/log/20240726_110417.log ADDED
The diff for this file is too large to render. See raw diff
 
1_feature_extractor/log/DINOv2_training/log/20240726_171814.log ADDED
The diff for this file is too large to render. See raw diff
 
1_feature_extractor/log/DINOv2_training/log/20240728_153020.log ADDED
The diff for this file is too large to render. See raw diff
 
1_feature_extractor/log/DINOv2_training/log/20240728_214526.log ADDED
The diff for this file is too large to render. See raw diff
 
1_feature_extractor/log/DINOv2_training/log/20240729_102738.log ADDED
The diff for this file is too large to render. See raw diff
 
1_feature_extractor/log/DINOv2_training/log/20240730_084148.log ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2024-07-30 08:41:48,409 [INFO ] Logging file is /data0/qiyp/Proteus-pytorch/pretrain/log/DINOv2_training/log//20240730_084148.log
2
+ 2024-07-30 08:41:48,410 [INFO ] job dir: /data0/qiyp/Proteus-pytorch/pretrain
3
+ 2024-07-30 08:41:48,410 [INFO ] Namespace(batch_size=48,
4
+ epochs=200,
5
+ bce_loss=False,
6
+ unscale_lr=False,
7
+ model='models_proteus_dinov2',
8
+ target_model='vit_base',
9
+ input_size=224,
10
+ drop=0.0,
11
+ drop_path=0.1,
12
+ model_ema=True,
13
+ model_ema_decay=0.99996,
14
+ model_ema_force_cpu=False,
15
+ opt='adamw',
16
+ opt_eps=1e-08,
17
+ opt_betas=None,
18
+ clip_grad=None,
19
+ momentum=0.9,
20
+ weight_decay=0.05,
21
+ sched='cosine',
22
+ lr=0.0002,
23
+ lr_noise=None,
24
+ lr_noise_pct=0.67,
25
+ lr_noise_std=1.0,
26
+ warmup_lr=1e-06,
27
+ min_lr=1e-05,
28
+ decay_epochs=30,
29
+ warmup_epochs=5,
30
+ cooldown_epochs=10,
31
+ patience_epochs=10,
32
+ decay_rate=0.1,
33
+ color_jitter=0.3,
34
+ aa='rand-m9-mstd0.5-inc1',
35
+ smoothing=0.1,
36
+ train_interpolation='bicubic',
37
+ repeated_aug=True,
38
+ train_mode=True,
39
+ ThreeAugment=False,
40
+ src=False,
41
+ global_crops_size=224,
42
+ patch_size=14,
43
+ mask_ratio=[0.5],
44
+ mask_probability=0.5,
45
+ mask_first_n=True,
46
+ clone_batch=1,
47
+ reprob=0.25,
48
+ remode='pixel',
49
+ recount=1,
50
+ resplit=False,
51
+ mixup=0.8,
52
+ cutmix=1.0,
53
+ cutmix_minmax=None,
54
+ mixup_prob=1.0,
55
+ mixup_switch_prob=0.5,
56
+ mixup_mode='batch',
57
+ teacher_model='vit_large',
58
+ teacher_path='',
59
+ distillation_type='none',
60
+ distillation_alpha=0.5,
61
+ distillation_tau=1.0,
62
+ lambda_token=1.0,
63
+ lambda_fea=1.1,
64
+ lambda_patch=1.0,
65
+ cosub=False,
66
+ finetune='',
67
+ attn_only=False,
68
+ weight_inherit='',
69
+ data_path='/data1/datasets/imagenet_fold',
70
+ data_set='IMNET',
71
+ inat_category='name',
72
+ output_dir='log/DINOv2_training',
73
+ log_dir='/data0/qiyp/Proteus-pytorch/pretrain/log/DINOv2_training/log/',
74
+ device='cuda',
75
+ seed=0,
76
+ resume='/data0/qiyp/Proteus-pytorch/pretrain/log/DINOv2_training/checkpoint0080.pth',
77
+ start_epoch=0,
78
+ eval=False,
79
+ eval_crop_ratio=0.875,
80
+ dist_eval=False,
81
+ num_workers=10,
82
+ pin_mem=True,
83
+ distributed=True,
84
+ world_size=6,
85
+ dist_url='env://',
86
+ rank=0,
87
+ gpu=0,
88
+ dist_backend='nccl')
89
+
90
+ 2024-07-30 08:41:52,685 [INFO ] Dataset ImageFolder
91
+ Number of datapoints: 1281167
92
+ Root location: /data1/datasets/imagenet_fold/train
93
+ StandardTransform
94
+ Transform: Compose(
95
+ RandomResizedCropAndInterpolation(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=bicubic)
96
+ RandomHorizontalFlip(p=0.5)
97
+ RandAugment(n=2, ops=
98
+ AugmentOp(name=AutoContrast, p=0.5, m=9, mstd=0.5)
99
+ AugmentOp(name=Equalize, p=0.5, m=9, mstd=0.5)
100
+ AugmentOp(name=Invert, p=0.5, m=9, mstd=0.5)
101
+ AugmentOp(name=Rotate, p=0.5, m=9, mstd=0.5)
102
+ AugmentOp(name=PosterizeIncreasing, p=0.5, m=9, mstd=0.5)
103
+ AugmentOp(name=SolarizeIncreasing, p=0.5, m=9, mstd=0.5)
104
+ AugmentOp(name=SolarizeAdd, p=0.5, m=9, mstd=0.5)
105
+ AugmentOp(name=ColorIncreasing, p=0.5, m=9, mstd=0.5)
106
+ AugmentOp(name=ContrastIncreasing, p=0.5, m=9, mstd=0.5)
107
+ AugmentOp(name=BrightnessIncreasing, p=0.5, m=9, mstd=0.5)
108
+ AugmentOp(name=SharpnessIncreasing, p=0.5, m=9, mstd=0.5)
109
+ AugmentOp(name=ShearX, p=0.5, m=9, mstd=0.5)
110
+ AugmentOp(name=ShearY, p=0.5, m=9, mstd=0.5)
111
+ AugmentOp(name=TranslateXRel, p=0.5, m=9, mstd=0.5)
112
+ AugmentOp(name=TranslateYRel, p=0.5, m=9, mstd=0.5))
113
+ ToTensor()
114
+ Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
115
+ RandomErasing(p=0.25, mode=pixel, count=(1, 1))
116
+ )
117
+ 2024-07-30 08:41:52,760 [INFO ] Sampler_train = <samplers.RASampler object at 0x7efd92aaf3a0>
118
+ 2024-07-30 08:41:52,792 [INFO ] Created a temporary directory at /tmp/tmp1r345n4h
119
+ 2024-07-30 08:41:52,793 [INFO ] Writing /tmp/tmp1r345n4h/_remote_module_non_scriptable.py
120
+ 2024-07-30 08:42:03,348 [INFO ] using MLP layer as FFN
121
+ 2024-07-30 08:42:05,496 [INFO ] using MLP layer as FFN
122
+ 2024-07-30 08:42:13,281 [INFO ] Model = MetaArch(
123
+ (student): ModuleDict(
124
+ (backbone): DinoVisionTransformer(
125
+ (patch_embed): PatchEmbed(
126
+ (proj): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
127
+ (norm): Identity()
128
+ )
129
+ (blocks): ModuleList(
130
+ (0-11): 12 x Block(
131
+ (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
132
+ (attn): MemEffAttention(
133
+ (qkv): Linear(in_features=768, out_features=2304, bias=True)
134
+ (attn_drop): Dropout(p=0.0, inplace=False)
135
+ (proj): Linear(in_features=768, out_features=768, bias=True)
136
+ (proj_drop): Dropout(p=0.0, inplace=False)
137
+ )
138
+ (ls1): LayerScale()
139
+ (drop_path1): Identity()
140
+ (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
141
+ (mlp): Mlp(
142
+ (fc1): Linear(in_features=768, out_features=3072, bias=True)
143
+ (act): GELU(approximate='none')
144
+ (fc2): Linear(in_features=3072, out_features=768, bias=True)
145
+ (drop): Dropout(p=0.0, inplace=False)
146
+ )
147
+ (ls2): LayerScale()
148
+ (drop_path2): Identity()
149
+ )
150
+ )
151
+ (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
152
+ (head): Identity()
153
+ )
154
+ )
155
+ (teacher): ModuleDict(
156
+ (backbone): DinoVisionTransformer(
157
+ (patch_embed): PatchEmbed(
158
+ (proj): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14))
159
+ (norm): Identity()
160
+ )
161
+ (blocks): ModuleList(
162
+ (0-23): 24 x NestedTensorBlock(
163
+ (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
164
+ (attn): MemEffAttention(
165
+ (qkv): Linear(in_features=1024, out_features=3072, bias=True)
166
+ (attn_drop): Dropout(p=0.0, inplace=False)
167
+ (proj): Linear(in_features=1024, out_features=1024, bias=True)
168
+ (proj_drop): Dropout(p=0.0, inplace=False)
169
+ )
170
+ (ls1): LayerScale()
171
+ (drop_path1): Identity()
172
+ (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
173
+ (mlp): Mlp(
174
+ (fc1): Linear(in_features=1024, out_features=4096, bias=True)
175
+ (act): GELU(approximate='none')
176
+ (fc2): Linear(in_features=4096, out_features=1024, bias=True)
177
+ (drop): Dropout(p=0.0, inplace=False)
178
+ )
179
+ (ls2): LayerScale()
180
+ (drop_path2): Identity()
181
+ )
182
+ )
183
+ (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
184
+ (head): Identity()
185
+ )
186
+ )
187
+ (ibot_head): Sequential(
188
+ (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
189
+ (1): Linear(in_features=768, out_features=1024, bias=True)
190
+ )
191
+ (token_head): Sequential(
192
+ (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
193
+ (1): Linear(in_features=768, out_features=1024, bias=True)
194
+ )
195
+ (fea_head): Sequential(
196
+ (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
197
+ (1): Linear(in_features=768, out_features=1024, bias=True)
198
+ )
199
+ (soft_criterion): MSELoss()
200
+ (info_bottleneck): IF_Module(
201
+ (encoder_blocks): ModuleList(
202
+ (0-3): 4 x Block(
203
+ (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
204
+ (attn): Attention(
205
+ (qkv): Linear(in_features=768, out_features=2304, bias=True)
206
+ (q_norm): Identity()
207
+ (k_norm): Identity()
208
+ (attn_drop): Dropout(p=0.0, inplace=False)
209
+ (proj): Linear(in_features=768, out_features=768, bias=True)
210
+ (proj_drop): Dropout(p=0.0, inplace=False)
211
+ )
212
+ (ls1): Identity()
213
+ (drop_path1): Identity()
214
+ (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
215
+ (mlp): Mlp(
216
+ (fc1): Linear(in_features=768, out_features=3072, bias=True)
217
+ (act): GELU(approximate='none')
218
+ (drop1): Dropout(p=0.0, inplace=False)
219
+ (norm): Identity()
220
+ (fc2): Linear(in_features=3072, out_features=768, bias=True)
221
+ (drop2): Dropout(p=0.0, inplace=False)
222
+ )
223
+ (ls2): Identity()
224
+ (drop_path2): Identity()
225
+ )
226
+ )
227
+ (encoder_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
228
+ (decoder_blocks): ModuleList(
229
+ (0-3): 4 x Block(
230
+ (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
231
+ (attn): Attention(
232
+ (qkv): Linear(in_features=768, out_features=2304, bias=True)
233
+ (q_norm): Identity()
234
+ (k_norm): Identity()
235
+ (attn_drop): Dropout(p=0.0, inplace=False)
236
+ (proj): Linear(in_features=768, out_features=768, bias=True)
237
+ (proj_drop): Dropout(p=0.0, inplace=False)
238
+ )
239
+ (ls1): Identity()
240
+ (drop_path1): Identity()
241
+ (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
242
+ (mlp): Mlp(
243
+ (fc1): Linear(in_features=768, out_features=3072, bias=True)
244
+ (act): GELU(approximate='none')
245
+ (drop1): Dropout(p=0.0, inplace=False)
246
+ (norm): Identity()
247
+ (fc2): Linear(in_features=3072, out_features=768, bias=True)
248
+ (drop2): Dropout(p=0.0, inplace=False)
249
+ )
250
+ (ls2): Identity()
251
+ (drop_path2): Identity()
252
+ )
253
+ )
254
+ (decoder_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
255
+ (entropy_bottleneck): EntropyBottleneck(
256
+ (likelihood_lower_bound): LowerBound()
257
+ )
258
+ )
259
+ )
260
+ 2024-07-30 08:42:16,057 [INFO ] number of params: 144845568
261
+ 2024-07-30 08:42:16,057 [INFO ] base lr = 0.0002
262
+ 2024-07-30 08:42:16,058 [INFO ] actural lr = 0.00011250000000000001
263
+ 2024-07-30 08:42:35,018 [INFO ] Loaded state_dict_ema
264
+ 2024-07-30 08:42:35,076 [INFO ] Resuming from /data0/qiyp/Proteus-pytorch/pretrain/log/DINOv2_training/checkpoint0080.pth
265
+ 2024-07-30 08:42:35,076 [INFO ] Start training for 200 epochs
266
+ 2024-07-30 08:42:47,484 [INFO ] Epoch: [81] [ 0/4448] eta: 15:19:18 lr: 0.000076 loss: 2.8205 (2.8205) task_loss: 2.6200 (2.6200) bpp_loss: 0.2883 (0.2883) patch_loss: 0.9649 (0.9649) token_loss: 0.8029 (0.8029) fea_loss: 0.8522 (0.8522) time: 12.4007 data: 3.4410 max mem: 20527
267
+ 2024-07-30 08:43:06,985 [INFO ] Epoch: [81] [ 20/4448] eta: 1:52:06 lr: 0.000076 loss: 2.8058 (2.8015) task_loss: 2.5992 (2.5967) bpp_loss: 0.2986 (0.2988) patch_loss: 0.9937 (0.9960) token_loss: 0.7355 (0.7480) fea_loss: 0.8571 (0.8527) time: 0.9749 data: 0.0001 max mem: 21093
268
+ 2024-07-30 08:43:26,649 [INFO ] Epoch: [81] [ 40/4448] eta: 1:32:23 lr: 0.000076 loss: 2.7488 (2.7856) task_loss: 2.5421 (2.5804) bpp_loss: 0.3020 (0.3006) patch_loss: 0.9832 (0.9945) token_loss: 0.7257 (0.7364) fea_loss: 0.8415 (0.8495) time: 0.9831 data: 0.0001 max mem: 21095
269
+ 2024-07-30 08:43:46,729 [INFO ] Epoch: [81] [ 60/4448] eta: 1:25:53 lr: 0.000076 loss: 2.7729 (2.7857) task_loss: 2.5642 (2.5803) bpp_loss: 0.3040 (0.3017) patch_loss: 1.0076 (0.9998) token_loss: 0.7200 (0.7333) fea_loss: 0.8399 (0.8472) time: 1.0040 data: 0.0001 max mem: 21095
270
+ 2024-07-30 08:44:07,288 [INFO ] Epoch: [81] [ 80/4448] eta: 1:22:51 lr: 0.000076 loss: 2.8271 (2.7941) task_loss: 2.6186 (2.5881) bpp_loss: 0.3040 (0.3025) patch_loss: 1.0128 (1.0035) token_loss: 0.7397 (0.7346) fea_loss: 0.8573 (0.8500) time: 1.0279 data: 0.0001 max mem: 21095
271
+ 2024-07-30 08:44:27,381 [INFO ] Epoch: [81] [ 100/4448] eta: 1:20:34 lr: 0.000076 loss: 2.8005 (2.8008) task_loss: 2.5913 (2.5942) bpp_loss: 0.3074 (0.3034) patch_loss: 1.0188 (1.0078) token_loss: 0.7295 (0.7347) fea_loss: 0.8552 (0.8517) time: 1.0046 data: 0.0001 max mem: 21095
272
+ 2024-07-30 08:44:47,266 [INFO ] Epoch: [81] [ 120/4448] eta: 1:18:47 lr: 0.000076 loss: 2.7752 (2.7995) task_loss: 2.5660 (2.5928) bpp_loss: 0.3070 (0.3040) patch_loss: 0.9996 (1.0060) token_loss: 0.7348 (0.7359) fea_loss: 0.8485 (0.8509) time: 0.9942 data: 0.0001 max mem: 21095
273
+ 2024-07-30 08:45:07,249 [INFO ] Epoch: [81] [ 140/4448] eta: 1:17:28 lr: 0.000076 loss: 2.8033 (2.8006) task_loss: 2.5930 (2.5937) bpp_loss: 0.3064 (0.3044) patch_loss: 1.0073 (1.0067) token_loss: 0.7371 (0.7357) fea_loss: 0.8503 (0.8513) time: 0.9991 data: 0.0002 max mem: 21096
274
+ 2024-07-30 08:45:27,250 [INFO ] Epoch: [81] [ 160/4448] eta: 1:16:25 lr: 0.000076 loss: 2.7980 (2.8014) task_loss: 2.5895 (2.5944) bpp_loss: 0.3069 (0.3047) patch_loss: 1.0087 (1.0062) token_loss: 0.7382 (0.7361) fea_loss: 0.8631 (0.8520) time: 1.0000 data: 0.0001 max mem: 21096
275
+ 2024-07-30 08:45:47,340 [INFO ] Epoch: [81] [ 180/4448] eta: 1:15:33 lr: 0.000076 loss: 2.8292 (2.8018) task_loss: 2.6184 (2.5945) bpp_loss: 0.3091 (0.3052) patch_loss: 1.0158 (1.0076) token_loss: 0.7273 (0.7351) fea_loss: 0.8446 (0.8518) time: 1.0044 data: 0.0001 max mem: 21096
276
+ 2024-07-30 08:46:07,943 [INFO ] Epoch: [81] [ 200/4448] eta: 1:14:58 lr: 0.000076 loss: 2.7974 (2.7999) task_loss: 2.5896 (2.5926) bpp_loss: 0.3082 (0.3055) patch_loss: 0.9949 (1.0069) token_loss: 0.7182 (0.7340) fea_loss: 0.8525 (0.8516) time: 1.0301 data: 0.0001 max mem: 21096
277
+ 2024-07-30 08:46:28,477 [INFO ] Epoch: [81] [ 220/4448] eta: 1:14:24 lr: 0.000076 loss: 2.7864 (2.8008) task_loss: 2.5794 (2.5933) bpp_loss: 0.3078 (0.3057) patch_loss: 1.0100 (1.0072) token_loss: 0.7306 (0.7342) fea_loss: 0.8441 (0.8518) time: 1.0266 data: 0.0001 max mem: 21096
278
+ 2024-07-30 08:46:48,558 [INFO ] Epoch: [81] [ 240/4448] eta: 1:13:45 lr: 0.000076 loss: 2.8172 (2.8027) task_loss: 2.6068 (2.5950) bpp_loss: 0.3095 (0.3060) patch_loss: 1.0120 (1.0081) token_loss: 0.7279 (0.7344) fea_loss: 0.8583 (0.8525) time: 1.0040 data: 0.0001 max mem: 21096
279
+ 2024-07-30 08:47:08,665 [INFO ] Epoch: [81] [ 260/4448] eta: 1:13:09 lr: 0.000076 loss: 2.7650 (2.8013) task_loss: 2.5573 (2.5936) bpp_loss: 0.3086 (0.3063) patch_loss: 0.9966 (1.0075) token_loss: 0.7286 (0.7338) fea_loss: 0.8450 (0.8523) time: 1.0053 data: 0.0001 max mem: 21096
280
+ 2024-07-30 08:47:28,757 [INFO ] Epoch: [81] [ 280/4448] eta: 1:12:35 lr: 0.000076 loss: 2.7513 (2.7988) task_loss: 2.5452 (2.5910) bpp_loss: 0.3100 (0.3066) patch_loss: 0.9790 (1.0058) token_loss: 0.7214 (0.7333) fea_loss: 0.8448 (0.8518) time: 1.0045 data: 0.0001 max mem: 21096
281
+ 2024-07-30 08:47:48,814 [INFO ] Epoch: [81] [ 300/4448] eta: 1:12:03 lr: 0.000076 loss: 2.7910 (2.7984) task_loss: 2.5833 (2.5906) bpp_loss: 0.3097 (0.3068) patch_loss: 0.9979 (1.0056) token_loss: 0.7220 (0.7334) fea_loss: 0.8453 (0.8515) time: 1.0028 data: 0.0001 max mem: 21096
282
+ 2024-07-30 08:48:08,893 [INFO ] Epoch: [81] [ 320/4448] eta: 1:11:32 lr: 0.000076 loss: 2.7957 (2.7991) task_loss: 2.5868 (2.5911) bpp_loss: 0.3101 (0.3071) patch_loss: 1.0096 (1.0065) token_loss: 0.7125 (0.7326) fea_loss: 0.8559 (0.8520) time: 1.0039 data: 0.0001 max mem: 21096
283
+ 2024-07-30 08:48:29,072 [INFO ] Epoch: [81] [ 340/4448] eta: 1:11:04 lr: 0.000076 loss: 2.7660 (2.7990) task_loss: 2.5585 (2.5909) bpp_loss: 0.3108 (0.3072) patch_loss: 0.9918 (1.0063) token_loss: 0.7228 (0.7327) fea_loss: 0.8468 (0.8519) time: 1.0089 data: 0.0001 max mem: 21096
284
+ 2024-07-30 08:48:49,101 [INFO ] Epoch: [81] [ 360/4448] eta: 1:10:35 lr: 0.000076 loss: 2.8059 (2.7999) task_loss: 2.5933 (2.5917) bpp_loss: 0.3136 (0.3076) patch_loss: 1.0057 (1.0066) token_loss: 0.7353 (0.7330) fea_loss: 0.8496 (0.8521) time: 1.0014 data: 0.0001 max mem: 21096
285
+ 2024-07-30 08:49:09,150 [INFO ] Epoch: [81] [ 380/4448] eta: 1:10:07 lr: 0.000076 loss: 2.7807 (2.7979) task_loss: 2.5695 (2.5897) bpp_loss: 0.3108 (0.3078) patch_loss: 0.9834 (1.0055) token_loss: 0.7254 (0.7327) fea_loss: 0.8475 (0.8515) time: 1.0024 data: 0.0001 max mem: 21096
286
+ 2024-07-30 08:49:29,184 [INFO ] Epoch: [81] [ 400/4448] eta: 1:09:40 lr: 0.000076 loss: 2.7576 (2.7964) task_loss: 2.5457 (2.5881) bpp_loss: 0.3135 (0.3081) patch_loss: 0.9908 (1.0050) token_loss: 0.7106 (0.7318) fea_loss: 0.8434 (0.8513) time: 1.0016 data: 0.0001 max mem: 21096
287
+ 2024-07-30 08:49:49,309 [INFO ] Epoch: [81] [ 420/4448] eta: 1:09:14 lr: 0.000076 loss: 2.7548 (2.7950) task_loss: 2.5465 (2.5866) bpp_loss: 0.3124 (0.3083) patch_loss: 0.9950 (1.0045) token_loss: 0.7273 (0.7312) fea_loss: 0.8426 (0.8510) time: 1.0062 data: 0.0002 max mem: 21096
288
+ 2024-07-30 08:50:09,308 [INFO ] Epoch: [81] [ 440/4448] eta: 1:08:47 lr: 0.000076 loss: 2.7648 (2.7940) task_loss: 2.5540 (2.5855) bpp_loss: 0.3125 (0.3085) patch_loss: 0.9801 (1.0036) token_loss: 0.7217 (0.7311) fea_loss: 0.8493 (0.8508) time: 0.9999 data: 0.0002 max mem: 21096
289
+ 2024-07-30 08:50:29,354 [INFO ] Epoch: [81] [ 460/4448] eta: 1:08:22 lr: 0.000076 loss: 2.7843 (2.7932) task_loss: 2.5758 (2.5847) bpp_loss: 0.3126 (0.3087) patch_loss: 0.9979 (1.0033) token_loss: 0.7206 (0.7308) fea_loss: 0.8484 (0.8506) time: 1.0022 data: 0.0002 max mem: 21096
290
+ 2024-07-30 08:50:49,382 [INFO ] Epoch: [81] [ 480/4448] eta: 1:07:57 lr: 0.000076 loss: 2.7398 (2.7920) task_loss: 2.5274 (2.5834) bpp_loss: 0.3156 (0.3090) patch_loss: 0.9877 (1.0032) token_loss: 0.7010 (0.7300) fea_loss: 0.8394 (0.8502) time: 1.0013 data: 0.0002 max mem: 21096
291
+ 2024-07-30 08:51:09,296 [INFO ] Epoch: [81] [ 500/4448] eta: 1:07:31 lr: 0.000076 loss: 2.7482 (2.7908) task_loss: 2.5427 (2.5822) bpp_loss: 0.3129 (0.3091) patch_loss: 0.9996 (1.0030) token_loss: 0.6940 (0.7294) fea_loss: 0.8401 (0.8499) time: 0.9956 data: 0.0002 max mem: 21096
292
+ 2024-07-30 08:51:29,262 [INFO ] Epoch: [81] [ 520/4448] eta: 1:07:07 lr: 0.000076 loss: 2.7785 (2.7908) task_loss: 2.5668 (2.5821) bpp_loss: 0.3141 (0.3093) patch_loss: 1.0087 (1.0028) token_loss: 0.7261 (0.7294) fea_loss: 0.8467 (0.8498) time: 0.9982 data: 0.0002 max mem: 21096
293
+ 2024-07-30 08:51:49,227 [INFO ] Epoch: [81] [ 540/4448] eta: 1:06:42 lr: 0.000076 loss: 2.8121 (2.7921) task_loss: 2.5996 (2.5833) bpp_loss: 0.3148 (0.3096) patch_loss: 0.9975 (1.0033) token_loss: 0.7236 (0.7298) fea_loss: 0.8581 (0.8502) time: 0.9982 data: 0.0002 max mem: 21096
294
+ 2024-07-30 08:52:09,093 [INFO ] Epoch: [81] [ 560/4448] eta: 1:06:17 lr: 0.000076 loss: 2.7914 (2.7926) task_loss: 2.5781 (2.5837) bpp_loss: 0.3144 (0.3097) patch_loss: 1.0098 (1.0037) token_loss: 0.7028 (0.7295) fea_loss: 0.8557 (0.8505) time: 0.9932 data: 0.0002 max mem: 21096
295
+ 2024-07-30 08:52:29,024 [INFO ] Epoch: [81] [ 580/4448] eta: 1:05:53 lr: 0.000076 loss: 2.8253 (2.7942) task_loss: 2.6127 (2.5851) bpp_loss: 0.3147 (0.3099) patch_loss: 1.0152 (1.0044) token_loss: 0.7303 (0.7297) fea_loss: 0.8694 (0.8511) time: 0.9965 data: 0.0002 max mem: 21096
296
+ 2024-07-30 08:52:48,932 [INFO ] Epoch: [81] [ 600/4448] eta: 1:05:30 lr: 0.000076 loss: 2.8256 (2.7949) task_loss: 2.6144 (2.5858) bpp_loss: 0.3141 (0.3101) patch_loss: 0.9969 (1.0043) token_loss: 0.7540 (0.7303) fea_loss: 0.8563 (0.8512) time: 0.9953 data: 0.0002 max mem: 21096
297
+ 2024-07-30 08:53:08,892 [INFO ] Epoch: [81] [ 620/4448] eta: 1:05:06 lr: 0.000076 loss: 2.7807 (2.7948) task_loss: 2.5716 (2.5856) bpp_loss: 0.3145 (0.3102) patch_loss: 1.0056 (1.0042) token_loss: 0.7300 (0.7303) fea_loss: 0.8505 (0.8510) time: 0.9980 data: 0.0002 max mem: 21096
298
+ 2024-07-30 08:53:28,936 [INFO ] Epoch: [81] [ 640/4448] eta: 1:04:44 lr: 0.000076 loss: 2.8072 (2.7954) task_loss: 2.5994 (2.5861) bpp_loss: 0.3138 (0.3103) patch_loss: 1.0003 (1.0044) token_loss: 0.7397 (0.7304) fea_loss: 0.8647 (0.8514) time: 1.0021 data: 0.0002 max mem: 21096
299
+ 2024-07-30 08:53:48,900 [INFO ] Epoch: [81] [ 660/4448] eta: 1:04:21 lr: 0.000076 loss: 2.7934 (2.7958) task_loss: 2.5801 (2.5865) bpp_loss: 0.3144 (0.3104) patch_loss: 0.9936 (1.0043) token_loss: 0.7337 (0.7307) fea_loss: 0.8504 (0.8515) time: 0.9982 data: 0.0002 max mem: 21096
300
+ 2024-07-30 08:54:08,946 [INFO ] Epoch: [81] [ 680/4448] eta: 1:03:58 lr: 0.000076 loss: 2.7859 (2.7956) task_loss: 2.5747 (2.5862) bpp_loss: 0.3140 (0.3106) patch_loss: 0.9925 (1.0043) token_loss: 0.7216 (0.7306) fea_loss: 0.8422 (0.8514) time: 1.0022 data: 0.0002 max mem: 21096
301
+ 2024-07-30 08:54:28,905 [INFO ] Epoch: [81] [ 700/4448] eta: 1:03:36 lr: 0.000076 loss: 2.7273 (2.7944) task_loss: 2.5182 (2.5851) bpp_loss: 0.3142 (0.3107) patch_loss: 0.9854 (1.0038) token_loss: 0.7153 (0.7301) fea_loss: 0.8333 (0.8511) time: 0.9979 data: 0.0001 max mem: 21096
1_feature_extractor/log/DINOv2_training/log/20240730_085449.log ADDED
The diff for this file is too large to render. See raw diff
 
1_feature_extractor/log/DINOv2_training/log/20240731_102940.log ADDED
The diff for this file is too large to render. See raw diff
 
1_feature_extractor/log/DINOv2_training/log/20240801_091959.log ADDED
The diff for this file is too large to render. See raw diff
 
1_feature_extractor/log/DINOv2_training/log/20240801_155326.log ADDED
The diff for this file is too large to render. See raw diff
 
1_feature_extractor/log/DINOv2_training/log/20240803_163338.log ADDED
The diff for this file is too large to render. See raw diff
 
1_feature_extractor/log/DINOv2_training/log/20240803_231933.log ADDED
The diff for this file is too large to render. See raw diff
 
1_feature_extractor/log/DINOv2_training/log/20240804_144252.log ADDED
The diff for this file is too large to render. See raw diff
 
1_feature_extractor/losses_hint.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2015-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ """
4
+ Implements the knowledge distillation loss
5
+ """
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+
10
+ class DistillationLoss(torch.nn.Module):
11
+ """
12
+ This module wraps a standard criterion and adds an extra knowledge distillation loss by
13
+ taking a teacher model prediction and using it as additional supervision.
14
+ """
15
+ def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
16
+ distillation_type: str, lambda_token: float, lambda_fea: float):
17
+ super().__init__()
18
+ self.base_criterion = base_criterion
19
+ self.teacher_model = teacher_model
20
+ assert distillation_type in ['none', 'soft', 'hard']
21
+ self.distillation_type = distillation_type
22
+ self.lambda_token = lambda_token
23
+ self.lambda_fea = lambda_fea
24
+ self.soft_criterion = torch.nn.MSELoss()
25
+
26
+ def forward(self, inputs, outputs, labels):
27
+ """
28
+ Args:
29
+ inputs: The original inputs that are feed to the teacher model
30
+ outputs: the outputs of the model to be trained. It is expected to be
31
+ either a Tensor, or a Tuple[Tensor, Tensor], with the original output
32
+ in the first position and the distillation predictions as the second output
33
+ labels: the labels for the base criterion
34
+ """
35
+ outputs_token, outputs_fea = outputs
36
+
37
+ # don't backprop throught the teacher
38
+ with torch.no_grad():
39
+ teacher_outputs = self.teacher_model.backbone.get_intermediate_layers(inputs, n=4, return_class_token=True)
40
+ teacher_outputs_token = teacher_outputs[3][1]
41
+ teacher_outputs_fea = torch.cat((teacher_outputs_token.unsqueeze(1),teacher_outputs[3][0]),dim=1)
42
+
43
+ distillation_loss_token = self.soft_criterion(outputs_token, teacher_outputs_token)
44
+ distillation_loss_fea = self.soft_criterion(outputs_fea, teacher_outputs_fea)
45
+
46
+ token_loss = self.lambda_token * distillation_loss_token
47
+ fea_loss = self.lambda_fea * distillation_loss_fea
48
+
49
+ return token_loss, fea_loss
1_feature_extractor/main.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2015-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ import argparse
4
+ import datetime
5
+ import numpy as np
6
+ import time
7
+ import torch
8
+ import torch.backends.cudnn as cudnn
9
+ import json
10
+
11
+ from pathlib import Path
12
+
13
+ from timm.models import create_model
14
+ from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
15
+ from timm.scheduler import create_scheduler
16
+ from timm.optim import create_optimizer
17
+ from timm.utils import NativeScaler, get_state_dict, ModelEma
18
+ from augmentations import collate_data_and_cast_aug
19
+ from datasets import build_dataset
20
+
21
+ from losses_hint import DistillationLoss
22
+ from samplers import RASampler
23
+ from functools import partial
24
+
25
+ import importlib
26
+ import utils
27
+ import random
28
+ import math
29
+ from multiprocessing import Value
30
+ from abc import ABC
31
+
32
+ import sys
33
+ from typing import Iterable, Optional
34
+ from timm.data import Mixup
35
+ from timm.utils import accuracy, ModelEma
36
+ import utils
37
+
38
+
39
+ class MaskingGenerator(ABC):
40
+ def __init__(self, input_size):
41
+ if not isinstance(input_size, tuple):
42
+ input_size = (input_size,) * 2
43
+ self.height, self.width = input_size
44
+ self.num_patches = self.height * self.width
45
+
46
+ def __repr__(self):
47
+ raise NotImplementedError
48
+
49
+ def get_shape(self):
50
+ return self.height, self.width
51
+
52
+ def _mask(self, mask, max_mask_patches):
53
+ raise NotImplementedError
54
+
55
+ def get_none_mask(self):
56
+ return np.zeros(shape=self.get_shape(), dtype=bool)
57
+
58
+
59
+
60
+ class RandomMaskingGenerator(MaskingGenerator):
61
+ def __init__(
62
+ self,
63
+ input_size,
64
+ ):
65
+ """
66
+ Args:
67
+ input_size: the size of the token map, e.g., 14x14
68
+ """
69
+ super().__init__(input_size)
70
+
71
+ def __repr__(self):
72
+ repr_str = f"Random Generator({self.height}, {self.width})"
73
+ return repr_str
74
+
75
+ def _mask(self, mask, max_mask_patches):
76
+ return super()._mask(mask, max_mask_patches)
77
+
78
+ def __call__(self, num_masking_patches=0):
79
+ if num_masking_patches <= 0:
80
+ return np.zeros(shape=self.get_shape(), dtype=bool)
81
+
82
+ mask = np.hstack([np.ones(num_masking_patches, dtype=bool),
83
+ np.zeros(self.num_patches - num_masking_patches, dtype=bool)])
84
+ np.random.shuffle(mask)
85
+ mask = mask.reshape(self.get_shape())
86
+ return mask
87
+
88
+
89
+ def get_args_parser():
90
+ parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
91
+ parser.add_argument('--batch-size', default=64, type=int)
92
+ parser.add_argument('--epochs', default=300, type=int)
93
+ parser.add_argument('--bce-loss', action='store_true')
94
+ parser.add_argument('--unscale-lr', action='store_true')
95
+
96
+ # Model parameters
97
+ parser.add_argument('--model', default='deit_base_patch16_224', type=str)
98
+ parser.add_argument('--target_model', default='deit_base_patch16_224', type=str)
99
+ parser.add_argument('--input-size', default=224, type=int, help='images input size')
100
+
101
+ parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
102
+ help='Dropout rate (default: 0.)')
103
+ parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
104
+ help='Drop path rate (default: 0.1)')
105
+
106
+ parser.add_argument('--model-ema', action='store_true')
107
+ parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
108
+ parser.set_defaults(model_ema=True)
109
+ parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
110
+ parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
111
+
112
+ # Optimizer parameters
113
+ parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
114
+ help='Optimizer (default: "adamw"')
115
+ parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
116
+ help='Optimizer Epsilon (default: 1e-8)')
117
+ parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
118
+ help='Optimizer Betas (default: None, use opt default)')
119
+ parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
120
+ help='Clip gradient norm (default: None, no clipping)')
121
+ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
122
+ help='SGD momentum (default: 0.9)')
123
+ parser.add_argument('--weight-decay', type=float, default=0.05,
124
+ help='weight decay (default: 0.05)')
125
+ # Learning rate schedule parameters
126
+ parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
127
+ help='LR scheduler (default: "cosine"')
128
+ parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
129
+ help='learning rate (default: 5e-4)')
130
+ parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
131
+ help='learning rate noise on/off epoch percentages')
132
+ parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
133
+ help='learning rate noise limit percent (default: 0.67)')
134
+ parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
135
+ help='learning rate noise std-dev (default: 1.0)')
136
+ parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
137
+ help='warmup learning rate (default: 1e-6)')
138
+ parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
139
+ help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
140
+
141
+ parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
142
+ help='epoch interval to decay LR')
143
+ parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
144
+ help='epochs to warmup LR, if scheduler supports')
145
+ parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
146
+ help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
147
+ parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
148
+ help='patience epochs for Plateau LR scheduler (default: 10')
149
+ parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
150
+ help='LR decay rate (default: 0.1)')
151
+
152
+ # Augmentation parameters
153
+ parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT',
154
+ help='Color jitter factor (default: 0.3)')
155
+ parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
156
+ help='Use AutoAugment policy. "v0" or "original". " + \
157
+ "(default: rand-m9-mstd0.5-inc1)'),
158
+ parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
159
+ parser.add_argument('--train-interpolation', type=str, default='bicubic',
160
+ help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
161
+
162
+ parser.add_argument('--repeated-aug', action='store_true')
163
+ parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
164
+ parser.set_defaults(repeated_aug=True)
165
+
166
+ parser.add_argument('--train-mode', action='store_true')
167
+ parser.add_argument('--no-train-mode', action='store_false', dest='train_mode')
168
+ parser.set_defaults(train_mode=True)
169
+
170
+ parser.add_argument('--ThreeAugment', action='store_true') #3augment
171
+
172
+ parser.add_argument('--src', action='store_true') #simple random crop
173
+
174
+ # add dataset parameters
175
+ parser.add_argument('--global_crops_size', '--img_size', default=224, type=int,
176
+ help="this should be equal to image size")
177
+ parser.add_argument('--patch_size', default=16, type=int,
178
+ help="patch size for vit patch embedding")
179
+
180
+ # add masking parameter
181
+ parser.add_argument('--mask_ratio', default=(0.1, 0.5), type=float, nargs='+',
182
+ help="mask ratio can be either a value or a range")
183
+ parser.add_argument('--mask_probability', default=0., type=float,
184
+ help="how many samples with be applied with masking")
185
+ parser.add_argument('--mask_first_n', action='store_true',
186
+ help="mask the first n sample to avoid shuffling. Needed for MAE-style encoder")
187
+ parser.add_argument('--clone_batch', default=1, type=int,
188
+ help="how many times to clone the batch for masking (default: 1, not cloning)")
189
+
190
+ # * Random Erase params
191
+ parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
192
+ help='Random erase prob (default: 0.25)')
193
+ parser.add_argument('--remode', type=str, default='pixel',
194
+ help='Random erase mode (default: "pixel")')
195
+ parser.add_argument('--recount', type=int, default=1,
196
+ help='Random erase count (default: 1)')
197
+ parser.add_argument('--resplit', action='store_true', default=False,
198
+ help='Do not random erase first (clean) augmentation split')
199
+
200
+ # * Mixup params
201
+ parser.add_argument('--mixup', type=float, default=0.8,
202
+ help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
203
+ parser.add_argument('--cutmix', type=float, default=1.0,
204
+ help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
205
+ parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
206
+ help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
207
+ parser.add_argument('--mixup-prob', type=float, default=1.0,
208
+ help='Probability of performing mixup or cutmix when either/both is enabled')
209
+ parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
210
+ help='Probability of switching to cutmix when both mixup and cutmix enabled')
211
+ parser.add_argument('--mixup-mode', type=str, default='batch',
212
+ help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
213
+
214
+ # Distillation parameters
215
+ parser.add_argument('--teacher-model', default='base', type=str)
216
+ parser.add_argument('--teacher-path', type=str, default='')
217
+ parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
218
+ parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
219
+ parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
220
+ parser.add_argument('--lambda_token', type=float, default=1.0)
221
+ parser.add_argument('--lambda_fea', type=float, default=1.0)
222
+ parser.add_argument('--lambda_patch', type=float, default=1.0)
223
+
224
+ # * Cosub params
225
+ parser.add_argument('--cosub', action='store_true')
226
+
227
+ # * Finetuning params
228
+ parser.add_argument('--finetune', default='', help='finetune from checkpoint')
229
+ parser.add_argument('--attn-only', action='store_true')
230
+ parser.add_argument('--weight_inherit', default='')
231
+
232
+ # Dataset parameters
233
+ parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
234
+ help='dataset path')
235
+ parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'IMNET_ibot', 'IMNET_ibot_aug', 'IMNET_ibot_fast_aug', 'INAT', 'INAT19', 'IMNET_L', 'IMNET_L_ibot'],
236
+ type=str, help='Image Net dataset path')
237
+ parser.add_argument('--inat-category', default='name',
238
+ choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
239
+ type=str, help='semantic granularity')
240
+
241
+ parser.add_argument('--output_dir', default='',
242
+ help='path where to save, empty for no saving')
243
+ parser.add_argument('--device', default='cuda',
244
+ help='device to use for training / testing')
245
+ parser.add_argument('--seed', default=0, type=int)
246
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
247
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
248
+ help='start epoch')
249
+ parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
250
+ parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation")
251
+ parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
252
+ parser.add_argument('--num_workers', default=10, type=int)
253
+ parser.add_argument('--pin-mem', action='store_true',
254
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
255
+ parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
256
+ help='')
257
+ parser.set_defaults(pin_mem=True)
258
+
259
+ # distributed training parameters
260
+ parser.add_argument('--distributed', action='store_true', default=False, help='Enabling distributed training')
261
+ parser.add_argument('--world_size', default=1, type=int,
262
+ help='number of distributed processes')
263
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
264
+ return parser
265
+
266
+
267
+ def main(args):
268
+ utils.init_distributed_mode(args)
269
+
270
+ print(args)
271
+
272
+ device = torch.device(args.device)
273
+
274
+ # fix the seed for reproducibility
275
+ seed = args.seed + utils.get_rank()
276
+ torch.manual_seed(seed)
277
+ np.random.seed(seed)
278
+ # random.seed(seed)
279
+
280
+ cudnn.benchmark = True
281
+
282
+ dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
283
+
284
+ if args.distributed:
285
+ num_tasks = utils.get_world_size()
286
+ global_rank = utils.get_rank()
287
+ if args.repeated_aug:
288
+ sampler_train = RASampler(
289
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
290
+ )
291
+ else:
292
+ sampler_train = torch.utils.data.DistributedSampler(
293
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
294
+ )
295
+ else:
296
+ sampler_train = torch.utils.data.RandomSampler(dataset_train)
297
+
298
+ n_tokens = (args.global_crops_size // args.patch_size) ** 2
299
+ mask_generator = RandomMaskingGenerator(
300
+ input_size=args.global_crops_size // args.patch_size,
301
+ )
302
+
303
+ collate_fn = partial(
304
+ collate_data_and_cast_aug,
305
+ mask_ratio=args.mask_ratio,
306
+ mask_probability=args.mask_probability,
307
+ dtype=torch.half, # half precision
308
+ n_tokens=n_tokens,
309
+ mask_first_n=args.mask_first_n,
310
+ mask_generator=mask_generator,
311
+ clone_batch=args.clone_batch,
312
+ )
313
+
314
+ data_loader_train = torch.utils.data.DataLoader(
315
+ dataset_train, sampler=sampler_train,
316
+ batch_size=args.batch_size,
317
+ num_workers=args.num_workers,
318
+ pin_memory=args.pin_mem,
319
+ drop_last=True,
320
+ collate_fn=collate_fn,
321
+ )
322
+
323
+ mixup_fn = None
324
+
325
+ print(f"Creating model: {args.model}")
326
+ meta_arch_module = importlib.import_module(args.model)
327
+ MetaArch = meta_arch_module.MetaArch
328
+
329
+ model = MetaArch(args)
330
+
331
+ if args.finetune:
332
+ checkpoint = torch.load(args.finetune, map_location='cpu')
333
+
334
+ if 'state_dict' in checkpoint:
335
+ pretrained_dict = checkpoint['state_dict']
336
+ elif 'model' in checkpoint:
337
+ pretrained_dict = checkpoint['model']
338
+ else:
339
+ pretrained_dict = checkpoint
340
+
341
+ missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False)
342
+ print('missing_keys: ', missing_keys)
343
+ print('unexpected_keys: ', unexpected_keys)
344
+
345
+ if args.attn_only:
346
+ for name_p,p in model.named_parameters():
347
+ if '.attn.' in name_p:
348
+ p.requires_grad = True
349
+ else:
350
+ p.requires_grad = False
351
+ try:
352
+ model.head.weight.requires_grad = True
353
+ model.head.bias.requires_grad = True
354
+ except:
355
+ model.fc.weight.requires_grad = True
356
+ model.fc.bias.requires_grad = True
357
+ try:
358
+ model.pos_embed.requires_grad = True
359
+ except:
360
+ print('no position encoding')
361
+ try:
362
+ for p in model.patch_embed.parameters():
363
+ p.requires_grad = False
364
+ except:
365
+ print('no patch embed')
366
+
367
+ model.to(device)
368
+
369
+ model_ema = None
370
+ if args.model_ema:
371
+ # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
372
+ model_ema = ModelEma(
373
+ model.student.backbone,
374
+ decay=args.model_ema_decay,
375
+ device='cpu' if args.model_ema_force_cpu else '',
376
+ resume='')
377
+
378
+ model_without_ddp = model
379
+ if args.distributed:
380
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
381
+ model_without_ddp = model.module
382
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
383
+ print('number of params:', n_parameters)
384
+
385
+ if not args.unscale_lr:
386
+ linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
387
+ args.lr = linear_scaled_lr
388
+
389
+ optimizer = create_optimizer(args, model_without_ddp)
390
+ loss_scaler = NativeScaler()
391
+
392
+ lr_scheduler, _ = create_scheduler(args, optimizer)
393
+
394
+ output_dir = Path(args.output_dir)
395
+ if args.resume:
396
+ if args.resume.startswith('https'):
397
+ checkpoint = torch.hub.load_state_dict_from_url(
398
+ args.resume, map_location='cpu', check_hash=True)
399
+ else:
400
+ checkpoint = torch.load(args.resume, map_location='cpu')
401
+
402
+ model_without_ddp.load_state_dict(checkpoint['model'])
403
+ if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
404
+ optimizer.load_state_dict(checkpoint['optimizer'])
405
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
406
+ args.start_epoch = checkpoint['epoch'] + 1
407
+ if args.model_ema:
408
+ utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
409
+ if 'scaler' in checkpoint:
410
+ loss_scaler.load_state_dict(checkpoint['scaler'])
411
+ lr_scheduler.step(args.start_epoch)
412
+
413
+ print(f"Start training for {args.epochs} epochs")
414
+ start_time = time.time()
415
+ max_accuracy = 0.0
416
+ for epoch in range(args.start_epoch, args.epochs):
417
+ if args.distributed:
418
+ data_loader_train.sampler.set_epoch(epoch)
419
+
420
+ train_stats = train_one_epoch(
421
+ model, data_loader_train,
422
+ optimizer, device, epoch, loss_scaler,
423
+ args.clip_grad, model_ema, mixup_fn,
424
+ set_training_mode=args.train_mode, # keep in eval mode for deit finetuning / train mode for training and deit III finetuning
425
+ args = args,
426
+ )
427
+
428
+ lr_scheduler.step(epoch)
429
+ if args.output_dir:
430
+ checkpoint_paths = [output_dir / 'checkpoint.pth']
431
+ for checkpoint_path in checkpoint_paths:
432
+ utils.save_on_master({
433
+ 'model': model_without_ddp.state_dict(),
434
+ 'optimizer': optimizer.state_dict(),
435
+ 'lr_scheduler': lr_scheduler.state_dict(),
436
+ 'epoch': epoch,
437
+ 'model_ema': get_state_dict(model_ema),
438
+ 'scaler': loss_scaler.state_dict(),
439
+ 'args': args,
440
+ }, checkpoint_path)
441
+
442
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
443
+ 'epoch': epoch,
444
+ 'n_parameters': n_parameters}
445
+
446
+ if args.output_dir and utils.is_main_process():
447
+ with (output_dir / "log.txt").open("a") as f:
448
+ f.write(json.dumps(log_stats) + "\n")
449
+
450
+ total_time = time.time() - start_time
451
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
452
+ print('Training time {}'.format(total_time_str))
453
+
454
+
455
+ def train_one_epoch(model: torch.nn.Module,
456
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
457
+ device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
458
+ model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
459
+ set_training_mode=True, args = None):
460
+ model.train(set_training_mode)
461
+ metric_logger = utils.MetricLogger(delimiter=" ")
462
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
463
+ header = 'Epoch: [{}]'.format(epoch)
464
+ print_freq = 10
465
+
466
+ loader_len = len(data_loader)
467
+
468
+ for data_iter_step, inputs_dict in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
469
+
470
+ for k, v in inputs_dict.items():
471
+ if isinstance(v, torch.Tensor):
472
+ inputs_dict[k] = v.to(device, non_blocking=True)
473
+
474
+ with torch.cuda.amp.autocast():
475
+ loss_dict = model(inputs_dict)
476
+
477
+ loss = loss_dict["loss"]
478
+ patch_loss = loss_dict["patch_loss"]
479
+ fea_loss = loss_dict["fea_loss"]
480
+ token_loss = loss_dict["token_loss"]
481
+
482
+ patch_loss_value = patch_loss.item()
483
+ token_loss_value = token_loss.item()
484
+ fea_loss_value = fea_loss.item()
485
+ loss_value = loss.item()
486
+
487
+
488
+ if not math.isfinite(loss_value):
489
+ print("Loss is {}, stopping training".format(loss_value))
490
+ sys.exit(1)
491
+
492
+ optimizer.zero_grad()
493
+
494
+ # this attribute is added by timm on one optimizer (adahessian)
495
+ is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
496
+ loss_scaler(loss, optimizer, clip_grad=max_norm,
497
+ parameters=model.parameters(), create_graph=is_second_order)
498
+
499
+ torch.cuda.synchronize()
500
+ if model_ema is not None:
501
+ model_ema.update(model.module.student.backbone)
502
+
503
+ metric_logger.update(loss=loss_value)
504
+ metric_logger.update(patch_loss=patch_loss_value)
505
+ metric_logger.update(token_loss=token_loss_value)
506
+ metric_logger.update(fea_loss=fea_loss_value)
507
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
508
+ # gather the stats from all processes
509
+ metric_logger.synchronize_between_processes()
510
+ print("Averaged stats:", metric_logger)
511
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
512
+
513
+
514
+
515
+ if __name__ == '__main__':
516
+ parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
517
+ args = parser.parse_args()
518
+ if args.output_dir:
519
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
520
+ main(args)
1_feature_extractor/models_IB.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from compressai.entropy_models import EntropyBottleneck
3
+ from timm.models.vision_transformer import Block
4
+
5
+ class IF_Module(nn.Module):
6
+ def __init__(self, embed_dim, num_heads, mlp_ratio, depth=4, norm_layer=nn.LayerNorm):
7
+ super(IF_Module, self).__init__()
8
+
9
+ self.encoder_blocks = nn.ModuleList([ # 4 layers, embed_dim=768, num_heads=12
10
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
11
+ for i in range(depth)])
12
+ self.encoder_norm = norm_layer(embed_dim)
13
+
14
+ self.decoder_blocks = nn.ModuleList([ # 4 layers, embed_dim=768, num_heads=12
15
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
16
+ for i in range(depth)])
17
+
18
+ self.decoder_norm = norm_layer(embed_dim)
19
+ self.entropy_bottleneck = EntropyBottleneck(embed_dim)
20
+
21
+ def forward(self, x, is_training=False):
22
+ # ViT analysis transform
23
+ for blk in self.encoder_blocks:
24
+ x = blk(x)
25
+ x = self.encoder_norm(x)
26
+
27
+ if is_training:
28
+ x = x.permute(0, 2, 1)
29
+ x_hat, x_likelihood = self.entropy_bottleneck(x)
30
+ x_hat = x_hat.permute(0, 2, 1)
31
+ else:
32
+ x_hat = x
33
+ x_likelihood = None
34
+
35
+ # ViT synthesis transform
36
+ for blk in self.decoder_blocks:
37
+ x_hat = blk(x_hat)
38
+ x_hat = self.decoder_norm(x_hat)
39
+
40
+ return x_hat, x_likelihood
1_feature_extractor/models_clip.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union, Callable
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from torch.nn.init import trunc_normal_
9
+
10
+
11
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
12
+ if not depth_first and include_root:
13
+ fn(module=module, name=name)
14
+ for child_name, child_module in module.named_children():
15
+ child_name = ".".join((name, child_name)) if name else child_name
16
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
17
+ if depth_first and include_root:
18
+ fn(module=module, name=name)
19
+ return module
20
+
21
+
22
+ class Bottleneck(nn.Module):
23
+ expansion = 4
24
+
25
+ def __init__(self, inplanes, planes, stride=1):
26
+ super().__init__()
27
+
28
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
29
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
30
+ self.bn1 = nn.BatchNorm2d(planes)
31
+ self.relu1 = nn.ReLU(inplace=True)
32
+
33
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
34
+ self.bn2 = nn.BatchNorm2d(planes)
35
+ self.relu2 = nn.ReLU(inplace=True)
36
+
37
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
38
+
39
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
40
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
41
+ self.relu3 = nn.ReLU(inplace=True)
42
+
43
+ self.downsample = None
44
+ self.stride = stride
45
+
46
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
47
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
48
+ self.downsample = nn.Sequential(OrderedDict([
49
+ ("-1", nn.AvgPool2d(stride)),
50
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
51
+ ("1", nn.BatchNorm2d(planes * self.expansion))
52
+ ]))
53
+
54
+ def forward(self, x: torch.Tensor):
55
+ identity = x
56
+
57
+ out = self.relu1(self.bn1(self.conv1(x)))
58
+ out = self.relu2(self.bn2(self.conv2(out)))
59
+ out = self.avgpool(out)
60
+ out = self.bn3(self.conv3(out))
61
+
62
+ if self.downsample is not None:
63
+ identity = self.downsample(x)
64
+
65
+ out += identity
66
+ out = self.relu3(out)
67
+ return out
68
+
69
+
70
+ class AttentionPool2d(nn.Module):
71
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
72
+ super().__init__()
73
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
74
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
75
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
76
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
77
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
78
+ self.num_heads = num_heads
79
+
80
+ def forward(self, x):
81
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
82
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
83
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
84
+ x, _ = F.multi_head_attention_forward(
85
+ query=x[:1], key=x, value=x,
86
+ embed_dim_to_check=x.shape[-1],
87
+ num_heads=self.num_heads,
88
+ q_proj_weight=self.q_proj.weight,
89
+ k_proj_weight=self.k_proj.weight,
90
+ v_proj_weight=self.v_proj.weight,
91
+ in_proj_weight=None,
92
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
93
+ bias_k=None,
94
+ bias_v=None,
95
+ add_zero_attn=False,
96
+ dropout_p=0,
97
+ out_proj_weight=self.c_proj.weight,
98
+ out_proj_bias=self.c_proj.bias,
99
+ use_separate_proj_weight=True,
100
+ training=self.training,
101
+ need_weights=False
102
+ )
103
+ return x.squeeze(0)
104
+
105
+
106
+ class ModifiedResNet(nn.Module):
107
+ """
108
+ A ResNet class that is similar to torchvision's but contains the following changes:
109
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
110
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
111
+ - The final pooling layer is a QKV attention instead of an average pool
112
+ """
113
+
114
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
115
+ super().__init__()
116
+ self.output_dim = output_dim
117
+ self.input_resolution = input_resolution
118
+
119
+ # the 3-layer stem
120
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
121
+ self.bn1 = nn.BatchNorm2d(width // 2)
122
+ self.relu1 = nn.ReLU(inplace=True)
123
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
124
+ self.bn2 = nn.BatchNorm2d(width // 2)
125
+ self.relu2 = nn.ReLU(inplace=True)
126
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
127
+ self.bn3 = nn.BatchNorm2d(width)
128
+ self.relu3 = nn.ReLU(inplace=True)
129
+ self.avgpool = nn.AvgPool2d(2)
130
+
131
+ # residual layers
132
+ self._inplanes = width # this is a *mutable* variable used during construction
133
+ self.layer1 = self._make_layer(width, layers[0])
134
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
135
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
136
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
137
+
138
+ embed_dim = width * 32 # the ResNet feature dimension
139
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
140
+
141
+ def _make_layer(self, planes, blocks, stride=1):
142
+ layers = [Bottleneck(self._inplanes, planes, stride)]
143
+
144
+ self._inplanes = planes * Bottleneck.expansion
145
+ for _ in range(1, blocks):
146
+ layers.append(Bottleneck(self._inplanes, planes))
147
+
148
+ return nn.Sequential(*layers)
149
+
150
+ def forward(self, x):
151
+ def stem(x):
152
+ x = self.relu1(self.bn1(self.conv1(x)))
153
+ x = self.relu2(self.bn2(self.conv2(x)))
154
+ x = self.relu3(self.bn3(self.conv3(x)))
155
+ x = self.avgpool(x)
156
+ return x
157
+
158
+ x = x.type(self.conv1.weight.dtype)
159
+ x = stem(x)
160
+ x = self.layer1(x)
161
+ x = self.layer2(x)
162
+ x = self.layer3(x)
163
+ x = self.layer4(x)
164
+ x = self.attnpool(x)
165
+
166
+ return x
167
+
168
+
169
+ class LayerNorm(nn.LayerNorm):
170
+ """Subclass torch's LayerNorm to handle fp16."""
171
+
172
+ def forward(self, x: torch.Tensor):
173
+ orig_type = x.dtype
174
+ ret = super().forward(x.type(torch.float32))
175
+ return ret.type(orig_type)
176
+
177
+
178
+ class QuickGELU(nn.Module):
179
+ def forward(self, x: torch.Tensor):
180
+ return x * torch.sigmoid(1.702 * x)
181
+
182
+
183
+ class ResidualAttentionBlock(nn.Module):
184
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
185
+ super().__init__()
186
+
187
+ self.attn = nn.MultiheadAttention(d_model, n_head)
188
+ self.ln_1 = LayerNorm(d_model)
189
+ self.mlp = nn.Sequential(OrderedDict([
190
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
191
+ ("gelu", QuickGELU()),
192
+ ("c_proj", nn.Linear(d_model * 4, d_model))
193
+ ]))
194
+ self.ln_2 = LayerNorm(d_model)
195
+ self.attn_mask = attn_mask
196
+
197
+ def attention(self, x: torch.Tensor):
198
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
199
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
200
+
201
+ def forward(self, x: torch.Tensor):
202
+ x = x + self.attention(self.ln_1(x))
203
+ x = x + self.mlp(self.ln_2(x))
204
+ return x
205
+
206
+
207
+ class Transformer(nn.Module):
208
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
209
+ super().__init__()
210
+ self.width = width
211
+ self.layers = layers
212
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
213
+
214
+ def forward(self, x: torch.Tensor):
215
+ x = x.permute(1, 0, 2) # NLD -> LND
216
+ x = self.resblocks(x)
217
+ x = x.permute(1, 0, 2) # LND -> NLD
218
+
219
+ return x
220
+
221
+
222
+ class VisionTransformer(nn.Module):
223
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int):
224
+ super().__init__()
225
+ self.input_resolution = input_resolution
226
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
227
+
228
+ scale = width ** -0.5
229
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
230
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
231
+ self.ln_pre = LayerNorm(width)
232
+
233
+ self.transformer = Transformer(width, layers, heads)
234
+
235
+ self.mask_token = nn.Parameter(torch.zeros(1, width))
236
+
237
+ self.ln_post = LayerNorm(width)
238
+
239
+ self.embed_dim = width
240
+ self.patch_size = patch_size
241
+
242
+ self.init_weights()
243
+
244
+ def init_weights(self):
245
+ trunc_normal_(self.positional_embedding, std=0.02)
246
+ nn.init.normal_(self.class_embedding, std=1e-6)
247
+ named_apply(init_weights_vit_timm, self)
248
+
249
+ def prepare_tokens_with_masks(self, x, masks=None):
250
+ x = self.conv1(x) # shape = [*, width, grid, grid]
251
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
252
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
253
+
254
+ if masks is not None:
255
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
256
+
257
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
258
+ x = x + self.positional_embedding.to(x.dtype)
259
+ x = self.ln_pre(x)
260
+
261
+ return x
262
+
263
+ def forward_features_list(self, x_list, masks_list):
264
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
265
+
266
+ all_x = [self.transformer(t) for t in x]
267
+
268
+ output = []
269
+ for x, masks in zip(all_x, masks_list):
270
+ output.append(
271
+ {
272
+ "x_norm_clstoken": self.ln_post(x[:, 0]),
273
+ "x_norm_patchtokens": x[:, 1 :],
274
+ "x_prenorm": x,
275
+ "masks": masks,
276
+ }
277
+ )
278
+ return output
279
+
280
+ def forward(self, x: torch.Tensor, masks=None):
281
+ if isinstance(x, list):
282
+ return self.forward_features_list(x, masks)
283
+
284
+ x = self.prepare_tokens_with_masks(x, masks)
285
+
286
+ x = self.transformer(x)
287
+
288
+ return {
289
+ "x_norm_clstoken": self.ln_post(x[:, 0]),
290
+ "x_norm_patchtokens": x[:, 1 :],
291
+ "x_prenorm": x,
292
+ "masks": masks,
293
+ }
294
+
295
+
296
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
297
+ """ViT weight initialization, original timm impl (for reproducibility)"""
298
+ if isinstance(module, nn.Linear):
299
+ trunc_normal_(module.weight, std=0.02)
300
+ if module.bias is not None:
301
+ nn.init.zeros_(module.bias)
302
+
303
+
304
+
305
+ def vit_small(patch_size=14, teacher_path=None):
306
+ model = VisionTransformer(
307
+ input_resolution=224,
308
+ patch_size=patch_size,
309
+ width=384,
310
+ layers=12,
311
+ heads=6
312
+ )
313
+
314
+ if teacher_path is not None:
315
+ checkpoint = torch.load(teacher_path, map_location='cpu')
316
+
317
+ if 'state_dict' in checkpoint:
318
+ pretrained_dict = checkpoint['state_dict']
319
+ elif 'model' in checkpoint:
320
+ pretrained_dict = checkpoint['model']
321
+ else:
322
+ pretrained_dict = checkpoint
323
+
324
+ missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False)
325
+ print('missing_keys: ', missing_keys)
326
+ print('unexpected_keys: ', unexpected_keys)
327
+
328
+ return model
329
+
330
+
331
+ def vit_base(patch_size=14, teacher_path=None):
332
+ model = VisionTransformer(
333
+ input_resolution=224,
334
+ patch_size=patch_size,
335
+ width=768,
336
+ layers=12,
337
+ heads=12
338
+ )
339
+
340
+ if teacher_path is not None:
341
+ checkpoint = torch.load(teacher_path, map_location='cpu')
342
+
343
+ if 'state_dict' in checkpoint:
344
+ pretrained_dict = checkpoint['state_dict']
345
+ elif 'model' in checkpoint:
346
+ pretrained_dict = checkpoint['model']
347
+ else:
348
+ pretrained_dict = checkpoint
349
+
350
+ missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False)
351
+ print('missing_keys: ', missing_keys)
352
+ print('unexpected_keys: ', unexpected_keys)
353
+
354
+ return model
355
+
356
+
357
+ def vit_large(patch_size=14, teacher_path=None):
358
+ model = VisionTransformer(
359
+ input_resolution=224,
360
+ patch_size=patch_size,
361
+ width=1024,
362
+ layers=24,
363
+ heads=16
364
+ )
365
+
366
+ if teacher_path is not None:
367
+ checkpoint = torch.load(teacher_path, map_location='cpu')
368
+
369
+ if 'state_dict' in checkpoint:
370
+ pretrained_dict = checkpoint['state_dict']
371
+ elif 'model' in checkpoint:
372
+ pretrained_dict = checkpoint['model']
373
+ else:
374
+ pretrained_dict = checkpoint
375
+
376
+ missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False)
377
+ print('missing_keys: ', missing_keys)
378
+ print('unexpected_keys: ', unexpected_keys)
379
+
380
+ return model
381
+
382
+
383
+
384
+ if __name__ == "__main__":
385
+ import argparse
386
+ import clip
387
+ import open_clip
388
+ from fvcore.nn import FlopCountAnalysis, parameter_count_table
389
+ parser = argparse.ArgumentParser(description='PyTorch resnet Training')
390
+ args = parser.parse_args()
391
+
392
+ # with torch.no_grad():
393
+ # print(clip.available_models())
394
+
395
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
396
+ # model, preprocess = clip.load('ViT-L/14', device)
397
+ # print(model.visual)
398
+
399
+ # model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', pretrained='laion400m_e32')
400
+ # model = model.to('cuda')
401
+
402
+ # for k,v in model.visual.named_parameters():
403
+ # print(k, v.shape)
404
+
405
+ # self_model = VisionTransformer(
406
+ # input_resolution=224,
407
+ # patch_size=32,
408
+ # width=768,
409
+ # layers=12,
410
+ # heads=12
411
+ # )
412
+
413
+ # print(self_model)
414
+
415
+ # for k,v in self_model.named_parameters():
416
+ # print(k, v.shape)
417
+
418
+ # new_ckpt = OrderedDict()
419
+ # for k,v in model.visual.named_parameters():
420
+ # if 'proj' != k:
421
+ # print(k)
422
+ # new_ckpt[k] = v
423
+ # new_ckpt[k] = v
424
+
425
+ # torch.save(new_ckpt, '/home/qw/yitian/TA-KD/clip_model/clip_l_14_400m.pth')
426
+
427
+ # model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion400m_e32')
428
+ # print(model.visual)
429
+
430
+
431
+ # model = clip_base_32()
432
+ # model = clip_base_14()
433
+
434
+ # print(parameter_count_table(model))
435
+
436
+ # tensor = torch.rand(1, 3, 224, 224)
437
+ # flops = FlopCountAnalysis(model, tensor)
438
+ # print("FLOPs: ", flops.total()/1e9)
1_feature_extractor/models_dinov2.py ADDED
@@ -0,0 +1,907 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable, Dict, Optional, Any, List
14
+ from torch import Tensor
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.utils.checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+ import os
20
+ import warnings
21
+
22
+ logger = logging.getLogger("dinov2")
23
+
24
+
25
+ from xformers.ops import memory_efficient_attention, unbind, fmha, scaled_index_add, index_select_cat
26
+ from xformers.ops import SwiGLU
27
+
28
+ XFORMERS_AVAILABLE = True
29
+ warnings.warn("xFormers is available (Attention)")
30
+
31
+
32
+
33
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
34
+ if drop_prob == 0.0 or not training:
35
+ return x
36
+ keep_prob = 1 - drop_prob
37
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
38
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
39
+ if keep_prob > 0.0:
40
+ random_tensor.div_(keep_prob)
41
+ output = x * random_tensor
42
+ return output
43
+
44
+
45
+ class DropPath(nn.Module):
46
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
47
+
48
+ def __init__(self, drop_prob=None):
49
+ super(DropPath, self).__init__()
50
+ self.drop_prob = drop_prob
51
+
52
+ def forward(self, x):
53
+ return drop_path(x, self.drop_prob, self.training)
54
+
55
+
56
+ class LayerScale(nn.Module):
57
+ def __init__(
58
+ self,
59
+ dim: int,
60
+ init_values: Union[float, Tensor] = 1e-5,
61
+ inplace: bool = False,
62
+ ) -> None:
63
+ super().__init__()
64
+ self.inplace = inplace
65
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
66
+
67
+ def forward(self, x: Tensor) -> Tensor:
68
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
69
+
70
+
71
+ class Attention(nn.Module):
72
+ def __init__(
73
+ self,
74
+ dim: int,
75
+ num_heads: int = 8,
76
+ qkv_bias: bool = False,
77
+ proj_bias: bool = True,
78
+ attn_drop: float = 0.0,
79
+ proj_drop: float = 0.0,
80
+ ) -> None:
81
+ super().__init__()
82
+ self.num_heads = num_heads
83
+ head_dim = dim // num_heads
84
+ self.scale = head_dim**-0.5
85
+
86
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
87
+ self.attn_drop = nn.Dropout(attn_drop)
88
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
89
+ self.proj_drop = nn.Dropout(proj_drop)
90
+
91
+ def forward(self, x: Tensor) -> Tensor:
92
+ B, N, C = x.shape
93
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
94
+
95
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
96
+ attn = q @ k.transpose(-2, -1)
97
+
98
+ attn = attn.softmax(dim=-1)
99
+ attn = self.attn_drop(attn)
100
+
101
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
102
+ x = self.proj(x)
103
+ x = self.proj_drop(x)
104
+ return x
105
+
106
+
107
+ class MemEffAttention(Attention):
108
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
109
+ if not XFORMERS_AVAILABLE:
110
+ if attn_bias is not None:
111
+ raise AssertionError("xFormers is required for using nested tensors")
112
+ return super().forward(x)
113
+
114
+ B, N, C = x.shape
115
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
116
+
117
+ q, k, v = unbind(qkv, 2)
118
+
119
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
120
+ x = x.reshape([B, N, C])
121
+
122
+ x = self.proj(x)
123
+ x = self.proj_drop(x)
124
+ return x
125
+
126
+
127
+
128
+ class SwiGLUFFN(nn.Module):
129
+ def __init__(
130
+ self,
131
+ in_features: int,
132
+ hidden_features: Optional[int] = None,
133
+ out_features: Optional[int] = None,
134
+ act_layer: Callable[..., nn.Module] = None,
135
+ drop: float = 0.0,
136
+ bias: bool = True,
137
+ ) -> None:
138
+ super().__init__()
139
+ out_features = out_features or in_features
140
+ hidden_features = hidden_features or in_features
141
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
142
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
143
+
144
+ def forward(self, x: Tensor) -> Tensor:
145
+ x12 = self.w12(x)
146
+ x1, x2 = x12.chunk(2, dim=-1)
147
+ hidden = F.silu(x1) * x2
148
+ return self.w3(hidden)
149
+
150
+
151
+
152
+ class SwiGLUFFNFused(SwiGLU):
153
+ def __init__(
154
+ self,
155
+ in_features: int,
156
+ hidden_features: Optional[int] = None,
157
+ out_features: Optional[int] = None,
158
+ act_layer: Callable[..., nn.Module] = None,
159
+ drop: float = 0.0,
160
+ bias: bool = True,
161
+ ) -> None:
162
+ out_features = out_features or in_features
163
+ hidden_features = hidden_features or in_features
164
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
165
+ super().__init__(
166
+ in_features=in_features,
167
+ hidden_features=hidden_features,
168
+ out_features=out_features,
169
+ bias=bias,
170
+ )
171
+
172
+
173
+
174
+ def make_2tuple(x):
175
+ if isinstance(x, tuple):
176
+ assert len(x) == 2
177
+ return x
178
+
179
+ assert isinstance(x, int)
180
+ return (x, x)
181
+
182
+
183
+ class PatchEmbed(nn.Module):
184
+ """
185
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
186
+
187
+ Args:
188
+ img_size: Image size.
189
+ patch_size: Patch token size.
190
+ in_chans: Number of input image channels.
191
+ embed_dim: Number of linear projection output channels.
192
+ norm_layer: Normalization layer.
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ img_size: Union[int, Tuple[int, int]] = 224,
198
+ patch_size: Union[int, Tuple[int, int]] = 16,
199
+ in_chans: int = 3,
200
+ embed_dim: int = 768,
201
+ norm_layer: Optional[Callable] = None,
202
+ flatten_embedding: bool = True,
203
+ ) -> None:
204
+ super().__init__()
205
+
206
+ image_HW = make_2tuple(img_size)
207
+ patch_HW = make_2tuple(patch_size)
208
+ patch_grid_size = (
209
+ image_HW[0] // patch_HW[0],
210
+ image_HW[1] // patch_HW[1],
211
+ )
212
+
213
+ self.img_size = image_HW
214
+ self.patch_size = patch_HW
215
+ self.patches_resolution = patch_grid_size
216
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
217
+
218
+ self.in_chans = in_chans
219
+ self.embed_dim = embed_dim
220
+
221
+ self.flatten_embedding = flatten_embedding
222
+
223
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
224
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
225
+
226
+ def forward(self, x: Tensor) -> Tensor:
227
+ _, _, H, W = x.shape
228
+ patch_H, patch_W = self.patch_size
229
+
230
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
231
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
232
+
233
+ x = self.proj(x) # B C H W
234
+ H, W = x.size(2), x.size(3)
235
+ x = x.flatten(2).transpose(1, 2) # B HW C
236
+ x = self.norm(x)
237
+ if not self.flatten_embedding:
238
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
239
+ return x
240
+
241
+ def flops(self) -> float:
242
+ Ho, Wo = self.patches_resolution
243
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
244
+ if self.norm is not None:
245
+ flops += Ho * Wo * self.embed_dim
246
+ return flops
247
+
248
+
249
+ class Mlp(nn.Module):
250
+ def __init__(
251
+ self,
252
+ in_features: int,
253
+ hidden_features: Optional[int] = None,
254
+ out_features: Optional[int] = None,
255
+ act_layer: Callable[..., nn.Module] = nn.GELU,
256
+ drop: float = 0.0,
257
+ bias: bool = True,
258
+ ) -> None:
259
+ super().__init__()
260
+ out_features = out_features or in_features
261
+ hidden_features = hidden_features or in_features
262
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
263
+ self.act = act_layer()
264
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
265
+ self.drop = nn.Dropout(drop)
266
+
267
+ def forward(self, x: Tensor) -> Tensor:
268
+ x = self.fc1(x)
269
+ x = self.act(x)
270
+ x = self.drop(x)
271
+ x = self.fc2(x)
272
+ x = self.drop(x)
273
+ return x
274
+
275
+
276
+
277
+
278
+ class Basic_Block(nn.Module):
279
+ def __init__(
280
+ self,
281
+ dim: int,
282
+ num_heads: int,
283
+ mlp_ratio: float = 4.0,
284
+ qkv_bias: bool = False,
285
+ proj_bias: bool = True,
286
+ ffn_bias: bool = True,
287
+ drop: float = 0.0,
288
+ attn_drop: float = 0.0,
289
+ init_values=None,
290
+ drop_path: float = 0.0,
291
+ act_layer: Callable[..., nn.Module] = nn.GELU,
292
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
293
+ attn_class: Callable[..., nn.Module] = Attention,
294
+ ffn_layer: Callable[..., nn.Module] = Mlp,
295
+ ) -> None:
296
+ super().__init__()
297
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
298
+ self.norm1 = norm_layer(dim)
299
+ self.attn = attn_class(
300
+ dim,
301
+ num_heads=num_heads,
302
+ qkv_bias=qkv_bias,
303
+ proj_bias=proj_bias,
304
+ attn_drop=attn_drop,
305
+ proj_drop=drop,
306
+ )
307
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
308
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
309
+
310
+ self.norm2 = norm_layer(dim)
311
+ mlp_hidden_dim = int(dim * mlp_ratio)
312
+ self.mlp = ffn_layer(
313
+ in_features=dim,
314
+ hidden_features=mlp_hidden_dim,
315
+ act_layer=act_layer,
316
+ drop=drop,
317
+ bias=ffn_bias,
318
+ )
319
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
320
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
321
+
322
+ self.sample_drop_ratio = drop_path
323
+
324
+ def forward(self, x: Tensor) -> Tensor:
325
+ def attn_residual_func(x: Tensor) -> Tensor:
326
+ return self.ls1(self.attn(self.norm1(x)))
327
+
328
+ def ffn_residual_func(x: Tensor) -> Tensor:
329
+ return self.ls2(self.mlp(self.norm2(x)))
330
+
331
+ if self.training and self.sample_drop_ratio > 0.1:
332
+ # the overhead is compensated only for a drop path rate larger than 0.1
333
+ x = drop_add_residual_stochastic_depth(
334
+ x,
335
+ residual_func=attn_residual_func,
336
+ sample_drop_ratio=self.sample_drop_ratio,
337
+ )
338
+ x = drop_add_residual_stochastic_depth(
339
+ x,
340
+ residual_func=ffn_residual_func,
341
+ sample_drop_ratio=self.sample_drop_ratio,
342
+ )
343
+ elif self.training and self.sample_drop_ratio > 0.0:
344
+ x = x + self.drop_path1(attn_residual_func(x))
345
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
346
+ else:
347
+ x = x + attn_residual_func(x)
348
+ x = x + ffn_residual_func(x)
349
+ return x
350
+
351
+
352
+ def drop_add_residual_stochastic_depth(
353
+ x: Tensor,
354
+ residual_func: Callable[[Tensor], Tensor],
355
+ sample_drop_ratio: float = 0.0,
356
+ ) -> Tensor:
357
+ # 1) extract subset using permutation
358
+ b, n, d = x.shape
359
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
360
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
361
+ x_subset = x[brange]
362
+
363
+ # 2) apply residual_func to get residual
364
+ residual = residual_func(x_subset)
365
+
366
+ x_flat = x.flatten(1)
367
+ residual = residual.flatten(1)
368
+
369
+ residual_scale_factor = b / sample_subset_size
370
+
371
+ # 3) add the residual
372
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
373
+ return x_plus_residual.view_as(x)
374
+
375
+
376
+ def get_branges_scales(x, sample_drop_ratio=0.0):
377
+ b, n, d = x.shape
378
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
379
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
380
+ residual_scale_factor = b / sample_subset_size
381
+ return brange, residual_scale_factor
382
+
383
+
384
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
385
+ if scaling_vector is None:
386
+ x_flat = x.flatten(1)
387
+ residual = residual.flatten(1)
388
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
389
+ else:
390
+ x_plus_residual = scaled_index_add(
391
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
392
+ )
393
+ return x_plus_residual
394
+
395
+
396
+ attn_bias_cache: Dict[Tuple, Any] = {}
397
+
398
+
399
+ def get_attn_bias_and_cat(x_list, branges=None):
400
+ """
401
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
402
+ """
403
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
404
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
405
+ if all_shapes not in attn_bias_cache.keys():
406
+ seqlens = []
407
+ for b, x in zip(batch_sizes, x_list):
408
+ for _ in range(b):
409
+ seqlens.append(x.shape[1])
410
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
411
+ attn_bias._batch_sizes = batch_sizes
412
+ attn_bias_cache[all_shapes] = attn_bias
413
+
414
+ if branges is not None:
415
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
416
+ else:
417
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
418
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
419
+
420
+ return attn_bias_cache[all_shapes], cat_tensors
421
+
422
+
423
+ def drop_add_residual_stochastic_depth_list(
424
+ x_list: List[Tensor],
425
+ residual_func: Callable[[Tensor, Any], Tensor],
426
+ sample_drop_ratio: float = 0.0,
427
+ scaling_vector=None,
428
+ ) -> Tensor:
429
+ # 1) generate random set of indices for dropping samples in the batch
430
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
431
+ branges = [s[0] for s in branges_scales]
432
+ residual_scale_factors = [s[1] for s in branges_scales]
433
+
434
+ # 2) get attention bias and index+concat the tensors
435
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
436
+
437
+ # 3) apply residual_func to get residual, and split the result
438
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
439
+
440
+ outputs = []
441
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
442
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
443
+ return outputs
444
+
445
+
446
+ class Block(Basic_Block):
447
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
448
+ """
449
+ x_list contains a list of tensors to nest together and run
450
+ """
451
+ assert isinstance(self.attn, MemEffAttention)
452
+
453
+ if self.training and self.sample_drop_ratio > 0.0:
454
+
455
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
456
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
457
+
458
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
459
+ return self.mlp(self.norm2(x))
460
+
461
+ x_list = drop_add_residual_stochastic_depth_list(
462
+ x_list,
463
+ residual_func=attn_residual_func,
464
+ sample_drop_ratio=self.sample_drop_ratio,
465
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
466
+ )
467
+ x_list = drop_add_residual_stochastic_depth_list(
468
+ x_list,
469
+ residual_func=ffn_residual_func,
470
+ sample_drop_ratio=self.sample_drop_ratio,
471
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
472
+ )
473
+ return x_list
474
+ else:
475
+
476
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
477
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
478
+
479
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
480
+ return self.ls2(self.mlp(self.norm2(x)))
481
+
482
+ attn_bias, x = get_attn_bias_and_cat(x_list)
483
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
484
+ x = x + ffn_residual_func(x)
485
+ return attn_bias.split(x)
486
+
487
+ def forward(self, x_or_x_list):
488
+ if isinstance(x_or_x_list, Tensor):
489
+ return super().forward(x_or_x_list)
490
+ elif isinstance(x_or_x_list, list):
491
+ if not XFORMERS_AVAILABLE:
492
+ raise AssertionError("xFormers is required for using nested tensors")
493
+ return self.forward_nested(x_or_x_list)
494
+ else:
495
+ raise AssertionError
496
+
497
+
498
+
499
+
500
+
501
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
502
+ if not depth_first and include_root:
503
+ fn(module=module, name=name)
504
+ for child_name, child_module in module.named_children():
505
+ child_name = ".".join((name, child_name)) if name else child_name
506
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
507
+ if depth_first and include_root:
508
+ fn(module=module, name=name)
509
+ return module
510
+
511
+
512
+ class BlockChunk(nn.ModuleList):
513
+ def forward(self, x):
514
+ for b in self:
515
+ x = b(x)
516
+ return x
517
+
518
+
519
+ class DinoVisionTransformer(nn.Module):
520
+ def __init__(
521
+ self,
522
+ img_size=224,
523
+ patch_size=16,
524
+ in_chans=3,
525
+ embed_dim=768,
526
+ depth=12,
527
+ num_heads=12,
528
+ mlp_ratio=4.0,
529
+ qkv_bias=True,
530
+ ffn_bias=True,
531
+ proj_bias=True,
532
+ drop_path_rate=0.0,
533
+ drop_path_uniform=False,
534
+ init_values=None, # for layerscale: None or 0 => no layerscale
535
+ embed_layer=PatchEmbed,
536
+ act_layer=nn.GELU,
537
+ block_fn=Block,
538
+ ffn_layer="mlp",
539
+ block_chunks=1,
540
+ num_register_tokens=0,
541
+ interpolate_antialias=False,
542
+ interpolate_offset=0.1,
543
+ ):
544
+ """
545
+ Args:
546
+ img_size (int, tuple): input image size
547
+ patch_size (int, tuple): patch size
548
+ in_chans (int): number of input channels
549
+ embed_dim (int): embedding dimension
550
+ depth (int): depth of transformer
551
+ num_heads (int): number of attention heads
552
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
553
+ qkv_bias (bool): enable bias for qkv if True
554
+ proj_bias (bool): enable bias for proj in attn if True
555
+ ffn_bias (bool): enable bias for ffn if True
556
+ drop_path_rate (float): stochastic depth rate
557
+ drop_path_uniform (bool): apply uniform drop rate across blocks
558
+ weight_init (str): weight init scheme
559
+ init_values (float): layer-scale init values
560
+ embed_layer (nn.Module): patch embedding layer
561
+ act_layer (nn.Module): MLP activation layer
562
+ block_fn (nn.Module): transformer block class
563
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
564
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
565
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
566
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
567
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
568
+ """
569
+ super().__init__()
570
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
571
+
572
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
573
+ self.num_tokens = 1
574
+ self.n_blocks = depth
575
+ self.num_heads = num_heads
576
+ self.patch_size = patch_size
577
+ self.num_register_tokens = num_register_tokens
578
+ self.interpolate_antialias = interpolate_antialias
579
+ self.interpolate_offset = interpolate_offset
580
+
581
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
582
+ num_patches = self.patch_embed.num_patches
583
+
584
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
585
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
586
+ assert num_register_tokens >= 0
587
+ self.register_tokens = (
588
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
589
+ )
590
+
591
+ if drop_path_uniform is True:
592
+ dpr = [drop_path_rate] * depth
593
+ else:
594
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
595
+
596
+ if ffn_layer == "mlp":
597
+ logger.info("using MLP layer as FFN")
598
+ ffn_layer = Mlp
599
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
600
+ logger.info("using SwiGLU layer as FFN")
601
+ ffn_layer = SwiGLUFFNFused
602
+ elif ffn_layer == "identity":
603
+ logger.info("using Identity layer as FFN")
604
+
605
+ def f(*args, **kwargs):
606
+ return nn.Identity()
607
+
608
+ ffn_layer = f
609
+ else:
610
+ raise NotImplementedError
611
+
612
+ blocks_list = [
613
+ block_fn(
614
+ dim=embed_dim,
615
+ num_heads=num_heads,
616
+ mlp_ratio=mlp_ratio,
617
+ qkv_bias=qkv_bias,
618
+ proj_bias=proj_bias,
619
+ ffn_bias=ffn_bias,
620
+ drop_path=dpr[i],
621
+ norm_layer=norm_layer,
622
+ act_layer=act_layer,
623
+ ffn_layer=ffn_layer,
624
+ init_values=init_values,
625
+ )
626
+ for i in range(depth)
627
+ ]
628
+ if block_chunks > 0:
629
+ self.chunked_blocks = True
630
+ chunked_blocks = []
631
+ chunksize = depth // block_chunks
632
+ for i in range(0, depth, chunksize):
633
+ # this is to keep the block index consistent if we chunk the block list
634
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
635
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
636
+ else:
637
+ self.chunked_blocks = False
638
+ self.blocks = nn.ModuleList(blocks_list)
639
+
640
+ self.norm = norm_layer(embed_dim)
641
+ self.head = nn.Identity()
642
+
643
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
644
+
645
+ self.init_weights()
646
+
647
+ def init_weights(self):
648
+ trunc_normal_(self.pos_embed, std=0.02)
649
+ nn.init.normal_(self.cls_token, std=1e-6)
650
+ if self.register_tokens is not None:
651
+ nn.init.normal_(self.register_tokens, std=1e-6)
652
+ named_apply(init_weights_vit_timm, self)
653
+
654
+ def interpolate_pos_encoding(self, x, w, h):
655
+ previous_dtype = x.dtype
656
+ npatch = x.shape[1] - 1
657
+ N = self.pos_embed.shape[1] - 1
658
+ if npatch == N and w == h:
659
+ return self.pos_embed
660
+ pos_embed = self.pos_embed.float()
661
+ class_pos_embed = pos_embed[:, 0]
662
+ patch_pos_embed = pos_embed[:, 1:]
663
+ dim = x.shape[-1]
664
+ w0 = w // self.patch_size
665
+ h0 = h // self.patch_size
666
+ # we add a small number to avoid floating point error in the interpolation
667
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
668
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
669
+
670
+ sqrt_N = math.sqrt(N)
671
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
672
+ patch_pos_embed = nn.functional.interpolate(
673
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
674
+ scale_factor=(sx, sy),
675
+ mode="bicubic",
676
+ antialias=self.interpolate_antialias,
677
+ )
678
+
679
+ assert int(w0) == patch_pos_embed.shape[-2]
680
+ assert int(h0) == patch_pos_embed.shape[-1]
681
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
682
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
683
+
684
+ def prepare_tokens_with_masks(self, x, masks=None):
685
+ B, nc, w, h = x.shape
686
+ x = self.patch_embed(x)
687
+ if masks is not None:
688
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
689
+
690
+
691
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
692
+ x = x + self.interpolate_pos_encoding(x, w, h)
693
+
694
+ if self.register_tokens is not None:
695
+ x = torch.cat(
696
+ (
697
+ x[:, :1],
698
+ self.register_tokens.expand(x.shape[0], -1, -1),
699
+ x[:, 1:],
700
+ ),
701
+ dim=1,
702
+ )
703
+
704
+ return x
705
+
706
+ def forward_features_list(self, x_list, masks_list):
707
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
708
+ for blk in self.blocks:
709
+ x = blk(x)
710
+
711
+ all_x = x
712
+ output = []
713
+ for x, masks in zip(all_x, masks_list):
714
+ x_norm = self.norm(x)
715
+ output.append(
716
+ {
717
+ "x_norm_clstoken": x_norm[:, 0],
718
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
719
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
720
+ "x_prenorm": x,
721
+ "masks": masks,
722
+ }
723
+ )
724
+ return output
725
+
726
+ def forward_features(self, x, masks=None):
727
+ if isinstance(x, list):
728
+ return self.forward_features_list(x, masks)
729
+
730
+ x = self.prepare_tokens_with_masks(x, masks)
731
+
732
+ for blk in self.blocks:
733
+ x = blk(x)
734
+
735
+ x_norm = self.norm(x)
736
+ return {
737
+ "x_norm_clstoken": x_norm[:, 0],
738
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
739
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
740
+ "x_prenorm": x,
741
+ "masks": masks,
742
+ }
743
+
744
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
745
+ x = self.prepare_tokens_with_masks(x)
746
+ # If n is an int, take the n last blocks. If it's a list, take them
747
+ output, total_block_len = [], len(self.blocks)
748
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
749
+ for i, blk in enumerate(self.blocks):
750
+ x = blk(x)
751
+ if i in blocks_to_take:
752
+ output.append(x)
753
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
754
+ return output
755
+
756
+ def _get_intermediate_layers_chunked(self, x, n=1):
757
+ x = self.prepare_tokens_with_masks(x)
758
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
759
+ # If n is an int, take the n last blocks. If it's a list, take them
760
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
761
+ for block_chunk in self.blocks:
762
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
763
+ x = blk(x)
764
+ if i in blocks_to_take:
765
+ output.append(x)
766
+ i += 1
767
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
768
+ return output
769
+
770
+ def get_intermediate_layers(
771
+ self,
772
+ x: torch.Tensor,
773
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
774
+ reshape: bool = False,
775
+ return_class_token: bool = False,
776
+ norm=True,
777
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
778
+ if self.chunked_blocks:
779
+ outputs = self._get_intermediate_layers_chunked(x, n)
780
+ else:
781
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
782
+ if norm:
783
+ outputs = [self.norm(out) for out in outputs]
784
+ class_tokens = [out[:, 0] for out in outputs]
785
+ outputs = [out[:, 1:] for out in outputs]
786
+ if reshape:
787
+ B, _, w, h = x.shape
788
+ outputs = [
789
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
790
+ for out in outputs
791
+ ]
792
+ if return_class_token:
793
+ return tuple(zip(outputs, class_tokens))
794
+ return tuple(outputs)
795
+
796
+ def forward(self, *args, is_training=False, **kwargs):
797
+ ret = self.forward_features(*args, **kwargs)
798
+ if is_training:
799
+ return ret
800
+ else:
801
+ return self.head(ret["x_norm_clstoken"])
802
+
803
+
804
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
805
+ """ViT weight initialization, original timm impl (for reproducibility)"""
806
+ if isinstance(module, nn.Linear):
807
+ trunc_normal_(module.weight, std=0.02)
808
+ if module.bias is not None:
809
+ nn.init.zeros_(module.bias)
810
+
811
+
812
+ def vit_tiny(patch_size=16, num_register_tokens=0, **kwargs):
813
+ model = DinoVisionTransformer(
814
+ patch_size=patch_size,
815
+ embed_dim=192,
816
+ depth=12,
817
+ num_heads=3,
818
+ mlp_ratio=4,
819
+ block_fn=partial(Block, attn_class=MemEffAttention),
820
+ num_register_tokens=num_register_tokens,
821
+ **kwargs,
822
+ )
823
+ return model
824
+
825
+
826
+
827
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
828
+ model = DinoVisionTransformer(
829
+ patch_size=patch_size,
830
+ embed_dim=384,
831
+ depth=12,
832
+ num_heads=6,
833
+ mlp_ratio=4,
834
+ block_fn=partial(Block, attn_class=MemEffAttention),
835
+ num_register_tokens=num_register_tokens,
836
+ **kwargs,
837
+ )
838
+ return model
839
+
840
+
841
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
842
+ model = DinoVisionTransformer(
843
+ patch_size=patch_size,
844
+ embed_dim=768,
845
+ depth=12,
846
+ num_heads=12,
847
+ mlp_ratio=4,
848
+ block_fn=partial(Block, attn_class=MemEffAttention),
849
+ num_register_tokens=num_register_tokens,
850
+ **kwargs,
851
+ )
852
+ return model
853
+
854
+
855
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
856
+ model = DinoVisionTransformer(
857
+ patch_size=patch_size,
858
+ embed_dim=1024,
859
+ depth=24,
860
+ num_heads=16,
861
+ mlp_ratio=4,
862
+ block_fn=partial(Block, attn_class=MemEffAttention),
863
+ num_register_tokens=num_register_tokens,
864
+ **kwargs,
865
+ )
866
+ return model
867
+
868
+
869
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
870
+ """
871
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
872
+ """
873
+ model = DinoVisionTransformer(
874
+ patch_size=patch_size,
875
+ embed_dim=1536,
876
+ depth=40,
877
+ num_heads=24,
878
+ mlp_ratio=4,
879
+ block_fn=partial(Block, attn_class=MemEffAttention),
880
+ num_register_tokens=num_register_tokens,
881
+ **kwargs,
882
+ )
883
+ return model
884
+
885
+
886
+ if __name__ == "__main__":
887
+ import argparse
888
+ from fvcore.nn import FlopCountAnalysis, parameter_count_table
889
+
890
+ with torch.no_grad():
891
+ model = vit_base(img_size=224,
892
+ patch_size=14,
893
+ init_values=1.0,
894
+ ffn_layer='mlp',
895
+ block_chunks=0,
896
+ num_register_tokens=0,
897
+ interpolate_antialias=False,
898
+ interpolate_offset=0.1)
899
+
900
+ for name, param in model.named_parameters():
901
+ print(name, param)
902
+
903
+ # print(parameter_count_table(model))
904
+
905
+ # tensor = torch.rand(1, 3, 224, 224)
906
+ # flops = FlopCountAnalysis(model, tensor)
907
+ # print("FLOPs: ", flops.total()/1e9)
1_feature_extractor/models_proteus_clip.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from torch import nn
17
+ import torch.nn.functional as F
18
+ import torch.distributed.nn
19
+ import torch.distributed as dist
20
+ from torch.nn.init import trunc_normal_
21
+ from torch.nn.utils import weight_norm
22
+ import models_clip
23
+
24
+
25
+
26
+ class MetaArch(nn.Module):
27
+
28
+ def __init__(self, cfg):
29
+ super().__init__()
30
+ self.cfg = cfg
31
+
32
+ student_model_dict = dict()
33
+ teacher_model_dict = dict()
34
+
35
+ import_student = getattr(models_clip, cfg.target_model)
36
+ student = import_student()
37
+
38
+ embed_dim = student.embed_dim
39
+
40
+ import_teacher = getattr(models_clip, cfg.teacher_model)
41
+ teacher_backbone = import_teacher(teacher_path=cfg.teacher_path)
42
+ teacher_backbone.eval()
43
+
44
+ student_model_dict['backbone'] = student
45
+ teacher_model_dict['backbone'] = teacher_backbone
46
+
47
+ self.embed_dim = embed_dim
48
+
49
+ # initialize parameters and checks
50
+ self.total_n_global_crops = cfg.batch_size
51
+
52
+ self.student = nn.ModuleDict(student_model_dict)
53
+ self.teacher = nn.ModuleDict(teacher_model_dict)
54
+
55
+ teacher_embed_dim = teacher_backbone.embed_dim
56
+
57
+ self.token_head = nn.Sequential(
58
+ nn.LayerNorm(embed_dim),
59
+ nn.Linear(embed_dim, teacher_embed_dim))
60
+
61
+ self.soft_criterion = torch.nn.MSELoss()
62
+
63
+ for param in self.teacher.backbone.parameters():
64
+ param.requires_grad = False
65
+
66
+ ## we explicitly remove the patch and feature learning objectives for CLIP training following the original design
67
+ def forward(self, inputs):
68
+ global_crops = inputs["collated_global_crops"]
69
+
70
+ # compute teacher output
71
+ # @torch.no_grad()
72
+ def compute_teacher_output():
73
+ with torch.no_grad():
74
+ teacher_backbone_output_dict = self.teacher.backbone(global_crops)
75
+ teacher_cls_tokens = teacher_backbone_output_dict["x_norm_clstoken"]
76
+
77
+ return teacher_cls_tokens
78
+
79
+ # get the teacher outputs
80
+ teacher_cls_tokens = compute_teacher_output()
81
+
82
+ student_backbone_output_dict_unmask = self.student.backbone(global_crops)
83
+
84
+ student_cls_token_unmask = student_backbone_output_dict_unmask["x_norm_clstoken"]
85
+
86
+ ## projection head
87
+ student_cls_token_unmask = self.token_head(student_cls_token_unmask)
88
+
89
+ ## token objective
90
+ distillation_loss_token = self.soft_criterion(student_cls_token_unmask, teacher_cls_tokens)
91
+
92
+ # coefficient
93
+ token_loss = self.cfg.lambda_token * distillation_loss_token
94
+
95
+ # compute the total loss
96
+ total_loss = token_loss
97
+
98
+ # return the final loss dict
99
+ loss_dict = {"patch_loss": torch.tensor(0.0), "fea_loss": torch.tensor(0.0), "token_loss": token_loss, "loss": total_loss}
100
+
101
+ return loss_dict
1_feature_extractor/models_proteus_dinov2.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from torch import nn
17
+ import torch.nn.functional as F
18
+ import torch.distributed.nn
19
+ import torch.distributed as dist
20
+ from torch.nn.init import trunc_normal_
21
+ from torch.nn.utils import weight_norm
22
+ import models_dinov2
23
+ from models_IB import IF_Module
24
+ import math
25
+
26
+
27
+ class MetaArch(nn.Module):
28
+
29
+ def __init__(self, cfg):
30
+ super().__init__()
31
+ self.cfg = cfg
32
+
33
+ student_model_dict = dict()
34
+ teacher_model_dict = dict()
35
+
36
+ import_student = getattr(models_dinov2, cfg.target_model)
37
+ student = import_student(img_size=224,
38
+ patch_size=cfg.patch_size,
39
+ init_values=1.0,
40
+ ffn_layer='mlp',
41
+ block_chunks=0,
42
+ num_register_tokens=0,
43
+ interpolate_antialias=False,
44
+ interpolate_offset=0.1)
45
+
46
+ embed_dim = student.embed_dim
47
+
48
+ if cfg.teacher_model == 'vit_base':
49
+ teacher_backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_lc')
50
+ elif cfg.teacher_model == 'vit_small':
51
+ teacher_backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_lc')
52
+ elif cfg.teacher_model == 'vit_large':
53
+ teacher_backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_lc')
54
+ elif cfg.teacher_model == 'vit_giant':
55
+ teacher_backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_lc')
56
+ teacher_backbone.eval()
57
+
58
+ student_model_dict['backbone'] = student
59
+ teacher_model_dict['backbone'] = teacher_backbone.backbone
60
+
61
+ self.embed_dim = embed_dim
62
+
63
+ # initialize parameters and checks
64
+ self.total_n_global_crops = cfg.batch_size
65
+
66
+ self.student = nn.ModuleDict(student_model_dict)
67
+ self.teacher = nn.ModuleDict(teacher_model_dict)
68
+
69
+ teacher_embed_dim = teacher_backbone.backbone.embed_dim
70
+ self.ibot_head = nn.Sequential(
71
+ nn.LayerNorm(embed_dim),
72
+ nn.Linear(embed_dim, teacher_embed_dim))
73
+
74
+ self.token_head = nn.Sequential(
75
+ nn.LayerNorm(embed_dim),
76
+ nn.Linear(embed_dim, teacher_embed_dim))
77
+
78
+ self.fea_head = nn.Sequential(
79
+ nn.LayerNorm(embed_dim),
80
+ nn.Linear(embed_dim, teacher_embed_dim))
81
+
82
+ self.soft_criterion = torch.nn.MSELoss()
83
+
84
+ self.info_bottleneck = IF_Module(embed_dim=embed_dim, num_heads=12, mlp_ratio=4, depth=4)
85
+
86
+ for param in self.teacher.backbone.parameters():
87
+ param.requires_grad = False
88
+
89
+ def cal_bpp(self, image, unmask_likelihood, mask_likelihood):
90
+ b, _, h, w = image.size()
91
+ num_pixels = b * h * w
92
+ log_unmask_likelihoods = torch.log(unmask_likelihood)
93
+ log_mask_likelihoods = torch.log(mask_likelihood)
94
+ bpp = (log_unmask_likelihoods.sum() + log_mask_likelihoods.sum()) / (-math.log(2) * num_pixels * 1.5)
95
+ return bpp
96
+
97
+ def forward(self, inputs):
98
+ global_crops = inputs["collated_global_crops"]
99
+
100
+ masks = inputs["collated_masks"]
101
+ mask_indices_list = inputs["mask_indices_list"]
102
+ n_masked_patches = mask_indices_list.shape[0]
103
+ upperbound = inputs["upperbound"]
104
+
105
+ n_global_crops = 1
106
+
107
+ # compute teacher output
108
+ # @torch.no_grad()
109
+ def compute_teacher_output():
110
+ with torch.no_grad():
111
+ teacher_backbone_output_dict = self.teacher.backbone(global_crops, is_training=True)
112
+ teacher_cls_tokens = teacher_backbone_output_dict["x_norm_clstoken"]
113
+ teacher_patch_tokens = teacher_backbone_output_dict["x_norm_patchtokens"]
114
+ _dim = teacher_patch_tokens.shape[-1]
115
+
116
+ # mask teacher patch tokens
117
+ buffer_tensor_teacher = teacher_patch_tokens.new_zeros(upperbound, _dim)
118
+ torch.index_select(
119
+ teacher_patch_tokens.flatten(0, 1),
120
+ dim=0,
121
+ index=mask_indices_list,
122
+ out=buffer_tensor_teacher[:n_masked_patches],
123
+ )
124
+ teacher_patch_tokens_masked = buffer_tensor_teacher[:n_masked_patches]
125
+
126
+ return teacher_cls_tokens, teacher_patch_tokens, teacher_patch_tokens_masked
127
+
128
+ # get the teacher outputs
129
+ (
130
+ teacher_cls_tokens,
131
+ teacher_patch_tokens,
132
+ teacher_patch_tokens_masked
133
+ ) = compute_teacher_output()
134
+
135
+ cur_masks = masks if self.cfg.mask_probability > 0 else None
136
+
137
+ student_backbone_output_dict, student_backbone_output_dict_unmask = self.student.backbone(
138
+ [global_crops, global_crops], masks=[cur_masks, None], is_training=True
139
+ )
140
+
141
+ student_cls_token_unmask = student_backbone_output_dict_unmask["x_norm_clstoken"]
142
+ student_patch_tokens_unmask = student_backbone_output_dict_unmask["x_norm_patchtokens"]
143
+ student_patch_tokens = student_backbone_output_dict["x_norm_patchtokens"]
144
+
145
+ # calculate bitrate
146
+ student_patch_tokens_unmask, unmask_likelihood = self.info_bottleneck(student_patch_tokens_unmask, is_training=True)
147
+ student_patch_tokens, mask_likelihood = self.info_bottleneck(student_patch_tokens, is_training=True)
148
+ bpp = self.cal_bpp(global_crops, unmask_likelihood, mask_likelihood)
149
+
150
+ # mask student patch tokens
151
+ _dim = student_patch_tokens.shape[-1]
152
+
153
+ buffer_tensor_student = student_patch_tokens.new_zeros(upperbound, _dim)
154
+ buffer_tensor_student[:n_masked_patches].copy_(
155
+ torch.index_select(student_patch_tokens.flatten(0, 1),
156
+ dim=0,
157
+ index=mask_indices_list)
158
+ )
159
+
160
+ ## projection head
161
+ student_patch_tokens_unmask = self.fea_head(student_patch_tokens_unmask)
162
+
163
+ student_cls_token_unmask = self.token_head(student_cls_token_unmask)
164
+
165
+ tokens_after_head = self.ibot_head(buffer_tensor_student)
166
+ student_patch_tokens_masked = tokens_after_head[:n_masked_patches]
167
+
168
+ ## token objective
169
+ distillation_loss_token = self.soft_criterion(student_cls_token_unmask, teacher_cls_tokens)
170
+
171
+ ## fea objective
172
+ student_whole_fea = torch.cat((student_cls_token_unmask.unsqueeze(1),student_patch_tokens_unmask),dim=1)
173
+ teacher_whole_fea = torch.cat((teacher_cls_tokens.unsqueeze(1),teacher_patch_tokens),dim=1)
174
+ distillation_loss_fea = self.soft_criterion(student_whole_fea, teacher_whole_fea)
175
+
176
+ ## patch objective
177
+ patch_loss = self.soft_criterion(student_patch_tokens_masked, teacher_patch_tokens_masked)
178
+
179
+ # coefficient
180
+ token_loss = self.cfg.lambda_token * distillation_loss_token
181
+ fea_loss = self.cfg.lambda_fea * distillation_loss_fea
182
+ patch_loss_weighted = self.cfg.lambda_patch * patch_loss
183
+ # print(f"self.cfg: {self.cfg}")
184
+ # print(f"self.cfg.lambda_token: {self.cfg.lambda_token}, self.cfg.lambda_fea: {self.cfg.lambda_fea}, self.cfg.lambda_patch: {self.cfg.lambda_patch}")
185
+
186
+ # compute the total loss
187
+ total_loss = patch_loss_weighted + fea_loss + token_loss + 0.48 * bpp
188
+ # task_loss = patch_loss + fea_loss + token_loss
189
+ task_loss = patch_loss + distillation_loss_fea + distillation_loss_token
190
+
191
+ # return the final loss dict
192
+ loss_dict = {"bpp_loss": bpp,
193
+ "patch_loss": patch_loss,
194
+ "fea_loss": distillation_loss_fea,
195
+ "token_loss": token_loss,
196
+ "loss": total_loss,
197
+ "task_loss": task_loss,
198
+ }
199
+
200
+ return loss_dict
1_feature_extractor/models_proteus_synclr.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from torch import nn
17
+ import torch.nn.functional as F
18
+ import torch.distributed.nn
19
+ import torch.distributed as dist
20
+ from torch.nn.init import trunc_normal_
21
+ from torch.nn.utils import weight_norm
22
+ import models_synclr
23
+
24
+
25
+
26
+ class MetaArch(nn.Module):
27
+
28
+ def __init__(self, cfg):
29
+ super().__init__()
30
+ self.cfg = cfg
31
+
32
+ student_model_dict = dict()
33
+ teacher_model_dict = dict()
34
+
35
+ import_student = getattr(models_synclr, cfg.target_model)
36
+ student = import_student(patch_size=cfg.patch_size, num_classes=0, mask_style='ibot')
37
+
38
+ embed_dim = student.embed_dim
39
+
40
+ import_teacher = getattr(models_synclr, cfg.teacher_model)
41
+ teacher_backbone = import_teacher(patch_size=cfg.patch_size, teacher_path=cfg.teacher_path, num_classes=0, mask_style='ibot')
42
+ teacher_backbone.eval()
43
+
44
+ student_model_dict['backbone'] = student
45
+ teacher_model_dict['backbone'] = teacher_backbone
46
+
47
+ self.embed_dim = embed_dim
48
+
49
+ # initialize parameters and checks
50
+ self.total_n_global_crops = cfg.batch_size
51
+
52
+ self.student = nn.ModuleDict(student_model_dict)
53
+ self.teacher = nn.ModuleDict(teacher_model_dict)
54
+
55
+ teacher_embed_dim = teacher_backbone.embed_dim
56
+ self.patch_head = nn.Sequential(
57
+ nn.LayerNorm(embed_dim),
58
+ nn.Linear(embed_dim, teacher_embed_dim))
59
+
60
+ self.token_head = nn.Sequential(
61
+ nn.LayerNorm(embed_dim),
62
+ nn.Linear(embed_dim, teacher_embed_dim))
63
+
64
+ self.fea_head = nn.Sequential(
65
+ nn.LayerNorm(embed_dim),
66
+ nn.Linear(embed_dim, teacher_embed_dim))
67
+
68
+ self.soft_criterion = torch.nn.MSELoss()
69
+
70
+ for param in self.teacher.backbone.parameters():
71
+ param.requires_grad = False
72
+
73
+ def forward(self, inputs):
74
+ global_crops = inputs["collated_global_crops"]
75
+
76
+ masks = inputs["collated_masks"]
77
+ mask_indices_list = inputs["mask_indices_list"]
78
+ n_masked_patches = mask_indices_list.shape[0]
79
+ upperbound = inputs["upperbound"]
80
+
81
+ n_global_crops = 1
82
+
83
+ # compute teacher output
84
+ # @torch.no_grad()
85
+ def compute_teacher_output():
86
+ with torch.no_grad():
87
+ teacher_backbone_output_dict = self.teacher.backbone(global_crops, is_training=True)
88
+ teacher_cls_tokens = teacher_backbone_output_dict["x_norm_clstoken"]
89
+ teacher_patch_tokens = teacher_backbone_output_dict["x_norm_patchtokens"]
90
+ _dim = teacher_patch_tokens.shape[-1]
91
+
92
+ # mask teacher patch tokens
93
+ buffer_tensor_teacher = teacher_patch_tokens.new_zeros(upperbound, _dim)
94
+ torch.index_select(
95
+ teacher_patch_tokens.flatten(0, 1),
96
+ dim=0,
97
+ index=mask_indices_list,
98
+ out=buffer_tensor_teacher[:n_masked_patches],
99
+ )
100
+ teacher_patch_tokens_masked = buffer_tensor_teacher[:n_masked_patches]
101
+
102
+ return teacher_cls_tokens, teacher_patch_tokens, teacher_patch_tokens_masked
103
+
104
+ # get the teacher outputs
105
+ (
106
+ teacher_cls_tokens,
107
+ teacher_patch_tokens,
108
+ teacher_patch_tokens_masked
109
+ ) = compute_teacher_output()
110
+
111
+ cur_masks = masks if self.cfg.mask_probability > 0 else None
112
+
113
+ student_backbone_output_dict, student_backbone_output_dict_unmask = self.student.backbone(
114
+ [global_crops, global_crops], masks=[cur_masks, None], is_training=True
115
+ )
116
+
117
+ student_cls_token_unmask = student_backbone_output_dict_unmask["x_norm_clstoken"]
118
+ student_patch_tokens_unmask = student_backbone_output_dict_unmask["x_norm_patchtokens"]
119
+ student_patch_tokens = student_backbone_output_dict["x_norm_patchtokens"]
120
+
121
+ # mask student patch tokens
122
+ _dim = student_patch_tokens.shape[-1]
123
+
124
+ buffer_tensor_student = student_patch_tokens.new_zeros(upperbound, _dim)
125
+ buffer_tensor_student[:n_masked_patches].copy_(
126
+ torch.index_select(student_patch_tokens.flatten(0, 1),
127
+ dim=0,
128
+ index=mask_indices_list)
129
+ )
130
+
131
+ ## projection head
132
+ student_patch_tokens_unmask = self.fea_head(student_patch_tokens_unmask)
133
+
134
+ student_cls_token_unmask = self.token_head(student_cls_token_unmask)
135
+
136
+ tokens_after_head = self.patch_head(buffer_tensor_student)
137
+ student_patch_tokens_masked = tokens_after_head[:n_masked_patches]
138
+
139
+ ## token objective
140
+ distillation_loss_token = self.soft_criterion(student_cls_token_unmask, teacher_cls_tokens)
141
+
142
+ ## fea objective
143
+ student_whole_fea = torch.cat((student_cls_token_unmask.unsqueeze(1),student_patch_tokens_unmask),dim=1)
144
+ teacher_whole_fea = torch.cat((teacher_cls_tokens.unsqueeze(1),teacher_patch_tokens),dim=1)
145
+ distillation_loss_fea = self.soft_criterion(student_whole_fea, teacher_whole_fea)
146
+
147
+ ## patch objective
148
+ patch_loss = self.soft_criterion(student_patch_tokens_masked, teacher_patch_tokens_masked)
149
+
150
+ # coefficient
151
+ token_loss = self.cfg.lambda_token * distillation_loss_token
152
+ fea_loss = self.cfg.lambda_fea * distillation_loss_fea
153
+ patch_loss = self.cfg.lambda_patch * patch_loss
154
+
155
+ # compute the total loss
156
+ total_loss = patch_loss + fea_loss + token_loss
157
+
158
+ # return the final loss dict
159
+ loss_dict = {"patch_loss": patch_loss, "fea_loss": fea_loss, "token_loss": token_loss, "loss": total_loss}
160
+
161
+ return loss_dict
1_feature_extractor/models_synclr.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from functools import partial
17
+ from collections import OrderedDict
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from timm.models.layers import trunc_normal_, lecun_normal_, to_2tuple
23
+ from timm.models.vision_transformer import Attention
24
+ from timm.models.layers import Mlp, DropPath
25
+ from timm.models.helpers import named_apply
26
+
27
+
28
+
29
+ class Block(nn.Module):
30
+
31
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
32
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, ffn_targets=False,
33
+ return_layer_targets=False):
34
+ super().__init__()
35
+ self.norm1 = norm_layer(dim)
36
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
37
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
38
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
39
+ self.norm2 = norm_layer(dim)
40
+ mlp_hidden_dim = int(dim * mlp_ratio)
41
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
42
+
43
+ # specify the targets for feature regression
44
+ self.ffn_targets = ffn_targets
45
+ self.return_layer_targets = return_layer_targets
46
+
47
+ def forward(self, x):
48
+ if isinstance(x, tuple):
49
+ x = x[0]
50
+
51
+ x = x + self.drop_path(self.attn(self.norm1(x)))
52
+ ffn_out = self.mlp(self.norm2(x))
53
+ x = x + self.drop_path(ffn_out)
54
+
55
+ target = ffn_out if self.ffn_targets else x
56
+
57
+ if self.return_layer_targets:
58
+ return x, target
59
+ return x
60
+
61
+
62
+ class PatchEmbed(nn.Module):
63
+ """ 2D Image to Patch Embedding
64
+ """
65
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
66
+ super().__init__()
67
+ img_size = to_2tuple(img_size)
68
+ patch_size = to_2tuple(patch_size)
69
+ self.img_size = img_size
70
+ self.patch_size = patch_size
71
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
72
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
73
+ self.flatten = flatten
74
+
75
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
76
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
77
+
78
+ def forward(self, x):
79
+ B, C, H, W = x.shape
80
+ patch_H, patch_W = self.patch_size
81
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
82
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
83
+ x = self.proj(x)
84
+ if self.flatten:
85
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
86
+ x = self.norm(x)
87
+ return x
88
+
89
+
90
+ class VisionTransformer(nn.Module):
91
+ """ Vision Transformer
92
+
93
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
94
+ - https://arxiv.org/abs/2010.11929
95
+
96
+ Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
97
+ - https://arxiv.org/abs/2012.12877
98
+ """
99
+
100
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
101
+ num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
102
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
103
+ act_layer=None, weight_init='', ffn_targets=False, return_layer_targets=False):
104
+ """
105
+ Args:
106
+ img_size (int, tuple): input image size
107
+ patch_size (int, tuple): patch size
108
+ in_chans (int): number of input channels
109
+ num_classes (int): number of classes for classification head
110
+ embed_dim (int): embedding dimension
111
+ depth (int): depth of transformer
112
+ num_heads (int): number of attention heads
113
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
114
+ qkv_bias (bool): enable bias for qkv if True
115
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
116
+ distilled (bool): model includes a distillation token and head as in DeiT models
117
+ drop_rate (float): dropout rate
118
+ attn_drop_rate (float): attention dropout rate
119
+ drop_path_rate (float): stochastic depth rate
120
+ embed_layer (nn.Module): patch embedding layer
121
+ norm_layer: (nn.Module): normalization layer
122
+ weight_init: (str): weight init scheme
123
+ ffn_targets (bool): whether we use ffn output or block end as the feature targets
124
+ return_layer_targets (bool): whether we return every layer targets
125
+ """
126
+ super().__init__()
127
+ self.num_classes = num_classes
128
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
129
+ self.num_tokens = 2 if distilled else 1
130
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
131
+ act_layer = act_layer or nn.GELU
132
+
133
+ self.patch_embed = embed_layer(
134
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
135
+ num_patches = self.patch_embed.num_patches
136
+
137
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
138
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
139
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
140
+ self.pos_drop = nn.Dropout(p=drop_rate)
141
+
142
+ self.ffn_targets = ffn_targets
143
+ self.return_layer_targets = return_layer_targets
144
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
145
+ self.blocks = nn.Sequential(*[
146
+ Block(
147
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
148
+ attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer,
149
+ ffn_targets=ffn_targets, return_layer_targets=return_layer_targets,
150
+ )
151
+ for i in range(depth)])
152
+ self.norm = norm_layer(embed_dim)
153
+
154
+ # Representation layer
155
+ if representation_size and not distilled:
156
+ self.num_features = representation_size
157
+ self.pre_logits = nn.Sequential(OrderedDict([
158
+ ('fc', nn.Linear(embed_dim, representation_size)),
159
+ ('act', nn.Tanh())
160
+ ]))
161
+ else:
162
+ self.pre_logits = nn.Identity()
163
+
164
+ # Classifier head(s)
165
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
166
+ self.head_dist = None
167
+ if distilled:
168
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
169
+
170
+ self.init_weights(weight_init)
171
+
172
+ def init_weights(self, mode=''):
173
+ assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
174
+ head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
175
+ trunc_normal_(self.pos_embed, std=.02)
176
+ if self.dist_token is not None:
177
+ trunc_normal_(self.dist_token, std=.02)
178
+ if mode.startswith('jax'):
179
+ # leave cls token as zeros to match jax impl
180
+ named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self)
181
+ else:
182
+ trunc_normal_(self.cls_token, std=.02)
183
+ self.apply(_init_vit_weights)
184
+
185
+ def _init_weights(self, m):
186
+ # this fn left here for compat with downstream users
187
+ _init_vit_weights(m)
188
+
189
+ @torch.jit.ignore
190
+ def no_weight_decay(self):
191
+ return {'pos_embed', 'cls_token', 'dist_token'}
192
+
193
+ def get_classifier(self):
194
+ if self.dist_token is None:
195
+ return self.head
196
+ else:
197
+ return self.head, self.head_dist
198
+
199
+ def reset_classifier(self, num_classes, global_pool=''):
200
+ self.num_classes = num_classes
201
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
202
+ if self.num_tokens == 2:
203
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
204
+
205
+ def forward_features(self, x):
206
+ x = self.patch_embed(x)
207
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
208
+ if self.dist_token is None:
209
+ x = torch.cat((cls_token, x), dim=1)
210
+ else:
211
+ x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
212
+ x = self.pos_drop(x + self.pos_embed)
213
+ x = self.blocks(x)
214
+ x = self.norm(x)
215
+ if self.dist_token is None:
216
+ return self.pre_logits(x[:, 0])
217
+ else:
218
+ return x[:, 0], x[:, 1]
219
+
220
+ def forward(self, x):
221
+ x = self.forward_features(x)
222
+ if self.head_dist is not None:
223
+ x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
224
+ if self.training and not torch.jit.is_scripting():
225
+ # during inference, return the average of both classifier predictions
226
+ return x, x_dist
227
+ else:
228
+ return (x + x_dist) / 2
229
+ else:
230
+ x = self.head(x)
231
+ return x
232
+
233
+
234
+ def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
235
+ """ ViT weight initialization
236
+ * When called without n, head_bias, jax_impl args it will behave exactly the same
237
+ as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
238
+ * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
239
+ """
240
+ if isinstance(module, nn.Linear):
241
+ if name.startswith('head'):
242
+ nn.init.zeros_(module.weight)
243
+ nn.init.constant_(module.bias, head_bias)
244
+ elif name.startswith('pre_logits'):
245
+ lecun_normal_(module.weight)
246
+ nn.init.zeros_(module.bias)
247
+ else:
248
+ if jax_impl:
249
+ nn.init.xavier_uniform_(module.weight)
250
+ if module.bias is not None:
251
+ if 'mlp' in name:
252
+ nn.init.normal_(module.bias, std=1e-6)
253
+ else:
254
+ nn.init.zeros_(module.bias)
255
+ else:
256
+ trunc_normal_(module.weight, std=.02)
257
+ if module.bias is not None:
258
+ nn.init.zeros_(module.bias)
259
+ elif jax_impl and isinstance(module, nn.Conv2d):
260
+ # NOTE conv was left to pytorch default in my original init
261
+ lecun_normal_(module.weight)
262
+ if module.bias is not None:
263
+ nn.init.zeros_(module.bias)
264
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
265
+ nn.init.zeros_(module.bias)
266
+ nn.init.ones_(module.weight)
267
+
268
+
269
+
270
+ def compute_gather_ids(masks):
271
+ unmask_indices = masks.logical_not().nonzero(as_tuple=False)
272
+ ids_keep = unmask_indices[:, -1].reshape(masks.shape[0], -1)
273
+ return ids_keep
274
+
275
+
276
+ class MaskedTransformer(VisionTransformer):
277
+ """Inherit vision transformer from timm"""
278
+
279
+ def __init__(self, mask_style='ibot', **kwargs):
280
+ super().__init__(**kwargs)
281
+ assert mask_style in ["ibot", "mae", "none"], "mask_style must be `ibot`, `mae`, or `none`"
282
+
283
+ self.patch_size = self.patch_embed.patch_size
284
+ if isinstance(self.patch_size, tuple):
285
+ self.patch_size = self.patch_size[0]
286
+
287
+ nn.init.normal_(self.cls_token, std=1e-6)
288
+
289
+ self.mask_style = mask_style
290
+ if self.mask_style == "ibot":
291
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
292
+ torch.nn.init.normal_(self.mask_token, std=.02)
293
+
294
+ def interpolate_pos_encoding(self, x, w, h, npatch):
295
+ previous_dtype = x.dtype
296
+ N = self.pos_embed.shape[1] - 1
297
+ if npatch == N and w == h:
298
+ return self.pos_embed
299
+ pos_embed = self.pos_embed.float()
300
+ class_pos_embed = pos_embed[:, 0]
301
+ patch_pos_embed = pos_embed[:, 1:]
302
+ dim = x.shape[-1]
303
+ w0 = w // self.patch_size
304
+ h0 = h // self.patch_size
305
+ # we add a small number to avoid floating point error in the interpolation
306
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
307
+ w0, h0 = w0 + 0.1, h0 + 0.1
308
+
309
+ patch_pos_embed = nn.functional.interpolate(
310
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
311
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
312
+ mode="bicubic",
313
+ )
314
+
315
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
316
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
317
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
318
+
319
+ def prepare_tokens_with_masks(self, x, masks=None):
320
+ """
321
+ Args:
322
+ x: data w/ shape [b, c, h, w]
323
+ masks: shape [b, n], n is the number of tokens, 1 means masked, 0 means unmasked
324
+ """
325
+ b, c, h, w = x.shape
326
+ x = self.patch_embed(x)
327
+ if masks is not None:
328
+ if self.mask_style == 'ibot':
329
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype), x)
330
+ elif self.mask_style == 'mae': # only gather unmasked patches
331
+ # add pos_embed before shuffle
332
+ pos_embed = self.interpolate_pos_encoding(x, w, h, npatch=x.shape[1])
333
+ x = x + pos_embed[:, 1:, :]
334
+ ids_keep = compute_gather_ids(masks)
335
+ x = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, x.shape[-1]))
336
+ # x = x[masks.logical_not()]
337
+ # x = x.reshape(b, -1, x.size(-1))
338
+ else:
339
+ raise NotImplementedError(f"mask style {self.mask_style} is not supported")
340
+
341
+ if (masks is None) or (self.mask_style != "mae"):
342
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
343
+ x = x + self.interpolate_pos_encoding(x, w, h, npatch=x.shape[1]-1)
344
+ else:
345
+ # mae-style masking, only need to add cls tokens w/ pos embedding
346
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
347
+ x = torch.cat((cls_token.expand(x.shape[0], -1, -1), x), dim=1)
348
+
349
+ return x
350
+
351
+ def forward_features_list(self, x_list, masks_list):
352
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
353
+
354
+ num_data = len(x)
355
+ if self.return_layer_targets:
356
+ all_layer_results = [[] for _ in range(num_data)]
357
+ for i, blk in enumerate(self.blocks):
358
+ out = [blk(t) for t in x]
359
+ x = [o[0] for o in out]
360
+ # store layer targets
361
+ for j in range(num_data):
362
+ all_layer_results[j].append(out[j][1])
363
+ all_x = x
364
+ else:
365
+ all_x = [self.blocks(t) for t in x]
366
+ all_layer_results = [None for _ in range(num_data)]
367
+
368
+ output = []
369
+ for x, masks, layer_results in zip(all_x, masks_list, all_layer_results):
370
+ x_norm = self.norm(x)
371
+ output.append(
372
+ {
373
+ "x_norm": x_norm,
374
+ "x_norm_clstoken": x_norm[:, 0],
375
+ "x_norm_patchtokens": x_norm[:, 1:],
376
+ "masks": masks,
377
+ "layer_results": layer_results,
378
+ }
379
+ )
380
+ return output
381
+
382
+ def forward_features(self, x, masks=None):
383
+ if isinstance(x, list):
384
+ return self.forward_features_list(x, masks)
385
+
386
+ x = self.prepare_tokens_with_masks(x, masks)
387
+
388
+ if self.return_layer_targets:
389
+ layer_results = []
390
+ for i, blk in enumerate(self.blocks):
391
+ x, lr = blk(x)
392
+ layer_results.append(lr)
393
+ else:
394
+ x = self.blocks(x)
395
+ layer_results = None
396
+
397
+ x_norm = self.norm(x)
398
+ return {
399
+ "x_norm": x_norm,
400
+ "x_norm_clstoken": x_norm[:, 0],
401
+ "x_norm_patchtokens": x_norm[:, 1:],
402
+ "masks": masks,
403
+ "layer_results": layer_results,
404
+ }
405
+
406
+ def forward(self, *args, is_training=False, **kwargs):
407
+ ret = self.forward_features(*args, **kwargs)
408
+ if is_training:
409
+ return ret
410
+ else:
411
+ return ret["x_norm_clstoken"]
412
+
413
+
414
+ def vit_small(patch_size=16, teacher_path=None, **kwargs):
415
+ model = MaskedTransformer(
416
+ patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, **kwargs)
417
+
418
+ if teacher_path is not None:
419
+ checkpoint = torch.load(teacher_path, map_location='cpu')
420
+
421
+ if 'state_dict' in checkpoint:
422
+ pretrained_dict = checkpoint['state_dict']
423
+ elif 'model' in checkpoint:
424
+ pretrained_dict = checkpoint['model']
425
+ else:
426
+ pretrained_dict = checkpoint
427
+
428
+ pretrained_dict = {k.replace("module.visual.", ""): v for k, v in pretrained_dict.items()}
429
+
430
+ missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False)
431
+ print('missing_keys: ', missing_keys)
432
+ print('unexpected_keys: ', unexpected_keys)
433
+
434
+ return model
435
+
436
+
437
+ def vit_base(patch_size=16, teacher_path=None, **kwargs):
438
+ model = MaskedTransformer(
439
+ patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, **kwargs)
440
+
441
+ if teacher_path is not None:
442
+ checkpoint = torch.load(teacher_path, map_location='cpu')
443
+
444
+ if 'state_dict' in checkpoint:
445
+ pretrained_dict = checkpoint['state_dict']
446
+ elif 'model' in checkpoint:
447
+ pretrained_dict = checkpoint['model']
448
+ else:
449
+ pretrained_dict = checkpoint
450
+
451
+ pretrained_dict = {k.replace("module.visual.", ""): v for k, v in pretrained_dict.items()}
452
+
453
+ missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False)
454
+ print('missing_keys: ', missing_keys)
455
+ print('unexpected_keys: ', unexpected_keys)
456
+
457
+ return model
458
+
459
+
460
+ def vit_large(patch_size=14, teacher_path=None, **kwargs):
461
+ model = MaskedTransformer(
462
+ patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, **kwargs)
463
+
464
+ if teacher_path is not None:
465
+ checkpoint = torch.load(teacher_path, map_location='cpu')
466
+
467
+ if 'state_dict' in checkpoint:
468
+ pretrained_dict = checkpoint['state_dict']
469
+ elif 'model' in checkpoint:
470
+ pretrained_dict = checkpoint['model']
471
+ else:
472
+ pretrained_dict = checkpoint
473
+
474
+ pretrained_dict = {k.replace("module.visual.", ""): v for k, v in pretrained_dict.items()}
475
+
476
+ missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False)
477
+ print('missing_keys: ', missing_keys)
478
+ print('unexpected_keys: ', unexpected_keys)
479
+
480
+ return model
481
+
482
+
483
+ if __name__ == '__main__':
484
+ import argparse
485
+ from fvcore.nn import FlopCountAnalysis, parameter_count_table
486
+ parser = argparse.ArgumentParser(description='PyTorch resnet Training')
487
+ args = parser.parse_args()
488
+
489
+ with torch.no_grad():
490
+ model = vit_base(patch_size=14, num_classes=0, mask_style='ibot')
491
+
492
+ # x = torch.randn(1, 3, 224, 224)
493
+ # out = model(x)
494
+ # print(out.shape)
495
+
496
+ print(parameter_count_table(model))
497
+
498
+ tensor = torch.rand(1, 3, 224, 224)
499
+ flops = FlopCountAnalysis(model, tensor)
500
+ print("FLOPs: ", flops.total()/1e9)
1_feature_extractor/original_images.png ADDED
1_feature_extractor/requirements.txt ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==0.33.0
3
+ aiohttp==3.9.5
4
+ aiosignal==1.3.1
5
+ antlr4-python3-runtime==4.9.3
6
+ async-timeout==4.0.3
7
+ attrs==23.2.0
8
+ cachetools==5.4.0
9
+ certifi==2024.7.4
10
+ charset-normalizer==3.3.2
11
+ click==8.1.7
12
+ cloudpickle==3.0.0
13
+ cmake==3.30.1
14
+ compressai==1.2.6
15
+ contourpy==1.2.1
16
+ cubinlinker-cu11==0.3.0.post2
17
+ cuda-python==11.8.2
18
+ cudf-cu11==24.6.1
19
+ cuml-cu11==24.6.1
20
+ cupy-cuda11x==13.2.0
21
+ cycler==0.12.1
22
+ Cython==3.0.10
23
+ dask==2024.5.1
24
+ dask-cuda==24.6.0
25
+ dask-cudf-cu11==24.6.1
26
+ dask-expr==1.1.1
27
+ diffusers==0.30.0
28
+ distributed==2024.5.1
29
+ distributed-ucxx-cu11==0.38.0
30
+ einops==0.8.0
31
+ fastrlock==0.8.2
32
+ filelock==3.15.4
33
+ fonttools==4.53.1
34
+ frozenlist==1.4.1
35
+ fsspec==2024.6.1
36
+ ftfy==6.2.0
37
+ future==1.0.0
38
+ fvcore==0.1.5.post20221221
39
+ grpcio==1.65.2
40
+ h5py==3.11.0
41
+ huggingface-hub==0.24.5
42
+ idna==3.7
43
+ importlib_metadata==8.0.0
44
+ importlib_resources==6.4.0
45
+ iopath==0.1.10
46
+ Jinja2==3.1.4
47
+ joblib==1.4.2
48
+ kiwisolver==1.4.5
49
+ libucx-cu11==1.15.0.post1
50
+ lightning-utilities==0.11.6
51
+ lit==18.1.8
52
+ llvmlite==0.43.0
53
+ locket==1.0.0
54
+ Markdown==3.6
55
+ markdown-it-py==3.0.0
56
+ MarkupSafe==2.1.5
57
+ matplotlib==3.9.1
58
+ mdurl==0.1.2
59
+ mpmath==1.3.0
60
+ msgpack==1.0.8
61
+ multidict==6.0.5
62
+ mypy-extensions==1.0.0
63
+ networkx==3.2.1
64
+ numba==0.60.0
65
+ numpy==1.26.4
66
+ nvtx==0.2.10
67
+ omegaconf==2.3.0
68
+ open-clip-torch==2.0.2
69
+ opencv-python==4.10.0.84
70
+ packaging==24.1
71
+ pandas==2.2.2
72
+ partd==1.4.2
73
+ pillow==10.4.0
74
+ portalocker==2.10.1
75
+ protobuf==4.25.4
76
+ psutil==6.0.0
77
+ ptxcompiler-cu11==0.8.1.post1
78
+ pyarrow==16.1.0
79
+ pycocotools==2.0.8
80
+ pyDeprecate==0.3.1
81
+ Pygments==2.18.0
82
+ pylibraft-cu11==24.6.0
83
+ pynvml==11.4.1
84
+ pyparsing==3.1.2
85
+ pyre-extensions==0.0.23
86
+ python-dateutil==2.9.0.post0
87
+ pytorch-lightning==1.5.0
88
+ pytorch-msssim==1.0.0
89
+ pytz==2024.1
90
+ PyYAML==6.0.1
91
+ raft-dask-cu11==24.6.0
92
+ rapids-dask-dependency==24.6.0
93
+ regex==2024.7.24
94
+ requests==2.32.3
95
+ rich==13.7.1
96
+ rmm-cu11==24.6.0
97
+ safetensors==0.4.3
98
+ scikit-learn==1.5.1
99
+ scipy==1.13.1
100
+ six==1.16.0
101
+ sortedcontainers==2.4.0
102
+ submitit==1.5.1
103
+ sympy==1.13.1
104
+ tabulate==0.9.0
105
+ tblib==3.0.0
106
+ tensorboard==2.17.0
107
+ tensorboard-data-server==0.7.2
108
+ termcolor==2.4.0
109
+ threadpoolctl==3.5.0
110
+ timm==1.0.7
111
+ tokenizers==0.13.3
112
+ toolz==0.12.1
113
+ torch==2.0.0+cu117
114
+ torch_geometric==2.5.3
115
+ torchmetrics==0.10.3
116
+ torchvision==0.15.0+cu117
117
+ tornado==6.4.1
118
+ tqdm==4.66.4
119
+ transformers==4.27.0
120
+ treelite==4.1.2
121
+ triton==2.0.0
122
+ typing-inspect==0.9.0
123
+ typing_extensions==4.12.2
124
+ tzdata==2024.1
125
+ ucx-py-cu11==0.38.0
126
+ ucxx-cu11==0.38.0
127
+ urllib3==2.2.2
128
+ wcwidth==0.2.13
129
+ Werkzeug==3.0.3
130
+ xformers==0.0.18
131
+ yacs==0.1.8
132
+ yarl==1.9.4
133
+ zict==3.0.0
134
+ zipp==3.19.2