yangwang825 commited on
Commit
0b91f1b
·
verified ·
1 Parent(s): f40a8ec

Update modeling_wav2vec2_spkreg.py

Browse files
Files changed (1) hide show
  1. modeling_wav2vec2_spkreg.py +26 -9
modeling_wav2vec2_spkreg.py CHANGED
@@ -529,6 +529,8 @@ class AMSoftmaxLoss(nn.Module):
529
  num_labels: int,
530
  scale: float = 30.0,
531
  margin: float = 0.35,
 
 
532
  ):
533
  """
534
  Args:
@@ -540,13 +542,13 @@ class AMSoftmaxLoss(nn.Module):
540
  self.num_labels = num_labels
541
  self.scale = scale
542
  self.margin = margin
 
 
543
 
544
  def forward(
545
  self,
546
  inputs: torch.Tensor,
547
  targets: torch.Tensor,
548
- label_smoothing: float = 0.0,
549
- reduction: str = "mean"
550
  ):
551
  """
552
  Args:
@@ -562,7 +564,9 @@ class AMSoftmaxLoss(nn.Module):
562
  psi = cosine - self.margin
563
  one_hot = nn.functional.one_hot(targets, self.num_labels)
564
  outputs = self.scale * torch.where(one_hot.bool(), psi, cosine)
565
- loss = F.cross_entropy(outputs, targets, label_smoothing=label_smoothing, reduction=reduction)
 
 
566
  return loss
567
 
568
 
@@ -577,7 +581,9 @@ class AAMSoftmaxLoss(nn.Module):
577
  num_labels: int,
578
  scale: float = 30.0,
579
  margin: float = 0.35,
580
- easy_margin: bool = False
 
 
581
  ):
582
  """
583
  Args:
@@ -591,6 +597,8 @@ class AAMSoftmaxLoss(nn.Module):
591
  self.scale = scale
592
  self.margin = margin
593
  self.easy_margin = easy_margin
 
 
594
 
595
  def forward(
596
  self,
@@ -627,7 +635,9 @@ class AAMSoftmaxLoss(nn.Module):
627
  outputs = (one_hot * phi) + ((1.0 - one_hot) * cosine)
628
  outputs = outputs * self.scale
629
 
630
- loss = F.cross_entropy(outputs, targets, label_smoothing=label_smoothing, reduction=reduction)
 
 
631
  return loss
632
 
633
 
@@ -739,17 +749,24 @@ class Wav2Vec2SpkRegForSequenceClassification(Wav2Vec2SpkRegPreTrainedModel):
739
  )
740
  elif self.config.loss_fct == 'additive_margin':
741
  loss_fct = AMSoftmaxLoss(
742
- self.config.num_labels, self.config.scale, self.config.margin
 
 
 
 
743
  )
744
  elif self.config.loss_fct == 'additive_angular_margin':
745
  loss_fct = AAMSoftmaxLoss(
746
- self.config.num_labels, self.config.scale, self.config.margin, self.config.easy_margin
 
 
 
 
 
747
  )
748
  loss = loss_fct(
749
  logits.view(-1, self.config.num_labels),
750
  labels.view(-1),
751
- label_smoothing=self.config.label_smoothing,
752
- reduction=self.config.reduction
753
  )
754
 
755
  if not return_dict:
 
529
  num_labels: int,
530
  scale: float = 30.0,
531
  margin: float = 0.35,
532
+ label_smoothing: float = 0.0,
533
+ reduction: str = "mean"
534
  ):
535
  """
536
  Args:
 
542
  self.num_labels = num_labels
543
  self.scale = scale
544
  self.margin = margin
545
+ self.label_smoothing = label_smoothing
546
+ self.reduction = reduction
547
 
548
  def forward(
549
  self,
550
  inputs: torch.Tensor,
551
  targets: torch.Tensor,
 
 
552
  ):
553
  """
554
  Args:
 
564
  psi = cosine - self.margin
565
  one_hot = nn.functional.one_hot(targets, self.num_labels)
566
  outputs = self.scale * torch.where(one_hot.bool(), psi, cosine)
567
+ loss = F.cross_entropy(
568
+ outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction
569
+ )
570
  return loss
571
 
572
 
 
581
  num_labels: int,
582
  scale: float = 30.0,
583
  margin: float = 0.35,
584
+ easy_margin: bool = False,
585
+ label_smoothing: float = 0.0,
586
+ reduction: str = "mean"
587
  ):
588
  """
589
  Args:
 
597
  self.scale = scale
598
  self.margin = margin
599
  self.easy_margin = easy_margin
600
+ self.label_smoothing = label_smoothing
601
+ self.reduction = reduction
602
 
603
  def forward(
604
  self,
 
635
  outputs = (one_hot * phi) + ((1.0 - one_hot) * cosine)
636
  outputs = outputs * self.scale
637
 
638
+ loss = F.cross_entropy(
639
+ outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction
640
+ )
641
  return loss
642
 
643
 
 
749
  )
750
  elif self.config.loss_fct == 'additive_margin':
751
  loss_fct = AMSoftmaxLoss(
752
+ self.config.num_labels,
753
+ self.config.scale,
754
+ self.config.margin,
755
+ label_smoothing=self.config.label_smoothing,
756
+ reduction=self.config.reduction
757
  )
758
  elif self.config.loss_fct == 'additive_angular_margin':
759
  loss_fct = AAMSoftmaxLoss(
760
+ self.config.num_labels,
761
+ self.config.scale,
762
+ self.config.margin,
763
+ self.config.easy_margin,
764
+ label_smoothing=self.config.label_smoothing,
765
+ reduction=self.config.reduction
766
  )
767
  loss = loss_fct(
768
  logits.view(-1, self.config.num_labels),
769
  labels.view(-1),
 
 
770
  )
771
 
772
  if not return_dict: