File size: 7,440 Bytes
d43d2a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class HadaWeight(torch.autograd.Function):
@staticmethod
def forward(ctx, orig_weight, w1a, w1b, w2a, w2b, scale=torch.tensor(1)):
ctx.save_for_backward(w1a, w1b, w2a, w2b, scale)
diff_weight = ((w1a@w1b)*(w2a@w2b)) * scale
return orig_weight.reshape(diff_weight.shape) + diff_weight
@staticmethod
def backward(ctx, grad_out):
(w1a, w1b, w2a, w2b, scale) = ctx.saved_tensors
grad_out = grad_out * scale
temp = grad_out*(w2a@w2b)
grad_w1a = temp @ w1b.T
grad_w1b = w1a.T @ temp
temp = grad_out * (w1a@w1b)
grad_w2a = temp @ w2b.T
grad_w2b = w2a.T @ temp
del temp
return grad_out, grad_w1a, grad_w1b, grad_w2a, grad_w2b, None
class HadaWeightCP(torch.autograd.Function):
@staticmethod
def forward(ctx, orig_weight, t1, w1a, w1b, t2, w2a, w2b, scale=torch.tensor(1)):
ctx.save_for_backward(t1, w1a, w1b, t2, w2a, w2b, scale)
rebuild1 = torch.einsum('i j k l, j r, i p -> p r k l', t1, w1b, w1a)
rebuild2 = torch.einsum('i j k l, j r, i p -> p r k l', t2, w2b, w2a)
return orig_weight + rebuild1*rebuild2*scale
@staticmethod
def backward(ctx, grad_out):
(t1, w1a, w1b, t2, w2a, w2b, scale) = ctx.saved_tensors
grad_out = grad_out*scale
temp = torch.einsum('i j k l, j r -> i r k l', t2, w2b)
rebuild = torch.einsum('i j k l, i r -> r j k l', temp, w2a)
grad_w = rebuild*grad_out
del rebuild
grad_w1a = torch.einsum('r j k l, i j k l -> r i', temp, grad_w)
grad_temp = torch.einsum('i j k l, i r -> r j k l', grad_w, w1a.T)
del grad_w, temp
grad_w1b = torch.einsum('i r k l, i j k l -> r j', t1, grad_temp)
grad_t1 = torch.einsum('i j k l, j r -> i r k l', grad_temp, w1b.T)
del grad_temp
temp = torch.einsum('i j k l, j r -> i r k l', t1, w1b)
rebuild = torch.einsum('i j k l, i r -> r j k l', temp, w1a)
grad_w = rebuild*grad_out
del rebuild
grad_w2a = torch.einsum('r j k l, i j k l -> r i', temp, grad_w)
grad_temp = torch.einsum('i j k l, i r -> r j k l', grad_w, w2a.T)
del grad_w, temp
grad_w2b = torch.einsum('i r k l, i j k l -> r j', t2, grad_temp)
grad_t2 = torch.einsum('i j k l, j r -> i r k l', grad_temp, w2b.T)
del grad_temp
return grad_out, grad_t1, grad_w1a, grad_w1b, grad_t2, grad_w2a, grad_w2b, None
def make_weight(orig_weight, w1a, w1b, w2a, w2b, scale):
return HadaWeight.apply(orig_weight, w1a, w1b, w2a, w2b, scale)
def make_weight_cp(orig_weight, t1, w1a, w1b, t2, w2a, w2b, scale):
return HadaWeightCP.apply(orig_weight, t1, w1a, w1b, t2, w2a, w2b, scale)
class LohaModule(nn.Module):
"""
Hadamard product Implementaion for Low Rank Adaptation
"""
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0, lora_dim=4, alpha=1, dropout=0.,
use_cp=True,
):
""" if alpha == 0 or None, alpha is rank (no scaling). """
super().__init__()
self.lora_name = lora_name
self.lora_dim = lora_dim
self.cp=False
self.shape = org_module.weight.shape
if org_module.__class__.__name__ == 'Conv2d':
in_dim = org_module.in_channels
k_size = org_module.kernel_size
out_dim = org_module.out_channels
self.cp = use_cp and k_size!=(1, 1)
if self.cp:
shape = (out_dim, in_dim, *k_size)
else:
shape = (out_dim, in_dim*k_size[0]*k_size[1])
self.op = F.conv2d
self.extra_args = {
"stride": org_module.stride,
"padding": org_module.padding,
"dilation": org_module.dilation,
"groups": org_module.groups
}
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
shape = (out_dim, in_dim)
self.op = F.linear
self.extra_args = {}
if self.cp:
self.hada_t1 = nn.Parameter(torch.empty(lora_dim, lora_dim, shape[2], shape[3]))
self.hada_w1_a = nn.Parameter(torch.empty(lora_dim, shape[0])) # out_dim, 1-mode
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1])) # in_dim , 2-mode
self.hada_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, shape[2], shape[3]))
self.hada_w2_a = nn.Parameter(torch.empty(lora_dim, shape[0])) # out_dim, 1-mode
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1])) # in_dim , 2-mode
else:
self.hada_w1_a = nn.Parameter(torch.empty(shape[0], lora_dim))
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1]))
self.hada_w2_a = nn.Parameter(torch.empty(shape[0], lora_dim))
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1]))
if dropout:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = nn.Identity()
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
# Need more experiences on init method
if self.cp:
torch.nn.init.normal_(self.hada_t1, std=0.1)
torch.nn.init.normal_(self.hada_t2, std=0.1)
torch.nn.init.normal_(self.hada_w1_b, std=1)
torch.nn.init.normal_(self.hada_w2_b, std=0.01)
torch.nn.init.normal_(self.hada_w1_a, std=1)
torch.nn.init.constant_(self.hada_w2_a, 0)
self.multiplier = multiplier
self.org_module = [org_module] # remove in applying
self.grad_ckpt = False
def apply_to(self):
self.org_module[0].forward = self.forward
def get_weight(self):
d_weight = self.hada_w1_a @ self.hada_w1_b
d_weight *= self.hada_w2_a @ self.hada_w2_b
return (d_weight).reshape(self.shape)
@torch.enable_grad()
def forward(self, x):
# print(torch.mean(torch.abs(self.orig_w1a.to(x.device) - self.hada_w1_a)), end='\r')
if self.cp:
weight = make_weight_cp(
self.org_module[0].weight.data,
self.hada_t1, self.hada_w1_a, self.hada_w1_b,
self.hada_t1, self.hada_w2_a, self.hada_w2_b,
scale = torch.tensor(self.scale*self.multiplier),
)
else:
weight = make_weight(
self.org_module[0].weight.data,
self.hada_w1_a, self.hada_w1_b,
self.hada_w2_a, self.hada_w2_b,
scale = torch.tensor(self.scale*self.multiplier),
)
bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
return self.op(
x,
weight.view(self.shape),
bias,
**self.extra_args
) |