Error with long text

#4
by hoan - opened

Hi,

I tried this model and it has a problem with very long text, probably the problem with NeoBERT code.
The quick workaround is to set the model.max_seq_length = 4,096

File ~/.cache/huggingface/modules/transformers_modules/chandar-lab/NeoBERT/a4fbc49a61db10ff2db66140ae59c09d96c027f9/model.py:271, in NeoBERT.forward(self, input_ids, posi
tion_ids, max_seqlen, cu_seqlens, attention_mask, output_hidden_states, output_attentions, **kwargs)                                                                      
    269 # Transformer encoder                                                                                                                                             
    270 for layer in self.transformer_encoder:                                                                                                                            
--> 271     x, attn = layer(x, attention_mask, freqs_cis, output_attentions, max_seqlen, cu_seqlens)                                                                      
    272     if output_hidden_states:                                                                                                                                      
    273         hidden_states.append(x)                                                                                                                                   
                                                                                                                                                                          
File ~/miniconda3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)                     
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]                                                                                        
   1738 else:                                                                                                                                                             
-> 1739     return self._call_impl(*args, **kwargs)                                                                                                                       
                                                                                                                                                                          
File ~/miniconda3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)                             
   1745 # If we don't have any hooks, we want to skip the rest of the logic in                                                                                            
   1746 # this function, and just call forward.                                                                                                                           
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks                                                        
   1748         or _global_backward_pre_hooks or _global_backward_hooks                                                                                                   
   1749         or _global_forward_hooks or _global_forward_pre_hooks):                                                                                                   
-> 1750     return forward_call(*args, **kwargs)                                                                                                                          
   1752 result = None                                                                                                                                                     
   1753 called_always_called_hooks = set()                                                                                                                                
                                                                                                                                                                          
File ~/.cache/huggingface/modules/transformers_modules/chandar-lab/NeoBERT/a4fbc49a61db10ff2db66140ae59c09d96c027f9/model.py:136, in EncoderBlock.forward(self, x, attenti
on_mask, freqs_cis, output_attentions, max_seqlen, cu_seqlens)                                                                                                            
    126 def forward(                                                                                                                                                      
    127     self,                                                                                                                                                         
    128     x: torch.Tensor,                                                                                                                                              
   (...)                                                                             
    134 ):                                                                           
    135     # Attention                                                              
--> 136     attn_output, attn_weights = self._att_block(                             
    137         self.attention_norm(x), attention_mask, freqs_cis, output_attentions, max_seqlen, cu_seqlens                                                              
    138     )                                                                        
    140     # Residual                                                               
    141     x = x + attn_output                                                      

File ~/.cache/huggingface/modules/transformers_modules/chandar-lab/NeoBERT/a4fbc49a61db10ff2db66140ae59c09d96c027f9/model.py:161, in EncoderBlock._att_block(self, x, atte
ntion_mask, freqs_cis, output_attentions, max_seqlen, cu_seqlens)                    
    157 batch_size, seq_len, _ = x.shape                                             
    159 xq, xk, xv = self.qkv(x).view(batch_size, seq_len, self.config.num_attention_heads, self.config.dim_head * 3).chunk(3, axis=-1)                                   
--> 161 xq, xk = apply_rotary_emb(xq, xk, freqs_cis)                                 
    163 # Attn block                                                                 
    164 attn_weights = None                                                          

File ~/.cache/huggingface/modules/transformers_modules/chandar-lab/NeoBERT/a4fbc49a61db10ff2db66140ae59c09d96c027f9/rotary.py:58, in apply_rotary_emb(xq, xk, freqs_cis)  
     56 xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))                                                                                            
     57 xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))                                                                                            
---> 58 freqs_cis = reshape_for_broadcast(freqs_cis, xq_)                            
     59 xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)                                                                                                           
     60 xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)                                                                                                           

File ~/.cache/huggingface/modules/transformers_modules/chandar-lab/NeoBERT/a4fbc49a61db10ff2db66140ae59c09d96c027f9/rotary.py:31, in reshape_for_broadcast(freqs_cis, x)  
     30 def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):                                                                                              
---> 31     assert freqs_cis.shape[1:] == (x.shape[1], x.shape[-1])                                                                                                       
     32     return freqs_cis.contiguous().unsqueeze(2)                               

AssertionError:  

Sign up or log in to comment