Abdullah-Nazhat
commited on
Commit
•
1ea09b4
1
Parent(s):
a46aedd
Update approximator.py
Browse files- approximator.py +13 -13
approximator.py
CHANGED
@@ -21,7 +21,7 @@ class FeedForward(nn.Module):
|
|
21 |
return self.net(x)
|
22 |
|
23 |
|
24 |
-
|
25 |
|
26 |
def exists(val):
|
27 |
return val is not None
|
@@ -43,7 +43,7 @@ def moore_penrose_iter_pinv(x, iters = 6):
|
|
43 |
|
44 |
return z
|
45 |
|
46 |
-
|
47 |
|
48 |
class NystromAttention(nn.Module):
|
49 |
def __init__(
|
@@ -83,7 +83,7 @@ class NystromAttention(nn.Module):
|
|
83 |
def forward(self, x, mask = None, return_attn = False):
|
84 |
b, n, _, h, m, iters, eps = *x.shape, self.heads, self.num_landmarks, self.pinv_iterations, self.eps
|
85 |
|
86 |
-
|
87 |
|
88 |
remainder = n % m
|
89 |
if remainder > 0:
|
@@ -93,12 +93,12 @@ class NystromAttention(nn.Module):
|
|
93 |
if exists(mask):
|
94 |
mask = F.pad(mask, (padding, 0), value = False)
|
95 |
|
96 |
-
|
97 |
|
98 |
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
99 |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
100 |
|
101 |
-
|
102 |
|
103 |
if exists(mask):
|
104 |
mask = rearrange(mask, 'b n -> b () n')
|
@@ -106,14 +106,14 @@ class NystromAttention(nn.Module):
|
|
106 |
|
107 |
q = q * self.scale
|
108 |
|
109 |
-
|
110 |
|
111 |
l = ceil(n / m)
|
112 |
landmark_einops_eq = '... (n l) d -> ... n d'
|
113 |
q_landmarks = reduce(q, landmark_einops_eq, 'sum', l = l)
|
114 |
k_landmarks = reduce(k, landmark_einops_eq, 'sum', l = l)
|
115 |
|
116 |
-
|
117 |
|
118 |
divisor = l
|
119 |
if exists(mask):
|
@@ -121,19 +121,19 @@ class NystromAttention(nn.Module):
|
|
121 |
divisor = mask_landmarks_sum[..., None] + eps
|
122 |
mask_landmarks = mask_landmarks_sum > 0
|
123 |
|
124 |
-
|
125 |
|
126 |
q_landmarks /= divisor
|
127 |
k_landmarks /= divisor
|
128 |
|
129 |
-
|
130 |
|
131 |
einops_eq = '... i d, ... j d -> ... i j'
|
132 |
sim1 = einsum(einops_eq, q, k_landmarks)
|
133 |
sim2 = einsum(einops_eq, q_landmarks, k_landmarks)
|
134 |
sim3 = einsum(einops_eq, q_landmarks, k)
|
135 |
|
136 |
-
|
137 |
|
138 |
if exists(mask):
|
139 |
mask_value = -torch.finfo(q.dtype).max
|
@@ -141,19 +141,19 @@ class NystromAttention(nn.Module):
|
|
141 |
sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value)
|
142 |
sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value)
|
143 |
|
144 |
-
|
145 |
|
146 |
attn1, attn2, attn3 = map(lambda t: t.softmax(dim = -1), (sim1, sim2, sim3))
|
147 |
attn2_inv = moore_penrose_iter_pinv(attn2, iters)
|
148 |
|
149 |
out = (attn1 @ attn2_inv) @ (attn3 @ v)
|
150 |
|
151 |
-
|
152 |
|
153 |
if self.residual:
|
154 |
out += self.res_conv(v)
|
155 |
|
156 |
-
|
157 |
|
158 |
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
|
159 |
out = self.to_out(out)
|
|
|
21 |
return self.net(x)
|
22 |
|
23 |
|
24 |
+
|
25 |
|
26 |
def exists(val):
|
27 |
return val is not None
|
|
|
43 |
|
44 |
return z
|
45 |
|
46 |
+
|
47 |
|
48 |
class NystromAttention(nn.Module):
|
49 |
def __init__(
|
|
|
83 |
def forward(self, x, mask = None, return_attn = False):
|
84 |
b, n, _, h, m, iters, eps = *x.shape, self.heads, self.num_landmarks, self.pinv_iterations, self.eps
|
85 |
|
86 |
+
|
87 |
|
88 |
remainder = n % m
|
89 |
if remainder > 0:
|
|
|
93 |
if exists(mask):
|
94 |
mask = F.pad(mask, (padding, 0), value = False)
|
95 |
|
96 |
+
|
97 |
|
98 |
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
99 |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
100 |
|
101 |
+
|
102 |
|
103 |
if exists(mask):
|
104 |
mask = rearrange(mask, 'b n -> b () n')
|
|
|
106 |
|
107 |
q = q * self.scale
|
108 |
|
109 |
+
|
110 |
|
111 |
l = ceil(n / m)
|
112 |
landmark_einops_eq = '... (n l) d -> ... n d'
|
113 |
q_landmarks = reduce(q, landmark_einops_eq, 'sum', l = l)
|
114 |
k_landmarks = reduce(k, landmark_einops_eq, 'sum', l = l)
|
115 |
|
116 |
+
|
117 |
|
118 |
divisor = l
|
119 |
if exists(mask):
|
|
|
121 |
divisor = mask_landmarks_sum[..., None] + eps
|
122 |
mask_landmarks = mask_landmarks_sum > 0
|
123 |
|
124 |
+
|
125 |
|
126 |
q_landmarks /= divisor
|
127 |
k_landmarks /= divisor
|
128 |
|
129 |
+
|
130 |
|
131 |
einops_eq = '... i d, ... j d -> ... i j'
|
132 |
sim1 = einsum(einops_eq, q, k_landmarks)
|
133 |
sim2 = einsum(einops_eq, q_landmarks, k_landmarks)
|
134 |
sim3 = einsum(einops_eq, q_landmarks, k)
|
135 |
|
136 |
+
|
137 |
|
138 |
if exists(mask):
|
139 |
mask_value = -torch.finfo(q.dtype).max
|
|
|
141 |
sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value)
|
142 |
sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value)
|
143 |
|
144 |
+
|
145 |
|
146 |
attn1, attn2, attn3 = map(lambda t: t.softmax(dim = -1), (sim1, sim2, sim3))
|
147 |
attn2_inv = moore_penrose_iter_pinv(attn2, iters)
|
148 |
|
149 |
out = (attn1 @ attn2_inv) @ (attn3 @ v)
|
150 |
|
151 |
+
|
152 |
|
153 |
if self.residual:
|
154 |
out += self.res_conv(v)
|
155 |
|
156 |
+
|
157 |
|
158 |
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
|
159 |
out = self.to_out(out)
|