File size: 2,342 Bytes
b725c5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F


def log_dur_loss(dur_pred_log, dur_target, mask, loss_type="l1"):
    # dur_pred_log: (B, N)
    # dur_target: (B, N)
    # mask: (B, N) mask is 0
    dur_target_log = torch.log(1 + dur_target)
    if loss_type == "l1":
        loss = F.l1_loss(
            dur_pred_log, dur_target_log, reduction="none"
        ).float() * mask.to(dur_target.dtype)
    elif loss_type == "l2":
        loss = F.mse_loss(
            dur_pred_log, dur_target_log, reduction="none"
        ).float() * mask.to(dur_target.dtype)
    else:
        raise NotImplementedError()
    loss = loss.sum() / (mask.to(dur_target.dtype).sum())
    return loss


def log_pitch_loss(pitch_pred_log, pitch_target, mask, loss_type="l1"):
    pitch_target_log = torch.log(pitch_target)
    if loss_type == "l1":
        loss = F.l1_loss(
            pitch_pred_log, pitch_target_log, reduction="none"
        ).float() * mask.to(pitch_target.dtype)
    elif loss_type == "l2":
        loss = F.mse_loss(
            pitch_pred_log, pitch_target_log, reduction="none"
        ).float() * mask.to(pitch_target.dtype)
    else:
        raise NotImplementedError()
    loss = loss.sum() / (mask.to(pitch_target.dtype).sum() + 1e-8)
    return loss


def diff_loss(pred, target, mask, loss_type="l1"):
    # pred: (B, d, T)
    # target: (B, d, T)
    # mask: (B, T)
    if loss_type == "l1":
        loss = F.l1_loss(pred, target, reduction="none").float() * (
            mask.to(pred.dtype).unsqueeze(1)
        )
    elif loss_type == "l2":
        loss = F.mse_loss(pred, target, reduction="none").float() * (
            mask.to(pred.dtype).unsqueeze(1)
        )
    else:
        raise NotImplementedError()
    loss = (torch.mean(loss, dim=1)).sum() / (mask.to(pred.dtype).sum())
    return loss


def diff_ce_loss(pred_dist, gt_indices, mask):
    # pred_dist: (nq, B, T, 1024)
    # gt_indices: (nq, B, T)
    pred_dist = pred_dist.permute(1, 3, 0, 2)  # (B, 1024, nq, T)
    gt_indices = gt_indices.permute(1, 0, 2).long()  # (B, nq, T)
    loss = F.cross_entropy(
        pred_dist, gt_indices, reduction="none"
    ).float()  # (B, nq, T)
    loss = loss * mask.to(loss.dtype).unsqueeze(1)
    loss = (torch.mean(loss, dim=1)).sum() / (mask.to(loss.dtype).sum())
    return loss