nielsr HF staff commited on
Commit
b8d892d
1 Parent(s): 5e46e8b

Add print statements

Browse files
Files changed (1) hide show
  1. modeling_cogvlm.py +6 -2
modeling_cogvlm.py CHANGED
@@ -290,12 +290,14 @@ class CogVLMDecoderLayer(nn.Module):
290
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
291
  output_attentions: Optional[bool] = False,
292
  use_cache: Optional[bool] = False,
 
293
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
294
  residual = hidden_states
295
 
296
  hidden_states = self.input_layernorm(hidden_states)
297
 
298
- print("Hidden states before self attention:", hidden_states[0,:3,:3])
 
299
 
300
  # Self Attention
301
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
@@ -308,7 +310,8 @@ class CogVLMDecoderLayer(nn.Module):
308
  use_cache=use_cache,
309
  )
310
 
311
- print("Hidden states after self attention:", hidden_states[0,:3,:3])
 
312
 
313
  hidden_states = residual + hidden_states
314
 
@@ -539,6 +542,7 @@ class CogVLMModel(CogVLMPreTrainedModel):
539
  past_key_value=past_key_value,
540
  output_attentions=output_attentions,
541
  use_cache=use_cache,
 
542
  )
543
  hidden_states = layer_outputs[0]
544
 
 
290
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
291
  output_attentions: Optional[bool] = False,
292
  use_cache: Optional[bool] = False,
293
+ print_values = False,
294
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
295
  residual = hidden_states
296
 
297
  hidden_states = self.input_layernorm(hidden_states)
298
 
299
+ if print_values:
300
+ print("Hidden states before self attention:", hidden_states[0,:3,:3])
301
 
302
  # Self Attention
303
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
 
310
  use_cache=use_cache,
311
  )
312
 
313
+ if print_values:
314
+ print("Hidden states after self attention:", hidden_states[0,:3,:3])
315
 
316
  hidden_states = residual + hidden_states
317
 
 
542
  past_key_value=past_key_value,
543
  output_attentions=output_attentions,
544
  use_cache=use_cache,
545
+ print_values=idx in [0, 1, 2],
546
  )
547
  hidden_states = layer_outputs[0]
548