ccdv commited on
Commit
090395f
1 Parent(s): 6990f99

replace 1e4 mask

Browse files
Files changed (1) hide show
  1. modeling_lsg_pegasus.py +5 -6
modeling_lsg_pegasus.py CHANGED
@@ -3,7 +3,6 @@ import torch
3
  from transformers.models.pegasus.modeling_pegasus import *
4
  from transformers.models.pegasus.modeling_pegasus import _expand_mask
5
  import torch.nn as nn
6
- from torch.nn import BCEWithLogitsLoss
7
  import sys
8
 
9
  AUTO_MAP = {
@@ -265,7 +264,7 @@ class LSGAttentionProduct(nn.Module):
265
 
266
  # Pad before block reshaping
267
  if is_attn_mask:
268
- pad_value = -10000
269
  hidden_states = hidden_states.transpose(-1, -2)
270
  else:
271
  pad_value = 0
@@ -294,7 +293,7 @@ class LSGAttentionProduct(nn.Module):
294
 
295
  # Pad before block reshaping
296
  if is_attn_mask:
297
- pad_value = -10000
298
  hidden_states = hidden_states.transpose(-1, -2)
299
  else:
300
  pad_value = 0
@@ -423,7 +422,7 @@ class LSGPegasusEncoderAttention(BaseSelfAttention):
423
  keys = keys.sum(dim=-2) / (mask + 1e-6)
424
  values = values.sum(dim=-2) / (mask + 1e-6)
425
 
426
- mask = - (1. - mask.clamp(0, 1)) * 1e4
427
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
428
 
429
  def get_sparse_tokens_with_stride(self, keys, values, mask):
@@ -488,7 +487,7 @@ class LSGPegasusEncoderAttention(BaseSelfAttention):
488
  keys /= mask + 1e-8
489
  values /= mask + 1e-8
490
 
491
- mask = -10000 * (1. - mask.clamp(0, 1))
492
 
493
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
494
 
@@ -772,7 +771,7 @@ class LSGPegasusEncoder(LSGPegasusPreTrainedModel, PegasusEncoder):
772
  n, t = inputs_.size()[:2]
773
 
774
  if attention_mask is None:
775
- attention_mask = torch.ones(n, t, device=inputs_.device)
776
  if self.mask_first_token:
777
  attention_mask[:,0] = 0
778
 
 
3
  from transformers.models.pegasus.modeling_pegasus import *
4
  from transformers.models.pegasus.modeling_pegasus import _expand_mask
5
  import torch.nn as nn
 
6
  import sys
7
 
8
  AUTO_MAP = {
 
264
 
265
  # Pad before block reshaping
266
  if is_attn_mask:
267
+ pad_value = torch.finfo(hidden_states.dtype).min
268
  hidden_states = hidden_states.transpose(-1, -2)
269
  else:
270
  pad_value = 0
 
293
 
294
  # Pad before block reshaping
295
  if is_attn_mask:
296
+ pad_value = torch.finfo(hidden_states.dtype).min
297
  hidden_states = hidden_states.transpose(-1, -2)
298
  else:
299
  pad_value = 0
 
422
  keys = keys.sum(dim=-2) / (mask + 1e-6)
423
  values = values.sum(dim=-2) / (mask + 1e-6)
424
 
425
+ mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
426
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
427
 
428
  def get_sparse_tokens_with_stride(self, keys, values, mask):
 
487
  keys /= mask + 1e-8
488
  values /= mask + 1e-8
489
 
490
+ mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
491
 
492
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
493
 
 
771
  n, t = inputs_.size()[:2]
772
 
773
  if attention_mask is None:
774
+ attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
775
  if self.mask_first_token:
776
  attention_mask[:,0] = 0
777