yangwang825 commited on
Commit
a2d1048
·
verified ·
1 Parent(s): d79a689

Update modeling_hubert_spkreg.py

Browse files
Files changed (1) hide show
  1. modeling_hubert_spkreg.py +18 -4
modeling_hubert_spkreg.py CHANGED
@@ -432,7 +432,7 @@ class AAMSoftmaxLoss(nn.Module):
432
  def __init__(
433
  self,
434
  scale: float = 30.0,
435
- margin: float = 0.35,
436
  easy_margin: bool = False,
437
  label_smoothing: float = 0.0,
438
  reduction: str = "mean"
@@ -465,9 +465,23 @@ class AAMSoftmaxLoss(nn.Module):
465
  """
466
  _, num_labels = inputs.shape
467
  # `inputs` are the outputs from AngularLinear()
468
- cos_theta = torch.clamp(inputs, -1.0 + 1e-7, 1.0 - 1e-7)
469
- theta = torch.acos(cos_theta)
470
- psi = torch.cos(theta + self.margin)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  one_hot = nn.functional.one_hot(targets, num_labels)
472
  outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
473
  loss = F.cross_entropy(
 
432
  def __init__(
433
  self,
434
  scale: float = 30.0,
435
+ margin: float = 0.2,
436
  easy_margin: bool = False,
437
  label_smoothing: float = 0.0,
438
  reduction: str = "mean"
 
465
  """
466
  _, num_labels = inputs.shape
467
  # `inputs` are the outputs from AngularLinear()
468
+ epsilon = 1e-6
469
+ # theta = torch.acos(cos_theta)
470
+ # psi = torch.cos(theta + self.margin)
471
+ cos_theta = torch.clamp(inputs, -1.0 + epsilon, 1.0 - epsilon)
472
+ sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
473
+ sin_theta = torch.clamp(sin_theta, 0.0 + epsilon, 1.0 - epsilon)
474
+
475
+ cos_m = math.cos(self.margin)
476
+ sin_m = math.sin(self.margin)
477
+ psi = cos_theta * cos_m - sin_theta * sin_m # cos(theta + m)
478
+
479
+ if self.easy_margin:
480
+ psi = torch.where(cos_theta > 0, psi, cos_theta)
481
+ else:
482
+ # Make the function cos(theta+m) monotonic decreasing while theta in [0°, 180°]
483
+ psi = torch.where((cos_theta - math.cos(math.pi - self.margin)) > 0, psi, cos_theta - self.margin)
484
+
485
  one_hot = nn.functional.one_hot(targets, num_labels)
486
  outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
487
  loss = F.cross_entropy(