Spaces:
Sleeping
Sleeping
# part of the code was referenced from SUPERB: https://github.com/s3prl/s3prl | |
# and https://github.com/wngh1187/IPET/blob/main/Speechcommands_V2/W2V2/models/W2V2.py | |
import os | |
import pdb | |
import copy | |
import torch | |
import argparse | |
import numpy as np | |
import loralib as lora | |
import transformers.models.wav2vec2.modeling_wav2vec2 as w2v2 | |
import transformers.models.wavlm.modeling_wavlm as wavlm | |
from functools import lru_cache | |
from torchaudio.compliance import kaldi | |
from torch import nn | |
from adapter import Adapter | |
from collections import OrderedDict | |
from typing import Optional, Callable | |
from torch.nn import functional as F | |
from torch.nn.functional import normalize | |
from transformers import WavLMModel | |
class WavLMEncoderLayer(nn.Module): | |
def __init__(self, config, has_relative_position_bias: bool = True): | |
super().__init__() | |
self.attention = wavlm.WavLMAttention( | |
embed_dim=config.hidden_size, | |
num_heads=config.num_attention_heads, | |
dropout=config.attention_dropout, | |
num_buckets=config.num_buckets, | |
max_distance=config.max_bucket_distance, | |
has_relative_position_bias=has_relative_position_bias, | |
) | |
self.dropout = nn.Dropout(config.hidden_dropout) | |
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
self.feed_forward = wavlm.WavLMFeedForward(config) | |
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
self.config = config | |
if self.config.finetune_method == "embedding_prompt" or self.config.finetune_method == "combined": | |
self.embed_prompt = nn.Parameter(torch.randn([1, self.config.embedding_prompt_dim, 768])) | |
nn.init.xavier_uniform_(self.embed_prompt) | |
if self.config.finetune_method == "lora" or self.config.finetune_method == "combined": | |
self.feed_forward.intermediate_dense = lora.Linear(config.hidden_size, config.intermediate_size, r=config.lora_rank) | |
self.feed_forward.output_dense = lora.Linear(config.intermediate_size, config.hidden_size, r=config.lora_rank) | |
if self.config.finetune_method == "adapter" or self.config.finetune_method == "adapter_l" or self.config.finetune_method == "combined": | |
self.adapter = Adapter( | |
config, | |
dropout=0.1, | |
bottleneck=config.adapter_hidden_dim, | |
adapter_scalar=0.1 | |
) | |
def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0): | |
if self.config.finetune_method == "embedding_prompt" or self.config.finetune_method == "combined": | |
hidden_states = torch.cat((self.embed_prompt.repeat(hidden_states.size(0), 1, 1), hidden_states), dim=1) | |
attn_residual = hidden_states | |
hidden_states, attn_weights, position_bias = self.attention( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_bias=position_bias, | |
output_attentions=output_attentions, | |
index=index, | |
) | |
hidden_states = self.dropout(hidden_states) | |
hidden_states = attn_residual + hidden_states | |
# Adapter | |
if self.config.finetune_method == "adapter": | |
adapt_h = self.adapter(hidden_states) | |
hidden_states = self.layer_norm(hidden_states) | |
hidden_states = hidden_states + self.feed_forward(hidden_states) | |
if self.config.finetune_method == "adapter": | |
hidden_states = hidden_states + adapt_h | |
if self.config.finetune_method == "adapter_l" or self.config.finetune_method == "combined": | |
hidden_states = hidden_states + self.adapter(hidden_states) | |
hidden_states = self.final_layer_norm(hidden_states) | |
if self.config.finetune_method == "embedding_prompt" or self.config.finetune_method == "combined": | |
hidden_states = hidden_states[:, self.config.embedding_prompt_dim:, :] | |
outputs = (hidden_states, position_bias) | |
if output_attentions: | |
outputs += (attn_weights,) | |
return outputs | |
class WavLMWrapper(nn.Module): | |
def __init__( | |
self, | |
args, | |
hidden_dim=256, | |
output_class_num=7 | |
): | |
super(WavLMWrapper, self).__init__() | |
# 1. We Load the model first with weights | |
self.args = args | |
self.backbone_model = WavLMModel.from_pretrained( | |
"microsoft/wavlm-base-plus", | |
output_hidden_states=True | |
) | |
state_dict = self.backbone_model.state_dict() | |
# 2. Read the model config | |
self.model_config = self.backbone_model.config | |
self.model_config.finetune_method = args.finetune_method | |
self.model_config.adapter_hidden_dim = args.adapter_hidden_dim | |
self.model_config.embedding_prompt_dim = args.embedding_prompt_dim | |
self.model_config.lora_rank = args.lora_rank | |
# 3. Config encoder layers with adapter or embedding prompt | |
# pdb.set_trace() | |
self.backbone_model.encoder.layers = nn.ModuleList( | |
[WavLMEncoderLayer(self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)] | |
) | |
# 4. Load the weights back | |
msg = self.backbone_model.load_state_dict(state_dict, strict=False) | |
# 5. Freeze the weights | |
if self.args.finetune_method == "adapter" or self.args.finetune_method == "adapter_l" or self.args.finetune_method == "embedding_prompt" or self.args.finetune_method == "finetune" or self.args.finetune_method == "lora" or self.args.finetune_method == "combined": | |
for name, p in self.backbone_model.named_parameters(): | |
if name in msg.missing_keys: p.requires_grad = True | |
else: p.requires_grad = False | |
self.finetune_method = self.args.finetune_method | |
# 6. Downstream models | |
self.model_seq = nn.Sequential( | |
nn.Conv1d(self.model_config.hidden_size, hidden_dim, 1, padding=0), | |
nn.ReLU(), | |
nn.Dropout(p=0.1), | |
nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0), | |
nn.ReLU(), | |
nn.Dropout(p=0.1), | |
nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0) | |
) | |
self.weights = nn.Parameter(torch.zeros(self.model_config.num_hidden_layers)) | |
# self.out_layer = nn.Sequential( | |
# nn.Linear(hidden_dim, hidden_dim), | |
# nn.ReLU(), | |
# nn.Linear(hidden_dim, output_class_num), | |
# ) | |
self.out_layer = nn.Sequential( | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.ReLU(), | |
nn.Linear(hidden_dim, 2), | |
nn.Sigmoid() | |
) | |
def forward(self, x, length=None): | |
# 1. feature extraction and projections | |
with torch.no_grad(): | |
x = self.backbone_model.feature_extractor(x) | |
x = x.transpose(1, 2) # New version of huggingface | |
x, _ = self.backbone_model.feature_projection(x) # New version of huggingface | |
# 2. get length and mask | |
if length is not None: | |
length = self.get_feat_extract_output_lengths(length.detach().cpu()) | |
length = length.cuda() | |
# 3. transformer encoding features | |
x = self.backbone_model.encoder( | |
x, output_hidden_states=True | |
).hidden_states | |
# 4. stacked feature | |
stacked_feature = torch.stack(x, dim=0)[1:] | |
# 5. Weighted sum | |
_, *origin_shape = stacked_feature.shape | |
# Return transformer enc outputs [num_enc_layers, B, T, D] | |
stacked_feature = stacked_feature.view(self.backbone_model.config.num_hidden_layers, -1) | |
norm_weights = F.softmax(self.weights, dim=-1) | |
# Perform weighted average | |
weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0) | |
features = weighted_feature.view(*origin_shape) | |
# 6. Pass the weighted average to point-wise 1D Conv | |
# B x T x D | |
features = features.transpose(1, 2) | |
features = self.model_seq(features) | |
features = features.transpose(1, 2) | |
# 7. Pooling | |
if length is not None: | |
masks = torch.arange(features.size(1)).expand(length.size(0), -1).cuda() < length.unsqueeze(1) | |
masks = masks.float() | |
features = (features * masks.unsqueeze(-1)).sum(1) / length.unsqueeze(1) | |
else: | |
features = torch.mean(features, dim=1) | |
# 8. Output predictions | |
# B x D | |
predicted = self.out_layer(features) | |
return predicted | |
# From huggingface | |
def get_feat_extract_output_lengths(self, input_length): | |
""" | |
Computes the output length of the convolutional layers | |
""" | |
def _conv_out_length(input_length, kernel_size, stride): | |
# 1D convolutional layer output length formula taken | |
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html | |
return (input_length - kernel_size) // stride + 1 | |
for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride): | |
input_length = _conv_out_length(input_length, kernel_size, stride) | |
return input_length | |
def prepare_mask(length, shape, dtype): | |
# Modified from huggingface | |
mask = torch.zeros( | |
shape, dtype=dtype | |
) | |
# these two operations makes sure that all values | |
# before the output lengths indices are attended to | |
mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1 | |
mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool() | |
return mask | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='emo2vec finetune experiments') | |
parser.add_argument( | |
'--finetune_method', | |
default='none', | |
type=str, | |
help='finetune method: adapter, embedding prompt, input prompt' | |
) | |
parser.add_argument( | |
'--adapter_hidden_dim', | |
default=128, | |
type=int, | |
help='adapter dimension' | |
) | |
parser.add_argument( | |
'--embedding_prompt_dim', | |
default=5, | |
type=int, | |
help='adapter dimension' | |
) | |
args = parser.parse_args() | |
model = WavLMWrapper(args) | |
data = torch.zeros([1, 16000]) | |
output = model(data) | |
print(output.shape) |