from transformers import Wav2Vec2PreTrainedModel, Wav2Vec2Model from transformers.modeling_outputs import CausalLMOutput from typing import Optional, Tuple, Union import warnings import torch import torch.nn as nn import math _HIDDEN_STATES_START_POSITION = 2 def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1. + math.erf(x / math.sqrt(2.))) / 2. with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): return _no_grad_trunc_normal_(tensor, mean, std, a, b) class Wav2Vec2ForWav2Vec2ForCTCAndUttranceRegression(Wav2Vec2PreTrainedModel): def __init__(self, config, target_lang: Optional[str] = None): super().__init__(config) self.wav2vec2 = Wav2Vec2Model(config) self.dropout = nn.Dropout(config.final_dropout) self.target_lang = target_lang if config.vocab_size is None: raise ValueError( f"You are trying to instantiate {self.__class__} with a configuration that " "does not define the vocabulary size of the language model head. Please " "instantiate the model as follows: `Wav2Vec2ForCTC.from_pretrained(..., vocab_size=vocab_size)`. " "or define `vocab_size` of your model's configuration." ) output_hidden_size = ( config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size ) self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) # utterance level, 1=accuracy, 2=fluency, 3=total score, 4=cotent self.cls_token1 = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.mlp_head_utt1 = nn.Sequential(nn.LayerNorm(config.hidden_size), nn.Linear(config.hidden_size, 1)) self.cls_token2 = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.mlp_head_utt2 = nn.Sequential(nn.LayerNorm(config.hidden_size), nn.Linear(config.hidden_size, 1)) self.cls_token3 = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.mlp_head_utt3 = nn.Sequential(nn.LayerNorm(config.hidden_size), nn.Linear(config.hidden_size, 1)) self.cls_token4 = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.mlp_head_utt4 = nn.Sequential(nn.LayerNorm(config.hidden_size), nn.Linear(config.hidden_size, 1)) self.post_init() # initialize the cls tokens trunc_normal_(self.cls_token1, std=.092) trunc_normal_(self.cls_token2, std=.01) trunc_normal_(self.cls_token3, std=.052) trunc_normal_(self.cls_token4, std=.02) # Initialize weights and apply final processing def tie_weights(self): """ This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when passing `target_lang=...` to `from_pretrained(...)`. This method is **not** supposed to be called by the user and is prone to be changed in the future. """ # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to # correctly load adapter layers for Wav2Vec2 so that we do not have to introduce a new API to # [`PreTrainedModel`]. While slightly hacky, Wav2Vec2 never has to tie input and output embeddings, so that it is # ok to repurpose this function here. target_lang = self.target_lang if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: print("By default `target_lang` is set to 'eng'.") elif target_lang is not None: self.load_adapter(target_lang, force_load=True) def freeze_feature_extractor(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameters will not be updated during training. """ warnings.warn( "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " "Please use the equivalent `freeze_feature_encoder` method instead.", FutureWarning, ) self.freeze_feature_encoder() def freeze_feature_encoder(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameter will not be updated during training. """ self.wav2vec2.feature_extractor._freeze_parameters() def freeze_base_model(self): """ Calling this function will disable the gradient computation for the base model so that its parameters will not be updated during training. Only the classification head will be updated. """ for param in self.wav2vec2.parameters(): param.requires_grad = False def forward( self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.Tensor] = None, ) -> Union[Tuple, CausalLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict B, T = input_values.size() extract_features = self.wav2vec2.feature_extractor(input_values) extract_features = extract_features.transpose(1, 2) if attention_mask is not None: # compute reduced attention_mask corresponding to feature vectors attention_mask = self.wav2vec2._get_feature_vector_attention_mask( extract_features.shape[1], attention_mask, add_adapter=False ) hidden_states, extract_features = self.wav2vec2.feature_projection(extract_features) hidden_states = self.wav2vec2._mask_hidden_states( hidden_states, mask_time_indices=None, attention_mask=attention_mask ) cls_token1 = self.cls_token1.expand(B, -1, -1) cls_token2 = self.cls_token2.expand(B, -1, -1) cls_token3 = self.cls_token3.expand(B, -1, -1) cls_token4 = self.cls_token4.expand(B, -1, -1) hidden_states = torch.cat((cls_token1, cls_token2, cls_token3, cls_token4, hidden_states), dim=1) #cls_token4 # hidden_states = torch.cat((cls_token1, cls_token3, hidden_states), dim=1) #cls_token4 outputs = self.wav2vec2.encoder( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] hidden_states = self.dropout(hidden_states) # the first 4 tokens are utterance-level cls tokens, i.e., accuracy, fluency, total scores, content u1 = self.mlp_head_utt1(hidden_states[:, 0]) u2 = self.mlp_head_utt2(hidden_states[:, 1]) u3 = self.mlp_head_utt3(hidden_states[:, 2]) u4 = self.mlp_head_utt4(hidden_states[:, 3]) logits = self.lm_head(hidden_states[:, 4:]) loss = None if labels is not None: labels, utt_label = labels['labels'], labels['utt_label'][:, :4] if labels.max() >= self.config.vocab_size: raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") # retrieve loss input_lengths from attention_mask attention_mask = ( attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) ) input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) # assuming that padded tokens are filled with -100 # when not being attended to labels_mask = labels >= 0 target_lengths = labels_mask.sum(-1) flattened_targets = labels.masked_select(labels_mask) log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) with torch.backends.cudnn.flags(enabled=False): # utterance level loss, also mse utt_preds = torch.cat((u1, u2, u3, u4), dim=1) # utt_preds = torch.cat((u1, u2), dim=1) loss_utt = nn.functional.mse_loss(utt_preds ,utt_label) loss_ph = nn.functional.ctc_loss( log_probs, flattened_targets, input_lengths, target_lengths, blank=self.config.pad_token_id, reduction=self.config.ctc_loss_reduction, zero_infinity=self.config.ctc_zero_infinity, ) loss = loss_utt + loss_ph if not return_dict: output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] return ((loss,) + output) if loss is not None else output # utterance level, 1=accuracy, 2=fluency, 3=total score, 4=content, , 'content': u4 return CausalLMOutput( loss=loss, logits={'logits': logits, 'accuracy': u2, 'fluency': u1, 'total score': u3, 'content': u4}, hidden_states=outputs.hidden_states, attentions=outputs.attentions )