|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from transformers import HubertModel |
|
from transformers.modeling_outputs import BaseModelOutput |
|
|
|
|
|
_CONFIG_FOR_DOC = 'HubertConfig' |
|
|
|
|
|
def linear_interpolation(features, seq_len): |
|
""" |
|
Transpose the features to interpolate linearly. |
|
|
|
Args: |
|
features (torch.Tensor): The extracted features to be interpolated. |
|
seq_len (torch.Tensor): The sequence lengths of the features. |
|
|
|
Returns: |
|
torch.Tensor: The interpolated features. |
|
""" |
|
features = features.transpose(1, 2) |
|
output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear') |
|
return output_features.transpose(1, 2) |
|
|
|
|
|
class HubertModel_(HubertModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
def forward( |
|
self, |
|
input_values: Optional[torch.Tensor], |
|
seq_len: Optional[int], |
|
sample_strategy: Optional[str] = "presample", |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
mask_time_indices: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
): |
|
""" |
|
Forward pass of the HuBERT model. |
|
|
|
Args: |
|
self: The instance of the model. |
|
input_values: The input values (waveform) to the model. |
|
seq_len: The sequence length of the input values. |
|
sample_strategy: The sample strategy to align features and seq_len, supports ['presample', 'postsample']. |
|
attention_mask: Attention mask to be used for the model. |
|
mask_time_indices: Mask indices to be used for the model. |
|
output_attentions: If set to True, returns attentions. |
|
output_hidden_states: If set to True, returns hidden states. |
|
return_dict: If set to True, returns a BaseModelOutput instead of a tuple. |
|
|
|
Returns: |
|
The output of the HuBERT model. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
assert sample_strategy in ["presample", "postsample"], f"sample_strategy must be in ['presample', 'postsample]" |
|
self.config.output_attentions = True |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
extract_features = self.feature_extractor(input_values) |
|
extract_features = extract_features.transpose(1, 2) |
|
if sample_strategy == "presample": |
|
extract_features = linear_interpolation(extract_features, seq_len=seq_len) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
|
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) |
|
|
|
hidden_states = self.feature_projection(extract_features) |
|
hidden_states = self._mask_hidden_states( |
|
hidden_states, |
|
mask_time_indices=mask_time_indices, |
|
attention_mask=attention_mask |
|
) |
|
|
|
encoder_outputs = self.encoder( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = encoder_outputs[0] |
|
|
|
if sample_strategy == "postsample": |
|
hidden_states = linear_interpolation(hidden_states, seq_len=seq_len) |
|
for i in range(len(encoder_outputs.hidden_states)): |
|
encoder_outputs.hidden_states[i] = linear_interpolation(encoder_outputs.hidden_states[i], seq_len=seq_len) |
|
|
|
if not return_dict: |
|
return (hidden_states,) + encoder_outputs[1:] |
|
|
|
return BaseModelOutput( |
|
last_hidden_state=hidden_states, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
) |
|
|