zxdu20 commited on
Commit
d4832e8
1 Parent(s): a034f2a

Add support for float32

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +2 -2
modeling_chatglm.py CHANGED
@@ -254,13 +254,13 @@ def attention_fn(
254
  if not (attention_mask == 0).all():
255
  # if auto-regressive, skip
256
  attention_scores.masked_fill_(attention_mask, -10000.0)
257
-
258
  attention_scores = attention_scores.float()
259
  attention_scores = attention_scores * query_key_layer_scaling_coeff
260
 
261
  attention_probs = F.softmax(attention_scores, dim=-1)
262
 
263
- attention_probs = attention_probs.half()
264
 
265
  # =========================
266
  # Context layer. [sq, b, hp]
 
254
  if not (attention_mask == 0).all():
255
  # if auto-regressive, skip
256
  attention_scores.masked_fill_(attention_mask, -10000.0)
257
+ dtype = attention_scores.type()
258
  attention_scores = attention_scores.float()
259
  attention_scores = attention_scores * query_key_layer_scaling_coeff
260
 
261
  attention_probs = F.softmax(attention_scores, dim=-1)
262
 
263
+ attention_probs = attention_probs.type(dtype)
264
 
265
  # =========================
266
  # Context layer. [sq, b, hp]