howard-hou commited on
Commit
0b9ef29
1 Parent(s): eb51fc8

Upload RankingPrompter

Browse files
config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/content/ICAA-compressor",
3
+ "architectures": [
4
+ "RankingPrompter"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_rankingprompter.RankingPrompterConfig",
8
+ "AutoModel": "modeling_rankingprompter.RankingPrompter"
9
+ },
10
+ "classifier_dropout": 0.0,
11
+ "d_ff": 1024,
12
+ "d_kv": 64,
13
+ "d_model": 512,
14
+ "decoder_start_token_id": 0,
15
+ "dense_act_fn": "gelu_new",
16
+ "dropout_rate": 0.1,
17
+ "eos_token_id": 1,
18
+ "feed_forward_proj": "gated-gelu",
19
+ "id2label": {
20
+ "0": "LABEL_0"
21
+ },
22
+ "initializer_factor": 1.0,
23
+ "is_encoder_decoder": true,
24
+ "is_gated_act": true,
25
+ "label2id": {
26
+ "LABEL_0": 0
27
+ },
28
+ "layer_norm_epsilon": 1e-06,
29
+ "max_new_tokens": 64,
30
+ "model_type": "umt5",
31
+ "num_answer_query": 128,
32
+ "num_decoder_layers": 8,
33
+ "num_heads": 6,
34
+ "num_layers": 8,
35
+ "pad_token_id": 0,
36
+ "relative_attention_max_distance": 128,
37
+ "relative_attention_num_buckets": 32,
38
+ "scalable_attention": true,
39
+ "tie_word_embeddings": false,
40
+ "tokenizer_class": "T5Tokenizer",
41
+ "torch_dtype": "float32",
42
+ "transformers_version": "4.41.2",
43
+ "use_cache": true,
44
+ "vocab_size": 256384
45
+ }
configuration_rankingprompter.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class RankingPrompterConfig(PretrainedConfig):
4
+ model_type = "umt5"
5
+
6
+ def __init__(
7
+ self,
8
+ vocab_size=250112,
9
+ d_model=512,
10
+ d_kv=64,
11
+ d_ff=1024,
12
+ num_layers=8,
13
+ num_decoder_layers=None,
14
+ num_heads=6,
15
+ relative_attention_num_buckets=32,
16
+ relative_attention_max_distance=128,
17
+ num_labels=1,
18
+ dropout_rate=0.1,
19
+ layer_norm_epsilon=1e-6,
20
+ initializer_factor=1.0,
21
+ feed_forward_proj="gated-gelu",
22
+ is_encoder_decoder=True,
23
+ use_cache=True,
24
+ tokenizer_class="T5Tokenizer",
25
+ tie_word_embeddings=True,
26
+ pad_token_id=0,
27
+ eos_token_id=1,
28
+ decoder_start_token_id=2,
29
+ classifier_dropout=0.1,
30
+ **kwargs,
31
+ ):
32
+ super().__init__(
33
+ is_encoder_decoder=is_encoder_decoder,
34
+ tokenizer_class=tokenizer_class,
35
+ tie_word_embeddings=tie_word_embeddings,
36
+ pad_token_id=pad_token_id,
37
+ eos_token_id=eos_token_id,
38
+ decoder_start_token_id=decoder_start_token_id,
39
+ **kwargs,
40
+ )
41
+ self.vocab_size = vocab_size
42
+ self.d_model = d_model
43
+ self.d_kv = d_kv
44
+ self.d_ff = d_ff
45
+ self.num_layers = num_layers
46
+ self.num_decoder_layers = (
47
+ num_decoder_layers if num_decoder_layers is not None else self.num_layers
48
+ ) # default = symmetry
49
+ self.num_heads = num_heads
50
+ self.relative_attention_num_buckets = relative_attention_num_buckets
51
+ self.relative_attention_max_distance = relative_attention_max_distance
52
+ self.num_labels = num_labels
53
+ self.dropout_rate = dropout_rate
54
+ self.classifier_dropout = classifier_dropout
55
+ self.layer_norm_epsilon = layer_norm_epsilon
56
+ self.initializer_factor = initializer_factor
57
+ self.feed_forward_proj = feed_forward_proj
58
+ self.use_cache = use_cache
59
+
60
+ act_info = self.feed_forward_proj.split("-")
61
+ self.dense_act_fn = act_info[-1]
62
+ self.is_gated_act = act_info[0] == "gated"
63
+
64
+ if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
65
+ raise ValueError(
66
+ f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer."
67
+ "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
68
+ "'gated-gelu' or 'relu'"
69
+ )
70
+
71
+ if feed_forward_proj == "gated-gelu":
72
+ self.dense_act_fn = "gelu_new"
73
+
74
+ @property
75
+ def hidden_size(self):
76
+ return self.d_model
77
+
78
+ @property
79
+ def num_attention_heads(self):
80
+ return self.num_heads
81
+
82
+ @property
83
+ def num_hidden_layers(self):
84
+ return self.num_layers
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3aa6984ac3f9219b5fae903a5d57fbf14065d5e1c1b77304b10cbeb179e6ce78
3
+ size 701360012
modeling_rankingprompter.py ADDED
@@ -0,0 +1,1723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ modified PyTorch UMT5 model. add save attention weights function so that we can compute grad-cam."""
2
+
3
+ import copy
4
+ import math
5
+ from typing import List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.utils.checkpoint import checkpoint
10
+
11
+ from transformers.activations import ACT2FN
12
+ from transformers.modeling_outputs import (
13
+ BaseModelOutput,
14
+ BaseModelOutputWithPastAndCrossAttentions,
15
+ Seq2SeqModelOutput,
16
+ )
17
+ from transformers import PreTrainedModel, UMT5Config
18
+ from transformers.utils import (
19
+ DUMMY_INPUTS,
20
+ DUMMY_MASK,
21
+ add_start_docstrings,
22
+ add_start_docstrings_to_model_forward,
23
+ is_torch_fx_proxy,
24
+ logging,
25
+ replace_return_docstrings,
26
+ )
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+ _CONFIG_FOR_DOC = "UMT5Config"
32
+ _CHECKPOINT_FOR_DOC = "google/umt5-small"
33
+
34
+
35
+ # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->UMT5
36
+ class UMT5LayerNorm(nn.Module):
37
+ def __init__(self, hidden_size, eps=1e-6):
38
+ """
39
+ Construct a layernorm module in the UMT5 style. No bias and no subtraction of mean.
40
+ """
41
+ super().__init__()
42
+ self.weight = nn.Parameter(torch.ones(hidden_size))
43
+ self.variance_epsilon = eps
44
+
45
+ def forward(self, hidden_states):
46
+ # UMT5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
47
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
48
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
49
+ # half-precision inputs is done in fp32
50
+
51
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
52
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
53
+
54
+ # convert into half-precision if necessary
55
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
56
+ hidden_states = hidden_states.to(self.weight.dtype)
57
+
58
+ return self.weight * hidden_states
59
+
60
+
61
+ # Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->UMT5
62
+ class UMT5DenseActDense(nn.Module):
63
+ def __init__(self, config: UMT5Config):
64
+ super().__init__()
65
+ self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
66
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
67
+ self.dropout = nn.Dropout(config.dropout_rate)
68
+ self.act = ACT2FN[config.dense_act_fn]
69
+
70
+ def forward(self, hidden_states):
71
+ hidden_states = self.wi(hidden_states)
72
+ hidden_states = self.act(hidden_states)
73
+ hidden_states = self.dropout(hidden_states)
74
+ if (
75
+ isinstance(self.wo.weight, torch.Tensor)
76
+ and hidden_states.dtype != self.wo.weight.dtype
77
+ and self.wo.weight.dtype != torch.int8
78
+ ):
79
+ hidden_states = hidden_states.to(self.wo.weight.dtype)
80
+ hidden_states = self.wo(hidden_states)
81
+ return hidden_states
82
+
83
+
84
+ # Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->UMT5
85
+ class UMT5DenseGatedActDense(nn.Module):
86
+ def __init__(self, config: UMT5Config):
87
+ super().__init__()
88
+ self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
89
+ self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
90
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
91
+ self.dropout = nn.Dropout(config.dropout_rate)
92
+ self.act = ACT2FN[config.dense_act_fn]
93
+
94
+ def forward(self, hidden_states):
95
+ hidden_gelu = self.act(self.wi_0(hidden_states))
96
+ hidden_linear = self.wi_1(hidden_states)
97
+ hidden_states = hidden_gelu * hidden_linear
98
+ hidden_states = self.dropout(hidden_states)
99
+
100
+ # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
101
+ # See https://github.com/huggingface/transformers/issues/20287
102
+ # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
103
+ if (
104
+ isinstance(self.wo.weight, torch.Tensor)
105
+ and hidden_states.dtype != self.wo.weight.dtype
106
+ and self.wo.weight.dtype != torch.int8
107
+ ):
108
+ hidden_states = hidden_states.to(self.wo.weight.dtype)
109
+
110
+ hidden_states = self.wo(hidden_states)
111
+ return hidden_states
112
+
113
+
114
+ # Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->UMT5
115
+ class UMT5LayerFF(nn.Module):
116
+ def __init__(self, config: UMT5Config):
117
+ super().__init__()
118
+ if config.is_gated_act:
119
+ self.DenseReluDense = UMT5DenseGatedActDense(config)
120
+ else:
121
+ self.DenseReluDense = UMT5DenseActDense(config)
122
+
123
+ self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
124
+ self.dropout = nn.Dropout(config.dropout_rate)
125
+
126
+ def forward(self, hidden_states):
127
+ forwarded_states = self.layer_norm(hidden_states)
128
+ forwarded_states = self.DenseReluDense(forwarded_states)
129
+ hidden_states = hidden_states + self.dropout(forwarded_states)
130
+ return hidden_states
131
+
132
+
133
+ class UMT5Attention(nn.Module):
134
+ """
135
+ T5's attention using relative_attention_bias.
136
+ """
137
+
138
+ def __init__(self, config, has_relative_attention_bias=False):
139
+ super().__init__()
140
+ self.is_decoder = config.is_decoder
141
+ self.has_relative_attention_bias = has_relative_attention_bias
142
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
143
+ self.relative_attention_max_distance = config.relative_attention_max_distance
144
+ self.d_model = config.d_model
145
+ self.key_value_proj_dim = config.d_kv
146
+ self.n_heads = config.num_heads
147
+ self.dropout = config.dropout_rate
148
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
149
+
150
+ # Mesh TensorFlow initialization to avoid scaling before softmax
151
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
152
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
153
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
154
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
155
+
156
+ if self.has_relative_attention_bias:
157
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
158
+ self.pruned_heads = set()
159
+
160
+ # save attention weights
161
+ self.save_attention = False
162
+ self.attn_gradients = None
163
+ self.attention_map = None
164
+
165
+ def save_attn_gradients(self, attn_gradients):
166
+ self.attn_gradients = attn_gradients
167
+
168
+ def get_attn_gradients(self):
169
+ return self.attn_gradients
170
+
171
+ def save_attention_map(self, attention_map):
172
+ self.attention_map = attention_map
173
+
174
+ def get_attention_map(self):
175
+ return self.attention_map
176
+
177
+ def _shape(self, projection: torch.Tensor) -> torch.Tensor:
178
+ new_projection_shape = projection.size()[:-1] + (self.n_heads, self.key_value_proj_dim)
179
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
180
+ new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
181
+ return new_projection
182
+
183
+ def _relative_position_bucket(self, relative_position):
184
+ """
185
+ Adapted from Mesh Tensorflow:
186
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
187
+
188
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
189
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
190
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
191
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
192
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
193
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
194
+
195
+ Args:
196
+ relative_position: an int32 Tensor
197
+ bidirectional: a boolean - whether the attention is bidirectional
198
+ num_buckets: an integer
199
+ max_distance: an integer
200
+
201
+ Returns:
202
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
203
+ """
204
+ relative_buckets = 0
205
+ num_buckets = self.relative_attention_num_buckets
206
+ max_distance = self.relative_attention_max_distance
207
+ if not self.is_decoder:
208
+ num_buckets //= 2
209
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
210
+ relative_position = torch.abs(relative_position)
211
+ else:
212
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
213
+ # now relative_position is in the range [0, inf)
214
+
215
+ # half of the buckets are for exact increments in positions
216
+ max_exact = num_buckets // 2
217
+ is_small = relative_position < max_exact
218
+
219
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
220
+ log_ratio = torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact)
221
+ log_ratio = log_ratio * (num_buckets - max_exact)
222
+ relative_position_if_large = max_exact + log_ratio.to(torch.long)
223
+ relative_position_if_large = torch.min(
224
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
225
+ )
226
+
227
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
228
+ return relative_buckets
229
+
230
+ def compute_bias(self, query_length, key_length, device=None):
231
+ """Compute binned relative position bias"""
232
+ if device is None:
233
+ device = self.relative_attention_bias.weight.device
234
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
235
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
236
+ relative_position = memory_position - context_position # shape (query_length, key_length)
237
+ relative_position_bucket = self._relative_position_bucket(relative_position)
238
+ values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
239
+ values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
240
+ return values
241
+
242
+ def forward(
243
+ self,
244
+ hidden_states: torch.Tensor,
245
+ encoder_hidden_states: Optional[torch.Tensor] = None,
246
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
247
+ attention_mask: Optional[torch.Tensor] = None,
248
+ layer_head_mask: Optional[torch.Tensor] = None,
249
+ ):
250
+ is_cross_attention = encoder_hidden_states is not None
251
+ batch_size, seq_length = hidden_states.shape[:2]
252
+
253
+ # use encoder_hidden_states if cross attention
254
+ current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
255
+ # checking that the `sequence_length` of the `past_key_value` is the same as the he provided
256
+ # `encoder_hidden_states` to support prefix tuning
257
+ if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
258
+ # reuse k,v, cross_attentions
259
+ key_states = past_key_value[0]
260
+ value_states = past_key_value[1]
261
+ else:
262
+ key_states = self._shape(self.k(current_states))
263
+ value_states = self._shape(self.v(current_states))
264
+ if past_key_value is not None and not is_cross_attention:
265
+ # reuse k, v, self_attention
266
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
267
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
268
+
269
+ query_states = self._shape(self.q(hidden_states))
270
+ attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
271
+
272
+ # compute positional bias
273
+ if self.has_relative_attention_bias:
274
+ query_length = seq_length
275
+ if past_key_value is not None:
276
+ query_length += past_key_value[0].shape[2]
277
+ position_bias = self.compute_bias(query_length, key_states.size(2), device=attention_scores.device)
278
+ else:
279
+ position_bias = torch.zeros(
280
+ (1, self.n_heads, seq_length, key_states.size(2)),
281
+ device=attention_scores.device,
282
+ dtype=attention_scores.dtype,
283
+ requires_grad=self.training,
284
+ )
285
+ if past_key_value is not None:
286
+ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
287
+ if attention_mask is not None:
288
+ position_bias = position_bias + attention_mask # (batch_size, n_heads, seq_length, key_length)
289
+
290
+ if self.is_decoder:
291
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
292
+ # Further calls to cross_attention layer can then reuse all cross-attention
293
+ # key/value_states (first "if" case)
294
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
295
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
296
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
297
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
298
+ past_key_value = (key_states, value_states)
299
+
300
+ attention_scores += position_bias
301
+ # (batch_size, n_heads, seq_length, key_length)
302
+ attn_weights = nn.functional.softmax(attention_scores.float(), dim=-1).type_as(attention_scores)
303
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
304
+
305
+ # Mask heads if we want to
306
+ if layer_head_mask is not None:
307
+ attn_weights = attn_weights * layer_head_mask
308
+
309
+ # save attention weights
310
+ if self.save_attention:
311
+ self.save_attention_map(attn_weights)
312
+ attn_weights.register_hook(self.save_attn_gradients)
313
+
314
+ # attn_output = torch.bmm(attn_probs, value_states) ?
315
+ context_states = torch.matmul(attn_weights, value_states)
316
+ # attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) ?
317
+ context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1)
318
+ attn_output = self.o(context_states)
319
+ return attn_output, attn_weights, past_key_value
320
+
321
+
322
+ class UMT5LayerSelfAttention(nn.Module):
323
+ def __init__(self, config):
324
+ super().__init__()
325
+ self.SelfAttention = UMT5Attention(config, has_relative_attention_bias=True)
326
+ self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
327
+ self.dropout = nn.Dropout(config.dropout_rate)
328
+
329
+ def forward(
330
+ self,
331
+ hidden_states,
332
+ attention_mask=None,
333
+ layer_head_mask=None,
334
+ past_key_value=None,
335
+ ):
336
+ normed_hidden_states = self.layer_norm(hidden_states)
337
+ attention_output = self.SelfAttention(
338
+ normed_hidden_states,
339
+ attention_mask=attention_mask,
340
+ layer_head_mask=layer_head_mask,
341
+ past_key_value=past_key_value,
342
+ )
343
+ hidden_states = hidden_states + self.dropout(attention_output[0])
344
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
345
+ return outputs
346
+
347
+
348
+ class UMT5LayerCrossAttention(nn.Module):
349
+ def __init__(self, config):
350
+ super().__init__()
351
+ self.EncDecAttention = UMT5Attention(config, has_relative_attention_bias=False)
352
+ self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
353
+ self.dropout = nn.Dropout(config.dropout_rate)
354
+
355
+ def forward(
356
+ self,
357
+ hidden_states,
358
+ encoder_hidden_states=None,
359
+ attention_mask=None,
360
+ layer_head_mask=None,
361
+ past_key_value=None,
362
+ ):
363
+ normed_hidden_states = self.layer_norm(hidden_states)
364
+ attention_output = self.EncDecAttention(
365
+ normed_hidden_states,
366
+ encoder_hidden_states=encoder_hidden_states,
367
+ attention_mask=attention_mask,
368
+ layer_head_mask=layer_head_mask,
369
+ past_key_value=past_key_value,
370
+ )
371
+ layer_output = hidden_states + self.dropout(attention_output[0])
372
+ outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
373
+ return outputs
374
+
375
+
376
+ class UMT5Block(nn.Module):
377
+ def __init__(self, config):
378
+ super().__init__()
379
+ self.is_decoder = config.is_decoder
380
+ self.layer = nn.ModuleList()
381
+ self.layer.append(UMT5LayerSelfAttention(config))
382
+ if self.is_decoder:
383
+ self.layer.append(UMT5LayerCrossAttention(config))
384
+
385
+ self.layer.append(UMT5LayerFF(config))
386
+
387
+ def forward(
388
+ self,
389
+ hidden_states,
390
+ attention_mask=None,
391
+ encoder_hidden_states=None,
392
+ encoder_attention_mask=None,
393
+ layer_head_mask=None,
394
+ cross_attn_layer_head_mask=None,
395
+ past_key_value=None,
396
+ use_cache=False,
397
+ output_attentions=False,
398
+ ):
399
+ # Self Attention
400
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
401
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
402
+
403
+ hidden_states, self_attn_weights, present_key_value = self.layer[0](
404
+ hidden_states,
405
+ attention_mask=attention_mask,
406
+ layer_head_mask=layer_head_mask,
407
+ past_key_value=self_attn_past_key_value,
408
+ )
409
+
410
+ # clamp inf values to enable fp16 training
411
+ if hidden_states.dtype == torch.float16:
412
+ max_dtype = torch.finfo(hidden_states.dtype).max
413
+ clamp_value = torch.where(torch.isinf(hidden_states).any(), max_dtype - 1000, max_dtype)
414
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
415
+
416
+ # Cross-Attention Block
417
+ cross_attn_present_key_value = None
418
+ cross_attn_weights = None
419
+ do_cross_attention = self.is_decoder and encoder_hidden_states is not None
420
+ if do_cross_attention:
421
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
422
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
423
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.layer[1](
424
+ hidden_states,
425
+ encoder_hidden_states=encoder_hidden_states,
426
+ attention_mask=encoder_attention_mask,
427
+ layer_head_mask=cross_attn_layer_head_mask,
428
+ past_key_value=cross_attn_past_key_value,
429
+ )
430
+ # clamp inf values to enable fp16 training
431
+ if hidden_states.dtype == torch.float16:
432
+ max_dtype = torch.finfo(hidden_states.dtype).max
433
+ clamp_value = torch.where(torch.isinf(hidden_states).any(), max_dtype - 1000, max_dtype)
434
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
435
+
436
+ present_key_value += cross_attn_present_key_value
437
+
438
+ # Apply Feed Forward layer
439
+ hidden_states = self.layer[-1](hidden_states)
440
+
441
+ # clamp inf values to enable fp16 training
442
+ if hidden_states.dtype == torch.float16:
443
+ max_dtype = torch.finfo(hidden_states.dtype).max
444
+ clamp_value = torch.where(torch.isinf(hidden_states).any(), max_dtype - 1000, max_dtype)
445
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
446
+
447
+ outputs = (
448
+ hidden_states,
449
+ present_key_value,
450
+ )
451
+
452
+ if output_attentions:
453
+ outputs += (self_attn_weights, cross_attn_weights)
454
+
455
+ return outputs
456
+
457
+
458
+ # Copied from transformers.models.t5.modeling_t5.T5ClassificationHead with T5->UMT5
459
+ class UMT5ClassificationHead(nn.Module):
460
+ """Head for sentence-level classification tasks."""
461
+
462
+ def __init__(self, config: UMT5Config):
463
+ super().__init__()
464
+ self.dense = nn.Linear(config.d_model, config.d_model)
465
+ self.dropout = nn.Dropout(p=config.classifier_dropout)
466
+ self.out_proj = nn.Linear(config.d_model, config.num_labels)
467
+
468
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
469
+ hidden_states = self.dropout(hidden_states)
470
+ hidden_states = self.dense(hidden_states)
471
+ hidden_states = torch.tanh(hidden_states)
472
+ hidden_states = self.dropout(hidden_states)
473
+ hidden_states = self.out_proj(hidden_states)
474
+ return hidden_states
475
+
476
+
477
+ class UMT5PreTrainedModel(PreTrainedModel):
478
+ """
479
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
480
+ models.
481
+ """
482
+
483
+ config_class = UMT5Config
484
+ base_model_prefix = "transformer"
485
+ supports_gradient_checkpointing = True
486
+ _no_split_modules = ["UMT5Block"]
487
+ _keep_in_fp32_modules = ["wo"]
488
+
489
+ @property
490
+ def dummy_inputs(self):
491
+ input_ids = torch.tensor(DUMMY_INPUTS)
492
+ input_mask = torch.tensor(DUMMY_MASK)
493
+ dummy_inputs = {
494
+ "decoder_input_ids": input_ids,
495
+ "input_ids": input_ids,
496
+ "decoder_attention_mask": input_mask,
497
+ }
498
+ return dummy_inputs
499
+
500
+ def _init_weights(self, module):
501
+ """Initialize the weights"""
502
+ factor = self.config.initializer_factor # Used for testing weights initialization
503
+ if isinstance(module, UMT5LayerNorm):
504
+ module.weight.data.fill_(factor * 1.0)
505
+ elif isinstance(
506
+ module,
507
+ (
508
+ UMT5Model,
509
+ ),
510
+ ):
511
+ # Mesh TensorFlow embeddings initialization
512
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
513
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
514
+ if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
515
+ module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
516
+ if hasattr(module, "qa_outputs"):
517
+ module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
518
+ module.qa_outputs.bias.data.zero_()
519
+ elif isinstance(module, UMT5ClassificationHead):
520
+ module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
521
+ if hasattr(module.dense, "bias") and module.dense.bias is not None:
522
+ module.dense.bias.data.zero_()
523
+ module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
524
+ if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
525
+ module.out_proj.bias.data.zero_()
526
+ elif isinstance(module, UMT5DenseActDense):
527
+ # Mesh TensorFlow FF initialization
528
+ # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
529
+ # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
530
+ module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
531
+ if hasattr(module.wi, "bias") and module.wi.bias is not None:
532
+ module.wi.bias.data.zero_()
533
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
534
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
535
+ module.wo.bias.data.zero_()
536
+ elif isinstance(module, UMT5DenseGatedActDense):
537
+ module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
538
+ if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
539
+ module.wi_0.bias.data.zero_()
540
+ module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
541
+ if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
542
+ module.wi_1.bias.data.zero_()
543
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
544
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
545
+ module.wo.bias.data.zero_()
546
+ elif isinstance(module, UMT5Attention):
547
+ # Mesh TensorFlow attention initialization to avoid scaling before softmax
548
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
549
+ d_model = self.config.d_model
550
+ key_value_proj_dim = self.config.d_kv
551
+ n_heads = self.config.num_heads
552
+ module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
553
+ module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
554
+ module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
555
+ module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
556
+ if module.has_relative_attention_bias:
557
+ module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
558
+
559
+ def _set_gradient_checkpointing(self, module, value=False):
560
+ if isinstance(module, (UMT5Attention, UMT5Stack)):
561
+ module.gradient_checkpointing = value
562
+
563
+ def _shift_right(self, input_ids):
564
+ decoder_start_token_id = self.config.decoder_start_token_id
565
+ pad_token_id = self.config.pad_token_id
566
+
567
+ if decoder_start_token_id is None:
568
+ raise ValueError(
569
+ "self.model.config.decoder_start_token_id has to be defined. In UMT5 it is usually set to the pad_token_id."
570
+ "See UMT5 docs for more information."
571
+ )
572
+
573
+ # shift inputs to the right
574
+ if is_torch_fx_proxy(input_ids):
575
+ # Item assignment is not supported natively for proxies.
576
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
577
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
578
+ else:
579
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
580
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
581
+ shifted_input_ids[..., 0] = decoder_start_token_id
582
+
583
+ if pad_token_id is None:
584
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
585
+ # replace possible -100 values in labels by `pad_token_id`
586
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
587
+
588
+ return shifted_input_ids
589
+
590
+
591
+ class UMT5Stack(UMT5PreTrainedModel):
592
+ def __init__(self, config, embed_tokens=None):
593
+ super().__init__(config)
594
+ self.embed_tokens = embed_tokens
595
+ self.is_decoder = config.is_decoder
596
+ self.block = nn.ModuleList([UMT5Block(config) for i in range(config.num_layers)])
597
+ self.final_layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
598
+ self.dropout = nn.Dropout(config.dropout_rate)
599
+
600
+ # Initialize weights and apply final processing
601
+ self.gradient_checkpointing = False
602
+ self.post_init()
603
+
604
+ def get_input_embeddings(self):
605
+ return self.embed_tokens
606
+
607
+ def set_input_embeddings(self, new_embeddings):
608
+ self.embed_tokens = new_embeddings
609
+
610
+ def forward(
611
+ self,
612
+ input_ids=None,
613
+ attention_mask=None,
614
+ encoder_hidden_states=None,
615
+ encoder_attention_mask=None,
616
+ inputs_embeds=None,
617
+ head_mask=None,
618
+ cross_attn_head_mask=None,
619
+ past_key_values=None,
620
+ use_cache=None,
621
+ output_attentions=None,
622
+ output_hidden_states=None,
623
+ return_dict=None,
624
+ ):
625
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
626
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
627
+ output_hidden_states = (
628
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
629
+ )
630
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
631
+
632
+ if input_ids is not None and inputs_embeds is not None:
633
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
634
+ raise ValueError(
635
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
636
+ )
637
+ elif input_ids is not None:
638
+ input_shape = input_ids.size()
639
+ input_ids = input_ids.view(-1, input_shape[-1])
640
+ elif inputs_embeds is not None:
641
+ input_shape = inputs_embeds.size()[:-1]
642
+ else:
643
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
644
+ raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
645
+
646
+ if inputs_embeds is None:
647
+ if self.embed_tokens is None:
648
+ raise ValueError("You have to initialize the model with valid token embeddings")
649
+ inputs_embeds = self.embed_tokens(input_ids)
650
+
651
+ batch_size, seq_length = input_shape
652
+
653
+ # required mask seq length can be calculated via length of past
654
+ mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
655
+
656
+ if use_cache is True:
657
+ if not self.is_decoder:
658
+ raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
659
+
660
+ if attention_mask is None:
661
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
662
+ if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
663
+ encoder_seq_length = encoder_hidden_states.shape[1]
664
+ encoder_attention_mask = torch.ones(
665
+ batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
666
+ )
667
+
668
+ # initialize past_key_values with `None` if past does not exist
669
+ if past_key_values is None:
670
+ past_key_values = [None] * len(self.block)
671
+
672
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
673
+ # ourselves in which case we just need to make it broadcastable to all heads.
674
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
675
+
676
+ # If a 2D or 3D attention mask is provided for the cross-attention
677
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
678
+ if self.is_decoder and encoder_hidden_states is not None:
679
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
680
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
681
+ if encoder_attention_mask is None:
682
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
683
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
684
+ else:
685
+ encoder_extended_attention_mask = None
686
+
687
+ if self.gradient_checkpointing and self.training:
688
+ if use_cache:
689
+ logger.warning_once(
690
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
691
+ )
692
+ use_cache = False
693
+
694
+ # Prepare head mask if needed
695
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
696
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
697
+ present_key_value_states = () if use_cache else None
698
+ all_hidden_states = () if output_hidden_states else None
699
+ all_attentions = () if output_attentions else None
700
+ all_cross_attentions = () if output_attentions and self.is_decoder else None
701
+
702
+ hidden_states = self.dropout(inputs_embeds)
703
+
704
+ for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
705
+ layer_head_mask = head_mask[i]
706
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
707
+
708
+ if output_hidden_states:
709
+ all_hidden_states = all_hidden_states + (hidden_states,)
710
+
711
+ if self.gradient_checkpointing and self.training:
712
+
713
+ def create_custom_forward(module):
714
+ def custom_forward(*inputs):
715
+ return tuple(module(*inputs, use_cache, output_attentions))
716
+
717
+ return custom_forward
718
+
719
+ layer_outputs = checkpoint(
720
+ create_custom_forward(layer_module),
721
+ hidden_states,
722
+ extended_attention_mask,
723
+ encoder_hidden_states,
724
+ encoder_extended_attention_mask,
725
+ layer_head_mask,
726
+ cross_attn_layer_head_mask,
727
+ None, # past_key_value is always None with gradient checkpointing
728
+ )
729
+ else:
730
+ layer_outputs = layer_module(
731
+ hidden_states,
732
+ attention_mask=extended_attention_mask,
733
+ encoder_hidden_states=encoder_hidden_states,
734
+ encoder_attention_mask=encoder_extended_attention_mask,
735
+ layer_head_mask=layer_head_mask,
736
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
737
+ past_key_value=past_key_value,
738
+ use_cache=use_cache,
739
+ output_attentions=output_attentions,
740
+ )
741
+
742
+ hidden_states = layer_outputs[0]
743
+
744
+ if use_cache:
745
+ present_key_value_states += (layer_outputs[1],)
746
+
747
+ if output_attentions:
748
+ all_attentions += (layer_outputs[2],)
749
+ if self.is_decoder:
750
+ all_cross_attentions += (layer_outputs[3],)
751
+
752
+ hidden_states = self.final_layer_norm(hidden_states)
753
+ hidden_states = self.dropout(hidden_states)
754
+
755
+ # Add last layer
756
+ if output_hidden_states:
757
+ all_hidden_states = all_hidden_states + (hidden_states,)
758
+
759
+ if not return_dict:
760
+ return tuple(
761
+ v
762
+ for v in [
763
+ hidden_states,
764
+ present_key_value_states,
765
+ all_hidden_states,
766
+ all_attentions,
767
+ all_cross_attentions,
768
+ ]
769
+ if v is not None
770
+ )
771
+ return BaseModelOutputWithPastAndCrossAttentions(
772
+ last_hidden_state=hidden_states,
773
+ past_key_values=present_key_value_states,
774
+ hidden_states=all_hidden_states,
775
+ attentions=all_attentions,
776
+ cross_attentions=all_cross_attentions,
777
+ )
778
+
779
+
780
+ UMT5_START_DOCSTRING = r"""
781
+
782
+ The UMT5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
783
+ Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
784
+ Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a
785
+ text-to-text denoising generative setting.
786
+
787
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
788
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
789
+ etc.)
790
+
791
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
792
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
793
+ and behavior.
794
+
795
+ Parameters:
796
+ config ([`UMT5Config`]): Model configuration class with all the parameters of the model.
797
+ Initializing with a config file does not load the weights associated with the model, only the
798
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
799
+ """
800
+
801
+ UMT5_INPUTS_DOCSTRING = r"""
802
+ Args:
803
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
804
+ Indices of input sequence tokens in the vocabulary. UMT5 is a model with relative position embeddings so
805
+ you should be able to pad the inputs on both the right and the left.
806
+
807
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
808
+ [`PreTrainedTokenizer.__call__`] for detail.
809
+
810
+ [What are input IDs?](../glossary#input-ids)
811
+
812
+ To know more on how to prepare `input_ids` for pretraining take a look a [UMT5 Training](./umt5#training).
813
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
814
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
815
+
816
+ - 1 for tokens that are **not masked**,
817
+ - 0 for tokens that are **masked**.
818
+
819
+ [What are attention masks?](../glossary#attention-mask)
820
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
821
+ Indices of decoder input sequence tokens in the vocabulary.
822
+
823
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
824
+ [`PreTrainedTokenizer.__call__`] for details.
825
+
826
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
827
+
828
+ UMT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
829
+ is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
830
+
831
+ To know more on how to prepare `decoder_input_ids` for pretraining take a look at [UMT5
832
+ Training](./umt5#training).
833
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
834
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
835
+ be used by default.
836
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
837
+ Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
838
+ 1]`:
839
+
840
+ - 1 indicates the head is **not masked**,
841
+ - 0 indicates the head is **masked**.
842
+
843
+ decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
844
+ Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
845
+ 1]`:
846
+
847
+ - 1 indicates the head is **not masked**,
848
+ - 0 indicates the head is **masked**.
849
+
850
+ cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
851
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
852
+ `[0, 1]`:
853
+
854
+ - 1 indicates the head is **not masked**,
855
+ - 0 indicates the head is **masked**.
856
+
857
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
858
+ Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
859
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
860
+ the output of the last layer of the encoder. Used in the cross-attention of the decoder.
861
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
862
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
863
+
864
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
865
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
866
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
867
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
868
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
869
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
870
+ model's internal embedding lookup matrix.
871
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
872
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
873
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
874
+ input (see `past_key_values`). This is useful if you want more control over how to convert
875
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
876
+
877
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
878
+ of `inputs_embeds`.
879
+
880
+ use_cache (`bool`, *optional*):
881
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
882
+ `past_key_values`).
883
+
884
+ output_attentions (`bool`, *optional*):
885
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
886
+ tensors for more detail.
887
+ output_hidden_states (`bool`, *optional*):
888
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
889
+ more detail.
890
+ return_dict (`bool`, *optional*):
891
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
892
+ """
893
+
894
+ UMT5_ENCODER_INPUTS_DOCSTRING = r"""
895
+ Args:
896
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
897
+ Indices of input sequence tokens in the vocabulary. UMT5 is a model with relative position embeddings so
898
+ you should be able to pad the inputs on both the right and the left.
899
+
900
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
901
+ [`PreTrainedTokenizer.__call__`] for detail.
902
+
903
+ To know more on how to prepare `input_ids` for pretraining take a look a [UMT5 Training](./umt5#training).
904
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
905
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
906
+
907
+ - 1 for tokens that are **not masked**,
908
+ - 0 for tokens that are **masked**.
909
+
910
+ [What are attention masks?](../glossary#attention-mask)
911
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
912
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
913
+
914
+ - 1 indicates the head is **not masked**,
915
+ - 0 indicates the head is **masked**.
916
+
917
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
918
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
919
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
920
+ model's internal embedding lookup matrix.
921
+ output_attentions (`bool`, *optional*):
922
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
923
+ tensors for more detail.
924
+ output_hidden_states (`bool`, *optional*):
925
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
926
+ more detail.
927
+ return_dict (`bool`, *optional*):
928
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
929
+ """
930
+
931
+
932
+ @add_start_docstrings(
933
+ "The bare UMT5 Model transformer outputting raw hidden-states without any specific head on top.",
934
+ UMT5_START_DOCSTRING,
935
+ )
936
+ class UMT5Model(UMT5PreTrainedModel):
937
+ r"""
938
+ Examples:
939
+
940
+ ```python
941
+ >>> from transformers import UMT5Model, AutoTokenizer
942
+
943
+ >>> model = UMT5Model.from_pretrained("google/umt5-small")
944
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/umt5-small")
945
+ >>> noisy_text = "UN Offizier sagt, dass weiter <extra_id_0> werden muss in Syrien."
946
+ >>> label = "<extra_id_0> verhandelt"
947
+ >>> inputs = tokenizer(inputs, return_tensors="pt")
948
+ >>> labels = tokenizer(label=label, return_tensors="pt")
949
+
950
+ >>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
951
+ >>> hidden_states = outputs.last_hidden_state
952
+ ```"""
953
+ model_type = "uumt5"
954
+ config_class = UMT5Config
955
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
956
+
957
+ def __init__(self, config):
958
+ super().__init__(config)
959
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
960
+
961
+ encoder_config = copy.deepcopy(config)
962
+ encoder_config.is_decoder = False
963
+ encoder_config.use_cache = False
964
+ encoder_config.is_encoder_decoder = False
965
+ self.encoder = UMT5Stack(encoder_config, self.shared)
966
+
967
+ decoder_config = copy.deepcopy(config)
968
+ decoder_config.is_decoder = True
969
+ decoder_config.is_encoder_decoder = False
970
+ decoder_config.num_layers = config.num_decoder_layers
971
+ self.decoder = UMT5Stack(decoder_config, self.shared)
972
+
973
+ # Initialize weights and apply final processing
974
+ self.post_init()
975
+
976
+ # Copied from transformers.models.t5.modeling_t5.T5Model.get_input_embeddings
977
+ def get_input_embeddings(self):
978
+ return self.shared
979
+
980
+ # Copied from transformers.models.t5.modeling_t5.T5Model.set_input_embeddings
981
+ def set_input_embeddings(self, new_embeddings):
982
+ self.shared = new_embeddings
983
+ self.encoder.set_input_embeddings(new_embeddings)
984
+ self.decoder.set_input_embeddings(new_embeddings)
985
+
986
+ # Copied from transformers.models.t5.modeling_t5.T5Model.get_encoder
987
+ def get_encoder(self):
988
+ return self.encoder
989
+
990
+ # Copied from transformers.models.t5.modeling_t5.T5Model.get_decoder
991
+ def get_decoder(self):
992
+ return self.decoder
993
+
994
+ # Copied from transformers.models.t5.modeling_t5.T5Model._prune_heads
995
+ def _prune_heads(self, heads_to_prune):
996
+ """
997
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
998
+ class PreTrainedModel
999
+ """
1000
+ for layer, heads in heads_to_prune.items():
1001
+ self.encoder.layer[layer].attention.prune_heads(heads)
1002
+
1003
+ @add_start_docstrings_to_model_forward(UMT5_INPUTS_DOCSTRING)
1004
+ @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
1005
+ def forward(
1006
+ self,
1007
+ input_ids: Optional[torch.LongTensor] = None,
1008
+ attention_mask: Optional[torch.FloatTensor] = None,
1009
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1010
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
1011
+ head_mask: Optional[torch.FloatTensor] = None,
1012
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
1013
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1014
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1015
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1016
+ inputs_embeds: Optional[torch.Tensor] = None,
1017
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
1018
+ use_cache: Optional[bool] = None,
1019
+ output_attentions: Optional[bool] = None,
1020
+ output_hidden_states: Optional[bool] = None,
1021
+ return_dict: Optional[bool] = None,
1022
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
1023
+ r"""
1024
+ Returns:
1025
+
1026
+ Example:
1027
+
1028
+ ```python
1029
+ >>> from transformers import AutoTokenizer, UMT5Model
1030
+
1031
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/umt5-small")
1032
+ >>> model = UMT5Model.from_pretrained("google/umt5-small")
1033
+
1034
+ >>> input_ids = tokenizer(
1035
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1036
+ ... ).input_ids # Batch size 1
1037
+ >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
1038
+
1039
+ >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for UMT5Model.
1040
+ >>> # This is not needed for torch's UMT5ForConditionalGeneration as it does this internally using labels arg.
1041
+ >>> decoder_input_ids = model._shift_right(decoder_input_ids)
1042
+
1043
+ >>> # forward pass
1044
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
1045
+ >>> last_hidden_states = outputs.last_hidden_state
1046
+ ```"""
1047
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1048
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1049
+
1050
+ # Encode if needed (training, first prediction pass)
1051
+ if encoder_outputs is None:
1052
+ encoder_outputs = self.encoder(
1053
+ input_ids=input_ids,
1054
+ attention_mask=attention_mask,
1055
+ inputs_embeds=inputs_embeds,
1056
+ head_mask=head_mask,
1057
+ output_attentions=output_attentions,
1058
+ output_hidden_states=output_hidden_states,
1059
+ return_dict=return_dict,
1060
+ )
1061
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1062
+ encoder_outputs = BaseModelOutput(
1063
+ last_hidden_state=encoder_outputs[0],
1064
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1065
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1066
+ )
1067
+
1068
+ hidden_states = encoder_outputs[0]
1069
+
1070
+ # Decode
1071
+ decoder_outputs = self.decoder(
1072
+ input_ids=decoder_input_ids,
1073
+ attention_mask=decoder_attention_mask,
1074
+ inputs_embeds=decoder_inputs_embeds,
1075
+ past_key_values=past_key_values,
1076
+ encoder_hidden_states=hidden_states,
1077
+ encoder_attention_mask=attention_mask,
1078
+ head_mask=decoder_head_mask,
1079
+ cross_attn_head_mask=cross_attn_head_mask,
1080
+ use_cache=use_cache,
1081
+ output_attentions=output_attentions,
1082
+ output_hidden_states=output_hidden_states,
1083
+ return_dict=return_dict,
1084
+ )
1085
+
1086
+ if not return_dict:
1087
+ return decoder_outputs + encoder_outputs
1088
+
1089
+ return Seq2SeqModelOutput(
1090
+ last_hidden_state=decoder_outputs.last_hidden_state,
1091
+ past_key_values=decoder_outputs.past_key_values,
1092
+ decoder_hidden_states=decoder_outputs.hidden_states,
1093
+ decoder_attentions=decoder_outputs.attentions,
1094
+ cross_attentions=decoder_outputs.cross_attentions,
1095
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1096
+ encoder_hidden_states=encoder_outputs.hidden_states,
1097
+ encoder_attentions=encoder_outputs.attentions,
1098
+ )
1099
+
1100
+
1101
+ # start of ranking prompter code
1102
+
1103
+
1104
+ from contextlib import nullcontext
1105
+ from dataclasses import dataclass
1106
+ from typing import Optional, Tuple, Union
1107
+
1108
+ import torch
1109
+ from torch import nn
1110
+ from torch.nn import CrossEntropyLoss
1111
+ from .configuration_rankingprompter import RankingPrompterConfig
1112
+
1113
+
1114
+ @dataclass
1115
+ class RankingPrompterForPreTrainingOutput:
1116
+ loss: torch.FloatTensor = None
1117
+ logits: torch.FloatTensor = None
1118
+
1119
+
1120
+ @dataclass
1121
+ class RankingPrompterOutput:
1122
+ loss: torch.FloatTensor = None
1123
+ logits: torch.FloatTensor = None
1124
+ lm_logits: torch.FloatTensor = None
1125
+ loss_lm: torch.FloatTensor = None
1126
+ loss_ranking: torch.FloatTensor = None
1127
+
1128
+
1129
+
1130
+ class RankingPrompterForPreTraining(UMT5Model):
1131
+ config_class = RankingPrompterConfig
1132
+
1133
+ _tied_weights_keys = [
1134
+ "encoder.embed_tokens.weight",
1135
+ "decoder.embed_tokens.weight",
1136
+ ]
1137
+
1138
+ def __init__(self, config):
1139
+ # encoder, decoder and shared are from UMT5Model
1140
+ super().__init__(config)
1141
+
1142
+ # add ranking head
1143
+ self.ranking_head = nn.Linear(config.d_model, 1)
1144
+
1145
+ # Initialize weights and apply final processing
1146
+ self.post_init()
1147
+
1148
+ # ctx for mixed precision training
1149
+ self.ctx = nullcontext()
1150
+
1151
+ def enable_amp_ctx(self, device_type="cuda", dtype=torch.bfloat16):
1152
+ self.ctx = torch.amp.autocast(device_type=device_type, dtype=dtype)
1153
+
1154
+ def disable_amp_ctx(self):
1155
+ self.ctx = nullcontext()
1156
+
1157
+ def forward(
1158
+ self,
1159
+ document_input_ids: Optional[torch.LongTensor] = None,
1160
+ document_attention_mask: Optional[torch.FloatTensor] = None,
1161
+ question_input_ids: Optional[torch.LongTensor] = None,
1162
+ question_attention_mask: Optional[torch.BoolTensor] = None,
1163
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1164
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1165
+ labels: Optional[torch.LongTensor] = None,
1166
+ use_cache: Optional[bool] = None,
1167
+ return_dict: Optional[bool] = None,
1168
+ ) -> Union[Tuple[torch.FloatTensor], RankingPrompterForPreTrainingOutput]:
1169
+ r"""
1170
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1171
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
1172
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
1173
+ labels in `[0, ..., config.vocab_size]`
1174
+
1175
+ Returns:
1176
+
1177
+ ```"""
1178
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1179
+ return_dict = (
1180
+ return_dict if return_dict is not None else self.config.use_return_dict
1181
+ )
1182
+ # document_input_ids: [batch_size, num_doc, doc_seq_len]
1183
+ batch_size, num_doc, doc_seq_len = document_input_ids.shape
1184
+ #
1185
+ document_input_ids = document_input_ids.view(-1, doc_seq_len)
1186
+ # to [batch_size * num_doc, doc_seq_len]
1187
+ document_attention_mask = document_attention_mask.view(-1, doc_seq_len)
1188
+
1189
+ # Convert encoder inputs in embeddings if needed
1190
+ with self.ctx:
1191
+ encoder_outputs = self.encoder(
1192
+ input_ids=document_input_ids,
1193
+ attention_mask=document_attention_mask,
1194
+ return_dict=return_dict,
1195
+ )
1196
+
1197
+ document_embeds = encoder_outputs[0]
1198
+
1199
+ # repeat question inputs for each document
1200
+ # question_input_ids: [batch_size, question_seq_len]
1201
+ question_seq_len = question_input_ids.shape[1]
1202
+ question_input_ids_expand = (
1203
+ question_input_ids.unsqueeze(1)
1204
+ .expand(-1, num_doc, -1)
1205
+ .reshape(-1, question_seq_len)
1206
+ ) # [batch_size * num_doc, question_seq_len]
1207
+ question_attention_mask_expand = (
1208
+ question_attention_mask.unsqueeze(1)
1209
+ .expand(-1, num_doc, -1)
1210
+ .reshape(-1, question_seq_len)
1211
+ ) # [batch_size * num_doc, question_seq_len]
1212
+
1213
+ # Decode
1214
+ with self.ctx:
1215
+ decoder_outputs = self.decoder(
1216
+ input_ids=question_input_ids_expand,
1217
+ attention_mask=question_attention_mask_expand,
1218
+ past_key_values=past_key_values,
1219
+ encoder_hidden_states=document_embeds,
1220
+ encoder_attention_mask=document_attention_mask,
1221
+ use_cache=use_cache,
1222
+ return_dict=return_dict,
1223
+ )
1224
+ # [batch_size * num_doc, soft_prompt_len + question_seq_len, hidden_size]
1225
+ sequence_output = decoder_outputs[0]
1226
+ # [batch_size * num_doc, soft_prompt_len, hidden_size]
1227
+ question_seq_len = sequence_output.size(1)
1228
+ # [batch_size, num_doc, soft_prompt_len, hidden_size]
1229
+ soft_prompt_output = sequence_output.view(
1230
+ batch_size, num_doc, question_seq_len, -1
1231
+ )
1232
+ question_attention_mask_expand = question_attention_mask_expand.view(
1233
+ batch_size, num_doc, question_seq_len
1234
+ )
1235
+ # apply question attention mask
1236
+ soft_prompt_output = soft_prompt_output * question_attention_mask_expand.unsqueeze(-1)
1237
+
1238
+ # [batch_size, num_doc, self.num_soft_prompt_tokens, hidden_size] -> [batch_size, num_doc]
1239
+ ranking_logits = self.ranking_head(soft_prompt_output.mean(dim=2)).view(batch_size, num_doc)
1240
+
1241
+ # rank loss
1242
+ loss = None
1243
+ if labels is not None:
1244
+ loss_fct = CrossEntropyLoss(ignore_index=-100, label_smoothing=0.1)
1245
+ loss = loss_fct(ranking_logits, labels)
1246
+
1247
+ if not return_dict:
1248
+ output = (ranking_logits,) + decoder_outputs[1:] + encoder_outputs
1249
+ return ((loss,) + output) if loss is not None else output
1250
+
1251
+ return RankingPrompterForPreTrainingOutput(
1252
+ loss=loss,
1253
+ logits=ranking_logits
1254
+ )
1255
+
1256
+
1257
+ class RankingPrompter(UMT5Model):
1258
+ config_class = RankingPrompterConfig
1259
+
1260
+ _tied_weights_keys = [
1261
+ "encoder.embed_tokens.weight",
1262
+ "decoder.embed_tokens.weight",
1263
+ ]
1264
+
1265
+ def __init__(self, config):
1266
+ # encoder, decoder and shared are from UMT5Model
1267
+ super().__init__(config)
1268
+
1269
+ # add ranking head
1270
+ self.ranking_head = nn.Linear(config.d_model, 1)
1271
+
1272
+ # Initialize weights and apply final processing
1273
+ self.post_init()
1274
+
1275
+ # ctx for mixed precision training
1276
+ self.ctx = nullcontext()
1277
+
1278
+ def enable_amp_ctx(self, device_type="cuda", dtype=torch.bfloat16):
1279
+ self.ctx = torch.amp.autocast(device_type=device_type, dtype=dtype)
1280
+
1281
+ def disable_amp_ctx(self):
1282
+ self.ctx = nullcontext()
1283
+
1284
+ def encode_document(self, document_input_ids, document_attention_mask):
1285
+ # input shape: [batch_size * num_doc, doc_seq_len]
1286
+ # Convert encoder inputs in embeddings if needed
1287
+ with self.ctx:
1288
+ encoder_outputs = self.encoder(
1289
+ input_ids=document_input_ids,
1290
+ attention_mask=document_attention_mask,
1291
+ return_dict=False,
1292
+ )
1293
+ return encoder_outputs
1294
+
1295
+ def decode_answer(
1296
+ self,
1297
+ question_input_ids,
1298
+ question_attention_mask,
1299
+ document_embeds,
1300
+ document_attention_mask,
1301
+ answer_input_ids=None,
1302
+ answer_attention_mask=None
1303
+ ):
1304
+ if answer_input_ids is not None and answer_attention_mask is not None:
1305
+ # append answer input ids to question input ids
1306
+ question_input_ids = torch.cat([question_input_ids, answer_input_ids], dim=1)
1307
+ question_attention_mask = torch.cat([question_attention_mask, answer_attention_mask], dim=1)
1308
+
1309
+ answer_outputs = self.decoder(
1310
+ input_ids=question_input_ids,
1311
+ attention_mask=question_attention_mask,
1312
+ encoder_hidden_states=document_embeds,
1313
+ encoder_attention_mask=document_attention_mask,
1314
+ return_dict=True,
1315
+ )
1316
+ return answer_outputs
1317
+
1318
+ def forward(
1319
+ self,
1320
+ document_input_ids: Optional[torch.LongTensor] = None,
1321
+ document_attention_mask: Optional[torch.FloatTensor] = None,
1322
+ question_input_ids: Optional[torch.LongTensor] = None,
1323
+ question_attention_mask: Optional[torch.BoolTensor] = None,
1324
+ answer_input_ids: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1325
+ answer_attention_mask: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1326
+ labels: Optional[torch.LongTensor] = None,
1327
+ use_cache: Optional[bool] = None,
1328
+ return_dict: Optional[bool] = None,
1329
+ ) -> Union[Tuple[torch.FloatTensor], RankingPrompterOutput]:
1330
+ r"""
1331
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1332
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
1333
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
1334
+ labels in `[0, ..., config.vocab_size]`
1335
+
1336
+ Returns:
1337
+
1338
+ ```"""
1339
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1340
+ return_dict = (
1341
+ return_dict if return_dict is not None else self.config.use_return_dict
1342
+ )
1343
+ if len(document_input_ids.shape) == 2:
1344
+ # make [batch_size, doc_seq_len] -> [batch_size, 1, doc_seq_len]
1345
+ document_input_ids = document_input_ids.unsqueeze(1)
1346
+ document_attention_mask = document_attention_mask.unsqueeze(1)
1347
+ # document_input_ids: [batch_size, num_doc, doc_seq_len]
1348
+ batch_size, num_doc, doc_seq_len = document_input_ids.shape
1349
+ document_input_ids = document_input_ids.view(-1, doc_seq_len)
1350
+ # to [batch_size * num_doc, doc_seq_len]
1351
+ document_attention_mask = document_attention_mask.view(-1, doc_seq_len)
1352
+
1353
+ encoder_outputs = self.encode_document(document_input_ids, document_attention_mask)
1354
+ document_embeds = encoder_outputs[0]
1355
+
1356
+ # repeat question inputs for each document
1357
+ # question_input_ids: [batch_size, question_seq_len]
1358
+ question_seq_len = question_input_ids.shape[1]
1359
+ question_input_ids_expand = (
1360
+ question_input_ids.unsqueeze(1)
1361
+ .expand(-1, num_doc, -1)
1362
+ .reshape(-1, question_seq_len)
1363
+ ) # [batch_size * num_doc, question_seq_len]
1364
+ question_attention_mask_expand = (
1365
+ question_attention_mask.unsqueeze(1)
1366
+ .expand(-1, num_doc, -1)
1367
+ .reshape(-1, question_seq_len)
1368
+ ) # [batch_size * num_doc, question_seq_len]
1369
+
1370
+ # Decode
1371
+ with self.ctx:
1372
+ decoder_outputs = self.decoder(
1373
+ input_ids=question_input_ids_expand,
1374
+ attention_mask=question_attention_mask_expand,
1375
+ encoder_hidden_states=document_embeds,
1376
+ encoder_attention_mask=document_attention_mask,
1377
+ use_cache=False,
1378
+ return_dict=True,
1379
+ )
1380
+ # [batch_size * num_doc, soft_prompt_len + question_seq_len, hidden_size]
1381
+ sequence_output = decoder_outputs.last_hidden_state
1382
+ # [batch_size * num_doc, soft_prompt_len, hidden_size]
1383
+ question_seq_len = sequence_output.size(1)
1384
+ # [batch_size, num_doc, soft_prompt_len, hidden_size]
1385
+ soft_prompt_output = sequence_output.view(
1386
+ batch_size, num_doc, question_seq_len, -1
1387
+ )
1388
+ question_attention_mask_expand = question_attention_mask_expand.view(
1389
+ batch_size, num_doc, question_seq_len
1390
+ )
1391
+ # apply question attention mask
1392
+ soft_prompt_output = soft_prompt_output * question_attention_mask_expand.unsqueeze(-1)
1393
+ # get the real mean by the real length
1394
+ soft_prompt_output_mean = soft_prompt_output.sum(dim=2) / question_attention_mask_expand.sum(dim=2, keepdim=True)
1395
+ # [batch_size, num_doc, self.num_soft_prompt_tokens, hidden_size] -> [batch_size, num_doc]
1396
+ ranking_logits = self.ranking_head(soft_prompt_output_mean).view(batch_size, num_doc)
1397
+
1398
+ # rank loss
1399
+ loss_ranking = None
1400
+ if labels is not None:
1401
+ loss_fct = CrossEntropyLoss(ignore_index=-100, label_smoothing=0.1)
1402
+ loss_ranking = loss_fct(ranking_logits, labels)
1403
+ # append bos token id to question input ids
1404
+ question_input_ids = torch.cat(
1405
+ [question_input_ids, torch.ones_like(question_input_ids[:, :1]).fill_(self.config.decoder_start_token_id)], dim=1)
1406
+ question_attention_mask = torch.cat(
1407
+ [question_attention_mask, torch.ones_like(question_attention_mask[:, :1])], dim=1)
1408
+ # only take the first document for answer generation training
1409
+ answer_outputs = self.decode_answer(question_input_ids,
1410
+ question_attention_mask,
1411
+ document_embeds[::num_doc],
1412
+ document_attention_mask[::num_doc],
1413
+ answer_input_ids,
1414
+ answer_attention_mask)
1415
+ # lm loss
1416
+ loss_lm = None
1417
+ lm_logits = None
1418
+ if answer_input_ids is not None:
1419
+ # fill in question_input_ids with -100
1420
+ question_input_mask = torch.zeros_like(question_input_ids).fill_(-100)
1421
+ # mask padding token in answer_input_ids with -100
1422
+ answer_input_ids = answer_input_ids.masked_fill(answer_input_ids == self.config.pad_token_id, -100)
1423
+ # [batch_size, question_seq_len + answer_seq_len, hidden_size]
1424
+ lm_labels = torch.cat([question_input_mask, answer_input_ids], dim=1)[:, 1:].contiguous()
1425
+ lm_logits = (answer_outputs.last_hidden_state @ self.decoder.embed_tokens.weight.t())[:, :-1, :].contiguous()
1426
+ loss_fct = CrossEntropyLoss(ignore_index=-100, label_smoothing=0.1)
1427
+ loss_lm = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
1428
+
1429
+ if loss_ranking is not None and loss_lm is not None:
1430
+ loss = loss_ranking + loss_lm
1431
+ elif loss_ranking is not None:
1432
+ loss = loss_ranking
1433
+ elif loss_lm is not None:
1434
+ loss = loss_lm
1435
+ else:
1436
+ loss = None
1437
+
1438
+ if not return_dict:
1439
+ output = (ranking_logits,) + decoder_outputs[1:] + encoder_outputs
1440
+ return ((loss,) + output) if loss is not None else output
1441
+
1442
+ return RankingPrompterOutput(
1443
+ loss=loss,
1444
+ logits=ranking_logits,
1445
+ lm_logits=lm_logits,
1446
+ loss_lm=loss_lm,
1447
+ loss_ranking=loss_ranking,
1448
+ )
1449
+
1450
+ def generate_answer(
1451
+ self,
1452
+ document_input_ids: Optional[torch.LongTensor] = None,
1453
+ document_attention_mask: Optional[torch.FloatTensor] = None,
1454
+ question_input_ids: Optional[torch.LongTensor] = None,
1455
+ question_attention_mask: Optional[torch.BoolTensor] = None
1456
+ ):
1457
+ if len(document_input_ids.shape) == 2:
1458
+ # make [batch_size, doc_seq_len] -> [batch_size, 1, doc_seq_len]
1459
+ document_input_ids = document_input_ids.unsqueeze(1)
1460
+ document_attention_mask = document_attention_mask.unsqueeze(1)
1461
+ # document_input_ids: [batch_size, num_doc, doc_seq_len]
1462
+ batch_size, num_doc, doc_seq_len = document_input_ids.shape
1463
+ document_input_ids = document_input_ids.view(-1, doc_seq_len)
1464
+ # to [batch_size * num_doc, doc_seq_len]
1465
+ document_attention_mask = document_attention_mask.view(-1, doc_seq_len)
1466
+ document_embeds = self.encode_document(document_input_ids, document_attention_mask)[0]
1467
+ # append bos token id to question input ids
1468
+ question_input_ids = torch.cat(
1469
+ [question_input_ids, torch.ones_like(question_input_ids[:, :1]).fill_(self.config.decoder_start_token_id)], dim=1)
1470
+ question_attention_mask = torch.cat(
1471
+ [question_attention_mask, torch.ones_like(question_attention_mask[:, :1])], dim=1)
1472
+ answer_outputs = self.decode_answer(question_input_ids,
1473
+ question_attention_mask,
1474
+ document_embeds[::num_doc],
1475
+ document_attention_mask[:num_doc])
1476
+ lm_logits = answer_outputs.last_hidden_state @ self.decoder.embed_tokens.weight.t()
1477
+ return lm_logits[:, -1:, :]
1478
+
1479
+
1480
+ def compute_ranking_grad_cam(
1481
+ self,
1482
+ document_input_ids,
1483
+ document_attention_mask,
1484
+ question_input_ids,
1485
+ question_attention_mask,
1486
+ block_num=-1,
1487
+ reduction="sum"):
1488
+ # 设置模型为evaluation模式, 开启保存attention map
1489
+ self.eval()
1490
+ attention_layer = self.decoder.block[block_num].layer[-2].EncDecAttention
1491
+ attention_layer.save_attention = True
1492
+
1493
+ # 正向传播以获取特征图
1494
+ encoder_outputs = self.encode_document(document_input_ids, document_attention_mask)
1495
+ document_embeds = encoder_outputs[0]
1496
+
1497
+ # 正向传播解码器以获取Grad-CAM
1498
+ decoder_outputs = self.decoder(
1499
+ input_ids=question_input_ids,
1500
+ attention_mask=question_attention_mask,
1501
+ encoder_hidden_states=document_embeds,
1502
+ encoder_attention_mask=document_attention_mask,
1503
+ use_cache=False,
1504
+ return_dict=True,
1505
+ )
1506
+
1507
+ # get grads
1508
+ soft_prompt_output = decoder_outputs.last_hidden_state * question_attention_mask.unsqueeze(-1)
1509
+ ranking_logits = self.ranking_head(soft_prompt_output.mean(dim=1)).view(-1)
1510
+ loss = ranking_logits.sum()
1511
+ self.zero_grad()
1512
+ loss.backward()
1513
+
1514
+ # compute grad cam
1515
+ with torch.no_grad():
1516
+ # grads and cams [bsz, num_head, ques_len, doc_len]
1517
+ grads = attention_layer.get_attn_gradients()
1518
+ cams = attention_layer.get_attention_map()
1519
+ gradcams = cams * grads
1520
+ # average over heads -> [bsz, ques_len, doc_len]
1521
+ gradcams = gradcams.mean(dim=1)
1522
+ # apply relu
1523
+ gradcams = gradcams.relu()
1524
+ # apply question attention mask
1525
+ gradcams = gradcams * question_attention_mask.unsqueeze(-1)
1526
+ if reduction == "sum":
1527
+ gradcams = gradcams.sum(dim=1)
1528
+ elif reduction == "mean":
1529
+ gradcams = gradcams.mean(dim=1)
1530
+ return gradcams
1531
+
1532
+
1533
+ def compute_lm_grad_cam(
1534
+ self,
1535
+ document_input_ids,
1536
+ document_attention_mask,
1537
+ question_input_ids,
1538
+ question_attention_mask,
1539
+ max_new_tokens=10,
1540
+ block_num=-1,
1541
+ reduction="sum"):
1542
+ # 设置模型为evaluation模式, 开启保存attention map
1543
+ self.eval()
1544
+ attention_layer = self.decoder.block[block_num].layer[-2].EncDecAttention
1545
+ attention_layer.save_attention = True
1546
+
1547
+ # 正向传播以获取特征图
1548
+ encoder_outputs = self.encode_document(document_input_ids, document_attention_mask)
1549
+ document_embeds = encoder_outputs[0]
1550
+
1551
+ # append bos token id to question input ids
1552
+ question_input_ids = torch.cat(
1553
+ [question_input_ids, torch.ones_like(question_input_ids[:, :1]).fill_(self.config.decoder_start_token_id)], dim=1)
1554
+ question_attention_mask = torch.cat(
1555
+ [question_attention_mask, torch.ones_like(question_attention_mask[:, :1])], dim=1)
1556
+
1557
+
1558
+ gradcams_output = []
1559
+ tokens_output = []
1560
+ for _ in range(max_new_tokens):
1561
+ # 正向传播解码器以获取Grad-CAM
1562
+ decoder_outputs = self.decoder(
1563
+ input_ids=question_input_ids,
1564
+ attention_mask=question_attention_mask,
1565
+ encoder_hidden_states=document_embeds,
1566
+ encoder_attention_mask=document_attention_mask,
1567
+ use_cache=False,
1568
+ return_dict=True,
1569
+ )
1570
+ # get grads
1571
+ lm_logits = (decoder_outputs.last_hidden_state @ self.decoder.embed_tokens.weight.t())[:, -1:, :].contiguous()
1572
+ max_logits, max_indices = lm_logits.max(dim=-1)
1573
+ loss = max_logits.sum()
1574
+ question_input_ids = torch.cat([question_input_ids, max_indices], dim=-1)
1575
+ question_attention_mask = torch.cat([question_attention_mask, torch.ones_like(question_attention_mask[:, :1])], dim=1)
1576
+ tokens_output.append(max_indices)
1577
+
1578
+ self.zero_grad()
1579
+ loss.backward(retain_graph=True)
1580
+
1581
+ # compute grad cam
1582
+ with torch.no_grad():
1583
+ # grads and cams [bsz, num_head, ques_len, doc_len]
1584
+ grads = attention_layer.get_attn_gradients()
1585
+ cams = attention_layer.get_attention_map()
1586
+ gradcams = cams[:, :, -1:, :] * grads[:, :, -1:, :]
1587
+ # average over heads -> [bsz, 1, doc_len]
1588
+ gradcams = gradcams.mean(dim=1)
1589
+ # apply relu
1590
+ gradcams = gradcams.relu()
1591
+ gradcams_output.append(gradcams)
1592
+ # concat to [bsz, max_new_tokens, doc_len]
1593
+ gradcams_output = torch.cat(gradcams_output, dim=1)
1594
+ # concat to [bsz, max_new_tokens]
1595
+ tokens_output = torch.cat(tokens_output, dim=1)
1596
+ # mask eos token gradcam
1597
+ gradcams_output = gradcams_output * (tokens_output != self.config.eos_token_id).unsqueeze(-1)
1598
+ if reduction == "sum":
1599
+ gradcams_output = gradcams_output.sum(dim=1)
1600
+ elif reduction == "mean":
1601
+ gradcams_output = gradcams_output.mean(dim=1)
1602
+ return tokens_output, gradcams_output
1603
+
1604
+
1605
+ def split_context_by_token_id(
1606
+ self,
1607
+ document_input_ids,
1608
+ gradcams,
1609
+ split_token_id = 310,
1610
+ ):
1611
+ bsz = document_input_ids.shape[0]
1612
+ batch_doc_splits = []
1613
+ for i in range(bsz):
1614
+ one_doc = document_input_ids[i]
1615
+ grad_cam = gradcams[i]
1616
+ # find the split token index
1617
+ split_idx = (one_doc == split_token_id).nonzero(as_tuple=True)[0]
1618
+ # split the document input ids
1619
+ num_split = len(split_idx)
1620
+ if num_split > 0:
1621
+ one_doc_splits = []
1622
+ activation_splits = []
1623
+ for i in range(num_split):
1624
+ if i == 0:
1625
+ # first split
1626
+ one_doc_splits.append(one_doc[:split_idx[i]])
1627
+ activation = grad_cam[:split_idx[i]].mean()
1628
+ activation_splits.append(activation)
1629
+ else:
1630
+ one_doc_splits.append(one_doc[split_idx[i-1]+1:split_idx[i]])
1631
+ activation = grad_cam[split_idx[i-1]+1:split_idx[i]].mean()
1632
+ activation_splits.append(activation)
1633
+ # append the last split
1634
+ one_doc_splits.append(one_doc[split_idx[-1]+1:])
1635
+ activation = grad_cam[split_idx[-1]+1:].mean()
1636
+ activation_splits.append(activation)
1637
+ else:
1638
+ # no split token in the document
1639
+ one_doc_splits = [one_doc]
1640
+ activation_splits = [grad_cam.mean()]
1641
+ #
1642
+ batch_doc_splits.append((one_doc_splits, activation_splits))
1643
+ return batch_doc_splits
1644
+
1645
+
1646
+ def drop_context_by_activation(
1647
+ self,
1648
+ batch_doc_splits,
1649
+ keep_ratio=0.5,
1650
+ ):
1651
+ # if keep ratio is zero, raise a error
1652
+ if keep_ratio == 0 or keep_ratio < 0 or keep_ratio == 0.0:
1653
+ raise ValueError("keep ratio should not be zero or negative")
1654
+ batch_doc_splits_drop = []
1655
+ for one_doc_splits, activation_splits in batch_doc_splits:
1656
+ sorted_idx = sorted(range(len(activation_splits)), key=lambda k: activation_splits[k], reverse=True)
1657
+ # at least keep one context
1658
+ num_drop = max(int(len(sorted_idx) * keep_ratio), 1)
1659
+ # keep order of document
1660
+ sorted_idx = sorted(sorted_idx[:num_drop])
1661
+ one_doc_splits_drop = [one_doc_splits[i] for i in sorted_idx]
1662
+ batch_doc_splits_drop.append(one_doc_splits_drop)
1663
+ return batch_doc_splits_drop
1664
+
1665
+ def drop_context_by_avg_rank(
1666
+ self,
1667
+ batch_doc_splits_ranking,
1668
+ batch_doc_splits_lm,
1669
+ keep_ratio=0.5,
1670
+ ):
1671
+ # if keep ratio is zero, raise a error
1672
+ if keep_ratio == 0 or keep_ratio < 0 or keep_ratio == 0.0:
1673
+ raise ValueError("keep ratio should not be zero or negative")
1674
+ batch_doc_splits_drop = []
1675
+ bsz = len(batch_doc_splits_ranking)
1676
+ for i in range(bsz):
1677
+ one_doc_splits_ranking, activation_splits_ranking = batch_doc_splits_ranking[i]
1678
+ one_doc_splits_lm, activation_splits_lm = batch_doc_splits_lm[i]
1679
+ # sort by ranking activation
1680
+ ranking_sorted_idx = sorted(range(len(activation_splits_ranking)), key=lambda k: activation_splits_ranking[k], reverse=True)
1681
+ lm_sorted_idx = sorted(range(len(activation_splits_lm)), key=lambda k: activation_splits_lm[k], reverse=True)
1682
+ # sort by average rank of ranking and lm
1683
+ avg_rank = [(ranking_sorted_idx.index(i) + lm_sorted_idx.index(i)) / 2 for i in range(len(ranking_sorted_idx))]
1684
+ sorted_idx = sorted(range(len(avg_rank)), key=lambda k: avg_rank[k])
1685
+ # at least keep one context
1686
+ num_drop = max(int(len(sorted_idx) * keep_ratio), 1)
1687
+ # keep order of document
1688
+ sorted_idx = sorted(sorted_idx[:num_drop])
1689
+ one_doc_splits_drop = [one_doc_splits_ranking[i] for i in sorted_idx]
1690
+ batch_doc_splits_drop.append(one_doc_splits_drop)
1691
+ return batch_doc_splits_drop
1692
+
1693
+
1694
+ def compress_context_by_activation(
1695
+ self,
1696
+ document_input_ids,
1697
+ gradcams_output,
1698
+ keep_ratio=0.5,
1699
+ ):
1700
+ # split context by split token id
1701
+ batch_doc_splits = self.split_context_by_token_id(document_input_ids, gradcams_output)
1702
+ # drop context by activation
1703
+ batch_doc_splits_drop = self.drop_context_by_activation(batch_doc_splits, keep_ratio)
1704
+ return batch_doc_splits_drop
1705
+
1706
+
1707
+ def compress_context(
1708
+ self,
1709
+ document_input_ids,
1710
+ ranking_gradcams,
1711
+ lm_gradcams,
1712
+ keep_ratio=0.5,
1713
+ ):
1714
+ # split context by split token id
1715
+ batch_doc_splits_ranking = self.split_context_by_token_id(document_input_ids, ranking_gradcams)
1716
+ batch_doc_splits_lm = self.split_context_by_token_id(document_input_ids, lm_gradcams)
1717
+ # drop context by activation
1718
+ batch_doc_splits_drop = self.drop_context_by_avg_rank(
1719
+ batch_doc_splits_ranking, batch_doc_splits_lm, keep_ratio)
1720
+ return batch_doc_splits_drop
1721
+
1722
+
1723
+