ccdv commited on
Commit
e23e338
1 Parent(s): 27d4330

fix case where sparsity == block_size

Browse files
Files changed (1) hide show
  1. modeling_lsg_pegasus.py +14 -4
modeling_lsg_pegasus.py CHANGED
@@ -207,7 +207,9 @@ class LSGAttentionProduct(nn.Module):
207
 
208
  # Shape of blocks
209
  self.local_shapes = (self.block_size*3, self.block_size)
210
- if self.sparsity_factor > 0:
 
 
211
  self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
212
 
213
  self.attention = BaseAttentionProduct(config)
@@ -306,9 +308,12 @@ class LSGAttentionProduct(nn.Module):
306
 
307
  size, step = self.sparse_shapes
308
 
 
 
 
309
  # n, h, t, d*2 + 1
310
  size = size*2
311
- s = (size - step) // 2
312
 
313
  # Pad before block reshaping
314
  if is_attn_mask:
@@ -326,11 +331,16 @@ class LSGAttentionProduct(nn.Module):
326
  # Make blocks
327
  hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2)
328
 
 
 
 
 
329
  # Indexes for selection
330
- u = (size - self.block_size * 3 // self.sparsity_factor) // 2
331
  s = self.sparse_block_size
332
 
333
- return torch.cat([hidden_states[..., u-s:u, :], hidden_states[..., -u:-u+s, :]], dim=-2)
 
334
 
335
  def cat_global_sparse_local_tokens(self, x_global, x_sparse=None, x_local=None, dim=-2):
336
 
207
 
208
  # Shape of blocks
209
  self.local_shapes = (self.block_size*3, self.block_size)
210
+ if self.sparse_block_size and self.sparsity_factor > 0:
211
+ assert self.block_size % self.sparsity_factor == 0, "block_size must be divisible by sparsity_factor"
212
+ assert self.block_size//self.sparsity_factor >= 1, "Config is wrong, make sure block_size >= sparsity_factor"
213
  self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
214
 
215
  self.attention = BaseAttentionProduct(config)
308
 
309
  size, step = self.sparse_shapes
310
 
311
+ # In case of odd case
312
+ odd_offset = (step % 2)
313
+
314
  # n, h, t, d*2 + 1
315
  size = size*2
316
+ s = (size - step) // 2 + odd_offset
317
 
318
  # Pad before block reshaping
319
  if is_attn_mask:
331
  # Make blocks
332
  hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2)
333
 
334
+ # Fix case where block_size == sparsify_factor
335
+ if odd_offset:
336
+ hidden_states = hidden_states[..., :-1, :, :]
337
+
338
  # Indexes for selection
339
+ u = (size - self.block_size * 3 // self.sparsity_factor) // 2 + odd_offset
340
  s = self.sparse_block_size
341
 
342
+ u_ = u + odd_offset
343
+ return torch.cat([hidden_states[..., u-s:u, :], hidden_states[..., -u_:-u_+s, :]], dim=-2)
344
 
345
  def cat_global_sparse_local_tokens(self, x_global, x_sparse=None, x_local=None, dim=-2):
346