replace -1e4 masks
Browse files- README.md +1 -0
- modeling_lsg_bert.py +11 -7
README.md
CHANGED
@@ -47,6 +47,7 @@ You can change various parameters like :
|
|
47 |
* local block size (block_size=128)
|
48 |
* sparse block size (sparse_block_size=128)
|
49 |
* sparsity factor (sparsity_factor=2)
|
|
|
50 |
* see config.json file
|
51 |
|
52 |
Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
|
|
|
47 |
* local block size (block_size=128)
|
48 |
* sparse block size (sparse_block_size=128)
|
49 |
* sparsity factor (sparsity_factor=2)
|
50 |
+
* mask_first_token (mask first token since it is redundant with the first global token)
|
51 |
* see config.json file
|
52 |
|
53 |
Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
|
modeling_lsg_bert.py
CHANGED
@@ -183,7 +183,11 @@ class CausalAttentionProduct(nn.Module):
|
|
183 |
|
184 |
# Add causal mask
|
185 |
causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
|
186 |
-
causal_mask = torch.tril(
|
|
|
|
|
|
|
|
|
187 |
attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
|
188 |
|
189 |
del attention_mask
|
@@ -301,7 +305,7 @@ class LSGAttentionProduct(nn.Module):
|
|
301 |
|
302 |
# Pad before block reshaping
|
303 |
if is_attn_mask:
|
304 |
-
pad_value =
|
305 |
hidden_states = hidden_states.transpose(-1, -2)
|
306 |
else:
|
307 |
pad_value = 0
|
@@ -334,7 +338,7 @@ class LSGAttentionProduct(nn.Module):
|
|
334 |
|
335 |
# Pad before block reshaping
|
336 |
if is_attn_mask:
|
337 |
-
pad_value =
|
338 |
hidden_states = hidden_states.transpose(-1, -2)
|
339 |
else:
|
340 |
pad_value = 0
|
@@ -525,7 +529,7 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
525 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
526 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
527 |
|
528 |
-
mask =
|
529 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
530 |
|
531 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
@@ -590,7 +594,7 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
590 |
keys /= mask + 1e-8
|
591 |
values /= mask + 1e-8
|
592 |
|
593 |
-
mask =
|
594 |
|
595 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
596 |
|
@@ -1042,7 +1046,7 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
|
|
1042 |
n, t = inputs_.size()[:2]
|
1043 |
|
1044 |
if attention_mask is None:
|
1045 |
-
attention_mask = torch.ones(n, t, device=inputs_.device)
|
1046 |
if self.mask_first_token:
|
1047 |
attention_mask[:,0] = 0
|
1048 |
if token_type_ids is None:
|
@@ -1125,7 +1129,7 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
|
|
1125 |
)
|
1126 |
|
1127 |
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
1128 |
-
extended_attention_mask = (1.0 - extended_attention_mask) *
|
1129 |
|
1130 |
return extended_attention_mask
|
1131 |
|
|
|
183 |
|
184 |
# Add causal mask
|
185 |
causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
|
186 |
+
causal_mask = torch.tril(
|
187 |
+
torch.ones(*causal_shape, device=attention_mask.device, dtype=attention_scores.dtype),
|
188 |
+
diagonal=-1
|
189 |
+
)
|
190 |
+
causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
|
191 |
attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
|
192 |
|
193 |
del attention_mask
|
|
|
305 |
|
306 |
# Pad before block reshaping
|
307 |
if is_attn_mask:
|
308 |
+
pad_value = torch.finfo(hidden_states.dtype).min
|
309 |
hidden_states = hidden_states.transpose(-1, -2)
|
310 |
else:
|
311 |
pad_value = 0
|
|
|
338 |
|
339 |
# Pad before block reshaping
|
340 |
if is_attn_mask:
|
341 |
+
pad_value = torch.finfo(hidden_states.dtype).min
|
342 |
hidden_states = hidden_states.transpose(-1, -2)
|
343 |
else:
|
344 |
pad_value = 0
|
|
|
529 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
530 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
531 |
|
532 |
+
mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
|
533 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
534 |
|
535 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
|
|
594 |
keys /= mask + 1e-8
|
595 |
values /= mask + 1e-8
|
596 |
|
597 |
+
mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
|
598 |
|
599 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
600 |
|
|
|
1046 |
n, t = inputs_.size()[:2]
|
1047 |
|
1048 |
if attention_mask is None:
|
1049 |
+
attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
|
1050 |
if self.mask_first_token:
|
1051 |
attention_mask[:,0] = 0
|
1052 |
if token_type_ids is None:
|
|
|
1129 |
)
|
1130 |
|
1131 |
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
1132 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(extended_attention_mask.dtype).min
|
1133 |
|
1134 |
return extended_attention_mask
|
1135 |
|