aframson commited on
Commit
1a401cc
·
1 Parent(s): 97d2dbc
Files changed (1) hide show
  1. modelLM.py +9 -14
modelLM.py CHANGED
@@ -29,20 +29,16 @@ class OBILanguageModel(PreTrainedModel):
29
 
30
 
31
 
32
- def forward(self, idx, targets=None):
33
  tok_emb = self.token_embedding_table(idx)
34
- pos_emb = None # Initialize pos_emb to None
35
- try:
36
- pos_emb = self.position_embedding_table(torch.arange(idx.size(1), device='cpu'))
37
- except IndexError as e:
38
- # Handle the IndexError by initializing pos_emb with zeros
39
- print(f"IndexError: {e}")
40
- print(f"idx.size(1): {idx.size(1)}")
41
- print(f"Positional embedding table shape: {self.position_embedding_table.weight.shape}")
42
- pos_emb = torch.zeros((idx.size(1), self.config.hidden_size), device=device)
43
-
44
  x = tok_emb + pos_emb
45
- x = self.transformer(x, x)
 
 
 
 
 
46
  x = self.ln1(x)
47
  x = self.ln2(x)
48
  logits = self.lm_head(x)
@@ -52,8 +48,7 @@ class OBILanguageModel(PreTrainedModel):
52
  else:
53
  loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), targets.view(-1))
54
 
55
- return (logits, loss) # Return as a tuple
56
-
57
 
58
 
59
  def generate(self, idx, max_new_tokens):
 
29
 
30
 
31
 
32
+ def forward(self, idx, attention_mask=None, targets=None):
33
  tok_emb = self.token_embedding_table(idx)
34
+ pos_emb = self.position_embedding_table(torch.arange(idx.size(1), device='cpu'))
 
 
 
 
 
 
 
 
 
35
  x = tok_emb + pos_emb
36
+
37
+ # Create an attention mask for padding tokens
38
+ if attention_mask is not None:
39
+ attention_mask = attention_mask.to(x.device)
40
+
41
+ x = self.transformer(x, attn_mask=attention_mask) # Pass attention_mask to the transformer
42
  x = self.ln1(x)
43
  x = self.ln2(x)
44
  logits = self.lm_head(x)
 
48
  else:
49
  loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), targets.view(-1))
50
 
51
+ return logits, loss
 
52
 
53
 
54
  def generate(self, idx, max_new_tokens):