BucketOfFish commited on
Commit
df388cc
1 Parent(s): c572a14

Corrected rotary embedding

Browse files
Files changed (1) hide show
  1. attention.py +36 -16
attention.py CHANGED
@@ -28,7 +28,7 @@ class RotaryEmbedding(nn.Module):
28
  d_rotary: int,
29
  rotary_base: float = 10000.0,
30
  initial_cos_sin_cache_len: int = 2048,
31
- device: torch.device = "cuda",
32
  ) -> None:
33
  super().__init__()
34
  self.d_rotary = d_rotary
@@ -37,31 +37,37 @@ class RotaryEmbedding(nn.Module):
37
  self.dtype = torch.float32
38
  self._update_cos_sin_cache(seqlen=initial_cos_sin_cache_len)
39
 
40
- def _update_cos_sin_cache(self, seqlen: int) -> None:
 
 
 
 
 
41
  # only call this function when seqlen is larger than _max_seqlen
42
  self._max_seqlen = seqlen
43
 
44
  # m * theta_i = m * base^(-2i/d) = m * (1 / base^(2i/d)), where i in [1, d/2]
45
  m = torch.arange(
46
  seqlen,
47
- device=self.device,
48
- dtype=self.dtype,
49
  )
50
  theta_i = 1.0 / (
51
  self.rotary_base ** (
52
  torch.arange(
53
  start=0,
54
  end=self.d_rotary,
55
- device=self.device,
56
- dtype=self.dtype,
 
57
  ) / self.d_rotary
58
  )
59
  )
60
  # torch.outer, since torch.einsum converts from fp32 to fp16 if used with torch.amp
61
  # TODO: does this matter if I'm disabling torch.autocast?
62
  m_theta_i = torch.outer(m, theta_i)
63
- self._cos_cached = torch.cos(m_theta_i).to(self.dtype).to(self.device)
64
- self._sin_cached = torch.sin(m_theta_i).to(self.dtype).to(self.device)
65
 
66
  # TODO: scale_base caching is labelled as not yet done in Phi2
67
  """
@@ -90,14 +96,17 @@ class RotaryEmbedding(nn.Module):
90
  sin: torch.FloatTensor, # dim: (_max_seqlen, d_rotary)
91
  ) -> torch.FloatTensor:
92
  seqlen = x.shape[1]
93
- x1, x2 = x.chunk(2, dim=-1) # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_head/2)
 
 
94
  broadcast_rearrange = "s d -> s 1 d" if x1.ndim == 4 else "s d -> s 1 1 d"
95
  c, s = rearrange(cos[:seqlen], broadcast_rearrange), rearrange(sin[:seqlen], broadcast_rearrange)
96
  x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]] # make sure rotary embedding is in float32
97
- return cast(
98
  torch.FloatTensor,
99
  torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], dim=-1).to(x.dtype)
100
  )
 
101
 
102
  def forward(
103
  self,
@@ -107,9 +116,11 @@ class RotaryEmbedding(nn.Module):
107
  if (
108
  not self._max_seqlen
109
  or self._max_seqlen < x.shape[1] + seqlen_offset
 
 
110
  or (self.training and self._cos_cached.is_inference())
111
  ):
112
- self._update_cos_sin_cache(seqlen=x.shape[1] + seqlen_offset)
113
  return self._apply_rotary_emb_qkv(
114
  x,
115
  cast(torch.FloatTensor, self._cos_cached[seqlen_offset:]),
@@ -269,7 +280,8 @@ class MHA(nn.Module):
269
  else RotaryEmbedding
270
  )
271
  self.rotary_emb = rotary_cls(
272
- d_rotary=math.ceil((d_embedding // n_attn_heads) / 2), # d_rotary is half of d_head
 
273
  initial_cos_sin_cache_len=initial_cos_sin_cache_len,
274
  )
275
 
@@ -378,12 +390,20 @@ class MHA(nn.Module):
378
  kv_cache: KVCache,
379
  key_padding_mask: torch.BoolTensor | None,
380
  ) -> torch.FloatTensor:
381
- q = qkv[:, :, 0, :, :]
382
- q = self.rotary_emb(
383
- q,
384
  seqlen_offset = 0 if kv_cache is None else kv_cache.seqlen_offset,
385
  )
386
- kv = cast(torch.FloatTensor, qkv[:, :, 1:, :, :])
 
 
 
 
 
 
 
 
387
  self._update_kv_cache(kv, kv_cache, self.block_n)
388
  causal = False # turning off causal mask for cross attention
389
 
 
28
  d_rotary: int,
29
  rotary_base: float = 10000.0,
30
  initial_cos_sin_cache_len: int = 2048,
31
+ device: torch.device | None = None,
32
  ) -> None:
33
  super().__init__()
34
  self.d_rotary = d_rotary
 
37
  self.dtype = torch.float32
38
  self._update_cos_sin_cache(seqlen=initial_cos_sin_cache_len)
39
 
40
+ def _update_cos_sin_cache(
41
+ self,
42
+ seqlen: int,
43
+ device: str | None = None,
44
+ dtype: torch.dtype | None = None,
45
+ ) -> None:
46
  # only call this function when seqlen is larger than _max_seqlen
47
  self._max_seqlen = seqlen
48
 
49
  # m * theta_i = m * base^(-2i/d) = m * (1 / base^(2i/d)), where i in [1, d/2]
50
  m = torch.arange(
51
  seqlen,
52
+ device=device,
53
+ dtype=torch.float32,
54
  )
55
  theta_i = 1.0 / (
56
  self.rotary_base ** (
57
  torch.arange(
58
  start=0,
59
  end=self.d_rotary,
60
+ step=2,
61
+ device=device,
62
+ dtype=torch.float32,
63
  ) / self.d_rotary
64
  )
65
  )
66
  # torch.outer, since torch.einsum converts from fp32 to fp16 if used with torch.amp
67
  # TODO: does this matter if I'm disabling torch.autocast?
68
  m_theta_i = torch.outer(m, theta_i)
69
+ self._cos_cached = torch.cos(m_theta_i).to(dtype)
70
+ self._sin_cached = torch.sin(m_theta_i).to(dtype)
71
 
72
  # TODO: scale_base caching is labelled as not yet done in Phi2
73
  """
 
96
  sin: torch.FloatTensor, # dim: (_max_seqlen, d_rotary)
97
  ) -> torch.FloatTensor:
98
  seqlen = x.shape[1]
99
+ x_to_rotate = x[..., :self.d_rotary]
100
+ x_to_keep_unrotated = x[..., self.d_rotary:]
101
+ x1, x2 = x_to_rotate.chunk(2, dim=-1) # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_rotary/2)
102
  broadcast_rearrange = "s d -> s 1 d" if x1.ndim == 4 else "s d -> s 1 1 d"
103
  c, s = rearrange(cos[:seqlen], broadcast_rearrange), rearrange(sin[:seqlen], broadcast_rearrange)
104
  x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]] # make sure rotary embedding is in float32
105
+ x_rotated = cast(
106
  torch.FloatTensor,
107
  torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], dim=-1).to(x.dtype)
108
  )
109
+ return torch.cat([x_rotated, x_to_keep_unrotated], axis=-1)
110
 
111
  def forward(
112
  self,
 
116
  if (
117
  not self._max_seqlen
118
  or self._max_seqlen < x.shape[1] + seqlen_offset
119
+ or self._cos_cached.device != x.device
120
+ or self._cos_cached.dtype != x.dtype
121
  or (self.training and self._cos_cached.is_inference())
122
  ):
123
+ self._update_cos_sin_cache(seqlen=x.shape[1] + seqlen_offset, device=x.device, dtype=x.dtype)
124
  return self._apply_rotary_emb_qkv(
125
  x,
126
  cast(torch.FloatTensor, self._cos_cached[seqlen_offset:]),
 
280
  else RotaryEmbedding
281
  )
282
  self.rotary_emb = rotary_cls(
283
+ # d_rotary=math.ceil((d_embedding // n_attn_heads) / 2), # d_rotary is half of d_head
284
+ d_rotary=32, # TODO: figure out why Phi2 uses this
285
  initial_cos_sin_cache_len=initial_cos_sin_cache_len,
286
  )
287
 
 
390
  kv_cache: KVCache,
391
  key_padding_mask: torch.BoolTensor | None,
392
  ) -> torch.FloatTensor:
393
+ qk = qkv[:, :, :2, :, :]
394
+ qk = self.rotary_emb(
395
+ qk,
396
  seqlen_offset = 0 if kv_cache is None else kv_cache.seqlen_offset,
397
  )
398
+ v = cast(torch.FloatTensor, qkv[:, :, 2, :, :])
399
+ q = qk[:, :, 0, :, :]
400
+ kv = torch.cat(
401
+ [
402
+ qk[:, :, 1, :, :].unsqueeze(2),
403
+ v.unsqueeze(2),
404
+ ],
405
+ dim=2,
406
+ )
407
  self._update_kv_cache(kv, kv_cache, self.block_n)
408
  causal = False # turning off causal mask for cross attention
409