SER_AUDIO / adapter.py
saikrishna32's picture
added requirements
4aa6431
raw
history blame contribute delete
No virus
2.38 kB
# --------------------------------------------------------
# References:
# https://github.com/jxhe/unify-parameter-efficient-tuning
# --------------------------------------------------------
import math
import torch
import torch.nn as nn
class Adapter(nn.Module):
def __init__(
self,
config=None,
d_model=768,
bottleneck=None,
dropout=0.0,
init_option="lora",
adapter_scalar="1.0",
adapter_layernorm_option="none"
):
super().__init__()
self.n_embd = config.d_model if d_model is None else d_model
self.down_size = config.attn_bn if bottleneck is None else bottleneck
#_before
self.adapter_layernorm_option = adapter_layernorm_option
self.adapter_layer_norm_before = None
if adapter_layernorm_option == "in" or adapter_layernorm_option == "out":
self.adapter_layer_norm_before = nn.LayerNorm(self.n_embd)
if adapter_scalar == "learnable_scalar":
self.scale = nn.Parameter(torch.ones(1))
else:
self.scale = float(adapter_scalar)
self.down_proj = nn.Linear(self.n_embd, self.down_size)
self.non_linear_func = nn.ReLU()
self.up_proj = nn.Linear(self.down_size, self.n_embd)
self.dropout = dropout
if init_option == "bert":
raise NotImplementedError
elif init_option == "lora":
with torch.no_grad():
nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
nn.init.zeros_(self.up_proj.weight)
nn.init.zeros_(self.down_proj.bias)
nn.init.zeros_(self.up_proj.bias)
def forward(self, x, add_residual=True, residual=None):
residual = x if residual is None else residual
if self.adapter_layernorm_option == 'in':
x = self.adapter_layer_norm_before(x)
down = self.down_proj(x)
down = self.non_linear_func(down)
down = nn.functional.dropout(down, p=self.dropout, training=self.training)
up = self.up_proj(down)
up = up * self.scale
if self.adapter_layernorm_option == 'out':
up = self.adapter_layer_norm_before(up)
if add_residual:
output = up + residual
else:
output = up
return output