File size: 20,779 Bytes
3c849be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
# Copyright 2022 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
models and functions for building student and teacher networks for multi-granular losses.
"""
import torch
import torch.nn as nn

import src.vision_transformer as vits
from src.vision_transformer import trunc_normal_


class Instance_Superivsion_Head(nn.Module):
    """
    a class to implement Instance Superivsion Head
    --in_dim: input dimension of projection head
    --hidden_dim: hidden dimension of projection head
    --out_dim: ouput dimension of projection and prediction heads
    --pred_hidden_dim: hidden dimension of prediction head
    --nlayers: layer number of projection head. prediction head has nlayers-1 layer
    --proj_bn: whether we use batch normalization in projection head
    --pred_bn: whether we use batch normalization in prediction head
    --norm_before_pred:  whether we use normalization before prediction head
    """

    def __init__(
        self,
        in_dim,
        hidden_dim=2048,
        out_dim=256,
        pred_hidden_dim=4096,
        nlayers=3,
        proj_bn=False,
        pred_bn=False,
        norm_before_pred=True,
    ):
        super().__init__()
        nlayers = max(nlayers, 1)
        self.norm_before_pred = norm_before_pred

        self.projector = self._build_mlp(
            nlayers, in_dim, hidden_dim, out_dim, use_bn=proj_bn
        )

        self.apply(self._init_weights)

        self.predictor = None
        if pred_hidden_dim > 0:  # teacher no, student yes
            self.predictor = self._build_mlp(
                nlayers - 1, out_dim, pred_hidden_dim, out_dim, use_bn=pred_bn
            )

    def _init_weights(self, m):
        """
        initilize the parameters in network
        """
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def _build_mlp(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False):
        """
        build a mlp
        """
        mlp = []
        for layer in range(num_layers):
            dim1 = input_dim if layer == 0 else hidden_dim
            dim2 = output_dim if layer == num_layers - 1 else hidden_dim

            mlp.append(nn.Linear(dim1, dim2, bias=False))

            if layer < num_layers - 1:
                if use_bn:
                    mlp.append(nn.BatchNorm1d(dim2))
                mlp.append(nn.GELU())

        return nn.Sequential(*mlp)

    def forward(self, x, return_target=False):
        """
        forward the input through projection head for teacher and
        projection/prediction heads for student
        """
        feat = self.projector(x)

        if return_target:
            feat = nn.functional.normalize(feat, dim=-1, p=2)
            return feat
        ## return prediction
        if self.norm_before_pred:
            feat = nn.functional.normalize(feat, dim=-1, p=2)
        pred = self.predictor(feat)
        pred = nn.functional.normalize(pred, dim=-1, p=2)
        return pred


class Local_Group_Superivsion_Head(nn.Module):
    """
    a class to implement Local Group Superivsion Head which is the same as Instance Superivsion Head
    --in_dim: input dimension of projection head
    --hidden_dim: hidden dimension of projection head
    --out_dim: ouput dimension of projection and prediction heads
    --pred_hidden_dim: hidden dimension of prediction head
    --nlayers: layer number of projection head. prediction head has nlayers-1 layer
    --proj_bn: whether we use batch normalization in projection head
    --pred_bn: whether we use batch normalization in prediction head
    --norm_before_pred:  whether we use normalization before prediction head
    """

    def __init__(
        self,
        in_dim,
        hidden_dim=2048,
        out_dim=256,
        pred_hidden_dim=4096,
        nlayers=3,
        proj_bn=False,
        pred_bn=False,
        norm_before_pred=True,
    ):
        super().__init__()
        nlayers = max(nlayers, 1)
        self.norm_before_pred = norm_before_pred

        self.projector = self._build_mlp(
            nlayers, in_dim, hidden_dim, out_dim, use_bn=proj_bn
        )

        self.apply(self._init_weights)

        self.predictor = None
        if pred_hidden_dim > 0:  # teacher no, student yes
            self.predictor = self._build_mlp(
                nlayers - 1, out_dim, pred_hidden_dim, out_dim, use_bn=pred_bn
            )

    def _init_weights(self, m):
        """
        initilize the parameters in network
        """
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def _build_mlp(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False):
        """
        build a mlp
        """
        mlp = []
        for layer in range(num_layers):
            dim1 = input_dim if layer == 0 else hidden_dim
            dim2 = output_dim if layer == num_layers - 1 else hidden_dim

            mlp.append(nn.Linear(dim1, dim2, bias=False))

            if layer < num_layers - 1:
                if use_bn:
                    mlp.append(nn.BatchNorm1d(dim2))
                mlp.append(nn.GELU())

        return nn.Sequential(*mlp)

    def forward(self, x, return_target=False):
        """
        forward the input through projection head for teacher and
        projection/prediction heads for student
        """
        feat = self.projector(x)

        if return_target:
            feat = nn.functional.normalize(feat, dim=-1, p=2)
            return feat
        ## return prediction
        if self.norm_before_pred:
            feat = nn.functional.normalize(feat, dim=-1, p=2)
        pred = self.predictor(feat)
        pred = nn.functional.normalize(pred, dim=-1, p=2)
        return pred


class Group_Superivsion_Head(nn.Module):
    """
    a class to implement Local Group Superivsion Head which is the same as Instance Superivsion Head
    --in_dim: input dimension of projection head
    --hidden_dim: hidden dimension of projection head
    --out_dim: ouput dimension of projection and prediction heads
    --pred_hidden_dim: hidden dimension of prediction head
    --nlayers: layer number of projection head. prediction head has nlayers-1 layer
    --proj_bn: whether we use batch normalization in projection head
    --pred_bn: whether we use batch normalization in prediction head
    --norm_before_pred:  whether we use normalization before prediction head
    """

    def __init__(
        self,
        in_dim,
        out_dim,
        hidden_dim=2048,
        bottleneck_dim=256,
        nlayers=3,
        use_bn=False,
        norm_last_layer=True,
    ):
        super().__init__()
        nlayers = max(nlayers, 1)

        self.projector = self._build_mlp(
            nlayers, in_dim, hidden_dim, bottleneck_dim, use_bn=use_bn
        )
        self.apply(self._init_weights)

        self.last_layer = nn.utils.weight_norm(
            nn.Linear(bottleneck_dim, out_dim, bias=False)
        )
        self.last_layer.weight_g.data.fill_(1)
        if norm_last_layer:
            self.last_layer.weight_g.requires_grad = False

    def _build_mlp(self, num_layers, in_dim, hidden_dim, output_dim, use_bn=False):
        """
        build a mlp
        """
        if num_layers == 1:
            mlp = nn.Linear(in_dim, output_dim)
        else:
            layers = [nn.Linear(in_dim, hidden_dim)]
            if use_bn:
                layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.GELU())
            for _ in range(num_layers - 2):
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                if use_bn:
                    layers.append(nn.BatchNorm1d(hidden_dim))
                layers.append(nn.GELU())
            layers.append(nn.Linear(hidden_dim, output_dim))
            mlp = nn.Sequential(*layers)
        return mlp

    def _init_weights(self, m):
        """
        initilize the parameters in network
        """
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        """
        forward the input through the projection and last prediction layer
        """
        feat = self.projector(x)
        feat = nn.functional.normalize(feat, dim=-1, p=2)
        feat = self.last_layer(feat)
        return feat


class Block_mem(nn.Module):
    """
    a class to implement a memory block for local group supervision
    --dim: feature vector dimenstion in the memory
    --K: memory size
    --top_n: number for neighbors in local group supervision
    """

    def __init__(self, dim, K=2048, top_n=10):
        super().__init__()
        self.dim = dim
        self.K = K
        self.top_n = top_n
        # create the queue
        self.register_buffer("queue_q", torch.randn(K, dim))
        self.register_buffer("queue_k", torch.randn(K, dim))
        self.register_buffer("queue_v", torch.randn(K, dim))
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _dequeue_and_enqueue(self, query, weak_aug_flags):
        """
        update memory queue
        """
        # import pdb
        # pdb.set_trace()
        len_weak = 0
        query = concat_all_gather(query)
        if weak_aug_flags is not None:
            weak_aug_flags = weak_aug_flags.cuda()
            weak_aug_flags = concat_all_gather(weak_aug_flags)
            idx_weak = torch.nonzero(weak_aug_flags)
            len_weak = len(idx_weak)
            if len_weak > 0:
                idx_weak = idx_weak.squeeze(-1) 
                query = query[idx_weak]
            else:
                return len_weak

        all_size = query.shape[0]
        ptr = int(self.queue_ptr)
        remaining_size = ptr + all_size - self.K
        if remaining_size <= 0:
            self.queue_q[ptr : ptr + all_size, :] = query
            self.queue_k[ptr : ptr + all_size, :] = query
            self.queue_v[ptr : ptr + all_size, :] = query
            ptr = ptr + all_size
            self.queue_ptr[0] = (ptr + all_size) % self.K
        else:
            self.queue_q[ptr : self.K, :] = query[0 : self.K - ptr, :]
            self.queue_k[ptr : self.K, :] = query[0 : self.K - ptr, :]
            self.queue_v[ptr : self.K, :] = query[0 : self.K - ptr, :]

            self.queue_q[0:remaining_size, :] = query[self.K - ptr :, :]
            self.queue_k[0:remaining_size, :] = query[self.K - ptr :, :]
            self.queue_v[0:remaining_size, :] = query[self.K - ptr :, :]
            self.queue_ptr[0] = remaining_size
        return len_weak

    @torch.no_grad()
    def _get_similarity_index(self, x):
        """
        compute the index of the top-n neighbors (key-value pair) in memory
        """
        x = nn.functional.normalize(x, dim=-1)
        queue_q = nn.functional.normalize(self.queue_q, dim=-1)

        cosine = x @ queue_q.T
        _, index = torch.topk(cosine, self.top_n, dim=-1)
        return index

    @torch.no_grad()
    def _get_similarity_samples(self, query, index=None):
        """
        compute top-n neighbors (key-value pair) in memory
        """
        if index is None:
            index = self._get_similarity_index(query)
        get_k = self.queue_k[index.view(-1)]
        get_v = self.queue_v[index.view(-1)]
        B, tn = index.shape
        get_k = get_k.view(B, tn, self.dim)
        get_v = get_v.view(B, tn, self.dim)
        return get_k, get_v

    def forward(self, query):
        """
        forward to find the top-n neighbors (key-value pair) in memory
        """
        get_k, get_v = self._get_similarity_samples(query)
        return get_k, get_v


class vit_mem(nn.Module):
    """
    a class to implement a memory for local group supervision
    --dim: feature vector dimenstion in the memory
    --K: memory size
    --top_n: number for neighbors in local group supervision
    """

    def __init__(self, dim, K=2048, top_n=10):
        super().__init__()
        self.block = Block_mem(dim, K, top_n)

    def _dequeue_and_enqueue(self, query, weak_aug_flags):
        """
        update memory queue
        """
        query = query.float()
        weak_num = self.block._dequeue_and_enqueue(query, weak_aug_flags)
        return weak_num

    def forward(self, query):
        """
        forward to find the top-n neighbors (key-value pair) in memory
        """
        query = query.float()
        get_k, get_v = self.block(query)
        return get_k, get_v


class Mugs_Wrapper(nn.Module):
    """
    a class to implement a student or teacher wrapper for mugs
    --backbone: the backnone of student/teacher, e.g. ViT-small
    --instance_head: head, including projection/prediction heads, for instance supervision
    --local_group_head: head, including projection/prediction heads, for local group supervision
    --group_head: projection head for group supervision
    """

    def __init__(self, backbone, instance_head, local_group_head, group_head):
        super(Mugs_Wrapper, self).__init__()
        backbone.fc, backbone.head = nn.Identity(), nn.Identity()
        self.backbone = backbone
        self.instance_head = instance_head
        self.local_group_head = local_group_head
        self.group_head = group_head

    def forward(self, x, return_target=False, local_group_memory_inputs=None):
        """
        forward input to get instance/local-group/group targets or predictions
        """
        # convert to list
        if not isinstance(x, list):
            x = [x]
        idx_crops = torch.cumsum(
            torch.unique_consecutive(
                torch.tensor([inp.shape[-1] for inp in x]),
                return_counts=True,
            )[1],
            0,
        )

        start_idx = 0
        class_tokens = torch.empty(0).to(x[0].device)
        mean_patch_tokens = torch.empty(0).to(x[0].device)
        memory_class_tokens = torch.empty(0).to(x[0].device)
        for _, end_idx in enumerate(idx_crops):
            input = torch.cat(x[start_idx:end_idx])
            token_feat, memory_class_token_feat = self.backbone(
                input,
                return_all=True,
                local_group_memory_inputs=local_group_memory_inputs,
            )  # [[16, 197, 384], [16, 384]] teacher
            # [[16, 197, 384], [16, 384]] student  [[48, 37, 384], [48, 384]]

            class_token_feat = token_feat[
                :, 0
            ]  # class tokens in ViT, [16, 384] teacher [16, 384] student  [48, 384]
            class_tokens = torch.cat((class_tokens, class_token_feat))

            start_idx = end_idx

            if self.local_group_head is not None:
                memory_class_tokens = torch.cat(
                    (memory_class_tokens, memory_class_token_feat)
                )
                if input.shape[-1] == 224:
                    mean_patch_tokens = torch.cat(
                        (mean_patch_tokens, token_feat[:, 1:].mean(dim=1))
                    )

        ## target [16, 256] for teacher,  [64, 256] for student,
        instance_feat = (
            self.instance_head(class_tokens, return_target)
            if self.instance_head is not None
            else None
        )

        ## target [16, 256] for teacher,  [64, 256] for student
        local_group_feat = (
            self.local_group_head(memory_class_tokens, return_target)
            if self.local_group_head is not None
            else None
        )

        # target [16, 65536] for teacher, [64, 65536] for student
        group_feat = (
            self.group_head(class_tokens) if self.group_head is not None else None
        )
        return instance_feat, local_group_feat, group_feat, mean_patch_tokens.detach()


def get_model(args):
    """
    build a student or teacher for mugs, includeing backbone, instance/local-group/group heads,
    and memory buffer
    """
    ## backbone
    if args.arch in vits.__dict__.keys():
        student = vits.__dict__[args.arch](
            patch_size=args.patch_size,
            num_relation_blocks=1,
            drop_path_rate=args.drop_path_rate,  # stochastic depth
        )
        teacher = vits.__dict__[args.arch](
            patch_size=args.patch_size, num_relation_blocks=1
        )
        embed_dim = student.embed_dim
    else:
        assert f"Unknow architecture: {args.arch}"

    ## memory buffer for local-group loss
    student_mem = vit_mem(
        embed_dim, K=args.local_group_queue_size, top_n=args.local_group_knn_top_n
    )
    teacher_mem = vit_mem(
        embed_dim, K=args.local_group_queue_size, top_n=args.local_group_knn_top_n
    )

    ## multi-crop wrapper handles forward with inputs of different resolutions
    student_instance_head, student_local_group_head, student_group_head = (
        None,
        None,
        None,
    )
    teacher_instance_head, teacher_local_group_head, teacher_group_head = (
        None,
        None,
        None,
    )

    # instance head
    if args.loss_weights[0] > 0:
        student_instance_head = Instance_Superivsion_Head(
            in_dim=embed_dim,
            hidden_dim=2048,
            out_dim=args.instance_out_dim,
            pred_hidden_dim=4096,
            nlayers=3,
            proj_bn=args.use_bn_in_head,
            pred_bn=False,
            norm_before_pred=args.norm_before_pred,
        )
        teacher_instance_head = Instance_Superivsion_Head(
            in_dim=embed_dim,
            hidden_dim=2048,
            out_dim=args.instance_out_dim,
            pred_hidden_dim=0,
            nlayers=3,
            proj_bn=args.use_bn_in_head,
            pred_bn=False,
            norm_before_pred=args.norm_before_pred,
        )

    # local group head
    if args.loss_weights[1] > 0:
        student_local_group_head = Local_Group_Superivsion_Head(
            in_dim=embed_dim,
            hidden_dim=2048,
            out_dim=args.local_group_out_dim,
            pred_hidden_dim=4096,
            nlayers=3,
            proj_bn=args.use_bn_in_head,
            pred_bn=False,
            norm_before_pred=args.norm_before_pred,
        )
        teacher_local_group_head = Local_Group_Superivsion_Head(
            in_dim=embed_dim,
            hidden_dim=2048,
            out_dim=args.local_group_out_dim,
            pred_hidden_dim=0,
            nlayers=3,
            proj_bn=args.use_bn_in_head,
            pred_bn=False,
            norm_before_pred=args.norm_before_pred,
        )

    # group head
    if args.loss_weights[2] > 0:
        student_group_head = Group_Superivsion_Head(
            in_dim=embed_dim,
            out_dim=args.group_out_dim,
            hidden_dim=2048,
            bottleneck_dim=args.group_bottleneck_dim,
            nlayers=3,
            use_bn=args.use_bn_in_head,
            norm_last_layer=args.norm_last_layer,
        )
        teacher_group_head = Group_Superivsion_Head(
            in_dim=embed_dim,
            out_dim=args.group_out_dim,
            hidden_dim=2048,
            bottleneck_dim=args.group_bottleneck_dim,
            nlayers=3,
            use_bn=args.use_bn_in_head,
            norm_last_layer=args.norm_last_layer,
        )

    # multi-crop wrapper
    student = Mugs_Wrapper(
        student, student_instance_head, student_local_group_head, student_group_head
    )

    teacher = Mugs_Wrapper(
        teacher, teacher_instance_head, teacher_local_group_head, teacher_group_head
    )

    return student, teacher, student_mem, teacher_mem


# utils
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """

    tensors_gather = [
        torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
    ]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output