SER_AUDIO / wavlm_plus.py
saikrishna32's picture
added requirements
4aa6431
raw
history blame
10.6 kB
# part of the code was referenced from SUPERB: https://github.com/s3prl/s3prl
# and https://github.com/wngh1187/IPET/blob/main/Speechcommands_V2/W2V2/models/W2V2.py
import os
import pdb
import copy
import torch
import argparse
import numpy as np
import loralib as lora
import transformers.models.wav2vec2.modeling_wav2vec2 as w2v2
import transformers.models.wavlm.modeling_wavlm as wavlm
from functools import lru_cache
from torchaudio.compliance import kaldi
from torch import nn
from adapter import Adapter
from collections import OrderedDict
from typing import Optional, Callable
from torch.nn import functional as F
from torch.nn.functional import normalize
from transformers import WavLMModel
class WavLMEncoderLayer(nn.Module):
def __init__(self, config, has_relative_position_bias: bool = True):
super().__init__()
self.attention = wavlm.WavLMAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
num_buckets=config.num_buckets,
max_distance=config.max_bucket_distance,
has_relative_position_bias=has_relative_position_bias,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = wavlm.WavLMFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.config = config
if self.config.finetune_method == "embedding_prompt" or self.config.finetune_method == "combined":
self.embed_prompt = nn.Parameter(torch.randn([1, self.config.embedding_prompt_dim, 768]))
nn.init.xavier_uniform_(self.embed_prompt)
if self.config.finetune_method == "lora" or self.config.finetune_method == "combined":
self.feed_forward.intermediate_dense = lora.Linear(config.hidden_size, config.intermediate_size, r=config.lora_rank)
self.feed_forward.output_dense = lora.Linear(config.intermediate_size, config.hidden_size, r=config.lora_rank)
if self.config.finetune_method == "adapter" or self.config.finetune_method == "adapter_l" or self.config.finetune_method == "combined":
self.adapter = Adapter(
config,
dropout=0.1,
bottleneck=config.adapter_hidden_dim,
adapter_scalar=0.1
)
def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0):
if self.config.finetune_method == "embedding_prompt" or self.config.finetune_method == "combined":
hidden_states = torch.cat((self.embed_prompt.repeat(hidden_states.size(0), 1, 1), hidden_states), dim=1)
attn_residual = hidden_states
hidden_states, attn_weights, position_bias = self.attention(
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
output_attentions=output_attentions,
index=index,
)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
# Adapter
if self.config.finetune_method == "adapter":
adapt_h = self.adapter(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states + self.feed_forward(hidden_states)
if self.config.finetune_method == "adapter":
hidden_states = hidden_states + adapt_h
if self.config.finetune_method == "adapter_l" or self.config.finetune_method == "combined":
hidden_states = hidden_states + self.adapter(hidden_states)
hidden_states = self.final_layer_norm(hidden_states)
if self.config.finetune_method == "embedding_prompt" or self.config.finetune_method == "combined":
hidden_states = hidden_states[:, self.config.embedding_prompt_dim:, :]
outputs = (hidden_states, position_bias)
if output_attentions:
outputs += (attn_weights,)
return outputs
class WavLMWrapper(nn.Module):
def __init__(
self,
args,
hidden_dim=256,
output_class_num=7
):
super(WavLMWrapper, self).__init__()
# 1. We Load the model first with weights
self.args = args
self.backbone_model = WavLMModel.from_pretrained(
"microsoft/wavlm-base-plus",
output_hidden_states=True
)
state_dict = self.backbone_model.state_dict()
# 2. Read the model config
self.model_config = self.backbone_model.config
self.model_config.finetune_method = args.finetune_method
self.model_config.adapter_hidden_dim = args.adapter_hidden_dim
self.model_config.embedding_prompt_dim = args.embedding_prompt_dim
self.model_config.lora_rank = args.lora_rank
# 3. Config encoder layers with adapter or embedding prompt
# pdb.set_trace()
self.backbone_model.encoder.layers = nn.ModuleList(
[WavLMEncoderLayer(self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)]
)
# 4. Load the weights back
msg = self.backbone_model.load_state_dict(state_dict, strict=False)
# 5. Freeze the weights
if self.args.finetune_method == "adapter" or self.args.finetune_method == "adapter_l" or self.args.finetune_method == "embedding_prompt" or self.args.finetune_method == "finetune" or self.args.finetune_method == "lora" or self.args.finetune_method == "combined":
for name, p in self.backbone_model.named_parameters():
if name in msg.missing_keys: p.requires_grad = True
else: p.requires_grad = False
self.finetune_method = self.args.finetune_method
# 6. Downstream models
self.model_seq = nn.Sequential(
nn.Conv1d(self.model_config.hidden_size, hidden_dim, 1, padding=0),
nn.ReLU(),
nn.Dropout(p=0.1),
nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0),
nn.ReLU(),
nn.Dropout(p=0.1),
nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0)
)
self.weights = nn.Parameter(torch.zeros(self.model_config.num_hidden_layers))
# self.out_layer = nn.Sequential(
# nn.Linear(hidden_dim, hidden_dim),
# nn.ReLU(),
# nn.Linear(hidden_dim, output_class_num),
# )
self.out_layer = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 2),
nn.Sigmoid()
)
def forward(self, x, length=None):
# 1. feature extraction and projections
with torch.no_grad():
x = self.backbone_model.feature_extractor(x)
x = x.transpose(1, 2) # New version of huggingface
x, _ = self.backbone_model.feature_projection(x) # New version of huggingface
# 2. get length and mask
if length is not None:
length = self.get_feat_extract_output_lengths(length.detach().cpu())
length = length.cuda()
# 3. transformer encoding features
x = self.backbone_model.encoder(
x, output_hidden_states=True
).hidden_states
# 4. stacked feature
stacked_feature = torch.stack(x, dim=0)[1:]
# 5. Weighted sum
_, *origin_shape = stacked_feature.shape
# Return transformer enc outputs [num_enc_layers, B, T, D]
stacked_feature = stacked_feature.view(self.backbone_model.config.num_hidden_layers, -1)
norm_weights = F.softmax(self.weights, dim=-1)
# Perform weighted average
weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0)
features = weighted_feature.view(*origin_shape)
# 6. Pass the weighted average to point-wise 1D Conv
# B x T x D
features = features.transpose(1, 2)
features = self.model_seq(features)
features = features.transpose(1, 2)
# 7. Pooling
if length is not None:
masks = torch.arange(features.size(1)).expand(length.size(0), -1).cuda() < length.unsqueeze(1)
masks = masks.float()
features = (features * masks.unsqueeze(-1)).sum(1) / length.unsqueeze(1)
else:
features = torch.mean(features, dim=1)
# 8. Output predictions
# B x D
predicted = self.out_layer(features)
return predicted
# From huggingface
def get_feat_extract_output_lengths(self, input_length):
"""
Computes the output length of the convolutional layers
"""
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride):
input_length = _conv_out_length(input_length, kernel_size, stride)
return input_length
def prepare_mask(length, shape, dtype):
# Modified from huggingface
mask = torch.zeros(
shape, dtype=dtype
)
# these two operations makes sure that all values
# before the output lengths indices are attended to
mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1
mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return mask
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='emo2vec finetune experiments')
parser.add_argument(
'--finetune_method',
default='none',
type=str,
help='finetune method: adapter, embedding prompt, input prompt'
)
parser.add_argument(
'--adapter_hidden_dim',
default=128,
type=int,
help='adapter dimension'
)
parser.add_argument(
'--embedding_prompt_dim',
default=5,
type=int,
help='adapter dimension'
)
args = parser.parse_args()
model = WavLMWrapper(args)
data = torch.zeros([1, 16000])
output = model(data)
print(output.shape)