aframson commited on
Commit
42fb9e5
·
1 Parent(s): 0628687
Files changed (1) hide show
  1. modelLM.py +37 -13
modelLM.py CHANGED
@@ -1,13 +1,18 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
- from transformers import PreTrainedModel
5
 
 
 
 
 
6
  class OBILanguageModel(PreTrainedModel):
7
  def __init__(self, config):
8
- super(OBILanguageModel, self).__init__(config)
9
- self.token_embedding_table = nn.Embedding(config.vocab_size, config.hidden_size)
10
  self.position_embedding_table = nn.Embedding(config.block_size, config.hidden_size)
 
11
 
12
  self.transformer = nn.Transformer(
13
  d_model=config.hidden_size,
@@ -20,21 +25,40 @@ class OBILanguageModel(PreTrainedModel):
20
  )
21
  self.ln1 = nn.LayerNorm(config.hidden_size)
22
  self.ln2 = nn.LayerNorm(config.hidden_size)
23
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
24
 
25
- def forward(self, input_ids, attention_mask=None, labels=None):
26
- tok_emb = self.token_embedding_table(input_ids)
27
- pos_emb = self.position_embedding_table(torch.arange(input_ids.size(1), device='cpu'))
 
 
 
 
 
 
 
 
 
 
28
 
29
  x = tok_emb + pos_emb
30
- x = self.transformer(x)
31
  x = self.ln1(x)
32
  x = self.ln2(x)
33
  logits = self.lm_head(x)
34
 
35
- # Always compute the loss if labels are provided
36
- loss = None
37
- if labels is not None:
38
- loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), labels.view(-1))
 
39
 
40
- return {"logits": logits, "loss": loss}
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
+ from transformers.modeling_utils import PreTrainedModel
5
 
6
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
7
+
8
+
9
+ # Define your custom language model class
10
  class OBILanguageModel(PreTrainedModel):
11
  def __init__(self, config):
12
+ super(OBILanguageModel,self).__init__(config)
13
+ self.token_embedding_table = nn.Embedding(config.vocab_size, config.hidden_size) # Use length of SentencePiece vocab
14
  self.position_embedding_table = nn.Embedding(config.block_size, config.hidden_size)
15
+
16
 
17
  self.transformer = nn.Transformer(
18
  d_model=config.hidden_size,
 
25
  )
26
  self.ln1 = nn.LayerNorm(config.hidden_size)
27
  self.ln2 = nn.LayerNorm(config.hidden_size)
28
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) # Use length of SentencePiece vocab
29
 
30
+
31
+
32
+ def forward(self, idx, targets=None):
33
+ tok_emb = self.token_embedding_table(idx)
34
+ pos_emb = None # Initialize pos_emb to None
35
+ try:
36
+ pos_emb = self.position_embedding_table(torch.arange(idx.size(1), device='cpu'))
37
+ except IndexError as e:
38
+ # Handle the IndexError by initializing pos_emb with zeros
39
+ print(f"IndexError: {e}")
40
+ print(f"idx.size(1): {idx.size(1)}")
41
+ print(f"Positional embedding table shape: {self.position_embedding_table.weight.shape}")
42
+ pos_emb = torch.zeros((idx.size(1), self.config.hidden_size), device=device)
43
 
44
  x = tok_emb + pos_emb
45
+ x = self.transformer(x, x)
46
  x = self.ln1(x)
47
  x = self.ln2(x)
48
  logits = self.lm_head(x)
49
 
50
+ # Always compute the loss, and set it to None if targets are not provided
51
+ loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), targets.view(-1)) if targets is not None else None
52
+
53
+ return (logits, loss)
54
+
55
 
56
+ def generate(self, idx, max_new_tokens):
57
+ for _ in range(max_new_tokens):
58
+ idx_cond = idx[:, -self.config.block_size:]
59
+ logits, loss = self(idx_cond)
60
+ logits = logits[:, -1, :]
61
+ probs = F.softmax(logits, dim=-1)
62
+ idx_next = torch.multinomial(probs, num_samples=1)
63
+ idx = torch.cat((idx, idx_next), dim=1)
64
+ return idx