|
import math |
|
import warnings |
|
from typing import Union, Tuple, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13 |
|
from transformers.modeling_outputs import SequenceClassifierOutput, Wav2Vec2BaseModelOutput |
|
from transformers.models.wav2vec2.modeling_wav2vec2 import ( |
|
Wav2Vec2ForPreTraining, |
|
Wav2Vec2GumbelVectorQuantizer, |
|
Wav2Vec2PositionalConvEmbedding, |
|
Wav2Vec2FeatureProjection, |
|
Wav2Vec2AttnAdapterLayer, |
|
Wav2Vec2ForCTC, |
|
Wav2Vec2FeatureEncoder, |
|
Wav2Vec2EncoderStableLayerNorm, |
|
Wav2Vec2Encoder, |
|
Wav2Vec2Adapter, |
|
safe_load_file, |
|
_compute_mask_indices, |
|
_HIDDEN_STATES_START_POSITION, |
|
WAV2VEC2_ADAPTER_SAFE_FILE, |
|
WAV2VEC2_ADAPTER_PT_FILE |
|
) |
|
from transformers.utils import ( |
|
cached_file, |
|
is_safetensors_available, |
|
logging, |
|
) |
|
|
|
from .configuration_wav2vec2_spkreg import Wav2Vec2SpkRegConfig |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class Wav2Vec2SpkRegPreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
models. |
|
""" |
|
|
|
config_class = Wav2Vec2SpkRegConfig |
|
base_model_prefix = "wav2vec2" |
|
main_input_name = "input_values" |
|
supports_gradient_checkpointing = True |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
|
|
if isinstance(module, Wav2Vec2ForPreTraining): |
|
module.project_hid.reset_parameters() |
|
module.project_q.reset_parameters() |
|
module.project_hid._is_hf_initialized = True |
|
module.project_q._is_hf_initialized = True |
|
|
|
elif isinstance(module, Wav2Vec2GumbelVectorQuantizer): |
|
module.weight_proj.weight.data.normal_(mean=0.0, std=1) |
|
module.weight_proj.bias.data.zero_() |
|
nn.init.uniform_(module.codevectors) |
|
elif isinstance(module, Wav2Vec2PositionalConvEmbedding): |
|
nn.init.normal_( |
|
module.conv.weight, |
|
mean=0, |
|
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), |
|
) |
|
nn.init.constant_(module.conv.bias, 0) |
|
elif isinstance(module, Wav2Vec2FeatureProjection): |
|
k = math.sqrt(1 / module.projection.in_features) |
|
nn.init.uniform_(module.projection.weight, a=-k, b=k) |
|
nn.init.uniform_(module.projection.bias, a=-k, b=k) |
|
elif isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
elif isinstance(module, nn.Conv1d): |
|
nn.init.kaiming_normal_(module.weight) |
|
|
|
if module.bias is not None: |
|
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) |
|
nn.init.uniform_(module.bias, a=-k, b=k) |
|
|
|
def _get_feat_extract_output_lengths( |
|
self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None |
|
): |
|
""" |
|
Computes the output length of the convolutional layers |
|
""" |
|
|
|
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter |
|
|
|
def _conv_out_length(input_length, kernel_size, stride): |
|
|
|
|
|
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 |
|
|
|
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): |
|
input_lengths = _conv_out_length(input_lengths, kernel_size, stride) |
|
|
|
if add_adapter: |
|
for _ in range(self.config.num_adapter_layers): |
|
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) |
|
|
|
return input_lengths |
|
|
|
def _get_feature_vector_attention_mask( |
|
self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None |
|
): |
|
|
|
|
|
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] |
|
|
|
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) |
|
output_lengths = output_lengths.to(torch.long) |
|
|
|
batch_size = attention_mask.shape[0] |
|
|
|
attention_mask = torch.zeros( |
|
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device |
|
) |
|
|
|
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 |
|
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() |
|
return attention_mask |
|
|
|
def _get_adapters(self): |
|
if self.config.adapter_attn_dim is None: |
|
raise ValueError(f"{self.__class__} has no adapter layers. Make sure to define `config.adapter_attn_dim`.") |
|
|
|
adapter_weights = {} |
|
for name, module in self.named_modules(): |
|
if isinstance(module, Wav2Vec2AttnAdapterLayer): |
|
for param_name, param in module.named_parameters(): |
|
adapter_weights[".".join([name, param_name])] = param |
|
|
|
if isinstance(self, Wav2Vec2ForCTC): |
|
for name, param in self.lm_head.named_parameters(): |
|
adapter_weights[".".join(["lm_head", name])] = param |
|
|
|
return adapter_weights |
|
|
|
def init_adapter_layers(self): |
|
""" |
|
(Re-)initialize attention adapter layers and lm head for adapter-only fine-tuning |
|
""" |
|
|
|
for module in self.modules(): |
|
if isinstance(module, Wav2Vec2AttnAdapterLayer): |
|
self._init_weights(module) |
|
|
|
|
|
if isinstance(self, Wav2Vec2ForCTC): |
|
self._init_weights(self.lm_head) |
|
|
|
def load_adapter(self, target_lang: str, force_load=True, **kwargs): |
|
r""" |
|
Load a language adapter model from a pre-trained adapter model. |
|
|
|
Parameters: |
|
target_lang (`str`): |
|
Has to be a language id of an existing adapter weight. Adapter weights are stored in the format |
|
adapter.<lang>.safetensors or adapter.<lang>.bin |
|
force_load (`bool`, defaults to `True`): |
|
Whether the weights shall be loaded even if `target_lang` matches `self.target_lang`. |
|
cache_dir (`Union[str, os.PathLike]`, *optional*): |
|
Path to a directory in which a downloaded pretrained model configuration should be cached if the |
|
standard cache should not be used. |
|
force_download (`bool`, *optional*, defaults to `False`): |
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the |
|
cached versions if they exist. |
|
resume_download: |
|
Deprecated and ignored. All downloads are now resumed by default when possible. |
|
Will be removed in v5 of Transformers. |
|
proxies (`Dict[str, str]`, *optional*): |
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', |
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. |
|
local_files_only(`bool`, *optional*, defaults to `False`): |
|
Whether or not to only look at local files (i.e., do not try to download the model). |
|
token (`str` or `bool`, *optional*): |
|
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use |
|
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). |
|
revision (`str`, *optional*, defaults to `"main"`): |
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a |
|
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any |
|
identifier allowed by git. |
|
|
|
<Tip> |
|
|
|
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`. |
|
|
|
</Tip> |
|
|
|
mirror (`str`, *optional*): |
|
Mirror source to accelerate downloads in China. If you are from China and have an accessibility |
|
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. |
|
Please refer to the mirror site for more information. |
|
|
|
<Tip> |
|
|
|
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to |
|
use this method in a firewalled environment. |
|
|
|
</Tip> |
|
|
|
Examples: |
|
|
|
```python |
|
>>> from transformers import Wav2Vec2ForCTC, AutoProcessor |
|
|
|
>>> ckpt = "facebook/mms-1b-all" |
|
>>> processor = AutoProcessor.from_pretrained(ckpt) |
|
>>> model = Wav2Vec2ForCTC.from_pretrained(ckpt, target_lang="eng") |
|
>>> # set specific language |
|
>>> processor.tokenizer.set_target_lang("spa") |
|
>>> model.load_adapter("spa") |
|
``` |
|
""" |
|
if self.config.adapter_attn_dim is None: |
|
raise ValueError(f"Cannot load_adapter for {target_lang} if `config.adapter_attn_dim` is not defined.") |
|
|
|
if target_lang == self.target_lang and not force_load: |
|
logger.warning(f"Adapter weights are already set to {target_lang}.") |
|
return |
|
|
|
cache_dir = kwargs.pop("cache_dir", None) |
|
force_download = kwargs.pop("force_download", False) |
|
resume_download = kwargs.pop("resume_download", None) |
|
proxies = kwargs.pop("proxies", None) |
|
local_files_only = kwargs.pop("local_files_only", False) |
|
token = kwargs.pop("token", None) |
|
use_auth_token = kwargs.pop("use_auth_token", None) |
|
revision = kwargs.pop("revision", None) |
|
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) |
|
|
|
if use_auth_token is not None: |
|
warnings.warn( |
|
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", |
|
FutureWarning, |
|
) |
|
if token is not None: |
|
raise ValueError( |
|
"`token` and `use_auth_token` are both specified. Please set only the argument `token`." |
|
) |
|
token = use_auth_token |
|
|
|
model_path_or_id = self.config._name_or_path |
|
state_dict = None |
|
|
|
|
|
if use_safetensors is not False: |
|
filepath = WAV2VEC2_ADAPTER_SAFE_FILE.format(target_lang) |
|
|
|
try: |
|
weight_path = cached_file( |
|
model_path_or_id, |
|
filename=filepath, |
|
force_download=force_download, |
|
resume_download=resume_download, |
|
proxies=proxies, |
|
local_files_only=local_files_only, |
|
token=token, |
|
revision=revision, |
|
cache_dir=cache_dir, |
|
) |
|
|
|
state_dict = safe_load_file(weight_path) |
|
|
|
except EnvironmentError: |
|
if use_safetensors: |
|
|
|
|
|
raise |
|
|
|
except Exception: |
|
|
|
if use_safetensors: |
|
raise EnvironmentError( |
|
f"Can't load the model for '{model_path_or_id}'. If you were trying to load it" |
|
" from 'https://huggingface.co/models', make sure you don't have a local directory with the" |
|
f" same name. Otherwise, make sure '{model_path_or_id}' is the correct path to a" |
|
f" directory containing a file named {filepath}." |
|
) |
|
|
|
|
|
if state_dict is None: |
|
filepath = WAV2VEC2_ADAPTER_PT_FILE.format(target_lang) |
|
|
|
try: |
|
weight_path = cached_file( |
|
model_path_or_id, |
|
filename=filepath, |
|
force_download=force_download, |
|
resume_download=resume_download, |
|
proxies=proxies, |
|
local_files_only=local_files_only, |
|
token=token, |
|
revision=revision, |
|
cache_dir=cache_dir, |
|
) |
|
|
|
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} |
|
state_dict = torch.load( |
|
weight_path, |
|
map_location="cpu", |
|
**weights_only_kwarg, |
|
) |
|
|
|
except EnvironmentError: |
|
|
|
|
|
raise |
|
|
|
except Exception: |
|
|
|
raise EnvironmentError( |
|
f"Can't load the model for '{model_path_or_id}'. If you were trying to load it" |
|
" from 'https://huggingface.co/models', make sure you don't have a local directory with the" |
|
f" same name. Otherwise, make sure '{model_path_or_id}' is the correct path to a" |
|
f" directory containing a file named {filepath}." |
|
) |
|
|
|
adapter_weights = self._get_adapters() |
|
unexpected_keys = set(state_dict.keys()) - set(adapter_weights.keys()) |
|
missing_keys = set(adapter_weights.keys()) - set(state_dict.keys()) |
|
|
|
if len(unexpected_keys) > 0: |
|
raise ValueError(f"The adapter weights {weight_path} has unexpected keys: {', '.join(unexpected_keys)}.") |
|
elif len(missing_keys) > 0: |
|
raise ValueError(f"The adapter weights {weight_path} has missing keys: {', '.join(missing_keys)}.") |
|
|
|
|
|
target_vocab_size = state_dict["lm_head.weight"].shape[0] |
|
if target_vocab_size != self.config.vocab_size: |
|
self.lm_head = nn.Linear( |
|
self.config.output_hidden_size, target_vocab_size, device=self.device, dtype=self.dtype |
|
) |
|
self.config.vocab_size = target_vocab_size |
|
|
|
|
|
state_dict = {k: v.to(adapter_weights[k]) for k, v in state_dict.items()} |
|
self.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
self.target_lang = target_lang |
|
|
|
|
|
class Wav2Vec2SpkRegModel(Wav2Vec2SpkRegPreTrainedModel): |
|
|
|
def __init__(self, config: Wav2Vec2SpkRegConfig): |
|
super().__init__(config) |
|
self.config = config |
|
self.feature_extractor = Wav2Vec2FeatureEncoder(config) |
|
self.feature_projection = Wav2Vec2FeatureProjection(config) |
|
|
|
|
|
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: |
|
self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) |
|
|
|
if config.do_stable_layer_norm: |
|
self.encoder = Wav2Vec2EncoderStableLayerNorm(config) |
|
else: |
|
self.encoder = Wav2Vec2Encoder(config) |
|
|
|
self.adapter = Wav2Vec2Adapter(config) if config.add_adapter else None |
|
|
|
|
|
self.post_init() |
|
|
|
def freeze_feature_extractor(self): |
|
""" |
|
Calling this function will disable the gradient computation for the feature encoder so that its parameters will |
|
not be updated during training. |
|
""" |
|
warnings.warn( |
|
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " |
|
"Please use the equivalent `freeze_feature_encoder` method instead.", |
|
FutureWarning, |
|
) |
|
self.freeze_feature_encoder() |
|
|
|
def freeze_feature_encoder(self): |
|
""" |
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will |
|
not be updated during training. |
|
""" |
|
self.feature_extractor._freeze_parameters() |
|
|
|
def _mask_hidden_states( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
mask_time_indices: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
): |
|
""" |
|
Masks extracted features along time axis and/or along feature axis according to |
|
[SpecAugment](https://arxiv.org/abs/1904.08779). |
|
""" |
|
|
|
|
|
if not getattr(self.config, "apply_spec_augment", True): |
|
return hidden_states |
|
|
|
|
|
batch_size, sequence_length, hidden_size = hidden_states.size() |
|
|
|
if mask_time_indices is not None: |
|
|
|
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) |
|
elif self.config.mask_time_prob > 0 and self.training: |
|
mask_time_indices = _compute_mask_indices( |
|
(batch_size, sequence_length), |
|
mask_prob=self.config.mask_time_prob, |
|
mask_length=self.config.mask_time_length, |
|
attention_mask=attention_mask, |
|
min_masks=self.config.mask_time_min_masks, |
|
) |
|
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) |
|
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) |
|
|
|
if self.config.mask_feature_prob > 0 and self.training: |
|
|
|
mask_feature_indices = _compute_mask_indices( |
|
(batch_size, hidden_size), |
|
mask_prob=self.config.mask_feature_prob, |
|
mask_length=self.config.mask_feature_length, |
|
min_masks=self.config.mask_feature_min_masks, |
|
) |
|
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) |
|
mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) |
|
hidden_states[mask_feature_indices] = 0 |
|
|
|
return hidden_states |
|
|
|
def forward( |
|
self, |
|
input_values: Optional[torch.Tensor], |
|
attention_mask: Optional[torch.Tensor] = None, |
|
mask_time_indices: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, Wav2Vec2BaseModelOutput]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
extract_features = self.feature_extractor(input_values) |
|
extract_features = extract_features.transpose(1, 2) |
|
|
|
if attention_mask is not None: |
|
|
|
attention_mask = self._get_feature_vector_attention_mask( |
|
extract_features.shape[1], attention_mask, add_adapter=False |
|
) |
|
|
|
hidden_states, extract_features = self.feature_projection(extract_features) |
|
hidden_states = self._mask_hidden_states( |
|
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask |
|
) |
|
|
|
encoder_outputs = self.encoder( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = encoder_outputs[0] |
|
|
|
if self.adapter is not None: |
|
hidden_states = self.adapter(hidden_states) |
|
|
|
if not return_dict: |
|
return (hidden_states, extract_features) + encoder_outputs[1:] |
|
|
|
return Wav2Vec2BaseModelOutput( |
|
last_hidden_state=hidden_states, |
|
extract_features=extract_features, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
) |
|
|
|
|
|
class AngularLinear(nn.Module): |
|
|
|
def __init__(self, in_features: int, out_features: int): |
|
super(AngularLinear, self).__init__() |
|
self.in_features = in_features |
|
self.out_features = out_features |
|
self.weight = torch.nn.Parameter( |
|
torch.FloatTensor(out_features, in_features), requires_grad=True |
|
) |
|
nn.init.xavier_normal_(self.weight, gain=1) |
|
|
|
def forward( |
|
self, |
|
inputs: torch.Tensor, |
|
): |
|
|
|
cosine = F.linear(F.normalize(inputs), F.normalize(self.weight)) |
|
return cosine |
|
|
|
def extra_repr(self) -> str: |
|
return 'in_features={}, out_features={}'.format( |
|
self.in_features, self.out_features |
|
) |
|
|
|
|
|
class AMSoftmaxLoss(nn.Module): |
|
"""Additive Margin Softmax (CosFace). |
|
|
|
Paper: Wang, Feng, et al. "Additive margin softmax for face verification." |
|
IEEE Signal Processing Letters 25.7 (2018): 926-930. |
|
""" |
|
def __init__( |
|
self, |
|
scale: float = 30.0, |
|
margin: float = 0.35, |
|
label_smoothing: float = 0.0, |
|
reduction: str = "mean" |
|
): |
|
""" |
|
Args: |
|
num_classes: Number of classes (output dimension) |
|
scale: Scaling factor for logits (default: 30.0) |
|
margin: Angular margin (default: 0.35) |
|
""" |
|
super(AMSoftmaxLoss, self).__init__() |
|
self.scale = scale |
|
self.margin = margin |
|
self.label_smoothing = label_smoothing |
|
self.reduction = reduction |
|
|
|
def forward( |
|
self, |
|
inputs: torch.Tensor, |
|
targets: torch.Tensor, |
|
): |
|
""" |
|
Args: |
|
inputs: Input features of shape (batch_size, num_labels) |
|
targets: Ground truth labels of shape (batch_size) |
|
label_smoothing: Label smoothing factor (default: 0.0) |
|
reduction: Reduction method (default: "mean") |
|
Returns: |
|
Loss value |
|
""" |
|
_, num_labels = inputs.shape |
|
|
|
cos_theta = torch.clamp(inputs, -1.0 + 1e-7, 1.0 - 1e-7) |
|
psi = cos_theta - self.margin |
|
one_hot = nn.functional.one_hot(targets, num_labels) |
|
outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta) |
|
loss = F.cross_entropy( |
|
outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction |
|
) |
|
return loss |
|
|
|
|
|
class AAMSoftmaxLoss(nn.Module): |
|
"""Additive Angular Margin Softmax (ArcFace). |
|
|
|
Paper: Deng, Jiankang, et al. "Arcface: Additive angular margin loss for deep face recognition." |
|
Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019. |
|
""" |
|
def __init__( |
|
self, |
|
scale: float = 30.0, |
|
margin: float = 0.35, |
|
easy_margin: bool = False, |
|
label_smoothing: float = 0.0, |
|
reduction: str = "mean" |
|
): |
|
""" |
|
Args: |
|
num_classes: Number of classes (output dimension) |
|
scale: Scaling factor for logits (default: 30.0) |
|
margin: Angular margin (default: 0.35) |
|
easy_margin: Use the easy margin loss (default: False) |
|
""" |
|
super(AAMSoftmaxLoss, self).__init__() |
|
self.scale = scale |
|
self.margin = margin |
|
self.easy_margin = easy_margin |
|
self.label_smoothing = label_smoothing |
|
self.reduction = reduction |
|
|
|
self.cos_m = math.cos(self.margin) |
|
self.sin_m = math.sin(self.margin) |
|
|
|
def forward( |
|
self, |
|
inputs: torch.Tensor, |
|
targets: torch.Tensor, |
|
): |
|
""" |
|
Args: |
|
inputs: Input features of shape (batch_size, num_labels) |
|
targets: Ground truth labels of shape (batch_size) |
|
Returns: |
|
Loss value |
|
""" |
|
_, num_labels = inputs.shape |
|
|
|
cos_theta = torch.clamp(inputs, -1.0, 1.0) |
|
sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2)) |
|
psi = cos_theta * self.cos_m - sin_theta * self.sin_m |
|
|
|
|
|
one_hot = nn.functional.one_hot(targets, num_labels) |
|
outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta) |
|
loss = F.cross_entropy( |
|
outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction |
|
) |
|
return loss |
|
|
|
|
|
class Wav2Vec2SpkRegForSequenceClassification(Wav2Vec2SpkRegPreTrainedModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
if hasattr(config, "add_adapter") and config.add_adapter: |
|
raise ValueError( |
|
"Sequence classification does not support the use of Wav2Vec2 adapters (config.add_adapter=True)" |
|
) |
|
self.wav2vec2 = Wav2Vec2SpkRegModel(config) |
|
num_layers = config.num_hidden_layers + 1 |
|
if config.use_weighted_layer_sum: |
|
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) |
|
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) |
|
|
|
if self.config.loss_fct == 'cross_entropy': |
|
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) |
|
elif self.config.loss_fct == 'additive_margin': |
|
self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels) |
|
elif self.config.loss_fct == 'additive_angular_margin': |
|
self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels) |
|
else: |
|
raise ValueError(f"Unsupported loss function: {self.config.loss_fct}") |
|
|
|
|
|
self.post_init() |
|
|
|
def freeze_feature_extractor(self): |
|
""" |
|
Calling this function will disable the gradient computation for the feature encoder so that its parameters will |
|
not be updated during training. |
|
""" |
|
warnings.warn( |
|
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " |
|
"Please use the equivalent `freeze_feature_encoder` method instead.", |
|
FutureWarning, |
|
) |
|
self.freeze_feature_encoder() |
|
|
|
def freeze_feature_encoder(self): |
|
""" |
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will |
|
not be updated during training. |
|
""" |
|
self.wav2vec2.feature_extractor._freeze_parameters() |
|
|
|
def freeze_base_model(self): |
|
""" |
|
Calling this function will disable the gradient computation for the base model so that its parameters will not |
|
be updated during training. Only the classification head will be updated. |
|
""" |
|
for param in self.wav2vec2.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward( |
|
self, |
|
input_values: Optional[torch.Tensor], |
|
attention_mask: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
) -> Union[Tuple, SequenceClassifierOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
""" |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states |
|
|
|
outputs = self.wav2vec2( |
|
input_values, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
if self.config.use_weighted_layer_sum: |
|
hidden_states = outputs[_HIDDEN_STATES_START_POSITION] |
|
hidden_states = torch.stack(hidden_states, dim=1) |
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) |
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) |
|
else: |
|
hidden_states = outputs[0] |
|
|
|
hidden_states = self.projector(hidden_states) |
|
if attention_mask is None: |
|
pooled_output = hidden_states.mean(dim=1) |
|
else: |
|
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) |
|
hidden_states[~padding_mask] = 0.0 |
|
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) |
|
|
|
logits = self.classifier(pooled_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
if self.config.loss_fct == 'cross_entropy': |
|
loss_fct = nn.CrossEntropyLoss( |
|
label_smoothing=self.config.label_smoothing, |
|
reduction=self.config.reduction |
|
) |
|
elif self.config.loss_fct == 'additive_margin': |
|
loss_fct = AMSoftmaxLoss( |
|
scale=self.config.scale, |
|
margin=self.config.margin, |
|
label_smoothing=self.config.label_smoothing, |
|
reduction=self.config.reduction |
|
) |
|
elif self.config.loss_fct == 'additive_angular_margin': |
|
loss_fct = AAMSoftmaxLoss( |
|
scale=self.config.scale, |
|
margin=self.config.margin, |
|
easy_margin=self.config.easy_margin, |
|
label_smoothing=self.config.label_smoothing, |
|
reduction=self.config.reduction |
|
) |
|
loss = loss_fct( |
|
logits.view(-1, self.config.num_labels), |
|
labels.view(-1), |
|
) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |