Spaces:
Runtime error
Runtime error
--- modeling_roberta.py 2022-06-28 11:59:19.974278244 +0200 | |
+++ roberta2.py 2022-06-28 14:13:05.765050058 +0200 | |
from torch import nn | |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
-from ...activations import ACT2FN, gelu | |
-from ...file_utils import ( | |
+from transformers.activations import ACT2FN, gelu | |
+from transformers.file_utils import ( | |
add_code_sample_docstrings, | |
add_start_docstrings, | |
add_start_docstrings_to_model_forward, | |
replace_return_docstrings, | |
) | |
-from ...modeling_outputs import ( | |
+from transformers.modeling_outputs import ( | |
BaseModelOutputWithPastAndCrossAttentions, | |
BaseModelOutputWithPoolingAndCrossAttentions, | |
CausalLMOutputWithCrossAttentions, | |
SequenceClassifierOutput, | |
TokenClassifierOutput, | |
) | |
-from ...modeling_utils import ( | |
+from transformers.modeling_utils import ( | |
PreTrainedModel, | |
apply_chunking_to_forward, | |
find_pruneable_heads_and_indices, | |
prune_linear_layer, | |
) | |
-from ...utils import logging | |
-from .configuration_roberta import RobertaConfig | |
+from transformers.utils import logging | |
+from transformers.models.roberta.configuration_roberta import RobertaConfig | |
logger = logging.get_logger(__name__) | |
self.is_decoder = config.is_decoder | |
+ def get_attn(self): | |
+ return self.attn | |
+ | |
+ def save_attn(self, attn): | |
+ self.attn = attn | |
+ | |
+ def save_attn_cam(self, cam): | |
+ self.attn_cam = cam | |
+ | |
+ def get_attn_cam(self): | |
+ return self.attn_cam | |
+ | |
+ def save_attn_gradients(self, attn_gradients): | |
+ self.attn_gradients = attn_gradients | |
+ | |
+ def get_attn_gradients(self): | |
+ return self.attn_gradients | |
+ | |
def transpose_for_scores(self, x): | |
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) | |
x = x.view(*new_x_shape) | |