Martijn van Beers
try to make it work quick and dirty
2e1a3f8
raw
history blame
No virus
1.91 kB
--- modeling_roberta.py 2022-06-28 11:59:19.974278244 +0200
+++ roberta2.py 2022-06-28 14:13:05.765050058 +0200
@@ -23,14 +23,14 @@
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
-from ...activations import ACT2FN, gelu
-from ...file_utils import (
+from transformers.activations import ACT2FN, gelu
+from transformers.file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
-from ...modeling_outputs import (
+from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
@@ -40,14 +40,14 @@
SequenceClassifierOutput,
TokenClassifierOutput,
)
-from ...modeling_utils import (
+from transformers.modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
-from ...utils import logging
-from .configuration_roberta import RobertaConfig
+from transformers.utils import logging
+from transformers.models.roberta.configuration_roberta import RobertaConfig
logger = logging.get_logger(__name__)
@@ -183,6 +183,24 @@
self.is_decoder = config.is_decoder
+ def get_attn(self):
+ return self.attn
+
+ def save_attn(self, attn):
+ self.attn = attn
+
+ def save_attn_cam(self, cam):
+ self.attn_cam = cam
+
+ def get_attn_cam(self):
+ return self.attn_cam
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)