suayptalha commited on
Commit
b27e0c7
·
verified ·
1 Parent(s): 2a6d57e

Create modeling_minGRULM.py

Browse files
Files changed (1) hide show
  1. modeling_minGRULM.py +78 -0
modeling_minGRULM.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import PreTrainedModel
4
+ from transformers.modeling_outputs import CausalLMOutputWithPast
5
+ from torch.nn import CrossEntropyLoss
6
+ from typing import Optional
7
+ from .configuration_minGRULM import MinGRULMConfig
8
+ from minGRU_pytorch.minGRULM import minGRULM
9
+
10
+
11
+ class MinGRULMPreTrainedModel(PreTrainedModel):
12
+ config_class = MinGRULMConfig
13
+ base_model_prefix = "model"
14
+
15
+ def _init_weights(self, module):
16
+ std = 0.02
17
+ if isinstance(module, nn.Linear):
18
+ module.weight.data.normal_(mean=0.0, std=std)
19
+ if module.bias is not None:
20
+ module.bias.data.zero_()
21
+ elif isinstance(module, nn.Embedding):
22
+ module.weight.data.normal_(mean=0.0, std=std)
23
+ if module.padding_idx is not None:
24
+ module.weight.data[module.padding_idx].zero_()
25
+
26
+
27
+ class MinGRULMForCausalLM(MinGRULMPreTrainedModel):
28
+ def __init__(self, config: MinGRULMConfig):
29
+ super().__init__(config)
30
+
31
+ # Load model from minGRULM library
32
+ self.model = minGRULM(
33
+ num_tokens=config.vocab_size,
34
+ dim=config.d_model,
35
+ depth=config.n_layer,
36
+ ff_mult=config.ff_mult,
37
+ min_gru_expansion=config.expand,
38
+ enable_conv=config.enable_conv,
39
+ )
40
+
41
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
42
+ self.post_init()
43
+
44
+ def get_input_embeddings(self):
45
+ return self.model.token_emb
46
+
47
+ def set_input_embeddings(self, value):
48
+ self.model.token_emb = value
49
+
50
+ def get_output_embeddings(self):
51
+ return self.lm_head
52
+
53
+ def forward(
54
+ self,
55
+ input_ids: torch.LongTensor,
56
+ labels: Optional[torch.LongTensor] = None,
57
+ return_dict: Optional[bool] = True,
58
+ ):
59
+ # Forward pass through the model
60
+ logits = self.model(input_ids)
61
+
62
+ loss = None
63
+ if labels is not None:
64
+ shift_logits = logits[..., :-1, :].contiguous()
65
+ shift_labels = labels[..., 1:].contiguous()
66
+ loss_fct = CrossEntropyLoss()
67
+ loss = loss_fct(
68
+ shift_logits.view(-1, self.config.vocab_size),
69
+ shift_labels.view(-1),
70
+ )
71
+
72
+ if not return_dict:
73
+ return (loss, logits) if loss is not None else (logits,)
74
+
75
+ return CausalLMOutputWithPast(
76
+ loss=loss,
77
+ logits=logits,
78
+ )