QHL067's picture
working
f9567e5
from functools import reduce
import math
import torch
import torch.nn as nn
import numpy as np
from .kernel import exported_tdp
import torch.nn.functional as F
from functools import partial
from timm.models.layers import trunc_normal_
class TimeDependentParameter(nn.Module):
def __init__(self, shape, init_fn):
super().__init__()
self.shape = shape
w = torch.empty(*shape)
init_fn(w)
self.param0 = nn.Parameter(w.clone().detach())
self.param1 = nn.Parameter(w.clone().detach())
self.nodecay_weight = nn.Parameter(torch.zeros(*shape))
self.nodecay_bias = nn.Parameter(torch.zeros(*shape))
self.curr_weight = None
def forward(self):
weight = self.curr_weight
# self.curr_weight = None
return weight
def __repr__(self):
return f"TimeDependentParameter(shape={self.shape})"
@staticmethod
def seed_time(model, log_snr):
assert log_snr.dim() == 1
if torch.all(log_snr == log_snr[0]):
log_snr = log_snr[0][None]
time_condition = log_snr / 4.0
tdp_list = [module for module in model.modules() if isinstance(module, TimeDependentParameter)]
for tdp in tdp_list:
tdp.curr_weight = exported_tdp(tdp.param0, tdp.param1, tdp.nodecay_weight + 1, tdp.nodecay_bias, time_condition, custom = False)
class LayerNorm(nn.Module):
def __init__(self, dim, num_groups = 1, eps = 1e-05):
super().__init__()
self.eps = eps
self.dim = dim
self.num_groups = num_groups
self.weight = TimeDependentParameter((dim, ), nn.init.ones_)
self.bias = TimeDependentParameter((dim, ), nn.init.zeros_)
def _forward(self, x):
weight, bias = self.weight(), self.bias()
assert weight.shape[0] == bias.shape[0]
assert x.shape[-1] == self.dim
if weight.shape[0] == 1:
x = F.layer_norm(x, (self.dim, ), weight = weight[0], bias = bias[0], eps = self.eps)
else:
assert x.shape[0] == weight.shape[0]
x = F.layer_norm(x, (self.dim, ), eps = self.eps)
x = torch.addcmul(bias[:, None, :], weight[:, None, :], x)
return x
def forward(self, x):
original_shape = x.shape
batch_size = x.shape[0]
assert self.dim == x.shape[-1]
x = x.reshape(batch_size, -1, self.dim)
x = self._forward(x)
x = x.reshape(*original_shape)
return x
class Linear(nn.Module):
def __init__(self, din, dout, bias = True, weight_init_fn = partial(trunc_normal_, std = 0.02)):
super().__init__()
self.din = din
self.dout = dout
self.weight = TimeDependentParameter((din, dout), weight_init_fn)
if bias:
self.bias = TimeDependentParameter((dout, ), nn.init.zeros_)
else:
self.bias = None
def _forward(self, x):
weight = self.weight()
bias = self.bias() if self.bias is not None else None
# if weight.shape[0] == 1:
# B, L, D = x.shape
# if bias is not None:
# assert weight.shape[0] == bias.shape[0]
# x = torch.addmm(bias, x.reshape(B * L, D), weight[0])
# else:
# x = torch.matmul(x.reshape(B * L, D), weight[0])
# x = x.reshape(B, L, -1)
# else:
if bias is not None:
x = torch.baddbmm(bias[:, None, :], x, weight)
else:
x = torch.bmm(x, weight)
return x
def forward(self, x):
original_shape = x.shape
batch_size = x.shape[0]
x = x.reshape(batch_size, -1, self.din)
x = self._forward(x)
x = x.reshape(*(list(original_shape[:-1]) + [self.dout]))
return x
class RMSNorm(nn.Module):
def __init__(self, d, p=-1., eps=1e-8, bias=False):
"""
Root Mean Square Layer Normalization
:param d: model size
:param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled)
:param eps: epsilon value, default 1e-8
:param bias: whether use bias term for RMSNorm, disabled by
default because RMSNorm doesn't enforce re-centering invariance.
"""
super(RMSNorm, self).__init__()
self.eps = eps
self.d = d
self.p = p
self.bias = bias
self.scale = nn.Parameter(torch.ones(d))
self.register_parameter("scale", self.scale)
if self.bias:
self.offset = nn.Parameter(torch.zeros(d))
self.register_parameter("offset", self.offset)
def forward(self, x):
if self.p < 0. or self.p > 1.:
norm_x = x.norm(2, dim=-1, keepdim=True)
d_x = self.d
else:
partial_size = int(self.d * self.p)
partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1)
norm_x = partial_x.norm(2, dim=-1, keepdim=True)
d_x = partial_size
rms_x = norm_x * d_x ** (-1. / 2)
x_normed = x / (rms_x + self.eps)
if self.bias:
return self.scale * x_normed + self.offset
return self.scale * x_normed
class TDRMSNorm(nn.Module):
def __init__(self, d, p=-1., eps=1e-8, bias=False):
"""
Root Mean Square Layer Normalization
:param d: model size
:param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled)
:param eps: epsilon value, default 1e-8
:param bias: whether use bias term for RMSNorm, disabled by
default because RMSNorm doesn't enforce re-centering invariance.
"""
super(TDRMSNorm, self).__init__()
self.eps = eps
self.d = d
self.p = p
self.bias = bias
# self.scale = nn.Parameter(torch.ones(d))
self.scale = TimeDependentParameter((d, ), nn.init.ones_)
# self.register_parameter("scale", self.scale)
if self.bias:
# self.offset = nn.Parameter(torch.zeros(d))
self.offset = TimeDependentParameter((d, ), nn.init.zeros_)
# self.register_parameter("offset", self.offset)
def forward(self, x):
if self.p < 0. or self.p > 1.:
norm_x = x.norm(2, dim=-1, keepdim=True)
d_x = self.d
else:
partial_size = int(self.d * self.p)
partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1)
norm_x = partial_x.norm(2, dim=-1, keepdim=True)
d_x = partial_size
rms_x = norm_x * d_x ** (-1. / 2)
x_normed = x / (rms_x + self.eps)
_scale = self.scale()
if self.bias:
# return self.scale * x_normed + self.offset
_offset = self.offset()
if _scale.shape[0] == 1:
return _scale[0] * x_normed + _offset[0]
elif x_normed.dim() == 3:
return torch.addcmul(_offset[:, None, :], _scale[:, None, :], x_normed)
elif x_normed.dim() == 4:
return torch.addcmul(_offset[:, None, None, :], _scale[:, None, None, :], x_normed)
else:
raise NotImplementedError
# return self.scale * x_normed
if _scale.shape[0] == 1:
return _scale[0] * x_normed
elif x_normed.dim() == 3:
return _scale[:, None, :] * x_normed
elif x_normed.dim() == 4:
return _scale[:, None, None, :] * x_normed
else:
raise NotImplementedError
def zero_init(layer):
nn.init.zeros_(layer.weight)
if layer.bias is not None:
nn.init.zeros_(layer.bias)
return layer
def rms_norm(x, scale, eps):
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
return x * scale.to(x.dtype)
class AdaRMSNorm(nn.Module):
def __init__(self, features, cond_features, eps=1e-6):
super().__init__()
self.eps = eps
self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
def extra_repr(self):
return f"eps={self.eps},"
def forward(self, x, cond):
return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
class QKNorm(nn.Module):
def __init__(self, n_heads, eps=1e-6, max_scale=100.0):
super().__init__()
self.eps = eps
self.max_scale = math.log(max_scale)
self.scale = nn.Parameter(torch.full((n_heads,), math.log(10.0)))
self.proj_()
def extra_repr(self):
return f"n_heads={self.scale.shape[0]}, eps={self.eps}"
@torch.no_grad()
def proj_(self):
"""Modify the scale in-place so it doesn't get "stuck" with zero gradient if it's clamped
to the max value."""
self.scale.clamp_(max=self.max_scale)
def forward(self, x):
self.proj_()
scale = torch.exp(0.5 * self.scale - 0.25 * math.log(x.shape[-1]))
return rms_norm(x, scale[:, None, None], self.eps)