oualidlamrini commited on
Commit
f1d58be
1 Parent(s): 742d190

Delete Doc-classification-model

Browse files
Doc-classification-model/config.json DELETED
@@ -1,65 +0,0 @@
1
- {
2
- "_name_or_path": "ccdv/lsg-xlm-roberta-base-4096",
3
- "adaptive": true,
4
- "architectures": [
5
- "LSGXLMRobertaForSequenceClassification"
6
- ],
7
- "attention_probs_dropout_prob": 0.1,
8
- "auto_map": {
9
- "AutoConfig": "modeling_lsg_xlm_roberta.LSGXLMRobertaConfig",
10
- "AutoModel": "modeling_lsg_xlm_roberta.LSGXLMRobertaModel",
11
- "AutoModelForCausalLM": "modeling_lsg_xlm_roberta.LSGXLMRobertaForCausalLM",
12
- "AutoModelForMaskedLM": "modeling_lsg_xlm_roberta.LSGXLMRobertaForMaskedLM",
13
- "AutoModelForMultipleChoice": "modeling_lsg_xlm_roberta.LSGXLMRobertaForMultipleChoice",
14
- "AutoModelForQuestionAnswering": "modeling_lsg_xlm_roberta.LSGXLMRobertaForQuestionAnswering",
15
- "AutoModelForSequenceClassification": "modeling_lsg_xlm_roberta.LSGXLMRobertaForSequenceClassification",
16
- "AutoModelForTokenClassification": "modeling_lsg_xlm_roberta.LSGXLMRobertaForTokenClassification"
17
- },
18
- "base_model_prefix": "lsg",
19
- "block_size": 256,
20
- "bos_token_id": 0,
21
- "classifier_dropout": null,
22
- "eos_token_id": 2,
23
- "hidden_act": "gelu",
24
- "hidden_dropout_prob": 0.1,
25
- "hidden_size": 768,
26
- "id2label": {
27
- "0": "PRETS IMMOBILIERS, CONSOMMATIONS ET DETTES (REPRIS ET CONSERVES)",
28
- "1": "RELEVE DE COMPTE ET RIB",
29
- "2": "AUTRES",
30
- "3": "HABITATION",
31
- "4": "REVENUS ET IMPOTS",
32
- "5": "IDENTITE ET SITUATION PERSONNELLE"
33
- },
34
- "initializer_range": 0.02,
35
- "intermediate_size": 3072,
36
- "label2id": {
37
- "AUTRES": 2,
38
- "HABITATION": 3,
39
- "IDENTITE ET SITUATION PERSONNELLE": 5,
40
- "PRETS IMMOBILIERS, CONSOMMATIONS ET DETTES (REPRIS ET CONSERVES)": 0,
41
- "RELEVE DE COMPTE ET RIB": 1,
42
- "REVENUS ET IMPOTS": 4
43
- },
44
- "layer_norm_eps": 1e-05,
45
- "lsh_num_pre_rounds": 1,
46
- "mask_first_token": true,
47
- "max_position_embeddings": 4098,
48
- "model_type": "xlm-roberta",
49
- "num_attention_heads": 12,
50
- "num_global_tokens": 1,
51
- "num_hidden_layers": 12,
52
- "output_past": true,
53
- "pad_token_id": 1,
54
- "pool_with_global": true,
55
- "position_embedding_type": "absolute",
56
- "problem_type": "single_label_classification",
57
- "sparse_block_size": 0,
58
- "sparsity_factor": 4,
59
- "sparsity_type": "none",
60
- "torch_dtype": "float32",
61
- "transformers_version": "4.40.2",
62
- "type_vocab_size": 1,
63
- "use_cache": true,
64
- "vocab_size": 250002
65
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Doc-classification-model/model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4abaf3f3d2f0ba394c6e0b8f5605d5d49b6db5595e14a14aad9244bb0c68bc03
3
- size 1124800360
 
 
 
 
Doc-classification-model/modeling_lsg_xlm_roberta.py DELETED
@@ -1,1207 +0,0 @@
1
- from logging import warn
2
- from transformers.models.xlm_roberta.modeling_xlm_roberta import *
3
- import torch
4
- import torch.nn as nn
5
- from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
6
- import sys
7
-
8
- AUTO_MAP = {
9
- "AutoModel": "modeling_lsg_xlm_roberta.LSGXLMRobertaModel",
10
- "AutoModelForCausalLM": "modeling_lsg_xlm_roberta.LSGXLMRobertaForCausalLM",
11
- "AutoModelForMaskedLM": "modeling_lsg_xlm_roberta.LSGXLMRobertaForMaskedLM",
12
- "AutoModelForMultipleChoice": "modeling_lsg_xlm_roberta.LSGXLMRobertaForMultipleChoice",
13
- "AutoModelForQuestionAnswering": "modeling_lsg_xlm_roberta.LSGXLMRobertaForQuestionAnswering",
14
- "AutoModelForSequenceClassification": "modeling_lsg_xlm_roberta.LSGXLMRobertaForSequenceClassification",
15
- "AutoModelForTokenClassification": "modeling_lsg_xlm_roberta.LSGXLMRobertaForTokenClassification"
16
- }
17
-
18
- class LSGXLMRobertaConfig(XLMRobertaConfig):
19
- """
20
- This class overrides :class:`~transformers.RobertaConfig`. Please check the superclass for the appropriate
21
- documentation alongside usage examples.
22
- """
23
-
24
- base_model_prefix = "lsg"
25
- model_type = "xlm-roberta"
26
-
27
- def __init__(
28
- self,
29
- adaptive=True,
30
- base_model_prefix="lsg",
31
- block_size=128,
32
- lsh_num_pre_rounds=1,
33
- mask_first_token=False,
34
- num_global_tokens=1,
35
- pool_with_global=True,
36
- sparse_block_size=128,
37
- sparsity_factor=2,
38
- sparsity_type="norm",
39
- **kwargs
40
- ):
41
- """Constructs LSGXLMRobertaConfig."""
42
- super().__init__(**kwargs)
43
-
44
- self.adaptive = adaptive
45
- self.auto_map = AUTO_MAP
46
- self.base_model_prefix = base_model_prefix
47
- self.block_size = block_size
48
- self.lsh_num_pre_rounds = lsh_num_pre_rounds
49
- self.mask_first_token = mask_first_token
50
- self.num_global_tokens = num_global_tokens
51
- self.pool_with_global = pool_with_global
52
- self.sparse_block_size = sparse_block_size
53
- self.sparsity_factor = sparsity_factor
54
- self.sparsity_type = sparsity_type
55
-
56
- if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride", "bos_pooling"]:
57
- logger.warning(
58
- "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride', 'bos_pooling'], \
59
- setting sparsity_type=None, computation will skip sparse attention")
60
- self.sparsity_type = None
61
-
62
- if self.sparsity_type in ["stride", "block_stride"]:
63
- if self.sparsity_factor > self.num_attention_heads:
64
- logger.warning(
65
- "[WARNING CONFIG]: sparsity_factor > num_attention_heads is not recommended for stride/block_stride sparsity"
66
- )
67
-
68
- if self.num_global_tokens < 1:
69
- logger.warning(
70
- "[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
71
- )
72
- self.num_global_tokens = 1
73
- elif self.num_global_tokens > 512:
74
- logger.warning(
75
- "[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
76
- )
77
- self.num_global_tokens = 512
78
-
79
- if self.sparsity_factor > 0:
80
- assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
81
- assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
82
-
83
- if self.mask_first_token and not pool_with_global:
84
- logger.warning(
85
- "[WARNING CONFIG]: pool_with_global==False is not compatible with mask_first_token==True. Setting pool_with_global to True.")
86
- self.pool_with_global = True
87
-
88
- if hasattr(self, "position_embedding_type"):
89
- if self.position_embedding_type != "absolute":
90
- logger.warning(
91
- "[WARNING CONFIG]: LSG Attention is not compatible with relative positional embedding and will skip its computation. Set position_embedding_type='absolute' to remove this warning.")
92
-
93
-
94
- class BaseSelfAttention(nn.Module):
95
-
96
- def init_modules(self, config):
97
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
98
- config, "embedding_size"
99
- ):
100
- raise ValueError(
101
- "The hidden size (%d) is not a multiple of the number of attention "
102
- "heads (%d)" % (config.hidden_size, config.num_attention_heads)
103
- )
104
-
105
- self.num_attention_heads = config.num_attention_heads
106
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
107
- self.all_head_size = self.num_attention_heads * self.attention_head_size
108
-
109
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
110
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
111
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
112
-
113
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
114
-
115
- def transpose_for_scores(self, x):
116
- new_x_shape = x.size()[:-1] + (
117
- self.num_attention_heads,
118
- self.attention_head_size,
119
- )
120
- x = x.view(*new_x_shape)
121
- return x.permute(0, 2, 1, 3)
122
-
123
- def reshape_output(self, context_layer):
124
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
125
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
126
- return context_layer.view(*new_context_layer_shape)
127
-
128
- def project_QKV(self, hidden_states):
129
-
130
- query_layer = self.transpose_for_scores(self.query(hidden_states))
131
- key_layer = self.transpose_for_scores(self.key(hidden_states))
132
- value_layer = self.transpose_for_scores(self.value(hidden_states))
133
- return query_layer, key_layer, value_layer
134
-
135
-
136
- class BaseAttentionProduct(nn.Module):
137
-
138
- def __init__(self, config):
139
- """
140
- Compute attention: softmax(Q @ K.T) @ V
141
- """
142
- super().__init__()
143
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
144
-
145
- def forward(self, query_layer, key_layer, value_layer, attention_mask=None):
146
-
147
- d = query_layer.shape[-1]
148
-
149
- # Take the dot product between "query" and "key" to get the raw attention scores.
150
- attention_scores = query_layer @ key_layer.transpose(-1, -2) / math.sqrt(d)
151
-
152
- del query_layer
153
- del key_layer
154
-
155
- if attention_mask is not None:
156
- # Apply the attention mask is (precomputed for all layers in XLMRobertaModel forward() function)
157
- attention_scores = attention_scores + attention_mask
158
- del attention_mask
159
-
160
- # Normalize the attention scores to probabilities.
161
- attention_probs = nn.Softmax(dim=-1)(attention_scores)
162
-
163
- # This is actually dropping out entire tokens to attend to, which might
164
- # seem a bit unusual, but is taken from the original Transformer paper.
165
- context_layer = self.dropout(attention_probs) @ value_layer
166
-
167
- return context_layer
168
-
169
-
170
- class CausalAttentionProduct(nn.Module):
171
-
172
- def __init__(self, config):
173
- """
174
- Compute attention: softmax(Q @ K.T) @ V
175
- """
176
- super().__init__()
177
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
178
- self.block_size = config.block_size
179
-
180
- def forward(self, query_layer, key_layer, value_layer, attention_mask=None, causal_shape=None):
181
-
182
- d = query_layer.shape[-1]
183
-
184
- # Take the dot product between "query" and "key" to get the raw attention scores.
185
- attention_scores = query_layer @ key_layer.transpose(-1, -2) / math.sqrt(d)
186
-
187
- del query_layer
188
- del key_layer
189
-
190
- if attention_mask is not None:
191
- # Add causal mask
192
- causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
193
- causal_mask = torch.tril(
194
- torch.ones(*causal_shape, device=attention_mask.device, dtype=attention_scores.dtype),
195
- diagonal=-1
196
- )
197
-
198
- # Min value
199
- dtype_min = torch.tensor(
200
- torch.finfo(attention_scores.dtype).min, device=attention_scores.device, dtype=attention_scores.dtype
201
- )
202
-
203
- # Build causal + attention_mask
204
- causal_mask = torch.nn.functional.pad(causal_mask.T * dtype_min, (attention_mask.size()[-1] - self.block_size, 0), value=0)
205
- attention_mask = torch.max(attention_mask + causal_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0), dtype_min)
206
-
207
- attention_scores = attention_scores + attention_mask
208
- del attention_mask
209
- del causal_mask
210
-
211
- # Normalize the attention scores to probabilities.
212
- attention_probs = nn.Softmax(dim=-1)(attention_scores)
213
-
214
- # This is actually dropping out entire tokens to attend to, which might
215
- # seem a bit unusual, but is taken from the original Transformer paper.
216
- context_layer = self.dropout(attention_probs) @ value_layer
217
-
218
- return context_layer
219
-
220
-
221
- class LSGAttentionProduct(nn.Module):
222
-
223
- def __init__(self, config, block_size=None, sparse_block_size=None, sparsity_factor=4, is_causal=False):
224
- """
225
- Compute block or overlapping blocks attention products
226
- """
227
- super().__init__()
228
-
229
- self.block_size = block_size
230
- self.sparse_block_size = sparse_block_size
231
- self.sparsity_factor = sparsity_factor
232
- self.is_causal = is_causal
233
-
234
- if self.block_size is None:
235
- self.block_size = config.block_size
236
-
237
- if self.sparse_block_size is None:
238
- self.sparse_block_size = config.sparse_block_size
239
-
240
- # Shape of blocks
241
- self.local_shapes = (self.block_size*3, self.block_size)
242
- if self.sparse_block_size and self.sparsity_factor > 0:
243
- self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
244
-
245
- if is_causal:
246
- self.attention = CausalAttentionProduct(config)
247
- else:
248
- self.attention = BaseAttentionProduct(config)
249
-
250
- def build_lsg_inputs(self, hidden_states, sparse_hidden_states, global_hidden_states, is_attn_mask=False):
251
-
252
- # Build local tokens
253
- local_hidden_states = self.reshape_to_local_block(hidden_states, is_attn_mask)
254
- del hidden_states
255
-
256
- # Build sparse tokens
257
- if sparse_hidden_states is not None:
258
- sparse_hidden_states = self.reshape_to_sparse_block(sparse_hidden_states, is_attn_mask)
259
-
260
- return self.cat_global_sparse_local_tokens(global_hidden_states, sparse_hidden_states, local_hidden_states)
261
-
262
- def forward(
263
- self,
264
- query_layer,
265
- key_layer,
266
- value_layer,
267
- attention_mask=None,
268
- sparse_key=None,
269
- sparse_value=None,
270
- sparse_mask=None,
271
- global_key=None,
272
- global_value=None,
273
- global_mask=None
274
- ):
275
-
276
- # Input batch, heads, length, hidden_size
277
- n, h, t, d = query_layer.size()
278
- n_blocks = t // self.block_size
279
- assert t % self.block_size == 0
280
-
281
- key_layer = self.build_lsg_inputs(
282
- key_layer,
283
- sparse_key,
284
- global_key
285
- )
286
- del sparse_key
287
- del global_key
288
-
289
- value_layer = self.build_lsg_inputs(
290
- value_layer,
291
- sparse_value,
292
- global_value
293
- )
294
- del sparse_value
295
- del global_value
296
-
297
- attention_mask = self.build_lsg_inputs(
298
- attention_mask,
299
- sparse_mask,
300
- global_mask.transpose(-1, -2),
301
- is_attn_mask=True
302
- ).transpose(-1, -2)
303
- del sparse_mask
304
- del global_mask
305
-
306
- # expect (..., t, d) shape
307
- # Compute attention
308
- context_layer = self.attention(
309
- query_layer=self.chunk(query_layer, n_blocks),
310
- key_layer=key_layer,
311
- value_layer=value_layer,
312
- attention_mask=attention_mask
313
- )
314
-
315
- return context_layer.reshape(n, h, -1, d)
316
-
317
- def reshape_to_local_block(self, hidden_states, is_attn_mask=False):
318
-
319
- size, step = self.local_shapes
320
- s = (size - step) // 2
321
-
322
- # Pad before block reshaping
323
- if is_attn_mask:
324
- pad_value = torch.finfo(hidden_states.dtype).min
325
- hidden_states = hidden_states.transpose(-1, -2)
326
- else:
327
- pad_value = 0
328
-
329
- hidden_states = torch.nn.functional.pad(
330
- hidden_states.transpose(-1, -2),
331
- pad=(s, s),
332
- value=pad_value
333
- ).transpose(-1, -2)
334
-
335
- # Make blocks
336
- hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2)
337
-
338
- # Skip third block if causal
339
- if self.is_causal:
340
- return hidden_states[..., :size*2//3, :]
341
-
342
- return hidden_states
343
-
344
- def reshape_to_sparse_block(self, hidden_states, is_attn_mask=False):
345
-
346
- size, step = self.sparse_shapes
347
-
348
- # In case of odd case
349
- odd_offset = (step % 2)
350
-
351
- # n, h, t, d*2 + 1
352
- size = size*2
353
- s = (size - step) // 2 + odd_offset
354
-
355
- # Pad before block reshaping
356
- if is_attn_mask:
357
- pad_value = torch.finfo(hidden_states.dtype).min
358
- hidden_states = hidden_states.transpose(-1, -2)
359
- else:
360
- pad_value = 0
361
-
362
- hidden_states = torch.nn.functional.pad(
363
- hidden_states.transpose(-1, -2),
364
- pad=(s, s),
365
- value=pad_value
366
- ).transpose(-1, -2)
367
-
368
- # Make blocks
369
- hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2)
370
-
371
- # Fix case where block_size == sparsify_factor
372
- if odd_offset:
373
- hidden_states = hidden_states[..., :-1, :, :]
374
-
375
- # Indexes for selection
376
- u = (size - self.block_size * 3 // self.sparsity_factor) // 2 + odd_offset
377
- s = self.sparse_block_size
378
-
379
- # Skip right block if causal
380
- if self.is_causal:
381
- return hidden_states[..., u-s:u, :]
382
-
383
- u_ = u + odd_offset
384
- return torch.cat([hidden_states[..., u-s:u, :], hidden_states[..., -u_:-u_+s, :]], dim=-2)
385
-
386
- def cat_global_sparse_local_tokens(self, x_global, x_sparse=None, x_local=None, dim=-2):
387
-
388
- n, h, b, t, d = x_local.size()
389
- x_global = x_global.unsqueeze(-3).expand(-1, -1, b, -1, -1)
390
- if x_sparse is not None:
391
- return torch.cat([x_global, x_sparse, x_local], dim=dim)
392
- return torch.cat([x_global, x_local], dim=dim)
393
-
394
- def chunk(self, x, n_blocks):
395
-
396
- t, d = x.size()[-2:]
397
- return x.reshape(*x.size()[:-2], n_blocks, -1, d)
398
-
399
-
400
- class LSGXLMRobertaEmbeddings(XLMRobertaEmbeddings):
401
-
402
- def __init__(self, config):
403
- super().__init__(config)
404
-
405
- self.num_global_tokens = config.num_global_tokens
406
-
407
- # Hardcoded but partially trained
408
- self.global_embeddings = nn.Embedding(512, embedding_dim=config.hidden_size, )
409
-
410
- self.block_size = config.block_size
411
-
412
- def forward(
413
- self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
414
- ):
415
- if position_ids is None:
416
- if input_ids is not None:
417
- # Create the position ids from the input token ids. Any padded tokens remain padded.
418
- position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
419
- else:
420
- position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
421
-
422
- if input_ids is not None:
423
- input_shape = input_ids.size()
424
- else:
425
- input_shape = inputs_embeds.size()[:-1]
426
-
427
- seq_length = input_shape[1]
428
-
429
- # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
430
- # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
431
- # issue #5664
432
- if token_type_ids is None:
433
- if hasattr(self, "token_type_ids"):
434
- buffered_token_type_ids = self.token_type_ids[:, :seq_length]
435
- buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
436
- token_type_ids = buffered_token_type_ids_expanded
437
- else:
438
- token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
439
-
440
- if inputs_embeds is None:
441
- inputs_embeds = self.word_embeddings(input_ids)
442
- token_type_embeddings = self.token_type_embeddings(token_type_ids[:, :seq_length])
443
-
444
- embeddings = inputs_embeds + token_type_embeddings
445
- if self.position_embedding_type == "absolute":
446
- position_embeddings = self.position_embeddings(position_ids[:, :seq_length])
447
- embeddings += position_embeddings
448
-
449
- #if self.num_global_tokens < 0:
450
- n, t, d = embeddings.size()
451
-
452
- # Add global_tokens
453
- indexes = torch.arange(self.num_global_tokens, device=embeddings.device).reshape(1, -1)
454
- global_embeddings = self.global_embeddings(indexes)
455
- embeddings = torch.cat([global_embeddings.expand(n, -1, d), embeddings], dim=-2)
456
-
457
- embeddings = self.LayerNorm(embeddings)
458
- embeddings = self.dropout(embeddings)
459
- return embeddings
460
-
461
-
462
- class LSGAttention(XLMRobertaAttention):
463
-
464
- def __init__(self, config):
465
-
466
- super().__init__(config)
467
-
468
- self.self = LSGSelfAttention(config)
469
-
470
-
471
- class LSGSelfAttention(BaseSelfAttention):
472
- '''
473
- Compute local attention with overlapping blocs
474
- Use global attention for tokens with highest norm
475
- '''
476
- def __init__(self, config):
477
- super().__init__()
478
-
479
- self.init_modules(config)
480
-
481
- self.block_size = config.block_size
482
- self.sparse_block_size = config.sparse_block_size
483
- self.num_global_tokens = config.num_global_tokens
484
- self.sparsity_factor = config.sparsity_factor
485
- self.is_causal = config.is_decoder
486
- self.is_decoder = config.is_decoder
487
-
488
- self.attention = LSGAttentionProduct(
489
- config,
490
- block_size=config.block_size,
491
- sparse_block_size=config.sparse_block_size,
492
- sparsity_factor=self.sparsity_factor,
493
- is_causal=self.is_causal
494
- )
495
-
496
- if self.is_causal:
497
- self.causal_attention = CausalAttentionProduct(config)
498
- self.full_attention = BaseAttentionProduct(config)
499
-
500
- sparse_functions = {
501
- "norm": self.get_sparse_tokens_with_norm,
502
- "pooling": self.get_sparse_tokens_with_pooling,
503
- "lsh": self.get_sparse_tokens_with_lsh,
504
- "stride": self.get_sparse_tokens_with_stride,
505
- "block_stride": self.get_sparse_tokens_with_block_stride,
506
- "bos_pooling": self.get_sparse_tokens_with_bos_pooling
507
- }
508
-
509
- self.sparsity_type = config.sparsity_type
510
- self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda w, x, y, z: (None, None, None))
511
-
512
- if config.sparsity_type == "lsh":
513
- self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
514
-
515
- def get_sparse_tokens_with_norm(self, queries, keys, values, mask):
516
-
517
- if self.sparsity_factor == 1:
518
- return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
519
-
520
- with torch.no_grad():
521
-
522
- block_size = min(self.block_size, self.sparse_block_size)
523
- key_norm = keys.detach().norm(dim=-1, keepdim=True)
524
- key_norm = key_norm * ~mask.transpose(-1, -2).bool()
525
- key_norm = self.chunk(key_norm, block_size)
526
-
527
- n, h, b, t, d = key_norm.size()
528
-
529
- idx = key_norm.argsort(dim=-2)
530
- del key_norm
531
- idx += (torch.arange(b, device=keys.device)*t).reshape(1, 1, b, 1, 1)
532
-
533
- split = (t - block_size // self.sparsity_factor, block_size // self.sparsity_factor)
534
- sparse_idx = idx.split(split, -2)[-1].reshape(n, h, -1, 1)
535
-
536
- d = keys.size()[-1]
537
- keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
538
- values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
539
- mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
540
-
541
- return keys, values, mask
542
-
543
- def get_sparse_tokens_with_pooling(self, queries, keys, values, mask):
544
-
545
- if self.sparsity_factor == 1:
546
- return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
547
-
548
- keys = self.chunk(keys, self.sparsity_factor)
549
- values = self.chunk(values, self.sparsity_factor)
550
-
551
- n, h, b, t, d = keys.size()
552
- mask = mask.reshape(n, 1, b, 1, t)
553
- mask = ~mask.transpose(-1, -2).bool()
554
-
555
- keys = keys * mask
556
- values = values * mask
557
-
558
- mask = mask.sum(dim=-2)
559
- keys = keys.sum(dim=-2) / (mask + 1e-6)
560
- values = values.sum(dim=-2) / (mask + 1e-6)
561
-
562
- mask = (1. - mask.clamp(0, 1))
563
- mask *= torch.finfo(mask.dtype).min
564
- return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
565
-
566
- def get_sparse_tokens_with_stride(self, queries, keys, values, mask):
567
-
568
- if self.sparsity_factor == 1:
569
- return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
570
-
571
- n, h, t, d = keys.size()
572
- sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
573
- sparse_idx = sparse_idx.reshape(1, 1, -1, 1) + (torch.arange(h, device=keys.device) % self.sparsity_factor).reshape(1, h, 1, 1)
574
- sparse_idx = sparse_idx.expand(n, h, -1, 1)
575
-
576
- keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
577
- values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
578
- mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
579
-
580
- return keys, values, mask
581
-
582
- def get_sparse_tokens_with_block_stride(self, queries, keys, values, mask):
583
-
584
- if self.sparsity_factor == 1:
585
- return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
586
-
587
- n, h, t, d = keys.size()
588
-
589
- t, b = self.block_size, t // self.block_size
590
- sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device)
591
- sparse_idx = sparse_idx.reshape(1, 1, 1, -1, 1) + torch.arange(h, device=keys.device).reshape(1, h, 1, 1, 1) * (t // self.sparsity_factor)
592
- sparse_idx = (sparse_idx % t)
593
- sparse_idx = sparse_idx + torch.arange(b, device=keys.device).reshape(1, 1, -1, 1, 1) * t
594
- sparse_idx = sparse_idx.reshape(1, h, -1, 1).expand(n, h, -1, 1)
595
-
596
- keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
597
- values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
598
- mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
599
-
600
- return keys, values, mask
601
-
602
- def get_sparse_tokens_with_lsh(self, queries, keys, values, mask):
603
-
604
- if self.sparsity_factor == 1:
605
- return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
606
-
607
- if self.sparsity_factor == self.sparse_block_size:
608
- return self.get_sparse_tokens_with_bos_pooling(queries, keys, values, mask)
609
-
610
- block_size = min(self.block_size, self.sparse_block_size)
611
- keys = self.chunk(keys, block_size)
612
- values = self.chunk(values, block_size)
613
-
614
- n, h, b, t, d = keys.size()
615
- mask = mask.reshape(n, 1, b, 1, t)
616
- mask = ~mask.transpose(-1, -2).bool()
617
-
618
- keys = keys * mask
619
- values = values * mask
620
- mask = mask.expand(-1, h, -1, -1, -1).float()
621
-
622
- extra_factor = 1
623
-
624
- for _ in range(self.lsh_num_pre_rounds):
625
- keys, values, mask = self.lsh_round(keys, values, mask, t*extra_factor)
626
-
627
- keys, values, mask = self.lsh_round(keys, values, mask, t//self.sparsity_factor)
628
- keys /= mask + 1e-8
629
- values /= mask + 1e-8
630
-
631
- mask = (1. - mask.clamp(0, 1))
632
- mask *= torch.finfo(mask.dtype).min
633
-
634
- return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
635
-
636
- def lsh_round(self, keys, values, mask, output_size):
637
-
638
- with torch.no_grad():
639
-
640
- n_hashes = output_size // 2
641
- n, h, b, t, d = keys.size()
642
- binary_mask = mask.clamp(0, 1)
643
-
644
- indexes = (torch.nn.functional.normalize(keys, dim=-1) * binary_mask) @ torch.randn(1, h, 1, d, n_hashes, device=keys.device)
645
- indexes = torch.cat([indexes, -indexes], dim=-1).argmax(dim=-1, keepdim=True)
646
-
647
- n, h, b, t, d = keys.size()
648
-
649
- x_ = torch.zeros(n, h, b, output_size, d, device=keys.device)
650
- mask_ = torch.zeros(n, h, b, output_size, 1, device=keys.device)
651
- keys = torch.scatter_add(x_, dim=-2, index=indexes.expand(-1, -1, -1, -1, d), src=keys)
652
- values = torch.scatter_add(x_, dim=-2, index=indexes.expand(-1, -1, -1, -1, d), src=values)
653
- mask = torch.scatter_add(mask_, dim=-2, index=indexes, src=mask)
654
-
655
- return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
656
-
657
- def get_sparse_tokens_with_bos_pooling(self, queries, keys, values, mask):
658
-
659
- if self.sparsity_factor == 1:
660
- return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
661
-
662
- queries = queries.unsqueeze(-3)
663
- mask = self.chunk(mask.transpose(-1, -2), self.sparsity_factor).transpose(-1, -2)
664
- keys = self.chunk(keys, self.sparsity_factor)
665
- values = self.chunk(values, self.sparsity_factor)
666
-
667
- n, h, b, t, d = keys.size()
668
- scores = (queries[..., :1, :] @ keys.transpose(-1, -2)) / math.sqrt(d)
669
- if mask is not None:
670
- scores = scores + mask
671
-
672
- scores = torch.softmax(scores, dim=-1)
673
- keys = scores @ keys
674
- values = scores @ values
675
- mask = mask.mean(dim=-1)
676
- mask[mask != torch.finfo(mask.dtype).min] = 0
677
-
678
- return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
679
-
680
- def forward(
681
- self,
682
- hidden_states,
683
- attention_mask=None,
684
- head_mask=None,
685
- encoder_hidden_states=None,
686
- encoder_attention_mask=None,
687
- past_key_value=None,
688
- output_attentions=False,
689
- ):
690
-
691
- query_layer = self.query(hidden_states)
692
-
693
- # If this is instantiated as a cross-attention module, the keys
694
- # and values come from an encoder; the attention mask needs to be
695
- # such that the encoder's padding tokens are not attended to.
696
- is_cross_attention = encoder_hidden_states is not None
697
-
698
- if is_cross_attention and past_key_value is not None:
699
- # reuse k,v, cross_attentions
700
- key_layer = past_key_value[0]
701
- value_layer = past_key_value[1]
702
- attention_mask = encoder_attention_mask
703
- elif is_cross_attention:
704
- key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
705
- value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
706
- attention_mask = encoder_attention_mask
707
- elif past_key_value is not None:
708
- key_layer = self.transpose_for_scores(self.key(hidden_states))
709
- value_layer = self.transpose_for_scores(self.value(hidden_states))
710
- key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
711
- value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
712
- else:
713
- key_layer = self.transpose_for_scores(self.key(hidden_states))
714
- value_layer = self.transpose_for_scores(self.value(hidden_states))
715
-
716
- query_layer = self.transpose_for_scores(query_layer)
717
-
718
- if self.is_decoder:
719
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
720
- # Further calls to cross_attention layer can then reuse all cross-attention
721
- # key/value_states (first "if" case)
722
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
723
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
724
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
725
- # if encoder bi-directional self-attention `past_key_value` is always `None`
726
- past_key_value = (key_layer, value_layer)
727
-
728
- if is_cross_attention:
729
- outputs = self.cross_attention_forward(
730
- query_layer=query_layer,
731
- key_layer=key_layer,
732
- value_layer=value_layer,
733
- attention_mask=attention_mask,
734
- output_attentions=output_attentions
735
- )
736
- else:
737
- outputs = self.causal_forward(
738
- query_layer,
739
- key_layer,
740
- value_layer,
741
- attention_mask=attention_mask,
742
- output_attentions=output_attentions,
743
- )
744
-
745
- outputs = outputs + ((key_layer, value_layer),)
746
-
747
- else:
748
- outputs = self.not_causal_forward(
749
- query_layer,
750
- key_layer,
751
- value_layer,
752
- attention_mask=attention_mask,
753
- output_attentions=output_attentions
754
- )
755
-
756
- return outputs
757
-
758
- def causal_forward(
759
- self,
760
- query_layer,
761
- key_layer,
762
- value_layer,
763
- attention_mask=None,
764
- output_attentions=False,
765
- ):
766
-
767
- n, h, t, d = key_layer.size()
768
-
769
- # Cat global mask
770
- attention_mask = torch.nn.functional.pad(attention_mask, (self.num_global_tokens, 0), value=0)
771
-
772
- # Split input into global tokens and other tokens
773
- split = (self.num_global_tokens, t - self.num_global_tokens)
774
- global_query, query_layer = query_layer.split(split, dim=-2)
775
-
776
- # Use normal causal attention if local attention covers every tokens
777
- if t <= 2 * self.block_size + self.num_global_tokens:
778
- context_layer = self.causal_attention(
779
- query_layer=query_layer,
780
- key_layer=key_layer,
781
- value_layer=value_layer,
782
- attention_mask=attention_mask,
783
- causal_shape=(t - self.num_global_tokens, t - self.num_global_tokens)
784
- )
785
-
786
- context_layer = torch.cat([global_query, context_layer], dim=-2)
787
- return (self.reshape_output(context_layer), )
788
-
789
- # Split K Q M on global and non global
790
- global_key, key_layer = key_layer.split(split, dim=-2)
791
- global_value, value_layer = value_layer.split(split, dim=-2)
792
- global_mask, attention_mask = attention_mask.split(split, dim=-1)
793
-
794
- n, h, t, d = key_layer.size()
795
-
796
- # Get sparse idx
797
- sparse_key, sparse_value, sparse_mask = (None, None, None)
798
- if self.sparse_block_size and self.sparsity_factor > 0:
799
- sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(query_layer, key_layer, value_layer, attention_mask)
800
-
801
- # Expand masks on heads
802
- attention_mask = attention_mask.expand(-1, h, -1, -1)
803
- global_mask = global_mask.expand(-1, h, -1, -1)
804
-
805
- # Compute dot product attention
806
- context_layer = self.attention(
807
- query_layer,
808
- key_layer,
809
- value_layer,
810
- attention_mask,
811
- sparse_key=sparse_key,
812
- sparse_value=sparse_value,
813
- sparse_mask=sparse_mask,
814
- global_key=global_key,
815
- global_value=global_value,
816
- global_mask=global_mask
817
- )
818
-
819
- # Merge pseudo global (causal) and local-sparse tokens
820
- context_layer = torch.cat([global_query, context_layer], dim=-2)
821
- context_layer = self.reshape_output(context_layer)
822
-
823
- return (context_layer,)
824
-
825
- def not_causal_forward(
826
- self,
827
- query_layer,
828
- key_layer,
829
- value_layer,
830
- attention_mask=None,
831
- output_attentions=False,
832
- ):
833
-
834
- n, h, t, d = query_layer.size()
835
-
836
- # Cat global mask
837
- attention_mask = torch.nn.functional.pad(attention_mask, (self.num_global_tokens, 0), value=0)
838
-
839
- # Use normal attention if local attention covers every tokens
840
- if t <= 2 * self.block_size + self.num_global_tokens:
841
- context_layer = self.full_attention(
842
- query_layer=query_layer,
843
- key_layer=key_layer,
844
- value_layer=value_layer,
845
- attention_mask=attention_mask
846
- )
847
- return (self.reshape_output(context_layer), )
848
-
849
- # Split input into global tokens and other tokens
850
- split = (self.num_global_tokens, t - self.num_global_tokens)
851
- global_query, query_layer = query_layer.split(split, dim=-2)
852
-
853
- # Get global_attention
854
- bos = self.full_attention(
855
- query_layer=global_query,
856
- key_layer=key_layer,
857
- value_layer=value_layer,
858
- attention_mask=attention_mask
859
- )
860
-
861
- # Split K Q M on global and non global
862
- global_key, key_layer = key_layer.split(split, dim=-2)
863
- global_value, value_layer = value_layer.split(split, dim=-2)
864
- global_mask, attention_mask = attention_mask.split(split, dim=-1)
865
-
866
- n, h, t, d = key_layer.size()
867
-
868
- # Get sparse idx
869
- sparse_key, sparse_value, sparse_mask = (None, None, None)
870
-
871
- if self.sparse_block_size and self.sparsity_factor > 0:
872
- sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(query_layer, key_layer, value_layer, attention_mask)
873
-
874
- # Expand masks on heads
875
- attention_mask = attention_mask.expand(-1, h, -1, -1)
876
- global_mask = global_mask.expand(-1, h, -1, -1)
877
-
878
- # Compute dot product attention
879
- context_layer = self.attention(
880
- query_layer,
881
- key_layer,
882
- value_layer,
883
- attention_mask,
884
- sparse_key=sparse_key,
885
- sparse_value=sparse_value,
886
- sparse_mask=sparse_mask,
887
- global_key=global_key,
888
- global_value=global_value,
889
- global_mask=global_mask
890
- )
891
-
892
- # Merge global and local-sparse tokens
893
- context_layer = torch.cat([bos, context_layer], dim=-2)
894
- context_layer = self.reshape_output(context_layer)
895
-
896
- return (context_layer,)
897
-
898
- def cross_attention_forward(
899
- self,
900
- query_layer,
901
- key_layer,
902
- value_layer,
903
- attention_mask=None,
904
- output_attentions=False,
905
- ):
906
-
907
- context_layer = self.full_attention(
908
- query_layer=query_layer,
909
- key_layer=key_layer,
910
- value_layer=value_layer,
911
- attention_mask=attention_mask
912
- )
913
- return (self.reshape_output(context_layer), )
914
-
915
- def chunk(self, x, chunk_size):
916
-
917
- n, h, t, d = x.size()
918
- return x.reshape(n, h, -1, chunk_size, d)
919
-
920
-
921
- class LSGXLMRobertaLayer(XLMRobertaLayer):
922
-
923
- def __init__(self, config):
924
-
925
- super().__init__(config)
926
-
927
- self.attention = LSGAttention(config)
928
- if self.add_cross_attention:
929
- assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
930
- self.crossattention = LSGAttention(config)
931
-
932
-
933
- class LSGXLMRobertaEncoder(XLMRobertaEncoder):
934
-
935
- def __init__(self, config):
936
-
937
- super().__init__(config)
938
- self.layer = nn.ModuleList([LSGXLMRobertaLayer(config) for _ in range(config.num_hidden_layers)])
939
-
940
- assert hasattr(config, "num_global_tokens")
941
- self.num_global_tokens = config.num_global_tokens
942
- self.pad_idx = config.pad_token_id
943
-
944
- assert hasattr(config, "block_size") and hasattr(config, "adaptive")
945
- self.block_size = config.block_size
946
- self.adaptive = config.adaptive
947
- self.mask_first_token = config.mask_first_token
948
- self.pool_with_global = config.pool_with_global
949
-
950
- def forward(
951
- self,
952
- hidden_states: torch.Tensor,
953
- attention_mask: Optional[torch.FloatTensor] = None,
954
- head_mask: Optional[torch.FloatTensor] = None,
955
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
956
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
957
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
958
- use_cache: Optional[bool] = None,
959
- output_attentions: Optional[bool] = False,
960
- output_hidden_states: Optional[bool] = False,
961
- return_dict: Optional[bool] = True,
962
- ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
963
-
964
- mask_value = torch.finfo(attention_mask.dtype).min
965
- n, _, __, t = attention_mask.size()
966
-
967
- if not (self.config.is_decoder and encoder_hidden_states is not None):
968
- b = self.block_size * 2
969
- pad = t % self.block_size
970
-
971
- # Check if t is multiple of block_size and pad
972
- if self.adaptive and t > b and pad > 0:
973
- pad_length = self.block_size - pad
974
- hidden_states = torch.nn.functional.pad(hidden_states.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
975
- attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=mask_value)
976
-
977
- if self.mask_first_token:
978
- attention_mask[..., 0] = mask_value
979
-
980
- encoder_outputs = super().forward(
981
- hidden_states=hidden_states,
982
- attention_mask=attention_mask,
983
- head_mask=head_mask,
984
- encoder_hidden_states=encoder_hidden_states,
985
- encoder_attention_mask=encoder_attention_mask,
986
- past_key_values=past_key_values,
987
- use_cache=use_cache,
988
- output_attentions=output_attentions,
989
- output_hidden_states=output_hidden_states,
990
- return_dict=return_dict
991
- )
992
-
993
- sequence_output = encoder_outputs[0]
994
- if self.pool_with_global:
995
- sequence_output[:, self.num_global_tokens] = sequence_output[:, 0]
996
-
997
- # Adapt sequence to initial shape
998
- sequence_output = sequence_output[..., self.num_global_tokens: t + self.num_global_tokens, :]
999
-
1000
- if not return_dict:
1001
- return (sequence_output, ) + encoder_outputs[1:]
1002
-
1003
- encoder_outputs.last_hidden_state = sequence_output
1004
- return encoder_outputs
1005
-
1006
-
1007
- class LSGXLMRobertaPreTrainedModel(XLMRobertaPreTrainedModel):
1008
- """
1009
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1010
- models.
1011
- """
1012
-
1013
- config_class = LSGXLMRobertaConfig
1014
- base_model_prefix = "roberta"
1015
- supports_gradient_checkpointing = True
1016
- _no_split_modules = []
1017
-
1018
- def _set_gradient_checkpointing(self, module, value=False):
1019
- if isinstance(module, (XLMRobertaEncoder, LSGXLMRobertaEncoder)):
1020
- module.gradient_checkpointing = value
1021
-
1022
-
1023
- class LSGXLMRobertaModel(LSGXLMRobertaPreTrainedModel, XLMRobertaModel):
1024
- """
1025
- This class overrides :class:`~transformers.RobertaModel`. Please check the superclass for the appropriate
1026
- documentation alongside usage examples.
1027
- """
1028
-
1029
- def __init__(self, config, add_pooling_layer=True):
1030
-
1031
- LSGXLMRobertaPreTrainedModel.__init__(self, config)
1032
-
1033
- self.embeddings = LSGXLMRobertaEmbeddings(config)
1034
- self.encoder = LSGXLMRobertaEncoder(config)
1035
- self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
1036
-
1037
- if config.add_cross_attention:
1038
- logger.warning(
1039
- "Cross attention is computed using full attention since it is not LSG compatible."
1040
- )
1041
-
1042
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1043
- if self._use_flash_attention_2:
1044
- logger.warning(
1045
- "[WARNING flash-attention]: LSG doesnt support flash-attention currently"
1046
- )
1047
-
1048
- # Initialize weights and apply final processing
1049
- self.post_init()
1050
-
1051
- def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
1052
-
1053
- # Do not rely on original triangular mask from BERT/RoBERTa for causalLM
1054
- if attention_mask.dim() == 3:
1055
- extended_attention_mask = attention_mask[:, None, :, :]
1056
- elif attention_mask.dim() == 2:
1057
- extended_attention_mask = attention_mask[:, None, None, :]
1058
- else:
1059
- raise ValueError(
1060
- f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
1061
- )
1062
-
1063
- extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
1064
- extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(extended_attention_mask.dtype).min
1065
-
1066
- return extended_attention_mask
1067
-
1068
-
1069
- class LSGXLMRobertaForCausalLM(LSGXLMRobertaPreTrainedModel, XLMRobertaForCausalLM):
1070
-
1071
- _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
1072
-
1073
- def __init__(self, config):
1074
-
1075
- LSGXLMRobertaPreTrainedModel.__init__(self, config)
1076
-
1077
- if not config.is_decoder:
1078
- logger.warning("If you want to use `LSGXLMRobertaLMHeadModel` as a standalone, add `is_decoder=True.`")
1079
-
1080
- self.roberta = LSGXLMRobertaModel(config, add_pooling_layer=False)
1081
- self.lm_head = XLMRobertaLMHead(config)
1082
-
1083
- # Initialize weights and apply final processing
1084
- self.post_init()
1085
-
1086
-
1087
- class LSGXLMRobertaForMaskedLM(LSGXLMRobertaPreTrainedModel, XLMRobertaForMaskedLM):
1088
- """
1089
- This class overrides :class:`~transformers.RobertaForMaskedLM`. Please check the superclass for the appropriate
1090
- documentation alongside usage examples.
1091
- """
1092
- config_class = LSGXLMRobertaConfig
1093
- _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
1094
- _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
1095
- _keys_to_ignore_on_load_unexpected = [r"pooler"]
1096
- _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
1097
-
1098
- def __init__(self, config):
1099
-
1100
- LSGXLMRobertaPreTrainedModel.__init__(self, config)
1101
-
1102
- if config.is_decoder:
1103
- logger.warning(
1104
- "If you want to use `LSGXLMRobertaForMaskedLM` make sure `config.is_decoder=False` for "
1105
- "bi-directional self-attention."
1106
- )
1107
-
1108
- self.roberta = LSGXLMRobertaModel(config, add_pooling_layer=False)
1109
- self.lm_head = XLMRobertaLMHead(config)
1110
-
1111
- # Initialize weights and apply final processing
1112
- self.post_init()
1113
-
1114
-
1115
- class LSGXLMRobertaForSequenceClassification(LSGXLMRobertaPreTrainedModel, XLMRobertaForSequenceClassification):
1116
- """
1117
- This class overrides :class:`~transformers.RobertaForSequenceClassification`. Please check the superclass for the
1118
- appropriate documentation alongside usage examples.
1119
- """
1120
-
1121
- def __init__(self, config):
1122
-
1123
- LSGXLMRobertaPreTrainedModel.__init__(self, config)
1124
-
1125
- self.num_labels = config.num_labels
1126
- self.config = config
1127
-
1128
- self.roberta = LSGXLMRobertaModel(config, add_pooling_layer=False)
1129
- self.classifier = XLMRobertaClassificationHead(config)
1130
-
1131
- # Initialize weights and apply final processing
1132
- self.post_init()
1133
-
1134
-
1135
- class LSGXLMRobertaForMultipleChoice(LSGXLMRobertaPreTrainedModel, XLMRobertaForMultipleChoice):
1136
- """
1137
- This class overrides :class:`~transformers.RobertaForMultipleChoice`. Please check the superclass for the
1138
- appropriate documentation alongside usage examples.
1139
- """
1140
- config_class = LSGXLMRobertaConfig
1141
- _keys_to_ignore_on_load_missing = [r"position_ids"]
1142
-
1143
- def __init__(self, config):
1144
-
1145
- LSGXLMRobertaPreTrainedModel.__init__(self, config)
1146
-
1147
- self.roberta = LSGXLMRobertaModel(config)
1148
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
1149
- self.classifier = nn.Linear(config.hidden_size, 1)
1150
-
1151
- # Initialize weights and apply final processing
1152
- self.post_init()
1153
-
1154
-
1155
- class LSGXLMRobertaForTokenClassification(LSGXLMRobertaPreTrainedModel, XLMRobertaForTokenClassification):
1156
- """
1157
- This class overrides :class:`~transformers.RobertaForTokenClassification`. Please check the superclass for the
1158
- appropriate documentation alongside usage examples.
1159
- """
1160
-
1161
- def __init__(self, config):
1162
-
1163
- LSGXLMRobertaPreTrainedModel.__init__(self, config)
1164
-
1165
- self.num_labels = config.num_labels
1166
-
1167
- self.roberta = LSGXLMRobertaModel(config, add_pooling_layer=False)
1168
- classifier_dropout = (
1169
- config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1170
- )
1171
- self.dropout = nn.Dropout(classifier_dropout)
1172
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1173
-
1174
- # Initialize weights and apply final processing
1175
- self.post_init()
1176
-
1177
-
1178
- class LSGXLMRobertaForQuestionAnswering(LSGXLMRobertaPreTrainedModel, XLMRobertaForQuestionAnswering):
1179
- """
1180
- This class overrides :class:`~transformers.RobertaForQuestionAnswering`. Please check the superclass for the
1181
- appropriate documentation alongside usage examples.
1182
- """
1183
-
1184
- def __init__(self, config):
1185
-
1186
- LSGXLMRobertaPreTrainedModel.__init__(self, config)
1187
-
1188
- self.num_labels = config.num_labels
1189
-
1190
- self.roberta = LSGXLMRobertaModel(config, add_pooling_layer=False)
1191
- self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1192
-
1193
- # Initialize weights and apply final processing
1194
- self.post_init()
1195
-
1196
-
1197
- def str_to_class(classname):
1198
- return getattr(sys.modules[__name__], classname)
1199
-
1200
- # Register model in Auto API
1201
- try:
1202
- LSGXLMRobertaConfig.register_for_auto_class()
1203
- for key, value in AUTO_MAP.items():
1204
- str_to_class(value.split(".")[-1]).register_for_auto_class(key)
1205
- except:
1206
- warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
1207
- warn("Update to transformers >= 4.36.1 to fix.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Doc-classification-model/sentencepiece.bpe.model DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:cfc8146abe2a0488e9e2a0c56de7952f7c11ab059eca145a0a727afce0db2865
3
- size 5069051
 
 
 
 
Doc-classification-model/special_tokens_map.json DELETED
@@ -1,51 +0,0 @@
1
- {
2
- "bos_token": {
3
- "content": "<s>",
4
- "lstrip": false,
5
- "normalized": false,
6
- "rstrip": false,
7
- "single_word": false
8
- },
9
- "cls_token": {
10
- "content": "<s>",
11
- "lstrip": false,
12
- "normalized": false,
13
- "rstrip": false,
14
- "single_word": false
15
- },
16
- "eos_token": {
17
- "content": "</s>",
18
- "lstrip": false,
19
- "normalized": false,
20
- "rstrip": false,
21
- "single_word": false
22
- },
23
- "mask_token": {
24
- "content": "<mask>",
25
- "lstrip": true,
26
- "normalized": false,
27
- "rstrip": false,
28
- "single_word": false
29
- },
30
- "pad_token": {
31
- "content": "<pad>",
32
- "lstrip": false,
33
- "normalized": false,
34
- "rstrip": false,
35
- "single_word": false
36
- },
37
- "sep_token": {
38
- "content": "</s>",
39
- "lstrip": false,
40
- "normalized": false,
41
- "rstrip": false,
42
- "single_word": false
43
- },
44
- "unk_token": {
45
- "content": "<unk>",
46
- "lstrip": false,
47
- "normalized": false,
48
- "rstrip": false,
49
- "single_word": false
50
- }
51
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Doc-classification-model/tokenizer.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3a56def25aa40facc030ea8b0b87f3688e4b3c39eb8b45d5702b3a1300fe2a20
3
- size 17082734
 
 
 
 
Doc-classification-model/tokenizer_config.json DELETED
@@ -1,54 +0,0 @@
1
- {
2
- "added_tokens_decoder": {
3
- "0": {
4
- "content": "<s>",
5
- "lstrip": false,
6
- "normalized": false,
7
- "rstrip": false,
8
- "single_word": false,
9
- "special": true
10
- },
11
- "1": {
12
- "content": "<pad>",
13
- "lstrip": false,
14
- "normalized": false,
15
- "rstrip": false,
16
- "single_word": false,
17
- "special": true
18
- },
19
- "2": {
20
- "content": "</s>",
21
- "lstrip": false,
22
- "normalized": false,
23
- "rstrip": false,
24
- "single_word": false,
25
- "special": true
26
- },
27
- "3": {
28
- "content": "<unk>",
29
- "lstrip": false,
30
- "normalized": false,
31
- "rstrip": false,
32
- "single_word": false,
33
- "special": true
34
- },
35
- "250001": {
36
- "content": "<mask>",
37
- "lstrip": true,
38
- "normalized": false,
39
- "rstrip": false,
40
- "single_word": false,
41
- "special": true
42
- }
43
- },
44
- "bos_token": "<s>",
45
- "clean_up_tokenization_spaces": true,
46
- "cls_token": "<s>",
47
- "eos_token": "</s>",
48
- "mask_token": "<mask>",
49
- "model_max_length": 4096,
50
- "pad_token": "<pad>",
51
- "sep_token": "</s>",
52
- "tokenizer_class": "XLMRobertaTokenizer",
53
- "unk_token": "<unk>"
54
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Doc-classification-model/training_args.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:58a4a157c7952a6065c6e08217ab8419a638d8688669a5fad0c94e80bb92f5b2
3
- size 4920