kq-chen commited on
Commit
f3440bb
1 Parent(s): b7a6ea7

update to v1.1

Browse files
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "cogvlm-grounding-generalist",
3
  "architectures": [
4
  "CogVLMForCausalLM"
5
  ],
 
1
  {
2
+ "_name_or_path": "cogvlm-grounding-generalist-v1-1",
3
  "architectures": [
4
  "CogVLMForCausalLM"
5
  ],
model-00001-of-00008.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ab245c9b171545099d652eefcd863aa20f95b7ec2c18dea754ffa661ddcdeebd
3
  size 4938885184
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ff07f55a4068d8d553593122343209b30b895760fcd83924cf9001c09c683c2
3
  size 4938885184
model-00002-of-00008.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c4c34bd99e8317404ef58c253a01648c120b356b3b8f95933da46dd02fbb73ba
3
  size 4947290688
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ab2d1ce61f81a53be36e9af77f122e74aca0ff875b58006436effa50884c005
3
  size 4947290688
model-00003-of-00008.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0affb830c00b94b9fc5ecfa41d7ae0d62fa42c73300df2537cd7c4496f947014
3
  size 4947307592
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48edfcbc0ee4a397ff8dc5e972e02f0299ba779d5929367c9caae0cc751cf892
3
  size 4947307592
model-00004-of-00008.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ed1d44e236c4af9263f8a32406769d4e8dcf2fefad205a17c2c2191d948c0e07
3
  size 4991331080
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed00da067fca6bdbc1b65f0c1482d716e0ee4ee648170f0588dc6cb0d23588cb
3
  size 4991331080
model-00005-of-00008.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9dc0499401747dd885bd9e212cec57296fad9b0c2d59c8c1984063f9c9d22eeb
3
  size 4991331088
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4a6c6a257805bd07721ebcb5884b1048248d732aa979cd4375761183b806d46
3
  size 4991331088
model-00006-of-00008.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7196fa96d96992a65ae041f2a2921afef149003e059bec884520d386c814ce0a
3
  size 4970162920
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37afe090f7597fa338287fb12f501cb74828005bb0f7a0d51fa6c410711de7ee
3
  size 4970162920
model-00007-of-00008.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:340afeec8059f2d6fb2b890041611a448bb54440ae185d8a5433bf39e2711d9b
3
  size 4960543792
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63a91f88f413ec65b0345f30e394cdfcf9a24fcd820825ecbcd35de78b469fa9
3
  size 4960543792
model-00008-of-00008.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9122674b65c1b3c9c8d498ff684591c106081a906f0c85e9fcb8e1a2cc0a270e
3
  size 532677104
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f97723a9e977d4416d2a70739c56aa8f8f50e55b86ddc71d7a4c9262e5167bbd
3
  size 532677104
modeling_cogvlm.py CHANGED
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, A
5
  import math
6
  import torch
7
  from torch import nn
 
8
  from torch.nn import CrossEntropyLoss
9
  from torchvision import transforms
10
  from einops import rearrange
@@ -15,7 +16,6 @@ from transformers.activations import ACT2FN
15
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
 
17
  from .configuration_cogvlm import CogVLMConfig
18
- from .util import FastRotaryEmbedding
19
  from .visual import EVA2CLIPModel
20
 
21
  if TYPE_CHECKING:
@@ -144,6 +144,57 @@ def attention_fn(
144
  return context_layer
145
 
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  class VisionExpertAttention(nn.Module):
148
  def __init__(self, config):
149
  super().__init__()
@@ -153,8 +204,7 @@ class VisionExpertAttention(nn.Module):
153
  self.head_dim = self.hidden_size // self.num_heads
154
  self.max_position_embeddings = config.max_position_embeddings
155
 
156
- # self.rotary_emb = RotaryEmbedding(self.hidden_size // self.num_heads)
157
- self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False)
158
  self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
159
  self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
160
  self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
@@ -193,8 +243,8 @@ class VisionExpertAttention(nn.Module):
193
  kv_seq_len = key_states.shape[-2]
194
  if past_key_value is not None:
195
  kv_seq_len += past_key_value[0].shape[-2]
196
-
197
- query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids, max_seqlen=position_ids.max() + 1)
198
 
199
  if past_key_value is not None:
200
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
@@ -278,7 +328,7 @@ class CogVLMPreTrainedModel(PreTrainedModel):
278
  config_class = CogVLMConfig
279
  base_model_prefix = "model"
280
  supports_gradient_checkpointing = False
281
- _no_split_modules = ["CogVLMDecoderLayer"]
282
  _skip_keys_device_placement = "past_key_values"
283
 
284
  def _init_weights(self, module):
@@ -538,25 +588,23 @@ class CogVLMModel(CogVLMPreTrainedModel):
538
  return combined_attention_mask
539
 
540
 
541
- def chat_history_to_prompt(history, query):
542
- prompt = " [INST] "
543
- for i, (old_query, response) in enumerate(history):
544
- prompt += old_query + " [/INST] " + response + " [INST] "
545
- prompt += query + " [/INST] "
546
- return prompt
547
-
 
 
548
 
549
- def base_history_to_prompt(history, query):
550
- prompt = query
 
 
551
  return prompt
552
 
553
 
554
- _history_to_prompt = {
555
- "base": base_history_to_prompt,
556
- "chat": chat_history_to_prompt
557
- }
558
-
559
-
560
  class CogVLMForCausalLM(CogVLMPreTrainedModel):
561
  _auto_class = "AutoModelForCausalLM"
562
 
@@ -708,7 +756,8 @@ class CogVLMForCausalLM(CogVLMPreTrainedModel):
708
  # update token_type_ids with last value
709
  if "token_type_ids" in model_kwargs:
710
  token_type_ids = model_kwargs["token_type_ids"]
711
- new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype, device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
 
712
  model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
713
 
714
  if not is_encoder_decoder:
@@ -744,14 +793,14 @@ class CogVLMForCausalLM(CogVLMPreTrainedModel):
744
  query: str,
745
  history: Optional[List[Tuple[str, str]]] = None,
746
  images: Optional[List["PIL.Image"]] = None,
747
- template_version: Optional[Literal["base", "chat"]] = None,
748
  ):
749
  image_size: int = self.config.vision_config['image_size']
750
  patch_size: int = self.config.vision_config['patch_size']
751
  template_version = template_version or self.config.template_version
752
  assert images is None or len(images) <= 1, f"not support multi images by now."
753
  history = history or []
754
- text = _history_to_prompt[template_version](history, query)
755
 
756
  input_ids = [tokenizer.bos_token_id]
757
  token_type_ids = [LANGUAGE_TOKEN_TYPE]
 
5
  import math
6
  import torch
7
  from torch import nn
8
+ from torch.nn import functional as F
9
  from torch.nn import CrossEntropyLoss
10
  from torchvision import transforms
11
  from einops import rearrange
 
16
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
17
 
18
  from .configuration_cogvlm import CogVLMConfig
 
19
  from .visual import EVA2CLIPModel
20
 
21
  if TYPE_CHECKING:
 
144
  return context_layer
145
 
146
 
147
+ class RotaryEmbedding(torch.nn.Module):
148
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
149
+ super().__init__()
150
+
151
+ self.dim = dim
152
+ self.max_position_embeddings = max_position_embeddings
153
+ self.base = base
154
+ inv_freq = self._compute_inv_freq(device)
155
+ self.register_buffer("inv_freq", inv_freq)
156
+ self.max_seq_len_cached = 0
157
+
158
+ def _compute_inv_freq(self, device=None):
159
+ return 1.0 / (
160
+ self.base
161
+ ** (torch.arange(0, self.dim, 2, device=device) / self.dim)
162
+ )
163
+
164
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
165
+ self.max_seq_len_cached = seq_len
166
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
167
+
168
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
169
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
170
+ emb = torch.cat((freqs, freqs), dim=-1)
171
+ self.register_buffer("cos_cached", emb.cos()[:, None, :].to(dtype), persistent=False)
172
+ self.register_buffer("sin_cached", emb.sin()[:, None, :].to(dtype), persistent=False)
173
+
174
+ def forward(self, x, seq_len):
175
+ # x: [bs, num_attention_heads, seq_len, head_size]
176
+ if seq_len > self.max_seq_len_cached:
177
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
178
+
179
+ return (
180
+ self.cos_cached[:seq_len, ...].to(dtype=x.dtype),
181
+ self.sin_cached[:seq_len, ...].to(dtype=x.dtype),
182
+ )
183
+
184
+
185
+ def rotate_half(x):
186
+ x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
187
+ return torch.cat((-x2, x1), dim=x1.ndim - 1)
188
+
189
+
190
+ def apply_rotary_pos_emb_index_bhs(q, k, cos, sin, position_id):
191
+ # batch_size, num_head, seq_len, hidden_size
192
+ cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(1), \
193
+ F.embedding(position_id, sin.squeeze(1)).unsqueeze(1)
194
+ q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
195
+ return q, k
196
+
197
+
198
  class VisionExpertAttention(nn.Module):
199
  def __init__(self, config):
200
  super().__init__()
 
204
  self.head_dim = self.hidden_size // self.num_heads
205
  self.max_position_embeddings = config.max_position_embeddings
206
 
207
+ self.rotary_emb = RotaryEmbedding(self.head_dim)
 
208
  self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
209
  self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
210
  self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
 
243
  kv_seq_len = key_states.shape[-2]
244
  if past_key_value is not None:
245
  kv_seq_len += past_key_value[0].shape[-2]
246
+ cos, sin = self.rotary_emb(value_states, seq_len=position_ids.max() + 1)
247
+ query_states, key_states = apply_rotary_pos_emb_index_bhs(query_states, key_states, cos, sin, position_ids)
248
 
249
  if past_key_value is not None:
250
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
 
328
  config_class = CogVLMConfig
329
  base_model_prefix = "model"
330
  supports_gradient_checkpointing = False
331
+ _no_split_modules = ["CogVLMDecoderLayer", "TransformerLayer"]
332
  _skip_keys_device_placement = "past_key_values"
333
 
334
  def _init_weights(self, module):
 
588
  return combined_attention_mask
589
 
590
 
591
+ def _history_to_prompt(signal_type, history, query):
592
+ if signal_type == 'base':
593
+ return query
594
+ elif signal_type == 'vqa':
595
+ answer_format = 'Short answer:'
596
+ elif signal_type == 'chat':
597
+ answer_format = 'Answer:'
598
+ else:
599
+ assert False, f"Unknown signal type {signal_type}"
600
 
601
+ prompt = ''
602
+ for i, (old_query, response) in enumerate(history):
603
+ prompt += 'Question: ' + old_query + " {} ".format(answer_format) + response + "\n"
604
+ prompt += 'Question: {} {}'.format(query, answer_format)
605
  return prompt
606
 
607
 
 
 
 
 
 
 
608
  class CogVLMForCausalLM(CogVLMPreTrainedModel):
609
  _auto_class = "AutoModelForCausalLM"
610
 
 
756
  # update token_type_ids with last value
757
  if "token_type_ids" in model_kwargs:
758
  token_type_ids = model_kwargs["token_type_ids"]
759
+ new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype,
760
+ device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
761
  model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
762
 
763
  if not is_encoder_decoder:
 
793
  query: str,
794
  history: Optional[List[Tuple[str, str]]] = None,
795
  images: Optional[List["PIL.Image"]] = None,
796
+ template_version: Optional[Literal["base", "chat", "vqa"]] = None,
797
  ):
798
  image_size: int = self.config.vision_config['image_size']
799
  patch_size: int = self.config.vision_config['patch_size']
800
  template_version = template_version or self.config.template_version
801
  assert images is None or len(images) <= 1, f"not support multi images by now."
802
  history = history or []
803
+ text = _history_to_prompt(template_version, history, query)
804
 
805
  input_ids = [tokenizer.bos_token_id]
806
  token_type_ids = [LANGUAGE_TOKEN_TYPE]
util.py DELETED
@@ -1,483 +0,0 @@
1
- from typing import Optional, Tuple, Union
2
-
3
- import torch
4
- from einops import rearrange, repeat
5
- import torch.nn.functional as F
6
-
7
- import triton
8
- import triton.language as tl
9
-
10
-
11
- # @triton.autotune(
12
- # configs=[
13
- # triton.Config({"BLOCK_M": 2}),
14
- # triton.Config({"BLOCK_M": 4}),
15
- # triton.Config({"BLOCK_M": 8}),
16
- # triton.Config({"BLOCK_M": 16}),
17
- # ],
18
- # key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
19
- # )
20
- @triton.jit
21
- def rotary_kernel(
22
- OUT, # Pointers to matrices
23
- X,
24
- COS,
25
- SIN,
26
- CU_SEQLENS,
27
- SEQLEN_OFFSETS, # this could be int or a pointer
28
- # Matrix dimensions
29
- seqlen,
30
- nheads,
31
- rotary_dim,
32
- seqlen_ro,
33
- CACHE_KEY_SEQLEN,
34
- # strides
35
- stride_out_batch,
36
- stride_out_nheads,
37
- stride_out_seqlen,
38
- stride_out_headdim,
39
- stride_x_batch,
40
- stride_x_nheads,
41
- stride_x_seqlen,
42
- stride_x_headdim,
43
- # Meta-parameters
44
- BLOCK_K: tl.constexpr,
45
- IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
46
- IS_VARLEN: tl.constexpr,
47
- INTERLEAVED: tl.constexpr,
48
- CONJUGATE: tl.constexpr,
49
- BLOCK_M: tl.constexpr,
50
- ):
51
- pid_m = tl.program_id(axis=0)
52
- pid_batch = tl.program_id(axis=1)
53
- pid_head = tl.program_id(axis=2)
54
- rotary_dim_half = rotary_dim // 2
55
-
56
- if not IS_VARLEN:
57
- X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
58
- OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
59
- COS = COS + pid_batch * seqlen_ro * rotary_dim_half
60
- SIN = SIN + pid_batch * seqlen_ro * rotary_dim_half
61
- else:
62
- start_idx = tl.load(CU_SEQLENS + pid_batch)
63
- seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
64
- X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
65
- OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
66
-
67
- if pid_m * BLOCK_M >= seqlen:
68
- return
69
- rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
70
- if not IS_SEQLEN_OFFSETS_TENSOR:
71
- rm_cs = rm + SEQLEN_OFFSETS
72
- else:
73
- rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
74
- rk = tl.arange(0, BLOCK_K)
75
- rk_half = tl.arange(0, BLOCK_K // 2)
76
-
77
- if not INTERLEAVED:
78
- # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
79
- X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
80
- COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
81
- SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
82
- cos = tl.load(
83
- COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
84
- )
85
- sin = tl.load(
86
- SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
87
- )
88
- x0 = tl.load(
89
- X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
90
- )
91
- x1 = tl.load(
92
- X + rotary_dim_half * stride_x_headdim,
93
- mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
94
- other=0.0,
95
- )
96
- if CONJUGATE:
97
- sin = -sin
98
- o0 = x0 * cos - x1 * sin
99
- o1 = x0 * sin + x1 * cos
100
- # write back result
101
- OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
102
- tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
103
- tl.store(
104
- OUT + rotary_dim_half * stride_out_headdim,
105
- o1,
106
- mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
107
- )
108
- else:
109
- # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
110
- # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
111
- # Loading x0 will be fast but x1 will be slow.
112
- # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
113
- # Then we do the calculation and use tl.where to pick put the right outputs for the even
114
- # and for the odd indices.
115
- rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
116
- rk_repeat = tl.arange(0, BLOCK_K) // 2
117
- X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
118
- X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
119
- COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
120
- SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
121
- cos = tl.load(
122
- COS,
123
- mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
124
- other=1.0,
125
- ).to(tl.float32)
126
- sin = tl.load(
127
- SIN,
128
- mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
129
- other=0.0,
130
- ).to(tl.float32)
131
- x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
132
- tl.float32
133
- )
134
- x1 = tl.load(
135
- X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
136
- ).to(tl.float32)
137
- if CONJUGATE:
138
- sin = -sin
139
- x0_cos = x0 * cos
140
- x1_sin = x1 * sin
141
- out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
142
- OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
143
- tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
144
-
145
-
146
- def apply_rotary(
147
- x: torch.Tensor,
148
- cos: torch.Tensor,
149
- sin: torch.Tensor,
150
- seqlen_offsets: Union[int, torch.Tensor] = 0,
151
- cu_seqlens: Optional[torch.Tensor] = None,
152
- max_seqlen: Optional[int] = None,
153
- interleaved=False,
154
- inplace=False,
155
- conjugate=False,
156
- ) -> torch.Tensor:
157
- """
158
- Arguments:
159
- x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
160
- else (total_seqlen, nheads, headdim).
161
- cos: (seqlen_ro, rotary_dim / 2)
162
- sin: (seqlen_ro, rotary_dim / 2)
163
- seqlen_offsets: integer or integer tensor of size (batch,)
164
- cu_seqlens: (batch + 1,) or None
165
- max_seqlen: int
166
- Returns:
167
- y: (batch, seqlen, nheads, headdim)
168
- """
169
-
170
- batch, nheads, seqlen, headdim = x.shape
171
-
172
- batch_ro, seqlen_ro, rotary_dim = cos.shape
173
-
174
- assert batch == batch_ro
175
- assert sin.shape == cos.shape
176
- rotary_dim *= 2
177
- assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
178
- assert headdim <= 256, "Only support headdim <= 256"
179
-
180
- assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
181
-
182
- assert (
183
- cos.dtype == sin.dtype
184
- ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
185
- assert (
186
- x.dtype == cos.dtype
187
- ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
188
-
189
- cos, sin = cos.contiguous(), sin.contiguous()
190
- if isinstance(seqlen_offsets, torch.Tensor):
191
- assert seqlen_offsets.shape == (batch,)
192
- assert seqlen_offsets.dtype in [torch.int32, torch.int64]
193
- seqlen_offsets = seqlen_offsets.contiguous()
194
- else:
195
- assert seqlen_offsets + seqlen <= seqlen_ro
196
-
197
- output = torch.empty_like(x) if not inplace else x
198
- if rotary_dim < headdim and not inplace:
199
- output[..., rotary_dim:].copy_(x[..., rotary_dim:])
200
-
201
- BLOCK_K = (
202
- 32
203
- if rotary_dim <= 32
204
- else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
205
- )
206
- grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
207
- BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
208
-
209
- # Need this, otherwise Triton tries to launch from cuda:0 and we get
210
- # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
211
- with torch.cuda.device(x.device.index):
212
- rotary_kernel[grid](
213
- output, # data ptrs
214
- x,
215
- cos,
216
- sin,
217
- cu_seqlens,
218
- seqlen_offsets,
219
- seqlen, # shapes
220
- nheads,
221
- rotary_dim,
222
- seqlen_ro,
223
- seqlen // 128, # key for triton cache (limit number of compilations)
224
- output.stride(0), # batch_strides
225
- output.stride(-3), # nheads_stride
226
- output.stride(-2), # seqlen_stride
227
- output.stride(-1), # headdim_stride
228
- x.stride(0), # batch_strides
229
- x.stride(-3), # nheads stride
230
- x.stride(-2), # seqlen stride
231
- x.stride(-1), # headdim stride
232
- BLOCK_K,
233
- isinstance(seqlen_offsets, torch.Tensor),
234
- False,
235
- interleaved,
236
- conjugate,
237
- BLOCK_M,
238
- )
239
- return output
240
-
241
-
242
- class ApplyRotaryEmb(torch.autograd.Function):
243
- @staticmethod
244
- def forward(
245
- ctx,
246
- x,
247
- cos,
248
- sin,
249
- interleaved=False,
250
- inplace=False,
251
- seqlen_offsets: Union[int, torch.Tensor] = 0,
252
- cu_seqlens: Optional[torch.Tensor] = None,
253
- max_seqlen: Optional[int] = None,
254
- ):
255
- out = apply_rotary(
256
- x,
257
- cos,
258
- sin,
259
- seqlen_offsets=seqlen_offsets,
260
- cu_seqlens=cu_seqlens,
261
- max_seqlen=max_seqlen,
262
- interleaved=interleaved,
263
- inplace=inplace,
264
- )
265
- if isinstance(seqlen_offsets, int):
266
- ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
267
- ctx.seqlen_offsets = seqlen_offsets
268
- else:
269
- ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
270
- ctx.seqlen_offsets = None
271
- ctx.interleaved = interleaved
272
- ctx.inplace = inplace
273
- ctx.max_seqlen = max_seqlen
274
- return out if not inplace else x
275
-
276
- @staticmethod
277
- def backward(ctx, do):
278
- seqlen_offsets = ctx.seqlen_offsets
279
- if seqlen_offsets is None:
280
- cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
281
- else:
282
- cos, sin, cu_seqlens = ctx.saved_tensors
283
- # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
284
- # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
285
- if not ctx.interleaved and not ctx.inplace:
286
- do = do.clone()
287
- dx = apply_rotary(
288
- do,
289
- cos,
290
- sin,
291
- seqlen_offsets=seqlen_offsets,
292
- cu_seqlens=cu_seqlens,
293
- max_seqlen=ctx.max_seqlen,
294
- interleaved=ctx.interleaved,
295
- inplace=ctx.inplace,
296
- conjugate=True,
297
- )
298
- return dx, None, None, None, None, None, None, None
299
-
300
-
301
- def apply_rotary_emb(
302
- x,
303
- cos,
304
- sin,
305
- interleaved=False,
306
- inplace=False,
307
- seqlen_offsets: Union[int, torch.Tensor] = 0,
308
- cu_seqlens: Optional[torch.Tensor] = None,
309
- max_seqlen: Optional[int] = None,
310
- ):
311
- """
312
- Arguments:
313
- x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
314
- else (total_seqlen, nheads, headdim)
315
- cos, sin: (seqlen_rotary, rotary_dim / 2)
316
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
317
- of 1st half and 2nd half (GPT-NeoX style).
318
- inplace: if True, apply rotary embedding in-place.
319
- seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
320
- Most commonly used in inference when we have KV cache.
321
- cu_seqlens: (batch + 1,) or None
322
- max_seqlen: int
323
- Return:
324
- out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
325
- else (total_seqlen, nheads, headdim)
326
- rotary_dim must be <= headdim
327
- Apply rotary embedding to the first rotary_dim of x.
328
- """
329
- return ApplyRotaryEmb.apply(
330
- x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
331
- )
332
-
333
-
334
- # For backward compatibility
335
- apply_rotary_emb_func = apply_rotary_emb
336
-
337
-
338
- class FastRotaryEmbedding(torch.nn.Module):
339
- """
340
- The rotary position embeddings from RoFormer_ (Su et. al).
341
- A crucial insight from the method is that the query and keys are
342
- transformed by rotation matrices which depend on the relative positions.
343
-
344
- Other implementations are available in the Rotary Transformer repo_ and in
345
- GPT-NeoX_, GPT-NeoX was an inspiration
346
-
347
- .. _RoFormer: https://arxiv.org/abs/2104.09864
348
- .. _repo: https://github.com/ZhuiyiTechnology/roformer
349
- .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
350
-
351
- If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
352
- A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
353
- Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
354
- """
355
-
356
- def __init__(
357
- self,
358
- dim: int,
359
- base=10000,
360
- interleaved=False,
361
- scale_base=None,
362
- pos_idx_in_fp32=True,
363
- device=None,
364
- ):
365
- """
366
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
367
- of 1st half and 2nd half (GPT-NeoX style).
368
- pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
369
- otherwise they might be in lower precision.
370
- This option was added because previously (before 2023-07-02), when we construct
371
- the position indices, we use the dtype of self.inv_freq. In most cases this would
372
- be fp32, but if the model is trained in pure bf16 (not mixed precision), then
373
- self.inv_freq would be bf16, and the position indices are also in bf16.
374
- Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
375
- embeddings for some positions will coincide.
376
- To maintain compatibility with models previously trained in pure bf16,
377
- we add this option.
378
- """
379
- super().__init__()
380
- self.dim = dim
381
- self.base = base
382
- self.pos_idx_in_fp32 = pos_idx_in_fp32
383
- # Generate and save the inverse frequency buffer (non trainable)
384
- inv_freq = self._compute_inv_freq(device)
385
- self.register_buffer("inv_freq", inv_freq)
386
- self.interleaved = interleaved
387
- self.scale_base = scale_base
388
- scale = (
389
- (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
390
- if scale_base is not None
391
- else None
392
- )
393
- self.register_buffer("scale", scale, persistent=False)
394
-
395
- self._seq_len_cached = 0
396
- self._cos_cached = None
397
- self._sin_cached = None
398
- self._cos_k_cached = None
399
- self._sin_k_cached = None
400
- self.cos = None
401
- self.sin = None
402
-
403
- def _compute_inv_freq(self, device=None):
404
- return 1.0 / (
405
- self.base
406
- ** (torch.arange(0, self.dim, 2, device=device) / self.dim)
407
- # ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)
408
- )
409
-
410
- def _update_cos_sin_cache(self, seqlen, position_id, device=None, dtype=None):
411
-
412
- if (
413
- seqlen > self._seq_len_cached
414
- ):
415
- self._seq_len_cached = seqlen
416
- # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
417
- # And the output of arange can be quite large, so bf16 would lose a lot of precision.
418
- # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
419
- if self.pos_idx_in_fp32:
420
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
421
- # We want fp32 here as well since inv_freq will be multiplied with t, and the output
422
- # will be large. Having it in bf16 will lose a lot of precision and cause the
423
- # cos & sin output to change significantly.
424
- # We want to recompute self.inv_freq if it was not loaded in fp32
425
- if self.inv_freq.dtype != torch.float32:
426
- inv_freq = self._compute_inv_freq(device=device)
427
- else:
428
- inv_freq = self.inv_freq
429
- else:
430
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
431
- inv_freq = self.inv_freq
432
- freqs = torch.einsum("i,j->ij", t, inv_freq)
433
- if self.scale is None:
434
- self._cos_cached = torch.cos(freqs).to(dtype)
435
- self._sin_cached = torch.sin(freqs).to(dtype)
436
-
437
- else:
438
- power = (
439
- torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
440
- - seqlen // 2
441
- ) / self.scale_base
442
- scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
443
- # We want the multiplication by scale to happen in fp32
444
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
445
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
446
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
447
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
448
-
449
- def forward(
450
- self,
451
- q: torch.Tensor,
452
- k: torch.Tensor,
453
- position_ids: torch.Tensor,
454
- max_seqlen,
455
- ) -> Tuple[torch.Tensor, torch.Tensor]:
456
- """
457
- q: (batch, nheads, seqlen, headdim)
458
- k: (batch, nheads, seqlen, headdim)
459
- position_id: (batch, seqlen)
460
- max_seqlen: int
461
- layer_id: int
462
- only if layer_id == 0, then update cons and sin
463
- Apply rotary embedding *inplace* to q k.
464
- """
465
-
466
- self._update_cos_sin_cache(max_seqlen, position_ids, device=q.device, dtype=q.dtype)
467
- cos, sin = F.embedding(position_ids, self._cos_cached), F.embedding(position_ids, self._sin_cached)
468
-
469
- q = apply_rotary_emb_func(
470
- q,
471
- cos,
472
- sin,
473
- interleaved=self.interleaved,
474
- inplace=True
475
- )
476
- k = apply_rotary_emb_func(
477
- k,
478
- cos,
479
- sin,
480
- interleaved=self.interleaved,
481
- inplace=True
482
- )
483
- return q, k