oweller2 commited on
Commit
59c0d24
·
1 Parent(s): e4b57a3

reshape logits

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. modeling_flexbert.py +2 -5
config.json CHANGED
@@ -69,7 +69,7 @@
69
  "num_attention_heads": 12,
70
  "num_hidden_layers": 22,
71
  "num_initial_layers": 1,
72
- "pad_logits": true,
73
  "pad_token_id": 0,
74
  "padding": "unpadded",
75
  "pooling_type": "cls",
 
69
  "num_attention_heads": 12,
70
  "num_hidden_layers": 22,
71
  "num_initial_layers": 1,
72
+ "pad_logits": false,
73
  "pad_token_id": 0,
74
  "padding": "unpadded",
75
  "pooling_type": "cls",
modeling_flexbert.py CHANGED
@@ -1702,12 +1702,9 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1702
  shift_labels.view(-1)
1703
  )
1704
 
1705
- # if self.unpad_embeddings:
1706
- # # reshape to batch size
1707
- # logits = logits.view(-1, self.vocab_size)
1708
- # # NOTE: error from here above
1709
 
1710
- breakpoint()
1711
  if self.pad_logits:
1712
  # print(f"Padding logits: {logits.shape}")
1713
  new_logits = self.pad_inputs(logits, indices, batch_size, seq_len-1)[0]
 
1702
  shift_labels.view(-1)
1703
  )
1704
 
1705
+ if self.unpad_embeddings:
1706
+ logits = logits.view(batch_size, -1, self.vocab_size)
 
 
1707
 
 
1708
  if self.pad_logits:
1709
  # print(f"Padding logits: {logits.shape}")
1710
  new_logits = self.pad_inputs(logits, indices, batch_size, seq_len-1)[0]