valhalla's picture
add glide repo
891b88f
raw history blame
No virus
3.35 kB
import math
from typing import Callable, Optional
import attr
import torch
import torch.nn as nn
import torch.nn.functional as F
FilterFn = Callable[[torch.Tensor], torch.Tensor]
class ZeroKeyBiasGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x
@staticmethod
def backward(ctx, output_grad):
output_grad = output_grad.clone()
output_grad.chunk(3)[1].zero_()
return output_grad
def zero_key_bias_grad(x: torch.Tensor) -> torch.Tensor:
return ZeroKeyBiasGrad.apply(x)
@attr.s(eq=False, repr=False)
class LayerNorm(nn.Module):
n_state: int = attr.ib()
eps: float = attr.ib(default=1e-6)
device: torch.device = attr.ib(default=torch.device("cuda"))
def __attrs_post_init__(self) -> None:
super().__init__()
self.g = nn.Parameter(torch.ones((self.n_state,), dtype=torch.float32, device=self.device))
self.b = nn.Parameter(torch.zeros((self.n_state,), dtype=torch.float32, device=self.device))
self.g.weight_decay_level = "disable" # type: ignore
self.b.weight_decay_level = "disable" # type: ignore
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm(
x.type(torch.float32), torch.Size((self.n_state,)), self.g, self.b, self.eps
)
@attr.s(eq=False, repr=False)
class Affine(nn.Module):
n_in: int = attr.ib()
n_out: int = attr.ib()
use_bias: bool = attr.ib(default=True)
use_admnet_init: bool = attr.ib(default=False)
std: Optional[float] = attr.ib(default=None)
extra_init_scale: Optional[float] = attr.ib(default=None)
bias_filter_fn: FilterFn = attr.ib(default=lambda x: x)
device: torch.device = attr.ib(default=torch.device("cuda"))
def __attrs_post_init__(self) -> None:
super().__init__()
if not self.use_admnet_init:
self.std = self.std if self.std is not None else math.sqrt(2 / (self.n_in + self.n_out))
self.std = (
self.std if self.extra_init_scale is None else self.std * self.extra_init_scale
)
w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device)
self.w = nn.Parameter(w)
if self.use_bias:
self.b = nn.Parameter(
torch.zeros((self.n_out,), dtype=torch.float32, device=self.device)
)
self.b.weight_decay_level = "disable" # type: ignore
else:
if self.extra_init_scale is not None:
raise ValueError("extra_init_scale incompatible with admnet init")
w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device)
if self.use_bias:
b = torch.empty((self.n_out,), dtype=torch.float32, device=self.device)
self.w = nn.Parameter(w)
if self.use_bias:
self.b = nn.Parameter(b)
self.b.weight_decay_level = "disable" # type: ignore
def forward(self, x: torch.Tensor) -> torch.Tensor:
w = self.w if self.w.dtype == x.dtype else self.w.to(x.dtype)
b = (
self.bias_filter_fn(self.b if self.b.dtype == x.dtype else self.b.to(x.dtype))
if self.use_bias
else None
)
return F.linear(x, w, b)