Abdullah-Nazhat commited on
Commit
1ea09b4
1 Parent(s): a46aedd

Update approximator.py

Browse files
Files changed (1) hide show
  1. approximator.py +13 -13
approximator.py CHANGED
@@ -21,7 +21,7 @@ class FeedForward(nn.Module):
21
  return self.net(x)
22
 
23
 
24
- # helper functions
25
 
26
  def exists(val):
27
  return val is not None
@@ -43,7 +43,7 @@ def moore_penrose_iter_pinv(x, iters = 6):
43
 
44
  return z
45
 
46
- # main attention class
47
 
48
  class NystromAttention(nn.Module):
49
  def __init__(
@@ -83,7 +83,7 @@ class NystromAttention(nn.Module):
83
  def forward(self, x, mask = None, return_attn = False):
84
  b, n, _, h, m, iters, eps = *x.shape, self.heads, self.num_landmarks, self.pinv_iterations, self.eps
85
 
86
- # pad so that sequence can be evenly divided into m landmarks
87
 
88
  remainder = n % m
89
  if remainder > 0:
@@ -93,12 +93,12 @@ class NystromAttention(nn.Module):
93
  if exists(mask):
94
  mask = F.pad(mask, (padding, 0), value = False)
95
 
96
- # derive query, keys, values
97
 
98
  q, k, v = self.to_qkv(x).chunk(3, dim = -1)
99
  q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
100
 
101
- # set masked positions to 0 in queries, keys, values
102
 
103
  if exists(mask):
104
  mask = rearrange(mask, 'b n -> b () n')
@@ -106,14 +106,14 @@ class NystromAttention(nn.Module):
106
 
107
  q = q * self.scale
108
 
109
- # generate landmarks by sum reduction, and then calculate mean using the mask
110
 
111
  l = ceil(n / m)
112
  landmark_einops_eq = '... (n l) d -> ... n d'
113
  q_landmarks = reduce(q, landmark_einops_eq, 'sum', l = l)
114
  k_landmarks = reduce(k, landmark_einops_eq, 'sum', l = l)
115
 
116
- # calculate landmark mask, and also get sum of non-masked elements in preparation for masked mean
117
 
118
  divisor = l
119
  if exists(mask):
@@ -121,19 +121,19 @@ class NystromAttention(nn.Module):
121
  divisor = mask_landmarks_sum[..., None] + eps
122
  mask_landmarks = mask_landmarks_sum > 0
123
 
124
- # masked mean (if mask exists)
125
 
126
  q_landmarks /= divisor
127
  k_landmarks /= divisor
128
 
129
- # similarities
130
 
131
  einops_eq = '... i d, ... j d -> ... i j'
132
  sim1 = einsum(einops_eq, q, k_landmarks)
133
  sim2 = einsum(einops_eq, q_landmarks, k_landmarks)
134
  sim3 = einsum(einops_eq, q_landmarks, k)
135
 
136
- # masking
137
 
138
  if exists(mask):
139
  mask_value = -torch.finfo(q.dtype).max
@@ -141,19 +141,19 @@ class NystromAttention(nn.Module):
141
  sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value)
142
  sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value)
143
 
144
- # eq (15) in the paper and aggregate values
145
 
146
  attn1, attn2, attn3 = map(lambda t: t.softmax(dim = -1), (sim1, sim2, sim3))
147
  attn2_inv = moore_penrose_iter_pinv(attn2, iters)
148
 
149
  out = (attn1 @ attn2_inv) @ (attn3 @ v)
150
 
151
- # add depth-wise conv residual of values
152
 
153
  if self.residual:
154
  out += self.res_conv(v)
155
 
156
- # merge and combine heads
157
 
158
  out = rearrange(out, 'b h n d -> b n (h d)', h = h)
159
  out = self.to_out(out)
 
21
  return self.net(x)
22
 
23
 
24
+
25
 
26
  def exists(val):
27
  return val is not None
 
43
 
44
  return z
45
 
46
+
47
 
48
  class NystromAttention(nn.Module):
49
  def __init__(
 
83
  def forward(self, x, mask = None, return_attn = False):
84
  b, n, _, h, m, iters, eps = *x.shape, self.heads, self.num_landmarks, self.pinv_iterations, self.eps
85
 
86
+
87
 
88
  remainder = n % m
89
  if remainder > 0:
 
93
  if exists(mask):
94
  mask = F.pad(mask, (padding, 0), value = False)
95
 
96
+
97
 
98
  q, k, v = self.to_qkv(x).chunk(3, dim = -1)
99
  q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
100
 
101
+
102
 
103
  if exists(mask):
104
  mask = rearrange(mask, 'b n -> b () n')
 
106
 
107
  q = q * self.scale
108
 
109
+
110
 
111
  l = ceil(n / m)
112
  landmark_einops_eq = '... (n l) d -> ... n d'
113
  q_landmarks = reduce(q, landmark_einops_eq, 'sum', l = l)
114
  k_landmarks = reduce(k, landmark_einops_eq, 'sum', l = l)
115
 
116
+
117
 
118
  divisor = l
119
  if exists(mask):
 
121
  divisor = mask_landmarks_sum[..., None] + eps
122
  mask_landmarks = mask_landmarks_sum > 0
123
 
124
+
125
 
126
  q_landmarks /= divisor
127
  k_landmarks /= divisor
128
 
129
+
130
 
131
  einops_eq = '... i d, ... j d -> ... i j'
132
  sim1 = einsum(einops_eq, q, k_landmarks)
133
  sim2 = einsum(einops_eq, q_landmarks, k_landmarks)
134
  sim3 = einsum(einops_eq, q_landmarks, k)
135
 
136
+
137
 
138
  if exists(mask):
139
  mask_value = -torch.finfo(q.dtype).max
 
141
  sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value)
142
  sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value)
143
 
144
+
145
 
146
  attn1, attn2, attn3 = map(lambda t: t.softmax(dim = -1), (sim1, sim2, sim3))
147
  attn2_inv = moore_penrose_iter_pinv(attn2, iters)
148
 
149
  out = (attn1 @ attn2_inv) @ (attn3 @ v)
150
 
151
+
152
 
153
  if self.residual:
154
  out += self.res_conv(v)
155
 
156
+
157
 
158
  out = rearrange(out, 'b h n d -> b n (h d)', h = h)
159
  out = self.to_out(out)