|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers.models.seamless_m4t.modeling_seamless_m4t import ( |
|
_compute_new_attention_mask, |
|
) |
|
from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2 import ( |
|
SeamlessM4Tv2SpeechEncoder, |
|
SeamlessM4Tv2PreTrainedModel, |
|
) |
|
from .configuration_seamless_m4t_v2_speech_encoder import ( |
|
MODEL_TYPE, |
|
SeamlessM4Tv2EncoderConfig, |
|
) |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
from transformers.models.auto import ( |
|
AutoModel, |
|
AutoModelForAudioClassification, |
|
AutoModelForSequenceClassification, |
|
) |
|
|
|
|
|
class SeamlessM4Tv2SpeechEncoder(SeamlessM4Tv2SpeechEncoder): |
|
model_type = MODEL_TYPE |
|
config_class = SeamlessM4Tv2EncoderConfig |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
@staticmethod |
|
def mean_pooling( |
|
hidden_states: torch.Tensor, attention_mask: torch.Tensor |
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
input_mask_expanded = ( |
|
attention_mask.unsqueeze(-1).expand(hidden_states.size()).float() |
|
) |
|
sum_hidden_states = torch.sum(hidden_states * input_mask_expanded, 1) |
|
sum_mask = input_mask_expanded.sum(1) |
|
|
|
return sum_hidden_states / torch.clamp(sum_mask, min=1e-9) |
|
|
|
|
|
class SeamlessM4Tv2ForAudioClassification(SeamlessM4Tv2PreTrainedModel): |
|
model_type = MODEL_TYPE |
|
base_model_prefix = "model" |
|
config_class = SeamlessM4Tv2EncoderConfig |
|
|
|
def __init__(self, config, *args, **kwargs): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
self.model = SeamlessM4Tv2SpeechEncoder(config) |
|
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) |
|
|
|
def forward( |
|
self, |
|
input_features: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
labels: None | torch.Tensor, |
|
*args, |
|
**kwargs, |
|
): |
|
output_hidden_states = kwargs.pop("output_hidden_states", False) |
|
outputs = self.model( |
|
input_features, |
|
attention_mask, |
|
output_hidden_states=output_hidden_states, |
|
*args, |
|
**kwargs, |
|
) |
|
hidden_states = outputs.last_hidden_state |
|
if attention_mask is not None: |
|
sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask( |
|
attention_mask |
|
).to(outputs.last_hidden_state.device) |
|
attention_mask = _compute_new_attention_mask( |
|
hidden_states=hidden_states, seq_lens=sub_sampled_lengths |
|
) |
|
hidden_states = self.model.mean_pooling( |
|
outputs.last_hidden_state, attention_mask |
|
) |
|
logits = self.score(hidden_states) |
|
|
|
if labels is not None: |
|
|
|
labels = labels.to(logits.device) |
|
if self.config.problem_type is None: |
|
if self.num_labels == 1: |
|
self.config.problem_type = "regression" |
|
elif self.num_labels > 1 and ( |
|
labels.dtype == torch.long or labels.dtype == torch.int |
|
): |
|
self.config.problem_type = "single_label_classification" |
|
else: |
|
self.config.problem_type = "multi_label_classification" |
|
if self.config.problem_type == "regression": |
|
loss_fct = F.mse_loss |
|
if self.num_labels == 1: |
|
loss = loss_fct(logits.squeeze(), labels.squeeze()) |
|
else: |
|
loss = loss_fct(logits, labels) |
|
elif self.config.problem_type == "single_label_classification": |
|
loss_fct = F.cross_entropy |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
elif self.config.problem_type == "multi_label_classification": |
|
loss_fct = F.binary_cross_entropy_with_logits |
|
loss = loss_fct(logits, labels) |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states if output_hidden_states else None, |
|
) |
|
|
|
|
|
AutoModel.register(SeamlessM4Tv2EncoderConfig, SeamlessM4Tv2SpeechEncoder) |
|
AutoModelForAudioClassification.register( |
|
SeamlessM4Tv2EncoderConfig, SeamlessM4Tv2ForAudioClassification |
|
) |
|
AutoModelForSequenceClassification.register( |
|
SeamlessM4Tv2EncoderConfig, SeamlessM4Tv2ForAudioClassification |
|
) |
|
|