replace 1e4 mask
Browse files- modeling_lsg_pegasus.py +5 -6
modeling_lsg_pegasus.py
CHANGED
@@ -3,7 +3,6 @@ import torch
|
|
3 |
from transformers.models.pegasus.modeling_pegasus import *
|
4 |
from transformers.models.pegasus.modeling_pegasus import _expand_mask
|
5 |
import torch.nn as nn
|
6 |
-
from torch.nn import BCEWithLogitsLoss
|
7 |
import sys
|
8 |
|
9 |
AUTO_MAP = {
|
@@ -265,7 +264,7 @@ class LSGAttentionProduct(nn.Module):
|
|
265 |
|
266 |
# Pad before block reshaping
|
267 |
if is_attn_mask:
|
268 |
-
pad_value =
|
269 |
hidden_states = hidden_states.transpose(-1, -2)
|
270 |
else:
|
271 |
pad_value = 0
|
@@ -294,7 +293,7 @@ class LSGAttentionProduct(nn.Module):
|
|
294 |
|
295 |
# Pad before block reshaping
|
296 |
if is_attn_mask:
|
297 |
-
pad_value =
|
298 |
hidden_states = hidden_states.transpose(-1, -2)
|
299 |
else:
|
300 |
pad_value = 0
|
@@ -423,7 +422,7 @@ class LSGPegasusEncoderAttention(BaseSelfAttention):
|
|
423 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
424 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
425 |
|
426 |
-
mask =
|
427 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
428 |
|
429 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
@@ -488,7 +487,7 @@ class LSGPegasusEncoderAttention(BaseSelfAttention):
|
|
488 |
keys /= mask + 1e-8
|
489 |
values /= mask + 1e-8
|
490 |
|
491 |
-
mask =
|
492 |
|
493 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
494 |
|
@@ -772,7 +771,7 @@ class LSGPegasusEncoder(LSGPegasusPreTrainedModel, PegasusEncoder):
|
|
772 |
n, t = inputs_.size()[:2]
|
773 |
|
774 |
if attention_mask is None:
|
775 |
-
attention_mask = torch.ones(n, t, device=inputs_.device)
|
776 |
if self.mask_first_token:
|
777 |
attention_mask[:,0] = 0
|
778 |
|
|
|
3 |
from transformers.models.pegasus.modeling_pegasus import *
|
4 |
from transformers.models.pegasus.modeling_pegasus import _expand_mask
|
5 |
import torch.nn as nn
|
|
|
6 |
import sys
|
7 |
|
8 |
AUTO_MAP = {
|
|
|
264 |
|
265 |
# Pad before block reshaping
|
266 |
if is_attn_mask:
|
267 |
+
pad_value = torch.finfo(hidden_states.dtype).min
|
268 |
hidden_states = hidden_states.transpose(-1, -2)
|
269 |
else:
|
270 |
pad_value = 0
|
|
|
293 |
|
294 |
# Pad before block reshaping
|
295 |
if is_attn_mask:
|
296 |
+
pad_value = torch.finfo(hidden_states.dtype).min
|
297 |
hidden_states = hidden_states.transpose(-1, -2)
|
298 |
else:
|
299 |
pad_value = 0
|
|
|
422 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
423 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
424 |
|
425 |
+
mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
|
426 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
427 |
|
428 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
|
|
487 |
keys /= mask + 1e-8
|
488 |
values /= mask + 1e-8
|
489 |
|
490 |
+
mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
|
491 |
|
492 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
493 |
|
|
|
771 |
n, t = inputs_.size()[:2]
|
772 |
|
773 |
if attention_mask is None:
|
774 |
+
attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
|
775 |
if self.mask_first_token:
|
776 |
attention_mask[:,0] = 0
|
777 |
|