nielsr HF staff commited on
Commit
0dc1d68
1 Parent(s): 8f61efb

Add print statements

Browse files
Files changed (1) hide show
  1. modeling_cogvlm.py +7 -2
modeling_cogvlm.py CHANGED
@@ -117,7 +117,8 @@ def attention_fn(
117
  attention_mask: "torch.tensor(B, H, L, HD)",
118
  *,
119
  scaling_attention_score: bool = True,
120
- attention_dropout: nn.Module = None
 
121
  ):
122
  attention_mask_bool = (attention_mask == 0)
123
  is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
@@ -126,6 +127,10 @@ def attention_fn(
126
  warnings.warn("It's recommended to use torch2.0 or higher.")
127
  if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
128
  dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
 
 
 
 
129
  return torch.nn.functional.scaled_dot_product_attention(
130
  query_layer, key_layer, value_layer,
131
  attn_mask=None,
@@ -302,7 +307,7 @@ class VisionExpertAttention(nn.Module):
302
 
303
  context_layer = attention_fn(
304
  query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
305
- scaling_attention_score=True, attention_dropout=None)
306
 
307
  if print_values:
308
  print("Shape of context_layer:", context_layer.shape)
 
117
  attention_mask: "torch.tensor(B, H, L, HD)",
118
  *,
119
  scaling_attention_score: bool = True,
120
+ attention_dropout: nn.Module = None,
121
+ print_values: bool = False,
122
  ):
123
  attention_mask_bool = (attention_mask == 0)
124
  is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
 
127
  warnings.warn("It's recommended to use torch2.0 or higher.")
128
  if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
129
  dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
130
+
131
+ if print_values:
132
+ print("Is_causal:", not is_full)
133
+
134
  return torch.nn.functional.scaled_dot_product_attention(
135
  query_layer, key_layer, value_layer,
136
  attn_mask=None,
 
307
 
308
  context_layer = attention_fn(
309
  query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
310
+ scaling_attention_score=True, attention_dropout=None, print_values=print_values)
311
 
312
  if print_values:
313
  print("Shape of context_layer:", context_layer.shape)