DanielHesslow commited on
Commit
ae25659
1 Parent(s): e2f0366
Files changed (4) hide show
  1. config.json +5 -3
  2. pytorch_model.bin +2 -2
  3. rita_configuration.py +3 -1
  4. rita_modeling.py +217 -15
config.json CHANGED
@@ -1,17 +1,19 @@
1
  {
2
- "_name_or_path": "Seledorn/RITA_m",
3
  "architectures": [
4
- "RITAModel"
5
  ],
6
  "auto_map": {
7
  "AutoConfig": "rita_configuration.RITAConfig",
8
  "AutoModel": "rita_modeling.RITAModel",
9
- "AutoModelForCausalLM": "rita_modeling.RITAModel"
 
10
  },
11
  "d_feedforward": 4096,
12
  "d_model": 1024,
13
  "dropout": 0.0,
14
  "eos_token_id": 2,
 
15
  "max_seq_len": 1024,
16
  "model_type": "rita",
17
  "num_heads": 16,
1
  {
2
+ "_name_or_path": "Seledorn/RITA_m_2",
3
  "architectures": [
4
+ "RITAModelForCausalLM"
5
  ],
6
  "auto_map": {
7
  "AutoConfig": "rita_configuration.RITAConfig",
8
  "AutoModel": "rita_modeling.RITAModel",
9
+ "AutoModelForCausalLM": "rita_modeling.RITAModelForCausalLM",
10
+ "AutoModelForSequenceClassification": "rita_modeling.RITAModelForSequenceClassification"
11
  },
12
  "d_feedforward": 4096,
13
  "d_model": 1024,
14
  "dropout": 0.0,
15
  "eos_token_id": 2,
16
+ "initializer_range": 0.02,
17
  "max_seq_len": 1024,
18
  "model_type": "rita",
19
  "num_heads": 16,
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d65d9e6f6e9d88059d230f690d3d56daa3c1d88da3282f9e5ac1cbf0d6d6f18c
3
- size 604802635
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f27acd85e3cdaf1f803995d7ae653e9b00dac37cca4dd6048c5839c92df93548
3
+ size 604861001
rita_configuration.py CHANGED
@@ -16,6 +16,7 @@ class RITAConfig(PretrainedConfig):
16
  dropout=0.,
17
  ff_ratio=4,
18
  eos_token_id=2,
 
19
  **kwargs,
20
  ):
21
  super().__init__(eos_token_id=eos_token_id, **kwargs)
@@ -26,4 +27,5 @@ class RITAConfig(PretrainedConfig):
26
  self.num_layers = num_layers
27
  self.max_seq_len=max_seq_len
28
  self.dropout = dropout
29
- self.eos_token_id=eos_token_id
 
16
  dropout=0.,
17
  ff_ratio=4,
18
  eos_token_id=2,
19
+ initializer_range=0.02,
20
  **kwargs,
21
  ):
22
  super().__init__(eos_token_id=eos_token_id, **kwargs)
27
  self.num_layers = num_layers
28
  self.max_seq_len=max_seq_len
29
  self.dropout = dropout
30
+ self.eos_token_id=eos_token_id
31
+ self.initializer_range=0.02
rita_modeling.py CHANGED
@@ -6,14 +6,12 @@ from typing import Optional, Tuple, Union
6
  import torch
7
  import torch.utils.checkpoint
8
  from torch import nn
9
- from torch.nn import CrossEntropyLoss
10
 
11
  from transformers.modeling_outputs import (
12
- BaseModelOutputWithPast,
13
- BaseModelOutputWithPastAndCrossAttentions,
14
- CausalLMOutputWithCrossAttentions,
15
- CausalLMOutputWithPast,
16
  CausalLMOutput,
 
17
  )
18
 
19
  from transformers.modeling_utils import PreTrainedModel
@@ -210,9 +208,12 @@ class DecoderLayer(nn.Module):
210
  y = self.mlp(y)
211
  x = x + self.mlp_dropout(y)
212
  return x
213
-
214
  class RITAModel(PreTrainedModel):
215
  config_class = RITAConfig
 
 
 
216
  def __init__(
217
  self,
218
  config
@@ -221,7 +222,6 @@ class RITAModel(PreTrainedModel):
221
  self.embedding = nn.Embedding(config.vocab_size, config.d_model)
222
  self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_layers)])
223
  self.final_norm = nn.LayerNorm(config.d_model)
224
- self.projector = nn.Linear(config.d_model, config.vocab_size, bias = False)
225
 
226
  def forward(
227
  self,
@@ -251,7 +251,78 @@ class RITAModel(PreTrainedModel):
251
  x = layer(x, attn_mask=attention_mask)
252
  x = self.final_norm(x) # N x L x D
253
 
254
- logits = self.projector(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  loss = None
256
  if labels is not None:
257
  # Shift so that tokens < n predict n
@@ -264,19 +335,150 @@ class RITAModel(PreTrainedModel):
264
  return CausalLMOutput(
265
  loss=loss,
266
  logits=logits,
267
- hidden_states=x,
268
  )
269
 
270
-
271
  #Some common HF functions.
272
  def get_input_embeddings(self):
273
- return self.embedding
274
 
275
  def set_input_embeddings(self, new_embeddings):
276
- self.embedding = new_embeddings
277
 
278
  def get_output_embeddings(self):
279
- return self.projector
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
- def set_output_embeddings(self, new_projector):
282
- self.projector = new_projector
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import torch
7
  import torch.utils.checkpoint
8
  from torch import nn
9
+ from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss, MSELoss
10
 
11
  from transformers.modeling_outputs import (
12
+ BaseModelOutput,
 
 
 
13
  CausalLMOutput,
14
+ SequenceClassifierOutput
15
  )
16
 
17
  from transformers.modeling_utils import PreTrainedModel
208
  y = self.mlp(y)
209
  x = x + self.mlp_dropout(y)
210
  return x
211
+
212
  class RITAModel(PreTrainedModel):
213
  config_class = RITAConfig
214
+ base_model_prefix = "transformer"
215
+ is_parallelizable = False
216
+
217
  def __init__(
218
  self,
219
  config
222
  self.embedding = nn.Embedding(config.vocab_size, config.d_model)
223
  self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_layers)])
224
  self.final_norm = nn.LayerNorm(config.d_model)
 
225
 
226
  def forward(
227
  self,
251
  x = layer(x, attn_mask=attention_mask)
252
  x = self.final_norm(x) # N x L x D
253
 
254
+ return BaseModelOutput(
255
+ hidden_states=x,
256
+ )
257
+
258
+ #Some common HF functions.
259
+ def get_input_embeddings(self):
260
+ return self.embedding
261
+
262
+ def set_input_embeddings(self, new_embeddings):
263
+ self.embedding = new_embeddings
264
+
265
+ def _init_weights(self, module):
266
+ """Initialize the weights."""
267
+ if isinstance(module, nn.Linear):
268
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
269
+ if module.bias is not None:
270
+ module.bias.data.zero_()
271
+ elif isinstance(module, nn.Embedding):
272
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
273
+ if module.padding_idx is not None:
274
+ module.weight.data[module.padding_idx].zero_()
275
+ elif isinstance(module, nn.LayerNorm):
276
+ module.bias.data.zero_()
277
+ module.weight.data.fill_(1.0)
278
+
279
+
280
+ class RITAModelForCausalLM(PreTrainedModel):
281
+ config_class = RITAConfig
282
+ base_model_prefix = "transformer"
283
+ is_parallelizable = False
284
+
285
+ def __init__(
286
+ self,
287
+ config
288
+ ):
289
+ super().__init__(config)
290
+ self.transformer = RITAModel(config)
291
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
292
+
293
+ def forward(
294
+ self,
295
+ input_ids=None,
296
+ past_key_values=None, # NOT USED
297
+ attention_mask=None,
298
+ token_type_ids=None, # NOT USED
299
+ position_ids=None, # NOT USED
300
+ head_mask=None, # NOT USED
301
+ inputs_embeds=None,
302
+ encoder_hidden_states=None, # NOT USED
303
+ encoder_attention_mask=None, # NOT USED
304
+ labels=None,
305
+ use_cache=None, # NOT USED
306
+ output_attentions=None, # NOT USED
307
+ output_hidden_states=None, # NOT USED
308
+ return_dict=None # NOT USED
309
+ ) -> torch.FloatTensor:
310
+
311
+ transformer_outputs = self.transformer(
312
+ input_ids,
313
+ past_key_values=past_key_values,
314
+ attention_mask=attention_mask,
315
+ token_type_ids=token_type_ids,
316
+ position_ids=position_ids,
317
+ head_mask=head_mask,
318
+ inputs_embeds=inputs_embeds,
319
+ use_cache=use_cache,
320
+ output_attentions=output_attentions,
321
+ output_hidden_states=output_hidden_states,
322
+ return_dict=return_dict,
323
+ )
324
+
325
+ logits = self.lm_head(transformer_outputs.hidden_states)
326
  loss = None
327
  if labels is not None:
328
  # Shift so that tokens < n predict n
335
  return CausalLMOutput(
336
  loss=loss,
337
  logits=logits,
338
+ hidden_states=transformer_outputs.hidden_states,
339
  )
340
 
 
341
  #Some common HF functions.
342
  def get_input_embeddings(self):
343
+ return self.transformer.embedding
344
 
345
  def set_input_embeddings(self, new_embeddings):
346
+ self.transformer.embedding = new_embeddings
347
 
348
  def get_output_embeddings(self):
349
+ return self.lm_head
350
+
351
+ def set_output_embeddings(self, lm_head):
352
+ self.lm_head = lm_head
353
+
354
+ def _init_weights(self, module):
355
+ """Initialize the weights."""
356
+ if isinstance(module, nn.Linear):
357
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
358
+ if module.bias is not None:
359
+ module.bias.data.zero_()
360
+ elif isinstance(module, nn.Embedding):
361
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
362
+ if module.padding_idx is not None:
363
+ module.weight.data[module.padding_idx].zero_()
364
+ elif isinstance(module, nn.LayerNorm):
365
+ module.bias.data.zero_()
366
+ module.weight.data.fill_(1.0)
367
+
368
+
369
+ class RITAModelForSequenceClassification(PreTrainedModel):
370
+ config_class = RITAConfig
371
+ base_model_prefix = "transformer"
372
+ is_parallelizable = False
373
+
374
+ def __init__(self, config):
375
+ super().__init__(config)
376
+ self.num_labels = config.num_labels
377
+ self.transformer = RITAModel(config)
378
+ self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
379
+
380
+ def forward(
381
+ self,
382
+ input_ids=None,
383
+ past_key_values=None,
384
+ attention_mask=None,
385
+ token_type_ids=None,
386
+ position_ids=None,
387
+ head_mask=None,
388
+ inputs_embeds=None,
389
+ labels=None,
390
+ use_cache=None,
391
+ output_attentions=None,
392
+ output_hidden_states=None,
393
+ return_dict=None,
394
+ ):
395
+ r"""
396
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
397
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
398
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
399
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
400
+ """
401
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
402
 
403
+ transformer_outputs = self.transformer(
404
+ input_ids,
405
+ past_key_values=past_key_values,
406
+ attention_mask=attention_mask,
407
+ token_type_ids=token_type_ids,
408
+ position_ids=position_ids,
409
+ head_mask=head_mask,
410
+ inputs_embeds=inputs_embeds,
411
+ use_cache=use_cache,
412
+ output_attentions=output_attentions,
413
+ output_hidden_states=output_hidden_states,
414
+ return_dict=return_dict,
415
+ )
416
+ hidden_states = transformer_outputs[0]
417
+ logits = self.score(hidden_states)
418
+
419
+ if input_ids is not None:
420
+ batch_size, sequence_length = input_ids.shape[:2]
421
+ else:
422
+ batch_size, sequence_length = inputs_embeds.shape[:2]
423
+
424
+ assert (
425
+ self.config.pad_token_id is not None or batch_size == 1
426
+ ), "Cannot handle batch sizes > 1 if no padding token is defined."
427
+ if self.config.pad_token_id is None:
428
+ sequence_lengths = -1
429
+ else:
430
+ if input_ids is not None:
431
+ sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
432
+ else:
433
+ sequence_lengths = -1
434
+ logger.warning(
435
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
436
+ f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
437
+ )
438
+
439
+ pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
440
+
441
+ loss = None
442
+ if labels is not None:
443
+ if self.config.problem_type is None:
444
+ if self.num_labels == 1:
445
+ self.config.problem_type = "regression"
446
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
447
+ self.config.problem_type = "single_label_classification"
448
+ else:
449
+ self.config.problem_type = "multi_label_classification"
450
+
451
+ if self.config.problem_type == "regression":
452
+ loss_fct = MSELoss()
453
+ if self.num_labels == 1:
454
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
455
+ else:
456
+ loss = loss_fct(pooled_logits, labels)
457
+ elif self.config.problem_type == "single_label_classification":
458
+ loss_fct = CrossEntropyLoss()
459
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
460
+ elif self.config.problem_type == "multi_label_classification":
461
+ loss_fct = BCEWithLogitsLoss()
462
+ loss = loss_fct(pooled_logits, labels)
463
+ if not return_dict:
464
+ output = (pooled_logits,) + transformer_outputs[1:]
465
+ return ((loss,) + output) if loss is not None else output
466
+
467
+ return SequenceClassifierOutput(
468
+ loss=loss,
469
+ logits=pooled_logits,
470
+ )
471
+
472
+ def _init_weights(self, module):
473
+ """Initialize the weights."""
474
+ if isinstance(module, nn.Linear):
475
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
476
+ if module.bias is not None:
477
+ module.bias.data.zero_()
478
+ elif isinstance(module, nn.Embedding):
479
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
480
+ if module.padding_idx is not None:
481
+ module.weight.data[module.padding_idx].zero_()
482
+ elif isinstance(module, nn.LayerNorm):
483
+ module.bias.data.zero_()
484
+ module.weight.data.fill_(1.0)