fix case where sparsity == block_size
Browse files- modeling_lsg_pegasus.py +14 -4
modeling_lsg_pegasus.py
CHANGED
@@ -207,7 +207,9 @@ class LSGAttentionProduct(nn.Module):
|
|
207 |
|
208 |
# Shape of blocks
|
209 |
self.local_shapes = (self.block_size*3, self.block_size)
|
210 |
-
if self.sparsity_factor > 0:
|
|
|
|
|
211 |
self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
|
212 |
|
213 |
self.attention = BaseAttentionProduct(config)
|
@@ -306,9 +308,12 @@ class LSGAttentionProduct(nn.Module):
|
|
306 |
|
307 |
size, step = self.sparse_shapes
|
308 |
|
|
|
|
|
|
|
309 |
# n, h, t, d*2 + 1
|
310 |
size = size*2
|
311 |
-
s = (size - step) // 2
|
312 |
|
313 |
# Pad before block reshaping
|
314 |
if is_attn_mask:
|
@@ -326,11 +331,16 @@ class LSGAttentionProduct(nn.Module):
|
|
326 |
# Make blocks
|
327 |
hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2)
|
328 |
|
|
|
|
|
|
|
|
|
329 |
# Indexes for selection
|
330 |
-
u = (size - self.block_size * 3 // self.sparsity_factor) // 2
|
331 |
s = self.sparse_block_size
|
332 |
|
333 |
-
|
|
|
334 |
|
335 |
def cat_global_sparse_local_tokens(self, x_global, x_sparse=None, x_local=None, dim=-2):
|
336 |
|
207 |
|
208 |
# Shape of blocks
|
209 |
self.local_shapes = (self.block_size*3, self.block_size)
|
210 |
+
if self.sparse_block_size and self.sparsity_factor > 0:
|
211 |
+
assert self.block_size % self.sparsity_factor == 0, "block_size must be divisible by sparsity_factor"
|
212 |
+
assert self.block_size//self.sparsity_factor >= 1, "Config is wrong, make sure block_size >= sparsity_factor"
|
213 |
self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
|
214 |
|
215 |
self.attention = BaseAttentionProduct(config)
|
308 |
|
309 |
size, step = self.sparse_shapes
|
310 |
|
311 |
+
# In case of odd case
|
312 |
+
odd_offset = (step % 2)
|
313 |
+
|
314 |
# n, h, t, d*2 + 1
|
315 |
size = size*2
|
316 |
+
s = (size - step) // 2 + odd_offset
|
317 |
|
318 |
# Pad before block reshaping
|
319 |
if is_attn_mask:
|
331 |
# Make blocks
|
332 |
hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2)
|
333 |
|
334 |
+
# Fix case where block_size == sparsify_factor
|
335 |
+
if odd_offset:
|
336 |
+
hidden_states = hidden_states[..., :-1, :, :]
|
337 |
+
|
338 |
# Indexes for selection
|
339 |
+
u = (size - self.block_size * 3 // self.sparsity_factor) // 2 + odd_offset
|
340 |
s = self.sparse_block_size
|
341 |
|
342 |
+
u_ = u + odd_offset
|
343 |
+
return torch.cat([hidden_states[..., u-s:u, :], hidden_states[..., -u_:-u_+s, :]], dim=-2)
|
344 |
|
345 |
def cat_global_sparse_local_tokens(self, x_global, x_sparse=None, x_local=None, dim=-2):
|
346 |
|