Spaces:
Running
on
T4
Running
on
T4
import torch | |
import torch.nn as nn | |
from .base import BaseLosses | |
class CommitLoss(nn.Module): | |
""" | |
Useless Wrapper | |
""" | |
def __init__(self, **kwargs): | |
super().__init__() | |
def forward(self, commit, commit2, **kwargs): | |
return commit | |
class GPTLosses(BaseLosses): | |
def __init__(self, cfg, stage, num_joints, **kwargs): | |
# Save parameters | |
self.stage = stage | |
recons_loss = cfg.LOSS.ABLATION.RECONS_LOSS | |
# Define losses | |
losses = [] | |
params = {} | |
if stage == "vae": | |
losses.append("recons_feature") | |
params['recons_feature'] = cfg.LOSS.LAMBDA_FEATURE | |
losses.append("recons_velocity") | |
params['recons_velocity'] = cfg.LOSS.LAMBDA_VELOCITY | |
losses.append("vq_commit") | |
params['vq_commit'] = cfg.LOSS.LAMBDA_COMMIT | |
elif stage in ["lm_pretrain", "lm_instruct"]: | |
losses.append("gpt_loss") | |
params['gpt_loss'] = cfg.LOSS.LAMBDA_CLS | |
# Define loss functions & weights | |
losses_func = {} | |
for loss in losses: | |
if loss.split('_')[0] == 'recons': | |
if recons_loss == "l1": | |
losses_func[loss] = nn.L1Loss | |
elif recons_loss == "l2": | |
losses_func[loss] = nn.MSELoss | |
elif recons_loss == "l1_smooth": | |
losses_func[loss] = nn.SmoothL1Loss | |
elif loss.split('_')[1] in [ | |
'commit', 'loss', 'gpt', 'm2t2m', 't2m2t' | |
]: | |
losses_func[loss] = CommitLoss | |
elif loss.split('_')[1] in ['cls', 'lm']: | |
losses_func[loss] = nn.CrossEntropyLoss | |
else: | |
raise NotImplementedError(f"Loss {loss} not implemented.") | |
super().__init__(cfg, losses, params, losses_func, num_joints, | |
**kwargs) | |
def update(self, rs_set): | |
'''Update the losses''' | |
total: float = 0.0 | |
if self.stage in ["vae"]: | |
total += self._update_loss("recons_feature", rs_set['m_rst'], | |
rs_set['m_ref']) | |
# total += self._update_loss("recons_joints", rs_set['joints_rst'], rs_set['joints_ref']) | |
nfeats = rs_set['m_rst'].shape[-1] | |
if nfeats in [263, 135 + 263]: | |
if nfeats == 135 + 263: | |
vel_start = 135 + 4 | |
elif nfeats == 263: | |
vel_start = 4 | |
total += self._update_loss( | |
"recons_velocity", | |
rs_set['m_rst'][..., vel_start:(self.num_joints - 1) * 3 + | |
vel_start], | |
rs_set['m_ref'][..., vel_start:(self.num_joints - 1) * 3 + | |
vel_start]) | |
else: | |
if self._params['recons_velocity'] != 0.0: | |
raise NotImplementedError( | |
"Velocity not implemented for nfeats = {})".format(nfeats)) | |
total += self._update_loss("vq_commit", rs_set['loss_commit'], | |
rs_set['loss_commit']) | |
if self.stage in ["lm_pretrain", "lm_instruct"]: | |
total += self._update_loss("gpt_loss", rs_set['outputs'].loss, | |
rs_set['outputs'].loss) | |
# Update the total loss | |
self.total += total.detach() | |
self.count += 1 | |
return total | |