yeelou commited on
Commit
4937272
·
verified ·
1 Parent(s): ecb9834

Upload 4 files

Browse files
Files changed (4) hide show
  1. configuration_cogagent.py +21 -21
  2. cross_visual.py +460 -191
  3. modeling_cogagent.py +434 -210
  4. visual.py +41 -13
configuration_cogagent.py CHANGED
@@ -1,4 +1,5 @@
1
  from typing import Literal
 
2
  from transformers import PretrainedConfig
3
 
4
 
@@ -6,27 +7,26 @@ class CogAgentConfig(PretrainedConfig):
6
  _auto_class = "AutoConfig"
7
 
8
  def __init__(
9
- self,
10
- vocab_size=32000,
11
- hidden_size=4096,
12
- cross_hidden_size=1024,
13
- cross_compute_hidden_size=1024,
14
- cross_image_size=1120,
15
- intermediate_size=11008,
16
- num_hidden_layers=32,
17
- num_attention_heads=32,
18
- hidden_act='silu',
19
- max_position_embeddings=2048,
20
- initializer_range=0.02,
21
- rms_norm_eps=1e-06,
22
- template_version: Literal["base", "chat"] = "chat",
23
-
24
- pad_token_id=0,
25
- bos_token_id=1,
26
- eos_token_id=2,
27
- tie_word_embeddings=False,
28
- use_cache=True,
29
- **kwargs,
30
  ):
31
  self.hidden_size = hidden_size
32
  self.cross_hidden_size = cross_hidden_size
 
1
  from typing import Literal
2
+
3
  from transformers import PretrainedConfig
4
 
5
 
 
7
  _auto_class = "AutoConfig"
8
 
9
  def __init__(
10
+ self,
11
+ vocab_size=32000,
12
+ hidden_size=4096,
13
+ cross_hidden_size=1024,
14
+ cross_compute_hidden_size=1024,
15
+ cross_image_size=1120,
16
+ intermediate_size=11008,
17
+ num_hidden_layers=32,
18
+ num_attention_heads=32,
19
+ hidden_act="silu",
20
+ max_position_embeddings=2048,
21
+ initializer_range=0.02,
22
+ rms_norm_eps=1e-06,
23
+ template_version: Literal["base", "chat"] = "chat",
24
+ pad_token_id=0,
25
+ bos_token_id=1,
26
+ eos_token_id=2,
27
+ tie_word_embeddings=False,
28
+ use_cache=True,
29
+ **kwargs,
 
30
  ):
31
  self.hidden_size = hidden_size
32
  self.cross_hidden_size = cross_hidden_size
cross_visual.py CHANGED
@@ -1,30 +1,36 @@
 
1
  from math import pi
 
2
  import torch
3
- from torch import nn
4
  from einops import rearrange, repeat
5
- import logging
 
6
 
7
- def broadcat(tensors, dim = -1):
8
  num_tensors = len(tensors)
9
  shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
10
- assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
11
  shape_len = list(shape_lens)[0]
12
  dim = (dim + shape_len) if dim < 0 else dim
13
  dims = list(zip(*map(lambda t: list(t.shape), tensors)))
14
  expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
15
- assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
 
 
16
  max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
17
  expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
18
  expanded_dims.insert(dim, (dim, dims[dim]))
19
  expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
20
  tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
21
- return torch.cat(tensors, dim = dim)
 
22
 
23
  def rotate_half(x):
24
- x = rearrange(x, '... (d r) -> ... d r', r = 2)
25
- x1, x2 = x.unbind(dim = -1)
26
- x = torch.stack((-x2, x1), dim = -1)
27
- return rearrange(x, '... d r -> ... (d r)')
 
28
 
29
  class VisionRotaryEmbeddingFast(nn.Module):
30
  def __init__(
@@ -32,31 +38,34 @@ class VisionRotaryEmbeddingFast(nn.Module):
32
  dim,
33
  pt_seq_len,
34
  ft_seq_len=None,
35
- custom_freqs = None,
36
- freqs_for = 'lang',
37
- theta = 10000,
38
- max_freq = 10,
39
- num_freqs = 1,
40
- patch_dropout = 0.
41
  ):
42
  super().__init__()
43
  if custom_freqs:
44
  freqs = custom_freqs
45
- elif freqs_for == 'lang':
46
- freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
47
- elif freqs_for == 'pixel':
48
- freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
49
- elif freqs_for == 'constant':
 
 
50
  freqs = torch.ones(num_freqs).float()
51
  else:
52
- raise ValueError(f'unknown modality {freqs_for}')
53
 
54
- if ft_seq_len is None: ft_seq_len = pt_seq_len
 
55
  t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
56
 
57
- freqs = torch.einsum('..., f -> ... f', t, freqs)
58
- freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
59
- freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
60
 
61
  freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
62
  freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
@@ -66,7 +75,7 @@ class VisionRotaryEmbeddingFast(nn.Module):
66
  self.register_buffer("freqs_cos", freqs_cos)
67
  self.register_buffer("freqs_sin", freqs_sin)
68
 
69
- logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
70
 
71
  def forward(self, t, patch_indices_keep=None):
72
  if patch_indices_keep is not None:
@@ -74,44 +83,46 @@ class VisionRotaryEmbeddingFast(nn.Module):
74
  batch_indices = torch.arange(batch)
75
  batch_indices = batch_indices[..., None]
76
 
77
- freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
78
- freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
 
 
 
 
79
 
80
  freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
81
- freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
82
  freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
83
- freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
84
 
85
- return t * freqs_cos + rotate_half(t) * freqs_sin
86
 
87
- return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
88
 
89
- import torch.nn as nn
90
- import os
91
- from dataclasses import dataclass
92
- from typing import Optional, Tuple, Union
93
- from functools import partial
94
 
95
- import numpy as np
96
- import torch
97
- import torch.nn.functional as F
98
- from torch import nn
99
 
100
  # --------------------------------------------------------
101
  # Adapted from https://github.com/microsoft/unilm/tree/master/beit
102
  # --------------------------------------------------------
103
  import math
104
  import os
 
105
  from functools import partial
 
 
 
106
  import torch
107
  import torch.nn as nn
108
  import torch.nn.functional as F
109
- import logging
 
110
  try:
111
  from timm.models.layers import drop_path, to_2tuple, trunc_normal_
112
  except:
113
  from timm.layers import drop_path, to_2tuple, trunc_normal_
114
-
 
115
  class PatchDropout(nn.Module):
116
  """
117
  https://arxiv.org/abs/2212.00794
@@ -119,13 +130,13 @@ class PatchDropout(nn.Module):
119
 
120
  def __init__(self, prob, exclude_first_token=True):
121
  super().__init__()
122
- assert 0 <= prob < 1.
123
  self.prob = prob
124
  self.exclude_first_token = exclude_first_token # exclude CLS token
125
  logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
126
 
127
  def forward(self, x):
128
- if not self.training or self.prob == 0.:
129
  return x
130
 
131
  if self.exclude_first_token:
@@ -150,12 +161,13 @@ class PatchDropout(nn.Module):
150
  if self.exclude_first_token:
151
  x = torch.cat((cls_tokens, x), dim=1)
152
 
153
- if self.training and os.getenv('RoPE') == '1':
154
  return x, patch_indices_keep
155
 
156
  return x
157
 
158
- if os.getenv('ENV_TYPE') == 'deepspeed':
 
159
  try:
160
  from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
161
  except:
@@ -165,32 +177,32 @@ else:
165
 
166
  import xformers.ops as xops
167
 
 
168
  class DropPath(nn.Module):
169
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
170
- """
171
  def __init__(self, drop_prob=None):
172
  super(DropPath, self).__init__()
173
  self.drop_prob = drop_prob
174
 
175
  def forward(self, x):
176
  return drop_path(x, self.drop_prob, self.training)
177
-
178
  def extra_repr(self) -> str:
179
- return 'p={}'.format(self.drop_prob)
180
 
181
 
182
  class Mlp(nn.Module):
183
  def __init__(
184
- self,
185
- in_features,
186
- hidden_features=None,
187
- out_features=None,
188
- act_layer=nn.GELU,
189
- norm_layer=nn.LayerNorm,
190
- drop=0.,
191
  subln=False,
192
-
193
- ):
194
  super().__init__()
195
  out_features = out_features or in_features
196
  hidden_features = hidden_features or in_features
@@ -206,16 +218,25 @@ class Mlp(nn.Module):
206
  x = self.fc1(x)
207
  x = self.act(x)
208
  # x = self.drop(x)
209
- # commit this for the orignal BERT implement
210
  x = self.ffn_ln(x)
211
 
212
  x = self.fc2(x)
213
  x = self.drop(x)
214
  return x
215
 
 
216
  class SwiGLU(nn.Module):
217
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
218
- norm_layer=nn.LayerNorm, subln=False):
 
 
 
 
 
 
 
 
219
  super().__init__()
220
  out_features = out_features or in_features
221
  hidden_features = hidden_features or in_features
@@ -226,7 +247,7 @@ class SwiGLU(nn.Module):
226
  self.act = act_layer()
227
  self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
228
  self.w3 = nn.Linear(hidden_features, out_features)
229
-
230
  self.drop = nn.Dropout(drop)
231
 
232
  def forward(self, x):
@@ -238,17 +259,30 @@ class SwiGLU(nn.Module):
238
  x = self.drop(x)
239
  return x
240
 
 
241
  class Attention(nn.Module):
242
  def __init__(
243
- self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
244
- proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
 
 
 
 
 
 
 
 
 
 
 
 
245
  super().__init__()
246
  self.num_heads = num_heads
247
  head_dim = dim // num_heads
248
  if attn_head_dim is not None:
249
  head_dim = attn_head_dim
250
  all_head_dim = head_dim * self.num_heads
251
- self.scale = qk_scale or head_dim ** -0.5
252
 
253
  self.subln = subln
254
  if self.subln:
@@ -267,9 +301,12 @@ class Attention(nn.Module):
267
 
268
  if window_size:
269
  self.window_size = window_size
270
- self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
 
 
271
  self.relative_position_bias_table = nn.Parameter(
272
- torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
 
273
  # cls to token & token 2 cls & cls to cls
274
 
275
  # get pair-wise relative position index for each token inside the window
@@ -277,13 +314,19 @@ class Attention(nn.Module):
277
  coords_w = torch.arange(window_size[1])
278
  coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
279
  coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
280
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
281
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
 
 
 
 
282
  relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
283
  relative_coords[:, :, 1] += window_size[1] - 1
284
  relative_coords[:, :, 0] *= 2 * window_size[1] - 1
285
- relative_position_index = \
286
- torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
 
 
287
  relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
288
  relative_position_index[0, 0:] = self.num_relative_distance - 3
289
  relative_position_index[0:, 0] = self.num_relative_distance - 2
@@ -307,28 +350,88 @@ class Attention(nn.Module):
307
 
308
  def forward(self, x, rel_pos_bias=None, attn_mask=None):
309
  B, N, C = x.shape
310
- if self.subln:
 
311
  if self.q_proj.weight.dtype == torch.uint8:
312
  import bitsandbytes as bnb
313
- q = bnb.matmul_4bit(x, self.q_proj.weight.t(), bias=self.q_bias, quant_state=self.q_proj.weight.quant_state)
314
- k = bnb.matmul_4bit(x, self.k_proj.weight.t(), bias=None, quant_state=self.k_proj.weight.quant_state)
315
- v = bnb.matmul_4bit(x, self.v_proj.weight.t(), bias=self.v_bias, quant_state=self.v_proj.weight.quant_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  else:
317
  q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
318
  k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
319
  v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
320
 
321
- q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
322
- k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
323
- v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
324
- else:
 
 
325
 
326
  qkv_bias = None
327
  if self.q_bias is not None:
328
- qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
329
-
 
 
 
 
 
 
330
  qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
331
- qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
 
 
332
  q, k, v = qkv[0], qkv[1], qkv[2]
333
 
334
  if self.rope:
@@ -342,29 +445,36 @@ class Attention(nn.Module):
342
  k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
343
 
344
  if self.xattn:
345
- q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
346
  k = k.permute(0, 2, 1, 3)
347
  v = v.permute(0, 2, 1, 3)
348
 
349
  x = xops.memory_efficient_attention(
350
- q, k, v,
 
 
351
  p=self.xattn_drop,
352
  scale=self.scale,
353
- )
354
  x = x.reshape(B, N, -1)
355
  x = self.inner_attn_ln(x)
356
  x = self.proj(x)
357
  x = self.proj_drop(x)
358
  else:
359
  q = q * self.scale
360
- attn = (q @ k.transpose(-2, -1))
361
 
362
  if self.relative_position_bias_table is not None:
363
- relative_position_bias = \
364
- self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
365
- self.window_size[0] * self.window_size[1] + 1,
366
- self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
367
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
 
 
 
 
 
368
  attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
369
 
370
  if rel_pos_bias is not None:
@@ -373,7 +483,7 @@ class Attention(nn.Module):
373
  if attn_mask is not None:
374
  attn_mask = attn_mask.bool()
375
  attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
376
-
377
  attn = attn.softmax(dim=-1)
378
  attn = self.attn_drop(attn)
379
 
@@ -386,40 +496,71 @@ class Attention(nn.Module):
386
 
387
  class Block(nn.Module):
388
 
389
- def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
390
- drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
391
- window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
392
- subln=False, naiveswiglu=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  super().__init__()
394
  self.norm1 = norm_layer(dim)
395
  self.attn = Attention(
396
- dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
397
- attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
398
- xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
 
 
 
 
 
 
 
 
 
 
399
  # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
400
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
401
  self.norm2 = norm_layer(dim)
402
  mlp_hidden_dim = int(dim * mlp_ratio)
403
 
404
  if naiveswiglu:
405
  self.mlp = SwiGLU(
406
- in_features=dim,
407
- hidden_features=mlp_hidden_dim,
408
  subln=subln,
409
  norm_layer=norm_layer,
410
  )
411
  else:
412
  self.mlp = Mlp(
413
- in_features=dim,
414
- hidden_features=mlp_hidden_dim,
415
  act_layer=act_layer,
416
  subln=subln,
417
- drop=drop
418
  )
419
 
420
  if init_values is not None and init_values > 0:
421
- self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
422
- self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
 
 
 
 
423
  else:
424
  self.gamma_1, self.gamma_2 = None, None
425
 
@@ -428,24 +569,42 @@ class Block(nn.Module):
428
  def forward(self, x, rel_pos_bias=None, attn_mask=None):
429
  if self.gamma_1 is None:
430
  if self.postnorm:
431
- x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
 
 
 
 
432
  x = x + self.drop_path(self.norm2(self.mlp(x)))
433
  else:
434
- x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
 
 
 
 
435
  x = x + self.drop_path(self.mlp(self.norm2(x)))
436
  else:
437
  if self.postnorm:
438
- x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
 
 
 
 
 
439
  x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
440
  else:
441
- x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
 
 
 
 
 
442
  x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
443
  return x
444
 
445
 
446
  class PatchEmbed(nn.Module):
447
- """ Image to Patch Embedding
448
- """
449
  def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
450
  super().__init__()
451
  img_size = to_2tuple(img_size)
@@ -456,13 +615,16 @@ class PatchEmbed(nn.Module):
456
  self.patch_size = patch_size
457
  self.num_patches = num_patches
458
 
459
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
 
 
460
 
461
  def forward(self, x, **kwargs):
462
  B, C, H, W = x.shape
463
  # FIXME look at relaxing size constraints
464
- assert H == self.img_size[0] and W == self.img_size[1], \
465
- f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
 
466
  x = self.proj(x).flatten(2).transpose(1, 2)
467
  return x
468
 
@@ -472,9 +634,12 @@ class RelativePositionBias(nn.Module):
472
  def __init__(self, window_size, num_heads):
473
  super().__init__()
474
  self.window_size = window_size
475
- self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
 
 
476
  self.relative_position_bias_table = nn.Parameter(
477
- torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
 
478
  # cls to token & token 2 cls & cls to cls
479
 
480
  # get pair-wise relative position index for each token inside the window
@@ -482,13 +647,18 @@ class RelativePositionBias(nn.Module):
482
  coords_w = torch.arange(window_size[1])
483
  coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
484
  coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
485
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
486
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
 
 
 
 
487
  relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
488
  relative_coords[:, :, 1] += window_size[1] - 1
489
  relative_coords[:, :, 0] *= 2 * window_size[1] - 1
490
- relative_position_index = \
491
- torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
 
492
  relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
493
  relative_position_index[0, 0:] = self.num_relative_distance - 3
494
  relative_position_index[0:, 0] = self.num_relative_distance - 2
@@ -497,29 +667,64 @@ class RelativePositionBias(nn.Module):
497
  self.register_buffer("relative_position_index", relative_position_index)
498
 
499
  def forward(self):
500
- relative_position_bias = \
501
- self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
502
- self.window_size[0] * self.window_size[1] + 1,
503
- self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
 
 
 
504
  return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
505
 
506
 
507
  class EVAVisionTransformer(nn.Module):
508
- """ Vision Transformer with support for patch or hybrid CNN input stage
509
- """
510
- def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
511
- num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
512
- drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
513
- use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
514
- use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
515
- pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
516
  super().__init__()
517
  self.image_size = img_size
518
  self.num_classes = num_classes
519
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
 
 
520
 
521
  self.patch_embed = PatchEmbed(
522
- img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
 
 
 
 
523
  num_patches = self.patch_embed.num_patches
524
 
525
  self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
@@ -531,7 +736,9 @@ class EVAVisionTransformer(nn.Module):
531
  self.pos_drop = nn.Dropout(p=drop_rate)
532
 
533
  if use_shared_rel_pos_bias:
534
- self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
 
 
535
  else:
536
  self.rel_pos_bias = None
537
 
@@ -544,40 +751,64 @@ class EVAVisionTransformer(nn.Module):
544
  ft_seq_len=hw_seq_len if intp_freq else None,
545
  # patch_dropout=patch_dropout
546
  )
547
- else:
548
  self.rope = None
549
 
550
  self.naiveswiglu = naiveswiglu
551
 
552
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
 
 
553
  self.use_rel_pos_bias = use_rel_pos_bias
554
- self.blocks = nn.ModuleList([
555
- Block(
556
- dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
557
- drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
558
- init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
559
- xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
560
- for i in range(depth)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561
  self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
562
  self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
563
- self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
 
 
564
 
565
  if self.pos_embed is not None:
566
- trunc_normal_(self.pos_embed, std=.02)
567
 
568
- trunc_normal_(self.cls_token, std=.02)
569
  # trunc_normal_(self.mask_token, std=.02)
570
 
571
  self.apply(self._init_weights)
572
  self.fix_init_weight()
573
 
574
  if isinstance(self.head, nn.Linear):
575
- trunc_normal_(self.head.weight, std=.02)
576
  self.head.weight.data.mul_(init_scale)
577
  self.head.bias.data.mul_(init_scale)
578
 
579
  # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
580
- self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
 
 
581
 
582
  self.grad_checkpointing = grad_checkpointing
583
 
@@ -597,7 +828,7 @@ class EVAVisionTransformer(nn.Module):
597
 
598
  def _init_weights(self, m):
599
  if isinstance(m, nn.Linear):
600
- trunc_normal_(m.weight, std=.02)
601
  if m.bias is not None:
602
  nn.init.constant_(m.bias, 0)
603
  elif isinstance(m, nn.LayerNorm):
@@ -606,9 +837,11 @@ class EVAVisionTransformer(nn.Module):
606
 
607
  def get_num_layers(self):
608
  return len(self.blocks)
609
-
610
  def lock(self, unlocked_groups=0, freeze_bn_stats=False):
611
- assert unlocked_groups == 0, 'partial locking not currently supported for this model'
 
 
612
  for param in self.parameters():
613
  param.requires_grad = False
614
 
@@ -618,31 +851,37 @@ class EVAVisionTransformer(nn.Module):
618
 
619
  @torch.jit.ignore
620
  def no_weight_decay(self):
621
- return {'pos_embed', 'cls_token'}
622
 
623
  def get_classifier(self):
624
  return self.head
625
 
626
- def reset_classifier(self, num_classes, global_pool=''):
627
  self.num_classes = num_classes
628
- self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
 
 
629
 
630
  def forward_features(self, x, return_all_features=False):
631
-
632
  x = self.patch_embed(x)
633
  batch_size, seq_len, _ = x.size()
634
 
635
- cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
 
 
636
  x = torch.cat((cls_tokens, x), dim=1)
637
  if self.pos_embed is not None:
638
  x = x + self.pos_embed
639
  x = self.pos_drop(x)
640
 
641
  # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
642
- if os.getenv('RoPE') == '1':
643
  if self.training and not isinstance(self.patch_dropout, nn.Identity):
644
  x, patch_indices_keep = self.patch_dropout(x)
645
- self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
 
 
646
  else:
647
  self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
648
  x = self.patch_dropout(x)
@@ -651,7 +890,7 @@ class EVAVisionTransformer(nn.Module):
651
 
652
  rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
653
  for i, blk in enumerate(self.blocks):
654
- if i == len(self.blocks)-1:
655
  continue
656
  if self.grad_checkpointing:
657
  x = checkpoint(blk, x, (rel_pos_bias,))
@@ -673,6 +912,7 @@ class EVAVisionTransformer(nn.Module):
673
  x = self.head(x)
674
  return x
675
 
 
676
  class LayerNorm(nn.LayerNorm):
677
  """Subclass torch's LayerNorm (with cast back to input dtype)."""
678
 
@@ -681,6 +921,7 @@ class LayerNorm(nn.LayerNorm):
681
  x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
682
  return x.to(orig_type)
683
 
 
684
  try:
685
  from apex.normalization import FusedLayerNorm
686
  except:
@@ -697,30 +938,41 @@ class CLIPVisionCfg:
697
  patch_size: int = 16
698
  image_size: Union[Tuple[int, int], int] = 224
699
  ls_init_value: Optional[float] = None # layer scale initial value
700
- patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
701
- global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
 
 
 
 
702
  drop_path_rate: Optional[float] = None # drop path rate
703
- timm_model_name: str = None # a valid model name overrides layers, width, patch_size
704
- timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
705
- timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
706
- timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
 
 
 
 
 
 
 
 
707
  timm_proj_bias: bool = False # enable bias final projection
708
- eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
 
 
709
  qkv_bias: bool = True
710
  fusedLN: bool = False
711
  xattn: bool = False
712
  postnorm: bool = False
713
  rope: bool = False
714
- pt_hw_seq_len: int = 16 # 224/14
715
  intp_freq: bool = False
716
  naiveswiglu: bool = False
717
  subln: bool = False
718
 
719
 
720
- def _build_vision_tower(
721
- embed_dim: int,
722
- vision_cfg: CLIPVisionCfg
723
- ):
724
  if isinstance(vision_cfg, dict):
725
  vision_cfg = CLIPVisionCfg(**vision_cfg)
726
 
@@ -731,7 +983,7 @@ def _build_vision_tower(
731
  img_size=vision_cfg.image_size,
732
  patch_size=vision_cfg.patch_size,
733
  num_classes=embed_dim,
734
- use_mean_pooling=vision_cfg.global_average_pool, #False
735
  init_values=vision_cfg.ls_init_value,
736
  patch_dropout=vision_cfg.patch_dropout,
737
  embed_dim=vision_cfg.width,
@@ -740,18 +992,23 @@ def _build_vision_tower(
740
  mlp_ratio=vision_cfg.mlp_ratio,
741
  qkv_bias=vision_cfg.qkv_bias,
742
  drop_path_rate=vision_cfg.drop_path_rate,
743
- norm_layer= partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6),
 
 
 
 
744
  xattn=vision_cfg.xattn,
745
  rope=vision_cfg.rope,
746
  postnorm=vision_cfg.postnorm,
747
- pt_hw_seq_len= vision_cfg.pt_hw_seq_len, # 224/14
748
- intp_freq= vision_cfg.intp_freq,
749
- naiveswiglu= vision_cfg.naiveswiglu,
750
- subln= vision_cfg.subln
751
  )
752
 
753
  return visual
754
 
 
755
  class Eva2LargeEncoder(nn.Module):
756
  def __init__(self, image_size=224):
757
  super(Eva2LargeEncoder, self).__init__()
@@ -772,26 +1029,38 @@ class Eva2LargeEncoder(nn.Module):
772
  "pt_hw_seq_len": 16,
773
  "intp_freq": True,
774
  "naiveswiglu": True,
775
- "subln": True
776
- }
777
  }
778
- self.config['vision_cfg']['image_size'] = image_size
779
-
780
  import os
781
- os.environ['delRoPE'] = '1' # to avoid error in rope params when changing image size
782
- self.model = _build_vision_tower(**self.config)
783
 
 
 
 
 
784
 
785
  def forward(self, images):
786
  encode = self.model(images, return_all_features=True)[:, 1:, :]
787
  return encode
788
 
 
789
  class CrossVisionModel(nn.Module):
790
  def __init__(self, config):
791
  super().__init__()
792
  self.vit = Eva2LargeEncoder(image_size=config.cross_image_size)
793
- self.pos_embed = nn.Parameter(torch.zeros((self.vit.config['vision_cfg']['image_size'] // self.vit.config['vision_cfg']['patch_size']) ** 2, self.vit.config['vision_cfg']['width']))
 
 
 
 
 
 
 
 
 
794
 
795
  def forward(self, images):
796
  enc = self.vit(images)
797
- return enc + self.pos_embed.to(enc.device).unsqueeze(0)
 
1
+ import logging
2
  from math import pi
3
+
4
  import torch
 
5
  from einops import rearrange, repeat
6
+ from torch import nn
7
+
8
 
9
+ def broadcat(tensors, dim=-1):
10
  num_tensors = len(tensors)
11
  shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
12
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
13
  shape_len = list(shape_lens)[0]
14
  dim = (dim + shape_len) if dim < 0 else dim
15
  dims = list(zip(*map(lambda t: list(t.shape), tensors)))
16
  expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
17
+ assert all(
18
+ [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
19
+ ), "invalid dimensions for broadcastable concatentation"
20
  max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
21
  expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
22
  expanded_dims.insert(dim, (dim, dims[dim]))
23
  expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
24
  tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
25
+ return torch.cat(tensors, dim=dim)
26
+
27
 
28
  def rotate_half(x):
29
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
30
+ x1, x2 = x.unbind(dim=-1)
31
+ x = torch.stack((-x2, x1), dim=-1)
32
+ return rearrange(x, "... d r -> ... (d r)")
33
+
34
 
35
  class VisionRotaryEmbeddingFast(nn.Module):
36
  def __init__(
 
38
  dim,
39
  pt_seq_len,
40
  ft_seq_len=None,
41
+ custom_freqs=None,
42
+ freqs_for="lang",
43
+ theta=10000,
44
+ max_freq=10,
45
+ num_freqs=1,
46
+ patch_dropout=0.0,
47
  ):
48
  super().__init__()
49
  if custom_freqs:
50
  freqs = custom_freqs
51
+ elif freqs_for == "lang":
52
+ freqs = 1.0 / (
53
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
54
+ )
55
+ elif freqs_for == "pixel":
56
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
57
+ elif freqs_for == "constant":
58
  freqs = torch.ones(num_freqs).float()
59
  else:
60
+ raise ValueError(f"unknown modality {freqs_for}")
61
 
62
+ if ft_seq_len is None:
63
+ ft_seq_len = pt_seq_len
64
  t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
65
 
66
+ freqs = torch.einsum("..., f -> ... f", t, freqs)
67
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
68
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
69
 
70
  freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
71
  freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
 
75
  self.register_buffer("freqs_cos", freqs_cos)
76
  self.register_buffer("freqs_sin", freqs_sin)
77
 
78
+ logging.info(f"Shape of rope freq: {self.freqs_cos.shape}")
79
 
80
  def forward(self, t, patch_indices_keep=None):
81
  if patch_indices_keep is not None:
 
83
  batch_indices = torch.arange(batch)
84
  batch_indices = batch_indices[..., None]
85
 
86
+ freqs_cos = repeat(
87
+ self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1]
88
+ )
89
+ freqs_sin = repeat(
90
+ self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1]
91
+ )
92
 
93
  freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
94
+ freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j")
95
  freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
96
+ freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j")
97
 
98
+ return t * freqs_cos + rotate_half(t) * freqs_sin
99
 
100
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
101
 
 
 
 
 
 
102
 
103
+ import logging
 
 
 
104
 
105
  # --------------------------------------------------------
106
  # Adapted from https://github.com/microsoft/unilm/tree/master/beit
107
  # --------------------------------------------------------
108
  import math
109
  import os
110
+ from dataclasses import dataclass
111
  from functools import partial
112
+ from typing import Optional, Tuple, Union
113
+
114
+ import numpy as np
115
  import torch
116
  import torch.nn as nn
117
  import torch.nn.functional as F
118
+ from torch import nn
119
+
120
  try:
121
  from timm.models.layers import drop_path, to_2tuple, trunc_normal_
122
  except:
123
  from timm.layers import drop_path, to_2tuple, trunc_normal_
124
+
125
+
126
  class PatchDropout(nn.Module):
127
  """
128
  https://arxiv.org/abs/2212.00794
 
130
 
131
  def __init__(self, prob, exclude_first_token=True):
132
  super().__init__()
133
+ assert 0 <= prob < 1.0
134
  self.prob = prob
135
  self.exclude_first_token = exclude_first_token # exclude CLS token
136
  logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
137
 
138
  def forward(self, x):
139
+ if not self.training or self.prob == 0.0:
140
  return x
141
 
142
  if self.exclude_first_token:
 
161
  if self.exclude_first_token:
162
  x = torch.cat((cls_tokens, x), dim=1)
163
 
164
+ if self.training and os.getenv("RoPE") == "1":
165
  return x, patch_indices_keep
166
 
167
  return x
168
 
169
+
170
+ if os.getenv("ENV_TYPE") == "deepspeed":
171
  try:
172
  from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
173
  except:
 
177
 
178
  import xformers.ops as xops
179
 
180
+
181
  class DropPath(nn.Module):
182
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
183
+
184
  def __init__(self, drop_prob=None):
185
  super(DropPath, self).__init__()
186
  self.drop_prob = drop_prob
187
 
188
  def forward(self, x):
189
  return drop_path(x, self.drop_prob, self.training)
190
+
191
  def extra_repr(self) -> str:
192
+ return "p={}".format(self.drop_prob)
193
 
194
 
195
  class Mlp(nn.Module):
196
  def __init__(
197
+ self,
198
+ in_features,
199
+ hidden_features=None,
200
+ out_features=None,
201
+ act_layer=nn.GELU,
202
+ norm_layer=nn.LayerNorm,
203
+ drop=0.0,
204
  subln=False,
205
+ ):
 
206
  super().__init__()
207
  out_features = out_features or in_features
208
  hidden_features = hidden_features or in_features
 
218
  x = self.fc1(x)
219
  x = self.act(x)
220
  # x = self.drop(x)
221
+ # commit this for the orignal BERT implement
222
  x = self.ffn_ln(x)
223
 
224
  x = self.fc2(x)
225
  x = self.drop(x)
226
  return x
227
 
228
+
229
  class SwiGLU(nn.Module):
230
+ def __init__(
231
+ self,
232
+ in_features,
233
+ hidden_features=None,
234
+ out_features=None,
235
+ act_layer=nn.SiLU,
236
+ drop=0.0,
237
+ norm_layer=nn.LayerNorm,
238
+ subln=False,
239
+ ):
240
  super().__init__()
241
  out_features = out_features or in_features
242
  hidden_features = hidden_features or in_features
 
247
  self.act = act_layer()
248
  self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
249
  self.w3 = nn.Linear(hidden_features, out_features)
250
+
251
  self.drop = nn.Dropout(drop)
252
 
253
  def forward(self, x):
 
259
  x = self.drop(x)
260
  return x
261
 
262
+
263
  class Attention(nn.Module):
264
  def __init__(
265
+ self,
266
+ dim,
267
+ num_heads=8,
268
+ qkv_bias=False,
269
+ qk_scale=None,
270
+ attn_drop=0.0,
271
+ proj_drop=0.0,
272
+ window_size=None,
273
+ attn_head_dim=None,
274
+ xattn=False,
275
+ rope=None,
276
+ subln=False,
277
+ norm_layer=nn.LayerNorm,
278
+ ):
279
  super().__init__()
280
  self.num_heads = num_heads
281
  head_dim = dim // num_heads
282
  if attn_head_dim is not None:
283
  head_dim = attn_head_dim
284
  all_head_dim = head_dim * self.num_heads
285
+ self.scale = qk_scale or head_dim**-0.5
286
 
287
  self.subln = subln
288
  if self.subln:
 
301
 
302
  if window_size:
303
  self.window_size = window_size
304
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
305
+ 2 * window_size[1] - 1
306
+ ) + 3
307
  self.relative_position_bias_table = nn.Parameter(
308
+ torch.zeros(self.num_relative_distance, num_heads)
309
+ ) # 2*Wh-1 * 2*Ww-1, nH
310
  # cls to token & token 2 cls & cls to cls
311
 
312
  # get pair-wise relative position index for each token inside the window
 
314
  coords_w = torch.arange(window_size[1])
315
  coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
316
  coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
317
+ relative_coords = (
318
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
319
+ ) # 2, Wh*Ww, Wh*Ww
320
+ relative_coords = relative_coords.permute(
321
+ 1, 2, 0
322
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
323
  relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
324
  relative_coords[:, :, 1] += window_size[1] - 1
325
  relative_coords[:, :, 0] *= 2 * window_size[1] - 1
326
+ relative_position_index = torch.zeros(
327
+ size=(window_size[0] * window_size[1] + 1,) * 2,
328
+ dtype=relative_coords.dtype,
329
+ )
330
  relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
331
  relative_position_index[0, 0:] = self.num_relative_distance - 3
332
  relative_position_index[0:, 0] = self.num_relative_distance - 2
 
350
 
351
  def forward(self, x, rel_pos_bias=None, attn_mask=None):
352
  B, N, C = x.shape
353
+ if self.subln:
354
+ print("self.q_proj.weight.dtype=", self.q_proj.weight.dtype)
355
  if self.q_proj.weight.dtype == torch.uint8:
356
  import bitsandbytes as bnb
357
+
358
+ # print("self.q_proj.weight.quant_state=", self.q_proj.weight.quant_state)
359
+ # print("self.k_proj.weight.quant_state=", self.q_proj.weight.quant_state)
360
+ # print("self.v_proj.weight.quant_state=", self.q_proj.weight.quant_state)
361
+ q = bnb.matmul_4bit(
362
+ x,
363
+ self.q_proj.weight.t(),
364
+ bias=self.q_bias,
365
+ quant_state=self.q_proj.weight.quant_state,
366
+ )
367
+ k = bnb.matmul_4bit(
368
+ x,
369
+ self.k_proj.weight.t(),
370
+ bias=None,
371
+ quant_state=self.k_proj.weight.quant_state,
372
+ )
373
+ v = bnb.matmul_4bit(
374
+ x,
375
+ self.v_proj.weight.t(),
376
+ bias=self.v_bias,
377
+ quant_state=self.v_proj.weight.quant_state,
378
+ )
379
+ elif self.q_proj.weight.dtype == torch.int8:
380
+ import bitsandbytes as bnb
381
+
382
+ def make_state(weight_v):
383
+ state = bnb.MatmulLtState()
384
+ state.threshold = 0
385
+ state.has_fp16_weights = weight_v.has_fp16_weights
386
+ state.memory_efficient_backward = False
387
+ state.CB = weight_v.CB
388
+ state.SCB = weight_v.SCB
389
+ return state
390
+
391
+ q = bnb.matmul(
392
+ x,
393
+ self.q_proj.weight,
394
+ bias=self.q_bias,
395
+ state=make_state(self.q_proj.weight),
396
+ )
397
+ k = bnb.matmul(
398
+ x,
399
+ self.k_proj.weight,
400
+ bias=None,
401
+ state=make_state(self.k_proj.weight),
402
+ )
403
+ v = bnb.matmul(
404
+ x,
405
+ self.v_proj.weight,
406
+ bias=self.v_bias,
407
+ state=make_state(self.v_proj.weight),
408
+ )
409
  else:
410
  q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
411
  k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
412
  v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
413
 
414
+ q = q.reshape(B, N, self.num_heads, -1).permute(
415
+ 0, 2, 1, 3
416
+ ) # B, num_heads, N, C
417
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
418
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
419
+ else:
420
 
421
  qkv_bias = None
422
  if self.q_bias is not None:
423
+ qkv_bias = torch.cat(
424
+ (
425
+ self.q_bias,
426
+ torch.zeros_like(self.v_bias, requires_grad=False),
427
+ self.v_bias,
428
+ )
429
+ )
430
+
431
  qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
432
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(
433
+ 2, 0, 3, 1, 4
434
+ ) # 3, B, num_heads, N, C
435
  q, k, v = qkv[0], qkv[1], qkv[2]
436
 
437
  if self.rope:
 
445
  k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
446
 
447
  if self.xattn:
448
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
449
  k = k.permute(0, 2, 1, 3)
450
  v = v.permute(0, 2, 1, 3)
451
 
452
  x = xops.memory_efficient_attention(
453
+ q,
454
+ k,
455
+ v,
456
  p=self.xattn_drop,
457
  scale=self.scale,
458
+ )
459
  x = x.reshape(B, N, -1)
460
  x = self.inner_attn_ln(x)
461
  x = self.proj(x)
462
  x = self.proj_drop(x)
463
  else:
464
  q = q * self.scale
465
+ attn = q @ k.transpose(-2, -1)
466
 
467
  if self.relative_position_bias_table is not None:
468
+ relative_position_bias = self.relative_position_bias_table[
469
+ self.relative_position_index.view(-1)
470
+ ].view(
471
+ self.window_size[0] * self.window_size[1] + 1,
472
+ self.window_size[0] * self.window_size[1] + 1,
473
+ -1,
474
+ ) # Wh*Ww,Wh*Ww,nH
475
+ relative_position_bias = relative_position_bias.permute(
476
+ 2, 0, 1
477
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
478
  attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
479
 
480
  if rel_pos_bias is not None:
 
483
  if attn_mask is not None:
484
  attn_mask = attn_mask.bool()
485
  attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
486
+
487
  attn = attn.softmax(dim=-1)
488
  attn = self.attn_drop(attn)
489
 
 
496
 
497
  class Block(nn.Module):
498
 
499
+ def __init__(
500
+ self,
501
+ dim,
502
+ num_heads,
503
+ mlp_ratio=4.0,
504
+ qkv_bias=False,
505
+ qk_scale=None,
506
+ drop=0.0,
507
+ attn_drop=0.0,
508
+ drop_path=0.0,
509
+ init_values=None,
510
+ act_layer=nn.GELU,
511
+ norm_layer=nn.LayerNorm,
512
+ window_size=None,
513
+ attn_head_dim=None,
514
+ xattn=False,
515
+ rope=None,
516
+ postnorm=False,
517
+ subln=False,
518
+ naiveswiglu=False,
519
+ ):
520
  super().__init__()
521
  self.norm1 = norm_layer(dim)
522
  self.attn = Attention(
523
+ dim,
524
+ num_heads=num_heads,
525
+ qkv_bias=qkv_bias,
526
+ qk_scale=qk_scale,
527
+ attn_drop=attn_drop,
528
+ proj_drop=drop,
529
+ window_size=window_size,
530
+ attn_head_dim=attn_head_dim,
531
+ xattn=xattn,
532
+ rope=rope,
533
+ subln=subln,
534
+ norm_layer=norm_layer,
535
+ )
536
  # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
537
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
538
  self.norm2 = norm_layer(dim)
539
  mlp_hidden_dim = int(dim * mlp_ratio)
540
 
541
  if naiveswiglu:
542
  self.mlp = SwiGLU(
543
+ in_features=dim,
544
+ hidden_features=mlp_hidden_dim,
545
  subln=subln,
546
  norm_layer=norm_layer,
547
  )
548
  else:
549
  self.mlp = Mlp(
550
+ in_features=dim,
551
+ hidden_features=mlp_hidden_dim,
552
  act_layer=act_layer,
553
  subln=subln,
554
+ drop=drop,
555
  )
556
 
557
  if init_values is not None and init_values > 0:
558
+ self.gamma_1 = nn.Parameter(
559
+ init_values * torch.ones((dim)), requires_grad=True
560
+ )
561
+ self.gamma_2 = nn.Parameter(
562
+ init_values * torch.ones((dim)), requires_grad=True
563
+ )
564
  else:
565
  self.gamma_1, self.gamma_2 = None, None
566
 
 
569
  def forward(self, x, rel_pos_bias=None, attn_mask=None):
570
  if self.gamma_1 is None:
571
  if self.postnorm:
572
+ x = x + self.drop_path(
573
+ self.norm1(
574
+ self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)
575
+ )
576
+ )
577
  x = x + self.drop_path(self.norm2(self.mlp(x)))
578
  else:
579
+ x = x + self.drop_path(
580
+ self.attn(
581
+ self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask
582
+ )
583
+ )
584
  x = x + self.drop_path(self.mlp(self.norm2(x)))
585
  else:
586
  if self.postnorm:
587
+ x = x + self.drop_path(
588
+ self.gamma_1
589
+ * self.norm1(
590
+ self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)
591
+ )
592
+ )
593
  x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
594
  else:
595
+ x = x + self.drop_path(
596
+ self.gamma_1
597
+ * self.attn(
598
+ self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask
599
+ )
600
+ )
601
  x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
602
  return x
603
 
604
 
605
  class PatchEmbed(nn.Module):
606
+ """Image to Patch Embedding"""
607
+
608
  def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
609
  super().__init__()
610
  img_size = to_2tuple(img_size)
 
615
  self.patch_size = patch_size
616
  self.num_patches = num_patches
617
 
618
+ self.proj = nn.Conv2d(
619
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
620
+ )
621
 
622
  def forward(self, x, **kwargs):
623
  B, C, H, W = x.shape
624
  # FIXME look at relaxing size constraints
625
+ assert (
626
+ H == self.img_size[0] and W == self.img_size[1]
627
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
628
  x = self.proj(x).flatten(2).transpose(1, 2)
629
  return x
630
 
 
634
  def __init__(self, window_size, num_heads):
635
  super().__init__()
636
  self.window_size = window_size
637
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
638
+ 2 * window_size[1] - 1
639
+ ) + 3
640
  self.relative_position_bias_table = nn.Parameter(
641
+ torch.zeros(self.num_relative_distance, num_heads)
642
+ ) # 2*Wh-1 * 2*Ww-1, nH
643
  # cls to token & token 2 cls & cls to cls
644
 
645
  # get pair-wise relative position index for each token inside the window
 
647
  coords_w = torch.arange(window_size[1])
648
  coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
649
  coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
650
+ relative_coords = (
651
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
652
+ ) # 2, Wh*Ww, Wh*Ww
653
+ relative_coords = relative_coords.permute(
654
+ 1, 2, 0
655
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
656
  relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
657
  relative_coords[:, :, 1] += window_size[1] - 1
658
  relative_coords[:, :, 0] *= 2 * window_size[1] - 1
659
+ relative_position_index = torch.zeros(
660
+ size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
661
+ )
662
  relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
663
  relative_position_index[0, 0:] = self.num_relative_distance - 3
664
  relative_position_index[0:, 0] = self.num_relative_distance - 2
 
667
  self.register_buffer("relative_position_index", relative_position_index)
668
 
669
  def forward(self):
670
+ relative_position_bias = self.relative_position_bias_table[
671
+ self.relative_position_index.view(-1)
672
+ ].view(
673
+ self.window_size[0] * self.window_size[1] + 1,
674
+ self.window_size[0] * self.window_size[1] + 1,
675
+ -1,
676
+ ) # Wh*Ww,Wh*Ww,nH
677
  return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
678
 
679
 
680
  class EVAVisionTransformer(nn.Module):
681
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
682
+
683
+ def __init__(
684
+ self,
685
+ img_size=224,
686
+ patch_size=16,
687
+ in_chans=3,
688
+ num_classes=1000,
689
+ embed_dim=768,
690
+ depth=12,
691
+ num_heads=12,
692
+ mlp_ratio=4.0,
693
+ qkv_bias=False,
694
+ qk_scale=None,
695
+ drop_rate=0.0,
696
+ attn_drop_rate=0.0,
697
+ drop_path_rate=0.0,
698
+ norm_layer=nn.LayerNorm,
699
+ init_values=None,
700
+ patch_dropout=0.0,
701
+ use_abs_pos_emb=True,
702
+ use_rel_pos_bias=False,
703
+ use_shared_rel_pos_bias=False,
704
+ rope=False,
705
+ use_mean_pooling=True,
706
+ init_scale=0.001,
707
+ grad_checkpointing=False,
708
+ xattn=False,
709
+ postnorm=False,
710
+ pt_hw_seq_len=16,
711
+ intp_freq=False,
712
+ naiveswiglu=False,
713
+ subln=False,
714
+ ):
715
  super().__init__()
716
  self.image_size = img_size
717
  self.num_classes = num_classes
718
+ self.num_features = self.embed_dim = (
719
+ embed_dim # num_features for consistency with other models
720
+ )
721
 
722
  self.patch_embed = PatchEmbed(
723
+ img_size=img_size,
724
+ patch_size=patch_size,
725
+ in_chans=in_chans,
726
+ embed_dim=embed_dim,
727
+ )
728
  num_patches = self.patch_embed.num_patches
729
 
730
  self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
 
736
  self.pos_drop = nn.Dropout(p=drop_rate)
737
 
738
  if use_shared_rel_pos_bias:
739
+ self.rel_pos_bias = RelativePositionBias(
740
+ window_size=self.patch_embed.patch_shape, num_heads=num_heads
741
+ )
742
  else:
743
  self.rel_pos_bias = None
744
 
 
751
  ft_seq_len=hw_seq_len if intp_freq else None,
752
  # patch_dropout=patch_dropout
753
  )
754
+ else:
755
  self.rope = None
756
 
757
  self.naiveswiglu = naiveswiglu
758
 
759
+ dpr = [
760
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
761
+ ] # stochastic depth decay rule
762
  self.use_rel_pos_bias = use_rel_pos_bias
763
+ self.blocks = nn.ModuleList(
764
+ [
765
+ Block(
766
+ dim=embed_dim,
767
+ num_heads=num_heads,
768
+ mlp_ratio=mlp_ratio,
769
+ qkv_bias=qkv_bias,
770
+ qk_scale=qk_scale,
771
+ drop=drop_rate,
772
+ attn_drop=attn_drop_rate,
773
+ drop_path=dpr[i],
774
+ norm_layer=norm_layer,
775
+ init_values=init_values,
776
+ window_size=(
777
+ self.patch_embed.patch_shape if use_rel_pos_bias else None
778
+ ),
779
+ xattn=xattn,
780
+ rope=self.rope,
781
+ postnorm=postnorm,
782
+ subln=subln,
783
+ naiveswiglu=naiveswiglu,
784
+ )
785
+ for i in range(depth)
786
+ ]
787
+ )
788
  self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
789
  self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
790
+ self.head = (
791
+ nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
792
+ )
793
 
794
  if self.pos_embed is not None:
795
+ trunc_normal_(self.pos_embed, std=0.02)
796
 
797
+ trunc_normal_(self.cls_token, std=0.02)
798
  # trunc_normal_(self.mask_token, std=.02)
799
 
800
  self.apply(self._init_weights)
801
  self.fix_init_weight()
802
 
803
  if isinstance(self.head, nn.Linear):
804
+ trunc_normal_(self.head.weight, std=0.02)
805
  self.head.weight.data.mul_(init_scale)
806
  self.head.bias.data.mul_(init_scale)
807
 
808
  # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
809
+ self.patch_dropout = (
810
+ PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
811
+ )
812
 
813
  self.grad_checkpointing = grad_checkpointing
814
 
 
828
 
829
  def _init_weights(self, m):
830
  if isinstance(m, nn.Linear):
831
+ trunc_normal_(m.weight, std=0.02)
832
  if m.bias is not None:
833
  nn.init.constant_(m.bias, 0)
834
  elif isinstance(m, nn.LayerNorm):
 
837
 
838
  def get_num_layers(self):
839
  return len(self.blocks)
840
+
841
  def lock(self, unlocked_groups=0, freeze_bn_stats=False):
842
+ assert (
843
+ unlocked_groups == 0
844
+ ), "partial locking not currently supported for this model"
845
  for param in self.parameters():
846
  param.requires_grad = False
847
 
 
851
 
852
  @torch.jit.ignore
853
  def no_weight_decay(self):
854
+ return {"pos_embed", "cls_token"}
855
 
856
  def get_classifier(self):
857
  return self.head
858
 
859
+ def reset_classifier(self, num_classes, global_pool=""):
860
  self.num_classes = num_classes
861
+ self.head = (
862
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
863
+ )
864
 
865
  def forward_features(self, x, return_all_features=False):
866
+
867
  x = self.patch_embed(x)
868
  batch_size, seq_len, _ = x.size()
869
 
870
+ cls_tokens = self.cls_token.expand(
871
+ batch_size, -1, -1
872
+ ) # stole cls_tokens impl from Phil Wang, thanks
873
  x = torch.cat((cls_tokens, x), dim=1)
874
  if self.pos_embed is not None:
875
  x = x + self.pos_embed
876
  x = self.pos_drop(x)
877
 
878
  # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
879
+ if os.getenv("RoPE") == "1":
880
  if self.training and not isinstance(self.patch_dropout, nn.Identity):
881
  x, patch_indices_keep = self.patch_dropout(x)
882
+ self.rope.forward = partial(
883
+ self.rope.forward, patch_indices_keep=patch_indices_keep
884
+ )
885
  else:
886
  self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
887
  x = self.patch_dropout(x)
 
890
 
891
  rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
892
  for i, blk in enumerate(self.blocks):
893
+ if i == len(self.blocks) - 1:
894
  continue
895
  if self.grad_checkpointing:
896
  x = checkpoint(blk, x, (rel_pos_bias,))
 
912
  x = self.head(x)
913
  return x
914
 
915
+
916
  class LayerNorm(nn.LayerNorm):
917
  """Subclass torch's LayerNorm (with cast back to input dtype)."""
918
 
 
921
  x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
922
  return x.to(orig_type)
923
 
924
+
925
  try:
926
  from apex.normalization import FusedLayerNorm
927
  except:
 
938
  patch_size: int = 16
939
  image_size: Union[Tuple[int, int], int] = 224
940
  ls_init_value: Optional[float] = None # layer scale initial value
941
+ patch_dropout: float = (
942
+ 0.0 # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
943
+ )
944
+ global_average_pool: bool = (
945
+ False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
946
+ )
947
  drop_path_rate: Optional[float] = None # drop path rate
948
+ timm_model_name: str = (
949
+ None # a valid model name overrides layers, width, patch_size
950
+ )
951
+ timm_model_pretrained: bool = (
952
+ False # use (imagenet) pretrained weights for named model
953
+ )
954
+ timm_pool: str = (
955
+ "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
956
+ )
957
+ timm_proj: str = (
958
+ "linear" # linear projection for timm model output ('linear', 'mlp', '')
959
+ )
960
  timm_proj_bias: bool = False # enable bias final projection
961
+ eva_model_name: str = (
962
+ None # a valid eva model name overrides layers, width, patch_size
963
+ )
964
  qkv_bias: bool = True
965
  fusedLN: bool = False
966
  xattn: bool = False
967
  postnorm: bool = False
968
  rope: bool = False
969
+ pt_hw_seq_len: int = 16 # 224/14
970
  intp_freq: bool = False
971
  naiveswiglu: bool = False
972
  subln: bool = False
973
 
974
 
975
+ def _build_vision_tower(embed_dim: int, vision_cfg: CLIPVisionCfg):
 
 
 
976
  if isinstance(vision_cfg, dict):
977
  vision_cfg = CLIPVisionCfg(**vision_cfg)
978
 
 
983
  img_size=vision_cfg.image_size,
984
  patch_size=vision_cfg.patch_size,
985
  num_classes=embed_dim,
986
+ use_mean_pooling=vision_cfg.global_average_pool, # False
987
  init_values=vision_cfg.ls_init_value,
988
  patch_dropout=vision_cfg.patch_dropout,
989
  embed_dim=vision_cfg.width,
 
992
  mlp_ratio=vision_cfg.mlp_ratio,
993
  qkv_bias=vision_cfg.qkv_bias,
994
  drop_path_rate=vision_cfg.drop_path_rate,
995
+ norm_layer=(
996
+ partial(FusedLayerNorm, eps=1e-6)
997
+ if vision_cfg.fusedLN
998
+ else partial(norm_layer, eps=1e-6)
999
+ ),
1000
  xattn=vision_cfg.xattn,
1001
  rope=vision_cfg.rope,
1002
  postnorm=vision_cfg.postnorm,
1003
+ pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14
1004
+ intp_freq=vision_cfg.intp_freq,
1005
+ naiveswiglu=vision_cfg.naiveswiglu,
1006
+ subln=vision_cfg.subln,
1007
  )
1008
 
1009
  return visual
1010
 
1011
+
1012
  class Eva2LargeEncoder(nn.Module):
1013
  def __init__(self, image_size=224):
1014
  super(Eva2LargeEncoder, self).__init__()
 
1029
  "pt_hw_seq_len": 16,
1030
  "intp_freq": True,
1031
  "naiveswiglu": True,
1032
+ "subln": True,
1033
+ },
1034
  }
1035
+ self.config["vision_cfg"]["image_size"] = image_size
1036
+
1037
  import os
 
 
1038
 
1039
+ os.environ["delRoPE"] = (
1040
+ "1" # to avoid error in rope params when changing image size
1041
+ )
1042
+ self.model = _build_vision_tower(**self.config)
1043
 
1044
  def forward(self, images):
1045
  encode = self.model(images, return_all_features=True)[:, 1:, :]
1046
  return encode
1047
 
1048
+
1049
  class CrossVisionModel(nn.Module):
1050
  def __init__(self, config):
1051
  super().__init__()
1052
  self.vit = Eva2LargeEncoder(image_size=config.cross_image_size)
1053
+ self.pos_embed = nn.Parameter(
1054
+ torch.zeros(
1055
+ (
1056
+ self.vit.config["vision_cfg"]["image_size"]
1057
+ // self.vit.config["vision_cfg"]["patch_size"]
1058
+ )
1059
+ ** 2,
1060
+ self.vit.config["vision_cfg"]["width"],
1061
+ )
1062
+ )
1063
 
1064
  def forward(self, images):
1065
  enc = self.vit(images)
1066
+ return enc + self.pos_embed.to(enc.device).unsqueeze(0)
modeling_cogagent.py CHANGED
@@ -1,24 +1,28 @@
1
  """largely copy from llama and adapt for CogAgent"""
2
- import warnings
3
- from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, Any
4
 
5
  import math
 
 
 
6
  import torch
 
7
  from torch import nn
 
 
8
  from torch.nn import CrossEntropyLoss
 
9
  from torchvision import transforms
10
- from einops import rearrange
11
-
12
  from transformers import PreTrainedModel, PreTrainedTokenizer
13
- from transformers.utils.logging import get_logger
14
  from transformers.activations import ACT2FN
15
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 
 
 
 
16
 
17
  from .configuration_cogagent import CogAgentConfig
18
- # from .util import FastRotaryEmbedding
19
- from torch.nn import functional as F
20
- from .visual import EVA2CLIPModel
21
  from .cross_visual import CrossVisionModel
 
22
 
23
  if TYPE_CHECKING:
24
  from transformers.utils import ModelOutput
@@ -31,7 +35,10 @@ VISION_TOKEN_TYPE = 1
31
 
32
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
33
  def _make_causal_mask(
34
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
 
 
 
35
  ):
36
  """
37
  Make causal mask used for bi-directional self-attention.
@@ -43,8 +50,18 @@ def _make_causal_mask(
43
  mask = mask.to(dtype)
44
 
45
  if past_key_values_length > 0:
46
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
47
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
 
 
 
 
 
 
 
 
 
 
48
 
49
 
50
  # Copied from transformers.models.bart.modeling_bart._expand_mask
@@ -59,7 +76,9 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
59
 
60
  inverted_mask = 1.0 - expanded_mask
61
 
62
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
 
 
63
 
64
 
65
  class RMSNorm(nn.Module):
@@ -91,9 +110,13 @@ class MLP(nn.Module):
91
  return down_proj
92
 
93
 
94
- def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
 
 
95
  vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
96
- vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] == VISION_TOKEN_TYPE)
 
 
97
  language_token_mask = ~vision_token_mask
98
  return vision_token_mask, language_token_mask
99
 
@@ -104,47 +127,71 @@ class VisionExpertMLP(nn.Module):
104
  self.language_mlp = MLP(config)
105
  self.vision_mlp = MLP(config)
106
 
107
- def forward(self, hidden_states: "torch.Tensor(B, L, D)", token_type_ids: "torch.LongTensor(B, L)"):
108
- output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device)
 
 
 
 
 
 
109
  vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
110
  output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
111
- output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
 
 
112
  return output
113
 
114
 
115
  def attention_fn(
116
- query_layer: "torch.tensor(B, H, L, HD)",
117
- key_layer: "torch.tensor(B, H, L, HD)",
118
- value_layer: "torch.tensor(B, H, L, HD)",
119
- attention_mask: "torch.tensor(B, H, L, HD)",
120
- *,
121
- scaling_attention_score: bool = True,
122
- attention_dropout: nn.Module = None
123
  ):
124
- attention_mask_bool = (attention_mask == 0)
125
- is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
 
 
 
126
  is_full = (attention_mask_bool > 0).all()
127
- if not (int(torch.__version__.split('.')[0]) >= 2):
128
  warnings.warn("It's recommended to use torch2.0 or higher.")
129
- if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
130
- dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
 
 
 
 
 
 
 
 
131
  return torch.nn.functional.scaled_dot_product_attention(
132
- query_layer, key_layer, value_layer,
 
 
133
  attn_mask=None,
134
  dropout_p=dropout_p,
135
- is_causal=not is_full
136
  )
137
  else:
138
  if scaling_attention_score:
139
  query_layer = query_layer / math.sqrt(query_layer.shape[-1])
140
  attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
141
  attention_scores = attention_scores + attention_mask
142
- attention_scores = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
 
 
143
  if attention_dropout is not None:
144
  attention_scores = attention_dropout(attention_scores)
145
  context_layer = torch.matmul(attention_scores, value_layer)
146
  return context_layer
147
 
 
148
  class RotaryEmbedding(torch.nn.Module):
149
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
150
  super().__init__()
@@ -158,19 +205,24 @@ class RotaryEmbedding(torch.nn.Module):
158
 
159
  def _compute_inv_freq(self, device=None):
160
  return 1.0 / (
161
- self.base
162
- ** (torch.arange(0, self.dim, 2, device=device) / self.dim)
163
  )
164
 
165
  def _set_cos_sin_cache(self, seq_len, device, dtype):
166
  self.max_seq_len_cached = seq_len
167
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
 
 
168
 
169
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
170
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
171
  emb = torch.cat((freqs, freqs), dim=-1)
172
- self.register_buffer("cos_cached", emb.cos()[:, None, :].to(dtype), persistent=False)
173
- self.register_buffer("sin_cached", emb.sin()[:, None, :].to(dtype), persistent=False)
 
 
 
 
174
 
175
  def forward(self, x, seq_len):
176
  # x: [bs, num_attention_heads, seq_len, head_size]
@@ -184,17 +236,19 @@ class RotaryEmbedding(torch.nn.Module):
184
 
185
 
186
  def rotate_half(x):
187
- x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
188
  return torch.cat((-x2, x1), dim=x1.ndim - 1)
189
 
190
 
191
  def apply_rotary_pos_emb_index_bhs(q, k, cos, sin, position_id):
192
  # batch_size, num_head, seq_len, hidden_size
193
- cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(1), \
194
- F.embedding(position_id, sin.squeeze(1)).unsqueeze(1)
 
195
  q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
196
  return q, k
197
 
 
198
  class VisionExpertAttention(nn.Module):
199
  def __init__(self, config):
200
  super().__init__()
@@ -205,10 +259,18 @@ class VisionExpertAttention(nn.Module):
205
  self.max_position_embeddings = config.max_position_embeddings
206
 
207
  self.rotary_emb = RotaryEmbedding(self.head_dim)
208
- self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
209
- self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
210
- self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
211
- self.language_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
 
 
 
 
 
 
 
 
212
 
213
  def _transpose_for_scores(self, tensor):
214
  """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
@@ -217,25 +279,33 @@ class VisionExpertAttention(nn.Module):
217
  return tensor.permute(0, 2, 1, 3)
218
 
219
  def forward(
220
- self,
221
- hidden_states: torch.Tensor,
222
- token_type_ids: torch.LongTensor,
223
- position_ids: torch.LongTensor,
224
- attention_mask: Optional[torch.Tensor] = None,
225
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
226
- output_attentions: bool = False,
227
- use_cache: bool = False,
228
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
229
  bsz, q_len, _ = hidden_states.size()
230
  vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
231
 
232
  shape = list(hidden_states.shape)
233
  shape[-1] = shape[-1] * 3
234
- mixed_raw_layer = torch.empty(shape, dtype=hidden_states.dtype, device=hidden_states.device)
235
- mixed_raw_layer[vision_token_mask] = self.vision_expert_query_key_value(hidden_states[vision_token_mask])
236
- mixed_raw_layer[language_token_mask] = self.language_expert_query_key_value(hidden_states[language_token_mask])
 
 
 
 
 
 
237
 
238
- query_states, key_states, value_states = torch.split(mixed_raw_layer, self.hidden_size, dim=-1)
 
 
239
  query_states = self._transpose_for_scores(query_states) # B, H, L, HD
240
  key_states = self._transpose_for_scores(key_states) # B, H, L, HD
241
  value_states = self._transpose_for_scores(value_states) # B, H, L, HD
@@ -245,7 +315,9 @@ class VisionExpertAttention(nn.Module):
245
  kv_seq_len += past_key_value[0].shape[-2]
246
 
247
  cos, sin = self.rotary_emb(value_states, seq_len=position_ids.max() + 1)
248
- query_states, key_states = apply_rotary_pos_emb_index_bhs(query_states, key_states, cos, sin, position_ids)
 
 
249
 
250
  if past_key_value is not None:
251
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
@@ -254,24 +326,40 @@ class VisionExpertAttention(nn.Module):
254
  past_key_value = (key_states, value_states) if use_cache else None
255
 
256
  context_layer = attention_fn(
257
- query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
258
- scaling_attention_score=True, attention_dropout=None)
 
 
 
 
 
259
  if context_layer.size() != (bsz, self.num_heads, q_len, self.head_dim):
260
  raise ValueError(
261
  f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
262
  f" {context_layer.size()}"
263
  )
264
- context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
 
 
 
 
265
 
266
- attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
267
- attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
268
- attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
 
 
 
 
 
 
269
 
270
  if output_attentions:
271
  warnings.warn("output_attentions is not implemented.")
272
 
273
  return attn_output, None, past_key_value
274
 
 
275
  class CrossAttention(nn.Module):
276
  def __init__(self, config):
277
  super().__init__()
@@ -284,9 +372,15 @@ class CrossAttention(nn.Module):
284
  self.cross_head_dim = self.cross_compute_hidden_size // self.num_heads
285
  self.max_position_embeddings = config.max_position_embeddings
286
 
287
- self.query = nn.Linear(self.hidden_size, self.cross_compute_hidden_size, bias=False)
288
- self.key_value = nn.Linear(self.cross_hidden_size, self.cross_compute_hidden_size * 2, bias=False)
289
- self.dense = nn.Linear(self.cross_compute_hidden_size, self.hidden_size, bias=False)
 
 
 
 
 
 
290
 
291
  def _transpose_for_scores(self, tensor):
292
  """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
@@ -295,13 +389,13 @@ class CrossAttention(nn.Module):
295
  return tensor.permute(0, 2, 1, 3)
296
 
297
  def forward(
298
- self,
299
- hidden_states: torch.Tensor,
300
- encoder_outputs: torch.LongTensor,
301
- attention_mask: Optional[torch.Tensor] = None,
302
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
303
- output_attentions: bool = False,
304
- use_cache: bool = False,
305
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
306
  bsz, q_len, _ = hidden_states.size()
307
 
@@ -311,7 +405,9 @@ class CrossAttention(nn.Module):
311
  mixed_query_layer = self.query(hidden_states)
312
  if past_key_value is None:
313
  mixed_x_layer = self.key_value(encoder_outputs)
314
- mixed_key_layer, mixed_value_layer = torch.split(mixed_x_layer, self.cross_compute_hidden_size, dim=-1)
 
 
315
  key_states = self._transpose_for_scores(mixed_key_layer) # B, H, L, HD
316
  value_states = self._transpose_for_scores(mixed_value_layer) # B, H, L, HD
317
  else:
@@ -322,14 +418,23 @@ class CrossAttention(nn.Module):
322
  past_key_value = (key_states, value_states) if use_cache else None
323
 
324
  context_layer = attention_fn(
325
- query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
326
- scaling_attention_score=True, attention_dropout=None)
 
 
 
 
 
327
  if context_layer.size() != (bsz, self.num_heads, q_len, self.cross_head_dim):
328
  raise ValueError(
329
  f"`cross_attn_output` should be of size {(bsz, self.num_heads, q_len, self.cross_head_dim)}, but is"
330
  f" {context_layer.size()}"
331
  )
332
- context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.cross_hidden_size)
 
 
 
 
333
 
334
  attn_output = self.dense(context_layer)
335
 
@@ -338,6 +443,7 @@ class CrossAttention(nn.Module):
338
 
339
  return attn_output, None, past_key_value
340
 
 
341
  class CogAgentDecoderLayer(nn.Module):
342
  def __init__(self, config):
343
  super().__init__()
@@ -346,21 +452,27 @@ class CogAgentDecoderLayer(nn.Module):
346
  self.cross_attn = CrossAttention(config=config)
347
  self.mlp = VisionExpertMLP(config)
348
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
349
- self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
350
- self.post_cross_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
 
 
351
 
352
  def forward(
353
- self,
354
- hidden_states: torch.Tensor,
355
- encoder_outputs: torch.Tensor,
356
- token_type_ids: torch.LongTensor,
357
- position_ids: torch.LongTensor,
358
- attention_mask: Optional[torch.Tensor] = None,
359
- cross_attention_mask: Optional[torch.Tensor] = None,
360
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
361
- output_attentions: Optional[bool] = False,
362
- use_cache: Optional[bool] = False,
363
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
364
  residual = hidden_states
365
 
366
  hidden_states = self.input_layernorm(hidden_states)
@@ -379,14 +491,18 @@ class CogAgentDecoderLayer(nn.Module):
379
 
380
  cross_input = self.post_cross_attention_layernorm(hidden_states)
381
  # Fully Connected
382
- attention_output, self_cross_attn_weights, present_cross_key_value = self.cross_attn(
383
- hidden_states=cross_input,
384
- encoder_outputs=encoder_outputs,
385
- attention_mask=cross_attention_mask,
386
- past_key_value=past_key_value[-2:] if past_key_value is not None else None,
387
- output_attentions=output_attentions,
388
- use_cache=use_cache,
 
 
 
389
  )
 
390
  hidden_states = hidden_states + attention_output
391
  mlp_input = self.post_attention_layernorm(hidden_states)
392
  mlp_output = self.mlp(mlp_input, token_type_ids=token_type_ids)
@@ -398,7 +514,7 @@ class CogAgentDecoderLayer(nn.Module):
398
  outputs += (self_attn_weights,)
399
 
400
  if use_cache:
401
- outputs += (present_key_value+present_cross_key_value,)
402
 
403
  return outputs # type: ignore
404
 
@@ -407,7 +523,7 @@ class CogAgentPreTrainedModel(PreTrainedModel):
407
  config_class = CogAgentConfig
408
  base_model_prefix = "model"
409
  supports_gradient_checkpointing = False
410
- _no_split_modules = ["CogAgentDecoderLayer", 'TransformerLayer', 'Block']
411
  _skip_keys_device_placement = "past_key_values"
412
 
413
  def _init_weights(self, module):
@@ -431,7 +547,10 @@ def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
431
  return True
432
 
433
 
434
- def build_position_ids(x: "torch.BoolTensor(B, L)", attention_mask: Optional["torch.BoolTensor(B, L)"] = None) -> "torch.LongTensor(B, L)":
 
 
 
435
  if attention_mask is not None:
436
  tmp = x.clone()
437
  tmp[~(attention_mask.bool())] = -1
@@ -439,14 +558,20 @@ def build_position_ids(x: "torch.BoolTensor(B, L)", attention_mask: Optional["to
439
  tmp = x.clone()
440
  # image boi eoi token as LANGUAGE_TOKEN_TYPE
441
  is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
442
- is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
443
- is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
444
- is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
445
- is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
 
 
 
 
446
  tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
447
  # final position ids
448
  y = torch.zeros_like(x, dtype=torch.long)
449
- y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ((tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
 
 
450
  y = y.cumsum(dim=-1)
451
  return y
452
 
@@ -457,8 +582,12 @@ class CogAgentModel(CogAgentPreTrainedModel):
457
  self.padding_idx = config.pad_token_id
458
  self.vocab_size = config.vocab_size
459
 
460
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
461
- self.layers = nn.ModuleList([CogAgentDecoderLayer(config) for _ in range(config.num_hidden_layers)])
 
 
 
 
462
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
463
 
464
  self.vision = EVA2CLIPModel(config)
@@ -491,22 +620,22 @@ class CogAgentModel(CogAgentPreTrainedModel):
491
  images = torch.stack(images)
492
  encoder_outputs = self.cross_vision(images)
493
  return encoder_outputs
494
-
495
  def forward(
496
- self,
497
- input_ids: torch.LongTensor = None,
498
- images: List[List[torch.Tensor]] = None,
499
- cross_images: List[List[torch.Tensor]] = None,
500
- token_type_ids: Optional[torch.LongTensor] = None,
501
- attention_mask: Optional[torch.Tensor] = None,
502
- cross_attention_mask: Optional[torch.Tensor] = None,
503
- position_ids: Optional[torch.LongTensor] = None,
504
- past_key_values: Optional[List[torch.FloatTensor]] = None,
505
- inputs_embeds: Optional[torch.FloatTensor] = None,
506
- use_cache: Optional[bool] = None,
507
- output_attentions: Optional[bool] = None,
508
- output_hidden_states: Optional[bool] = None,
509
- return_dict: Optional[bool] = None,
510
  ) -> Union[Tuple, BaseModelOutputWithPast]:
511
  """take care of image_encode, token_type_ids, position_ids and (attention_mask = None is fine)"""
512
 
@@ -515,20 +644,35 @@ class CogAgentModel(CogAgentPreTrainedModel):
515
  # generate mode with past_key_values. the image features are already mapped
516
  else:
517
  # not allow for inputs_embeds, because we want to process image feature
518
- assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
 
 
519
  if not is_empty(images): # multi-modality
520
- assert token_type_ids is not None, f"multi-modality requires `token_type_ids`!"
 
 
521
  assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
522
  inputs_embeds = self.embed_tokens(input_ids)
523
  images_features = self.encode_images(images)
524
  encoder_outputs = self.encode_cross_images(cross_images)
525
- images_features = rearrange(images_features, 'b n d -> (b n) d')
526
- images_features = images_features.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
527
- inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
 
 
 
 
528
  else: # single-modality
529
  if token_type_ids is None:
530
- token_type_ids = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) * LANGUAGE_TOKEN_TYPE
531
- assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
 
 
 
 
 
 
 
532
  inputs_embeds = self.embed_tokens(input_ids)
533
  encoder_outputs = None
534
 
@@ -552,38 +696,50 @@ class CogAgentModel(CogAgentPreTrainedModel):
552
  )
553
 
554
  def llm_forward(
555
- self,
556
- input_ids: torch.LongTensor = None,
557
- encoder_outputs: torch.LongTensor = None,
558
- token_type_ids: torch.LongTensor = None,
559
- attention_mask: Optional[torch.Tensor] = None,
560
- cross_attention_mask: Optional[torch.Tensor] = None,
561
- position_ids: Optional[torch.LongTensor] = None,
562
- past_key_values: Optional[List[torch.FloatTensor]] = None,
563
- inputs_embeds: Optional[torch.FloatTensor] = None,
564
- use_cache: Optional[bool] = None,
565
- output_attentions: Optional[bool] = None,
566
- output_hidden_states: Optional[bool] = None,
567
- return_dict: Optional[bool] = None,
568
  ) -> Union[Tuple, BaseModelOutputWithPast]:
569
  """largely copy from llama forward and adapt for CogAgent with `token_type_ids`"""
570
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
571
  output_hidden_states = (
572
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
573
  )
574
  use_cache = use_cache if use_cache is not None else self.config.use_cache
575
 
576
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
577
 
578
  # retrieve input_ids and inputs_embeds
579
  if input_ids is not None and inputs_embeds is not None:
580
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
 
 
581
  elif input_ids is not None:
582
  batch_size, seq_length = input_ids.shape
583
  elif inputs_embeds is not None:
584
  batch_size, seq_length, _ = inputs_embeds.shape
585
  else:
586
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
 
 
587
 
588
  seq_length_with_past = seq_length
589
  past_key_values_length = 0
@@ -595,7 +751,10 @@ class CogAgentModel(CogAgentPreTrainedModel):
595
  if position_ids is None:
596
  device = input_ids.device if input_ids is not None else inputs_embeds.device
597
  position_ids = torch.arange(
598
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
 
 
 
599
  )
600
  position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
601
  else:
@@ -606,14 +765,19 @@ class CogAgentModel(CogAgentPreTrainedModel):
606
  # embed positions
607
  if attention_mask is None:
608
  attention_mask = torch.ones(
609
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
 
 
610
  )
611
  if cross_attention_mask is None:
612
  cross_attention_mask = torch.ones(
613
  (batch_size, 1), dtype=torch.bool, device=inputs_embeds.device
614
  )
615
  attention_mask = self._prepare_decoder_attention_mask(
616
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
 
 
 
617
  )
618
 
619
  hidden_states = inputs_embeds
@@ -627,7 +791,9 @@ class CogAgentModel(CogAgentPreTrainedModel):
627
  if output_hidden_states:
628
  all_hidden_states += (hidden_states,)
629
 
630
- past_key_value = past_key_values[idx] if past_key_values is not None else None
 
 
631
  layer_outputs = decoder_layer(
632
  hidden_states,
633
  encoder_outputs=encoder_outputs,
@@ -655,7 +821,11 @@ class CogAgentModel(CogAgentPreTrainedModel):
655
 
656
  next_cache = next_decoder_cache if use_cache else None
657
  if not return_dict:
658
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
 
 
 
659
  return BaseModelOutputWithPast(
660
  last_hidden_state=hidden_states,
661
  past_key_values=next_cache,
@@ -671,7 +841,9 @@ class CogAgentModel(CogAgentPreTrainedModel):
671
 
672
  # noinspection PyMethodMayBeStatic
673
  # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
674
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
 
 
675
  # create causal mask
676
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
677
  combined_attention_mask = None
@@ -685,15 +857,18 @@ class CogAgentModel(CogAgentPreTrainedModel):
685
 
686
  if attention_mask is not None:
687
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
688
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
689
- inputs_embeds.device
690
- )
691
  combined_attention_mask = (
692
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
 
 
693
  )
694
 
695
  return combined_attention_mask
696
 
 
697
  def vqa_history_to_prompt(history, query):
698
  # Only support single round chat in vqa mode
699
  prompt = "<EOI>Question: "
@@ -702,6 +877,7 @@ def vqa_history_to_prompt(history, query):
702
  prompt += query + " Short answer:"
703
  return prompt
704
 
 
705
  def chat_old_history_to_prompt(history, query):
706
  prompt = "<EOI>Question: "
707
  for i, (old_query, response) in enumerate(history):
@@ -709,6 +885,7 @@ def chat_old_history_to_prompt(history, query):
709
  prompt += query + " Answer:"
710
  return prompt
711
 
 
712
  def chat_history_to_prompt(history, query):
713
  prompt = " [INST] "
714
  for i, (old_query, response) in enumerate(history):
@@ -726,7 +903,7 @@ _history_to_prompt = {
726
  "base": base_history_to_prompt,
727
  "chat": chat_history_to_prompt,
728
  "chat_old": chat_old_history_to_prompt,
729
- "vqa": vqa_history_to_prompt
730
  }
731
 
732
 
@@ -761,26 +938,34 @@ class CogAgentForCausalLM(CogAgentPreTrainedModel):
761
  return self.model
762
 
763
  def forward(
764
- self,
765
- input_ids: torch.LongTensor = None,
766
- images: List[List[torch.Tensor]] = None,
767
- cross_images: List[List[torch.Tensor]] = None,
768
- token_type_ids: Optional[torch.LongTensor] = None,
769
- attention_mask: Optional[torch.Tensor] = None,
770
- position_ids: Optional[torch.LongTensor] = None,
771
- past_key_values: Optional[List[torch.FloatTensor]] = None,
772
- inputs_embeds: Optional[torch.FloatTensor] = None,
773
- use_cache: Optional[bool] = None,
774
- output_attentions: Optional[bool] = None,
775
- output_hidden_states: Optional[bool] = None,
776
- return_dict: Optional[bool] = None,
777
- labels: Optional[torch.LongTensor] = None,
778
  ) -> Union[Tuple, CausalLMOutputWithPast]:
779
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
780
  output_hidden_states = (
781
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
782
  )
783
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
784
 
785
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
786
  outputs = self.model(
@@ -828,15 +1013,23 @@ class CogAgentForCausalLM(CogAgentPreTrainedModel):
828
  )
829
 
830
  def _prepare_attention_mask_for_generation(
831
- self,
832
- inputs: torch.Tensor,
833
- pad_token_id: Optional[int],
834
- eos_token_id: Optional[Union[int, List[int]]],
835
  ) -> torch.LongTensor:
836
  return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
837
 
838
  def prepare_inputs_for_generation(
839
- self, input_ids, token_type_ids, images=None, cross_images=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
 
 
 
 
 
 
 
 
840
  ):
841
  # build position_ids if needed
842
  position_ids = kwargs.get("position_ids", None)
@@ -868,11 +1061,11 @@ class CogAgentForCausalLM(CogAgentPreTrainedModel):
868
  return model_inputs
869
 
870
  def _update_model_kwargs_for_generation(
871
- self,
872
- outputs: "ModelOutput",
873
- model_kwargs: Dict[str, Any],
874
- is_encoder_decoder: bool = False,
875
- standardize_cache_format: bool = False,
876
  ) -> Dict[str, Any]:
877
  # update past_key_values
878
  model_kwargs["past_key_values"] = self._extract_past_from_model_output(
@@ -884,22 +1077,40 @@ class CogAgentForCausalLM(CogAgentPreTrainedModel):
884
  # update token_type_ids with last value
885
  if "token_type_ids" in model_kwargs:
886
  token_type_ids = model_kwargs["token_type_ids"]
887
- new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype, device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
888
- model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
 
 
 
 
 
 
 
 
 
889
 
890
  if not is_encoder_decoder:
891
  # update attention mask
892
  if "attention_mask" in model_kwargs:
893
  attention_mask = model_kwargs["attention_mask"]
894
  model_kwargs["attention_mask"] = torch.cat(
895
- [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
 
 
 
 
896
  )
897
  else:
898
  # update decoder attention mask
899
  if "decoder_attention_mask" in model_kwargs:
900
  decoder_attention_mask = model_kwargs["decoder_attention_mask"]
901
  model_kwargs["decoder_attention_mask"] = torch.cat(
902
- [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
 
 
 
 
 
903
  dim=-1,
904
  )
905
 
@@ -909,22 +1120,25 @@ class CogAgentForCausalLM(CogAgentPreTrainedModel):
909
  reordered_past = ()
910
  for layer_past in past_key_values:
911
  reordered_past += (
912
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
 
 
 
913
  )
914
  return reordered_past
915
 
916
  def build_conversation_input_ids(
917
- self,
918
- tokenizer: "PreTrainedTokenizer",
919
- *,
920
- query: str,
921
- history: Optional[List[Tuple[str, str]]] = None,
922
- images: Optional[List["PIL.Image"]] = None,
923
- template_version: Optional[Literal["base", "chat", "vqa"]] = None,
924
  ):
925
- image_size: int = self.config.vision_config['image_size']
926
  cross_image_size: int = self.config.cross_image_size
927
- patch_size: int = self.config.vision_config['patch_size']
928
  template_version = template_version or self.config.template_version
929
  assert images is None or len(images) <= 1, f"not support multi images by now."
930
  history = history or []
@@ -938,25 +1152,35 @@ class CogAgentForCausalLM(CogAgentPreTrainedModel):
938
  transform = transforms.Compose(
939
  [
940
  transforms.Resize(
941
- (image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC
 
942
  ),
943
  transforms.ToTensor(),
944
- transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
 
 
 
945
  ]
946
  )
947
  images = [transform(ori[0])]
948
  cross_transform = transforms.Compose(
949
  [
950
  transforms.Resize(
951
- (cross_image_size, cross_image_size), interpolation=transforms.InterpolationMode.BICUBIC
 
952
  ),
953
  transforms.ToTensor(),
954
- transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
 
 
 
955
  ]
956
  )
957
  cross_images = [cross_transform(ori[0])]
958
  # language
959
- vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
 
 
960
  input_ids += [tokenizer.pad_token_id] * vision_token_num
961
  token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
962
  text_ids = tokenizer.encode(text, add_special_tokens=False)
@@ -966,9 +1190,9 @@ class CogAgentForCausalLM(CogAgentPreTrainedModel):
966
  attention_mask = [1] * len(input_ids)
967
 
968
  return {
969
- 'input_ids': torch.tensor(input_ids, dtype=torch.long),
970
- 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
971
- 'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
972
- 'images': images,
973
- 'cross_images': cross_images
974
  }
 
1
  """largely copy from llama and adapt for CogAgent"""
 
 
2
 
3
  import math
4
+ import warnings
5
+ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
6
+
7
  import torch
8
+ from einops import rearrange
9
  from torch import nn
10
+
11
+ # from .util import FastRotaryEmbedding
12
  from torch.nn import CrossEntropyLoss
13
+ from torch.nn import functional as F
14
  from torchvision import transforms
 
 
15
  from transformers import PreTrainedModel, PreTrainedTokenizer
 
16
  from transformers.activations import ACT2FN
17
+ from transformers.modeling_outputs import (
18
+ BaseModelOutputWithPast,
19
+ CausalLMOutputWithPast,
20
+ )
21
+ from transformers.utils.logging import get_logger
22
 
23
  from .configuration_cogagent import CogAgentConfig
 
 
 
24
  from .cross_visual import CrossVisionModel
25
+ from .visual import EVA2CLIPModel
26
 
27
  if TYPE_CHECKING:
28
  from transformers.utils import ModelOutput
 
35
 
36
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
37
  def _make_causal_mask(
38
+ input_ids_shape: torch.Size,
39
+ dtype: torch.dtype,
40
+ device: torch.device,
41
+ past_key_values_length: int = 0,
42
  ):
43
  """
44
  Make causal mask used for bi-directional self-attention.
 
50
  mask = mask.to(dtype)
51
 
52
  if past_key_values_length > 0:
53
+ mask = torch.cat(
54
+ [
55
+ torch.zeros(
56
+ tgt_len, past_key_values_length, dtype=dtype, device=device
57
+ ),
58
+ mask,
59
+ ],
60
+ dim=-1,
61
+ )
62
+ return mask[None, None, :, :].expand(
63
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
64
+ )
65
 
66
 
67
  # Copied from transformers.models.bart.modeling_bart._expand_mask
 
76
 
77
  inverted_mask = 1.0 - expanded_mask
78
 
79
+ return inverted_mask.masked_fill(
80
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
81
+ )
82
 
83
 
84
  class RMSNorm(nn.Module):
 
110
  return down_proj
111
 
112
 
113
+ def get_expert_mask(
114
+ token_type_ids: "torch.LongTensor(B, L)",
115
+ ) -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
116
  vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
117
+ vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (
118
+ token_type_ids[:, 1:] == VISION_TOKEN_TYPE
119
+ )
120
  language_token_mask = ~vision_token_mask
121
  return vision_token_mask, language_token_mask
122
 
 
127
  self.language_mlp = MLP(config)
128
  self.vision_mlp = MLP(config)
129
 
130
+ def forward(
131
+ self,
132
+ hidden_states: "torch.Tensor(B, L, D)",
133
+ token_type_ids: "torch.LongTensor(B, L)",
134
+ ):
135
+ output = torch.empty(
136
+ hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
137
+ )
138
  vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
139
  output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
140
+ output[language_token_mask] = self.language_mlp(
141
+ hidden_states[language_token_mask]
142
+ )
143
  return output
144
 
145
 
146
  def attention_fn(
147
+ query_layer: "torch.tensor(B, H, L, HD)",
148
+ key_layer: "torch.tensor(B, H, L, HD)",
149
+ value_layer: "torch.tensor(B, H, L, HD)",
150
+ attention_mask: "torch.tensor(B, H, L, HD)",
151
+ *,
152
+ scaling_attention_score: bool = True,
153
+ attention_dropout: nn.Module = None,
154
  ):
155
+ attention_mask_bool = attention_mask == 0
156
+ is_low_triangle = (
157
+ attention_mask_bool
158
+ == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()
159
+ ).all()
160
  is_full = (attention_mask_bool > 0).all()
161
+ if not (int(torch.__version__.split(".")[0]) >= 2):
162
  warnings.warn("It's recommended to use torch2.0 or higher.")
163
+ if (
164
+ int(torch.__version__.split(".")[0]) >= 2
165
+ and scaling_attention_score
166
+ and (is_full or is_low_triangle)
167
+ ):
168
+ dropout_p = (
169
+ 0.0
170
+ if attention_dropout is None or not attention_dropout.training
171
+ else attention_dropout.p
172
+ )
173
  return torch.nn.functional.scaled_dot_product_attention(
174
+ query_layer,
175
+ key_layer,
176
+ value_layer,
177
  attn_mask=None,
178
  dropout_p=dropout_p,
179
+ is_causal=not is_full,
180
  )
181
  else:
182
  if scaling_attention_score:
183
  query_layer = query_layer / math.sqrt(query_layer.shape[-1])
184
  attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
185
  attention_scores = attention_scores + attention_mask
186
+ attention_scores = nn.functional.softmax(
187
+ attention_scores, dim=-1, dtype=torch.float32
188
+ ).to(query_layer.dtype)
189
  if attention_dropout is not None:
190
  attention_scores = attention_dropout(attention_scores)
191
  context_layer = torch.matmul(attention_scores, value_layer)
192
  return context_layer
193
 
194
+
195
  class RotaryEmbedding(torch.nn.Module):
196
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
197
  super().__init__()
 
205
 
206
  def _compute_inv_freq(self, device=None):
207
  return 1.0 / (
208
+ self.base ** (torch.arange(0, self.dim, 2, device=device) / self.dim)
 
209
  )
210
 
211
  def _set_cos_sin_cache(self, seq_len, device, dtype):
212
  self.max_seq_len_cached = seq_len
213
+ t = torch.arange(
214
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
215
+ )
216
 
217
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
218
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
219
  emb = torch.cat((freqs, freqs), dim=-1)
220
+ self.register_buffer(
221
+ "cos_cached", emb.cos()[:, None, :].to(dtype), persistent=False
222
+ )
223
+ self.register_buffer(
224
+ "sin_cached", emb.sin()[:, None, :].to(dtype), persistent=False
225
+ )
226
 
227
  def forward(self, x, seq_len):
228
  # x: [bs, num_attention_heads, seq_len, head_size]
 
236
 
237
 
238
  def rotate_half(x):
239
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
240
  return torch.cat((-x2, x1), dim=x1.ndim - 1)
241
 
242
 
243
  def apply_rotary_pos_emb_index_bhs(q, k, cos, sin, position_id):
244
  # batch_size, num_head, seq_len, hidden_size
245
+ cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(1), F.embedding(
246
+ position_id, sin.squeeze(1)
247
+ ).unsqueeze(1)
248
  q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
249
  return q, k
250
 
251
+
252
  class VisionExpertAttention(nn.Module):
253
  def __init__(self, config):
254
  super().__init__()
 
259
  self.max_position_embeddings = config.max_position_embeddings
260
 
261
  self.rotary_emb = RotaryEmbedding(self.head_dim)
262
+ self.vision_expert_query_key_value = nn.Linear(
263
+ self.hidden_size, self.hidden_size * 3, bias=False
264
+ )
265
+ self.vision_expert_dense = nn.Linear(
266
+ self.hidden_size, self.hidden_size, bias=False
267
+ )
268
+ self.language_expert_query_key_value = nn.Linear(
269
+ self.hidden_size, self.hidden_size * 3, bias=False
270
+ )
271
+ self.language_expert_dense = nn.Linear(
272
+ self.hidden_size, self.hidden_size, bias=False
273
+ )
274
 
275
  def _transpose_for_scores(self, tensor):
276
  """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
 
279
  return tensor.permute(0, 2, 1, 3)
280
 
281
  def forward(
282
+ self,
283
+ hidden_states: torch.Tensor,
284
+ token_type_ids: torch.LongTensor,
285
+ position_ids: torch.LongTensor,
286
+ attention_mask: Optional[torch.Tensor] = None,
287
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
288
+ output_attentions: bool = False,
289
+ use_cache: bool = False,
290
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
291
  bsz, q_len, _ = hidden_states.size()
292
  vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
293
 
294
  shape = list(hidden_states.shape)
295
  shape[-1] = shape[-1] * 3
296
+ mixed_raw_layer = torch.empty(
297
+ shape, dtype=hidden_states.dtype, device=hidden_states.device
298
+ )
299
+ mixed_raw_layer[vision_token_mask] = self.vision_expert_query_key_value(
300
+ hidden_states[vision_token_mask]
301
+ )
302
+ mixed_raw_layer[language_token_mask] = self.language_expert_query_key_value(
303
+ hidden_states[language_token_mask]
304
+ )
305
 
306
+ query_states, key_states, value_states = torch.split(
307
+ mixed_raw_layer, self.hidden_size, dim=-1
308
+ )
309
  query_states = self._transpose_for_scores(query_states) # B, H, L, HD
310
  key_states = self._transpose_for_scores(key_states) # B, H, L, HD
311
  value_states = self._transpose_for_scores(value_states) # B, H, L, HD
 
315
  kv_seq_len += past_key_value[0].shape[-2]
316
 
317
  cos, sin = self.rotary_emb(value_states, seq_len=position_ids.max() + 1)
318
+ query_states, key_states = apply_rotary_pos_emb_index_bhs(
319
+ query_states, key_states, cos, sin, position_ids
320
+ )
321
 
322
  if past_key_value is not None:
323
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
 
326
  past_key_value = (key_states, value_states) if use_cache else None
327
 
328
  context_layer = attention_fn(
329
+ query_layer=query_states,
330
+ key_layer=key_states,
331
+ value_layer=value_states,
332
+ attention_mask=attention_mask,
333
+ scaling_attention_score=True,
334
+ attention_dropout=None,
335
+ )
336
  if context_layer.size() != (bsz, self.num_heads, q_len, self.head_dim):
337
  raise ValueError(
338
  f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
339
  f" {context_layer.size()}"
340
  )
341
+ context_layer = (
342
+ context_layer.transpose(1, 2)
343
+ .contiguous()
344
+ .reshape(bsz, q_len, self.hidden_size)
345
+ )
346
 
347
+ attn_output = torch.empty(
348
+ context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device
349
+ )
350
+ attn_output[vision_token_mask] = self.vision_expert_dense(
351
+ context_layer[vision_token_mask]
352
+ )
353
+ attn_output[language_token_mask] = self.language_expert_dense(
354
+ context_layer[language_token_mask]
355
+ )
356
 
357
  if output_attentions:
358
  warnings.warn("output_attentions is not implemented.")
359
 
360
  return attn_output, None, past_key_value
361
 
362
+
363
  class CrossAttention(nn.Module):
364
  def __init__(self, config):
365
  super().__init__()
 
372
  self.cross_head_dim = self.cross_compute_hidden_size // self.num_heads
373
  self.max_position_embeddings = config.max_position_embeddings
374
 
375
+ self.query = nn.Linear(
376
+ self.hidden_size, self.cross_compute_hidden_size, bias=False
377
+ )
378
+ self.key_value = nn.Linear(
379
+ self.cross_hidden_size, self.cross_compute_hidden_size * 2, bias=False
380
+ )
381
+ self.dense = nn.Linear(
382
+ self.cross_compute_hidden_size, self.hidden_size, bias=False
383
+ )
384
 
385
  def _transpose_for_scores(self, tensor):
386
  """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
 
389
  return tensor.permute(0, 2, 1, 3)
390
 
391
  def forward(
392
+ self,
393
+ hidden_states: torch.Tensor,
394
+ encoder_outputs: torch.LongTensor,
395
+ attention_mask: Optional[torch.Tensor] = None,
396
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
397
+ output_attentions: bool = False,
398
+ use_cache: bool = False,
399
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
400
  bsz, q_len, _ = hidden_states.size()
401
 
 
405
  mixed_query_layer = self.query(hidden_states)
406
  if past_key_value is None:
407
  mixed_x_layer = self.key_value(encoder_outputs)
408
+ mixed_key_layer, mixed_value_layer = torch.split(
409
+ mixed_x_layer, self.cross_compute_hidden_size, dim=-1
410
+ )
411
  key_states = self._transpose_for_scores(mixed_key_layer) # B, H, L, HD
412
  value_states = self._transpose_for_scores(mixed_value_layer) # B, H, L, HD
413
  else:
 
418
  past_key_value = (key_states, value_states) if use_cache else None
419
 
420
  context_layer = attention_fn(
421
+ query_layer=query_states,
422
+ key_layer=key_states,
423
+ value_layer=value_states,
424
+ attention_mask=attention_mask,
425
+ scaling_attention_score=True,
426
+ attention_dropout=None,
427
+ )
428
  if context_layer.size() != (bsz, self.num_heads, q_len, self.cross_head_dim):
429
  raise ValueError(
430
  f"`cross_attn_output` should be of size {(bsz, self.num_heads, q_len, self.cross_head_dim)}, but is"
431
  f" {context_layer.size()}"
432
  )
433
+ context_layer = (
434
+ context_layer.transpose(1, 2)
435
+ .contiguous()
436
+ .reshape(bsz, q_len, self.cross_hidden_size)
437
+ )
438
 
439
  attn_output = self.dense(context_layer)
440
 
 
443
 
444
  return attn_output, None, past_key_value
445
 
446
+
447
  class CogAgentDecoderLayer(nn.Module):
448
  def __init__(self, config):
449
  super().__init__()
 
452
  self.cross_attn = CrossAttention(config=config)
453
  self.mlp = VisionExpertMLP(config)
454
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
455
+ self.post_attention_layernorm = RMSNorm(
456
+ config.hidden_size, eps=config.rms_norm_eps
457
+ )
458
+ self.post_cross_attention_layernorm = RMSNorm(
459
+ config.hidden_size, eps=config.rms_norm_eps
460
+ )
461
 
462
  def forward(
463
+ self,
464
+ hidden_states: torch.Tensor,
465
+ encoder_outputs: torch.Tensor,
466
+ token_type_ids: torch.LongTensor,
467
+ position_ids: torch.LongTensor,
468
+ attention_mask: Optional[torch.Tensor] = None,
469
+ cross_attention_mask: Optional[torch.Tensor] = None,
470
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
471
+ output_attentions: Optional[bool] = False,
472
+ use_cache: Optional[bool] = False,
473
+ ) -> Tuple[
474
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
475
+ ]:
476
  residual = hidden_states
477
 
478
  hidden_states = self.input_layernorm(hidden_states)
 
491
 
492
  cross_input = self.post_cross_attention_layernorm(hidden_states)
493
  # Fully Connected
494
+ attention_output, self_cross_attn_weights, present_cross_key_value = (
495
+ self.cross_attn(
496
+ hidden_states=cross_input,
497
+ encoder_outputs=encoder_outputs,
498
+ attention_mask=cross_attention_mask,
499
+ past_key_value=(
500
+ past_key_value[-2:] if past_key_value is not None else None
501
+ ),
502
+ output_attentions=output_attentions,
503
+ use_cache=use_cache,
504
  )
505
+ )
506
  hidden_states = hidden_states + attention_output
507
  mlp_input = self.post_attention_layernorm(hidden_states)
508
  mlp_output = self.mlp(mlp_input, token_type_ids=token_type_ids)
 
514
  outputs += (self_attn_weights,)
515
 
516
  if use_cache:
517
+ outputs += (present_key_value + present_cross_key_value,)
518
 
519
  return outputs # type: ignore
520
 
 
523
  config_class = CogAgentConfig
524
  base_model_prefix = "model"
525
  supports_gradient_checkpointing = False
526
+ _no_split_modules = ["CogAgentDecoderLayer", "TransformerLayer", "Block"]
527
  _skip_keys_device_placement = "past_key_values"
528
 
529
  def _init_weights(self, module):
 
547
  return True
548
 
549
 
550
+ def build_position_ids(
551
+ x: "torch.BoolTensor(B, L)",
552
+ attention_mask: Optional["torch.BoolTensor(B, L)"] = None,
553
+ ) -> "torch.LongTensor(B, L)":
554
  if attention_mask is not None:
555
  tmp = x.clone()
556
  tmp[~(attention_mask.bool())] = -1
 
558
  tmp = x.clone()
559
  # image boi eoi token as LANGUAGE_TOKEN_TYPE
560
  is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
561
+ is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (
562
+ tmp[:, :-1] == LANGUAGE_TOKEN_TYPE
563
+ )
564
+ is_boi_eoi[:, 0] |= tmp[:, 0] == VISION_TOKEN_TYPE
565
+ is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (
566
+ tmp[:, 1:] == LANGUAGE_TOKEN_TYPE
567
+ )
568
+ is_boi_eoi[:, -1] |= tmp[:, -1] == VISION_TOKEN_TYPE
569
  tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
570
  # final position ids
571
  y = torch.zeros_like(x, dtype=torch.long)
572
+ y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
573
+ (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
574
+ )
575
  y = y.cumsum(dim=-1)
576
  return y
577
 
 
582
  self.padding_idx = config.pad_token_id
583
  self.vocab_size = config.vocab_size
584
 
585
+ self.embed_tokens = nn.Embedding(
586
+ config.vocab_size, config.hidden_size, self.padding_idx
587
+ )
588
+ self.layers = nn.ModuleList(
589
+ [CogAgentDecoderLayer(config) for _ in range(config.num_hidden_layers)]
590
+ )
591
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
592
 
593
  self.vision = EVA2CLIPModel(config)
 
620
  images = torch.stack(images)
621
  encoder_outputs = self.cross_vision(images)
622
  return encoder_outputs
623
+
624
  def forward(
625
+ self,
626
+ input_ids: torch.LongTensor = None,
627
+ images: List[List[torch.Tensor]] = None,
628
+ cross_images: List[List[torch.Tensor]] = None,
629
+ token_type_ids: Optional[torch.LongTensor] = None,
630
+ attention_mask: Optional[torch.Tensor] = None,
631
+ cross_attention_mask: Optional[torch.Tensor] = None,
632
+ position_ids: Optional[torch.LongTensor] = None,
633
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
634
+ inputs_embeds: Optional[torch.FloatTensor] = None,
635
+ use_cache: Optional[bool] = None,
636
+ output_attentions: Optional[bool] = None,
637
+ output_hidden_states: Optional[bool] = None,
638
+ return_dict: Optional[bool] = None,
639
  ) -> Union[Tuple, BaseModelOutputWithPast]:
640
  """take care of image_encode, token_type_ids, position_ids and (attention_mask = None is fine)"""
641
 
 
644
  # generate mode with past_key_values. the image features are already mapped
645
  else:
646
  # not allow for inputs_embeds, because we want to process image feature
647
+ assert (
648
+ input_ids is not None and inputs_embeds is None
649
+ ), f"{input_ids} {inputs_embeds}"
650
  if not is_empty(images): # multi-modality
651
+ assert (
652
+ token_type_ids is not None
653
+ ), f"multi-modality requires `token_type_ids`!"
654
  assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
655
  inputs_embeds = self.embed_tokens(input_ids)
656
  images_features = self.encode_images(images)
657
  encoder_outputs = self.encode_cross_images(cross_images)
658
+ images_features = rearrange(images_features, "b n d -> (b n) d")
659
+ images_features = images_features.to(
660
+ dtype=inputs_embeds.dtype, device=inputs_embeds.device
661
+ )
662
+ inputs_embeds = inputs_embeds.index_put(
663
+ [token_type_ids == VISION_TOKEN_TYPE], images_features
664
+ )
665
  else: # single-modality
666
  if token_type_ids is None:
667
+ token_type_ids = (
668
+ torch.ones_like(
669
+ input_ids, dtype=torch.long, device=input_ids.device
670
+ )
671
+ * LANGUAGE_TOKEN_TYPE
672
+ )
673
+ assert not (
674
+ token_type_ids == VISION_TOKEN_TYPE
675
+ ).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
676
  inputs_embeds = self.embed_tokens(input_ids)
677
  encoder_outputs = None
678
 
 
696
  )
697
 
698
  def llm_forward(
699
+ self,
700
+ input_ids: torch.LongTensor = None,
701
+ encoder_outputs: torch.LongTensor = None,
702
+ token_type_ids: torch.LongTensor = None,
703
+ attention_mask: Optional[torch.Tensor] = None,
704
+ cross_attention_mask: Optional[torch.Tensor] = None,
705
+ position_ids: Optional[torch.LongTensor] = None,
706
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
707
+ inputs_embeds: Optional[torch.FloatTensor] = None,
708
+ use_cache: Optional[bool] = None,
709
+ output_attentions: Optional[bool] = None,
710
+ output_hidden_states: Optional[bool] = None,
711
+ return_dict: Optional[bool] = None,
712
  ) -> Union[Tuple, BaseModelOutputWithPast]:
713
  """largely copy from llama forward and adapt for CogAgent with `token_type_ids`"""
714
+ output_attentions = (
715
+ output_attentions
716
+ if output_attentions is not None
717
+ else self.config.output_attentions
718
+ )
719
  output_hidden_states = (
720
+ output_hidden_states
721
+ if output_hidden_states is not None
722
+ else self.config.output_hidden_states
723
  )
724
  use_cache = use_cache if use_cache is not None else self.config.use_cache
725
 
726
+ return_dict = (
727
+ return_dict if return_dict is not None else self.config.use_return_dict
728
+ )
729
 
730
  # retrieve input_ids and inputs_embeds
731
  if input_ids is not None and inputs_embeds is not None:
732
+ raise ValueError(
733
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
734
+ )
735
  elif input_ids is not None:
736
  batch_size, seq_length = input_ids.shape
737
  elif inputs_embeds is not None:
738
  batch_size, seq_length, _ = inputs_embeds.shape
739
  else:
740
+ raise ValueError(
741
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
742
+ )
743
 
744
  seq_length_with_past = seq_length
745
  past_key_values_length = 0
 
751
  if position_ids is None:
752
  device = input_ids.device if input_ids is not None else inputs_embeds.device
753
  position_ids = torch.arange(
754
+ past_key_values_length,
755
+ seq_length + past_key_values_length,
756
+ dtype=torch.long,
757
+ device=device,
758
  )
759
  position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
760
  else:
 
765
  # embed positions
766
  if attention_mask is None:
767
  attention_mask = torch.ones(
768
+ (batch_size, seq_length_with_past),
769
+ dtype=torch.bool,
770
+ device=inputs_embeds.device,
771
  )
772
  if cross_attention_mask is None:
773
  cross_attention_mask = torch.ones(
774
  (batch_size, 1), dtype=torch.bool, device=inputs_embeds.device
775
  )
776
  attention_mask = self._prepare_decoder_attention_mask(
777
+ attention_mask,
778
+ (batch_size, seq_length),
779
+ inputs_embeds,
780
+ past_key_values_length,
781
  )
782
 
783
  hidden_states = inputs_embeds
 
791
  if output_hidden_states:
792
  all_hidden_states += (hidden_states,)
793
 
794
+ past_key_value = (
795
+ past_key_values[idx] if past_key_values is not None else None
796
+ )
797
  layer_outputs = decoder_layer(
798
  hidden_states,
799
  encoder_outputs=encoder_outputs,
 
821
 
822
  next_cache = next_decoder_cache if use_cache else None
823
  if not return_dict:
824
+ return tuple(
825
+ v
826
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
827
+ if v is not None
828
+ )
829
  return BaseModelOutputWithPast(
830
  last_hidden_state=hidden_states,
831
  past_key_values=next_cache,
 
841
 
842
  # noinspection PyMethodMayBeStatic
843
  # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
844
+ def _prepare_decoder_attention_mask(
845
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
846
+ ):
847
  # create causal mask
848
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
849
  combined_attention_mask = None
 
857
 
858
  if attention_mask is not None:
859
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
860
+ expanded_attn_mask = _expand_mask(
861
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
862
+ ).to(inputs_embeds.device)
863
  combined_attention_mask = (
864
+ expanded_attn_mask
865
+ if combined_attention_mask is None
866
+ else expanded_attn_mask + combined_attention_mask
867
  )
868
 
869
  return combined_attention_mask
870
 
871
+
872
  def vqa_history_to_prompt(history, query):
873
  # Only support single round chat in vqa mode
874
  prompt = "<EOI>Question: "
 
877
  prompt += query + " Short answer:"
878
  return prompt
879
 
880
+
881
  def chat_old_history_to_prompt(history, query):
882
  prompt = "<EOI>Question: "
883
  for i, (old_query, response) in enumerate(history):
 
885
  prompt += query + " Answer:"
886
  return prompt
887
 
888
+
889
  def chat_history_to_prompt(history, query):
890
  prompt = " [INST] "
891
  for i, (old_query, response) in enumerate(history):
 
903
  "base": base_history_to_prompt,
904
  "chat": chat_history_to_prompt,
905
  "chat_old": chat_old_history_to_prompt,
906
+ "vqa": vqa_history_to_prompt,
907
  }
908
 
909
 
 
938
  return self.model
939
 
940
  def forward(
941
+ self,
942
+ input_ids: torch.LongTensor = None,
943
+ images: List[List[torch.Tensor]] = None,
944
+ cross_images: List[List[torch.Tensor]] = None,
945
+ token_type_ids: Optional[torch.LongTensor] = None,
946
+ attention_mask: Optional[torch.Tensor] = None,
947
+ position_ids: Optional[torch.LongTensor] = None,
948
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
949
+ inputs_embeds: Optional[torch.FloatTensor] = None,
950
+ use_cache: Optional[bool] = None,
951
+ output_attentions: Optional[bool] = None,
952
+ output_hidden_states: Optional[bool] = None,
953
+ return_dict: Optional[bool] = None,
954
+ labels: Optional[torch.LongTensor] = None,
955
  ) -> Union[Tuple, CausalLMOutputWithPast]:
956
+ output_attentions = (
957
+ output_attentions
958
+ if output_attentions is not None
959
+ else self.config.output_attentions
960
+ )
961
  output_hidden_states = (
962
+ output_hidden_states
963
+ if output_hidden_states is not None
964
+ else self.config.output_hidden_states
965
+ )
966
+ return_dict = (
967
+ return_dict if return_dict is not None else self.config.use_return_dict
968
  )
 
969
 
970
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
971
  outputs = self.model(
 
1013
  )
1014
 
1015
  def _prepare_attention_mask_for_generation(
1016
+ self,
1017
+ inputs: torch.Tensor,
1018
+ pad_token_id: Optional[int],
1019
+ eos_token_id: Optional[Union[int, List[int]]],
1020
  ) -> torch.LongTensor:
1021
  return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
1022
 
1023
  def prepare_inputs_for_generation(
1024
+ self,
1025
+ input_ids,
1026
+ token_type_ids,
1027
+ images=None,
1028
+ cross_images=None,
1029
+ past_key_values=None,
1030
+ attention_mask=None,
1031
+ inputs_embeds=None,
1032
+ **kwargs,
1033
  ):
1034
  # build position_ids if needed
1035
  position_ids = kwargs.get("position_ids", None)
 
1061
  return model_inputs
1062
 
1063
  def _update_model_kwargs_for_generation(
1064
+ self,
1065
+ outputs: "ModelOutput",
1066
+ model_kwargs: Dict[str, Any],
1067
+ is_encoder_decoder: bool = False,
1068
+ standardize_cache_format: bool = False,
1069
  ) -> Dict[str, Any]:
1070
  # update past_key_values
1071
  model_kwargs["past_key_values"] = self._extract_past_from_model_output(
 
1077
  # update token_type_ids with last value
1078
  if "token_type_ids" in model_kwargs:
1079
  token_type_ids = model_kwargs["token_type_ids"]
1080
+ new_token_type_ids = (
1081
+ torch.ones(
1082
+ size=(token_type_ids.shape[0], 1),
1083
+ dtype=token_type_ids.dtype,
1084
+ device=token_type_ids.device,
1085
+ )
1086
+ * LANGUAGE_TOKEN_TYPE
1087
+ )
1088
+ model_kwargs["token_type_ids"] = torch.cat(
1089
+ [token_type_ids, new_token_type_ids], dim=-1
1090
+ )
1091
 
1092
  if not is_encoder_decoder:
1093
  # update attention mask
1094
  if "attention_mask" in model_kwargs:
1095
  attention_mask = model_kwargs["attention_mask"]
1096
  model_kwargs["attention_mask"] = torch.cat(
1097
+ [
1098
+ attention_mask,
1099
+ attention_mask.new_ones((attention_mask.shape[0], 1)),
1100
+ ],
1101
+ dim=-1,
1102
  )
1103
  else:
1104
  # update decoder attention mask
1105
  if "decoder_attention_mask" in model_kwargs:
1106
  decoder_attention_mask = model_kwargs["decoder_attention_mask"]
1107
  model_kwargs["decoder_attention_mask"] = torch.cat(
1108
+ [
1109
+ decoder_attention_mask,
1110
+ decoder_attention_mask.new_ones(
1111
+ (decoder_attention_mask.shape[0], 1)
1112
+ ),
1113
+ ],
1114
  dim=-1,
1115
  )
1116
 
 
1120
  reordered_past = ()
1121
  for layer_past in past_key_values:
1122
  reordered_past += (
1123
+ tuple(
1124
+ past_state.index_select(0, beam_idx.to(past_state.device))
1125
+ for past_state in layer_past
1126
+ ),
1127
  )
1128
  return reordered_past
1129
 
1130
  def build_conversation_input_ids(
1131
+ self,
1132
+ tokenizer: "PreTrainedTokenizer",
1133
+ *,
1134
+ query: str,
1135
+ history: Optional[List[Tuple[str, str]]] = None,
1136
+ images: Optional[List["PIL.Image"]] = None,
1137
+ template_version: Optional[Literal["base", "chat", "vqa"]] = None,
1138
  ):
1139
+ image_size: int = self.config.vision_config["image_size"]
1140
  cross_image_size: int = self.config.cross_image_size
1141
+ patch_size: int = self.config.vision_config["patch_size"]
1142
  template_version = template_version or self.config.template_version
1143
  assert images is None or len(images) <= 1, f"not support multi images by now."
1144
  history = history or []
 
1152
  transform = transforms.Compose(
1153
  [
1154
  transforms.Resize(
1155
+ (image_size, image_size),
1156
+ interpolation=transforms.InterpolationMode.BICUBIC,
1157
  ),
1158
  transforms.ToTensor(),
1159
+ transforms.Normalize(
1160
+ (0.48145466, 0.4578275, 0.40821073),
1161
+ (0.26862954, 0.26130258, 0.27577711),
1162
+ ),
1163
  ]
1164
  )
1165
  images = [transform(ori[0])]
1166
  cross_transform = transforms.Compose(
1167
  [
1168
  transforms.Resize(
1169
+ (cross_image_size, cross_image_size),
1170
+ interpolation=transforms.InterpolationMode.BICUBIC,
1171
  ),
1172
  transforms.ToTensor(),
1173
+ transforms.Normalize(
1174
+ (0.48145466, 0.4578275, 0.40821073),
1175
+ (0.26862954, 0.26130258, 0.27577711),
1176
+ ),
1177
  ]
1178
  )
1179
  cross_images = [cross_transform(ori[0])]
1180
  # language
1181
+ vision_token_num = (image_size // patch_size) * (
1182
+ image_size // patch_size
1183
+ ) + 2
1184
  input_ids += [tokenizer.pad_token_id] * vision_token_num
1185
  token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
1186
  text_ids = tokenizer.encode(text, add_special_tokens=False)
 
1190
  attention_mask = [1] * len(input_ids)
1191
 
1192
  return {
1193
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
1194
+ "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
1195
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
1196
+ "images": images,
1197
+ "cross_images": cross_images,
1198
  }
visual.py CHANGED
@@ -1,14 +1,20 @@
1
- import torch
2
- from torch import nn
3
  from argparse import Namespace
 
 
4
  import xformers.ops as xops
 
5
  from transformers.activations import ACT2FN
6
 
7
 
8
  class PatchEmbedding(nn.Module):
9
  def __init__(self, config):
10
  super().__init__()
11
- self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size)
 
 
 
 
 
12
  self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
13
  self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
14
 
@@ -26,7 +32,7 @@ class Attention(nn.Module):
26
  super().__init__()
27
  self.num_heads = config.num_heads
28
  head_dim = config.hidden_size // config.num_heads
29
- self.scale = head_dim ** -0.5
30
  self.query_key_value = nn.Linear(config.hidden_size, config.hidden_size * 3)
31
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
32
  self.output_dropout = torch.nn.Dropout(config.dropout_prob)
@@ -34,11 +40,16 @@ class Attention(nn.Module):
34
  def forward(self, x: "tensor(B, L, D)") -> "tensor(B, L, D)":
35
  B, L, _ = x.shape
36
  qkv = self.query_key_value(x)
37
- qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) # 3, B, L, H, D
 
 
38
  q, k, v = qkv[0], qkv[1], qkv[2]
39
 
40
  out = xops.memory_efficient_attention(
41
- q, k, v, scale=self.scale,
 
 
 
42
  )
43
  output = self.dense(out.view(B, L, -1))
44
  output = self.output_dropout(output)
@@ -69,10 +80,14 @@ class MLP(nn.Module):
69
  class TransformerLayer(nn.Module):
70
  def __init__(self, config):
71
  super().__init__()
72
- self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
 
 
73
  self.attention = Attention(config)
74
  self.mlp = MLP(config)
75
- self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
 
 
76
 
77
  def forward(self, hidden_states):
78
  attention_input = hidden_states
@@ -87,7 +102,9 @@ class TransformerLayer(nn.Module):
87
  class Transformer(nn.Module):
88
  def __init__(self, config):
89
  super().__init__()
90
- self.layers = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)])
 
 
91
 
92
  def forward(self, hidden_states):
93
  for layer_module in self.layers:
@@ -102,9 +119,15 @@ class GLU(nn.Module):
102
  self.norm1 = nn.LayerNorm(config.hidden_size)
103
  self.act1 = nn.GELU()
104
  self.act2 = nn.functional.silu
105
- self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
106
- self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
107
- self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
 
 
 
 
 
 
108
 
109
  def forward(self, x):
110
  x = self.linear_proj(x)
@@ -123,7 +146,12 @@ class EVA2CLIPModel(nn.Module):
123
  self.linear_proj = GLU(config, in_features=vision_config.hidden_size)
124
  self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
125
  self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
126
- self.pos_embed = nn.Parameter(torch.zeros((vision_config.image_size // vision_config.patch_size) ** 2, vision_config.hidden_size))
 
 
 
 
 
127
 
128
  def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
129
  x = self.patch_embedding(images)
 
 
 
1
  from argparse import Namespace
2
+
3
+ import torch
4
  import xformers.ops as xops
5
+ from torch import nn
6
  from transformers.activations import ACT2FN
7
 
8
 
9
  class PatchEmbedding(nn.Module):
10
  def __init__(self, config):
11
  super().__init__()
12
+ self.proj = nn.Conv2d(
13
+ config.in_channels,
14
+ config.hidden_size,
15
+ kernel_size=config.patch_size,
16
+ stride=config.patch_size,
17
+ )
18
  self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
19
  self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
20
 
 
32
  super().__init__()
33
  self.num_heads = config.num_heads
34
  head_dim = config.hidden_size // config.num_heads
35
+ self.scale = head_dim**-0.5
36
  self.query_key_value = nn.Linear(config.hidden_size, config.hidden_size * 3)
37
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
38
  self.output_dropout = torch.nn.Dropout(config.dropout_prob)
 
40
  def forward(self, x: "tensor(B, L, D)") -> "tensor(B, L, D)":
41
  B, L, _ = x.shape
42
  qkv = self.query_key_value(x)
43
+ qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(
44
+ 2, 0, 1, 3, 4
45
+ ) # 3, B, L, H, D
46
  q, k, v = qkv[0], qkv[1], qkv[2]
47
 
48
  out = xops.memory_efficient_attention(
49
+ q,
50
+ k,
51
+ v,
52
+ scale=self.scale,
53
  )
54
  output = self.dense(out.view(B, L, -1))
55
  output = self.output_dropout(output)
 
80
  class TransformerLayer(nn.Module):
81
  def __init__(self, config):
82
  super().__init__()
83
+ self.input_layernorm = nn.LayerNorm(
84
+ config.hidden_size, eps=config.layer_norm_eps
85
+ )
86
  self.attention = Attention(config)
87
  self.mlp = MLP(config)
88
+ self.post_attention_layernorm = nn.LayerNorm(
89
+ config.hidden_size, eps=config.layer_norm_eps
90
+ )
91
 
92
  def forward(self, hidden_states):
93
  attention_input = hidden_states
 
102
  class Transformer(nn.Module):
103
  def __init__(self, config):
104
  super().__init__()
105
+ self.layers = nn.ModuleList(
106
+ [TransformerLayer(config) for _ in range(config.num_hidden_layers)]
107
+ )
108
 
109
  def forward(self, hidden_states):
110
  for layer_module in self.layers:
 
119
  self.norm1 = nn.LayerNorm(config.hidden_size)
120
  self.act1 = nn.GELU()
121
  self.act2 = nn.functional.silu
122
+ self.dense_h_to_4h = nn.Linear(
123
+ config.hidden_size, config.intermediate_size, bias=False
124
+ )
125
+ self.gate_proj = nn.Linear(
126
+ config.hidden_size, config.intermediate_size, bias=False
127
+ )
128
+ self.dense_4h_to_h = nn.Linear(
129
+ config.intermediate_size, config.hidden_size, bias=False
130
+ )
131
 
132
  def forward(self, x):
133
  x = self.linear_proj(x)
 
146
  self.linear_proj = GLU(config, in_features=vision_config.hidden_size)
147
  self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
148
  self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
149
+ self.pos_embed = nn.Parameter(
150
+ torch.zeros(
151
+ (vision_config.image_size // vision_config.patch_size) ** 2,
152
+ vision_config.hidden_size,
153
+ )
154
+ )
155
 
156
  def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
157
  x = self.patch_embedding(images)