# part of the code was referenced from SUPERB: https://github.com/s3prl/s3prl # and https://github.com/wngh1187/IPET/blob/main/Speechcommands_V2/W2V2/models/W2V2.py import os import pdb import copy import torch import argparse import numpy as np import loralib as lora import transformers.models.wav2vec2.modeling_wav2vec2 as w2v2 import transformers.models.wavlm.modeling_wavlm as wavlm from functools import lru_cache from torchaudio.compliance import kaldi from torch import nn from adapter import Adapter from collections import OrderedDict from typing import Optional, Callable from torch.nn import functional as F from torch.nn.functional import normalize from transformers import WavLMModel class WavLMEncoderLayer(nn.Module): def __init__(self, config, has_relative_position_bias: bool = True): super().__init__() self.attention = wavlm.WavLMAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, num_buckets=config.num_buckets, max_distance=config.max_bucket_distance, has_relative_position_bias=has_relative_position_bias, ) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.feed_forward = wavlm.WavLMFeedForward(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.config = config if self.config.finetune_method == "embedding_prompt" or self.config.finetune_method == "combined": self.embed_prompt = nn.Parameter(torch.randn([1, self.config.embedding_prompt_dim, 768])) nn.init.xavier_uniform_(self.embed_prompt) if self.config.finetune_method == "lora" or self.config.finetune_method == "combined": self.feed_forward.intermediate_dense = lora.Linear(config.hidden_size, config.intermediate_size, r=config.lora_rank) self.feed_forward.output_dense = lora.Linear(config.intermediate_size, config.hidden_size, r=config.lora_rank) if self.config.finetune_method == "adapter" or self.config.finetune_method == "adapter_l" or self.config.finetune_method == "combined": self.adapter = Adapter( config, dropout=0.1, bottleneck=config.adapter_hidden_dim, adapter_scalar=0.1 ) def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0): if self.config.finetune_method == "embedding_prompt" or self.config.finetune_method == "combined": hidden_states = torch.cat((self.embed_prompt.repeat(hidden_states.size(0), 1, 1), hidden_states), dim=1) attn_residual = hidden_states hidden_states, attn_weights, position_bias = self.attention( hidden_states, attention_mask=attention_mask, position_bias=position_bias, output_attentions=output_attentions, index=index, ) hidden_states = self.dropout(hidden_states) hidden_states = attn_residual + hidden_states # Adapter if self.config.finetune_method == "adapter": adapt_h = self.adapter(hidden_states) hidden_states = self.layer_norm(hidden_states) hidden_states = hidden_states + self.feed_forward(hidden_states) if self.config.finetune_method == "adapter": hidden_states = hidden_states + adapt_h if self.config.finetune_method == "adapter_l" or self.config.finetune_method == "combined": hidden_states = hidden_states + self.adapter(hidden_states) hidden_states = self.final_layer_norm(hidden_states) if self.config.finetune_method == "embedding_prompt" or self.config.finetune_method == "combined": hidden_states = hidden_states[:, self.config.embedding_prompt_dim:, :] outputs = (hidden_states, position_bias) if output_attentions: outputs += (attn_weights,) return outputs class WavLMWrapper(nn.Module): def __init__( self, args, hidden_dim=256, output_class_num=2 ): super(WavLMWrapper, self).__init__() # 1. We Load the model first with weights self.args = args self.backbone_model = WavLMModel.from_pretrained( "microsoft/wavlm-base-plus", output_hidden_states=True ) state_dict = self.backbone_model.state_dict() # 2. Read the model config self.model_config = self.backbone_model.config self.model_config.finetune_method = args.finetune_method self.model_config.adapter_hidden_dim = args.adapter_hidden_dim self.model_config.embedding_prompt_dim = args.embedding_prompt_dim self.model_config.lora_rank = args.lora_rank # 3. Config encoder layers with adapter or embedding prompt # pdb.set_trace() self.backbone_model.encoder.layers = nn.ModuleList( [WavLMEncoderLayer(self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)] ) # 4. Load the weights back msg = self.backbone_model.load_state_dict(state_dict, strict=False) # 5. Freeze the weights if self.args.finetune_method == "adapter" or self.args.finetune_method == "adapter_l" or self.args.finetune_method == "embedding_prompt" or self.args.finetune_method == "finetune" or self.args.finetune_method == "lora" or self.args.finetune_method == "combined": for name, p in self.backbone_model.named_parameters(): if name in msg.missing_keys: p.requires_grad = True else: p.requires_grad = False self.finetune_method = self.args.finetune_method # 6. Downstream models self.model_seq = nn.Sequential( nn.Conv1d(self.model_config.hidden_size, hidden_dim, 1, padding=0), nn.ReLU(), nn.Dropout(p=0.1), nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0), nn.ReLU(), nn.Dropout(p=0.1), nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0) ) self.weights = nn.Parameter(torch.zeros(self.model_config.num_hidden_layers)) # self.out_layer = nn.Sequential( # nn.Linear(hidden_dim, hidden_dim), # nn.ReLU(), # nn.Linear(hidden_dim, output_class_num), # ) self.out_layer = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, output_class_num), nn.Sigmoid() ) def forward(self, x, length=None): # 1. feature extraction and projections with torch.no_grad(): x = self.backbone_model.feature_extractor(x) x = x.transpose(1, 2) # New version of huggingface x, _ = self.backbone_model.feature_projection(x) # New version of huggingface # 2. get length and mask if length is not None: length = self.get_feat_extract_output_lengths(length.detach().cpu()) length = length.cuda() # 3. transformer encoding features x = self.backbone_model.encoder( x, output_hidden_states=True ).hidden_states # 4. stacked feature stacked_feature = torch.stack(x, dim=0)[1:] # 5. Weighted sum _, *origin_shape = stacked_feature.shape # Return transformer enc outputs [num_enc_layers, B, T, D] stacked_feature = stacked_feature.view(self.backbone_model.config.num_hidden_layers, -1) norm_weights = F.softmax(self.weights, dim=-1) # Perform weighted average weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0) features = weighted_feature.view(*origin_shape) # 6. Pass the weighted average to point-wise 1D Conv # B x T x D features = features.transpose(1, 2) features = self.model_seq(features) features = features.transpose(1, 2) # 7. Pooling if length is not None: masks = torch.arange(features.size(1)).expand(length.size(0), -1).cuda() < length.unsqueeze(1) masks = masks.float() features = (features * masks.unsqueeze(-1)).sum(1) / length.unsqueeze(1) else: features = torch.mean(features, dim=1) # 8. Output predictions # B x D predicted = self.out_layer(features) return predicted # From huggingface def get_feat_extract_output_lengths(self, input_length): """ Computes the output length of the convolutional layers """ def _conv_out_length(input_length, kernel_size, stride): # 1D convolutional layer output length formula taken # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html return (input_length - kernel_size) // stride + 1 for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride): input_length = _conv_out_length(input_length, kernel_size, stride) return input_length def prepare_mask(length, shape, dtype): # Modified from huggingface mask = torch.zeros( shape, dtype=dtype ) # these two operations makes sure that all values # before the output lengths indices are attended to mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1 mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool() return mask if __name__ == '__main__': parser = argparse.ArgumentParser(description='emo2vec finetune experiments') parser.add_argument( '--finetune_method', default='none', type=str, help='finetune method: adapter, embedding prompt, input prompt' ) parser.add_argument( '--adapter_hidden_dim', default=128, type=int, help='adapter dimension' ) parser.add_argument( '--embedding_prompt_dim', default=5, type=int, help='adapter dimension' ) args = parser.parse_args() model = WavLMWrapper(args) data = torch.zeros([1, 16000]) output = model(data) print(output.shape)