yangwang825
commited on
Update modeling_hubert_spkreg.py
Browse files- modeling_hubert_spkreg.py +18 -4
modeling_hubert_spkreg.py
CHANGED
@@ -432,7 +432,7 @@ class AAMSoftmaxLoss(nn.Module):
|
|
432 |
def __init__(
|
433 |
self,
|
434 |
scale: float = 30.0,
|
435 |
-
margin: float = 0.
|
436 |
easy_margin: bool = False,
|
437 |
label_smoothing: float = 0.0,
|
438 |
reduction: str = "mean"
|
@@ -465,9 +465,23 @@ class AAMSoftmaxLoss(nn.Module):
|
|
465 |
"""
|
466 |
_, num_labels = inputs.shape
|
467 |
# `inputs` are the outputs from AngularLinear()
|
468 |
-
|
469 |
-
theta = torch.acos(cos_theta)
|
470 |
-
psi = torch.cos(theta + self.margin)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
471 |
one_hot = nn.functional.one_hot(targets, num_labels)
|
472 |
outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
|
473 |
loss = F.cross_entropy(
|
|
|
432 |
def __init__(
|
433 |
self,
|
434 |
scale: float = 30.0,
|
435 |
+
margin: float = 0.2,
|
436 |
easy_margin: bool = False,
|
437 |
label_smoothing: float = 0.0,
|
438 |
reduction: str = "mean"
|
|
|
465 |
"""
|
466 |
_, num_labels = inputs.shape
|
467 |
# `inputs` are the outputs from AngularLinear()
|
468 |
+
epsilon = 1e-6
|
469 |
+
# theta = torch.acos(cos_theta)
|
470 |
+
# psi = torch.cos(theta + self.margin)
|
471 |
+
cos_theta = torch.clamp(inputs, -1.0 + epsilon, 1.0 - epsilon)
|
472 |
+
sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
|
473 |
+
sin_theta = torch.clamp(sin_theta, 0.0 + epsilon, 1.0 - epsilon)
|
474 |
+
|
475 |
+
cos_m = math.cos(self.margin)
|
476 |
+
sin_m = math.sin(self.margin)
|
477 |
+
psi = cos_theta * cos_m - sin_theta * sin_m # cos(theta + m)
|
478 |
+
|
479 |
+
if self.easy_margin:
|
480 |
+
psi = torch.where(cos_theta > 0, psi, cos_theta)
|
481 |
+
else:
|
482 |
+
# Make the function cos(theta+m) monotonic decreasing while theta in [0°, 180°]
|
483 |
+
psi = torch.where((cos_theta - math.cos(math.pi - self.margin)) > 0, psi, cos_theta - self.margin)
|
484 |
+
|
485 |
one_hot = nn.functional.one_hot(targets, num_labels)
|
486 |
outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
|
487 |
loss = F.cross_entropy(
|