RevCol / training /loss.py
LarryTsai's picture
Training Code:cls/det
b9425fd
# --------------------------------------------------------
# Reversible Column Networks
# Copyright (c) 2022 Megvii Inc.
# Licensed under The Apache License 2.0 [see LICENSE for details]
# Written by Yuxuan Cai
# --------------------------------------------------------
from dis import dis
import torch
from torch import nn
import torch.distributed as dist
from torch.functional import Tensor
import torch.nn.functional as F
def compound_loss(coe, output_feature, image:Tensor, output_label, targets, criterion_bce, criterion_ce, epoch):
f_coe, c_coe = coe
image.clamp_(0.01, 0.99)
multi_loss = []
for i, feature in enumerate(output_feature):
ratio_f = 1 - i / len(output_feature)
ratio_c = (i+1) / (len(output_label))
ihx = criterion_bce(feature, image) * ratio_f * f_coe
ihy = criterion_ce(output_label[i], targets) * ratio_c * c_coe
# if dist.get_rank() == 0:
# print(f'ihx: {ihx}, ihy: {ihy}')
multi_loss.append(ihx + ihy)
# feature_loss.append(torch.dist(output_feature[i], teacher_feature) * feature_coe)
multi_loss.append(criterion_ce(output_label[-1], targets))
# print(feature_loss)
loss = torch.sum(torch.stack(multi_loss), dim=0)
# +torch.mean(torch.stack(feature_loss), dim=0)
return loss, multi_loss