qingsonglv commited on
Commit
0751d1b
1 Parent(s): fde4399

upload readme

Browse files
config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "cogagent",
3
+ "architectures": [
4
+ "CogAgentForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_cogagent.CogAgentConfig",
8
+ "AutoModelForCausalLM": "modeling_cogagent.CogAgentForCausalLM"
9
+ },
10
+ "bos_token_id": 1,
11
+ "cross_compute_hidden_size": 1024,
12
+ "cross_hidden_size": 1024,
13
+ "cross_image_size": 1120,
14
+ "eos_token_id": 2,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 4096,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 11008,
19
+ "max_position_embeddings": 2048,
20
+ "num_attention_heads": 32,
21
+ "num_hidden_layers": 32,
22
+ "pad_token_id": 0,
23
+ "rms_norm_eps": 1e-05,
24
+ "template_version": "chat",
25
+ "tie_word_embeddings": false,
26
+ "torch_dtype": "bfloat16",
27
+ "transformers_version": "4.36.0.dev0",
28
+ "use_cache": true,
29
+ "vision_config": {
30
+ "dropout_prob": 0.0,
31
+ "hidden_act": "gelu",
32
+ "hidden_size": 1792,
33
+ "image_size": 224,
34
+ "in_channels": 3,
35
+ "intermediate_size": 15360,
36
+ "layer_norm_eps": 1e-06,
37
+ "num_heads": 16,
38
+ "num_hidden_layers": 63,
39
+ "num_positions": 257,
40
+ "patch_size": 14
41
+ },
42
+ "vocab_size": 32000
43
+ }
configuration_cogagent.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class CogAgentConfig(PretrainedConfig):
6
+ _auto_class = "AutoConfig"
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=32000,
11
+ hidden_size=4096,
12
+ cross_hidden_size=1024,
13
+ cross_compute_hidden_size=1024,
14
+ cross_image_size=1120,
15
+ intermediate_size=11008,
16
+ num_hidden_layers=32,
17
+ num_attention_heads=32,
18
+ hidden_act='silu',
19
+ max_position_embeddings=2048,
20
+ initializer_range=0.02,
21
+ rms_norm_eps=1e-06,
22
+ template_version: Literal["base", "chat"] = "chat",
23
+
24
+ pad_token_id=0,
25
+ bos_token_id=1,
26
+ eos_token_id=2,
27
+ tie_word_embeddings=False,
28
+ use_cache=True,
29
+ **kwargs,
30
+ ):
31
+ self.hidden_size = hidden_size
32
+ self.cross_hidden_size = cross_hidden_size
33
+ self.cross_compute_hidden_size = cross_compute_hidden_size
34
+ self.cross_image_size = cross_image_size
35
+ self.intermediate_size = intermediate_size
36
+ self.num_attention_heads = num_attention_heads
37
+ self.max_position_embeddings = max_position_embeddings
38
+ self.rms_norm_eps = rms_norm_eps
39
+ self.initializer_range = initializer_range
40
+ self.vocab_size = vocab_size
41
+ self.num_hidden_layers = num_hidden_layers
42
+ self.hidden_act = hidden_act
43
+ self.template_version = template_version
44
+ self.use_cache = use_cache
45
+ super().__init__(
46
+ pad_token_id=pad_token_id,
47
+ bos_token_id=bos_token_id,
48
+ eos_token_id=eos_token_id,
49
+ tie_word_embeddings=tie_word_embeddings,
50
+ **kwargs,
51
+ )
cross_visual.py ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ import torch
3
+ from torch import nn
4
+ from einops import rearrange, repeat
5
+ import logging
6
+
7
+ def broadcat(tensors, dim = -1):
8
+ num_tensors = len(tensors)
9
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
10
+ assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
11
+ shape_len = list(shape_lens)[0]
12
+ dim = (dim + shape_len) if dim < 0 else dim
13
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
14
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
15
+ assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
16
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
17
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
18
+ expanded_dims.insert(dim, (dim, dims[dim]))
19
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
20
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
21
+ return torch.cat(tensors, dim = dim)
22
+
23
+ def rotate_half(x):
24
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
25
+ x1, x2 = x.unbind(dim = -1)
26
+ x = torch.stack((-x2, x1), dim = -1)
27
+ return rearrange(x, '... d r -> ... (d r)')
28
+
29
+ class VisionRotaryEmbeddingFast(nn.Module):
30
+ def __init__(
31
+ self,
32
+ dim,
33
+ pt_seq_len,
34
+ ft_seq_len=None,
35
+ custom_freqs = None,
36
+ freqs_for = 'lang',
37
+ theta = 10000,
38
+ max_freq = 10,
39
+ num_freqs = 1,
40
+ patch_dropout = 0.
41
+ ):
42
+ super().__init__()
43
+ if custom_freqs:
44
+ freqs = custom_freqs
45
+ elif freqs_for == 'lang':
46
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
47
+ elif freqs_for == 'pixel':
48
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
49
+ elif freqs_for == 'constant':
50
+ freqs = torch.ones(num_freqs).float()
51
+ else:
52
+ raise ValueError(f'unknown modality {freqs_for}')
53
+
54
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
55
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
56
+
57
+ freqs = torch.einsum('..., f -> ... f', t, freqs)
58
+ freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
59
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
60
+
61
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
62
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
63
+
64
+ self.patch_dropout = patch_dropout
65
+
66
+ self.register_buffer("freqs_cos", freqs_cos)
67
+ self.register_buffer("freqs_sin", freqs_sin)
68
+
69
+ logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
70
+
71
+ def forward(self, t, patch_indices_keep=None):
72
+ if patch_indices_keep is not None:
73
+ batch = t.size()[0]
74
+ batch_indices = torch.arange(batch)
75
+ batch_indices = batch_indices[..., None]
76
+
77
+ freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
78
+ freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
79
+
80
+ freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
81
+ freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
82
+ freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
83
+ freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
84
+
85
+ return t * freqs_cos + rotate_half(t) * freqs_sin
86
+
87
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
88
+
89
+ import torch.nn as nn
90
+ import os
91
+ from dataclasses import dataclass
92
+ from typing import Optional, Tuple, Union
93
+ from functools import partial
94
+
95
+ import numpy as np
96
+ import torch
97
+ import torch.nn.functional as F
98
+ from torch import nn
99
+
100
+ # --------------------------------------------------------
101
+ # Adapted from https://github.com/microsoft/unilm/tree/master/beit
102
+ # --------------------------------------------------------
103
+ import math
104
+ import os
105
+ from functools import partial
106
+ import torch
107
+ import torch.nn as nn
108
+ import torch.nn.functional as F
109
+ import logging
110
+ try:
111
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
112
+ except:
113
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
114
+
115
+ class PatchDropout(nn.Module):
116
+ """
117
+ https://arxiv.org/abs/2212.00794
118
+ """
119
+
120
+ def __init__(self, prob, exclude_first_token=True):
121
+ super().__init__()
122
+ assert 0 <= prob < 1.
123
+ self.prob = prob
124
+ self.exclude_first_token = exclude_first_token # exclude CLS token
125
+ logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
126
+
127
+ def forward(self, x):
128
+ if not self.training or self.prob == 0.:
129
+ return x
130
+
131
+ if self.exclude_first_token:
132
+ cls_tokens, x = x[:, :1], x[:, 1:]
133
+ else:
134
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
135
+
136
+ batch = x.size()[0]
137
+ num_tokens = x.size()[1]
138
+
139
+ batch_indices = torch.arange(batch)
140
+ batch_indices = batch_indices[..., None]
141
+
142
+ keep_prob = 1 - self.prob
143
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
144
+
145
+ rand = torch.randn(batch, num_tokens)
146
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
147
+
148
+ x = x[batch_indices, patch_indices_keep]
149
+
150
+ if self.exclude_first_token:
151
+ x = torch.cat((cls_tokens, x), dim=1)
152
+
153
+ if self.training and os.getenv('RoPE') == '1':
154
+ return x, patch_indices_keep
155
+
156
+ return x
157
+
158
+ if os.getenv('ENV_TYPE') == 'deepspeed':
159
+ try:
160
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
161
+ except:
162
+ from torch.utils.checkpoint import checkpoint
163
+ else:
164
+ from torch.utils.checkpoint import checkpoint
165
+
166
+ import xformers.ops as xops
167
+
168
+ class DropPath(nn.Module):
169
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
170
+ """
171
+ def __init__(self, drop_prob=None):
172
+ super(DropPath, self).__init__()
173
+ self.drop_prob = drop_prob
174
+
175
+ def forward(self, x):
176
+ return drop_path(x, self.drop_prob, self.training)
177
+
178
+ def extra_repr(self) -> str:
179
+ return 'p={}'.format(self.drop_prob)
180
+
181
+
182
+ class Mlp(nn.Module):
183
+ def __init__(
184
+ self,
185
+ in_features,
186
+ hidden_features=None,
187
+ out_features=None,
188
+ act_layer=nn.GELU,
189
+ norm_layer=nn.LayerNorm,
190
+ drop=0.,
191
+ subln=False,
192
+
193
+ ):
194
+ super().__init__()
195
+ out_features = out_features or in_features
196
+ hidden_features = hidden_features or in_features
197
+ self.fc1 = nn.Linear(in_features, hidden_features)
198
+ self.act = act_layer()
199
+
200
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
201
+
202
+ self.fc2 = nn.Linear(hidden_features, out_features)
203
+ self.drop = nn.Dropout(drop)
204
+
205
+ def forward(self, x):
206
+ x = self.fc1(x)
207
+ x = self.act(x)
208
+ # x = self.drop(x)
209
+ # commit this for the orignal BERT implement
210
+ x = self.ffn_ln(x)
211
+
212
+ x = self.fc2(x)
213
+ x = self.drop(x)
214
+ return x
215
+
216
+ class SwiGLU(nn.Module):
217
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
218
+ norm_layer=nn.LayerNorm, subln=False):
219
+ super().__init__()
220
+ out_features = out_features or in_features
221
+ hidden_features = hidden_features or in_features
222
+
223
+ self.w1 = nn.Linear(in_features, hidden_features)
224
+ self.w2 = nn.Linear(in_features, hidden_features)
225
+
226
+ self.act = act_layer()
227
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
228
+ self.w3 = nn.Linear(hidden_features, out_features)
229
+
230
+ self.drop = nn.Dropout(drop)
231
+
232
+ def forward(self, x):
233
+ x1 = self.w1(x)
234
+ x2 = self.w2(x)
235
+ hidden = self.act(x1) * x2
236
+ x = self.ffn_ln(hidden)
237
+ x = self.w3(x)
238
+ x = self.drop(x)
239
+ return x
240
+
241
+ class Attention(nn.Module):
242
+ def __init__(
243
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
244
+ proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
245
+ super().__init__()
246
+ self.num_heads = num_heads
247
+ head_dim = dim // num_heads
248
+ if attn_head_dim is not None:
249
+ head_dim = attn_head_dim
250
+ all_head_dim = head_dim * self.num_heads
251
+ self.scale = qk_scale or head_dim ** -0.5
252
+
253
+ self.subln = subln
254
+ if self.subln:
255
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
256
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
257
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
258
+ else:
259
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
260
+
261
+ if qkv_bias:
262
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
263
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
264
+ else:
265
+ self.q_bias = None
266
+ self.v_bias = None
267
+
268
+ if window_size:
269
+ self.window_size = window_size
270
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
271
+ self.relative_position_bias_table = nn.Parameter(
272
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
273
+ # cls to token & token 2 cls & cls to cls
274
+
275
+ # get pair-wise relative position index for each token inside the window
276
+ coords_h = torch.arange(window_size[0])
277
+ coords_w = torch.arange(window_size[1])
278
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
279
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
280
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
281
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
282
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
283
+ relative_coords[:, :, 1] += window_size[1] - 1
284
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
285
+ relative_position_index = \
286
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
287
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
288
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
289
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
290
+ relative_position_index[0, 0] = self.num_relative_distance - 1
291
+
292
+ self.register_buffer("relative_position_index", relative_position_index)
293
+ else:
294
+ self.window_size = None
295
+ self.relative_position_bias_table = None
296
+ self.relative_position_index = None
297
+
298
+ self.attn_drop = nn.Dropout(attn_drop)
299
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
300
+ # self.proj = nn.Linear(all_head_dim, all_head_dim)
301
+ self.proj = nn.Linear(all_head_dim, dim)
302
+ self.proj_drop = nn.Dropout(proj_drop)
303
+ self.xattn = xattn
304
+ self.xattn_drop = attn_drop
305
+
306
+ self.rope = rope
307
+
308
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
309
+ B, N, C = x.shape
310
+ if self.subln:
311
+ if self.q_proj.weight.dtype == torch.uint8:
312
+ import bitsandbytes as bnb
313
+ q = bnb.matmul_4bit(x, self.q_proj.weight.t(), bias=self.q_bias, quant_state=self.q_proj.weight.quant_state)
314
+ k = bnb.matmul_4bit(x, self.k_proj.weight.t(), bias=None, quant_state=self.k_proj.weight.quant_state)
315
+ v = bnb.matmul_4bit(x, self.v_proj.weight.t(), bias=self.v_bias, quant_state=self.v_proj.weight.quant_state)
316
+ else:
317
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
318
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
319
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
320
+
321
+ q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
322
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
323
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
324
+ else:
325
+
326
+ qkv_bias = None
327
+ if self.q_bias is not None:
328
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
329
+
330
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
331
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
332
+ q, k, v = qkv[0], qkv[1], qkv[2]
333
+
334
+ if self.rope:
335
+ # slightly fast impl
336
+ q_t = q[:, :, 1:, :]
337
+ ro_q_t = self.rope(q_t)
338
+ q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
339
+
340
+ k_t = k[:, :, 1:, :]
341
+ ro_k_t = self.rope(k_t)
342
+ k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
343
+
344
+ if self.xattn:
345
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
346
+ k = k.permute(0, 2, 1, 3)
347
+ v = v.permute(0, 2, 1, 3)
348
+
349
+ x = xops.memory_efficient_attention(
350
+ q, k, v,
351
+ p=self.xattn_drop,
352
+ scale=self.scale,
353
+ )
354
+ x = x.reshape(B, N, -1)
355
+ x = self.inner_attn_ln(x)
356
+ x = self.proj(x)
357
+ x = self.proj_drop(x)
358
+ else:
359
+ q = q * self.scale
360
+ attn = (q @ k.transpose(-2, -1))
361
+
362
+ if self.relative_position_bias_table is not None:
363
+ relative_position_bias = \
364
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
365
+ self.window_size[0] * self.window_size[1] + 1,
366
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
367
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
368
+ attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
369
+
370
+ if rel_pos_bias is not None:
371
+ attn = attn + rel_pos_bias.type_as(attn)
372
+
373
+ if attn_mask is not None:
374
+ attn_mask = attn_mask.bool()
375
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
376
+
377
+ attn = attn.softmax(dim=-1)
378
+ attn = self.attn_drop(attn)
379
+
380
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
381
+ x = self.inner_attn_ln(x)
382
+ x = self.proj(x)
383
+ x = self.proj_drop(x)
384
+ return x
385
+
386
+
387
+ class Block(nn.Module):
388
+
389
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
390
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
391
+ window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
392
+ subln=False, naiveswiglu=False):
393
+ super().__init__()
394
+ self.norm1 = norm_layer(dim)
395
+ self.attn = Attention(
396
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
397
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
398
+ xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
399
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
400
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
401
+ self.norm2 = norm_layer(dim)
402
+ mlp_hidden_dim = int(dim * mlp_ratio)
403
+
404
+ if naiveswiglu:
405
+ self.mlp = SwiGLU(
406
+ in_features=dim,
407
+ hidden_features=mlp_hidden_dim,
408
+ subln=subln,
409
+ norm_layer=norm_layer,
410
+ )
411
+ else:
412
+ self.mlp = Mlp(
413
+ in_features=dim,
414
+ hidden_features=mlp_hidden_dim,
415
+ act_layer=act_layer,
416
+ subln=subln,
417
+ drop=drop
418
+ )
419
+
420
+ if init_values is not None and init_values > 0:
421
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
422
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
423
+ else:
424
+ self.gamma_1, self.gamma_2 = None, None
425
+
426
+ self.postnorm = postnorm
427
+
428
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
429
+ if self.gamma_1 is None:
430
+ if self.postnorm:
431
+ x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
432
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
433
+ else:
434
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
435
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
436
+ else:
437
+ if self.postnorm:
438
+ x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
439
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
440
+ else:
441
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
442
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
443
+ return x
444
+
445
+
446
+ class PatchEmbed(nn.Module):
447
+ """ Image to Patch Embedding
448
+ """
449
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
450
+ super().__init__()
451
+ img_size = to_2tuple(img_size)
452
+ patch_size = to_2tuple(patch_size)
453
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
454
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
455
+ self.img_size = img_size
456
+ self.patch_size = patch_size
457
+ self.num_patches = num_patches
458
+
459
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
460
+
461
+ def forward(self, x, **kwargs):
462
+ B, C, H, W = x.shape
463
+ # FIXME look at relaxing size constraints
464
+ assert H == self.img_size[0] and W == self.img_size[1], \
465
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
466
+ x = self.proj(x).flatten(2).transpose(1, 2)
467
+ return x
468
+
469
+
470
+ class RelativePositionBias(nn.Module):
471
+
472
+ def __init__(self, window_size, num_heads):
473
+ super().__init__()
474
+ self.window_size = window_size
475
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
476
+ self.relative_position_bias_table = nn.Parameter(
477
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
478
+ # cls to token & token 2 cls & cls to cls
479
+
480
+ # get pair-wise relative position index for each token inside the window
481
+ coords_h = torch.arange(window_size[0])
482
+ coords_w = torch.arange(window_size[1])
483
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
484
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
485
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
486
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
487
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
488
+ relative_coords[:, :, 1] += window_size[1] - 1
489
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
490
+ relative_position_index = \
491
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
492
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
493
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
494
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
495
+ relative_position_index[0, 0] = self.num_relative_distance - 1
496
+
497
+ self.register_buffer("relative_position_index", relative_position_index)
498
+
499
+ def forward(self):
500
+ relative_position_bias = \
501
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
502
+ self.window_size[0] * self.window_size[1] + 1,
503
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
504
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
505
+
506
+
507
+ class EVAVisionTransformer(nn.Module):
508
+ """ Vision Transformer with support for patch or hybrid CNN input stage
509
+ """
510
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
511
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
512
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
513
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
514
+ use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
515
+ pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False):
516
+ super().__init__()
517
+ self.image_size = img_size
518
+ self.num_classes = num_classes
519
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
520
+
521
+ self.patch_embed = PatchEmbed(
522
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
523
+ num_patches = self.patch_embed.num_patches
524
+
525
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
526
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
527
+ if use_abs_pos_emb:
528
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
529
+ else:
530
+ self.pos_embed = None
531
+ self.pos_drop = nn.Dropout(p=drop_rate)
532
+
533
+ if use_shared_rel_pos_bias:
534
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
535
+ else:
536
+ self.rel_pos_bias = None
537
+
538
+ if rope:
539
+ half_head_dim = embed_dim // num_heads // 2
540
+ hw_seq_len = img_size // patch_size
541
+ self.rope = VisionRotaryEmbeddingFast(
542
+ dim=half_head_dim,
543
+ pt_seq_len=pt_hw_seq_len,
544
+ ft_seq_len=hw_seq_len if intp_freq else None,
545
+ # patch_dropout=patch_dropout
546
+ )
547
+ else:
548
+ self.rope = None
549
+
550
+ self.naiveswiglu = naiveswiglu
551
+
552
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
553
+ self.use_rel_pos_bias = use_rel_pos_bias
554
+ self.blocks = nn.ModuleList([
555
+ Block(
556
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
557
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
558
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
559
+ xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
560
+ for i in range(depth)])
561
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
562
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
563
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
564
+
565
+ if self.pos_embed is not None:
566
+ trunc_normal_(self.pos_embed, std=.02)
567
+
568
+ trunc_normal_(self.cls_token, std=.02)
569
+ # trunc_normal_(self.mask_token, std=.02)
570
+
571
+ self.apply(self._init_weights)
572
+ self.fix_init_weight()
573
+
574
+ if isinstance(self.head, nn.Linear):
575
+ trunc_normal_(self.head.weight, std=.02)
576
+ self.head.weight.data.mul_(init_scale)
577
+ self.head.bias.data.mul_(init_scale)
578
+
579
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
580
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
581
+
582
+ self.grad_checkpointing = grad_checkpointing
583
+
584
+ def fix_init_weight(self):
585
+ def rescale(param, layer_id):
586
+ param.div_(math.sqrt(2.0 * layer_id))
587
+
588
+ for layer_id, layer in enumerate(self.blocks):
589
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
590
+ if self.naiveswiglu:
591
+ rescale(layer.mlp.w3.weight.data, layer_id + 1)
592
+ else:
593
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
594
+
595
+ def get_cast_dtype(self) -> torch.dtype:
596
+ return self.blocks[0].mlp.fc2.weight.dtype
597
+
598
+ def _init_weights(self, m):
599
+ if isinstance(m, nn.Linear):
600
+ trunc_normal_(m.weight, std=.02)
601
+ if m.bias is not None:
602
+ nn.init.constant_(m.bias, 0)
603
+ elif isinstance(m, nn.LayerNorm):
604
+ nn.init.constant_(m.bias, 0)
605
+ nn.init.constant_(m.weight, 1.0)
606
+
607
+ def get_num_layers(self):
608
+ return len(self.blocks)
609
+
610
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
611
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
612
+ for param in self.parameters():
613
+ param.requires_grad = False
614
+
615
+ @torch.jit.ignore
616
+ def set_grad_checkpointing(self, enable=True):
617
+ self.grad_checkpointing = enable
618
+
619
+ @torch.jit.ignore
620
+ def no_weight_decay(self):
621
+ return {'pos_embed', 'cls_token'}
622
+
623
+ def get_classifier(self):
624
+ return self.head
625
+
626
+ def reset_classifier(self, num_classes, global_pool=''):
627
+ self.num_classes = num_classes
628
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
629
+
630
+ def forward_features(self, x, return_all_features=False):
631
+
632
+ x = self.patch_embed(x)
633
+ batch_size, seq_len, _ = x.size()
634
+
635
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
636
+ x = torch.cat((cls_tokens, x), dim=1)
637
+ if self.pos_embed is not None:
638
+ x = x + self.pos_embed
639
+ x = self.pos_drop(x)
640
+
641
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
642
+ if os.getenv('RoPE') == '1':
643
+ if self.training and not isinstance(self.patch_dropout, nn.Identity):
644
+ x, patch_indices_keep = self.patch_dropout(x)
645
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
646
+ else:
647
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
648
+ x = self.patch_dropout(x)
649
+ else:
650
+ x = self.patch_dropout(x)
651
+
652
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
653
+ for i, blk in enumerate(self.blocks):
654
+ if i == len(self.blocks)-1:
655
+ continue
656
+ if self.grad_checkpointing:
657
+ x = checkpoint(blk, x, (rel_pos_bias,))
658
+ else:
659
+ x = blk(x, rel_pos_bias=rel_pos_bias)
660
+
661
+ if not return_all_features:
662
+ x = self.norm(x)
663
+ if self.fc_norm is not None:
664
+ return self.fc_norm(x.mean(1))
665
+ else:
666
+ return x[:, 0]
667
+ return x
668
+
669
+ def forward(self, x, return_all_features=False):
670
+ if return_all_features:
671
+ return self.forward_features(x, return_all_features)
672
+ x = self.forward_features(x)
673
+ x = self.head(x)
674
+ return x
675
+
676
+ class LayerNorm(nn.LayerNorm):
677
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
678
+
679
+ def forward(self, x: torch.Tensor):
680
+ orig_type = x.dtype
681
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
682
+ return x.to(orig_type)
683
+
684
+ try:
685
+ from apex.normalization import FusedLayerNorm
686
+ except:
687
+ FusedLayerNorm = LayerNorm
688
+ print("Please 'pip install apex'")
689
+
690
+
691
+ @dataclass
692
+ class CLIPVisionCfg:
693
+ layers: Union[Tuple[int, int, int, int], int] = 12
694
+ width: int = 768
695
+ head_width: int = 64
696
+ mlp_ratio: float = 4.0
697
+ patch_size: int = 16
698
+ image_size: Union[Tuple[int, int], int] = 224
699
+ ls_init_value: Optional[float] = None # layer scale initial value
700
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
701
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
702
+ drop_path_rate: Optional[float] = None # drop path rate
703
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
704
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
705
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
706
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
707
+ timm_proj_bias: bool = False # enable bias final projection
708
+ eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
709
+ qkv_bias: bool = True
710
+ fusedLN: bool = False
711
+ xattn: bool = False
712
+ postnorm: bool = False
713
+ rope: bool = False
714
+ pt_hw_seq_len: int = 16 # 224/14
715
+ intp_freq: bool = False
716
+ naiveswiglu: bool = False
717
+ subln: bool = False
718
+
719
+
720
+ def _build_vision_tower(
721
+ embed_dim: int,
722
+ vision_cfg: CLIPVisionCfg
723
+ ):
724
+ if isinstance(vision_cfg, dict):
725
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
726
+
727
+ if vision_cfg.eva_model_name:
728
+ vision_heads = vision_cfg.width // vision_cfg.head_width
729
+ norm_layer = LayerNorm
730
+ visual = EVAVisionTransformer(
731
+ img_size=vision_cfg.image_size,
732
+ patch_size=vision_cfg.patch_size,
733
+ num_classes=embed_dim,
734
+ use_mean_pooling=vision_cfg.global_average_pool, #False
735
+ init_values=vision_cfg.ls_init_value,
736
+ patch_dropout=vision_cfg.patch_dropout,
737
+ embed_dim=vision_cfg.width,
738
+ depth=vision_cfg.layers,
739
+ num_heads=vision_heads,
740
+ mlp_ratio=vision_cfg.mlp_ratio,
741
+ qkv_bias=vision_cfg.qkv_bias,
742
+ drop_path_rate=vision_cfg.drop_path_rate,
743
+ norm_layer= partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6),
744
+ xattn=vision_cfg.xattn,
745
+ rope=vision_cfg.rope,
746
+ postnorm=vision_cfg.postnorm,
747
+ pt_hw_seq_len= vision_cfg.pt_hw_seq_len, # 224/14
748
+ intp_freq= vision_cfg.intp_freq,
749
+ naiveswiglu= vision_cfg.naiveswiglu,
750
+ subln= vision_cfg.subln
751
+ )
752
+
753
+ return visual
754
+
755
+ class Eva2LargeEncoder(nn.Module):
756
+ def __init__(self, image_size=224):
757
+ super(Eva2LargeEncoder, self).__init__()
758
+ self.config = {
759
+ "embed_dim": 768,
760
+ "vision_cfg": {
761
+ "image_size": 336,
762
+ "layers": 24,
763
+ "width": 1024,
764
+ "drop_path_rate": 0,
765
+ "head_width": 64,
766
+ "mlp_ratio": 2.6667,
767
+ "patch_size": 14,
768
+ "eva_model_name": "eva-clip-l-14-336",
769
+ "xattn": True,
770
+ "fusedLN": True,
771
+ "rope": True,
772
+ "pt_hw_seq_len": 16,
773
+ "intp_freq": True,
774
+ "naiveswiglu": True,
775
+ "subln": True
776
+ }
777
+ }
778
+ self.config['vision_cfg']['image_size'] = image_size
779
+
780
+ import os
781
+ os.environ['delRoPE'] = '1' # to avoid error in rope params when changing image size
782
+ self.model = _build_vision_tower(**self.config)
783
+
784
+
785
+ def forward(self, images):
786
+ encode = self.model(images, return_all_features=True)[:, 1:, :]
787
+ return encode
788
+
789
+ class CrossVisionModel(nn.Module):
790
+ def __init__(self, config):
791
+ super().__init__()
792
+ self.vit = Eva2LargeEncoder(image_size=config.cross_image_size)
793
+ self.pos_embed = nn.Parameter(torch.zeros((self.vit.config['vision_cfg']['image_size'] // self.vit.config['vision_cfg']['patch_size']) ** 2, self.vit.config['vision_cfg']['width']))
794
+
795
+ def forward(self, images):
796
+ enc = self.vit(images)
797
+ return enc + self.pos_embed.unsqueeze(0)
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.36.0.dev0"
7
+ }
model-00001-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15c2451e4e0dd5caae61e91de67ea0dac7b554c4e2b39d54e67ffbe232460063
3
+ size 4974581824
model-00002-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:351e987e5f1f28124a8c839085aeb3111e8f956393d3092439e07c7a285a5d90
3
+ size 4982995648
model-00003-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff93f061698e4a8bbc042d9f243efd9647b19889f0e5f87188cd6bfc0400f839
3
+ size 4982995728
model-00004-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54c8f6eca491986fb61914029014d80a7813c1538807a55296530c6337fe2829
3
+ size 4982995728
model-00005-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4218cb8e5bda353d4491a70054b2e8f384f00c55bcc453807309982760fc48fd
3
+ size 4982995728
model-00006-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23f2f3123c49dc4fe16fa44f6ee3dd388b18490fa05966f9dd21e806e3cbd22d
3
+ size 4950060832
model-00007-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d6eec2f2b61a1eb39dfb7703231639e16394bba2348108f1ed970268abac86e
3
+ size 4945866712
model-00008-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6df8c2486c4d21b20c5b2a72a699b47fd8ac62199ba9ae69db32054ef4fab1a2
3
+ size 1783098344
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_cogagent.py ADDED
@@ -0,0 +1,910 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """largely copy from llama and adapt for CogAgent"""
2
+ import warnings
3
+ from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, Any
4
+
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import CrossEntropyLoss
9
+ from torchvision import transforms
10
+ from einops import rearrange
11
+
12
+ from transformers import PreTrainedModel, PreTrainedTokenizer
13
+ from transformers.utils.logging import get_logger
14
+ from transformers.activations import ACT2FN
15
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
+
17
+ from .configuration_cogagent import CogAgentConfig
18
+ from .util import FastRotaryEmbedding
19
+ from .visual import EVA2CLIPModel
20
+ from .cross_visual import CrossVisionModel
21
+
22
+ if TYPE_CHECKING:
23
+ from transformers.utils import ModelOutput
24
+
25
+ logger = get_logger(__name__)
26
+
27
+ LANGUAGE_TOKEN_TYPE = 0
28
+ VISION_TOKEN_TYPE = 1
29
+
30
+
31
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
32
+ def _make_causal_mask(
33
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
34
+ ):
35
+ """
36
+ Make causal mask used for bi-directional self-attention.
37
+ """
38
+ bsz, tgt_len = input_ids_shape
39
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
40
+ mask_cond = torch.arange(mask.size(-1), device=device)
41
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
42
+ mask = mask.to(dtype)
43
+
44
+ if past_key_values_length > 0:
45
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
46
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
47
+
48
+
49
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
50
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
51
+ """
52
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
53
+ """
54
+ bsz, src_len = mask.size()
55
+ tgt_len = tgt_len if tgt_len is not None else src_len
56
+
57
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
58
+
59
+ inverted_mask = 1.0 - expanded_mask
60
+
61
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
62
+
63
+
64
+ class RMSNorm(nn.Module):
65
+ def __init__(self, hidden_size, eps=1e-6):
66
+ super().__init__()
67
+ self.weight = nn.Parameter(torch.ones(hidden_size))
68
+ self.variance_epsilon = eps
69
+
70
+ def forward(self, hidden_states):
71
+ input_dtype = hidden_states.dtype
72
+ hidden_states = hidden_states.to(torch.float32)
73
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
74
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
75
+ return (self.weight * hidden_states).to(input_dtype)
76
+
77
+
78
+ class MLP(nn.Module):
79
+ def __init__(self, config):
80
+ super().__init__()
81
+ self.hidden_size = config.hidden_size
82
+ self.intermediate_size = config.intermediate_size
83
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
84
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
85
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
86
+ self.act_fn = ACT2FN[config.hidden_act]
87
+
88
+ def forward(self, x):
89
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
90
+ return down_proj
91
+
92
+
93
+ def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
94
+ vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
95
+ vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] == VISION_TOKEN_TYPE)
96
+ language_token_mask = ~vision_token_mask
97
+ return vision_token_mask, language_token_mask
98
+
99
+
100
+ class VisionExpertMLP(nn.Module):
101
+ def __init__(self, config):
102
+ super().__init__()
103
+ self.language_mlp = MLP(config)
104
+ self.vision_mlp = MLP(config)
105
+
106
+ def forward(self, hidden_states: "torch.Tensor(B, L, D)", token_type_ids: "torch.LongTensor(B, L)"):
107
+ output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device)
108
+ vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
109
+ output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
110
+ output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
111
+ return output
112
+
113
+
114
+ def attention_fn(
115
+ query_layer: "torch.tensor(B, H, L, HD)",
116
+ key_layer: "torch.tensor(B, H, L, HD)",
117
+ value_layer: "torch.tensor(B, H, L, HD)",
118
+ attention_mask: "torch.tensor(B, H, L, HD)",
119
+ *,
120
+ scaling_attention_score: bool = True,
121
+ attention_dropout: nn.Module = None
122
+ ):
123
+ attention_mask_bool = (attention_mask == 0)
124
+ is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
125
+ is_full = (attention_mask_bool > 0).all()
126
+ if not (int(torch.__version__.split('.')[0]) >= 2):
127
+ warnings.warn("It's recommended to use torch2.0 or higher.")
128
+ if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
129
+ dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
130
+ return torch.nn.functional.scaled_dot_product_attention(
131
+ query_layer, key_layer, value_layer,
132
+ attn_mask=None,
133
+ dropout_p=dropout_p,
134
+ is_causal=not is_full
135
+ )
136
+ else:
137
+ if scaling_attention_score:
138
+ query_layer = query_layer / math.sqrt(query_layer.shape[-1])
139
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
140
+ attention_scores = attention_scores + attention_mask
141
+ attention_scores = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
142
+ if attention_dropout is not None:
143
+ attention_scores = attention_dropout(attention_scores)
144
+ context_layer = torch.matmul(attention_scores, value_layer)
145
+ return context_layer
146
+
147
+
148
+ class VisionExpertAttention(nn.Module):
149
+ def __init__(self, config):
150
+ super().__init__()
151
+ self.config = config
152
+ self.hidden_size = config.hidden_size
153
+ self.num_heads = config.num_attention_heads
154
+ self.head_dim = self.hidden_size // self.num_heads
155
+ self.max_position_embeddings = config.max_position_embeddings
156
+
157
+ # self.rotary_emb = RotaryEmbedding(self.hidden_size // self.num_heads)
158
+ self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False)
159
+ self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
160
+ self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
161
+ self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
162
+ self.language_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
163
+
164
+ def _transpose_for_scores(self, tensor):
165
+ """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
166
+ new_tensor_shape = tensor.size()[:-1] + (self.num_heads, self.head_dim)
167
+ tensor = tensor.view(*new_tensor_shape)
168
+ return tensor.permute(0, 2, 1, 3)
169
+
170
+ def forward(
171
+ self,
172
+ hidden_states: torch.Tensor,
173
+ token_type_ids: torch.LongTensor,
174
+ position_ids: torch.LongTensor,
175
+ attention_mask: Optional[torch.Tensor] = None,
176
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
177
+ output_attentions: bool = False,
178
+ use_cache: bool = False,
179
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
180
+ bsz, q_len, _ = hidden_states.size()
181
+ vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
182
+
183
+ shape = list(hidden_states.shape)
184
+ shape[-1] = shape[-1] * 3
185
+ mixed_raw_layer = torch.empty(shape, dtype=hidden_states.dtype, device=hidden_states.device)
186
+ mixed_raw_layer[vision_token_mask] = self.vision_expert_query_key_value(hidden_states[vision_token_mask])
187
+ mixed_raw_layer[language_token_mask] = self.language_expert_query_key_value(hidden_states[language_token_mask])
188
+
189
+ query_states, key_states, value_states = torch.split(mixed_raw_layer, self.hidden_size, dim=-1)
190
+ query_states = self._transpose_for_scores(query_states) # B, H, L, HD
191
+ key_states = self._transpose_for_scores(key_states) # B, H, L, HD
192
+ value_states = self._transpose_for_scores(value_states) # B, H, L, HD
193
+
194
+ kv_seq_len = key_states.shape[-2]
195
+ if past_key_value is not None:
196
+ kv_seq_len += past_key_value[0].shape[-2]
197
+
198
+ query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids, max_seqlen=position_ids.max() + 1)
199
+
200
+ if past_key_value is not None:
201
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
202
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
203
+
204
+ past_key_value = (key_states, value_states) if use_cache else None
205
+
206
+ context_layer = attention_fn(
207
+ query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
208
+ scaling_attention_score=True, attention_dropout=None)
209
+ if context_layer.size() != (bsz, self.num_heads, q_len, self.head_dim):
210
+ raise ValueError(
211
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
212
+ f" {context_layer.size()}"
213
+ )
214
+ context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
215
+
216
+ attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
217
+ attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
218
+ attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
219
+
220
+ if output_attentions:
221
+ warnings.warn("output_attentions is not implemented.")
222
+
223
+ return attn_output, None, past_key_value
224
+
225
+ class CrossAttention(nn.Module):
226
+ def __init__(self, config):
227
+ super().__init__()
228
+ self.config = config
229
+ self.hidden_size = config.hidden_size
230
+ self.cross_hidden_size = config.cross_hidden_size
231
+ self.cross_compute_hidden_size = config.cross_compute_hidden_size
232
+ self.num_heads = config.num_attention_heads
233
+ self.head_dim = self.hidden_size // self.num_heads
234
+ self.cross_head_dim = self.cross_compute_hidden_size // self.num_heads
235
+ self.max_position_embeddings = config.max_position_embeddings
236
+
237
+ # self.rotary_emb = RotaryEmbedding(self.hidden_size // self.num_heads)
238
+ self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False)
239
+ self.query = nn.Linear(self.hidden_size, self.cross_compute_hidden_size, bias=False)
240
+ self.key_value = nn.Linear(self.cross_hidden_size, self.cross_compute_hidden_size * 2, bias=False)
241
+ self.dense = nn.Linear(self.cross_compute_hidden_size, self.hidden_size, bias=False)
242
+
243
+ def _transpose_for_scores(self, tensor):
244
+ """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
245
+ new_tensor_shape = tensor.size()[:-1] + (self.num_heads, self.cross_head_dim)
246
+ tensor = tensor.view(*new_tensor_shape)
247
+ return tensor.permute(0, 2, 1, 3)
248
+
249
+ def forward(
250
+ self,
251
+ hidden_states: torch.Tensor,
252
+ encoder_outputs: torch.LongTensor,
253
+ attention_mask: Optional[torch.Tensor] = None,
254
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
255
+ output_attentions: bool = False,
256
+ use_cache: bool = False,
257
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
258
+ bsz, q_len, _ = hidden_states.size()
259
+
260
+ shape = list(hidden_states.shape)
261
+ shape[-1] = shape[-1] * 3
262
+
263
+ mixed_query_layer = self.query(hidden_states)
264
+ if past_key_value is None:
265
+ mixed_x_layer = self.key_value(encoder_outputs)
266
+ mixed_key_layer, mixed_value_layer = torch.split(mixed_x_layer, self.cross_compute_hidden_size, dim=-1)
267
+ key_states = self._transpose_for_scores(mixed_key_layer) # B, H, L, HD
268
+ value_states = self._transpose_for_scores(mixed_value_layer) # B, H, L, HD
269
+ else:
270
+ key_states, value_states = past_key_value
271
+
272
+ query_states = self._transpose_for_scores(mixed_query_layer) # B, H, L, HD
273
+
274
+ past_key_value = (key_states, value_states) if use_cache else None
275
+
276
+ context_layer = attention_fn(
277
+ query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
278
+ scaling_attention_score=True, attention_dropout=None)
279
+ if context_layer.size() != (bsz, self.num_heads, q_len, self.cross_head_dim):
280
+ raise ValueError(
281
+ f"`cross_attn_output` should be of size {(bsz, self.num_heads, q_len, self.cross_head_dim)}, but is"
282
+ f" {context_layer.size()}"
283
+ )
284
+ context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.cross_hidden_size)
285
+
286
+ attn_output = self.dense(context_layer)
287
+
288
+ if output_attentions:
289
+ warnings.warn("output_attentions is not implemented.")
290
+
291
+ return attn_output, None, past_key_value
292
+
293
+ class CogAgentDecoderLayer(nn.Module):
294
+ def __init__(self, config):
295
+ super().__init__()
296
+ self.hidden_size = config.hidden_size
297
+ self.self_attn = VisionExpertAttention(config=config)
298
+ self.cross_attn = CrossAttention(config=config)
299
+ self.mlp = VisionExpertMLP(config)
300
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
301
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
302
+ self.post_cross_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
303
+
304
+ def forward(
305
+ self,
306
+ hidden_states: torch.Tensor,
307
+ encoder_outputs: torch.Tensor,
308
+ token_type_ids: torch.LongTensor,
309
+ position_ids: torch.LongTensor,
310
+ attention_mask: Optional[torch.Tensor] = None,
311
+ cross_attention_mask: Optional[torch.Tensor] = None,
312
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
313
+ output_attentions: Optional[bool] = False,
314
+ use_cache: Optional[bool] = False,
315
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
316
+ residual = hidden_states
317
+
318
+ hidden_states = self.input_layernorm(hidden_states)
319
+
320
+ # Self Attention
321
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
322
+ hidden_states=hidden_states,
323
+ token_type_ids=token_type_ids,
324
+ position_ids=position_ids,
325
+ attention_mask=attention_mask,
326
+ past_key_value=past_key_value[:2] if past_key_value is not None else None,
327
+ output_attentions=output_attentions,
328
+ use_cache=use_cache,
329
+ )
330
+ hidden_states = residual + hidden_states
331
+
332
+ cross_input = self.post_cross_attention_layernorm(hidden_states)
333
+ # Fully Connected
334
+ attention_output, self_cross_attn_weights, present_cross_key_value = self.cross_attn(
335
+ hidden_states=cross_input,
336
+ encoder_outputs=encoder_outputs,
337
+ attention_mask=cross_attention_mask,
338
+ past_key_value=past_key_value[-2:] if past_key_value is not None else None,
339
+ output_attentions=output_attentions,
340
+ use_cache=use_cache,
341
+ )
342
+ hidden_states = hidden_states + attention_output
343
+ mlp_input = self.post_attention_layernorm(hidden_states)
344
+ mlp_output = self.mlp(mlp_input, token_type_ids=token_type_ids)
345
+ hidden_states = mlp_output + hidden_states
346
+
347
+ outputs = (hidden_states,)
348
+
349
+ if output_attentions:
350
+ outputs += (self_attn_weights,)
351
+
352
+ if use_cache:
353
+ outputs += (present_key_value+present_cross_key_value,)
354
+
355
+ return outputs # type: ignore
356
+
357
+
358
+ class CogAgentPreTrainedModel(PreTrainedModel):
359
+ config_class = CogAgentConfig
360
+ base_model_prefix = "model"
361
+ supports_gradient_checkpointing = False
362
+ _no_split_modules = ["CogAgentDecoderLayer"]
363
+ _skip_keys_device_placement = "past_key_values"
364
+
365
+ def _init_weights(self, module):
366
+ std = self.config.initializer_range
367
+ if isinstance(module, nn.Linear):
368
+ module.weight.data.normal_(mean=0.0, std=std)
369
+ if module.bias is not None:
370
+ module.bias.data.zero_()
371
+ elif isinstance(module, nn.Embedding):
372
+ module.weight.data.normal_(mean=0.0, std=std)
373
+ if module.padding_idx is not None:
374
+ module.weight.data[module.padding_idx].zero_()
375
+
376
+
377
+ def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
378
+ if images_list is None or len(images_list) == 0:
379
+ return True
380
+ for image_list in images_list:
381
+ if len(image_list):
382
+ return False
383
+ return True
384
+
385
+
386
+ def build_position_ids(x: "torch.BoolTensor(B, L)", attention_mask: Optional["torch.BoolTensor(B, L)"] = None) -> "torch.LongTensor(B, L)":
387
+ if attention_mask is not None:
388
+ tmp = x.clone()
389
+ tmp[~(attention_mask.bool())] = -1
390
+ else:
391
+ tmp = x.clone()
392
+ # image boi eoi token as LANGUAGE_TOKEN_TYPE
393
+ is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
394
+ is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
395
+ is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
396
+ is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
397
+ is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
398
+ tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
399
+ # final position ids
400
+ y = torch.zeros_like(x, dtype=torch.long)
401
+ y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ((tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
402
+ y = y.cumsum(dim=-1)
403
+ return y
404
+
405
+
406
+ class CogAgentModel(CogAgentPreTrainedModel):
407
+ def __init__(self, config):
408
+ super().__init__(config)
409
+ self.padding_idx = config.pad_token_id
410
+ self.vocab_size = config.vocab_size
411
+
412
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
413
+ self.layers = nn.ModuleList([CogAgentDecoderLayer(config) for _ in range(config.num_hidden_layers)])
414
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
415
+
416
+ self.vision = EVA2CLIPModel(config)
417
+ self.cross_vision = CrossVisionModel(config)
418
+
419
+ self.gradient_checkpointing = False
420
+ # Initialize weights and apply final processing
421
+ self.post_init()
422
+
423
+ def encode_images(self, images: List[List[torch.Tensor]]) -> torch.Tensor:
424
+ images_list, images = images, []
425
+
426
+ images = []
427
+ for image_list in images_list:
428
+ for image in image_list:
429
+ images.append(image)
430
+
431
+ images = torch.stack(images)
432
+ images_features = self.vision(images)
433
+ return images_features
434
+
435
+ def encode_cross_images(self, images: List[List[torch.Tensor]]) -> torch.Tensor:
436
+ images_list, images = images, []
437
+
438
+ images = []
439
+ for image_list in images_list:
440
+ for image in image_list:
441
+ images.append(image)
442
+
443
+ images = torch.stack(images)
444
+ encoder_outputs = self.cross_vision(images)
445
+ return encoder_outputs
446
+
447
+ def forward(
448
+ self,
449
+ input_ids: torch.LongTensor = None,
450
+ images: List[List[torch.Tensor]] = None,
451
+ cross_images: List[List[torch.Tensor]] = None,
452
+ token_type_ids: Optional[torch.LongTensor] = None,
453
+ attention_mask: Optional[torch.Tensor] = None,
454
+ cross_attention_mask: Optional[torch.Tensor] = None,
455
+ position_ids: Optional[torch.LongTensor] = None,
456
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
457
+ inputs_embeds: Optional[torch.FloatTensor] = None,
458
+ use_cache: Optional[bool] = None,
459
+ output_attentions: Optional[bool] = None,
460
+ output_hidden_states: Optional[bool] = None,
461
+ return_dict: Optional[bool] = None,
462
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
463
+ """take care of image_encode, token_type_ids, position_ids and (attention_mask = None is fine)"""
464
+
465
+ if past_key_values is not None:
466
+ encoder_outputs = None
467
+ # generate mode with past_key_values. the image features are already mapped
468
+ else:
469
+ # not allow for inputs_embeds, because we want to process image feature
470
+ assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
471
+ if not is_empty(images): # multi-modality
472
+ assert token_type_ids is not None, f"multi-modality requires `token_type_ids`!"
473
+ assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
474
+ inputs_embeds = self.embed_tokens(input_ids)
475
+ images_features = self.encode_images(images)
476
+ encoder_outputs = self.encode_cross_images(cross_images)
477
+ images_features = rearrange(images_features, 'b n d -> (b n) d')
478
+ images_features = images_features.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
479
+ inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
480
+ else: # single-modality
481
+ if token_type_ids is None:
482
+ token_type_ids = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) * LANGUAGE_TOKEN_TYPE
483
+ assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
484
+ inputs_embeds = self.embed_tokens(input_ids)
485
+ encoder_outputs = None
486
+
487
+ if position_ids is None:
488
+ position_ids = build_position_ids(token_type_ids, attention_mask)
489
+ input_ids = None
490
+
491
+ return self.llm_forward(
492
+ input_ids=input_ids,
493
+ encoder_outputs=encoder_outputs,
494
+ token_type_ids=token_type_ids,
495
+ attention_mask=attention_mask,
496
+ cross_attention_mask=cross_attention_mask,
497
+ position_ids=position_ids,
498
+ past_key_values=past_key_values,
499
+ inputs_embeds=inputs_embeds,
500
+ use_cache=use_cache,
501
+ output_attentions=output_attentions,
502
+ output_hidden_states=output_hidden_states,
503
+ return_dict=return_dict,
504
+ )
505
+
506
+ def llm_forward(
507
+ self,
508
+ input_ids: torch.LongTensor = None,
509
+ encoder_outputs: torch.LongTensor = None,
510
+ token_type_ids: torch.LongTensor = None,
511
+ attention_mask: Optional[torch.Tensor] = None,
512
+ cross_attention_mask: Optional[torch.Tensor] = None,
513
+ position_ids: Optional[torch.LongTensor] = None,
514
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
515
+ inputs_embeds: Optional[torch.FloatTensor] = None,
516
+ use_cache: Optional[bool] = None,
517
+ output_attentions: Optional[bool] = None,
518
+ output_hidden_states: Optional[bool] = None,
519
+ return_dict: Optional[bool] = None,
520
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
521
+ """largely copy from llama forward and adapt for CogAgent with `token_type_ids`"""
522
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
523
+ output_hidden_states = (
524
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
525
+ )
526
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
527
+
528
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
529
+
530
+ # retrieve input_ids and inputs_embeds
531
+ if input_ids is not None and inputs_embeds is not None:
532
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
533
+ elif input_ids is not None:
534
+ batch_size, seq_length = input_ids.shape
535
+ elif inputs_embeds is not None:
536
+ batch_size, seq_length, _ = inputs_embeds.shape
537
+ else:
538
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
539
+
540
+ seq_length_with_past = seq_length
541
+ past_key_values_length = 0
542
+
543
+ if past_key_values is not None:
544
+ past_key_values_length = past_key_values[0][0].shape[2]
545
+ seq_length_with_past = seq_length_with_past + past_key_values_length
546
+
547
+ if position_ids is None:
548
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
549
+ position_ids = torch.arange(
550
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
551
+ )
552
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
553
+ else:
554
+ position_ids = position_ids.view(-1, seq_length).long()
555
+
556
+ if inputs_embeds is None:
557
+ inputs_embeds = self.embed_tokens(input_ids)
558
+ # embed positions
559
+ if attention_mask is None:
560
+ attention_mask = torch.ones(
561
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
562
+ )
563
+ if cross_attention_mask is None:
564
+ cross_attention_mask = torch.ones(
565
+ (batch_size, 1), dtype=torch.bool, device=inputs_embeds.device
566
+ )
567
+ attention_mask = self._prepare_decoder_attention_mask(
568
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
569
+ )
570
+
571
+ hidden_states = inputs_embeds
572
+
573
+ # decoder layers
574
+ all_hidden_states = () if output_hidden_states else None
575
+ all_self_attns = () if output_attentions else None
576
+ next_decoder_cache = () if use_cache else None
577
+
578
+ for idx, decoder_layer in enumerate(self.layers):
579
+ if output_hidden_states:
580
+ all_hidden_states += (hidden_states,)
581
+
582
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
583
+ layer_outputs = decoder_layer(
584
+ hidden_states,
585
+ encoder_outputs=encoder_outputs,
586
+ token_type_ids=token_type_ids,
587
+ attention_mask=attention_mask,
588
+ cross_attention_mask=cross_attention_mask,
589
+ position_ids=position_ids,
590
+ past_key_value=past_key_value,
591
+ output_attentions=output_attentions,
592
+ use_cache=use_cache,
593
+ )
594
+ hidden_states = layer_outputs[0]
595
+
596
+ if use_cache:
597
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
598
+
599
+ if output_attentions:
600
+ all_self_attns += (layer_outputs[1],)
601
+
602
+ hidden_states = self.norm(hidden_states)
603
+
604
+ # add hidden states from the last decoder layer
605
+ if output_hidden_states:
606
+ all_hidden_states += (hidden_states,)
607
+
608
+ next_cache = next_decoder_cache if use_cache else None
609
+ if not return_dict:
610
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
611
+ return BaseModelOutputWithPast(
612
+ last_hidden_state=hidden_states,
613
+ past_key_values=next_cache,
614
+ hidden_states=all_hidden_states,
615
+ attentions=all_self_attns,
616
+ )
617
+
618
+ def get_input_embeddings(self):
619
+ return self.embed_tokens
620
+
621
+ def set_input_embeddings(self, value):
622
+ self.embed_tokens = value
623
+
624
+ # noinspection PyMethodMayBeStatic
625
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
626
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
627
+ # create causal mask
628
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
629
+ combined_attention_mask = None
630
+ if input_shape[-1] > 1:
631
+ combined_attention_mask = _make_causal_mask(
632
+ input_shape,
633
+ inputs_embeds.dtype,
634
+ device=inputs_embeds.device,
635
+ past_key_values_length=past_key_values_length,
636
+ )
637
+
638
+ if attention_mask is not None:
639
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
640
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
641
+ inputs_embeds.device
642
+ )
643
+ combined_attention_mask = (
644
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
645
+ )
646
+
647
+ return combined_attention_mask
648
+
649
+
650
+ def chat_history_to_prompt(history, query):
651
+ prompt = " [INST] "
652
+ for i, (old_query, response) in enumerate(history):
653
+ prompt += old_query + " [/INST] " + response + " [INST] "
654
+ prompt += query + " [/INST] "
655
+ return prompt
656
+
657
+
658
+ def base_history_to_prompt(history, query):
659
+ prompt = query
660
+ return prompt
661
+
662
+
663
+ _history_to_prompt = {
664
+ "base": base_history_to_prompt,
665
+ "chat": chat_history_to_prompt
666
+ }
667
+
668
+
669
+ class CogAgentForCausalLM(CogAgentPreTrainedModel):
670
+ _auto_class = "AutoModelForCausalLM"
671
+
672
+ def __init__(self, config):
673
+ super().__init__(config)
674
+ self.model = CogAgentModel(config)
675
+ self.vocab_size = config.vocab_size
676
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
677
+
678
+ # Initialize weights and apply final processing
679
+ self.post_init()
680
+
681
+ def get_input_embeddings(self):
682
+ return self.model.embed_tokens
683
+
684
+ def set_input_embeddings(self, value):
685
+ self.model.embed_tokens = value
686
+
687
+ def get_output_embeddings(self):
688
+ return self.lm_head
689
+
690
+ def set_output_embeddings(self, new_embeddings):
691
+ self.lm_head = new_embeddings
692
+
693
+ def set_decoder(self, decoder):
694
+ self.model = decoder
695
+
696
+ def get_decoder(self):
697
+ return self.model
698
+
699
+ def forward(
700
+ self,
701
+ input_ids: torch.LongTensor = None,
702
+ images: List[List[torch.Tensor]] = None,
703
+ cross_images: List[List[torch.Tensor]] = None,
704
+ token_type_ids: Optional[torch.LongTensor] = None,
705
+ attention_mask: Optional[torch.Tensor] = None,
706
+ position_ids: Optional[torch.LongTensor] = None,
707
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
708
+ inputs_embeds: Optional[torch.FloatTensor] = None,
709
+ use_cache: Optional[bool] = None,
710
+ output_attentions: Optional[bool] = None,
711
+ output_hidden_states: Optional[bool] = None,
712
+ return_dict: Optional[bool] = None,
713
+ labels: Optional[torch.LongTensor] = None,
714
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
715
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
716
+ output_hidden_states = (
717
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
718
+ )
719
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
720
+
721
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
722
+ outputs = self.model(
723
+ input_ids=input_ids,
724
+ images=images,
725
+ cross_images=cross_images,
726
+ token_type_ids=token_type_ids,
727
+ attention_mask=attention_mask,
728
+ position_ids=position_ids,
729
+ past_key_values=past_key_values,
730
+ inputs_embeds=inputs_embeds,
731
+ use_cache=use_cache,
732
+ output_attentions=output_attentions,
733
+ output_hidden_states=output_hidden_states,
734
+ return_dict=return_dict,
735
+ )
736
+
737
+ hidden_states = outputs[0]
738
+ logits = self.lm_head(hidden_states)
739
+ logits = logits.float()
740
+
741
+ loss = None
742
+ if labels is not None:
743
+ # Shift so that tokens < n predict n
744
+ shift_logits = logits[..., :-1, :].contiguous()
745
+ shift_labels = labels[..., 1:].contiguous()
746
+ # Flatten the tokens
747
+ loss_fct = CrossEntropyLoss()
748
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
749
+ shift_labels = shift_labels.view(-1)
750
+ # Enable model parallelism
751
+ shift_labels = shift_labels.to(shift_logits.device)
752
+ loss = loss_fct(shift_logits, shift_labels)
753
+
754
+ if not return_dict:
755
+ output = (logits,) + outputs[1:]
756
+ return (loss,) + output if loss is not None else output
757
+
758
+ return CausalLMOutputWithPast(
759
+ loss=loss,
760
+ logits=logits,
761
+ past_key_values=outputs.past_key_values,
762
+ hidden_states=outputs.hidden_states,
763
+ attentions=outputs.attentions,
764
+ )
765
+
766
+ def _prepare_attention_mask_for_generation(
767
+ self,
768
+ inputs: torch.Tensor,
769
+ pad_token_id: Optional[int],
770
+ eos_token_id: Optional[Union[int, List[int]]],
771
+ ) -> torch.LongTensor:
772
+ return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
773
+
774
+ def prepare_inputs_for_generation(
775
+ self, input_ids, token_type_ids, images=None, cross_images=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
776
+ ):
777
+ # build position_ids if needed
778
+ position_ids = kwargs.get("position_ids", None)
779
+ if position_ids is None:
780
+ position_ids = build_position_ids(token_type_ids, attention_mask)
781
+
782
+ if past_key_values:
783
+ input_ids = input_ids[:, -1:]
784
+ token_type_ids = token_type_ids[:, -1:]
785
+ position_ids = position_ids[:, -1:]
786
+
787
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
788
+ if inputs_embeds is not None and past_key_values is None:
789
+ model_inputs = {"inputs_embeds": inputs_embeds}
790
+ else:
791
+ model_inputs = {"input_ids": input_ids}
792
+
793
+ model_inputs.update(
794
+ {
795
+ "token_type_ids": token_type_ids,
796
+ "images": images,
797
+ "cross_images": cross_images,
798
+ "position_ids": position_ids,
799
+ "past_key_values": past_key_values,
800
+ "use_cache": kwargs.get("use_cache"),
801
+ "attention_mask": attention_mask,
802
+ }
803
+ )
804
+ return model_inputs
805
+
806
+ def _update_model_kwargs_for_generation(
807
+ self,
808
+ outputs: "ModelOutput",
809
+ model_kwargs: Dict[str, Any],
810
+ is_encoder_decoder: bool = False,
811
+ standardize_cache_format: bool = False,
812
+ ) -> Dict[str, Any]:
813
+ # update past_key_values
814
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
815
+ outputs, standardize_cache_format=standardize_cache_format
816
+ )
817
+ if getattr(outputs, "state", None) is not None:
818
+ model_kwargs["state"] = outputs.state
819
+
820
+ # update token_type_ids with last value
821
+ if "token_type_ids" in model_kwargs:
822
+ token_type_ids = model_kwargs["token_type_ids"]
823
+ new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype, device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
824
+ model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
825
+
826
+ if not is_encoder_decoder:
827
+ # update attention mask
828
+ if "attention_mask" in model_kwargs:
829
+ attention_mask = model_kwargs["attention_mask"]
830
+ model_kwargs["attention_mask"] = torch.cat(
831
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
832
+ )
833
+ else:
834
+ # update decoder attention mask
835
+ if "decoder_attention_mask" in model_kwargs:
836
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
837
+ model_kwargs["decoder_attention_mask"] = torch.cat(
838
+ [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
839
+ dim=-1,
840
+ )
841
+
842
+ return model_kwargs
843
+
844
+ def _reorder_cache(self, past_key_values, beam_idx):
845
+ reordered_past = ()
846
+ for layer_past in past_key_values:
847
+ reordered_past += (
848
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
849
+ )
850
+ return reordered_past
851
+
852
+ def build_conversation_input_ids(
853
+ self,
854
+ tokenizer: "PreTrainedTokenizer",
855
+ *,
856
+ query: str,
857
+ history: Optional[List[Tuple[str, str]]] = None,
858
+ images: Optional[List["PIL.Image"]] = None,
859
+ template_version: Optional[Literal["base", "chat", "vqa"]] = None,
860
+ ):
861
+ image_size: int = self.config.vision_config['image_size']
862
+ cross_image_size: int = self.config.cross_image_size
863
+ patch_size: int = self.config.vision_config['patch_size']
864
+ template_version = template_version or self.config.template_version
865
+ assert images is None or len(images) <= 1, f"not support multi images by now."
866
+ history = history or []
867
+ text = _history_to_prompt[template_version](history, query)
868
+
869
+ input_ids = [tokenizer.bos_token_id]
870
+ token_type_ids = [LANGUAGE_TOKEN_TYPE]
871
+ if images is not None and len(images) == 1:
872
+ ori = images
873
+ # vision
874
+ transform = transforms.Compose(
875
+ [
876
+ transforms.Resize(
877
+ (image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC
878
+ ),
879
+ transforms.ToTensor(),
880
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
881
+ ]
882
+ )
883
+ images = [transform(ori[0])]
884
+ cross_transform = transforms.Compose(
885
+ [
886
+ transforms.Resize(
887
+ (cross_image_size, cross_image_size), interpolation=transforms.InterpolationMode.BICUBIC
888
+ ),
889
+ transforms.ToTensor(),
890
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
891
+ ]
892
+ )
893
+ cross_images = [cross_transform(ori[0])]
894
+ # language
895
+ vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
896
+ input_ids += [tokenizer.pad_token_id] * vision_token_num
897
+ token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
898
+ text_ids = tokenizer.encode(text, add_special_tokens=False)
899
+
900
+ input_ids += text_ids
901
+ token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
902
+ attention_mask = [1] * len(input_ids)
903
+
904
+ return {
905
+ 'input_ids': torch.tensor(input_ids, dtype=torch.long),
906
+ 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
907
+ 'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
908
+ 'images': images,
909
+ 'cross_images': cross_images
910
+ }
util.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ from einops import rearrange, repeat
5
+ import torch.nn.functional as F
6
+
7
+ import triton
8
+ import triton.language as tl
9
+
10
+
11
+ # @triton.autotune(
12
+ # configs=[
13
+ # triton.Config({"BLOCK_M": 2}),
14
+ # triton.Config({"BLOCK_M": 4}),
15
+ # triton.Config({"BLOCK_M": 8}),
16
+ # triton.Config({"BLOCK_M": 16}),
17
+ # ],
18
+ # key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
19
+ # )
20
+ @triton.jit
21
+ def rotary_kernel(
22
+ OUT, # Pointers to matrices
23
+ X,
24
+ COS,
25
+ SIN,
26
+ CU_SEQLENS,
27
+ SEQLEN_OFFSETS, # this could be int or a pointer
28
+ # Matrix dimensions
29
+ seqlen,
30
+ nheads,
31
+ rotary_dim,
32
+ seqlen_ro,
33
+ CACHE_KEY_SEQLEN,
34
+ # strides
35
+ stride_out_batch,
36
+ stride_out_nheads,
37
+ stride_out_seqlen,
38
+ stride_out_headdim,
39
+ stride_x_batch,
40
+ stride_x_nheads,
41
+ stride_x_seqlen,
42
+ stride_x_headdim,
43
+ # Meta-parameters
44
+ BLOCK_K: tl.constexpr,
45
+ IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
46
+ IS_VARLEN: tl.constexpr,
47
+ INTERLEAVED: tl.constexpr,
48
+ CONJUGATE: tl.constexpr,
49
+ BLOCK_M: tl.constexpr,
50
+ ):
51
+ pid_m = tl.program_id(axis=0)
52
+ pid_batch = tl.program_id(axis=1)
53
+ pid_head = tl.program_id(axis=2)
54
+ rotary_dim_half = rotary_dim // 2
55
+
56
+ if not IS_VARLEN:
57
+ X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
58
+ OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
59
+ COS = COS + pid_batch * seqlen_ro * rotary_dim_half
60
+ SIN = SIN + pid_batch * seqlen_ro * rotary_dim_half
61
+ else:
62
+ start_idx = tl.load(CU_SEQLENS + pid_batch)
63
+ seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
64
+ X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
65
+ OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
66
+
67
+ if pid_m * BLOCK_M >= seqlen:
68
+ return
69
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
70
+ if not IS_SEQLEN_OFFSETS_TENSOR:
71
+ rm_cs = rm + SEQLEN_OFFSETS
72
+ else:
73
+ rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
74
+ rk = tl.arange(0, BLOCK_K)
75
+ rk_half = tl.arange(0, BLOCK_K // 2)
76
+
77
+ if not INTERLEAVED:
78
+ # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
79
+ X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
80
+ COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
81
+ SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
82
+ cos = tl.load(
83
+ COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
84
+ )
85
+ sin = tl.load(
86
+ SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
87
+ )
88
+ x0 = tl.load(
89
+ X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
90
+ )
91
+ x1 = tl.load(
92
+ X + rotary_dim_half * stride_x_headdim,
93
+ mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
94
+ other=0.0,
95
+ )
96
+ if CONJUGATE:
97
+ sin = -sin
98
+ o0 = x0 * cos - x1 * sin
99
+ o1 = x0 * sin + x1 * cos
100
+ # write back result
101
+ OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
102
+ tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
103
+ tl.store(
104
+ OUT + rotary_dim_half * stride_out_headdim,
105
+ o1,
106
+ mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
107
+ )
108
+ else:
109
+ # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
110
+ # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
111
+ # Loading x0 will be fast but x1 will be slow.
112
+ # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
113
+ # Then we do the calculation and use tl.where to pick put the right outputs for the even
114
+ # and for the odd indices.
115
+ rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
116
+ rk_repeat = tl.arange(0, BLOCK_K) // 2
117
+ X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
118
+ X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
119
+ COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
120
+ SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
121
+ cos = tl.load(
122
+ COS,
123
+ mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
124
+ other=1.0,
125
+ ).to(tl.float32)
126
+ sin = tl.load(
127
+ SIN,
128
+ mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
129
+ other=0.0,
130
+ ).to(tl.float32)
131
+ x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
132
+ tl.float32
133
+ )
134
+ x1 = tl.load(
135
+ X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
136
+ ).to(tl.float32)
137
+ if CONJUGATE:
138
+ sin = -sin
139
+ x0_cos = x0 * cos
140
+ x1_sin = x1 * sin
141
+ out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
142
+ OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
143
+ tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
144
+
145
+
146
+ def apply_rotary(
147
+ x: torch.Tensor,
148
+ cos: torch.Tensor,
149
+ sin: torch.Tensor,
150
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
151
+ cu_seqlens: Optional[torch.Tensor] = None,
152
+ max_seqlen: Optional[int] = None,
153
+ interleaved=False,
154
+ inplace=False,
155
+ conjugate=False,
156
+ ) -> torch.Tensor:
157
+ """
158
+ Arguments:
159
+ x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
160
+ else (total_seqlen, nheads, headdim).
161
+ cos: (seqlen_ro, rotary_dim / 2)
162
+ sin: (seqlen_ro, rotary_dim / 2)
163
+ seqlen_offsets: integer or integer tensor of size (batch,)
164
+ cu_seqlens: (batch + 1,) or None
165
+ max_seqlen: int
166
+ Returns:
167
+ y: (batch, seqlen, nheads, headdim)
168
+ """
169
+
170
+ batch, nheads, seqlen, headdim = x.shape
171
+
172
+ batch_ro, seqlen_ro, rotary_dim = cos.shape
173
+
174
+ assert batch == batch_ro
175
+ assert sin.shape == cos.shape
176
+ rotary_dim *= 2
177
+ assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
178
+ assert headdim <= 256, "Only support headdim <= 256"
179
+
180
+ assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
181
+
182
+ assert (
183
+ cos.dtype == sin.dtype
184
+ ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
185
+ assert (
186
+ x.dtype == cos.dtype
187
+ ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
188
+
189
+ cos, sin = cos.contiguous(), sin.contiguous()
190
+ if isinstance(seqlen_offsets, torch.Tensor):
191
+ assert seqlen_offsets.shape == (batch,)
192
+ assert seqlen_offsets.dtype in [torch.int32, torch.int64]
193
+ seqlen_offsets = seqlen_offsets.contiguous()
194
+ else:
195
+ assert seqlen_offsets + seqlen <= seqlen_ro
196
+
197
+ output = torch.empty_like(x) if not inplace else x
198
+ if rotary_dim < headdim and not inplace:
199
+ output[..., rotary_dim:].copy_(x[..., rotary_dim:])
200
+
201
+ BLOCK_K = (
202
+ 32
203
+ if rotary_dim <= 32
204
+ else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
205
+ )
206
+ grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
207
+ BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
208
+
209
+ # Need this, otherwise Triton tries to launch from cuda:0 and we get
210
+ # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
211
+ with torch.cuda.device(x.device.index):
212
+ rotary_kernel[grid](
213
+ output, # data ptrs
214
+ x,
215
+ cos,
216
+ sin,
217
+ cu_seqlens,
218
+ seqlen_offsets,
219
+ seqlen, # shapes
220
+ nheads,
221
+ rotary_dim,
222
+ seqlen_ro,
223
+ seqlen // 128, # key for triton cache (limit number of compilations)
224
+ output.stride(0), # batch_strides
225
+ output.stride(-3), # nheads_stride
226
+ output.stride(-2), # seqlen_stride
227
+ output.stride(-1), # headdim_stride
228
+ x.stride(0), # batch_strides
229
+ x.stride(-3), # nheads stride
230
+ x.stride(-2), # seqlen stride
231
+ x.stride(-1), # headdim stride
232
+ BLOCK_K,
233
+ isinstance(seqlen_offsets, torch.Tensor),
234
+ False,
235
+ interleaved,
236
+ conjugate,
237
+ BLOCK_M,
238
+ )
239
+ return output
240
+
241
+
242
+ class ApplyRotaryEmb(torch.autograd.Function):
243
+ @staticmethod
244
+ def forward(
245
+ ctx,
246
+ x,
247
+ cos,
248
+ sin,
249
+ interleaved=False,
250
+ inplace=False,
251
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
252
+ cu_seqlens: Optional[torch.Tensor] = None,
253
+ max_seqlen: Optional[int] = None,
254
+ ):
255
+ out = apply_rotary(
256
+ x,
257
+ cos,
258
+ sin,
259
+ seqlen_offsets=seqlen_offsets,
260
+ cu_seqlens=cu_seqlens,
261
+ max_seqlen=max_seqlen,
262
+ interleaved=interleaved,
263
+ inplace=inplace,
264
+ )
265
+ if isinstance(seqlen_offsets, int):
266
+ ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
267
+ ctx.seqlen_offsets = seqlen_offsets
268
+ else:
269
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
270
+ ctx.seqlen_offsets = None
271
+ ctx.interleaved = interleaved
272
+ ctx.inplace = inplace
273
+ ctx.max_seqlen = max_seqlen
274
+ return out if not inplace else x
275
+
276
+ @staticmethod
277
+ def backward(ctx, do):
278
+ seqlen_offsets = ctx.seqlen_offsets
279
+ if seqlen_offsets is None:
280
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
281
+ else:
282
+ cos, sin, cu_seqlens = ctx.saved_tensors
283
+ # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
284
+ # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
285
+ if not ctx.interleaved and not ctx.inplace:
286
+ do = do.clone()
287
+ dx = apply_rotary(
288
+ do,
289
+ cos,
290
+ sin,
291
+ seqlen_offsets=seqlen_offsets,
292
+ cu_seqlens=cu_seqlens,
293
+ max_seqlen=ctx.max_seqlen,
294
+ interleaved=ctx.interleaved,
295
+ inplace=ctx.inplace,
296
+ conjugate=True,
297
+ )
298
+ return dx, None, None, None, None, None, None, None
299
+
300
+
301
+ def apply_rotary_emb(
302
+ x,
303
+ cos,
304
+ sin,
305
+ interleaved=False,
306
+ inplace=False,
307
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
308
+ cu_seqlens: Optional[torch.Tensor] = None,
309
+ max_seqlen: Optional[int] = None,
310
+ ):
311
+ """
312
+ Arguments:
313
+ x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
314
+ else (total_seqlen, nheads, headdim)
315
+ cos, sin: (seqlen_rotary, rotary_dim / 2)
316
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
317
+ of 1st half and 2nd half (GPT-NeoX style).
318
+ inplace: if True, apply rotary embedding in-place.
319
+ seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
320
+ Most commonly used in inference when we have KV cache.
321
+ cu_seqlens: (batch + 1,) or None
322
+ max_seqlen: int
323
+ Return:
324
+ out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
325
+ else (total_seqlen, nheads, headdim)
326
+ rotary_dim must be <= headdim
327
+ Apply rotary embedding to the first rotary_dim of x.
328
+ """
329
+ return ApplyRotaryEmb.apply(
330
+ x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
331
+ )
332
+
333
+
334
+ # For backward compatibility
335
+ apply_rotary_emb_func = apply_rotary_emb
336
+
337
+
338
+ class FastRotaryEmbedding(torch.nn.Module):
339
+ """
340
+ The rotary position embeddings from RoFormer_ (Su et. al).
341
+ A crucial insight from the method is that the query and keys are
342
+ transformed by rotation matrices which depend on the relative positions.
343
+
344
+ Other implementations are available in the Rotary Transformer repo_ and in
345
+ GPT-NeoX_, GPT-NeoX was an inspiration
346
+
347
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
348
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
349
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
350
+
351
+ If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
352
+ A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
353
+ Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
354
+ """
355
+
356
+ def __init__(
357
+ self,
358
+ dim: int,
359
+ base=10000,
360
+ interleaved=False,
361
+ scale_base=None,
362
+ pos_idx_in_fp32=True,
363
+ device=None,
364
+ ):
365
+ """
366
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
367
+ of 1st half and 2nd half (GPT-NeoX style).
368
+ pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
369
+ otherwise they might be in lower precision.
370
+ This option was added because previously (before 2023-07-02), when we construct
371
+ the position indices, we use the dtype of self.inv_freq. In most cases this would
372
+ be fp32, but if the model is trained in pure bf16 (not mixed precision), then
373
+ self.inv_freq would be bf16, and the position indices are also in bf16.
374
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
375
+ embeddings for some positions will coincide.
376
+ To maintain compatibility with models previously trained in pure bf16,
377
+ we add this option.
378
+ """
379
+ super().__init__()
380
+ self.dim = dim
381
+ self.base = base
382
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
383
+ # Generate and save the inverse frequency buffer (non trainable)
384
+ inv_freq = self._compute_inv_freq(device)
385
+ self.register_buffer("inv_freq", inv_freq)
386
+ self.interleaved = interleaved
387
+ self.scale_base = scale_base
388
+ scale = (
389
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
390
+ if scale_base is not None
391
+ else None
392
+ )
393
+ self.register_buffer("scale", scale, persistent=False)
394
+
395
+ self._seq_len_cached = 0
396
+ self._cos_cached = None
397
+ self._sin_cached = None
398
+ self._cos_k_cached = None
399
+ self._sin_k_cached = None
400
+ self.cos = None
401
+ self.sin = None
402
+
403
+ def _compute_inv_freq(self, device=None):
404
+ return 1.0 / (
405
+ self.base
406
+ ** (torch.arange(0, self.dim, 2, device=device) / self.dim)
407
+ # ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)
408
+ )
409
+
410
+ def _update_cos_sin_cache(self, seqlen, position_id, device=None, dtype=None):
411
+
412
+ if (
413
+ seqlen > self._seq_len_cached
414
+ ):
415
+ self._seq_len_cached = seqlen
416
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
417
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
418
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
419
+ if self.pos_idx_in_fp32:
420
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
421
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
422
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
423
+ # cos & sin output to change significantly.
424
+ # We want to recompute self.inv_freq if it was not loaded in fp32
425
+ if self.inv_freq.dtype != torch.float32:
426
+ inv_freq = self._compute_inv_freq(device=device)
427
+ else:
428
+ inv_freq = self.inv_freq
429
+ else:
430
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
431
+ inv_freq = self.inv_freq
432
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
433
+ if self.scale is None:
434
+ self._cos_cached = torch.cos(freqs).to(dtype)
435
+ self._sin_cached = torch.sin(freqs).to(dtype)
436
+
437
+ else:
438
+ power = (
439
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
440
+ - seqlen // 2
441
+ ) / self.scale_base
442
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
443
+ # We want the multiplication by scale to happen in fp32
444
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
445
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
446
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
447
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
448
+
449
+ def forward(
450
+ self,
451
+ q: torch.Tensor,
452
+ k: torch.Tensor,
453
+ position_ids: torch.Tensor,
454
+ max_seqlen,
455
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
456
+ """
457
+ q: (batch, nheads, seqlen, headdim)
458
+ k: (batch, nheads, seqlen, headdim)
459
+ position_id: (batch, seqlen)
460
+ max_seqlen: int
461
+ layer_id: int
462
+ only if layer_id == 0, then update cons and sin
463
+ Apply rotary embedding *inplace* to q k.
464
+ """
465
+
466
+ self._update_cos_sin_cache(max_seqlen, position_ids, device=q.device, dtype=q.dtype)
467
+ cos, sin = F.embedding(position_ids, self._cos_cached), F.embedding(position_ids, self._sin_cached)
468
+
469
+ q = apply_rotary_emb_func(
470
+ q,
471
+ cos,
472
+ sin,
473
+ interleaved=self.interleaved,
474
+ inplace=True
475
+ )
476
+ k = apply_rotary_emb_func(
477
+ k,
478
+ cos,
479
+ sin,
480
+ interleaved=self.interleaved,
481
+ inplace=True
482
+ )
483
+ return q, k
visual.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from argparse import Namespace
4
+ import xformers.ops as xops
5
+ from transformers.activations import ACT2FN
6
+
7
+
8
+ class PatchEmbedding(nn.Module):
9
+ def __init__(self, config):
10
+ super().__init__()
11
+ self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size)
12
+ self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
13
+ self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
14
+
15
+ def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
16
+ x = self.proj(images)
17
+ x = x.flatten(2).transpose(1, 2)
18
+ cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
19
+ x = torch.cat((cls_token, x), dim=1)
20
+ x += self.position_embedding.weight.unsqueeze(0)
21
+ return x
22
+
23
+
24
+ class Attention(nn.Module):
25
+ def __init__(self, config):
26
+ super().__init__()
27
+ self.num_heads = config.num_heads
28
+ head_dim = config.hidden_size // config.num_heads
29
+ self.scale = head_dim ** -0.5
30
+ self.query_key_value = nn.Linear(config.hidden_size, config.hidden_size * 3)
31
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
32
+ self.output_dropout = torch.nn.Dropout(config.dropout_prob)
33
+
34
+ def forward(self, x: "tensor(B, L, D)") -> "tensor(B, L, D)":
35
+ B, L, _ = x.shape
36
+ qkv = self.query_key_value(x)
37
+ qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) # 3, B, L, H, D
38
+ q, k, v = qkv[0], qkv[1], qkv[2]
39
+
40
+ out = xops.memory_efficient_attention(
41
+ q, k, v, scale=self.scale,
42
+ )
43
+ output = self.dense(out.view(B, L, -1))
44
+ output = self.output_dropout(output)
45
+ return output
46
+
47
+ def attention(self, q, k, v):
48
+ attn_weights = torch.matmul(q * self.scale, k.transpose(-2, -1))
49
+ attn_weights = attn_weights.softmax(dim=-1)
50
+ output = torch.matmul(attn_weights, v)
51
+ return output
52
+
53
+
54
+ class MLP(nn.Module):
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.config = config
58
+ self.activation_fn = ACT2FN[config.hidden_act]
59
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
60
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
61
+
62
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
63
+ x = self.fc1(x)
64
+ x = self.activation_fn(x)
65
+ x = self.fc2(x)
66
+ return x
67
+
68
+
69
+ class TransformerLayer(nn.Module):
70
+ def __init__(self, config):
71
+ super().__init__()
72
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
73
+ self.attention = Attention(config)
74
+ self.mlp = MLP(config)
75
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
76
+
77
+ def forward(self, hidden_states):
78
+ attention_input = hidden_states
79
+ attention_output = self.input_layernorm(self.attention(attention_input))
80
+ hidden_states = attention_input + attention_output
81
+ mlp_input = hidden_states
82
+ mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
83
+ output = mlp_input + mlp_output
84
+ return output
85
+
86
+
87
+ class Transformer(nn.Module):
88
+ def __init__(self, config):
89
+ super().__init__()
90
+ self.layers = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)])
91
+
92
+ def forward(self, hidden_states):
93
+ for layer_module in self.layers:
94
+ hidden_states = layer_module(hidden_states)
95
+ return hidden_states
96
+
97
+
98
+ class GLU(nn.Module):
99
+ def __init__(self, config, in_features):
100
+ super().__init__()
101
+ self.linear_proj = nn.Linear(in_features, config.hidden_size, bias=False)
102
+ self.norm1 = nn.LayerNorm(config.hidden_size)
103
+ self.act1 = nn.GELU()
104
+ self.act2 = nn.functional.silu
105
+ self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
106
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
107
+ self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
108
+
109
+ def forward(self, x):
110
+ x = self.linear_proj(x)
111
+ x = self.act1(self.norm1(x))
112
+ x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
113
+ x = self.dense_4h_to_h(x)
114
+ return x
115
+
116
+
117
+ class EVA2CLIPModel(nn.Module):
118
+ def __init__(self, config):
119
+ super().__init__()
120
+ vision_config = Namespace(**config.vision_config)
121
+ self.patch_embedding = PatchEmbedding(vision_config)
122
+ self.transformer = Transformer(vision_config)
123
+ self.linear_proj = GLU(config, in_features=vision_config.hidden_size)
124
+ self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
125
+ self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
126
+ self.pos_embed = nn.Parameter(torch.zeros((vision_config.image_size // vision_config.patch_size) ** 2, vision_config.hidden_size))
127
+
128
+ def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
129
+ x = self.patch_embedding(images)
130
+ x = self.transformer(x)
131
+ x = x[:, 1:]
132
+ x = self.linear_proj(x + self.pos_embed.unsqueeze(0))
133
+ boi = self.boi.expand(x.shape[0], -1, -1)
134
+ eoi = self.eoi.expand(x.shape[0], -1, -1)
135
+ x = torch.cat((boi, x, eoi), dim=1)
136
+ return x