myownskyW7 commited on
Commit
8662ae5
1 Parent(s): 4ef8d94

Upload modeling_InternLM.py

Browse files

Remove the dependency of flash-attention and rotary_emb

Files changed (1) hide show
  1. modeling_InternLM.py +43 -44
modeling_InternLM.py CHANGED
@@ -2,12 +2,10 @@ import math
2
  from typing import List, Union
3
  from typing import Optional, Tuple
4
 
5
- import rotary_emb
6
  import torch
7
  import torch.utils.checkpoint
8
  import torch.utils.checkpoint
9
  from einops import rearrange
10
- from flash_attn.layers.rotary import ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_
11
  from torch import nn
12
  from torch.nn import CrossEntropyLoss
13
  from transformers.activations import ACT2FN
@@ -23,51 +21,70 @@ logger = logging.get_logger(__name__)
23
  _CONFIG_FOR_DOC = "InternLMXComposerConfig"
24
 
25
 
26
- class ApplyRotaryEmbQKV_(torch.autograd.Function):
27
- """
28
- ApplyRotaryEmbQKV_
29
- """
 
 
 
 
 
 
 
30
  @staticmethod
31
- def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None):
32
  """
33
- qkv: (total, 3, nheads, headdim)
34
  cos, sin: (seqlen, rotary_dim / 2)
35
  cos_k, sin_k: (seqlen, rotary_dim / 2), optional
 
 
36
  rotary_dim must be <= headdim
37
  Apply rotary embedding *inplace* to the first rotary_dim of q and k.
38
  """
39
- _, three, _, headdim = qkv.shape
40
  assert three == 3
41
  rotary_seqlen, rotary_dim = cos.shape
42
  rotary_dim *= 2
43
  assert rotary_dim <= headdim
 
44
  cos_k = cos if cos_k is None else cos_k
45
  sin_k = sin if sin_k is None else sin_k
46
- assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen,
47
- rotary_dim // 2)
48
- q1, q2 = qkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
49
- rotary_emb.apply_rotary(q1, q2, rearrange(cos, "s d -> s 1 d"),
50
- rearrange(sin, "s d -> s 1 d"), q1, q2, False)
51
- k1, k2 = qkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
52
- rotary_emb.apply_rotary(k1, k2, rearrange(cos_k, "s d -> s 1 d"),
53
- rearrange(sin_k, "s d -> s 1 d"), k1, k2,
54
- False)
 
 
 
 
55
  ctx.save_for_backward(cos, sin, cos_k, sin_k)
 
56
  return qkv
57
 
58
  @staticmethod
59
  def backward(ctx, dqkv):
60
  cos, sin, cos_k, sin_k = ctx.saved_tensors
 
61
  rotary_dim = cos.shape[-1]
62
  rotary_dim *= 2
63
- dq1, dq2 = dqkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
64
- rotary_emb.apply_rotary(dq1, dq2, rearrange(cos, "s d -> s 1 d"),
65
- rearrange(sin, "s d -> s 1 d"), dq1, dq2, True)
66
- dk1, dk2 = dqkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
67
- rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k, "s d -> s 1 d"),
68
- rearrange(sin_k, "s d -> s 1 d"), dk1, dk2,
69
- True)
70
- return dqkv, None, None, None, None
 
 
 
71
 
72
 
73
  class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
@@ -120,23 +137,6 @@ class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
120
  self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
121
  self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
122
 
123
- def forward(self,
124
- qkv: torch.Tensor,
125
- indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
126
- self._update_cos_sin_cache(qkv, indexes)
127
- if self.scale is None:
128
- return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes],
129
- self._sin_cached[indexes]).to(
130
- qkv.dtype)
131
- else:
132
- return apply_rotary_emb_qkv_(
133
- qkv,
134
- self._cos_cached[indexes],
135
- self._sin_cached[indexes],
136
- self._cos_k_cached[indexes],
137
- self._sin_k_cached[indexes],
138
- ).to(qkv.dtype)
139
-
140
  def eval_forward(self, qkv, seqlen_offset=0):
141
  """
142
  seqlen_offset: can be used in generation where the qkv being passed in is only the last
@@ -157,7 +157,6 @@ class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
157
  )
158
 
159
 
160
- apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
161
  legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
162
 
163
 
 
2
  from typing import List, Union
3
  from typing import Optional, Tuple
4
 
 
5
  import torch
6
  import torch.utils.checkpoint
7
  import torch.utils.checkpoint
8
  from einops import rearrange
 
9
  from torch import nn
10
  from torch.nn import CrossEntropyLoss
11
  from transformers.activations import ACT2FN
 
21
  _CONFIG_FOR_DOC = "InternLMXComposerConfig"
22
 
23
 
24
+ def rotary_embed(x1, x2, cos, sin, conj):
25
+ x1, x2 = x1.float(), x2.float()
26
+ if conj:
27
+ x1, x2 = x1 * cos + x2 * sin, x1 * sin + x2 * cos
28
+ else:
29
+ x1, x2 = x1 * cos - x2 * sin, x1 * sin + x2 * cos
30
+ return x1, x2
31
+
32
+
33
+ class LegacyApplyRotaryEmbQKV_(torch.autograd.Function):
34
+
35
  @staticmethod
36
+ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
37
  """
38
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
39
  cos, sin: (seqlen, rotary_dim / 2)
40
  cos_k, sin_k: (seqlen, rotary_dim / 2), optional
41
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
42
+ 1st half and 2nd half (GPT-NeoX style).
43
  rotary_dim must be <= headdim
44
  Apply rotary embedding *inplace* to the first rotary_dim of q and k.
45
  """
46
+ batch, seqlen, three, nheads, headdim = qkv.shape
47
  assert three == 3
48
  rotary_seqlen, rotary_dim = cos.shape
49
  rotary_dim *= 2
50
  assert rotary_dim <= headdim
51
+ assert seqlen <= rotary_seqlen
52
  cos_k = cos if cos_k is None else cos_k
53
  sin_k = sin if sin_k is None else sin_k
54
+ assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
55
+ q_ro = qkv[:, :, 0, :, :rotary_dim]
56
+ q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
57
+ # rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
58
+ # rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
59
+ q1, q2 = rotary_embed(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'), rearrange(sin[:seqlen], 's d -> s 1 d'), False)
60
+ qkv[:, :, 0, :, :rotary_dim] = torch.cat([q1, q2], dim=-1)
61
+ k_ro = qkv[:, :, 1, :, :rotary_dim]
62
+ k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
63
+ # rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
64
+ # rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
65
+ k1, k2 = rotary_embed(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'), rearrange(sin_k[:seqlen], 's d -> s 1 d'), False)
66
+ qkv[:, :, 1, :, :rotary_dim] = torch.cat([k1, k2], dim=-1)
67
  ctx.save_for_backward(cos, sin, cos_k, sin_k)
68
+ ctx.interleaved = interleaved
69
  return qkv
70
 
71
  @staticmethod
72
  def backward(ctx, dqkv):
73
  cos, sin, cos_k, sin_k = ctx.saved_tensors
74
+ _, seqlen, _, _, headdim = dqkv.shape
75
  rotary_dim = cos.shape[-1]
76
  rotary_dim *= 2
77
+ dq_ro = dqkv[:, :, 0, :, :rotary_dim]
78
+ dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved
79
+ else (dq_ro[..., ::2], dq_ro[..., 1::2]))
80
+ rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
81
+ rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True)
82
+ dk_ro = dqkv[:, :, 1, :, :rotary_dim]
83
+ dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
84
+ else (dk_ro[..., ::2], dk_ro[..., 1::2]))
85
+ rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
86
+ rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
87
+ return dqkv, None, None, None, None, None
88
 
89
 
90
  class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
 
137
  self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
138
  self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  def eval_forward(self, qkv, seqlen_offset=0):
141
  """
142
  seqlen_offset: can be used in generation where the qkv being passed in is only the last
 
157
  )
158
 
159
 
 
160
  legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
161
 
162