bapatra mrm196 commited on
Commit
f196467
1 Parent(s): e6adf2a

Ensure the query_states and key_states remain in bf16 (#21)

Browse files

- Ensure the query_states and key_states remain in bf16 (55b8e963ff0aa4a4190cff537165f08c378f62ff)


Co-authored-by: Mohammadreza Mohseni <mrm196@users.noreply.huggingface.co>

Files changed (1) hide show
  1. positional_embedding.py +2 -2
positional_embedding.py CHANGED
@@ -269,10 +269,10 @@ class RotaryEmbedding(torch.nn.Module):
269
  return (
270
  apply_rotary_pos_emb(
271
  q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
272
- ),
273
  apply_rotary_pos_emb(
274
  k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
275
- ),
276
  )
277
 
278
  @classmethod
 
269
  return (
270
  apply_rotary_pos_emb(
271
  q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
272
+ ).to(q.dtype),
273
  apply_rotary_pos_emb(
274
  k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
275
+ ).to(k.dtype),
276
  )
277
 
278
  @classmethod