multimodalart's picture
Upload 247 files
7758cff verified
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from transformers import HubertModel
from transformers.modeling_outputs import BaseModelOutput
_CONFIG_FOR_DOC = 'HubertConfig'
def linear_interpolation(features, seq_len):
"""
Transpose the features to interpolate linearly.
Args:
features (torch.Tensor): The extracted features to be interpolated.
seq_len (torch.Tensor): The sequence lengths of the features.
Returns:
torch.Tensor: The interpolated features.
"""
features = features.transpose(1, 2)
output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
return output_features.transpose(1, 2)
class HubertModel_(HubertModel):
def __init__(self, config):
super().__init__(config)
def forward(
self,
input_values: Optional[torch.Tensor],
seq_len: Optional[int],
sample_strategy: Optional[str] = "presample",
attention_mask: Optional[torch.LongTensor] = None,
mask_time_indices: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
Forward pass of the HuBERT model.
Args:
self: The instance of the model.
input_values: The input values (waveform) to the model.
seq_len: The sequence length of the input values.
sample_strategy: The sample strategy to align features and seq_len, supports ['presample', 'postsample'].
attention_mask: Attention mask to be used for the model.
mask_time_indices: Mask indices to be used for the model.
output_attentions: If set to True, returns attentions.
output_hidden_states: If set to True, returns hidden states.
return_dict: If set to True, returns a BaseModelOutput instead of a tuple.
Returns:
The output of the HuBERT model.
"""
# output_fps=25,
# attention_mask=None,
# output_attentions=None,
# output_hidden_states=None,
# return_dict=None,
# frame_num=None
assert sample_strategy in ["presample", "postsample"], f"sample_strategy must be in ['presample', 'postsample]"
self.config.output_attentions = True
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
extract_features = self.feature_extractor(input_values) # (N, C, L)
extract_features = extract_features.transpose(1, 2)
if sample_strategy == "presample":
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
# # Resample the audio feature @ 50 fps to `output_fps`.
# if frame_num is not None:
# extract_features_len = round(frame_num * 50 / output_fps)
# extract_features = extract_features[:, :, :extract_features_len]
# extract_features = linear_interpolation(extract_features, 50, output_fps, output_len=frame_num)
# extract_features = extract_features.transpose(1, 2) # (N, L, C)
if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
hidden_states = self.feature_projection(extract_features)
hidden_states = self._mask_hidden_states(
hidden_states,
mask_time_indices=mask_time_indices,
attention_mask=attention_mask
)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if sample_strategy == "postsample":
hidden_states = linear_interpolation(hidden_states, seq_len=seq_len)
for i in range(len(encoder_outputs.hidden_states)):
encoder_outputs.hidden_states[i] = linear_interpolation(encoder_outputs.hidden_states[i], seq_len=seq_len)
if not return_dict:
return (hidden_states,) + encoder_outputs[1:]
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)