wavlm-bert-tiny2-s-emotion-russian-resd / audio_text_multimodal.py
Ar4ikov's picture
Update audio_text_multimodal.py
3fb246b
from dataclasses import dataclass
from typing import Union, Type
import torch
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import (
PreTrainedModel,
PretrainedConfig,
WavLMConfig,
BertConfig,
WavLMModel,
BertModel,
Wav2Vec2Config,
Wav2Vec2Model
)
from transformers.models.wavlm.modeling_wavlm import (
WavLMEncoder,
WavLMEncoderStableLayerNorm,
WavLMFeatureEncoder
)
from transformers.models.bert.modeling_bert import BertEncoder
class MultiModalConfig(PretrainedConfig):
"""Base class for multimodal configs"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
class WavLMBertConfig(MultiModalConfig):
...
class BaseClassificationModel(PreTrainedModel):
config: Type[Union[PretrainedConfig, None]] = None
def compute_loss(self, logits, labels):
"""Compute loss
Args:
logits (torch.FloatTensor): logits
labels (torch.LongTensor): labels
Returns:
torch.FloatTensor: loss
Raises:
ValueError: Invalid number of labels
"""
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1:
self.config.problem_type = "single_label_classification"
else:
raise ValueError("Invalid number of labels: {}".format(self.num_labels))
if self.config.problem_type == "single_label_classification":
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = torch.nn.BCEWithLogitsLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
elif self.config.problem_type == "regression":
loss_fct = torch.nn.MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
raise ValueError("Problem_type {} not supported".format(self.config.problem_type))
return loss
@staticmethod
def merged_strategy(
hidden_states,
mode="mean"
):
"""Merged strategy for pooling
Args:
hidden_states (torch.FloatTensor): hidden states
mode (str, optional): pooling mode. Defaults to "mean".
Returns:
torch.FloatTensor: pooled hidden states
"""
if mode == "mean":
outputs = torch.mean(hidden_states, dim=1)
elif mode == "sum":
outputs = torch.sum(hidden_states, dim=1)
elif mode == "max":
outputs = torch.max(hidden_states, dim=1)[0]
else:
raise Exception(
"The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")
return outputs
class AudioTextModelForSequenceBaseClassification(BaseClassificationModel):
config_class = MultiModalConfig
def __init__(self, config):
"""
Args:
config (MultiModalConfig): config
Attributes:
config (MultiModalConfig): config
num_labels (int): number of labels
audio_config (Union[PretrainedConfig, None]): audio config
text_config (Union[PretrainedConfig, None]): text config
audio_model (Union[PreTrainedModel, None]): audio model
text_model (Union[PreTrainedModel, None]): text model
classifier (Union[torch.nn.Linear, None]): classifier
"""
super().__init__(config)
self.config = config
self.num_labels = self.config.num_labels
self.audio_config: Union[PretrainedConfig, None] = None
self.text_config: Union[PretrainedConfig, None] = None
self.audio_model: Union[PreTrainedModel, None] = None
self.text_model: Union[PreTrainedModel, None] = None
self.classifier: Union[torch.nn.Linear, None] = None
def forward(
self,
input_ids=None,
input_values=None,
text_attention_mask=None,
audio_attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=True,
):
"""Forward method for multimodal model for sequence classification task (e.g. text + audio)
Args:
input_ids (torch.LongTensor, optional): input ids. Defaults to None.
input_values (torch.FloatTensor, optional): input values. Defaults to None.
text_attention_mask (torch.LongTensor, optional): text attention mask. Defaults to None.
audio_attention_mask (torch.LongTensor, optional): audio attention mask. Defaults to None.
token_type_ids (torch.LongTensor, optional): token type ids. Defaults to None.
position_ids (torch.LongTensor, optional): position ids. Defaults to None.
head_mask (torch.FloatTensor, optional): head mask. Defaults to None.
inputs_embeds (torch.FloatTensor, optional): inputs embeds. Defaults to None.
labels (torch.LongTensor, optional): labels. Defaults to None.
output_attentions (bool, optional): output attentions. Defaults to None.
output_hidden_states (bool, optional): output hidden states. Defaults to None.
return_dict (bool, optional): return dict. Defaults to True.
Returns:
torch.FloatTensor: logits
"""
audio_output = self.audio_model(
input_values=input_values,
attention_mask=audio_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
text_output = self.text_model(
input_ids=input_ids,
attention_mask=text_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
audio_mean = self.merged_strategy(audio_output.last_hidden_state, mode=self.config.pooling_mode)
pooled_output = torch.cat(
(audio_mean, text_output.pooler_output), dim=1
)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss = self.compute_loss(logits, labels)
return SequenceClassifierOutput(
loss=loss,
logits=logits
)
class WavLMBertForSequenceClassification(AudioTextModelForSequenceBaseClassification):
"""
WavLMBertForSequenceClassification is a model for sequence classification task
(e.g. sentiment analysis, text classification, etc.)
Args:
config (WavLMBertConfig): config
Attributes:
config (WavLMBertConfig): config
audio_config (WavLMConfig): wav2vec2 config
text_config (BertConfig): bert config
audio_model (WavLMModel): wav2vec2 model
text_model (BertModel): bert model
classifier (torch.nn.Linear): classifier
"""
def __init__(self, config):
super().__init__(config)
self.supports_gradient_checkpointing = getattr(config, "gradient_checkpointing", True)
self.audio_config = WavLMConfig.from_dict(self.config.WavLMModel)
self.text_config = BertConfig.from_dict(self.config.BertModel)
self.audio_model = WavLMModel(self.audio_config)
self.text_model = BertModel(self.text_config)
self.classifier = torch.nn.Linear(
self.audio_config.hidden_size + self.text_config.hidden_size, self.num_labels
)
self.init_weights()
@staticmethod
def _set_gradient_checkpointing(module, value=False):
if isinstance(module, (WavLMEncoder, WavLMEncoderStableLayerNorm, WavLMFeatureEncoder, BertEncoder)):
module.gradient_checkpointing = value