magma / magma /adapters.py
stellaathena's picture
This should work
bb5cd12
import torch
import torch.nn as nn
from torchtyping import TensorType
class Adapter(nn.Module):
def __init__(
self,
dim: int,
downsample_factor: int = 4,
activation: nn.Module = nn.ReLU,
add_layernorm: bool = False,
):
super().__init__()
layers = []
if add_layernorm:
layers.append(nn.LayerNorm(dim))
layers.extend(
[
nn.Linear(dim, dim // downsample_factor),
activation(),
nn.Linear(dim // downsample_factor, dim),
]
)
self.adapter = nn.Sequential(*layers)
self.adapter.apply(self.init_weights)
def init_weights(self, m: nn.Module, std=1e-3):
if isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight, std=std)
torch.nn.init.normal_(m.bias, std=std)
m.weight.data = torch.clamp(m.weight.data, min=-2 * std, max=2 * std)
m.bias.data = torch.clamp(m.bias.data, min=-2 * std, max=2 * std)
elif isinstance(m, nn.LayerNorm):
m.bias.data.zero_()
m.weight.data.fill_(1.0)
def forward(self, x: TensorType["b", "s", "d"]) -> TensorType["b", "s", "d"]:
return self.adapter(x) + x
class ParallelAdapter(Adapter):
def __init__(
self,
module: nn.Module,
dim: int,
downsample_factor: int = 4,
scaled: bool = False,
add_layernorm: bool = False,
activation: nn.Module = nn.ReLU,
):
super().__init__(
dim, downsample_factor, add_layernorm=add_layernorm, activation=activation
)
self.module = module
if scaled:
# init scaling param
self.adapter_scale = nn.Parameter(torch.ones(1))
else:
self.adapter_scale = 1
def forward(self, x: TensorType["b", "s", "d"], **module_kwargs):
y = self.module(x, **module_kwargs)
z = self.adapter(x)
return y + (z * self.adapter_scale)
class ParallelAdapterWrapper(ParallelAdapter):
# used to add an adapter to the attention block
def __init__(
self,
module: nn.Module,
dim: int,
downsample_factor: int = 4,
scaled: bool = False,
add_layernorm: bool = False,
activation: nn.Module = nn.ReLU,
):
super().__init__(
module, dim, downsample_factor, scaled, add_layernorm, activation
)
def forward(self, x: TensorType["b", "s", "d"], *attn_args, **attn_kwargs):
attn_outputs = self.module(x, *attn_args, **attn_kwargs)
attn_output, outputs = (
attn_outputs[0],
attn_outputs[1:],
) # output_attn: a, present, (attentions)
hidden_states = attn_output + (self.adapter(x) * self.adapter_scale)
return (hidden_states,) + outputs
class AdapterWrapper(Adapter):
# used to add an adapter to the attention block
def __init__(
self,
attn_block: nn.Module,
dim: int,
downsample_factor: int = 4,
activation: nn.Module = nn.ReLU,
add_layernorm: bool = False,
):
super().__init__(dim, downsample_factor, activation, add_layernorm)
self.attn_block = attn_block
def forward(self, x: TensorType["b", "s", "d"], *attn_args, **attn_kwargs):
attn_outputs = self.attn_block(x, *attn_args, **attn_kwargs)
attn_output, outputs = (
attn_outputs[0],
attn_outputs[1:],
) # output_attn: a, present, (attentions)
hidden_states = self.adapter(attn_output) + attn_output
return (hidden_states,) + outputs