aframson commited on
Commit
8d90985
·
1 Parent(s): 1a401cc
Files changed (1) hide show
  1. modelLM.py +12 -8
modelLM.py CHANGED
@@ -29,16 +29,20 @@ class OBILanguageModel(PreTrainedModel):
29
 
30
 
31
 
32
- def forward(self, idx, attention_mask=None, targets=None):
33
  tok_emb = self.token_embedding_table(idx)
34
- pos_emb = self.position_embedding_table(torch.arange(idx.size(1), device='cpu'))
35
- x = tok_emb + pos_emb
36
-
37
- # Create an attention mask for padding tokens
38
- if attention_mask is not None:
39
- attention_mask = attention_mask.to(x.device)
 
 
 
40
 
41
- x = self.transformer(x, attn_mask=attention_mask) # Pass attention_mask to the transformer
 
42
  x = self.ln1(x)
43
  x = self.ln2(x)
44
  logits = self.lm_head(x)
 
29
 
30
 
31
 
32
+ def forward(self, idx, targets=None):
33
  tok_emb = self.token_embedding_table(idx)
34
+ pos_emb = None # Initialize pos_emb to None
35
+ try:
36
+ pos_emb = self.position_embedding_table(torch.arange(idx.size(1), device='cpu'))
37
+ except IndexError as e:
38
+ # Handle the IndexError by initializing pos_emb with zeros
39
+ print(f"IndexError: {e}")
40
+ print(f"idx.size(1): {idx.size(1)}")
41
+ print(f"Positional embedding table shape: {self.position_embedding_table.weight.shape}")
42
+ pos_emb = torch.zeros((idx.size(1), self.config.hidden_size), device=device)
43
 
44
+ x = tok_emb + pos_emb
45
+ x = self.transformer(x, x)
46
  x = self.ln1(x)
47
  x = self.ln2(x)
48
  logits = self.lm_head(x)