File size: 1,324 Bytes
b9425fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# --------------------------------------------------------
# Reversible Column Networks
# Copyright (c) 2022 Megvii Inc.
# Licensed under The Apache License 2.0 [see LICENSE for details]
# Written by Yuxuan Cai
# --------------------------------------------------------

from dis import dis
import torch
from torch import nn
import torch.distributed as dist
from torch.functional import Tensor
import torch.nn.functional as F



def compound_loss(coe, output_feature, image:Tensor, output_label, targets, criterion_bce, criterion_ce, epoch):
    f_coe, c_coe = coe
    image.clamp_(0.01, 0.99)
    multi_loss = []
    for i, feature in enumerate(output_feature):
        ratio_f = 1 - i / len(output_feature)
        ratio_c = (i+1) / (len(output_label))

        ihx = criterion_bce(feature, image) * ratio_f * f_coe 
        ihy = criterion_ce(output_label[i], targets) * ratio_c * c_coe
        # if dist.get_rank() == 0:
        #     print(f'ihx: {ihx}, ihy: {ihy}')
        multi_loss.append(ihx + ihy)
        # feature_loss.append(torch.dist(output_feature[i], teacher_feature) *  feature_coe)
    multi_loss.append(criterion_ce(output_label[-1], targets))
    # print(feature_loss)
    loss = torch.sum(torch.stack(multi_loss), dim=0)
    # +torch.mean(torch.stack(feature_loss), dim=0)
    return loss, multi_loss