model-editing / nn.py
Charles Lin
Add algorithms from efk codebase
e56055d
raw
history blame contribute delete
No virus
11.2 kB
import torch
import torch.nn as nn
import logging
import time
from utils import factorization
LOG = logging.getLogger(__name__)
class FixableDropout(nn.Module):
def __init__(self, p: float):
super().__init__()
self.p = p
self.mask_cache = {}
self.seed = 0
def resample(self, seed=None):
if seed is None:
seed = int(time.time() * 1e6)
self.mask_cache = {}
self.seed = seed
def forward(self, x):
if self.training:
if x.shape not in self.mask_cache:
generator = torch.Generator(x.device).manual_seed(self.seed)
self.mask_cache[x.shape] = torch.bernoulli(
torch.full_like(x, 1 - self.p), generator=generator
).bool()
self.should_resample = False
x = (self.mask_cache[x.shape] * x) / (1 - self.p)
return x
def extra_repr(self) -> str:
return f"p={self.p}"
class ActMLP(nn.Module):
def __init__(self, hidden_dim, n_hidden):
super().__init__()
self.mlp = MLP(1, 1, hidden_dim, n_hidden, init="id")
def forward(self, x):
return self.mlp(x.view(-1, 1)).view(x.shape)
class LightIDMLP(nn.Module):
def __init__(
self,
indim: int,
outdim: int,
hidden_dim: int,
n_hidden: int,
init: str = None,
act: str = None,
rank: int = None,
):
super().__init__()
LOG.info(f"Building LightIDMLP {[indim] + [rank] + [indim]}")
self.layer1 = nn.Linear(indim, rank)
self.layer2 = nn.Linear(rank, indim)
self.layer2.weight.data[:] = 0
self.layer2.bias = None
def forward(self, x):
h = self.layer1(x).relu()
return x + self.layer2(h)
class IDMLP(nn.Module):
def __init__(
self,
indim: int,
outdim: int,
hidden_dim: int,
n_hidden: int,
init: str = None,
act: str = None,
rank: int = None,
n_modes: int = None
):
super().__init__()
LOG.info(f"Building IDMLP ({init}) {[indim] * (n_hidden + 2)}")
self.layers = nn.ModuleList(
[
LRLinear(indim, indim, rank=rank, relu=idx < n_hidden, init=init, n_modes=n_modes)
for idx in range(n_hidden + 1)
]
)
def forward(self, x, mode=None):
for layer in self.layers:
x = layer(x, mode=mode)
return x
class LatentIDMLP(nn.Module):
def __init__(
self,
indim: int,
outdim: int,
hidden_dim: int,
n_hidden: int,
init: str = None,
act: str = None,
rank: int = None,
):
super().__init__()
LOG.info(f"Building Latent IDMLP ({init}) {[indim] * (n_hidden + 2)}")
self.layers = nn.ModuleList()
self.layers.append(nn.Linear(indim, rank))
for _ in range(n_hidden - 1):
self.layers.append(nn.Linear(rank, rank))
self.layers.append(nn.Linear(rank, outdim))
for layer in self.layers[:-1]:
nn.init.xavier_normal_(layer.weight.data)
if init == "id":
self.layers[-1].weight.data.zero_()
self.layers[-1].bias.data.zero_()
self.init = init
def forward(self, x):
out = x
for layer in self.layers[:-1]:
out = layer(out).relu()
out = self.layers[-1](out)
if self.init == "id":
return out + x
else:
return out
class KLinear(nn.Module):
def __init__(self, inf, outf, pfrac=0.05, symmetric=True, zero_init: bool = True):
super().__init__()
self.inf = inf
in_fact = factorization(inf)
out_fact = factorization(outf)
total_params = 0
self.a, self.b = nn.ParameterList(), nn.ParameterList()
for (i1, i2), (o1, o2) in zip(reversed(in_fact), reversed(out_fact)):
new_params = (o1 * i1 + o2 * i2) * (2 if symmetric else 1)
if (total_params + new_params) / (inf * outf) > pfrac and len(self.a) > 0:
break
total_params += new_params
self.a.append(nn.Parameter(torch.empty(o1, i1)))
if symmetric:
self.a.append(nn.Parameter(torch.empty(o2, i2)))
self.b.append(nn.Parameter(torch.empty(o2, i2)))
if symmetric:
self.b.append(nn.Parameter(torch.empty(o1, i1)))
assert self.a[-1].kron(self.b[-1]).shape == (outf, inf)
for factor in self.a:
nn.init.kaiming_normal_(factor.data)
for factor in self.b:
if zero_init:
factor.data.zero_()
else:
nn.init.kaiming_normal_(factor.data)
print(f"Created ({symmetric}) k-layer using {total_params/(outf*inf):.3f} params, {len(self.a)} comps")
self.bias = nn.Parameter(torch.zeros(outf))
def forward(self, x):
assert x.shape[-1] == self.inf, f"Expected input with {self.inf} dimensions, got {x.shape}"
w = sum([a.kron(b) for a, b in zip(self.a, self.b)]) / (2 * len(self.a) ** 0.5)
y = w @ x.T
if self.bias is not None:
y = y + self.bias
return y
class LRLinear(nn.Module):
def __init__(self, inf, outf, rank: int = None, relu=False, init="id", n_modes=None):
super().__init__()
mid_dim = min(rank, inf)
if init == "id":
self.u = nn.Parameter(torch.zeros(outf, mid_dim))
self.v = nn.Parameter(torch.randn(mid_dim, inf))
elif init == "xavier":
self.u = nn.Parameter(torch.empty(outf, mid_dim))
self.v = nn.Parameter(torch.empty(mid_dim, inf))
nn.init.xavier_uniform_(self.u.data, gain=nn.init.calculate_gain("relu"))
nn.init.xavier_uniform_(self.v.data, gain=1.0)
else:
raise ValueError(f"Unrecognized initialization {init}")
if n_modes is not None:
self.mode_shift = nn.Embedding(n_modes, outf)
self.mode_shift.weight.data.zero_()
self.mode_scale = nn.Embedding(n_modes, outf)
self.mode_scale.weight.data.fill_(1)
self.n_modes = n_modes
self.bias = nn.Parameter(torch.zeros(outf))
self.inf = inf
self.init = init
def forward(self, x, mode=None):
if mode is not None:
assert self.n_modes is not None, "Linear got a mode but wasn't initialized for it"
assert mode < self.n_modes, f"Input mode {mode} outside of range {self.n_modes}"
assert x.shape[-1] == self.inf, f"Input wrong dim ({x.shape}, {self.inf})"
pre_act = (self.u @ (self.v @ x.T)).T
if self.bias is not None:
pre_act += self.bias
if mode is not None:
if not isinstance(mode, torch.Tensor):
mode = torch.tensor(mode).to(x.device)
scale, shift = self.mode_scale(mode), self.mode_shift(mode)
pre_act = pre_act * scale + shift
# need clamp instead of relu so gradient at 0 isn't 0
acts = pre_act.clamp(min=0)
if self.init == "id":
return acts + x
else:
return acts
class MLP(nn.Module):
def __init__(
self,
indim: int,
outdim: int,
hidden_dim: int,
n_hidden: int,
init: str = "xavier_uniform",
act: str = "relu",
rank: int = None,
):
super().__init__()
self.init = init
if act == "relu":
self.act = nn.ReLU()
elif act == "learned":
self.act = ActMLP(10, 1)
else:
raise ValueError(f"Unrecognized activation function '{act}'")
if hidden_dim is None:
hidden_dim = outdim * 2
if init.startswith("id") and outdim != indim:
LOG.info(f"Overwriting outdim ({outdim}) to be indim ({indim})")
outdim = indim
if init == "id":
old_hidden_dim = hidden_dim
if hidden_dim < indim * 2:
hidden_dim = indim * 2
if hidden_dim % indim != 0:
hidden_dim += hidden_dim % indim
if old_hidden_dim != hidden_dim:
LOG.info(
f"Overwriting hidden dim ({old_hidden_dim}) to be {hidden_dim}"
)
if init == "id_alpha":
self.alpha = nn.Parameter(torch.zeros(1, outdim))
dims = [indim] + [hidden_dim] * n_hidden + [outdim]
LOG.info(f"Building ({init}) MLP: {dims} (rank {rank})")
layers = []
for idx, (ind, outd) in enumerate(zip(dims[:-1], dims[1:])):
if rank is None:
layers.append(nn.Linear(ind, outd))
else:
layers.append(LRLinear(ind, outd, rank=rank))
if idx < n_hidden:
layers.append(self.act)
if rank is None:
if init == "id":
if n_hidden > 0:
layers[0].weight.data = torch.eye(indim).repeat(
hidden_dim // indim, 1
)
layers[0].weight.data[hidden_dim // 2:] *= -1
layers[-1].weight.data = torch.eye(outdim).repeat(
1, hidden_dim // outdim
)
layers[-1].weight.data[:, hidden_dim // 2:] *= -1
layers[-1].weight.data /= (hidden_dim // indim) / 2.0
for layer in layers:
if isinstance(layer, nn.Linear):
if init == "ortho":
nn.init.orthogonal_(layer.weight)
elif init == "id":
if layer.weight.shape[0] == layer.weight.shape[1]:
layer.weight.data = torch.eye(hidden_dim)
else:
gain = 3 ** 0.5 if (layer is layers[-1]) else 1.0
nn.init.xavier_uniform_(layer.weight, gain=gain)
layer.bias.data[:] = 0
layers[-1].bias = None
self.mlp = nn.Sequential(*layers)
def forward(self, x):
if self.init == "id_alpha":
return x + self.alpha * self.mlp(x)
else:
return self.mlp(x)
if __name__ == "__main__":
logging.basicConfig(
format="%(asctime)s - %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
level=logging.INFO,
)
m0 = MLP(1000, 1000, 1500, 3)
m1 = MLP(1000, 1000, 1500, 3, init="id")
m2 = MLP(1000, 1000, 1500, 3, init="id_alpha")
m3 = MLP(1000, 1000, 1500, 3, init="ortho", act="learned")
x = 0.01 * torch.randn(999, 1000)
y0 = m0(x)
y1 = m1(x)
y2 = m2(x)
y3 = m3(x)
print("y0", (y0 - x).abs().max())
print("y1", (y1 - x).abs().max())
print("y2", (y2 - x).abs().max())
print("y3", (y3 - x).abs().max())
assert not torch.allclose(y0, x)
assert torch.allclose(y1, x)
assert torch.allclose(y2, x)
assert not torch.allclose(y3, x)
import pdb; pdb.set_trace() # fmt: skip