|
""" |
|
Copyright 2023 LINE Corporation |
|
LINE Corporation licenses this file to you under the Apache License, |
|
version 2.0 (the "License"); you may not use this file except in compliance |
|
with the License. You may obtain a copy of the License at: |
|
https://www.apache.org/licenses/LICENSE-2.0 |
|
Unless required by applicable law or agreed to in writing, software |
|
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
|
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
|
License for the specific language governing permissions and limitations |
|
under the License. |
|
""" |
|
|
|
from platform import mac_ver |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange, reduce, repeat |
|
from torch.autograd import Variable |
|
|
|
|
|
def kl_loss_compute(pred, soft_targets, reduce=True): |
|
|
|
kl = F.kl_div( |
|
F.log_softmax(pred, dim=1), F.softmax(soft_targets, dim=1), reduce=False |
|
) |
|
if reduce: |
|
return torch.mean(torch.sum(kl, dim=1)) |
|
else: |
|
return torch.sum(kl, 1) |
|
|
|
|
|
def mvl_loss(y_1, y_2, rate=0.2, weight=0.1): |
|
y_1 = rearrange(y_1, "n t c -> (n t) c") |
|
y_2 = rearrange(y_2, "n t c -> (n t) c") |
|
|
|
loss_pick = weight * kl_loss_compute( |
|
y_1, y_2, reduce=False |
|
) + weight * kl_loss_compute(y_2, y_1, reduce=False) |
|
|
|
loss_pick = loss_pick.cpu().detach() |
|
|
|
ind_sorted = torch.argsort(loss_pick.data) |
|
loss_sorted = loss_pick[ind_sorted] |
|
|
|
num_remember = int(rate * len(loss_sorted)) |
|
|
|
ind_update = ind_sorted[:num_remember] |
|
|
|
loss = torch.mean(loss_pick[ind_update]) |
|
|
|
return loss |
|
|
|
|
|
def cross_entropy_loss(outputs, soft_targets): |
|
mask = (soft_targets != -100).sum(1) > 0 |
|
outputs = outputs[mask] |
|
soft_targets = soft_targets[mask] |
|
loss = -torch.mean(torch.sum(F.log_softmax(outputs, dim=1) * soft_targets, dim=1)) |
|
return loss |
|
|
|
|