ccdv commited on
Commit
3364dbb
1 Parent(s): 09d5cb2

small fix with torch.finfo

Browse files
Files changed (1) hide show
  1. modeling_lsg_bart.py +4 -2
modeling_lsg_bart.py CHANGED
@@ -435,7 +435,8 @@ class LSGBartEncoderAttention(BaseSelfAttention):
435
  keys = keys.sum(dim=-2) / (mask + 1e-6)
436
  values = values.sum(dim=-2) / (mask + 1e-6)
437
 
438
- mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
 
439
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
440
 
441
  def get_sparse_tokens_with_stride(self, keys, values, mask):
@@ -500,7 +501,8 @@ class LSGBartEncoderAttention(BaseSelfAttention):
500
  keys /= mask + 1e-8
501
  values /= mask + 1e-8
502
 
503
- mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
 
504
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
505
 
506
  def lsh_round(self, keys, values, mask, output_size):
435
  keys = keys.sum(dim=-2) / (mask + 1e-6)
436
  values = values.sum(dim=-2) / (mask + 1e-6)
437
 
438
+ mask = (1. - mask.clamp(0, 1))
439
+ mask *= torch.finfo(mask.dtype).min
440
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
441
 
442
  def get_sparse_tokens_with_stride(self, keys, values, mask):
501
  keys /= mask + 1e-8
502
  values /= mask + 1e-8
503
 
504
+ mask = (1. - mask.clamp(0, 1))
505
+ mask *= torch.finfo(mask.dtype).min
506
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
507
 
508
  def lsh_round(self, keys, values, mask, output_size):