jupyterjazz commited on
Commit
841b70f
2 Parent(s): 6060bad 0bb73e5

feat: merge stuff

Browse files

Signed-off-by: jupyterjazz <saba.sturua@jina.ai>

config.json CHANGED
@@ -3,8 +3,12 @@
3
  "AutoConfig": "configuration_xlm_roberta.XLMRobertaFlashConfig",
4
  "AutoModel": "modeling_xlm_roberta.XLMRobertaModel",
5
  "AutoModelForPreTraining": "modeling_xlm_roberta.XLMRobertaForPreTraining",
6
- "AutoModelForMaskedLM": "modeling_xlm_roberta.XLMRobertaForMaskedLM"
 
7
  },
 
 
 
8
  "attention_probs_dropout_prob": 0.1,
9
  "bos_token_id": 0,
10
  "eos_token_id": 2,
 
3
  "AutoConfig": "configuration_xlm_roberta.XLMRobertaFlashConfig",
4
  "AutoModel": "modeling_xlm_roberta.XLMRobertaModel",
5
  "AutoModelForPreTraining": "modeling_xlm_roberta.XLMRobertaForPreTraining",
6
+ "AutoModelForMaskedLM": "modeling_xlm_roberta.XLMRobertaForMaskedLM",
7
+ "AutoModelForSequenceClassification":"modeling_xlm_roberta.XLMRobertaForSequenceClassification"
8
  },
9
+ "architectures": [
10
+ "XLMRobertaModel"
11
+ ],
12
  "attention_probs_dropout_prob": 0.1,
13
  "bos_token_id": 0,
14
  "eos_token_id": 2,
configuration_xlm_roberta.py CHANGED
@@ -1,4 +1,5 @@
1
  from transformers import PretrainedConfig
 
2
 
3
  class XLMRobertaFlashConfig(PretrainedConfig):
4
  def __init__(
@@ -23,6 +24,9 @@ class XLMRobertaFlashConfig(PretrainedConfig):
23
  classifier_dropout=None,
24
  num_loras=1,
25
  load_trained_adapters=False,
 
 
 
26
  **kwargs,
27
  ):
28
  super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@@ -45,3 +49,9 @@ class XLMRobertaFlashConfig(PretrainedConfig):
45
  self.classifier_dropout = classifier_dropout
46
  self.num_loras = num_loras
47
  self.load_trained_adapters = load_trained_adapters
 
 
 
 
 
 
 
1
  from transformers import PretrainedConfig
2
+ import torch
3
 
4
  class XLMRobertaFlashConfig(PretrainedConfig):
5
  def __init__(
 
24
  classifier_dropout=None,
25
  num_loras=1,
26
  load_trained_adapters=False,
27
+ use_flash_attn=True,
28
+ torch_dtype=None,
29
+ emb_pooler=None,
30
  **kwargs,
31
  ):
32
  super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
 
49
  self.classifier_dropout = classifier_dropout
50
  self.num_loras = num_loras
51
  self.load_trained_adapters = load_trained_adapters
52
+ self.use_flash_attn = use_flash_attn
53
+ self.emb_pooler = emb_pooler
54
+ if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
55
+ self.torch_dtype = getattr(torch, torch_dtype)
56
+ else:
57
+ self.torch_dtype = torch_dtype
convert_roberta_weights_to_flash.py CHANGED
@@ -1,10 +1,11 @@
1
  import re
2
  from collections import OrderedDict
3
  from transformers import PretrainedConfig
4
- from transformers import XLMRobertaForMaskedLM
5
 
6
  from .configuration_xlm_roberta import XLMRobertaFlashConfig as BertConfig
7
- from .modeling_xlm_roberta import XLMRobertaForMaskedLM as BertModel
 
8
  import torch
9
 
10
  import click
@@ -137,14 +138,23 @@ def remap_state_dict(state_dict, config: PretrainedConfig):
137
 
138
  @click.command()
139
  @click.option('--model_name', default='FacebookAI/xlm-roberta-base', help='model name')
 
 
140
  @click.option('--output', default='converted_roberta_weights.bin', help='model name')
141
- def main(model_name, output):
142
- roberta_model = XLMRobertaForMaskedLM.from_pretrained(model_name)
 
 
 
 
143
  config = BertConfig.from_dict(roberta_model.config.to_dict())
144
  state_dict = roberta_model.state_dict()
145
  new_state_dict = remap_state_dict(state_dict, config)
146
-
147
- flash_model = BertModel(config)
 
 
 
148
 
149
  for k, v in flash_model.state_dict().items():
150
  if k not in new_state_dict:
 
1
  import re
2
  from collections import OrderedDict
3
  from transformers import PretrainedConfig
4
+ from transformers import XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification
5
 
6
  from .configuration_xlm_roberta import XLMRobertaFlashConfig as BertConfig
7
+ from .modeling_xlm_roberta import XLMRobertaForMaskedLM as FlashXLMRobertaForMaskedLM
8
+ from .modeling_xlm_roberta import XLMRobertaForSequenceClassification as FlashXLMRobertaForSequenceClassification
9
  import torch
10
 
11
  import click
 
138
 
139
  @click.command()
140
  @click.option('--model_name', default='FacebookAI/xlm-roberta-base', help='model name')
141
+ @click.option('--revision', default='main', help='revision')
142
+ @click.option('--task', default='masked_lm', help='task')
143
  @click.option('--output', default='converted_roberta_weights.bin', help='model name')
144
+ def main(model_name, revision, task, output):
145
+
146
+ if task == 'masked_lm':
147
+ roberta_model = XLMRobertaForMaskedLM.from_pretrained(model_name, revision=revision)
148
+ elif task == 'sequence_classification':
149
+ roberta_model = XLMRobertaForSequenceClassification.from_pretrained(model_name, revision=revision,num_labels=1)
150
  config = BertConfig.from_dict(roberta_model.config.to_dict())
151
  state_dict = roberta_model.state_dict()
152
  new_state_dict = remap_state_dict(state_dict, config)
153
+
154
+ if task == 'masked_lm':
155
+ flash_model = FlashXLMRobertaForMaskedLM(config)
156
+ elif task == 'sequence_classification':
157
+ flash_model = FlashXLMRobertaForSequenceClassification(config)
158
 
159
  for k, v in flash_model.state_dict().items():
160
  if k not in new_state_dict:
mha.py CHANGED
@@ -10,8 +10,6 @@ import torch
10
  import torch.nn as nn
11
  from einops import rearrange, repeat
12
 
13
- from flash_attn.utils.distributed import get_dim_for_local_rank
14
-
15
  try:
16
  from flash_attn import (
17
  flash_attn_kvpacked_func,
 
10
  import torch.nn as nn
11
  from einops import rearrange, repeat
12
 
 
 
13
  try:
14
  from flash_attn import (
15
  flash_attn_kvpacked_func,
modeling_xlm_roberta.py CHANGED
@@ -1,6 +1,5 @@
1
  # This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
2
  # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3
-
4
  # Copyright (c) 2022, Tri Dao.
5
  # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
6
  # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
@@ -8,20 +7,23 @@
8
 
9
  # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
10
 
 
11
  import logging
12
  import re
13
  from collections import OrderedDict
14
  from collections.abc import Sequence
15
  from functools import partial
 
16
 
17
  import torch
18
  import torch.nn as nn
19
  import torch.nn.functional as F
20
  import torch.utils.checkpoint
 
21
  from einops import rearrange
22
  from transformers import PretrainedConfig
23
  from transformers.modeling_utils import PreTrainedModel
24
- from transformers.modeling_outputs import MaskedLMOutput
25
  from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
26
 
27
  from transformers.models.bert.modeling_bert import (
@@ -29,7 +31,7 @@ from transformers.models.bert.modeling_bert import (
29
  BertForPreTrainingOutput,
30
  )
31
 
32
- from typing import Optional, Tuple, Union
33
 
34
  from .xlm_padding import (
35
  index_first_axis,
@@ -61,12 +63,30 @@ try:
61
  except ImportError:
62
  CrossEntropyLoss = None
63
 
 
 
 
 
 
64
 
65
  logger = logging.getLogger(__name__)
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def create_mixer_cls(config, cross_attn=False, return_residual=False):
69
- use_flash_attn = getattr(config, "use_flash_attn", False)
70
  fused_bias_fc = getattr(config, "fused_bias_fc", False)
71
  rotary_kwargs = {}
72
  if config.position_embedding_type == "rotary":
@@ -169,7 +189,7 @@ def _init_weights(module, initializer_range=0.02):
169
  class XLMRobertaEncoder(nn.Module):
170
  def __init__(self, config: XLMRobertaFlashConfig):
171
  super().__init__()
172
- self.use_flash_attn = getattr(config, "use_flash_attn", False)
173
  self.layers = nn.ModuleList(
174
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
175
  )
@@ -376,6 +396,17 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
376
  if isinstance(module, XLMRobertaEncoder):
377
  module.gradient_checkpointing = value
378
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
  class XLMRobertaModel(XLMRobertaPreTrainedModel):
381
  def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
@@ -409,6 +440,169 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
409
 
410
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  def forward(
413
  self,
414
  input_ids,
@@ -946,3 +1140,117 @@ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
946
  )
947
 
948
  return state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
2
  # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
 
3
  # Copyright (c) 2022, Tri Dao.
4
  # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
5
  # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
 
7
 
8
  # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
9
 
10
+ import importlib.util
11
  import logging
12
  import re
13
  from collections import OrderedDict
14
  from collections.abc import Sequence
15
  from functools import partial
16
+ import numpy as np
17
 
18
  import torch
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
  import torch.utils.checkpoint
22
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
  from einops import rearrange
24
  from transformers import PretrainedConfig
25
  from transformers.modeling_utils import PreTrainedModel
26
+ from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
27
  from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
28
 
29
  from transformers.models.bert.modeling_bert import (
 
31
  BertForPreTrainingOutput,
32
  )
33
 
34
+ from typing import List, Optional, Tuple, Union
35
 
36
  from .xlm_padding import (
37
  index_first_axis,
 
63
  except ImportError:
64
  CrossEntropyLoss = None
65
 
66
+ try:
67
+ from tqdm.autonotebook import trange
68
+ except ImportError:
69
+ trange = None
70
+
71
 
72
  logger = logging.getLogger(__name__)
73
 
74
 
75
+ def get_use_flash_attn(config: XLMRobertaFlashConfig):
76
+ if not getattr(config, "use_flash_attn", False):
77
+ return False
78
+ if not torch.cuda.is_available():
79
+ return False
80
+ if importlib.util.find_spec("flash_attn") is None:
81
+ logger.warning(
82
+ 'flash_attn is not installed. Using PyTorch native attention implementation.'
83
+ )
84
+ return False
85
+ return True
86
+
87
+
88
  def create_mixer_cls(config, cross_attn=False, return_residual=False):
89
+ use_flash_attn = get_use_flash_attn(config)
90
  fused_bias_fc = getattr(config, "fused_bias_fc", False)
91
  rotary_kwargs = {}
92
  if config.position_embedding_type == "rotary":
 
189
  class XLMRobertaEncoder(nn.Module):
190
  def __init__(self, config: XLMRobertaFlashConfig):
191
  super().__init__()
192
+ self.use_flash_attn = get_use_flash_attn(config)
193
  self.layers = nn.ModuleList(
194
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
195
  )
 
396
  if isinstance(module, XLMRobertaEncoder):
397
  module.gradient_checkpointing = value
398
 
399
+ @classmethod
400
+ def from_pretrained(
401
+ cls,
402
+ *args,
403
+ **kwargs,
404
+ ):
405
+ if not 'torch_dtype' in kwargs:
406
+ kwargs['torch_dtype'] = 'auto'
407
+ return super().from_pretrained(*args, **kwargs)
408
+
409
+
410
 
411
  class XLMRobertaModel(XLMRobertaPreTrainedModel):
412
  def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
 
440
 
441
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
442
 
443
+
444
+ @torch.inference_mode()
445
+ def encode(
446
+ self: 'XLMRobertaModel',
447
+ sentences: Union[str, List[str]],
448
+ batch_size: int = 32,
449
+ show_progress_bar: Optional[bool] = None,
450
+ output_value: str = 'sentence_embedding',
451
+ convert_to_numpy: bool = True,
452
+ convert_to_tensor: bool = False,
453
+ device: Optional[torch.device] = None,
454
+ normalize_embeddings: bool = False,
455
+ **tokenizer_kwargs,
456
+ ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
457
+ """
458
+ Computes sentence embeddings
459
+ Args:
460
+ sentences(`str` or `List[str]`):
461
+ Sentence or sentences to be encoded
462
+ batch_size(`int`, *optional*, defaults to 32):
463
+ Batch size for the computation
464
+ show_progress_bar(`bool`, *optional*, defaults to None):
465
+ Show a progress bar when encoding sentences.
466
+ If set to None, progress bar is only shown when
467
+ `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
468
+ output_value(`str`, *optional*, defaults to 'sentence_embedding'):
469
+ Default sentence_embedding, to get sentence embeddings.
470
+ Can be set to token_embeddings to get wordpiece token embeddings.
471
+ Set to None, to get all output values
472
+ convert_to_numpy(`bool`, *optional*, defaults to True):
473
+ If true, the output is a list of numpy vectors.
474
+ Else, it is a list of pytorch tensors.
475
+ convert_to_tensor(`bool`, *optional*, defaults to False):
476
+ If true, you get one large tensor as return.
477
+ Overwrites any setting from convert_to_numpy
478
+ device(`torch.device`, *optional*, defaults to None):
479
+ Which torch.device to use for the computation
480
+ normalize_embeddings(`bool`, *optional*, defaults to False):
481
+ If set to true, returned vectors will have length 1. In that case, the
482
+ faster dot-product (util.dot_score) instead of cosine similarity can
483
+ be used.
484
+ tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
485
+ Keyword arguments for the tokenizer
486
+ Returns:
487
+ By default, a list of tensors is returned.
488
+ If convert_to_tensor, a stacked tensor is returned.
489
+ If convert_to_numpy, a numpy matrix is returned.
490
+ """
491
+ from transformers import AutoTokenizer
492
+
493
+ self.tokenizer = AutoTokenizer.from_pretrained(
494
+ self.name_or_path, trust_remote_code=True
495
+ )
496
+
497
+ is_training = self.training
498
+ self.eval()
499
+
500
+ if show_progress_bar is None:
501
+ show_progress_bar = (
502
+ logger.getEffectiveLevel() == logging.INFO
503
+ or logger.getEffectiveLevel() == logging.DEBUG
504
+ )
505
+
506
+ if convert_to_tensor:
507
+ convert_to_numpy = False
508
+
509
+ if output_value != 'sentence_embedding':
510
+ convert_to_tensor = False
511
+ convert_to_numpy = False
512
+
513
+ input_was_string = False
514
+ if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
515
+ sentences = [sentences]
516
+ input_was_string = True
517
+
518
+ if device is not None:
519
+ self.to(device)
520
+
521
+ permutation = np.argsort([-len(i) for i in sentences])
522
+ inverse_permutation = np.argsort(permutation)
523
+ sentences = [sentences[idx] for idx in permutation]
524
+
525
+ tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
526
+ tokenizer_kwargs['max_length'] = tokenizer_kwargs.get(
527
+ 'max_length', self.tokenizer.init_kwargs.get('model_max_length', 8192)
528
+ )
529
+ tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
530
+
531
+ all_embeddings = []
532
+
533
+ if trange is not None:
534
+ range_iter = trange(
535
+ 0,
536
+ len(sentences),
537
+ batch_size,
538
+ desc="Encoding",
539
+ disable=not show_progress_bar,
540
+ )
541
+ else:
542
+ range_iter = range(0, len(sentences), batch_size)
543
+
544
+ for i in range_iter:
545
+ encoded_input = self.tokenizer(
546
+ sentences[i : i + batch_size],
547
+ return_tensors='pt',
548
+ **tokenizer_kwargs,
549
+ ).to(self.device)
550
+ token_embs = self.forward(**encoded_input)[0]
551
+
552
+ # Accumulate in fp32 to avoid overflow
553
+ token_embs = token_embs.float()
554
+
555
+ if output_value == 'token_embeddings':
556
+ raise NotImplementedError
557
+ elif output_value is None:
558
+ raise NotImplementedError
559
+ else:
560
+ if self.config.emb_pooler == 'cls':
561
+ embeddings = self.cls_pooling(
562
+ token_embs, encoded_input['attention_mask']
563
+ )
564
+ else:
565
+ embeddings = self.mean_pooling(
566
+ token_embs, encoded_input['attention_mask']
567
+ )
568
+
569
+ if normalize_embeddings:
570
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
571
+
572
+ if convert_to_numpy:
573
+ embeddings = embeddings.cpu()
574
+ all_embeddings.extend(embeddings)
575
+
576
+ all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
577
+
578
+ if convert_to_tensor:
579
+ all_embeddings = torch.stack(all_embeddings)
580
+ elif convert_to_numpy:
581
+ all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
582
+
583
+ if input_was_string:
584
+ all_embeddings = all_embeddings[0]
585
+
586
+ self.train(is_training)
587
+ return all_embeddings
588
+
589
+ def mean_pooling(
590
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
591
+ ):
592
+ input_mask_expanded = (
593
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
594
+ )
595
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
596
+ input_mask_expanded.sum(1), min=1e-9
597
+ )
598
+
599
+
600
+ def cls_pooling(
601
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
602
+ ):
603
+ return token_embeddings[:,0]
604
+
605
+
606
  def forward(
607
  self,
608
  input_ids,
 
1140
  )
1141
 
1142
  return state_dict
1143
+
1144
+
1145
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->XLMRoberta
1146
+ class XLMRobertaClassificationHead(nn.Module):
1147
+ """Head for sentence-level classification tasks."""
1148
+
1149
+ def __init__(self, config):
1150
+ super().__init__()
1151
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1152
+ classifier_dropout = (
1153
+ config.classifier_dropout
1154
+ if config.classifier_dropout is not None
1155
+ else config.hidden_dropout_prob
1156
+ )
1157
+ self.dropout = nn.Dropout(classifier_dropout)
1158
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
1159
+
1160
+ def forward(self, features, **kwargs):
1161
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1162
+ x = self.dropout(x)
1163
+ x = self.dense(x)
1164
+ x = torch.tanh(x)
1165
+ x = self.dropout(x)
1166
+ x = self.out_proj(x)
1167
+ return x
1168
+
1169
+
1170
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
1171
+ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
1172
+ def __init__(self, config):
1173
+ super().__init__(config)
1174
+ self.num_labels = config.num_labels
1175
+ self.config = config
1176
+
1177
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
1178
+ self.classifier = XLMRobertaClassificationHead(config)
1179
+
1180
+ # Initialize weights and apply final processing
1181
+ self.post_init()
1182
+
1183
+ def forward(
1184
+ self,
1185
+ input_ids: Optional[torch.LongTensor] = None,
1186
+ attention_mask: Optional[torch.FloatTensor] = None,
1187
+ token_type_ids: Optional[torch.LongTensor] = None,
1188
+ position_ids: Optional[torch.LongTensor] = None,
1189
+ head_mask: Optional[torch.FloatTensor] = None,
1190
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1191
+ labels: Optional[torch.LongTensor] = None,
1192
+ output_attentions: Optional[bool] = None,
1193
+ output_hidden_states: Optional[bool] = None,
1194
+ return_dict: Optional[bool] = None,
1195
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1196
+ r"""
1197
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1198
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1199
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1200
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1201
+ """
1202
+ return_dict = (
1203
+ return_dict if return_dict is not None else self.config.use_return_dict
1204
+ )
1205
+
1206
+ outputs = self.roberta(
1207
+ input_ids,
1208
+ attention_mask=attention_mask,
1209
+ token_type_ids=token_type_ids,
1210
+ position_ids=position_ids,
1211
+ head_mask=head_mask,
1212
+ inputs_embeds=inputs_embeds,
1213
+ output_attentions=output_attentions,
1214
+ output_hidden_states=output_hidden_states,
1215
+ return_dict=return_dict,
1216
+ )
1217
+ sequence_output = outputs[0]
1218
+ logits = self.classifier(sequence_output)
1219
+
1220
+ loss = None
1221
+ if labels is not None:
1222
+ # move labels to correct device to enable model parallelism
1223
+ labels = labels.to(logits.device)
1224
+ if self.config.problem_type is None:
1225
+ if self.num_labels == 1:
1226
+ self.config.problem_type = "regression"
1227
+ elif self.num_labels > 1 and (
1228
+ labels.dtype == torch.long or labels.dtype == torch.int
1229
+ ):
1230
+ self.config.problem_type = "single_label_classification"
1231
+ else:
1232
+ self.config.problem_type = "multi_label_classification"
1233
+
1234
+ if self.config.problem_type == "regression":
1235
+ loss_fct = MSELoss()
1236
+ if self.num_labels == 1:
1237
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1238
+ else:
1239
+ loss = loss_fct(logits, labels)
1240
+ elif self.config.problem_type == "single_label_classification":
1241
+ loss_fct = CrossEntropyLoss()
1242
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1243
+ elif self.config.problem_type == "multi_label_classification":
1244
+ loss_fct = BCEWithLogitsLoss()
1245
+ loss = loss_fct(logits, labels)
1246
+
1247
+ if not return_dict:
1248
+ output = (logits,) + outputs[2:]
1249
+ return ((loss,) + output) if loss is not None else output
1250
+
1251
+ return SequenceClassifierOutput(
1252
+ loss=loss,
1253
+ logits=logits,
1254
+ hidden_states=outputs.hidden_states,
1255
+ attentions=outputs.attentions,
1256
+ )
modeling_xlm_roberta_for_glue.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
6
+ from transformers.modeling_outputs import SequenceClassifierOutput, QuestionAnsweringModelOutput, TokenClassifierOutput
7
+
8
+ from .modeling_xlm_roberta import XLMRobertaPreTrainedModel, XLMRobertaModel
9
+ from .configuration_xlm_roberta import XLMRobertaFlashConfig
10
+
11
+
12
+ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
13
+ def __init__(self, config: XLMRobertaFlashConfig):
14
+ super().__init__(config)
15
+ self.num_labels = config.num_labels
16
+ self.config = config
17
+
18
+ self.roberta = XLMRobertaModel(config)
19
+ classifier_dropout = (
20
+ config.classifier_dropout
21
+ if config.classifier_dropout is not None
22
+ else config.hidden_dropout_prob
23
+ )
24
+ self.dropout = nn.Dropout(classifier_dropout)
25
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
26
+
27
+ # Initialize weights and apply final processing
28
+ self.post_init()
29
+
30
+
31
+ def forward(
32
+ self,
33
+ input_ids: Optional[torch.Tensor] = None,
34
+ attention_mask: Optional[torch.Tensor] = None,
35
+ token_type_ids: Optional[torch.Tensor] = None,
36
+ position_ids: Optional[torch.Tensor] = None,
37
+ head_mask: Optional[torch.Tensor] = None,
38
+ inputs_embeds: Optional[torch.Tensor] = None,
39
+ labels: Optional[torch.Tensor] = None,
40
+ output_attentions: Optional[bool] = None,
41
+ output_hidden_states: Optional[bool] = None,
42
+ return_dict: Optional[bool] = None,
43
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
44
+ r"""
45
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
46
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
47
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
48
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
49
+ """
50
+ return_dict = (
51
+ return_dict if return_dict is not None else self.config.use_return_dict
52
+ )
53
+
54
+ assert head_mask is None
55
+ assert inputs_embeds is None
56
+ assert output_attentions is None
57
+ assert output_hidden_states is None
58
+ assert return_dict
59
+ outputs = self.roberta(
60
+ input_ids,
61
+ attention_mask=attention_mask,
62
+ token_type_ids=token_type_ids,
63
+ position_ids=position_ids,
64
+ head_mask=head_mask,
65
+ inputs_embeds=inputs_embeds,
66
+ output_attentions=output_attentions,
67
+ output_hidden_states=output_hidden_states,
68
+ return_dict=return_dict,
69
+ )
70
+
71
+ pooled_output = outputs[1]
72
+
73
+ pooled_output = self.dropout(pooled_output)
74
+ logits = self.classifier(pooled_output)
75
+
76
+ loss = None
77
+ if labels is not None:
78
+ if self.config.problem_type is None:
79
+ if self.num_labels == 1:
80
+ self.config.problem_type = "regression"
81
+ elif self.num_labels > 1 and (
82
+ labels.dtype == torch.long or labels.dtype == torch.int
83
+ ):
84
+ self.config.problem_type = "single_label_classification"
85
+ else:
86
+ self.config.problem_type = "multi_label_classification"
87
+
88
+ if self.config.problem_type == "regression":
89
+ loss_fct = MSELoss()
90
+ if self.num_labels == 1:
91
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
92
+ else:
93
+ loss = loss_fct(logits, labels)
94
+ elif self.config.problem_type == "single_label_classification":
95
+ loss_fct = CrossEntropyLoss()
96
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
97
+ elif self.config.problem_type == "multi_label_classification":
98
+ loss_fct = BCEWithLogitsLoss()
99
+ loss = loss_fct(logits, labels)
100
+ if not return_dict:
101
+ output = (logits,) + outputs[2:]
102
+ return ((loss,) + output) if loss is not None else output
103
+
104
+ return SequenceClassifierOutput(
105
+ loss=loss,
106
+ logits=logits,
107
+ hidden_states=outputs.hidden_states,
108
+ attentions=outputs.attentions,
109
+ )