|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .general import FUNC_LIST |
|
|
|
|
|
class HadaWeight(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, w1d, w1u, w2d, w2u, scale=torch.tensor(1)): |
|
ctx.save_for_backward(w1d, w1u, w2d, w2u, scale) |
|
diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale |
|
return diff_weight |
|
|
|
@staticmethod |
|
def backward(ctx, grad_out): |
|
(w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors |
|
grad_out = grad_out * scale |
|
temp = grad_out * (w2u @ w2d) |
|
grad_w1u = temp @ w1d.T |
|
grad_w1d = w1u.T @ temp |
|
|
|
temp = grad_out * (w1u @ w1d) |
|
grad_w2u = temp @ w2d.T |
|
grad_w2d = w2u.T @ temp |
|
|
|
del temp |
|
return grad_w1d, grad_w1u, grad_w2d, grad_w2u, None |
|
|
|
|
|
class HadaWeightTucker(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, t1, w1d, w1u, t2, w2d, w2u, scale=torch.tensor(1)): |
|
ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale) |
|
|
|
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u) |
|
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u) |
|
|
|
return rebuild1 * rebuild2 * scale |
|
|
|
@staticmethod |
|
def backward(ctx, grad_out): |
|
(t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors |
|
grad_out = grad_out * scale |
|
|
|
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d) |
|
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u) |
|
|
|
grad_w = rebuild * grad_out |
|
del rebuild |
|
|
|
grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w) |
|
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T) |
|
del grad_w, temp |
|
|
|
grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp) |
|
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T) |
|
del grad_temp |
|
|
|
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d) |
|
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u) |
|
|
|
grad_w = rebuild * grad_out |
|
del rebuild |
|
|
|
grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w) |
|
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T) |
|
del grad_w, temp |
|
|
|
grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp) |
|
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T) |
|
del grad_temp |
|
return grad_t1, grad_w1d, grad_w1u, grad_t2, grad_w2d, grad_w2u, None |
|
|
|
|
|
def make_weight(w1d, w1u, w2d, w2u, scale): |
|
return HadaWeight.apply(w1d, w1u, w2d, w2u, scale) |
|
|
|
|
|
def make_weight_tucker(t1, w1d, w1u, t2, w2d, w2u, scale): |
|
return HadaWeightTucker.apply(t1, w1d, w1u, t2, w2d, w2u, scale) |
|
|
|
|
|
def weight_gen(org_weight, rank, tucker=True): |
|
"""### weight_gen |
|
|
|
Args: |
|
org_weight (torch.Tensor): the weight tensor |
|
rank (int): low rank |
|
|
|
Returns: |
|
torch.Tensor: w1d, w2d, w1u, w2u[, t1, t2] |
|
""" |
|
out_dim, in_dim, *k = org_weight.shape |
|
if k and tucker: |
|
w1d = torch.empty(rank, in_dim) |
|
w1u = torch.empty(rank, out_dim) |
|
t1 = torch.empty(rank, rank, *k) |
|
w2d = torch.empty(rank, in_dim) |
|
w2u = torch.empty(rank, out_dim) |
|
t2 = torch.empty(rank, rank, *k) |
|
nn.init.normal_(t1, std=0.1) |
|
nn.init.normal_(t2, std=0.1) |
|
else: |
|
w1d = torch.empty(rank, in_dim) |
|
w1u = torch.empty(out_dim, rank) |
|
w2d = torch.empty(rank, in_dim) |
|
w2u = torch.empty(out_dim, rank) |
|
t1 = t2 = None |
|
nn.init.normal_(w1d, std=1) |
|
nn.init.constant_(w1u, 0) |
|
nn.init.normal_(w2d, std=1) |
|
nn.init.normal_(w2u, std=0.1) |
|
return w1d, w1u, w2d, w2u, t1, t2 |
|
|
|
|
|
def diff_weight(*weights, gamma=1.0): |
|
"""### diff_weight |
|
|
|
Get ΔW = BA, where BA is low rank decomposition |
|
|
|
Args: |
|
wegihts (tuple[torch.Tensor]): (w1d, w2d, w1u, w2u[, t1, t2]) |
|
gamma (float, optional): scale factor, normally alpha/rank here |
|
|
|
Returns: |
|
torch.Tensor: ΔW |
|
""" |
|
w1d, w1u, w2d, w2u, t1, t2 = weights |
|
if t1 is not None and t2 is not None: |
|
R, I = w1d.shape |
|
R, O = w1u.shape |
|
R, R, *k = t1.shape |
|
result = make_weight_tucker(t1, w1d, w1u, t2, w2d, w2u, gamma) |
|
else: |
|
R, I, *k = w1d.shape |
|
O, R, *_ = w1u.shape |
|
w1d = w1d.reshape(w1d.size(0), -1) |
|
w1u = w1u.reshape(-1, w1u.size(1)) |
|
w2d = w2d.reshape(w2d.size(0), -1) |
|
w2u = w2u.reshape(-1, w2u.size(1)) |
|
result = make_weight(w1d, w1u, w2d, w2u, gamma) |
|
|
|
result = result.reshape(O, I, *k) |
|
return result |
|
|
|
|
|
def bypass_forward_diff(x, org_out, *weights, gamma=1.0, extra_args={}): |
|
"""### bypass_forward_diff |
|
|
|
Args: |
|
x (torch.Tensor): input tensor |
|
weights (tuple[torch.Tensor]): (w1d, w2d, w1u, w2u[, t1, t2]) |
|
gamma (float, optional): scale factor, normally alpha/rank here |
|
extra_args (dict, optional): extra args for forward func, \ |
|
e.g. padding, stride for Conv1/2/3d |
|
|
|
Returns: |
|
torch.Tensor: output tensor |
|
""" |
|
w1d, w1u, w2d, w2u, t1, t2 = weights |
|
diff_w = diff_weight(w1d, w1u, w2d, w2u, t1, t2, gamma) |
|
return FUNC_LIST[w1d.dim() if t1 is None else t1.dim()](x, diff_w, **extra_args) |
|
|