tonyshark's picture
Upload 132 files
cc69848 verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .general import FUNC_LIST
class HadaWeight(torch.autograd.Function):
@staticmethod
def forward(ctx, w1d, w1u, w2d, w2u, scale=torch.tensor(1)):
ctx.save_for_backward(w1d, w1u, w2d, w2u, scale)
diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale
return diff_weight
@staticmethod
def backward(ctx, grad_out):
(w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors
grad_out = grad_out * scale
temp = grad_out * (w2u @ w2d)
grad_w1u = temp @ w1d.T
grad_w1d = w1u.T @ temp
temp = grad_out * (w1u @ w1d)
grad_w2u = temp @ w2d.T
grad_w2d = w2u.T @ temp
del temp
return grad_w1d, grad_w1u, grad_w2d, grad_w2u, None
class HadaWeightTucker(torch.autograd.Function):
@staticmethod
def forward(ctx, t1, w1d, w1u, t2, w2d, w2u, scale=torch.tensor(1)):
ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale)
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u)
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u)
return rebuild1 * rebuild2 * scale
@staticmethod
def backward(ctx, grad_out):
(t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors
grad_out = grad_out * scale
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d)
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u)
grad_w = rebuild * grad_out
del rebuild
grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T)
del grad_w, temp
grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T)
del grad_temp
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d)
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u)
grad_w = rebuild * grad_out
del rebuild
grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T)
del grad_w, temp
grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T)
del grad_temp
return grad_t1, grad_w1d, grad_w1u, grad_t2, grad_w2d, grad_w2u, None
def make_weight(w1d, w1u, w2d, w2u, scale):
return HadaWeight.apply(w1d, w1u, w2d, w2u, scale)
def make_weight_tucker(t1, w1d, w1u, t2, w2d, w2u, scale):
return HadaWeightTucker.apply(t1, w1d, w1u, t2, w2d, w2u, scale)
def weight_gen(org_weight, rank, tucker=True):
"""### weight_gen
Args:
org_weight (torch.Tensor): the weight tensor
rank (int): low rank
Returns:
torch.Tensor: w1d, w2d, w1u, w2u[, t1, t2]
"""
out_dim, in_dim, *k = org_weight.shape
if k and tucker:
w1d = torch.empty(rank, in_dim)
w1u = torch.empty(rank, out_dim)
t1 = torch.empty(rank, rank, *k)
w2d = torch.empty(rank, in_dim)
w2u = torch.empty(rank, out_dim)
t2 = torch.empty(rank, rank, *k)
nn.init.normal_(t1, std=0.1)
nn.init.normal_(t2, std=0.1)
else:
w1d = torch.empty(rank, in_dim)
w1u = torch.empty(out_dim, rank)
w2d = torch.empty(rank, in_dim)
w2u = torch.empty(out_dim, rank)
t1 = t2 = None
nn.init.normal_(w1d, std=1)
nn.init.constant_(w1u, 0)
nn.init.normal_(w2d, std=1)
nn.init.normal_(w2u, std=0.1)
return w1d, w1u, w2d, w2u, t1, t2
def diff_weight(*weights, gamma=1.0):
"""### diff_weight
Get ΔW = BA, where BA is low rank decomposition
Args:
wegihts (tuple[torch.Tensor]): (w1d, w2d, w1u, w2u[, t1, t2])
gamma (float, optional): scale factor, normally alpha/rank here
Returns:
torch.Tensor: ΔW
"""
w1d, w1u, w2d, w2u, t1, t2 = weights
if t1 is not None and t2 is not None:
R, I = w1d.shape
R, O = w1u.shape
R, R, *k = t1.shape
result = make_weight_tucker(t1, w1d, w1u, t2, w2d, w2u, gamma)
else:
R, I, *k = w1d.shape
O, R, *_ = w1u.shape
w1d = w1d.reshape(w1d.size(0), -1)
w1u = w1u.reshape(-1, w1u.size(1))
w2d = w2d.reshape(w2d.size(0), -1)
w2u = w2u.reshape(-1, w2u.size(1))
result = make_weight(w1d, w1u, w2d, w2u, gamma)
result = result.reshape(O, I, *k)
return result
def bypass_forward_diff(x, org_out, *weights, gamma=1.0, extra_args={}):
"""### bypass_forward_diff
Args:
x (torch.Tensor): input tensor
weights (tuple[torch.Tensor]): (w1d, w2d, w1u, w2u[, t1, t2])
gamma (float, optional): scale factor, normally alpha/rank here
extra_args (dict, optional): extra args for forward func, \
e.g. padding, stride for Conv1/2/3d
Returns:
torch.Tensor: output tensor
"""
w1d, w1u, w2d, w2u, t1, t2 = weights
diff_w = diff_weight(w1d, w1u, w2d, w2u, t1, t2, gamma)
return FUNC_LIST[w1d.dim() if t1 is None else t1.dim()](x, diff_w, **extra_args)