FalconLLM commited on
Commit
357dc3f
1 Parent(s): 38377d0

Update modelling_RW.py

Browse files
Files changed (1) hide show
  1. modelling_RW.py +1 -1
modelling_RW.py CHANGED
@@ -89,7 +89,7 @@ class RotaryEmbedding(torch.nn.Module):
89
 
90
  def forward(self, q, k):
91
  batch, seq_len, head_dim = q.shape
92
- cos, sin = self.cos_sin(seq_len, q.device)
93
  return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
94
 
95
 
 
89
 
90
  def forward(self, q, k):
91
  batch, seq_len, head_dim = q.shape
92
+ cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
93
  return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
94
 
95