small fix with torch.finfo
Browse files- modeling_lsg_bart.py +4 -2
modeling_lsg_bart.py
CHANGED
@@ -435,7 +435,8 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
435 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
436 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
437 |
|
438 |
-
mask = (1. - mask.clamp(0, 1))
|
|
|
439 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
440 |
|
441 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
@@ -500,7 +501,8 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
500 |
keys /= mask + 1e-8
|
501 |
values /= mask + 1e-8
|
502 |
|
503 |
-
mask = (1. - mask.clamp(0, 1))
|
|
|
504 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
505 |
|
506 |
def lsh_round(self, keys, values, mask, output_size):
|
|
|
435 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
436 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
437 |
|
438 |
+
mask = (1. - mask.clamp(0, 1))
|
439 |
+
mask *= torch.finfo(mask.dtype).min
|
440 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
441 |
|
442 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
|
|
501 |
keys /= mask + 1e-8
|
502 |
values /= mask + 1e-8
|
503 |
|
504 |
+
mask = (1. - mask.clamp(0, 1))
|
505 |
+
mask *= torch.finfo(mask.dtype).min
|
506 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
507 |
|
508 |
def lsh_round(self, keys, values, mask, output_size):
|