Text Generation
Transformers
PyTorch
Safetensors
English
gpt_refact
code
custom_code
Eval Results
svakhreev commited on
Commit
38cebfc
1 Parent(s): 82cd6f7

Update modeling_gpt_refact.py

Browse files
Files changed (1) hide show
  1. modeling_gpt_refact.py +1 -1
modeling_gpt_refact.py CHANGED
@@ -151,7 +151,7 @@ class Attention(nn.Module):
151
  upcast = dtype != softmax_dtype
152
  unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
153
 
154
- attn_weights = alibi + torch.matmul(query * self.scale, key)
155
 
156
  if upcast:
157
  if attention_mask is None:
 
151
  upcast = dtype != softmax_dtype
152
  unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
153
 
154
+ attn_weights = (alibi + torch.matmul(query * self.scale, key)).to(query.dtype)
155
 
156
  if upcast:
157
  if attention_mask is None: