Spaces:
Running
Running
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
def flatten_label(target): | |
label_flatten = [] | |
label_length = [] | |
for i in range(0, target.size()[0]): | |
cur_label = target[i].tolist() | |
label_flatten += cur_label[:cur_label.index(0) + 1] | |
label_length.append(cur_label.index(0) + 1) | |
label_flatten = torch.LongTensor(label_flatten) | |
label_length = torch.IntTensor(label_length) | |
return (label_flatten, label_length) | |
def _flatten(sources, lengths): | |
return torch.cat([t[:l] for t, l in zip(sources, lengths)]) | |
class VisionLANLoss(nn.Module): | |
def __init__(self, | |
training_step='LA', | |
ratio_res=0.5, | |
ratio_sub=0.5, | |
**kwargs): | |
super(VisionLANLoss, self).__init__() | |
self.loss_func = nn.CrossEntropyLoss(reduction='mean') | |
self.ratio_res = ratio_res | |
self.ratio_sub = ratio_sub | |
assert training_step in ['LF_1', 'LF_2', 'LA'] | |
self.training_step = training_step | |
def forward(self, pred, batch): | |
text_pre, text_rem, text_mas, _ = pred | |
target = batch[1].to(dtype=torch.int64) | |
label_flatten, length = flatten_label(target) | |
text_pre = _flatten(text_pre, length) | |
if self.training_step == 'LF_1': | |
loss = self.loss_func(text_pre, label_flatten.to(text_pre.device)) | |
else: | |
target_res = batch[2].to(dtype=torch.int64) | |
target_sub = batch[3].to(dtype=torch.int64) | |
label_flatten_res, length_res = flatten_label(target_res) | |
label_flatten_sub, length_sub = flatten_label(target_sub) | |
text_rem = _flatten(text_rem, length_res) | |
text_mas = _flatten(text_mas, length_sub) | |
loss_ori = self.loss_func(text_pre, | |
label_flatten.to(text_pre.device)) | |
loss_res = self.loss_func(text_rem, | |
label_flatten_res.to(text_rem.device)) | |
loss_mas = self.loss_func(text_mas, | |
label_flatten_sub.to(text_mas.device)) | |
loss = loss_ori + loss_res * self.ratio_res + loss_mas * self.ratio_sub | |
return {'loss': loss} | |