|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from typing import Type |
|
from .svd_layers import SVDLinear |
|
|
|
from .SALT_layers_3 import SALTLinear , SALTConv2d |
|
from .lora_layers import LoRAConv2D, LoRALinear |
|
|
|
class MLPBlock(nn.Module): |
|
def __init__( |
|
self, |
|
embedding_dim: int, |
|
mlp_dim: int, |
|
act: Type[nn.Module] = nn.GELU, |
|
mlp_transform=False, |
|
use_lora = False |
|
) -> None: |
|
super().__init__() |
|
if use_lora: |
|
self.lin1 = LoRALinear(embedding_dim, mlp_dim) |
|
self.lin2 = LoRALinear(mlp_dim, embedding_dim) |
|
else: |
|
|
|
|
|
rank_value = 500 |
|
|
|
|
|
self.lin1 = SALTLinear(embedding_dim, mlp_dim, rank=rank_value , r_lora=256 , rsLora=False,alpha=1) |
|
self.lin2 = SALTLinear(mlp_dim, embedding_dim, rank=rank_value , r_lora=256 , rsLora=False,alpha=1) |
|
self.act = act() |
|
|
|
def forward(self, x: torch.Tensor, output_loss=True) -> torch.Tensor: |
|
out, reg_loss1 = self.lin1(x) |
|
out, reg_loss2 = self.lin2(self.act(out)) |
|
if output_loss: |
|
return out, (reg_loss1+reg_loss2) |
|
else: |
|
return out |
|
|
|
class MLPBlock2(nn.Module): |
|
def __init__( |
|
self, |
|
embedding_dim: int, |
|
mlp_dim: int, |
|
act: Type[nn.Module] = nn.GELU, |
|
) -> None: |
|
super().__init__() |
|
self.lin1 = nn.Linear(embedding_dim, mlp_dim) |
|
self.lin2 = nn.Linear(mlp_dim, embedding_dim) |
|
self.act = act() |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
out = self.lin1(x) |
|
out = self.lin2(self.act(out)) |
|
return out |
|
|
|
|
|
|
|
|
|
class LayerNorm2d(nn.Module): |
|
def __init__(self, num_channels: int, eps: float = 1e-6) -> None: |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(num_channels)) |
|
self.bias = nn.Parameter(torch.zeros(num_channels)) |
|
self.eps = eps |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
u = x.mean(1, keepdim=True) |
|
s = (x - u).pow(2).mean(1, keepdim=True) |
|
x = (x - u) / torch.sqrt(s + self.eps) |
|
x = self.weight[:, None, None] * x + self.bias[:, None, None] |
|
return x |
|
|