yangwang825
commited on
Commit
•
c007f7f
1
Parent(s):
6a1d27a
Upload model
Browse files- config.json +5 -3
- modeling_wav2vec2_spkreg.py +19 -40
config.json
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
{
|
|
|
2 |
"activation_dropout": 0.0,
|
3 |
"adapter_attn_dim": null,
|
4 |
"adapter_kernel_size": 3,
|
@@ -6,11 +7,12 @@
|
|
6 |
"add_adapter": false,
|
7 |
"apply_spec_augment": true,
|
8 |
"architectures": [
|
9 |
-
"
|
10 |
],
|
11 |
"attention_dropout": 0.1,
|
12 |
"auto_map": {
|
13 |
-
"AutoConfig": "configuration_wav2vec2_spkreg.Wav2Vec2SpkRegConfig"
|
|
|
14 |
},
|
15 |
"bos_token_id": 1,
|
16 |
"classifier_proj_size": 256,
|
@@ -56,7 +58,6 @@
|
|
56 |
"feat_quantizer_dropout": 0.0,
|
57 |
"final_dropout": 0.0,
|
58 |
"freeze_feat_extract_train": true,
|
59 |
-
"gradient_checkpointing": true,
|
60 |
"hidden_act": "gelu",
|
61 |
"hidden_dropout": 0.1,
|
62 |
"hidden_size": 768,
|
@@ -120,6 +121,7 @@
|
|
120 |
1,
|
121 |
1
|
122 |
],
|
|
|
123 |
"transformers_version": "4.46.2",
|
124 |
"use_weighted_layer_sum": false,
|
125 |
"vocab_size": 32,
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "facebook/wav2vec2-base",
|
3 |
"activation_dropout": 0.0,
|
4 |
"adapter_attn_dim": null,
|
5 |
"adapter_kernel_size": 3,
|
|
|
7 |
"add_adapter": false,
|
8 |
"apply_spec_augment": true,
|
9 |
"architectures": [
|
10 |
+
"Wav2Vec2SpkRegModel"
|
11 |
],
|
12 |
"attention_dropout": 0.1,
|
13 |
"auto_map": {
|
14 |
+
"AutoConfig": "configuration_wav2vec2_spkreg.Wav2Vec2SpkRegConfig",
|
15 |
+
"AutoModel": "modeling_wav2vec2_spkreg.Wav2Vec2SpkRegModel"
|
16 |
},
|
17 |
"bos_token_id": 1,
|
18 |
"classifier_proj_size": 256,
|
|
|
58 |
"feat_quantizer_dropout": 0.0,
|
59 |
"final_dropout": 0.0,
|
60 |
"freeze_feat_extract_train": true,
|
|
|
61 |
"hidden_act": "gelu",
|
62 |
"hidden_dropout": 0.1,
|
63 |
"hidden_size": 768,
|
|
|
121 |
1,
|
122 |
1
|
123 |
],
|
124 |
+
"torch_dtype": "float32",
|
125 |
"transformers_version": "4.46.2",
|
126 |
"use_weighted_layer_sum": false,
|
127 |
"vocab_size": 32,
|
modeling_wav2vec2_spkreg.py
CHANGED
@@ -519,14 +519,13 @@ class AngularLinear(nn.Module):
|
|
519 |
|
520 |
|
521 |
class AMSoftmaxLoss(nn.Module):
|
522 |
-
"""Additive Margin Softmax
|
523 |
|
524 |
Paper: Wang, Feng, et al. "Additive margin softmax for face verification."
|
525 |
IEEE Signal Processing Letters 25.7 (2018): 926-930.
|
526 |
"""
|
527 |
def __init__(
|
528 |
self,
|
529 |
-
num_labels: int,
|
530 |
scale: float = 30.0,
|
531 |
margin: float = 0.35,
|
532 |
label_smoothing: float = 0.0,
|
@@ -539,7 +538,6 @@ class AMSoftmaxLoss(nn.Module):
|
|
539 |
margin: Angular margin (default: 0.35)
|
540 |
"""
|
541 |
super(AMSoftmaxLoss, self).__init__()
|
542 |
-
self.num_labels = num_labels
|
543 |
self.scale = scale
|
544 |
self.margin = margin
|
545 |
self.label_smoothing = label_smoothing
|
@@ -559,11 +557,12 @@ class AMSoftmaxLoss(nn.Module):
|
|
559 |
Returns:
|
560 |
Loss value
|
561 |
"""
|
|
|
562 |
# `inputs` are the outputs from AngularLinear()
|
563 |
-
|
564 |
-
psi =
|
565 |
-
one_hot = nn.functional.one_hot(targets,
|
566 |
-
outputs = self.scale * torch.where(one_hot.bool(), psi,
|
567 |
loss = F.cross_entropy(
|
568 |
outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction
|
569 |
)
|
@@ -571,14 +570,13 @@ class AMSoftmaxLoss(nn.Module):
|
|
571 |
|
572 |
|
573 |
class AAMSoftmaxLoss(nn.Module):
|
574 |
-
"""Additive Angular Margin Softmax.
|
575 |
|
576 |
Paper: Deng, Jiankang, et al. "Arcface: Additive angular margin loss for deep face recognition."
|
577 |
Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019.
|
578 |
"""
|
579 |
def __init__(
|
580 |
self,
|
581 |
-
num_labels: int,
|
582 |
scale: float = 30.0,
|
583 |
margin: float = 0.35,
|
584 |
easy_margin: bool = False,
|
@@ -593,7 +591,6 @@ class AAMSoftmaxLoss(nn.Module):
|
|
593 |
easy_margin: Use the easy margin loss (default: False)
|
594 |
"""
|
595 |
super(AAMSoftmaxLoss, self).__init__()
|
596 |
-
self.num_labels = num_labels
|
597 |
self.scale = scale
|
598 |
self.margin = margin
|
599 |
self.easy_margin = easy_margin
|
@@ -604,37 +601,21 @@ class AAMSoftmaxLoss(nn.Module):
|
|
604 |
self,
|
605 |
inputs: torch.Tensor,
|
606 |
targets: torch.Tensor,
|
607 |
-
label_smoothing: float = 0.0,
|
608 |
-
reduction: str = "mean"
|
609 |
):
|
610 |
"""
|
611 |
Args:
|
612 |
inputs: Input features of shape (batch_size, num_labels)
|
613 |
targets: Ground truth labels of shape (batch_size)
|
614 |
-
label_smoothing: Label smoothing factor (default: 0.0)
|
615 |
-
reduction: Reduction method (default: "mean")
|
616 |
Returns:
|
617 |
Loss value
|
618 |
"""
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
mm = math.sin(math.pi - self.margin) * self.margin
|
627 |
-
|
628 |
-
if self.easy_margin:
|
629 |
-
phi = torch.where(cosine > 0, phi, cosine)
|
630 |
-
else:
|
631 |
-
phi = torch.where((cosine - th) > 0, phi, cosine - mm)
|
632 |
-
|
633 |
-
one_hot = torch.zeros_like(cosine)
|
634 |
-
one_hot.scatter_(1, targets.view(-1, 1), 1)
|
635 |
-
outputs = (one_hot * phi) + ((1.0 - one_hot) * cosine)
|
636 |
-
outputs = outputs * self.scale
|
637 |
-
|
638 |
loss = F.cross_entropy(
|
639 |
outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction
|
640 |
)
|
@@ -749,18 +730,16 @@ class Wav2Vec2SpkRegForSequenceClassification(Wav2Vec2SpkRegPreTrainedModel):
|
|
749 |
)
|
750 |
elif self.config.loss_fct == 'additive_margin':
|
751 |
loss_fct = AMSoftmaxLoss(
|
752 |
-
self.config.
|
753 |
-
self.config.
|
754 |
-
self.config.margin,
|
755 |
label_smoothing=self.config.label_smoothing,
|
756 |
reduction=self.config.reduction
|
757 |
)
|
758 |
elif self.config.loss_fct == 'additive_angular_margin':
|
759 |
loss_fct = AAMSoftmaxLoss(
|
760 |
-
self.config.
|
761 |
-
self.config.
|
762 |
-
self.config.
|
763 |
-
self.config.easy_margin,
|
764 |
label_smoothing=self.config.label_smoothing,
|
765 |
reduction=self.config.reduction
|
766 |
)
|
|
|
519 |
|
520 |
|
521 |
class AMSoftmaxLoss(nn.Module):
|
522 |
+
"""Additive Margin Softmax (CosFace).
|
523 |
|
524 |
Paper: Wang, Feng, et al. "Additive margin softmax for face verification."
|
525 |
IEEE Signal Processing Letters 25.7 (2018): 926-930.
|
526 |
"""
|
527 |
def __init__(
|
528 |
self,
|
|
|
529 |
scale: float = 30.0,
|
530 |
margin: float = 0.35,
|
531 |
label_smoothing: float = 0.0,
|
|
|
538 |
margin: Angular margin (default: 0.35)
|
539 |
"""
|
540 |
super(AMSoftmaxLoss, self).__init__()
|
|
|
541 |
self.scale = scale
|
542 |
self.margin = margin
|
543 |
self.label_smoothing = label_smoothing
|
|
|
557 |
Returns:
|
558 |
Loss value
|
559 |
"""
|
560 |
+
_, num_labels = inputs.shape
|
561 |
# `inputs` are the outputs from AngularLinear()
|
562 |
+
cos_theta = torch.clamp(inputs, -1.0 + 1e-7, 1.0 - 1e-7)
|
563 |
+
psi = cos_theta - self.margin
|
564 |
+
one_hot = nn.functional.one_hot(targets, num_labels)
|
565 |
+
outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
|
566 |
loss = F.cross_entropy(
|
567 |
outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction
|
568 |
)
|
|
|
570 |
|
571 |
|
572 |
class AAMSoftmaxLoss(nn.Module):
|
573 |
+
"""Additive Angular Margin Softmax (ArcFace).
|
574 |
|
575 |
Paper: Deng, Jiankang, et al. "Arcface: Additive angular margin loss for deep face recognition."
|
576 |
Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019.
|
577 |
"""
|
578 |
def __init__(
|
579 |
self,
|
|
|
580 |
scale: float = 30.0,
|
581 |
margin: float = 0.35,
|
582 |
easy_margin: bool = False,
|
|
|
591 |
easy_margin: Use the easy margin loss (default: False)
|
592 |
"""
|
593 |
super(AAMSoftmaxLoss, self).__init__()
|
|
|
594 |
self.scale = scale
|
595 |
self.margin = margin
|
596 |
self.easy_margin = easy_margin
|
|
|
601 |
self,
|
602 |
inputs: torch.Tensor,
|
603 |
targets: torch.Tensor,
|
|
|
|
|
604 |
):
|
605 |
"""
|
606 |
Args:
|
607 |
inputs: Input features of shape (batch_size, num_labels)
|
608 |
targets: Ground truth labels of shape (batch_size)
|
|
|
|
|
609 |
Returns:
|
610 |
Loss value
|
611 |
"""
|
612 |
+
_, num_labels = inputs.shape
|
613 |
+
# `inputs` are the outputs from AngularLinear()
|
614 |
+
cos_theta = torch.clamp(inputs, -1.0 + 1e-7, 1.0 - 1e-7)
|
615 |
+
theta = torch.acos(cos_theta)
|
616 |
+
psi = torch.cos(theta + self.margin)
|
617 |
+
one_hot = nn.functional.one_hot(targets, num_labels)
|
618 |
+
outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
619 |
loss = F.cross_entropy(
|
620 |
outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction
|
621 |
)
|
|
|
730 |
)
|
731 |
elif self.config.loss_fct == 'additive_margin':
|
732 |
loss_fct = AMSoftmaxLoss(
|
733 |
+
scale=self.config.scale,
|
734 |
+
margin=self.config.margin,
|
|
|
735 |
label_smoothing=self.config.label_smoothing,
|
736 |
reduction=self.config.reduction
|
737 |
)
|
738 |
elif self.config.loss_fct == 'additive_angular_margin':
|
739 |
loss_fct = AAMSoftmaxLoss(
|
740 |
+
scale=self.config.scale,
|
741 |
+
margin=self.config.margin,
|
742 |
+
easy_margin=self.config.easy_margin,
|
|
|
743 |
label_smoothing=self.config.label_smoothing,
|
744 |
reduction=self.config.reduction
|
745 |
)
|