nhatminh commited on
Commit
51e5059
·
verified ·
1 Parent(s): 4aed4b4

Upload modeling_xlm_roberta.py

Browse files
Files changed (1) hide show
  1. modeling_xlm_roberta.py +1119 -0
modeling_xlm_roberta.py ADDED
@@ -0,0 +1,1119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
2
+ # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3
+ # Copyright (c) 2022, Tri Dao.
4
+ # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
5
+ # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
6
+ # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
7
+
8
+ # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
9
+
10
+ import importlib.util
11
+ import logging
12
+ import re
13
+ from collections import OrderedDict
14
+ from collections.abc import Sequence
15
+ from functools import partial
16
+ import numpy as np
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.utils.checkpoint
22
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
+ from einops import rearrange
24
+ from transformers import PretrainedConfig
25
+ from transformers.modeling_utils import PreTrainedModel
26
+ from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
27
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
28
+
29
+ from transformers.models.bert.modeling_bert import (
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ BertForPreTrainingOutput,
32
+ )
33
+
34
+ from typing import List, Optional, Tuple, Union
35
+
36
+ from .xlm_padding import (
37
+ index_first_axis,
38
+ index_first_axis_residual,
39
+ pad_input,
40
+ unpad_input,
41
+ )
42
+ from .configuration_xlm_roberta import XLMRobertaFlashConfig
43
+ from .block import Block
44
+ from .embedding import XLMRobertaEmbeddings
45
+ from .mha import MHA
46
+ from .mlp import FusedMLP, Mlp
47
+
48
+ try:
49
+ from flash_attn.ops.fused_dense import FusedDense
50
+ except ImportError:
51
+ FusedDense = None
52
+
53
+ try:
54
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn
55
+ except ImportError:
56
+ layer_norm_fn = None
57
+
58
+
59
+ try:
60
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss
61
+ except ImportError:
62
+ CrossEntropyLoss = torch.nn.CrossEntropyLoss
63
+
64
+ try:
65
+ from tqdm.autonotebook import trange
66
+ except ImportError:
67
+ trange = None
68
+
69
+
70
+ logger = logging.getLogger(__name__)
71
+
72
+
73
+ def get_use_flash_attn(config: XLMRobertaFlashConfig):
74
+ if not getattr(config, "use_flash_attn", False):
75
+ return False
76
+ if not torch.cuda.is_available():
77
+ return False
78
+ if importlib.util.find_spec("flash_attn") is None:
79
+ logger.warning(
80
+ 'flash_attn is not installed. Using PyTorch native attention implementation.'
81
+ )
82
+ return False
83
+ return True
84
+
85
+
86
+ def create_mixer_cls(config, cross_attn=False, return_residual=False):
87
+ use_flash_attn = get_use_flash_attn(config)
88
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
89
+
90
+ mixer_cls = partial(
91
+ MHA,
92
+ num_heads=config.num_attention_heads,
93
+ cross_attn=cross_attn,
94
+ dropout=config.attention_probs_dropout_prob,
95
+ causal=False,
96
+ fused_bias_fc=fused_bias_fc,
97
+ use_flash_attn=use_flash_attn,
98
+ return_residual=return_residual,
99
+ )
100
+ return mixer_cls
101
+
102
+
103
+ def create_mlp_cls(config, layer_idx=None, return_residual=False):
104
+ inner_dim = config.intermediate_size
105
+ fused_mlp = getattr(config, "fused_mlp", False)
106
+ if fused_mlp:
107
+ assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
108
+ "fused_mlp only " "supports approximate gelu"
109
+ )
110
+ if not fused_mlp:
111
+ approximate = (
112
+ "tanh"
113
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
114
+ else "none"
115
+ )
116
+ mlp_cls = partial(
117
+ Mlp,
118
+ hidden_features=inner_dim,
119
+ activation=partial(F.gelu, approximate=approximate),
120
+ return_residual=return_residual,
121
+ )
122
+ else:
123
+ if FusedMLP is None:
124
+ raise ImportError("fused_dense is not installed")
125
+ mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
126
+ # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
127
+ if isinstance(mlp_checkpoint_lvl, Sequence):
128
+ assert layer_idx is not None
129
+ mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
130
+ mlp_cls = partial(
131
+ FusedMLP,
132
+ hidden_features=inner_dim,
133
+ checkpoint_lvl=mlp_checkpoint_lvl,
134
+ return_residual=return_residual,
135
+ )
136
+ return mlp_cls
137
+
138
+
139
+ def create_block(config, layer_idx=None):
140
+ last_layer_subset = getattr(config, "last_layer_subset", False)
141
+ cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1
142
+ # TD [2022-12-19]: For cross attention (last layer), we actually want to return the
143
+ # residual x_kv, not residual x. But it's annoying to change the API (and it only affects
144
+ # one layer) so we just choose not to return residual in this case.
145
+ return_residual = not cross_attn
146
+ mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
147
+ mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
148
+ norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
149
+ block = Block(
150
+ config.hidden_size,
151
+ mixer_cls,
152
+ mlp_cls,
153
+ norm_cls=norm_cls,
154
+ prenorm=False,
155
+ resid_dropout1=config.hidden_dropout_prob,
156
+ resid_dropout2=config.hidden_dropout_prob,
157
+ fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
158
+ return_residual=return_residual,
159
+ )
160
+ return block
161
+
162
+
163
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
164
+ def _init_weights(module, initializer_range=0.02):
165
+ if isinstance(module, nn.Linear):
166
+ nn.init.normal_(module.weight, std=initializer_range)
167
+ if module.bias is not None:
168
+ nn.init.zeros_(module.bias)
169
+ elif isinstance(module, nn.Embedding):
170
+ nn.init.normal_(module.weight, std=initializer_range)
171
+ if module.padding_idx is not None:
172
+ nn.init.zeros_(module.weight[module.padding_idx])
173
+
174
+
175
+ class XLMRobertaEncoder(nn.Module):
176
+ def __init__(self, config: XLMRobertaFlashConfig):
177
+ super().__init__()
178
+ self.use_flash_attn = get_use_flash_attn(config)
179
+ self.layers = nn.ModuleList(
180
+ [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
181
+ )
182
+ self._grad_checkpointing = False
183
+
184
+ @property
185
+ def gradient_checkpointing(self):
186
+ return self._grad_checkpointing
187
+
188
+ @gradient_checkpointing.setter
189
+ def gradient_checkpointing(self, value):
190
+ self._grad_checkpointing = value
191
+
192
+ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
193
+ """If subset_mask is not None, we only want output for the subset of the sequence.
194
+ This means that we only compute the last layer output for these tokens.
195
+ subset_mask: (batch, seqlen), dtype=torch.bool
196
+ """
197
+ if key_padding_mask is None or not self.use_flash_attn:
198
+ mixer_kwargs = (
199
+ {"key_padding_mask": key_padding_mask.bool()}
200
+ if key_padding_mask is not None
201
+ else None
202
+ )
203
+ for layer in self.layers:
204
+ if self._grad_checkpointing:
205
+ hidden_states = torch.utils.checkpoint.checkpoint(
206
+ layer,
207
+ hidden_states,
208
+ use_reentrant=False,
209
+ mixer_kwargs=mixer_kwargs,
210
+ )
211
+ else:
212
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
213
+ if subset_mask is not None:
214
+ hidden_states = hidden_states[subset_mask]
215
+ else:
216
+ batch, seqlen = hidden_states.shape[:2]
217
+ hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
218
+ hidden_states, key_padding_mask
219
+ )
220
+ mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
221
+ if subset_mask is None:
222
+ for layer in self.layers:
223
+ if self._grad_checkpointing:
224
+ hidden_states = torch.utils.checkpoint.checkpoint(
225
+ layer,
226
+ hidden_states,
227
+ use_reentrant=False,
228
+ mixer_kwargs=mixer_kwargs,
229
+ )
230
+ else:
231
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
232
+ hidden_states = pad_input(hidden_states, indices, batch, seqlen)
233
+ else:
234
+ for layer in self.layers[:-1]:
235
+ if self._grad_checkpointing:
236
+ hidden_states = torch.utils.checkpoint.checkpoint(
237
+ layer,
238
+ hidden_states,
239
+ use_reentrant=False,
240
+ mixer_kwargs=mixer_kwargs,
241
+ )
242
+ else:
243
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
244
+ if key_padding_mask is not None:
245
+ subset_idx = torch.nonzero(
246
+ subset_mask[key_padding_mask], as_tuple=False
247
+ ).flatten()
248
+ subset_seqlens = (subset_mask & key_padding_mask).sum(
249
+ dim=-1, dtype=torch.int32
250
+ )
251
+ subset_cu_seqlens = F.pad(
252
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
253
+ (1, 0),
254
+ )
255
+ else:
256
+ subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
257
+ subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
258
+ subset_cu_seqlens = F.pad(
259
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
260
+ (1, 0),
261
+ )
262
+ hidden_states_subset, hidden_states = index_first_axis_residual(
263
+ hidden_states, subset_idx
264
+ )
265
+ # It's ok to set max_seqlen_q to be much larger
266
+ mixer_kwargs = {
267
+ "x_kv": hidden_states,
268
+ "cu_seqlens": subset_cu_seqlens,
269
+ "max_seqlen": max_seqlen_in_batch,
270
+ "cu_seqlens_k": cu_seqlens,
271
+ "max_seqlen_k": max_seqlen_in_batch,
272
+ }
273
+ if self._grad_checkpointing:
274
+ torch.utils.checkpoint.checkpoint(
275
+ self.layers[-1],
276
+ hidden_states_subset,
277
+ use_reentrant=False,
278
+ mixer_kwargs=mixer_kwargs,
279
+ )
280
+ else:
281
+ hidden_states = self.layers[-1](
282
+ hidden_states_subset, mixer_kwargs=mixer_kwargs
283
+ )
284
+ return hidden_states
285
+
286
+
287
+ class XLMRobertaPooler(nn.Module):
288
+ def __init__(self, config):
289
+ super().__init__()
290
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
291
+ if fused_bias_fc and FusedDense is None:
292
+ raise ImportError("fused_dense is not installed")
293
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
294
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
295
+ self.activation = nn.Tanh()
296
+
297
+ def forward(self, hidden_states, pool=True):
298
+ # We "pool" the model by simply taking the hidden state corresponding
299
+ # to the first token.
300
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
301
+ pooled_output = self.dense(first_token_tensor)
302
+ pooled_output = self.activation(pooled_output)
303
+ return pooled_output
304
+
305
+
306
+ class XLMRobertaPredictionHeadTransform(nn.Module):
307
+ def __init__(self, config):
308
+ super().__init__()
309
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
310
+ if fused_bias_fc and FusedDense is None:
311
+ raise ImportError("fused_dense is not installed")
312
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
313
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
314
+ raise ImportError("Triton is not installed")
315
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
316
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
317
+ approximate = (
318
+ "tanh"
319
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
320
+ else "none"
321
+ )
322
+ self.transform_act_fn = nn.GELU(approximate=approximate)
323
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
324
+
325
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
326
+ hidden_states = self.dense(hidden_states)
327
+ hidden_states = self.transform_act_fn(hidden_states)
328
+ if not self.fused_dropout_add_ln:
329
+ hidden_states = self.layer_norm(hidden_states)
330
+ else:
331
+ hidden_states = layer_norm_fn(
332
+ hidden_states,
333
+ self.layer_norm.weight,
334
+ self.layer_norm.bias,
335
+ eps=self.layer_norm.eps,
336
+ )
337
+ return hidden_states
338
+
339
+
340
+ class XLMRobertaLMPredictionHead(nn.Module):
341
+ def __init__(self, config):
342
+ super().__init__()
343
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
344
+ if fused_bias_fc and FusedDense is None:
345
+ raise ImportError("fused_dense is not installed")
346
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
347
+
348
+ self.transform = XLMRobertaPredictionHeadTransform(config)
349
+
350
+ # The output weights are the same as the input embeddings, but there is
351
+ # an output-only bias for each token.
352
+ self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
353
+
354
+ def forward(self, hidden_states):
355
+ hidden_states = self.transform(hidden_states)
356
+ hidden_states = self.decoder(hidden_states)
357
+ return hidden_states
358
+
359
+
360
+ class XLMRobertaPreTrainingHeads(nn.Module):
361
+ def __init__(self, config):
362
+ super().__init__()
363
+ self.predictions = XLMRobertaLMPredictionHead(config)
364
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
365
+
366
+ def forward(self, sequence_output, pooled_output):
367
+ prediction_scores = self.predictions(sequence_output)
368
+ seq_relationship_score = self.seq_relationship(pooled_output)
369
+ return prediction_scores, seq_relationship_score
370
+
371
+
372
+ class XLMRobertaPreTrainedModel(PreTrainedModel):
373
+ """An abstract class to handle weights initialization and
374
+ a simple interface for dowloading and loading pretrained models.
375
+ """
376
+
377
+ config_class = XLMRobertaFlashConfig
378
+ base_model_prefix = "roberta"
379
+ supports_gradient_checkpointing = True
380
+
381
+ def _set_gradient_checkpointing(self, module, value=False):
382
+ if isinstance(module, XLMRobertaEncoder):
383
+ module.gradient_checkpointing = value
384
+
385
+ @classmethod
386
+ def from_pretrained(
387
+ cls,
388
+ *args,
389
+ **kwargs,
390
+ ):
391
+ if not 'torch_dtype' in kwargs:
392
+ kwargs['torch_dtype'] = 'auto'
393
+ return super().from_pretrained(*args, **kwargs)
394
+
395
+
396
+
397
+ class XLMRobertaModel(XLMRobertaPreTrainedModel):
398
+ def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
399
+ super().__init__(config)
400
+ self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
401
+ if config.vocab_size % self.pad_vocab_size_multiple != 0:
402
+ config.vocab_size += self.pad_vocab_size_multiple - (
403
+ config.vocab_size % self.pad_vocab_size_multiple
404
+ )
405
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
406
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
407
+ raise ImportError("Triton is not installed")
408
+ assert config.hidden_act in [
409
+ "gelu",
410
+ "gelu_new",
411
+ "gelu_fast",
412
+ "gelu_pytorch_tanh",
413
+ ]
414
+
415
+ self.embeddings = XLMRobertaEmbeddings(
416
+ config.hidden_size,
417
+ config.vocab_size,
418
+ config.max_position_embeddings if config.position_embedding_type == 'absolute' else -1,
419
+ config.type_vocab_size,
420
+ padding_idx=config.pad_token_id,
421
+ )
422
+ self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
423
+ self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
424
+ self.encoder = XLMRobertaEncoder(config)
425
+ self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
426
+
427
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
428
+
429
+
430
+ @torch.inference_mode()
431
+ def encode(
432
+ self: 'XLMRobertaModel',
433
+ sentences: Union[str, List[str]],
434
+ batch_size: int = 32,
435
+ show_progress_bar: Optional[bool] = None,
436
+ output_value: str = 'sentence_embedding',
437
+ convert_to_numpy: bool = True,
438
+ convert_to_tensor: bool = False,
439
+ device: Optional[torch.device] = None,
440
+ normalize_embeddings: bool = False,
441
+ truncate_dim: Optional[int] = None,
442
+ **tokenizer_kwargs,
443
+ ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
444
+ """
445
+ Computes sentence embeddings
446
+ Args:
447
+ sentences(`str` or `List[str]`):
448
+ Sentence or sentences to be encoded
449
+ batch_size(`int`, *optional*, defaults to 32):
450
+ Batch size for the computation
451
+ show_progress_bar(`bool`, *optional*, defaults to None):
452
+ Show a progress bar when encoding sentences.
453
+ If set to None, progress bar is only shown when
454
+ `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
455
+ output_value(`str`, *optional*, defaults to 'sentence_embedding'):
456
+ Default sentence_embedding, to get sentence embeddings.
457
+ Can be set to token_embeddings to get wordpiece token embeddings.
458
+ Set to None, to get all output values
459
+ convert_to_numpy(`bool`, *optional*, defaults to True):
460
+ If true, the output is a list of numpy vectors.
461
+ Else, it is a list of pytorch tensors.
462
+ convert_to_tensor(`bool`, *optional*, defaults to False):
463
+ If true, you get one large tensor as return.
464
+ Overwrites any setting from convert_to_numpy
465
+ device(`torch.device`, *optional*, defaults to None):
466
+ Which torch.device to use for the computation
467
+ normalize_embeddings(`bool`, *optional*, defaults to False):
468
+ If set to true, returned vectors will have length 1. In that case, the
469
+ faster dot-product (util.dot_score) instead of cosine similarity can
470
+ be used.
471
+ truncate_dim(`int`, *optional*, defaults to None):
472
+ The dimension to truncate sentence embeddings to. `None` does no truncation.
473
+ tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
474
+ Keyword arguments for the tokenizer
475
+ Returns:
476
+ By default, a list of tensors is returned.
477
+ If convert_to_tensor, a stacked tensor is returned.
478
+ If convert_to_numpy, a numpy matrix is returned.
479
+ """
480
+ from transformers import AutoTokenizer
481
+
482
+ self.tokenizer = AutoTokenizer.from_pretrained(
483
+ self.name_or_path, trust_remote_code=True
484
+ )
485
+
486
+ is_training = self.training
487
+ self.eval()
488
+
489
+ if show_progress_bar is None:
490
+ show_progress_bar = (
491
+ logger.getEffectiveLevel() == logging.INFO
492
+ or logger.getEffectiveLevel() == logging.DEBUG
493
+ )
494
+
495
+ if convert_to_tensor:
496
+ convert_to_numpy = False
497
+
498
+ if output_value != 'sentence_embedding':
499
+ convert_to_tensor = False
500
+ convert_to_numpy = False
501
+
502
+ input_was_string = False
503
+ if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
504
+ sentences = [sentences]
505
+ input_was_string = True
506
+
507
+ if device is not None:
508
+ self.to(device)
509
+
510
+ permutation = np.argsort([-len(i) for i in sentences])
511
+ inverse_permutation = np.argsort(permutation)
512
+ sentences = [sentences[idx] for idx in permutation]
513
+
514
+ tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
515
+ tokenizer_kwargs['max_length'] = tokenizer_kwargs.get(
516
+ 'max_length', self.tokenizer.init_kwargs.get('model_max_length', 8192)
517
+ )
518
+ tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
519
+
520
+ all_embeddings = []
521
+
522
+ if trange is not None:
523
+ range_iter = trange(
524
+ 0,
525
+ len(sentences),
526
+ batch_size,
527
+ desc="Encoding",
528
+ disable=not show_progress_bar,
529
+ )
530
+ else:
531
+ range_iter = range(0, len(sentences), batch_size)
532
+
533
+ for i in range_iter:
534
+ encoded_input = self.tokenizer(
535
+ sentences[i : i + batch_size],
536
+ return_tensors='pt',
537
+ **tokenizer_kwargs,
538
+ ).to(self.device)
539
+ token_embs = self.forward(**encoded_input)[0]
540
+
541
+ # Accumulate in fp32 to avoid overflow
542
+ token_embs = token_embs.float()
543
+
544
+ if output_value == 'token_embeddings':
545
+ raise NotImplementedError
546
+ elif output_value is None:
547
+ raise NotImplementedError
548
+ else:
549
+ if self.config.emb_pooler == 'cls':
550
+ embeddings = self.cls_pooling(
551
+ token_embs, encoded_input['attention_mask']
552
+ )
553
+ else:
554
+ embeddings = self.mean_pooling(
555
+ token_embs, encoded_input['attention_mask']
556
+ )
557
+
558
+ if normalize_embeddings:
559
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
560
+
561
+ if convert_to_numpy:
562
+ embeddings = embeddings.cpu()
563
+ all_embeddings.extend(embeddings)
564
+
565
+ all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
566
+
567
+ truncate_dim = truncate_dim or self.config.truncate_dim
568
+ if truncate_dim:
569
+ all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
570
+
571
+ if convert_to_tensor:
572
+ all_embeddings = torch.stack(all_embeddings)
573
+ elif convert_to_numpy:
574
+ all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
575
+
576
+ if input_was_string:
577
+ all_embeddings = all_embeddings[0]
578
+
579
+ self.train(is_training)
580
+ return all_embeddings
581
+
582
+
583
+ def truncate_embeddings(self, embeddings, truncate_dim):
584
+ if not self.config.matryoshka_dimensions:
585
+ logger.warning(
586
+ 'Matryoshka embeddings are not supported, so dimension truncation will not be performed.'
587
+ )
588
+ return embeddings
589
+ elif truncate_dim in self.config.matryoshka_dimensions:
590
+ return [tensor[:truncate_dim] for tensor in embeddings]
591
+ else:
592
+ raise ValueError(f'The provided `truncate_dim` value of {truncate_dim} is not supported. '
593
+ f'Supported dimensions are {self.config.matryoshka_dimensions}.')
594
+
595
+ def mean_pooling(
596
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
597
+ ):
598
+ input_mask_expanded = (
599
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
600
+ )
601
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
602
+ input_mask_expanded.sum(1), min=1e-9
603
+ )
604
+
605
+
606
+ def cls_pooling(
607
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
608
+ ):
609
+ return token_embeddings[:,0]
610
+
611
+
612
+ def forward(
613
+ self,
614
+ input_ids,
615
+ position_ids=None,
616
+ token_type_ids=None,
617
+ attention_mask=None,
618
+ masked_tokens_mask=None,
619
+ return_dict=None,
620
+ **kwargs,
621
+ ):
622
+ """If masked_tokens_mask is not None (i.e. last_layer_subset == True in XLMForPreTraining),
623
+ we only want the output for the masked tokens. This means that we only compute the last
624
+ layer output for these tokens.
625
+ masked_tokens_mask: (batch, seqlen), dtype=torch.bool
626
+ """
627
+
628
+ if kwargs:
629
+ for key, value in kwargs.items():
630
+ if value is not None:
631
+ logger.warning(
632
+ 'Flash attention implementation does not support kwargs: %s',
633
+ key,
634
+ )
635
+
636
+ return_dict = (
637
+ return_dict if return_dict is not None else self.config.use_return_dict
638
+ )
639
+
640
+ hidden_states = self.embeddings(
641
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids
642
+ )
643
+ # TD [2022-12:18]: Don't need to force residual in fp32
644
+ # BERT puts embedding LayerNorm before embedding dropout.
645
+ if not self.fused_dropout_add_ln:
646
+ hidden_states = self.emb_ln(hidden_states)
647
+ else:
648
+ hidden_states = layer_norm_fn(
649
+ hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
650
+ )
651
+ hidden_states = self.emb_drop(hidden_states)
652
+
653
+ if masked_tokens_mask is not None:
654
+ batch_size, seqlen = input_ids.shape[:2]
655
+ # We also need the first column for the CLS token
656
+ first_col_mask = torch.zeros(
657
+ batch_size, seqlen, dtype=torch.bool, device=input_ids.device
658
+ )
659
+ first_col_mask[:, 0] = True
660
+ subset_mask = masked_tokens_mask | first_col_mask
661
+ else:
662
+ subset_mask = None
663
+
664
+ sequence_output = self.encoder(
665
+ hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
666
+ )
667
+
668
+ if masked_tokens_mask is None:
669
+ pooled_output = (
670
+ self.pooler(sequence_output) if self.pooler is not None else None
671
+ )
672
+ else:
673
+ # TD [2022-03-01]: the indexing here is very tricky.
674
+ if attention_mask is not None:
675
+ subset_idx = subset_mask[attention_mask]
676
+ pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
677
+ sequence_output = sequence_output[
678
+ masked_tokens_mask[attention_mask][subset_idx]
679
+ ]
680
+ else:
681
+ pool_input = sequence_output[first_col_mask[subset_mask]]
682
+ sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
683
+ pooled_output = (
684
+ self.pooler(pool_input, pool=False) if self.pooler is not None else None
685
+ )
686
+
687
+ if not return_dict:
688
+ return sequence_output, pooled_output
689
+
690
+ return BaseModelOutputWithPoolingAndCrossAttentions(
691
+ last_hidden_state=sequence_output,
692
+ pooler_output=pooled_output,
693
+ )
694
+
695
+
696
+ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
697
+ _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
698
+
699
+ def __init__(self, config):
700
+ super().__init__(config)
701
+
702
+ if config.is_decoder:
703
+ logger.warning(
704
+ "If you want to use `XLMRobertaForMaskedLM` make sure `config.is_decoder=False` for "
705
+ "bi-directional self-attention."
706
+ )
707
+
708
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
709
+ self.lm_head = XLMRobertaLMHead(config)
710
+
711
+ # Initialize weights and apply final processing
712
+ self.post_init()
713
+
714
+ def get_input_embeddings(self):
715
+ return self.roberta.embeddings.word_embeddings
716
+
717
+ def get_output_embeddings(self):
718
+ return self.lm_head.decoder
719
+
720
+ def set_output_embeddings(self, new_embeddings):
721
+ self.lm_head.decoder = new_embeddings
722
+
723
+ def forward(
724
+ self,
725
+ input_ids: Optional[torch.LongTensor] = None,
726
+ attention_mask: Optional[torch.FloatTensor] = None,
727
+ token_type_ids: Optional[torch.LongTensor] = None,
728
+ position_ids: Optional[torch.LongTensor] = None,
729
+ head_mask: Optional[torch.FloatTensor] = None,
730
+ inputs_embeds: Optional[torch.FloatTensor] = None,
731
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
732
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
733
+ labels: Optional[torch.LongTensor] = None,
734
+ output_attentions: Optional[bool] = None,
735
+ output_hidden_states: Optional[bool] = None,
736
+ return_dict: Optional[bool] = None,
737
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
738
+ r"""
739
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
740
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
741
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
742
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
743
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
744
+ Used to hide legacy arguments that have been deprecated.
745
+ """
746
+ return_dict = (
747
+ return_dict if return_dict is not None else self.config.use_return_dict
748
+ )
749
+
750
+ outputs = self.roberta(
751
+ input_ids,
752
+ attention_mask=attention_mask,
753
+ token_type_ids=token_type_ids,
754
+ position_ids=position_ids,
755
+ head_mask=head_mask,
756
+ inputs_embeds=inputs_embeds,
757
+ encoder_hidden_states=encoder_hidden_states,
758
+ encoder_attention_mask=encoder_attention_mask,
759
+ output_attentions=output_attentions,
760
+ output_hidden_states=output_hidden_states,
761
+ return_dict=return_dict,
762
+ )
763
+ sequence_output = outputs[0]
764
+ prediction_scores = self.lm_head(sequence_output)
765
+
766
+ masked_lm_loss = None
767
+ if labels is not None:
768
+ # move labels to correct device to enable model parallelism
769
+ labels = labels.to(prediction_scores.device)
770
+ loss_fct = CrossEntropyLoss()
771
+ masked_lm_loss = loss_fct(
772
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
773
+ )
774
+
775
+ if not return_dict:
776
+ output = (prediction_scores,) + outputs[2:]
777
+ return (
778
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
779
+ )
780
+
781
+ return MaskedLMOutput(
782
+ loss=masked_lm_loss,
783
+ logits=prediction_scores,
784
+ hidden_states=outputs.hidden_states,
785
+ attentions=outputs.attentions,
786
+ )
787
+
788
+
789
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->XLMRoberta
790
+ class XLMRobertaClassificationHead(nn.Module):
791
+ """Head for sentence-level classification tasks."""
792
+
793
+ def __init__(self, config):
794
+ super().__init__()
795
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
796
+ if fused_bias_fc and FusedDense is None:
797
+ raise ImportError("fused_dense is not installed")
798
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
799
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
800
+ classifier_dropout = (
801
+ config.classifier_dropout
802
+ if config.classifier_dropout is not None
803
+ else config.hidden_dropout_prob
804
+ )
805
+ self.dropout = nn.Dropout(classifier_dropout)
806
+ self.out_proj = linear_cls(config.hidden_size, config.num_labels)
807
+
808
+ def forward(self, features, **kwargs):
809
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
810
+ x = self.dropout(x)
811
+ x = self.dense(x)
812
+ x = torch.tanh(x)
813
+ x = self.dropout(x)
814
+ x = self.out_proj(x)
815
+ return x
816
+
817
+
818
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
819
+ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
820
+ def __init__(self, config):
821
+ super().__init__(config)
822
+ self.num_labels = config.num_labels
823
+ self.config = config
824
+
825
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
826
+ self.classifier = XLMRobertaClassificationHead(config)
827
+
828
+ # Initialize weights and apply final processing
829
+ self.post_init()
830
+
831
+ def forward(
832
+ self,
833
+ input_ids: Optional[torch.LongTensor] = None,
834
+ attention_mask: Optional[torch.FloatTensor] = None,
835
+ token_type_ids: Optional[torch.LongTensor] = None,
836
+ position_ids: Optional[torch.LongTensor] = None,
837
+ head_mask: Optional[torch.FloatTensor] = None,
838
+ inputs_embeds: Optional[torch.FloatTensor] = None,
839
+ labels: Optional[torch.LongTensor] = None,
840
+ output_attentions: Optional[bool] = None,
841
+ output_hidden_states: Optional[bool] = None,
842
+ return_dict: Optional[bool] = None,
843
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
844
+ r"""
845
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
846
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
847
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
848
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
849
+ """
850
+ return_dict = (
851
+ return_dict if return_dict is not None else self.config.use_return_dict
852
+ )
853
+
854
+ outputs = self.roberta(
855
+ input_ids,
856
+ attention_mask=attention_mask,
857
+ token_type_ids=token_type_ids,
858
+ position_ids=position_ids,
859
+ head_mask=head_mask,
860
+ inputs_embeds=inputs_embeds,
861
+ output_attentions=output_attentions,
862
+ output_hidden_states=output_hidden_states,
863
+ return_dict=return_dict,
864
+ )
865
+ sequence_output = outputs[0]
866
+ logits = self.classifier(sequence_output)
867
+
868
+ loss = None
869
+ if labels is not None:
870
+ # move labels to correct device to enable model parallelism
871
+ labels = labels.to(logits.device)
872
+ if self.config.problem_type is None:
873
+ if self.num_labels == 1:
874
+ self.config.problem_type = "regression"
875
+ elif self.num_labels > 1 and (
876
+ labels.dtype == torch.long or labels.dtype == torch.int
877
+ ):
878
+ self.config.problem_type = "single_label_classification"
879
+ else:
880
+ self.config.problem_type = "multi_label_classification"
881
+
882
+ if self.config.problem_type == "regression":
883
+ loss_fct = MSELoss()
884
+ if self.num_labels == 1:
885
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
886
+ else:
887
+ loss = loss_fct(logits, labels)
888
+ elif self.config.problem_type == "single_label_classification":
889
+ loss_fct = CrossEntropyLoss()
890
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
891
+ elif self.config.problem_type == "multi_label_classification":
892
+ loss_fct = BCEWithLogitsLoss()
893
+ loss = loss_fct(logits, labels)
894
+
895
+ if not return_dict:
896
+ output = (logits,) + outputs[2:]
897
+ return ((loss,) + output) if loss is not None else output
898
+
899
+ return SequenceClassifierOutput(
900
+ loss=loss,
901
+ logits=logits,
902
+ hidden_states=outputs.hidden_states,
903
+ attentions=outputs.attentions,
904
+ )
905
+
906
+
907
+ @torch.inference_mode()
908
+ def compute_score(
909
+ self,
910
+ sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
911
+ batch_size: int = 32,
912
+ max_length: Optional[int] = None,
913
+ ) -> List[float]:
914
+
915
+ if not hasattr(self, "_tokenizer"):
916
+ from transformers import AutoTokenizer
917
+
918
+ self._tokenizer = AutoTokenizer.from_pretrained(
919
+ self.name_or_path, trust_remote_code=True
920
+ )
921
+
922
+ assert isinstance(sentence_pairs, list)
923
+ if isinstance(sentence_pairs[0], str):
924
+ sentence_pairs = [sentence_pairs]
925
+
926
+ all_scores = []
927
+ for start_index in range(
928
+ 0, len(sentence_pairs), batch_size
929
+ ):
930
+ sentences_batch = sentence_pairs[
931
+ start_index : start_index + batch_size
932
+ ]
933
+ inputs = self._tokenizer(
934
+ sentences_batch,
935
+ padding=True,
936
+ truncation=True,
937
+ return_tensors='pt',
938
+ max_length=max_length,
939
+ ).to(self.device)
940
+ scores = (
941
+ self.forward(**inputs, return_dict=True)
942
+ .logits.view(
943
+ -1,
944
+ )
945
+ .float()
946
+ )
947
+ scores = torch.sigmoid(scores)
948
+ all_scores.extend(scores.cpu().numpy().tolist())
949
+
950
+ if len(all_scores) == 1:
951
+ return all_scores[0]
952
+ return all_scores
953
+
954
+ def predict(
955
+ self,
956
+ sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
957
+ batch_size: int = 32,
958
+ max_length: Optional[int] = None,
959
+ ) -> List[float]:
960
+ # used for beir evaluation
961
+ return self.compute_score(sentence_pairs, batch_size=batch_size, max_length=max_length)
962
+
963
+ def rerank(
964
+ self,
965
+ query: str,
966
+ documents: List[str],
967
+ batch_size: int = 32,
968
+ max_length: int = 1024,
969
+ max_query_length: int = 512,
970
+ overlap_tokens: int = 80,
971
+ top_n: Optional[int] = None,
972
+ **kwargs,
973
+ ):
974
+ assert max_length >= max_query_length * 2, (
975
+ f'max_length ({max_length}) must be greater than or equal to '
976
+ f'max_query_length ({max_query_length}) * 2'
977
+ )
978
+
979
+ if not hasattr(self, "_tokenizer"):
980
+ from transformers import AutoTokenizer
981
+
982
+ self._tokenizer = AutoTokenizer.from_pretrained(
983
+ self.name_or_path, trust_remote_code=True
984
+ )
985
+
986
+ # preproc of tokenization
987
+ sentence_pairs, sentence_pairs_pids = reranker_tokenize_preproc(
988
+ query,
989
+ documents,
990
+ tokenizer=self._tokenizer,
991
+ max_length=max_length,
992
+ max_query_length=max_query_length,
993
+ overlap_tokens=overlap_tokens,
994
+ )
995
+
996
+ tot_scores = []
997
+ with torch.no_grad():
998
+ for k in range(0, len(sentence_pairs), batch_size):
999
+ batch = self._tokenizer.pad(
1000
+ sentence_pairs[k : k + batch_size],
1001
+ padding=True,
1002
+ max_length=max_length,
1003
+ pad_to_multiple_of=None,
1004
+ return_tensors="pt",
1005
+ )
1006
+ batch_on_device = {k: v.to(self.device) for k, v in batch.items()}
1007
+ scores = (
1008
+ self.forward(**batch_on_device, return_dict=True)
1009
+ .logits.view(
1010
+ -1,
1011
+ )
1012
+ .float()
1013
+ )
1014
+ scores = torch.sigmoid(scores)
1015
+ tot_scores.extend(scores.cpu().numpy().tolist())
1016
+
1017
+ # ranking
1018
+ merge_scores = [0 for _ in range(len(documents))]
1019
+ for pid, score in zip(sentence_pairs_pids, tot_scores):
1020
+ merge_scores[pid] = max(merge_scores[pid], score)
1021
+
1022
+ merge_scores_argsort = np.argsort(merge_scores)[::-1]
1023
+ sorted_documents = []
1024
+ sorted_scores = []
1025
+ for mid in merge_scores_argsort:
1026
+ sorted_scores.append(merge_scores[mid])
1027
+ sorted_documents.append(documents[mid])
1028
+
1029
+ top_n = min(top_n or len(sorted_documents), len(sorted_documents))
1030
+
1031
+ return [
1032
+ {
1033
+ 'document': sorted_documents[i],
1034
+ 'relevance_score': sorted_scores[i],
1035
+ 'index': merge_scores_argsort[i],
1036
+ }
1037
+ for i in range(top_n)
1038
+ ]
1039
+
1040
+
1041
+ def reranker_tokenize_preproc(
1042
+ query: str,
1043
+ passages: List[str],
1044
+ tokenizer=None,
1045
+ max_length: int = 1024,
1046
+ max_query_length: int = 512,
1047
+ overlap_tokens: int = 80,
1048
+ ):
1049
+ from copy import deepcopy
1050
+
1051
+ assert tokenizer is not None, "Please provide a valid tokenizer for tokenization!"
1052
+ sep_id = tokenizer.sep_token_id
1053
+
1054
+ def _merge_inputs(chunk1_raw, chunk2):
1055
+ chunk1 = deepcopy(chunk1_raw)
1056
+ chunk1['input_ids'].append(sep_id)
1057
+ chunk1['input_ids'].extend(chunk2['input_ids'])
1058
+ chunk1['input_ids'].append(sep_id)
1059
+ chunk1['attention_mask'].append(chunk2['attention_mask'][0])
1060
+ chunk1['attention_mask'].extend(chunk2['attention_mask'])
1061
+ chunk1['attention_mask'].append(chunk2['attention_mask'][-1])
1062
+ if 'token_type_ids' in chunk1:
1063
+ token_type_ids = [1 for _ in range(len(chunk2['token_type_ids']) + 2)]
1064
+ chunk1['token_type_ids'].extend(token_type_ids)
1065
+ return chunk1
1066
+
1067
+ # Note: the long query will be truncated to 256 tokens by default
1068
+ query_inputs = tokenizer.encode_plus(
1069
+ query, truncation=True, padding=False, max_length=max_query_length
1070
+ )
1071
+
1072
+ max_passage_inputs_length = max_length - len(query_inputs['input_ids']) - 2
1073
+ # assert (
1074
+ # max_passage_inputs_length > 100
1075
+ # ), "Your query is too long! Please make sure your query less than 500 tokens!"
1076
+
1077
+ overlap_tokens_implt = min(overlap_tokens, max_passage_inputs_length // 4)
1078
+
1079
+ res_merge_inputs = []
1080
+ res_merge_inputs_pids = []
1081
+ for pid, passage in enumerate(passages):
1082
+ passage_inputs = tokenizer.encode_plus(
1083
+ passage,
1084
+ truncation=False,
1085
+ padding=False,
1086
+ add_special_tokens=False,
1087
+ max_length=0,
1088
+ )
1089
+ passage_inputs_length = len(passage_inputs['input_ids'])
1090
+
1091
+ if passage_inputs_length <= max_passage_inputs_length:
1092
+ qp_merge_inputs = _merge_inputs(query_inputs, passage_inputs)
1093
+ res_merge_inputs.append(qp_merge_inputs)
1094
+ res_merge_inputs_pids.append(pid)
1095
+ else:
1096
+ start_id = 0
1097
+ while start_id < passage_inputs_length:
1098
+ end_id = start_id + max_passage_inputs_length
1099
+ # make sure the length of the last chunk is `max_passage_inputs_length`
1100
+ if end_id >= passage_inputs_length:
1101
+ sub_passage_inputs = {
1102
+ k: v[-max_passage_inputs_length:]
1103
+ for k, v in passage_inputs.items()
1104
+ }
1105
+ else:
1106
+ sub_passage_inputs = {
1107
+ k: v[start_id:end_id] for k, v in passage_inputs.items()
1108
+ }
1109
+ start_id = (
1110
+ end_id - overlap_tokens_implt
1111
+ if end_id < passage_inputs_length
1112
+ else end_id
1113
+ )
1114
+
1115
+ qp_merge_inputs = _merge_inputs(query_inputs, sub_passage_inputs)
1116
+ res_merge_inputs.append(qp_merge_inputs)
1117
+ res_merge_inputs_pids.append(pid)
1118
+
1119
+ return res_merge_inputs, res_merge_inputs_pids