oweller2
		
	commited on
		
		
					Commit 
							
							·
						
						b44e834
	
1
								Parent(s):
							
							3574e72
								
fix
Browse files- modeling_flexbert.py +8 -1
    	
        modeling_flexbert.py
    CHANGED
    
    | @@ -1643,7 +1643,14 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel): | |
| 1643 |  | 
| 1644 | 
             
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 1645 | 
             
                    if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
         | 
| 1646 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1647 | 
             
                        if attention_mask is None:  # Create causal mask (lower triangular)
         | 
| 1648 | 
             
                            attention_mask = torch.tril(torch.ones(batch_size, seq_len, device=input_ids.device), diagonal=0)
         | 
| 1649 | 
             
                        input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
         | 
|  | |
| 1643 |  | 
| 1644 | 
             
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 1645 | 
             
                    if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
         | 
| 1646 | 
            +
                        if input_ids.dim() == 2:
         | 
| 1647 | 
            +
                            batch_size, seq_len = input_ids.shape[:2]
         | 
| 1648 | 
            +
                        elif input_ids.dim() >= 3:
         | 
| 1649 | 
            +
                            batch_size, seq_len = input_ids.shape[:2]
         | 
| 1650 | 
            +
                        else: # dim is 1
         | 
| 1651 | 
            +
                            batch_size, seq_len = input_ids.shape[0], 1
         | 
| 1652 | 
            +
                            input_ids = input_ids.unsqueeze(1)
         | 
| 1653 | 
            +
                        
         | 
| 1654 | 
             
                        if attention_mask is None:  # Create causal mask (lower triangular)
         | 
| 1655 | 
             
                            attention_mask = torch.tril(torch.ones(batch_size, seq_len, device=input_ids.device), diagonal=0)
         | 
| 1656 | 
             
                        input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
         | 
