ccdv commited on
Commit
c6f7c88
1 Parent(s): 73a3f59

replace -1e4 masks

Browse files
Files changed (2) hide show
  1. README.md +1 -0
  2. modeling_lsg_bert.py +11 -7
README.md CHANGED
@@ -47,6 +47,7 @@ You can change various parameters like :
47
  * local block size (block_size=128)
48
  * sparse block size (sparse_block_size=128)
49
  * sparsity factor (sparsity_factor=2)
 
50
  * see config.json file
51
 
52
  Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
 
47
  * local block size (block_size=128)
48
  * sparse block size (sparse_block_size=128)
49
  * sparsity factor (sparsity_factor=2)
50
+ * mask_first_token (mask first token since it is redundant with the first global token)
51
  * see config.json file
52
 
53
  Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
modeling_lsg_bert.py CHANGED
@@ -183,7 +183,11 @@ class CausalAttentionProduct(nn.Module):
183
 
184
  # Add causal mask
185
  causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
186
- causal_mask = torch.tril(torch.ones(*causal_shape, device=attention_mask.device), diagonal=-1).T * (-10000)
 
 
 
 
187
  attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
188
 
189
  del attention_mask
@@ -301,7 +305,7 @@ class LSGAttentionProduct(nn.Module):
301
 
302
  # Pad before block reshaping
303
  if is_attn_mask:
304
- pad_value = -10000
305
  hidden_states = hidden_states.transpose(-1, -2)
306
  else:
307
  pad_value = 0
@@ -334,7 +338,7 @@ class LSGAttentionProduct(nn.Module):
334
 
335
  # Pad before block reshaping
336
  if is_attn_mask:
337
- pad_value = -10000
338
  hidden_states = hidden_states.transpose(-1, -2)
339
  else:
340
  pad_value = 0
@@ -525,7 +529,7 @@ class LSGSelfAttention(BaseSelfAttention):
525
  keys = keys.sum(dim=-2) / (mask + 1e-6)
526
  values = values.sum(dim=-2) / (mask + 1e-6)
527
 
528
- mask = - (1. - mask.clamp(0, 1)) * 1e4
529
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
530
 
531
  def get_sparse_tokens_with_stride(self, keys, values, mask):
@@ -590,7 +594,7 @@ class LSGSelfAttention(BaseSelfAttention):
590
  keys /= mask + 1e-8
591
  values /= mask + 1e-8
592
 
593
- mask = -10000 * (1. - mask.clamp(0, 1))
594
 
595
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
596
 
@@ -1042,7 +1046,7 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
1042
  n, t = inputs_.size()[:2]
1043
 
1044
  if attention_mask is None:
1045
- attention_mask = torch.ones(n, t, device=inputs_.device)
1046
  if self.mask_first_token:
1047
  attention_mask[:,0] = 0
1048
  if token_type_ids is None:
@@ -1125,7 +1129,7 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
1125
  )
1126
 
1127
  extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
1128
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
1129
 
1130
  return extended_attention_mask
1131
 
 
183
 
184
  # Add causal mask
185
  causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
186
+ causal_mask = torch.tril(
187
+ torch.ones(*causal_shape, device=attention_mask.device, dtype=attention_scores.dtype),
188
+ diagonal=-1
189
+ )
190
+ causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
191
  attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
192
 
193
  del attention_mask
 
305
 
306
  # Pad before block reshaping
307
  if is_attn_mask:
308
+ pad_value = torch.finfo(hidden_states.dtype).min
309
  hidden_states = hidden_states.transpose(-1, -2)
310
  else:
311
  pad_value = 0
 
338
 
339
  # Pad before block reshaping
340
  if is_attn_mask:
341
+ pad_value = torch.finfo(hidden_states.dtype).min
342
  hidden_states = hidden_states.transpose(-1, -2)
343
  else:
344
  pad_value = 0
 
529
  keys = keys.sum(dim=-2) / (mask + 1e-6)
530
  values = values.sum(dim=-2) / (mask + 1e-6)
531
 
532
+ mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
533
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
534
 
535
  def get_sparse_tokens_with_stride(self, keys, values, mask):
 
594
  keys /= mask + 1e-8
595
  values /= mask + 1e-8
596
 
597
+ mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
598
 
599
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
600
 
 
1046
  n, t = inputs_.size()[:2]
1047
 
1048
  if attention_mask is None:
1049
+ attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
1050
  if self.mask_first_token:
1051
  attention_mask[:,0] = 0
1052
  if token_type_ids is None:
 
1129
  )
1130
 
1131
  extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
1132
+ extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(extended_attention_mask.dtype).min
1133
 
1134
  return extended_attention_mask
1135