oweller2 commited on
Commit
b44e834
·
1 Parent(s): 3574e72
Files changed (1) hide show
  1. modeling_flexbert.py +8 -1
modeling_flexbert.py CHANGED
@@ -1643,7 +1643,14 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1643
 
1644
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1645
  if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
1646
- batch_size, seq_len = input_ids.shape[:2]
 
 
 
 
 
 
 
1647
  if attention_mask is None: # Create causal mask (lower triangular)
1648
  attention_mask = torch.tril(torch.ones(batch_size, seq_len, device=input_ids.device), diagonal=0)
1649
  input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
 
1643
 
1644
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1645
  if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
1646
+ if input_ids.dim() == 2:
1647
+ batch_size, seq_len = input_ids.shape[:2]
1648
+ elif input_ids.dim() >= 3:
1649
+ batch_size, seq_len = input_ids.shape[:2]
1650
+ else: # dim is 1
1651
+ batch_size, seq_len = input_ids.shape[0], 1
1652
+ input_ids = input_ids.unsqueeze(1)
1653
+
1654
  if attention_mask is None: # Create causal mask (lower triangular)
1655
  attention_mask = torch.tril(torch.ones(batch_size, seq_len, device=input_ids.device), diagonal=0)
1656
  input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(