from transformers import Wav2Vec2Config, Wav2Vec2Model from transformers.modeling_outputs import BaseModelOutput from .torch_utils import linear_interpolation # the implementation of Wav2Vec2Model is borrowed from # https://github.com/huggingface/transformers/blob/HEAD/src/transformers/models/wav2vec2/modeling_wav2vec2.py # initialize our encoder with the pre-trained wav2vec 2.0 weights. class Wav2Vec2Model(Wav2Vec2Model): def __init__(self, config: Wav2Vec2Config): super().__init__(config) def forward( self, input_values, seq_len, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): self.config.output_attentions = True output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict extract_features = self.feature_extractor(input_values) extract_features = extract_features.transpose(1, 2) extract_features = linear_interpolation(extract_features, seq_len=seq_len) if attention_mask is not None: # compute reduced attention_mask corresponding to feature vectors attention_mask = self._get_feature_vector_attention_mask( extract_features.shape[1], attention_mask, add_adapter=False ) hidden_states, extract_features = self.feature_projection(extract_features) hidden_states = self._mask_hidden_states( hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask ) encoder_outputs = self.encoder( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = encoder_outputs[0] if self.adapter is not None: hidden_states = self.adapter(hidden_states) if not return_dict: return (hidden_states, ) + encoder_outputs[1:] return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) def feature_extract( self, input_values, seq_len, ): extract_features = self.feature_extractor(input_values) extract_features = extract_features.transpose(1, 2) extract_features = linear_interpolation(extract_features, seq_len=seq_len) return extract_features def encode( self, extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): self.config.output_attentions = True output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if attention_mask is not None: # compute reduced attention_mask corresponding to feature vectors attention_mask = self._get_feature_vector_attention_mask( extract_features.shape[1], attention_mask, add_adapter=False ) hidden_states, extract_features = self.feature_projection(extract_features) hidden_states = self._mask_hidden_states( hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask ) encoder_outputs = self.encoder( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = encoder_outputs[0] if self.adapter is not None: hidden_states = self.adapter(hidden_states) if not return_dict: return (hidden_states, ) + encoder_outputs[1:] return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, )