Text Generation
Transformers
Safetensors
English
custom_code
qingsonglv commited on
Commit
8523fb2
·
1 Parent(s): 952cd27

remove triton dependency

Browse files
Files changed (1) hide show
  1. modeling_cogagent.py +967 -917
modeling_cogagent.py CHANGED
@@ -1,917 +1,967 @@
1
- """largely copy from llama and adapt for CogAgent"""
2
- import warnings
3
- from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, Any
4
-
5
- import math
6
- import torch
7
- from torch import nn
8
- from torch.nn import CrossEntropyLoss
9
- from torchvision import transforms
10
- from einops import rearrange
11
-
12
- from transformers import PreTrainedModel, PreTrainedTokenizer
13
- from transformers.utils.logging import get_logger
14
- from transformers.activations import ACT2FN
15
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
-
17
- from .configuration_cogagent import CogAgentConfig
18
- from .util import FastRotaryEmbedding
19
- from .visual import EVA2CLIPModel
20
- from .cross_visual import CrossVisionModel
21
-
22
- if TYPE_CHECKING:
23
- from transformers.utils import ModelOutput
24
-
25
- logger = get_logger(__name__)
26
-
27
- LANGUAGE_TOKEN_TYPE = 0
28
- VISION_TOKEN_TYPE = 1
29
-
30
-
31
- # Copied from transformers.models.bart.modeling_bart._make_causal_mask
32
- def _make_causal_mask(
33
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
34
- ):
35
- """
36
- Make causal mask used for bi-directional self-attention.
37
- """
38
- bsz, tgt_len = input_ids_shape
39
- mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
40
- mask_cond = torch.arange(mask.size(-1), device=device)
41
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
42
- mask = mask.to(dtype)
43
-
44
- if past_key_values_length > 0:
45
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
46
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
47
-
48
-
49
- # Copied from transformers.models.bart.modeling_bart._expand_mask
50
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
51
- """
52
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
53
- """
54
- bsz, src_len = mask.size()
55
- tgt_len = tgt_len if tgt_len is not None else src_len
56
-
57
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
58
-
59
- inverted_mask = 1.0 - expanded_mask
60
-
61
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
62
-
63
-
64
- class RMSNorm(nn.Module):
65
- def __init__(self, hidden_size, eps=1e-6):
66
- super().__init__()
67
- self.weight = nn.Parameter(torch.ones(hidden_size))
68
- self.variance_epsilon = eps
69
-
70
- def forward(self, hidden_states):
71
- input_dtype = hidden_states.dtype
72
- hidden_states = hidden_states.to(torch.float32)
73
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
74
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
75
- return (self.weight * hidden_states).to(input_dtype)
76
-
77
-
78
- class MLP(nn.Module):
79
- def __init__(self, config):
80
- super().__init__()
81
- self.hidden_size = config.hidden_size
82
- self.intermediate_size = config.intermediate_size
83
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
84
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
85
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
86
- self.act_fn = ACT2FN[config.hidden_act]
87
-
88
- def forward(self, x):
89
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
90
- return down_proj
91
-
92
-
93
- def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
94
- vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
95
- vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] == VISION_TOKEN_TYPE)
96
- language_token_mask = ~vision_token_mask
97
- return vision_token_mask, language_token_mask
98
-
99
-
100
- class VisionExpertMLP(nn.Module):
101
- def __init__(self, config):
102
- super().__init__()
103
- self.language_mlp = MLP(config)
104
- self.vision_mlp = MLP(config)
105
-
106
- def forward(self, hidden_states: "torch.Tensor(B, L, D)", token_type_ids: "torch.LongTensor(B, L)"):
107
- output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device)
108
- vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
109
- output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
110
- output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
111
- return output
112
-
113
-
114
- def attention_fn(
115
- query_layer: "torch.tensor(B, H, L, HD)",
116
- key_layer: "torch.tensor(B, H, L, HD)",
117
- value_layer: "torch.tensor(B, H, L, HD)",
118
- attention_mask: "torch.tensor(B, H, L, HD)",
119
- *,
120
- scaling_attention_score: bool = True,
121
- attention_dropout: nn.Module = None
122
- ):
123
- attention_mask_bool = (attention_mask == 0)
124
- is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
125
- is_full = (attention_mask_bool > 0).all()
126
- if not (int(torch.__version__.split('.')[0]) >= 2):
127
- warnings.warn("It's recommended to use torch2.0 or higher.")
128
- if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
129
- dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
130
- return torch.nn.functional.scaled_dot_product_attention(
131
- query_layer, key_layer, value_layer,
132
- attn_mask=None,
133
- dropout_p=dropout_p,
134
- is_causal=not is_full
135
- )
136
- else:
137
- if scaling_attention_score:
138
- query_layer = query_layer / math.sqrt(query_layer.shape[-1])
139
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
140
- attention_scores = attention_scores + attention_mask
141
- attention_scores = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
142
- if attention_dropout is not None:
143
- attention_scores = attention_dropout(attention_scores)
144
- context_layer = torch.matmul(attention_scores, value_layer)
145
- return context_layer
146
-
147
-
148
- class VisionExpertAttention(nn.Module):
149
- def __init__(self, config):
150
- super().__init__()
151
- self.config = config
152
- self.hidden_size = config.hidden_size
153
- self.num_heads = config.num_attention_heads
154
- self.head_dim = self.hidden_size // self.num_heads
155
- self.max_position_embeddings = config.max_position_embeddings
156
-
157
- # self.rotary_emb = RotaryEmbedding(self.hidden_size // self.num_heads)
158
- self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False)
159
- self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
160
- self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
161
- self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
162
- self.language_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
163
-
164
- def _transpose_for_scores(self, tensor):
165
- """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
166
- new_tensor_shape = tensor.size()[:-1] + (self.num_heads, self.head_dim)
167
- tensor = tensor.view(*new_tensor_shape)
168
- return tensor.permute(0, 2, 1, 3)
169
-
170
- def forward(
171
- self,
172
- hidden_states: torch.Tensor,
173
- token_type_ids: torch.LongTensor,
174
- position_ids: torch.LongTensor,
175
- attention_mask: Optional[torch.Tensor] = None,
176
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
177
- output_attentions: bool = False,
178
- use_cache: bool = False,
179
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
180
- bsz, q_len, _ = hidden_states.size()
181
- vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
182
-
183
- shape = list(hidden_states.shape)
184
- shape[-1] = shape[-1] * 3
185
- mixed_raw_layer = torch.empty(shape, dtype=hidden_states.dtype, device=hidden_states.device)
186
- mixed_raw_layer[vision_token_mask] = self.vision_expert_query_key_value(hidden_states[vision_token_mask])
187
- mixed_raw_layer[language_token_mask] = self.language_expert_query_key_value(hidden_states[language_token_mask])
188
-
189
- query_states, key_states, value_states = torch.split(mixed_raw_layer, self.hidden_size, dim=-1)
190
- query_states = self._transpose_for_scores(query_states) # B, H, L, HD
191
- key_states = self._transpose_for_scores(key_states) # B, H, L, HD
192
- value_states = self._transpose_for_scores(value_states) # B, H, L, HD
193
-
194
- kv_seq_len = key_states.shape[-2]
195
- if past_key_value is not None:
196
- kv_seq_len += past_key_value[0].shape[-2]
197
-
198
- query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids, max_seqlen=position_ids.max() + 1)
199
-
200
- if past_key_value is not None:
201
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
202
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
203
-
204
- past_key_value = (key_states, value_states) if use_cache else None
205
-
206
- context_layer = attention_fn(
207
- query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
208
- scaling_attention_score=True, attention_dropout=None)
209
- if context_layer.size() != (bsz, self.num_heads, q_len, self.head_dim):
210
- raise ValueError(
211
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
212
- f" {context_layer.size()}"
213
- )
214
- context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
215
-
216
- attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
217
- attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
218
- attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
219
-
220
- if output_attentions:
221
- warnings.warn("output_attentions is not implemented.")
222
-
223
- return attn_output, None, past_key_value
224
-
225
- class CrossAttention(nn.Module):
226
- def __init__(self, config):
227
- super().__init__()
228
- self.config = config
229
- self.hidden_size = config.hidden_size
230
- self.cross_hidden_size = config.cross_hidden_size
231
- self.cross_compute_hidden_size = config.cross_compute_hidden_size
232
- self.num_heads = config.num_attention_heads
233
- self.head_dim = self.hidden_size // self.num_heads
234
- self.cross_head_dim = self.cross_compute_hidden_size // self.num_heads
235
- self.max_position_embeddings = config.max_position_embeddings
236
-
237
- # self.rotary_emb = RotaryEmbedding(self.hidden_size // self.num_heads)
238
- self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False)
239
- self.query = nn.Linear(self.hidden_size, self.cross_compute_hidden_size, bias=False)
240
- self.key_value = nn.Linear(self.cross_hidden_size, self.cross_compute_hidden_size * 2, bias=False)
241
- self.dense = nn.Linear(self.cross_compute_hidden_size, self.hidden_size, bias=False)
242
-
243
- def _transpose_for_scores(self, tensor):
244
- """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
245
- new_tensor_shape = tensor.size()[:-1] + (self.num_heads, self.cross_head_dim)
246
- tensor = tensor.view(*new_tensor_shape)
247
- return tensor.permute(0, 2, 1, 3)
248
-
249
- def forward(
250
- self,
251
- hidden_states: torch.Tensor,
252
- encoder_outputs: torch.LongTensor,
253
- attention_mask: Optional[torch.Tensor] = None,
254
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
255
- output_attentions: bool = False,
256
- use_cache: bool = False,
257
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
258
- bsz, q_len, _ = hidden_states.size()
259
-
260
- shape = list(hidden_states.shape)
261
- shape[-1] = shape[-1] * 3
262
-
263
- mixed_query_layer = self.query(hidden_states)
264
- if past_key_value is None:
265
- mixed_x_layer = self.key_value(encoder_outputs)
266
- mixed_key_layer, mixed_value_layer = torch.split(mixed_x_layer, self.cross_compute_hidden_size, dim=-1)
267
- key_states = self._transpose_for_scores(mixed_key_layer) # B, H, L, HD
268
- value_states = self._transpose_for_scores(mixed_value_layer) # B, H, L, HD
269
- else:
270
- key_states, value_states = past_key_value
271
-
272
- query_states = self._transpose_for_scores(mixed_query_layer) # B, H, L, HD
273
-
274
- past_key_value = (key_states, value_states) if use_cache else None
275
-
276
- context_layer = attention_fn(
277
- query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
278
- scaling_attention_score=True, attention_dropout=None)
279
- if context_layer.size() != (bsz, self.num_heads, q_len, self.cross_head_dim):
280
- raise ValueError(
281
- f"`cross_attn_output` should be of size {(bsz, self.num_heads, q_len, self.cross_head_dim)}, but is"
282
- f" {context_layer.size()}"
283
- )
284
- context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.cross_hidden_size)
285
-
286
- attn_output = self.dense(context_layer)
287
-
288
- if output_attentions:
289
- warnings.warn("output_attentions is not implemented.")
290
-
291
- return attn_output, None, past_key_value
292
-
293
- class CogAgentDecoderLayer(nn.Module):
294
- def __init__(self, config):
295
- super().__init__()
296
- self.hidden_size = config.hidden_size
297
- self.self_attn = VisionExpertAttention(config=config)
298
- self.cross_attn = CrossAttention(config=config)
299
- self.mlp = VisionExpertMLP(config)
300
- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
301
- self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
302
- self.post_cross_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
303
-
304
- def forward(
305
- self,
306
- hidden_states: torch.Tensor,
307
- encoder_outputs: torch.Tensor,
308
- token_type_ids: torch.LongTensor,
309
- position_ids: torch.LongTensor,
310
- attention_mask: Optional[torch.Tensor] = None,
311
- cross_attention_mask: Optional[torch.Tensor] = None,
312
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
313
- output_attentions: Optional[bool] = False,
314
- use_cache: Optional[bool] = False,
315
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
316
- residual = hidden_states
317
-
318
- hidden_states = self.input_layernorm(hidden_states)
319
-
320
- # Self Attention
321
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
322
- hidden_states=hidden_states,
323
- token_type_ids=token_type_ids,
324
- position_ids=position_ids,
325
- attention_mask=attention_mask,
326
- past_key_value=past_key_value[:2] if past_key_value is not None else None,
327
- output_attentions=output_attentions,
328
- use_cache=use_cache,
329
- )
330
- hidden_states = residual + hidden_states
331
-
332
- cross_input = self.post_cross_attention_layernorm(hidden_states)
333
- # Fully Connected
334
- attention_output, self_cross_attn_weights, present_cross_key_value = self.cross_attn(
335
- hidden_states=cross_input,
336
- encoder_outputs=encoder_outputs,
337
- attention_mask=cross_attention_mask,
338
- past_key_value=past_key_value[-2:] if past_key_value is not None else None,
339
- output_attentions=output_attentions,
340
- use_cache=use_cache,
341
- )
342
- hidden_states = hidden_states + attention_output
343
- mlp_input = self.post_attention_layernorm(hidden_states)
344
- mlp_output = self.mlp(mlp_input, token_type_ids=token_type_ids)
345
- hidden_states = mlp_output + hidden_states
346
-
347
- outputs = (hidden_states,)
348
-
349
- if output_attentions:
350
- outputs += (self_attn_weights,)
351
-
352
- if use_cache:
353
- outputs += (present_key_value+present_cross_key_value,)
354
-
355
- return outputs # type: ignore
356
-
357
-
358
- class CogAgentPreTrainedModel(PreTrainedModel):
359
- config_class = CogAgentConfig
360
- base_model_prefix = "model"
361
- supports_gradient_checkpointing = False
362
- _no_split_modules = ["CogAgentDecoderLayer"]
363
- _skip_keys_device_placement = "past_key_values"
364
-
365
- def _init_weights(self, module):
366
- std = self.config.initializer_range
367
- if isinstance(module, nn.Linear):
368
- module.weight.data.normal_(mean=0.0, std=std)
369
- if module.bias is not None:
370
- module.bias.data.zero_()
371
- elif isinstance(module, nn.Embedding):
372
- module.weight.data.normal_(mean=0.0, std=std)
373
- if module.padding_idx is not None:
374
- module.weight.data[module.padding_idx].zero_()
375
-
376
-
377
- def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
378
- if images_list is None or len(images_list) == 0:
379
- return True
380
- for image_list in images_list:
381
- if len(image_list):
382
- return False
383
- return True
384
-
385
-
386
- def build_position_ids(x: "torch.BoolTensor(B, L)", attention_mask: Optional["torch.BoolTensor(B, L)"] = None) -> "torch.LongTensor(B, L)":
387
- if attention_mask is not None:
388
- tmp = x.clone()
389
- tmp[~(attention_mask.bool())] = -1
390
- else:
391
- tmp = x.clone()
392
- # image boi eoi token as LANGUAGE_TOKEN_TYPE
393
- is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
394
- is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
395
- is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
396
- is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
397
- is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
398
- tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
399
- # final position ids
400
- y = torch.zeros_like(x, dtype=torch.long)
401
- y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ((tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
402
- y = y.cumsum(dim=-1)
403
- return y
404
-
405
-
406
- class CogAgentModel(CogAgentPreTrainedModel):
407
- def __init__(self, config):
408
- super().__init__(config)
409
- self.padding_idx = config.pad_token_id
410
- self.vocab_size = config.vocab_size
411
-
412
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
413
- self.layers = nn.ModuleList([CogAgentDecoderLayer(config) for _ in range(config.num_hidden_layers)])
414
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
415
-
416
- self.vision = EVA2CLIPModel(config)
417
- self.cross_vision = CrossVisionModel(config)
418
-
419
- self.gradient_checkpointing = False
420
- # Initialize weights and apply final processing
421
- self.post_init()
422
-
423
- def encode_images(self, images: List[List[torch.Tensor]]) -> torch.Tensor:
424
- images_list, images = images, []
425
-
426
- images = []
427
- for image_list in images_list:
428
- for image in image_list:
429
- images.append(image)
430
-
431
- images = torch.stack(images)
432
- images_features = self.vision(images)
433
- return images_features
434
-
435
- def encode_cross_images(self, images: List[List[torch.Tensor]]) -> torch.Tensor:
436
- images_list, images = images, []
437
-
438
- images = []
439
- for image_list in images_list:
440
- for image in image_list:
441
- images.append(image)
442
-
443
- images = torch.stack(images)
444
- encoder_outputs = self.cross_vision(images)
445
- return encoder_outputs
446
-
447
- def forward(
448
- self,
449
- input_ids: torch.LongTensor = None,
450
- images: List[List[torch.Tensor]] = None,
451
- cross_images: List[List[torch.Tensor]] = None,
452
- token_type_ids: Optional[torch.LongTensor] = None,
453
- attention_mask: Optional[torch.Tensor] = None,
454
- cross_attention_mask: Optional[torch.Tensor] = None,
455
- position_ids: Optional[torch.LongTensor] = None,
456
- past_key_values: Optional[List[torch.FloatTensor]] = None,
457
- inputs_embeds: Optional[torch.FloatTensor] = None,
458
- use_cache: Optional[bool] = None,
459
- output_attentions: Optional[bool] = None,
460
- output_hidden_states: Optional[bool] = None,
461
- return_dict: Optional[bool] = None,
462
- ) -> Union[Tuple, BaseModelOutputWithPast]:
463
- """take care of image_encode, token_type_ids, position_ids and (attention_mask = None is fine)"""
464
-
465
- if past_key_values is not None:
466
- encoder_outputs = None
467
- # generate mode with past_key_values. the image features are already mapped
468
- else:
469
- # not allow for inputs_embeds, because we want to process image feature
470
- assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
471
- if not is_empty(images): # multi-modality
472
- assert token_type_ids is not None, f"multi-modality requires `token_type_ids`!"
473
- assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
474
- inputs_embeds = self.embed_tokens(input_ids)
475
- images_features = self.encode_images(images)
476
- encoder_outputs = self.encode_cross_images(cross_images)
477
- images_features = rearrange(images_features, 'b n d -> (b n) d')
478
- images_features = images_features.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
479
- inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
480
- else: # single-modality
481
- if token_type_ids is None:
482
- token_type_ids = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) * LANGUAGE_TOKEN_TYPE
483
- assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
484
- inputs_embeds = self.embed_tokens(input_ids)
485
- encoder_outputs = None
486
-
487
- if position_ids is None:
488
- position_ids = build_position_ids(token_type_ids, attention_mask)
489
- input_ids = None
490
-
491
- return self.llm_forward(
492
- input_ids=input_ids,
493
- encoder_outputs=encoder_outputs,
494
- token_type_ids=token_type_ids,
495
- attention_mask=attention_mask,
496
- cross_attention_mask=cross_attention_mask,
497
- position_ids=position_ids,
498
- past_key_values=past_key_values,
499
- inputs_embeds=inputs_embeds,
500
- use_cache=use_cache,
501
- output_attentions=output_attentions,
502
- output_hidden_states=output_hidden_states,
503
- return_dict=return_dict,
504
- )
505
-
506
- def llm_forward(
507
- self,
508
- input_ids: torch.LongTensor = None,
509
- encoder_outputs: torch.LongTensor = None,
510
- token_type_ids: torch.LongTensor = None,
511
- attention_mask: Optional[torch.Tensor] = None,
512
- cross_attention_mask: Optional[torch.Tensor] = None,
513
- position_ids: Optional[torch.LongTensor] = None,
514
- past_key_values: Optional[List[torch.FloatTensor]] = None,
515
- inputs_embeds: Optional[torch.FloatTensor] = None,
516
- use_cache: Optional[bool] = None,
517
- output_attentions: Optional[bool] = None,
518
- output_hidden_states: Optional[bool] = None,
519
- return_dict: Optional[bool] = None,
520
- ) -> Union[Tuple, BaseModelOutputWithPast]:
521
- """largely copy from llama forward and adapt for CogAgent with `token_type_ids`"""
522
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
523
- output_hidden_states = (
524
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
525
- )
526
- use_cache = use_cache if use_cache is not None else self.config.use_cache
527
-
528
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
529
-
530
- # retrieve input_ids and inputs_embeds
531
- if input_ids is not None and inputs_embeds is not None:
532
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
533
- elif input_ids is not None:
534
- batch_size, seq_length = input_ids.shape
535
- elif inputs_embeds is not None:
536
- batch_size, seq_length, _ = inputs_embeds.shape
537
- else:
538
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
539
-
540
- seq_length_with_past = seq_length
541
- past_key_values_length = 0
542
-
543
- if past_key_values is not None:
544
- past_key_values_length = past_key_values[0][0].shape[2]
545
- seq_length_with_past = seq_length_with_past + past_key_values_length
546
-
547
- if position_ids is None:
548
- device = input_ids.device if input_ids is not None else inputs_embeds.device
549
- position_ids = torch.arange(
550
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
551
- )
552
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
553
- else:
554
- position_ids = position_ids.view(-1, seq_length).long()
555
-
556
- if inputs_embeds is None:
557
- inputs_embeds = self.embed_tokens(input_ids)
558
- # embed positions
559
- if attention_mask is None:
560
- attention_mask = torch.ones(
561
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
562
- )
563
- if cross_attention_mask is None:
564
- cross_attention_mask = torch.ones(
565
- (batch_size, 1), dtype=torch.bool, device=inputs_embeds.device
566
- )
567
- attention_mask = self._prepare_decoder_attention_mask(
568
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
569
- )
570
-
571
- hidden_states = inputs_embeds
572
-
573
- # decoder layers
574
- all_hidden_states = () if output_hidden_states else None
575
- all_self_attns = () if output_attentions else None
576
- next_decoder_cache = () if use_cache else None
577
-
578
- for idx, decoder_layer in enumerate(self.layers):
579
- if output_hidden_states:
580
- all_hidden_states += (hidden_states,)
581
-
582
- past_key_value = past_key_values[idx] if past_key_values is not None else None
583
- layer_outputs = decoder_layer(
584
- hidden_states,
585
- encoder_outputs=encoder_outputs,
586
- token_type_ids=token_type_ids,
587
- attention_mask=attention_mask,
588
- cross_attention_mask=cross_attention_mask,
589
- position_ids=position_ids,
590
- past_key_value=past_key_value,
591
- output_attentions=output_attentions,
592
- use_cache=use_cache,
593
- )
594
- hidden_states = layer_outputs[0]
595
-
596
- if use_cache:
597
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
598
-
599
- if output_attentions:
600
- all_self_attns += (layer_outputs[1],)
601
-
602
- hidden_states = self.norm(hidden_states)
603
-
604
- # add hidden states from the last decoder layer
605
- if output_hidden_states:
606
- all_hidden_states += (hidden_states,)
607
-
608
- next_cache = next_decoder_cache if use_cache else None
609
- if not return_dict:
610
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
611
- return BaseModelOutputWithPast(
612
- last_hidden_state=hidden_states,
613
- past_key_values=next_cache,
614
- hidden_states=all_hidden_states,
615
- attentions=all_self_attns,
616
- )
617
-
618
- def get_input_embeddings(self):
619
- return self.embed_tokens
620
-
621
- def set_input_embeddings(self, value):
622
- self.embed_tokens = value
623
-
624
- # noinspection PyMethodMayBeStatic
625
- # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
626
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
627
- # create causal mask
628
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
629
- combined_attention_mask = None
630
- if input_shape[-1] > 1:
631
- combined_attention_mask = _make_causal_mask(
632
- input_shape,
633
- inputs_embeds.dtype,
634
- device=inputs_embeds.device,
635
- past_key_values_length=past_key_values_length,
636
- )
637
-
638
- if attention_mask is not None:
639
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
640
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
641
- inputs_embeds.device
642
- )
643
- combined_attention_mask = (
644
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
645
- )
646
-
647
- return combined_attention_mask
648
-
649
- def chat_old_history_to_prompt(history, query):
650
- prompt = "<EOI>Question: "
651
- for i, (old_query, response) in enumerate(history):
652
- prompt += old_query + " Answer: " + response + "\nQuestion: "
653
- prompt += query + " Answer:"
654
- return prompt
655
-
656
- def chat_history_to_prompt(history, query):
657
- prompt = " [INST] "
658
- for i, (old_query, response) in enumerate(history):
659
- prompt += old_query + " [/INST] " + response + " [INST] "
660
- prompt += query + " [/INST] "
661
- return prompt
662
-
663
-
664
- def base_history_to_prompt(history, query):
665
- prompt = query
666
- return prompt
667
-
668
-
669
- _history_to_prompt = {
670
- "base": base_history_to_prompt,
671
- "chat": chat_history_to_prompt,
672
- "chat_old": chat_old_history_to_prompt
673
- }
674
-
675
-
676
- class CogAgentForCausalLM(CogAgentPreTrainedModel):
677
- _auto_class = "AutoModelForCausalLM"
678
-
679
- def __init__(self, config):
680
- super().__init__(config)
681
- self.model = CogAgentModel(config)
682
- self.vocab_size = config.vocab_size
683
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
684
-
685
- # Initialize weights and apply final processing
686
- self.post_init()
687
-
688
- def get_input_embeddings(self):
689
- return self.model.embed_tokens
690
-
691
- def set_input_embeddings(self, value):
692
- self.model.embed_tokens = value
693
-
694
- def get_output_embeddings(self):
695
- return self.lm_head
696
-
697
- def set_output_embeddings(self, new_embeddings):
698
- self.lm_head = new_embeddings
699
-
700
- def set_decoder(self, decoder):
701
- self.model = decoder
702
-
703
- def get_decoder(self):
704
- return self.model
705
-
706
- def forward(
707
- self,
708
- input_ids: torch.LongTensor = None,
709
- images: List[List[torch.Tensor]] = None,
710
- cross_images: List[List[torch.Tensor]] = None,
711
- token_type_ids: Optional[torch.LongTensor] = None,
712
- attention_mask: Optional[torch.Tensor] = None,
713
- position_ids: Optional[torch.LongTensor] = None,
714
- past_key_values: Optional[List[torch.FloatTensor]] = None,
715
- inputs_embeds: Optional[torch.FloatTensor] = None,
716
- use_cache: Optional[bool] = None,
717
- output_attentions: Optional[bool] = None,
718
- output_hidden_states: Optional[bool] = None,
719
- return_dict: Optional[bool] = None,
720
- labels: Optional[torch.LongTensor] = None,
721
- ) -> Union[Tuple, CausalLMOutputWithPast]:
722
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
723
- output_hidden_states = (
724
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
725
- )
726
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
727
-
728
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
729
- outputs = self.model(
730
- input_ids=input_ids,
731
- images=images,
732
- cross_images=cross_images,
733
- token_type_ids=token_type_ids,
734
- attention_mask=attention_mask,
735
- position_ids=position_ids,
736
- past_key_values=past_key_values,
737
- inputs_embeds=inputs_embeds,
738
- use_cache=use_cache,
739
- output_attentions=output_attentions,
740
- output_hidden_states=output_hidden_states,
741
- return_dict=return_dict,
742
- )
743
-
744
- hidden_states = outputs[0]
745
- logits = self.lm_head(hidden_states)
746
- logits = logits.float()
747
-
748
- loss = None
749
- if labels is not None:
750
- # Shift so that tokens < n predict n
751
- shift_logits = logits[..., :-1, :].contiguous()
752
- shift_labels = labels[..., 1:].contiguous()
753
- # Flatten the tokens
754
- loss_fct = CrossEntropyLoss()
755
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
756
- shift_labels = shift_labels.view(-1)
757
- # Enable model parallelism
758
- shift_labels = shift_labels.to(shift_logits.device)
759
- loss = loss_fct(shift_logits, shift_labels)
760
-
761
- if not return_dict:
762
- output = (logits,) + outputs[1:]
763
- return (loss,) + output if loss is not None else output
764
-
765
- return CausalLMOutputWithPast(
766
- loss=loss,
767
- logits=logits,
768
- past_key_values=outputs.past_key_values,
769
- hidden_states=outputs.hidden_states,
770
- attentions=outputs.attentions,
771
- )
772
-
773
- def _prepare_attention_mask_for_generation(
774
- self,
775
- inputs: torch.Tensor,
776
- pad_token_id: Optional[int],
777
- eos_token_id: Optional[Union[int, List[int]]],
778
- ) -> torch.LongTensor:
779
- return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
780
-
781
- def prepare_inputs_for_generation(
782
- self, input_ids, token_type_ids, images=None, cross_images=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
783
- ):
784
- # build position_ids if needed
785
- position_ids = kwargs.get("position_ids", None)
786
- if position_ids is None:
787
- position_ids = build_position_ids(token_type_ids, attention_mask)
788
-
789
- if past_key_values:
790
- input_ids = input_ids[:, -1:]
791
- token_type_ids = token_type_ids[:, -1:]
792
- position_ids = position_ids[:, -1:]
793
-
794
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
795
- if inputs_embeds is not None and past_key_values is None:
796
- model_inputs = {"inputs_embeds": inputs_embeds}
797
- else:
798
- model_inputs = {"input_ids": input_ids}
799
-
800
- model_inputs.update(
801
- {
802
- "token_type_ids": token_type_ids,
803
- "images": images,
804
- "cross_images": cross_images,
805
- "position_ids": position_ids,
806
- "past_key_values": past_key_values,
807
- "use_cache": kwargs.get("use_cache"),
808
- "attention_mask": attention_mask,
809
- }
810
- )
811
- return model_inputs
812
-
813
- def _update_model_kwargs_for_generation(
814
- self,
815
- outputs: "ModelOutput",
816
- model_kwargs: Dict[str, Any],
817
- is_encoder_decoder: bool = False,
818
- standardize_cache_format: bool = False,
819
- ) -> Dict[str, Any]:
820
- # update past_key_values
821
- model_kwargs["past_key_values"] = self._extract_past_from_model_output(
822
- outputs, standardize_cache_format=standardize_cache_format
823
- )
824
- if getattr(outputs, "state", None) is not None:
825
- model_kwargs["state"] = outputs.state
826
-
827
- # update token_type_ids with last value
828
- if "token_type_ids" in model_kwargs:
829
- token_type_ids = model_kwargs["token_type_ids"]
830
- new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype, device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
831
- model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
832
-
833
- if not is_encoder_decoder:
834
- # update attention mask
835
- if "attention_mask" in model_kwargs:
836
- attention_mask = model_kwargs["attention_mask"]
837
- model_kwargs["attention_mask"] = torch.cat(
838
- [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
839
- )
840
- else:
841
- # update decoder attention mask
842
- if "decoder_attention_mask" in model_kwargs:
843
- decoder_attention_mask = model_kwargs["decoder_attention_mask"]
844
- model_kwargs["decoder_attention_mask"] = torch.cat(
845
- [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
846
- dim=-1,
847
- )
848
-
849
- return model_kwargs
850
-
851
- def _reorder_cache(self, past_key_values, beam_idx):
852
- reordered_past = ()
853
- for layer_past in past_key_values:
854
- reordered_past += (
855
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
856
- )
857
- return reordered_past
858
-
859
- def build_conversation_input_ids(
860
- self,
861
- tokenizer: "PreTrainedTokenizer",
862
- *,
863
- query: str,
864
- history: Optional[List[Tuple[str, str]]] = None,
865
- images: Optional[List["PIL.Image"]] = None,
866
- template_version: Optional[Literal["base", "chat", "vqa"]] = None,
867
- ):
868
- image_size: int = self.config.vision_config['image_size']
869
- cross_image_size: int = self.config.cross_image_size
870
- patch_size: int = self.config.vision_config['patch_size']
871
- template_version = template_version or self.config.template_version
872
- assert images is None or len(images) <= 1, f"not support multi images by now."
873
- history = history or []
874
- text = _history_to_prompt[template_version](history, query)
875
-
876
- input_ids = [tokenizer.bos_token_id]
877
- token_type_ids = [LANGUAGE_TOKEN_TYPE]
878
- if images is not None and len(images) == 1:
879
- ori = images
880
- # vision
881
- transform = transforms.Compose(
882
- [
883
- transforms.Resize(
884
- (image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC
885
- ),
886
- transforms.ToTensor(),
887
- transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
888
- ]
889
- )
890
- images = [transform(ori[0])]
891
- cross_transform = transforms.Compose(
892
- [
893
- transforms.Resize(
894
- (cross_image_size, cross_image_size), interpolation=transforms.InterpolationMode.BICUBIC
895
- ),
896
- transforms.ToTensor(),
897
- transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
898
- ]
899
- )
900
- cross_images = [cross_transform(ori[0])]
901
- # language
902
- vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
903
- input_ids += [tokenizer.pad_token_id] * vision_token_num
904
- token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
905
- text_ids = tokenizer.encode(text, add_special_tokens=False)
906
-
907
- input_ids += text_ids
908
- token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
909
- attention_mask = [1] * len(input_ids)
910
-
911
- return {
912
- 'input_ids': torch.tensor(input_ids, dtype=torch.long),
913
- 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
914
- 'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
915
- 'images': images,
916
- 'cross_images': cross_images
917
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """largely copy from llama and adapt for CogAgent"""
2
+ import warnings
3
+ from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, Any
4
+
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import CrossEntropyLoss
9
+ from torchvision import transforms
10
+ from einops import rearrange
11
+
12
+ from transformers import PreTrainedModel, PreTrainedTokenizer
13
+ from transformers.utils.logging import get_logger
14
+ from transformers.activations import ACT2FN
15
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
+
17
+ from .configuration_cogagent import CogAgentConfig
18
+ # from .util import FastRotaryEmbedding
19
+ from torch.nn import functional as F
20
+ from .visual import EVA2CLIPModel
21
+ from .cross_visual import CrossVisionModel
22
+
23
+ if TYPE_CHECKING:
24
+ from transformers.utils import ModelOutput
25
+
26
+ logger = get_logger(__name__)
27
+
28
+ LANGUAGE_TOKEN_TYPE = 0
29
+ VISION_TOKEN_TYPE = 1
30
+
31
+
32
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
33
+ def _make_causal_mask(
34
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
35
+ ):
36
+ """
37
+ Make causal mask used for bi-directional self-attention.
38
+ """
39
+ bsz, tgt_len = input_ids_shape
40
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
41
+ mask_cond = torch.arange(mask.size(-1), device=device)
42
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
43
+ mask = mask.to(dtype)
44
+
45
+ if past_key_values_length > 0:
46
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
47
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
48
+
49
+
50
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
51
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
52
+ """
53
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
54
+ """
55
+ bsz, src_len = mask.size()
56
+ tgt_len = tgt_len if tgt_len is not None else src_len
57
+
58
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
59
+
60
+ inverted_mask = 1.0 - expanded_mask
61
+
62
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
63
+
64
+
65
+ class RMSNorm(nn.Module):
66
+ def __init__(self, hidden_size, eps=1e-6):
67
+ super().__init__()
68
+ self.weight = nn.Parameter(torch.ones(hidden_size))
69
+ self.variance_epsilon = eps
70
+
71
+ def forward(self, hidden_states):
72
+ input_dtype = hidden_states.dtype
73
+ hidden_states = hidden_states.to(torch.float32)
74
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
75
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
76
+ return (self.weight * hidden_states).to(input_dtype)
77
+
78
+
79
+ class MLP(nn.Module):
80
+ def __init__(self, config):
81
+ super().__init__()
82
+ self.hidden_size = config.hidden_size
83
+ self.intermediate_size = config.intermediate_size
84
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
85
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
86
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
87
+ self.act_fn = ACT2FN[config.hidden_act]
88
+
89
+ def forward(self, x):
90
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
91
+ return down_proj
92
+
93
+
94
+ def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
95
+ vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
96
+ vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] == VISION_TOKEN_TYPE)
97
+ language_token_mask = ~vision_token_mask
98
+ return vision_token_mask, language_token_mask
99
+
100
+
101
+ class VisionExpertMLP(nn.Module):
102
+ def __init__(self, config):
103
+ super().__init__()
104
+ self.language_mlp = MLP(config)
105
+ self.vision_mlp = MLP(config)
106
+
107
+ def forward(self, hidden_states: "torch.Tensor(B, L, D)", token_type_ids: "torch.LongTensor(B, L)"):
108
+ output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device)
109
+ vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
110
+ output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
111
+ output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
112
+ return output
113
+
114
+
115
+ def attention_fn(
116
+ query_layer: "torch.tensor(B, H, L, HD)",
117
+ key_layer: "torch.tensor(B, H, L, HD)",
118
+ value_layer: "torch.tensor(B, H, L, HD)",
119
+ attention_mask: "torch.tensor(B, H, L, HD)",
120
+ *,
121
+ scaling_attention_score: bool = True,
122
+ attention_dropout: nn.Module = None
123
+ ):
124
+ attention_mask_bool = (attention_mask == 0)
125
+ is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
126
+ is_full = (attention_mask_bool > 0).all()
127
+ if not (int(torch.__version__.split('.')[0]) >= 2):
128
+ warnings.warn("It's recommended to use torch2.0 or higher.")
129
+ if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
130
+ dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
131
+ return torch.nn.functional.scaled_dot_product_attention(
132
+ query_layer, key_layer, value_layer,
133
+ attn_mask=None,
134
+ dropout_p=dropout_p,
135
+ is_causal=not is_full
136
+ )
137
+ else:
138
+ if scaling_attention_score:
139
+ query_layer = query_layer / math.sqrt(query_layer.shape[-1])
140
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
141
+ attention_scores = attention_scores + attention_mask
142
+ attention_scores = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
143
+ if attention_dropout is not None:
144
+ attention_scores = attention_dropout(attention_scores)
145
+ context_layer = torch.matmul(attention_scores, value_layer)
146
+ return context_layer
147
+
148
+ class RotaryEmbedding(torch.nn.Module):
149
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
150
+ super().__init__()
151
+
152
+ self.dim = dim
153
+ self.max_position_embeddings = max_position_embeddings
154
+ self.base = base
155
+ inv_freq = self._compute_inv_freq(device)
156
+ self.register_buffer("inv_freq", inv_freq)
157
+ self.max_seq_len_cached = 0
158
+
159
+ def _compute_inv_freq(self, device=None):
160
+ return 1.0 / (
161
+ self.base
162
+ ** (torch.arange(0, self.dim, 2, device=device) / self.dim)
163
+ )
164
+
165
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
166
+ self.max_seq_len_cached = seq_len
167
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
168
+
169
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
170
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
171
+ emb = torch.cat((freqs, freqs), dim=-1)
172
+ self.register_buffer("cos_cached", emb.cos()[:, None, :].to(dtype), persistent=False)
173
+ self.register_buffer("sin_cached", emb.sin()[:, None, :].to(dtype), persistent=False)
174
+
175
+ def forward(self, x, seq_len):
176
+ # x: [bs, num_attention_heads, seq_len, head_size]
177
+ if seq_len > self.max_seq_len_cached:
178
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
179
+
180
+ return (
181
+ self.cos_cached[:seq_len, ...].to(dtype=x.dtype),
182
+ self.sin_cached[:seq_len, ...].to(dtype=x.dtype),
183
+ )
184
+
185
+
186
+ def rotate_half(x):
187
+ x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
188
+ return torch.cat((-x2, x1), dim=x1.ndim - 1)
189
+
190
+
191
+ def apply_rotary_pos_emb_index_bhs(q, k, cos, sin, position_id):
192
+ # batch_size, num_head, seq_len, hidden_size
193
+ cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(1), \
194
+ F.embedding(position_id, sin.squeeze(1)).unsqueeze(1)
195
+ q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
196
+ return q, k
197
+
198
+ class VisionExpertAttention(nn.Module):
199
+ def __init__(self, config):
200
+ super().__init__()
201
+ self.config = config
202
+ self.hidden_size = config.hidden_size
203
+ self.num_heads = config.num_attention_heads
204
+ self.head_dim = self.hidden_size // self.num_heads
205
+ self.max_position_embeddings = config.max_position_embeddings
206
+
207
+ self.rotary_emb = RotaryEmbedding(self.head_dim)
208
+ self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
209
+ self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
210
+ self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
211
+ self.language_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
212
+
213
+ def _transpose_for_scores(self, tensor):
214
+ """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
215
+ new_tensor_shape = tensor.size()[:-1] + (self.num_heads, self.head_dim)
216
+ tensor = tensor.view(*new_tensor_shape)
217
+ return tensor.permute(0, 2, 1, 3)
218
+
219
+ def forward(
220
+ self,
221
+ hidden_states: torch.Tensor,
222
+ token_type_ids: torch.LongTensor,
223
+ position_ids: torch.LongTensor,
224
+ attention_mask: Optional[torch.Tensor] = None,
225
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
226
+ output_attentions: bool = False,
227
+ use_cache: bool = False,
228
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
229
+ bsz, q_len, _ = hidden_states.size()
230
+ vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
231
+
232
+ shape = list(hidden_states.shape)
233
+ shape[-1] = shape[-1] * 3
234
+ mixed_raw_layer = torch.empty(shape, dtype=hidden_states.dtype, device=hidden_states.device)
235
+ mixed_raw_layer[vision_token_mask] = self.vision_expert_query_key_value(hidden_states[vision_token_mask])
236
+ mixed_raw_layer[language_token_mask] = self.language_expert_query_key_value(hidden_states[language_token_mask])
237
+
238
+ query_states, key_states, value_states = torch.split(mixed_raw_layer, self.hidden_size, dim=-1)
239
+ query_states = self._transpose_for_scores(query_states) # B, H, L, HD
240
+ key_states = self._transpose_for_scores(key_states) # B, H, L, HD
241
+ value_states = self._transpose_for_scores(value_states) # B, H, L, HD
242
+
243
+ kv_seq_len = key_states.shape[-2]
244
+ if past_key_value is not None:
245
+ kv_seq_len += past_key_value[0].shape[-2]
246
+
247
+ cos, sin = self.rotary_emb(value_states, seq_len=position_ids.max() + 1)
248
+ query_states, key_states = apply_rotary_pos_emb_index_bhs(query_states, key_states, cos, sin, position_ids)
249
+
250
+ if past_key_value is not None:
251
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
252
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
253
+
254
+ past_key_value = (key_states, value_states) if use_cache else None
255
+
256
+ context_layer = attention_fn(
257
+ query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
258
+ scaling_attention_score=True, attention_dropout=None)
259
+ if context_layer.size() != (bsz, self.num_heads, q_len, self.head_dim):
260
+ raise ValueError(
261
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
262
+ f" {context_layer.size()}"
263
+ )
264
+ context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
265
+
266
+ attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
267
+ attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
268
+ attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
269
+
270
+ if output_attentions:
271
+ warnings.warn("output_attentions is not implemented.")
272
+
273
+ return attn_output, None, past_key_value
274
+
275
+ class CrossAttention(nn.Module):
276
+ def __init__(self, config):
277
+ super().__init__()
278
+ self.config = config
279
+ self.hidden_size = config.hidden_size
280
+ self.cross_hidden_size = config.cross_hidden_size
281
+ self.cross_compute_hidden_size = config.cross_compute_hidden_size
282
+ self.num_heads = config.num_attention_heads
283
+ self.head_dim = self.hidden_size // self.num_heads
284
+ self.cross_head_dim = self.cross_compute_hidden_size // self.num_heads
285
+ self.max_position_embeddings = config.max_position_embeddings
286
+
287
+ # self.rotary_emb = RotaryEmbedding(self.hidden_size // self.num_heads)
288
+ self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False)
289
+ self.query = nn.Linear(self.hidden_size, self.cross_compute_hidden_size, bias=False)
290
+ self.key_value = nn.Linear(self.cross_hidden_size, self.cross_compute_hidden_size * 2, bias=False)
291
+ self.dense = nn.Linear(self.cross_compute_hidden_size, self.hidden_size, bias=False)
292
+
293
+ def _transpose_for_scores(self, tensor):
294
+ """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
295
+ new_tensor_shape = tensor.size()[:-1] + (self.num_heads, self.cross_head_dim)
296
+ tensor = tensor.view(*new_tensor_shape)
297
+ return tensor.permute(0, 2, 1, 3)
298
+
299
+ def forward(
300
+ self,
301
+ hidden_states: torch.Tensor,
302
+ encoder_outputs: torch.LongTensor,
303
+ attention_mask: Optional[torch.Tensor] = None,
304
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
305
+ output_attentions: bool = False,
306
+ use_cache: bool = False,
307
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
308
+ bsz, q_len, _ = hidden_states.size()
309
+
310
+ shape = list(hidden_states.shape)
311
+ shape[-1] = shape[-1] * 3
312
+
313
+ mixed_query_layer = self.query(hidden_states)
314
+ if past_key_value is None:
315
+ mixed_x_layer = self.key_value(encoder_outputs)
316
+ mixed_key_layer, mixed_value_layer = torch.split(mixed_x_layer, self.cross_compute_hidden_size, dim=-1)
317
+ key_states = self._transpose_for_scores(mixed_key_layer) # B, H, L, HD
318
+ value_states = self._transpose_for_scores(mixed_value_layer) # B, H, L, HD
319
+ else:
320
+ key_states, value_states = past_key_value
321
+
322
+ query_states = self._transpose_for_scores(mixed_query_layer) # B, H, L, HD
323
+
324
+ past_key_value = (key_states, value_states) if use_cache else None
325
+
326
+ context_layer = attention_fn(
327
+ query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
328
+ scaling_attention_score=True, attention_dropout=None)
329
+ if context_layer.size() != (bsz, self.num_heads, q_len, self.cross_head_dim):
330
+ raise ValueError(
331
+ f"`cross_attn_output` should be of size {(bsz, self.num_heads, q_len, self.cross_head_dim)}, but is"
332
+ f" {context_layer.size()}"
333
+ )
334
+ context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.cross_hidden_size)
335
+
336
+ attn_output = self.dense(context_layer)
337
+
338
+ if output_attentions:
339
+ warnings.warn("output_attentions is not implemented.")
340
+
341
+ return attn_output, None, past_key_value
342
+
343
+ class CogAgentDecoderLayer(nn.Module):
344
+ def __init__(self, config):
345
+ super().__init__()
346
+ self.hidden_size = config.hidden_size
347
+ self.self_attn = VisionExpertAttention(config=config)
348
+ self.cross_attn = CrossAttention(config=config)
349
+ self.mlp = VisionExpertMLP(config)
350
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
351
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
352
+ self.post_cross_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
353
+
354
+ def forward(
355
+ self,
356
+ hidden_states: torch.Tensor,
357
+ encoder_outputs: torch.Tensor,
358
+ token_type_ids: torch.LongTensor,
359
+ position_ids: torch.LongTensor,
360
+ attention_mask: Optional[torch.Tensor] = None,
361
+ cross_attention_mask: Optional[torch.Tensor] = None,
362
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
363
+ output_attentions: Optional[bool] = False,
364
+ use_cache: Optional[bool] = False,
365
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
366
+ residual = hidden_states
367
+
368
+ hidden_states = self.input_layernorm(hidden_states)
369
+
370
+ # Self Attention
371
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
372
+ hidden_states=hidden_states,
373
+ token_type_ids=token_type_ids,
374
+ position_ids=position_ids,
375
+ attention_mask=attention_mask,
376
+ past_key_value=past_key_value[:2] if past_key_value is not None else None,
377
+ output_attentions=output_attentions,
378
+ use_cache=use_cache,
379
+ )
380
+ hidden_states = residual + hidden_states
381
+
382
+ cross_input = self.post_cross_attention_layernorm(hidden_states)
383
+ # Fully Connected
384
+ attention_output, self_cross_attn_weights, present_cross_key_value = self.cross_attn(
385
+ hidden_states=cross_input,
386
+ encoder_outputs=encoder_outputs,
387
+ attention_mask=cross_attention_mask,
388
+ past_key_value=past_key_value[-2:] if past_key_value is not None else None,
389
+ output_attentions=output_attentions,
390
+ use_cache=use_cache,
391
+ )
392
+ hidden_states = hidden_states + attention_output
393
+ mlp_input = self.post_attention_layernorm(hidden_states)
394
+ mlp_output = self.mlp(mlp_input, token_type_ids=token_type_ids)
395
+ hidden_states = mlp_output + hidden_states
396
+
397
+ outputs = (hidden_states,)
398
+
399
+ if output_attentions:
400
+ outputs += (self_attn_weights,)
401
+
402
+ if use_cache:
403
+ outputs += (present_key_value+present_cross_key_value,)
404
+
405
+ return outputs # type: ignore
406
+
407
+
408
+ class CogAgentPreTrainedModel(PreTrainedModel):
409
+ config_class = CogAgentConfig
410
+ base_model_prefix = "model"
411
+ supports_gradient_checkpointing = False
412
+ _no_split_modules = ["CogAgentDecoderLayer"]
413
+ _skip_keys_device_placement = "past_key_values"
414
+
415
+ def _init_weights(self, module):
416
+ std = self.config.initializer_range
417
+ if isinstance(module, nn.Linear):
418
+ module.weight.data.normal_(mean=0.0, std=std)
419
+ if module.bias is not None:
420
+ module.bias.data.zero_()
421
+ elif isinstance(module, nn.Embedding):
422
+ module.weight.data.normal_(mean=0.0, std=std)
423
+ if module.padding_idx is not None:
424
+ module.weight.data[module.padding_idx].zero_()
425
+
426
+
427
+ def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
428
+ if images_list is None or len(images_list) == 0:
429
+ return True
430
+ for image_list in images_list:
431
+ if len(image_list):
432
+ return False
433
+ return True
434
+
435
+
436
+ def build_position_ids(x: "torch.BoolTensor(B, L)", attention_mask: Optional["torch.BoolTensor(B, L)"] = None) -> "torch.LongTensor(B, L)":
437
+ if attention_mask is not None:
438
+ tmp = x.clone()
439
+ tmp[~(attention_mask.bool())] = -1
440
+ else:
441
+ tmp = x.clone()
442
+ # image boi eoi token as LANGUAGE_TOKEN_TYPE
443
+ is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
444
+ is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
445
+ is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
446
+ is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
447
+ is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
448
+ tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
449
+ # final position ids
450
+ y = torch.zeros_like(x, dtype=torch.long)
451
+ y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ((tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
452
+ y = y.cumsum(dim=-1)
453
+ return y
454
+
455
+
456
+ class CogAgentModel(CogAgentPreTrainedModel):
457
+ def __init__(self, config):
458
+ super().__init__(config)
459
+ self.padding_idx = config.pad_token_id
460
+ self.vocab_size = config.vocab_size
461
+
462
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
463
+ self.layers = nn.ModuleList([CogAgentDecoderLayer(config) for _ in range(config.num_hidden_layers)])
464
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
465
+
466
+ self.vision = EVA2CLIPModel(config)
467
+ self.cross_vision = CrossVisionModel(config)
468
+
469
+ self.gradient_checkpointing = False
470
+ # Initialize weights and apply final processing
471
+ self.post_init()
472
+
473
+ def encode_images(self, images: List[List[torch.Tensor]]) -> torch.Tensor:
474
+ images_list, images = images, []
475
+
476
+ images = []
477
+ for image_list in images_list:
478
+ for image in image_list:
479
+ images.append(image)
480
+
481
+ images = torch.stack(images)
482
+ images_features = self.vision(images)
483
+ return images_features
484
+
485
+ def encode_cross_images(self, images: List[List[torch.Tensor]]) -> torch.Tensor:
486
+ images_list, images = images, []
487
+
488
+ images = []
489
+ for image_list in images_list:
490
+ for image in image_list:
491
+ images.append(image)
492
+
493
+ images = torch.stack(images)
494
+ encoder_outputs = self.cross_vision(images)
495
+ return encoder_outputs
496
+
497
+ def forward(
498
+ self,
499
+ input_ids: torch.LongTensor = None,
500
+ images: List[List[torch.Tensor]] = None,
501
+ cross_images: List[List[torch.Tensor]] = None,
502
+ token_type_ids: Optional[torch.LongTensor] = None,
503
+ attention_mask: Optional[torch.Tensor] = None,
504
+ cross_attention_mask: Optional[torch.Tensor] = None,
505
+ position_ids: Optional[torch.LongTensor] = None,
506
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
507
+ inputs_embeds: Optional[torch.FloatTensor] = None,
508
+ use_cache: Optional[bool] = None,
509
+ output_attentions: Optional[bool] = None,
510
+ output_hidden_states: Optional[bool] = None,
511
+ return_dict: Optional[bool] = None,
512
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
513
+ """take care of image_encode, token_type_ids, position_ids and (attention_mask = None is fine)"""
514
+
515
+ if past_key_values is not None:
516
+ encoder_outputs = None
517
+ # generate mode with past_key_values. the image features are already mapped
518
+ else:
519
+ # not allow for inputs_embeds, because we want to process image feature
520
+ assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
521
+ if not is_empty(images): # multi-modality
522
+ assert token_type_ids is not None, f"multi-modality requires `token_type_ids`!"
523
+ assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
524
+ inputs_embeds = self.embed_tokens(input_ids)
525
+ images_features = self.encode_images(images)
526
+ encoder_outputs = self.encode_cross_images(cross_images)
527
+ images_features = rearrange(images_features, 'b n d -> (b n) d')
528
+ images_features = images_features.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
529
+ inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
530
+ else: # single-modality
531
+ if token_type_ids is None:
532
+ token_type_ids = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) * LANGUAGE_TOKEN_TYPE
533
+ assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
534
+ inputs_embeds = self.embed_tokens(input_ids)
535
+ encoder_outputs = None
536
+
537
+ if position_ids is None:
538
+ position_ids = build_position_ids(token_type_ids, attention_mask)
539
+ input_ids = None
540
+
541
+ return self.llm_forward(
542
+ input_ids=input_ids,
543
+ encoder_outputs=encoder_outputs,
544
+ token_type_ids=token_type_ids,
545
+ attention_mask=attention_mask,
546
+ cross_attention_mask=cross_attention_mask,
547
+ position_ids=position_ids,
548
+ past_key_values=past_key_values,
549
+ inputs_embeds=inputs_embeds,
550
+ use_cache=use_cache,
551
+ output_attentions=output_attentions,
552
+ output_hidden_states=output_hidden_states,
553
+ return_dict=return_dict,
554
+ )
555
+
556
+ def llm_forward(
557
+ self,
558
+ input_ids: torch.LongTensor = None,
559
+ encoder_outputs: torch.LongTensor = None,
560
+ token_type_ids: torch.LongTensor = None,
561
+ attention_mask: Optional[torch.Tensor] = None,
562
+ cross_attention_mask: Optional[torch.Tensor] = None,
563
+ position_ids: Optional[torch.LongTensor] = None,
564
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
565
+ inputs_embeds: Optional[torch.FloatTensor] = None,
566
+ use_cache: Optional[bool] = None,
567
+ output_attentions: Optional[bool] = None,
568
+ output_hidden_states: Optional[bool] = None,
569
+ return_dict: Optional[bool] = None,
570
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
571
+ """largely copy from llama forward and adapt for CogAgent with `token_type_ids`"""
572
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
573
+ output_hidden_states = (
574
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
575
+ )
576
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
577
+
578
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
579
+
580
+ # retrieve input_ids and inputs_embeds
581
+ if input_ids is not None and inputs_embeds is not None:
582
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
583
+ elif input_ids is not None:
584
+ batch_size, seq_length = input_ids.shape
585
+ elif inputs_embeds is not None:
586
+ batch_size, seq_length, _ = inputs_embeds.shape
587
+ else:
588
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
589
+
590
+ seq_length_with_past = seq_length
591
+ past_key_values_length = 0
592
+
593
+ if past_key_values is not None:
594
+ past_key_values_length = past_key_values[0][0].shape[2]
595
+ seq_length_with_past = seq_length_with_past + past_key_values_length
596
+
597
+ if position_ids is None:
598
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
599
+ position_ids = torch.arange(
600
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
601
+ )
602
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
603
+ else:
604
+ position_ids = position_ids.view(-1, seq_length).long()
605
+
606
+ if inputs_embeds is None:
607
+ inputs_embeds = self.embed_tokens(input_ids)
608
+ # embed positions
609
+ if attention_mask is None:
610
+ attention_mask = torch.ones(
611
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
612
+ )
613
+ if cross_attention_mask is None:
614
+ cross_attention_mask = torch.ones(
615
+ (batch_size, 1), dtype=torch.bool, device=inputs_embeds.device
616
+ )
617
+ attention_mask = self._prepare_decoder_attention_mask(
618
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
619
+ )
620
+
621
+ hidden_states = inputs_embeds
622
+
623
+ # decoder layers
624
+ all_hidden_states = () if output_hidden_states else None
625
+ all_self_attns = () if output_attentions else None
626
+ next_decoder_cache = () if use_cache else None
627
+
628
+ for idx, decoder_layer in enumerate(self.layers):
629
+ if output_hidden_states:
630
+ all_hidden_states += (hidden_states,)
631
+
632
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
633
+ layer_outputs = decoder_layer(
634
+ hidden_states,
635
+ encoder_outputs=encoder_outputs,
636
+ token_type_ids=token_type_ids,
637
+ attention_mask=attention_mask,
638
+ cross_attention_mask=cross_attention_mask,
639
+ position_ids=position_ids,
640
+ past_key_value=past_key_value,
641
+ output_attentions=output_attentions,
642
+ use_cache=use_cache,
643
+ )
644
+ hidden_states = layer_outputs[0]
645
+
646
+ if use_cache:
647
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
648
+
649
+ if output_attentions:
650
+ all_self_attns += (layer_outputs[1],)
651
+
652
+ hidden_states = self.norm(hidden_states)
653
+
654
+ # add hidden states from the last decoder layer
655
+ if output_hidden_states:
656
+ all_hidden_states += (hidden_states,)
657
+
658
+ next_cache = next_decoder_cache if use_cache else None
659
+ if not return_dict:
660
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
661
+ return BaseModelOutputWithPast(
662
+ last_hidden_state=hidden_states,
663
+ past_key_values=next_cache,
664
+ hidden_states=all_hidden_states,
665
+ attentions=all_self_attns,
666
+ )
667
+
668
+ def get_input_embeddings(self):
669
+ return self.embed_tokens
670
+
671
+ def set_input_embeddings(self, value):
672
+ self.embed_tokens = value
673
+
674
+ # noinspection PyMethodMayBeStatic
675
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
676
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
677
+ # create causal mask
678
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
679
+ combined_attention_mask = None
680
+ if input_shape[-1] > 1:
681
+ combined_attention_mask = _make_causal_mask(
682
+ input_shape,
683
+ inputs_embeds.dtype,
684
+ device=inputs_embeds.device,
685
+ past_key_values_length=past_key_values_length,
686
+ )
687
+
688
+ if attention_mask is not None:
689
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
690
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
691
+ inputs_embeds.device
692
+ )
693
+ combined_attention_mask = (
694
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
695
+ )
696
+
697
+ return combined_attention_mask
698
+
699
+ def chat_old_history_to_prompt(history, query):
700
+ prompt = "<EOI>Question: "
701
+ for i, (old_query, response) in enumerate(history):
702
+ prompt += old_query + " Answer: " + response + "\nQuestion: "
703
+ prompt += query + " Answer:"
704
+ return prompt
705
+
706
+ def chat_history_to_prompt(history, query):
707
+ prompt = " [INST] "
708
+ for i, (old_query, response) in enumerate(history):
709
+ prompt += old_query + " [/INST] " + response + " [INST] "
710
+ prompt += query + " [/INST] "
711
+ return prompt
712
+
713
+
714
+ def base_history_to_prompt(history, query):
715
+ prompt = query
716
+ return prompt
717
+
718
+
719
+ _history_to_prompt = {
720
+ "base": base_history_to_prompt,
721
+ "chat": chat_history_to_prompt,
722
+ "chat_old": chat_old_history_to_prompt
723
+ }
724
+
725
+
726
+ class CogAgentForCausalLM(CogAgentPreTrainedModel):
727
+ _auto_class = "AutoModelForCausalLM"
728
+
729
+ def __init__(self, config):
730
+ super().__init__(config)
731
+ self.model = CogAgentModel(config)
732
+ self.vocab_size = config.vocab_size
733
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
734
+
735
+ # Initialize weights and apply final processing
736
+ self.post_init()
737
+
738
+ def get_input_embeddings(self):
739
+ return self.model.embed_tokens
740
+
741
+ def set_input_embeddings(self, value):
742
+ self.model.embed_tokens = value
743
+
744
+ def get_output_embeddings(self):
745
+ return self.lm_head
746
+
747
+ def set_output_embeddings(self, new_embeddings):
748
+ self.lm_head = new_embeddings
749
+
750
+ def set_decoder(self, decoder):
751
+ self.model = decoder
752
+
753
+ def get_decoder(self):
754
+ return self.model
755
+
756
+ def forward(
757
+ self,
758
+ input_ids: torch.LongTensor = None,
759
+ images: List[List[torch.Tensor]] = None,
760
+ cross_images: List[List[torch.Tensor]] = None,
761
+ token_type_ids: Optional[torch.LongTensor] = None,
762
+ attention_mask: Optional[torch.Tensor] = None,
763
+ position_ids: Optional[torch.LongTensor] = None,
764
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
765
+ inputs_embeds: Optional[torch.FloatTensor] = None,
766
+ use_cache: Optional[bool] = None,
767
+ output_attentions: Optional[bool] = None,
768
+ output_hidden_states: Optional[bool] = None,
769
+ return_dict: Optional[bool] = None,
770
+ labels: Optional[torch.LongTensor] = None,
771
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
772
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
773
+ output_hidden_states = (
774
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
775
+ )
776
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
777
+
778
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
779
+ outputs = self.model(
780
+ input_ids=input_ids,
781
+ images=images,
782
+ cross_images=cross_images,
783
+ token_type_ids=token_type_ids,
784
+ attention_mask=attention_mask,
785
+ position_ids=position_ids,
786
+ past_key_values=past_key_values,
787
+ inputs_embeds=inputs_embeds,
788
+ use_cache=use_cache,
789
+ output_attentions=output_attentions,
790
+ output_hidden_states=output_hidden_states,
791
+ return_dict=return_dict,
792
+ )
793
+
794
+ hidden_states = outputs[0]
795
+ logits = self.lm_head(hidden_states)
796
+ logits = logits.float()
797
+
798
+ loss = None
799
+ if labels is not None:
800
+ # Shift so that tokens < n predict n
801
+ shift_logits = logits[..., :-1, :].contiguous()
802
+ shift_labels = labels[..., 1:].contiguous()
803
+ # Flatten the tokens
804
+ loss_fct = CrossEntropyLoss()
805
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
806
+ shift_labels = shift_labels.view(-1)
807
+ # Enable model parallelism
808
+ shift_labels = shift_labels.to(shift_logits.device)
809
+ loss = loss_fct(shift_logits, shift_labels)
810
+
811
+ if not return_dict:
812
+ output = (logits,) + outputs[1:]
813
+ return (loss,) + output if loss is not None else output
814
+
815
+ return CausalLMOutputWithPast(
816
+ loss=loss,
817
+ logits=logits,
818
+ past_key_values=outputs.past_key_values,
819
+ hidden_states=outputs.hidden_states,
820
+ attentions=outputs.attentions,
821
+ )
822
+
823
+ def _prepare_attention_mask_for_generation(
824
+ self,
825
+ inputs: torch.Tensor,
826
+ pad_token_id: Optional[int],
827
+ eos_token_id: Optional[Union[int, List[int]]],
828
+ ) -> torch.LongTensor:
829
+ return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
830
+
831
+ def prepare_inputs_for_generation(
832
+ self, input_ids, token_type_ids, images=None, cross_images=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
833
+ ):
834
+ # build position_ids if needed
835
+ position_ids = kwargs.get("position_ids", None)
836
+ if position_ids is None:
837
+ position_ids = build_position_ids(token_type_ids, attention_mask)
838
+
839
+ if past_key_values:
840
+ input_ids = input_ids[:, -1:]
841
+ token_type_ids = token_type_ids[:, -1:]
842
+ position_ids = position_ids[:, -1:]
843
+
844
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
845
+ if inputs_embeds is not None and past_key_values is None:
846
+ model_inputs = {"inputs_embeds": inputs_embeds}
847
+ else:
848
+ model_inputs = {"input_ids": input_ids}
849
+
850
+ model_inputs.update(
851
+ {
852
+ "token_type_ids": token_type_ids,
853
+ "images": images,
854
+ "cross_images": cross_images,
855
+ "position_ids": position_ids,
856
+ "past_key_values": past_key_values,
857
+ "use_cache": kwargs.get("use_cache"),
858
+ "attention_mask": attention_mask,
859
+ }
860
+ )
861
+ return model_inputs
862
+
863
+ def _update_model_kwargs_for_generation(
864
+ self,
865
+ outputs: "ModelOutput",
866
+ model_kwargs: Dict[str, Any],
867
+ is_encoder_decoder: bool = False,
868
+ standardize_cache_format: bool = False,
869
+ ) -> Dict[str, Any]:
870
+ # update past_key_values
871
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
872
+ outputs, standardize_cache_format=standardize_cache_format
873
+ )
874
+ if getattr(outputs, "state", None) is not None:
875
+ model_kwargs["state"] = outputs.state
876
+
877
+ # update token_type_ids with last value
878
+ if "token_type_ids" in model_kwargs:
879
+ token_type_ids = model_kwargs["token_type_ids"]
880
+ new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype, device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
881
+ model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
882
+
883
+ if not is_encoder_decoder:
884
+ # update attention mask
885
+ if "attention_mask" in model_kwargs:
886
+ attention_mask = model_kwargs["attention_mask"]
887
+ model_kwargs["attention_mask"] = torch.cat(
888
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
889
+ )
890
+ else:
891
+ # update decoder attention mask
892
+ if "decoder_attention_mask" in model_kwargs:
893
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
894
+ model_kwargs["decoder_attention_mask"] = torch.cat(
895
+ [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
896
+ dim=-1,
897
+ )
898
+
899
+ return model_kwargs
900
+
901
+ def _reorder_cache(self, past_key_values, beam_idx):
902
+ reordered_past = ()
903
+ for layer_past in past_key_values:
904
+ reordered_past += (
905
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
906
+ )
907
+ return reordered_past
908
+
909
+ def build_conversation_input_ids(
910
+ self,
911
+ tokenizer: "PreTrainedTokenizer",
912
+ *,
913
+ query: str,
914
+ history: Optional[List[Tuple[str, str]]] = None,
915
+ images: Optional[List["PIL.Image"]] = None,
916
+ template_version: Optional[Literal["base", "chat", "vqa"]] = None,
917
+ ):
918
+ image_size: int = self.config.vision_config['image_size']
919
+ cross_image_size: int = self.config.cross_image_size
920
+ patch_size: int = self.config.vision_config['patch_size']
921
+ template_version = template_version or self.config.template_version
922
+ assert images is None or len(images) <= 1, f"not support multi images by now."
923
+ history = history or []
924
+ text = _history_to_prompt[template_version](history, query)
925
+
926
+ input_ids = [tokenizer.bos_token_id]
927
+ token_type_ids = [LANGUAGE_TOKEN_TYPE]
928
+ if images is not None and len(images) == 1:
929
+ ori = images
930
+ # vision
931
+ transform = transforms.Compose(
932
+ [
933
+ transforms.Resize(
934
+ (image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC
935
+ ),
936
+ transforms.ToTensor(),
937
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
938
+ ]
939
+ )
940
+ images = [transform(ori[0])]
941
+ cross_transform = transforms.Compose(
942
+ [
943
+ transforms.Resize(
944
+ (cross_image_size, cross_image_size), interpolation=transforms.InterpolationMode.BICUBIC
945
+ ),
946
+ transforms.ToTensor(),
947
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
948
+ ]
949
+ )
950
+ cross_images = [cross_transform(ori[0])]
951
+ # language
952
+ vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
953
+ input_ids += [tokenizer.pad_token_id] * vision_token_num
954
+ token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
955
+ text_ids = tokenizer.encode(text, add_special_tokens=False)
956
+
957
+ input_ids += text_ids
958
+ token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
959
+ attention_mask = [1] * len(input_ids)
960
+
961
+ return {
962
+ 'input_ids': torch.tensor(input_ids, dtype=torch.long),
963
+ 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
964
+ 'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
965
+ 'images': images,
966
+ 'cross_images': cross_images
967
+ }