Text Generation
Transformers
Safetensors
English
custom_code
qingsonglv commited on
Commit
a2c2a1e
1 Parent(s): ff4a7b3

upload model

Browse files
README.md CHANGED
@@ -1,3 +1,156 @@
1
  ---
2
  license: apache-2.0
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ language:
4
+ - en
5
  ---
6
+ # CogAgent
7
+
8
+ ## Introduction
9
+
10
+ **CogAgent** is an open-source visual language model improved based on **CogVLM**.
11
+
12
+ 📖 Paper: https://arxiv.org/abs/2312.08914
13
+
14
+ **CogAgent-18B** has 11 billion visual parameters and 7 billion language parameters and achieves state-of-the-art generalist performance on 9 classic cross-modal benchmarks, including:
15
+ + VQAv2
16
+ + OK-VQ
17
+ + TextVQA
18
+ + ST-VQA
19
+ + ChartQA
20
+ + infoVQA
21
+ + DocVQA
22
+ + MM-Vet
23
+ + POPE
24
+
25
+ **CogAgent-18B** significantly surpasses existing models on GUI operation datasets such as AITW and Mind2Web.
26
+
27
+ In addition to all the features already present in **CogVLM** (visual multi-round dialogue, visual grounding), **CogAgent**:
28
+
29
+ 1. Supports higher resolution visual input and dialogue question-answering. It supports ultra-high-resolution image inputs of **1120x1120**.
30
+
31
+ 2. Possesses the capabilities of a visual Agent, being able to return a plan, next action, and specific operations with coordinates for any given task on any GUI screenshot.
32
+
33
+ 3. Enhanced GUI-related question-answering capabilities, allowing it to handle questions about any GUI screenshot, such as web pages, PC apps, mobile applications, etc.
34
+
35
+ 4. Enhanced capabilities in OCR-related tasks through improved pre-training and fine-tuning.
36
+
37
+ <div align="center">
38
+ <img src="https://raw.githubusercontent.com/THUDM/CogVLM/master/assets/cogagent_function.jpg" alt="img" style="zoom: 50%;" />
39
+ </div>
40
+
41
+ ## Quick Start
42
+
43
+ use this python code to get started quickly in `cli_demo.py`:
44
+
45
+ ```python
46
+ import torch
47
+ from PIL import Image
48
+ from transformers import AutoModelForCausalLM, LlamaTokenizer
49
+ import argparse
50
+
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument("--quant", choices=[4], type=int, default=None, help='quantization bits')
53
+ parser.add_argument("--from_pretrained", type=str, default="THUDM/cogagent-chat-hf", help='pretrained ckpt')
54
+ parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path')
55
+ parser.add_argument("--fp16", action="store_true")
56
+ parser.add_argument("--bf16", action="store_true")
57
+
58
+ args = parser.parse_args()
59
+ MODEL_PATH = args.from_pretrained
60
+ TOKENIZER_PATH = args.local_tokenizer
61
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
62
+
63
+ tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH)
64
+ if args.bf16:
65
+ torch_type = torch.bfloat16
66
+ else:
67
+ torch_type = torch.float16
68
+
69
+ print("========Use torch type as:{} with device:{}========\n\n".format(torch_type, DEVICE))
70
+
71
+ if args.quant:
72
+ model = AutoModelForCausalLM.from_pretrained(
73
+ MODEL_PATH,
74
+ torch_dtype=torch_type,
75
+ low_cpu_mem_usage=True,
76
+ load_in_4bit=True,
77
+ trust_remote_code=True
78
+ ).eval()
79
+ else:
80
+ model = AutoModelForCausalLM.from_pretrained(
81
+ MODEL_PATH,
82
+ torch_dtype=torch_type,
83
+ low_cpu_mem_usage=True,
84
+ load_in_4bit=args.quant is not None,
85
+ trust_remote_code=True
86
+ ).to(DEVICE).eval()
87
+
88
+ while True:
89
+ image_path = input("image path >>>>> ")
90
+ if image_path == "stop":
91
+ break
92
+
93
+ image = Image.open(image_path).convert('RGB')
94
+ history = []
95
+ while True:
96
+ query = input("Human:")
97
+ if query == "clear":
98
+ break
99
+ input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image])
100
+ inputs = {
101
+ 'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE),
102
+ 'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE),
103
+ 'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE),
104
+ 'images': [[input_by_model['images'][0].to(DEVICE).to(torch_type)]],
105
+ }
106
+ if 'cross_images' in input_by_model and input_by_model['cross_images']:
107
+ inputs['cross_images'] = [[input_by_model['cross_images'][0].to(DEVICE).to(torch_type)]]
108
+
109
+ # add any transformers params here.
110
+ gen_kwargs = {"max_length": 2048,
111
+ "temperature": 0.9,
112
+ "do_sample": False}
113
+ with torch.no_grad():
114
+ outputs = model.generate(**inputs, **gen_kwargs)
115
+ outputs = outputs[:, inputs['input_ids'].shape[1]:]
116
+ response = tokenizer.decode(outputs[0])
117
+ response = response.split("</s>")[0]
118
+ print("\nCog:", response)
119
+ history.append((query, response))
120
+ ```
121
+
122
+ Then run:
123
+
124
+ ```bash
125
+ python cli_demo_hf.py --bf16
126
+ ```
127
+ for more information such as Web Demo and Finetune, please refer to [Our GitHub](https://github.com/THUDM/CogVLM/)
128
+
129
+ ## License
130
+
131
+ The code in this repository is open source under the [Apache-2.0 license](./LICENSE), while the use of CogAgent and CogVLM model weights must comply with the [Model License](./MODEL_LICENSE).
132
+
133
+ ## Citation & Acknowledgements
134
+
135
+ If you find our work helpful, please consider citing the following papers
136
+
137
+ ```
138
+ @misc{hong2023cogagent,
139
+ title={CogAgent: A Visual Language Model for GUI Agents},
140
+ author={Wenyi Hong and Weihan Wang and Qingsong Lv and Jiazheng Xu and Wenmeng Yu and Junhui Ji and Yan Wang and Zihan Wang and Yuxiao Dong and Ming Ding and Jie Tang},
141
+ year={2023},
142
+ eprint={2312.08914},
143
+ archivePrefix={arXiv},
144
+ primaryClass={cs.CV}
145
+ }
146
+
147
+ @misc{wang2023cogvlm,
148
+ title={CogVLM: Visual Expert for Pretrained Language Models},
149
+ author={Weihan Wang and Qingsong Lv and Wenmeng Yu and Wenyi Hong and Ji Qi and Yan Wang and Junhui Ji and Zhuoyi Yang and Lei Zhao and Xixuan Song and Jiazheng Xu and Bin Xu and Juanzi Li and Yuxiao Dong and Ming Ding and Jie Tang},
150
+ year={2023},
151
+ eprint={2311.03079},
152
+ archivePrefix={arXiv},
153
+ primaryClass={cs.CV}
154
+ }
155
+ ```
156
+ In the instruction fine-tuning phase of the CogVLM, there are some English image-text data from the [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4), [LLAVA](https://github.com/haotian-liu/LLaVA), [LRV-Instruction](https://github.com/FuxiaoLiu/LRV-Instruction), [LLaVAR](https://github.com/SALT-NLP/LLaVAR) and [Shikra](https://github.com/shikras/shikra) projects, as well as many classic cross-modal work datasets. We sincerely thank them for their contributions.
config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "cogagent",
3
+ "architectures": [
4
+ "CogAgentForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_cogagent.CogAgentConfig",
8
+ "AutoModelForCausalLM": "modeling_cogagent.CogAgentForCausalLM"
9
+ },
10
+ "bos_token_id": 1,
11
+ "cross_compute_hidden_size": 1024,
12
+ "cross_hidden_size": 1024,
13
+ "cross_image_size": 1120,
14
+ "eos_token_id": 2,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 4096,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 11008,
19
+ "max_position_embeddings": 2048,
20
+ "num_attention_heads": 32,
21
+ "num_hidden_layers": 32,
22
+ "pad_token_id": 0,
23
+ "rms_norm_eps": 1e-05,
24
+ "template_version": "chat_old",
25
+ "tie_word_embeddings": false,
26
+ "torch_dtype": "bfloat16",
27
+ "transformers_version": "4.36.0.dev0",
28
+ "use_cache": true,
29
+ "vision_config": {
30
+ "dropout_prob": 0.0,
31
+ "hidden_act": "gelu",
32
+ "hidden_size": 1792,
33
+ "image_size": 224,
34
+ "in_channels": 3,
35
+ "intermediate_size": 15360,
36
+ "layer_norm_eps": 1e-06,
37
+ "num_heads": 16,
38
+ "num_hidden_layers": 63,
39
+ "num_positions": 257,
40
+ "patch_size": 14
41
+ },
42
+ "vocab_size": 32000
43
+ }
configuration_cogagent.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class CogAgentConfig(PretrainedConfig):
6
+ _auto_class = "AutoConfig"
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=32000,
11
+ hidden_size=4096,
12
+ cross_hidden_size=1024,
13
+ cross_compute_hidden_size=1024,
14
+ cross_image_size=1120,
15
+ intermediate_size=11008,
16
+ num_hidden_layers=32,
17
+ num_attention_heads=32,
18
+ hidden_act='silu',
19
+ max_position_embeddings=2048,
20
+ initializer_range=0.02,
21
+ rms_norm_eps=1e-06,
22
+ template_version: Literal["base", "chat"] = "chat",
23
+
24
+ pad_token_id=0,
25
+ bos_token_id=1,
26
+ eos_token_id=2,
27
+ tie_word_embeddings=False,
28
+ use_cache=True,
29
+ **kwargs,
30
+ ):
31
+ self.hidden_size = hidden_size
32
+ self.cross_hidden_size = cross_hidden_size
33
+ self.cross_compute_hidden_size = cross_compute_hidden_size
34
+ self.cross_image_size = cross_image_size
35
+ self.intermediate_size = intermediate_size
36
+ self.num_attention_heads = num_attention_heads
37
+ self.max_position_embeddings = max_position_embeddings
38
+ self.rms_norm_eps = rms_norm_eps
39
+ self.initializer_range = initializer_range
40
+ self.vocab_size = vocab_size
41
+ self.num_hidden_layers = num_hidden_layers
42
+ self.hidden_act = hidden_act
43
+ self.template_version = template_version
44
+ self.use_cache = use_cache
45
+ super().__init__(
46
+ pad_token_id=pad_token_id,
47
+ bos_token_id=bos_token_id,
48
+ eos_token_id=eos_token_id,
49
+ tie_word_embeddings=tie_word_embeddings,
50
+ **kwargs,
51
+ )
cross_visual.py ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ import torch
3
+ from torch import nn
4
+ from einops import rearrange, repeat
5
+ import logging
6
+
7
+ def broadcat(tensors, dim = -1):
8
+ num_tensors = len(tensors)
9
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
10
+ assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
11
+ shape_len = list(shape_lens)[0]
12
+ dim = (dim + shape_len) if dim < 0 else dim
13
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
14
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
15
+ assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
16
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
17
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
18
+ expanded_dims.insert(dim, (dim, dims[dim]))
19
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
20
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
21
+ return torch.cat(tensors, dim = dim)
22
+
23
+ def rotate_half(x):
24
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
25
+ x1, x2 = x.unbind(dim = -1)
26
+ x = torch.stack((-x2, x1), dim = -1)
27
+ return rearrange(x, '... d r -> ... (d r)')
28
+
29
+ class VisionRotaryEmbeddingFast(nn.Module):
30
+ def __init__(
31
+ self,
32
+ dim,
33
+ pt_seq_len,
34
+ ft_seq_len=None,
35
+ custom_freqs = None,
36
+ freqs_for = 'lang',
37
+ theta = 10000,
38
+ max_freq = 10,
39
+ num_freqs = 1,
40
+ patch_dropout = 0.
41
+ ):
42
+ super().__init__()
43
+ if custom_freqs:
44
+ freqs = custom_freqs
45
+ elif freqs_for == 'lang':
46
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
47
+ elif freqs_for == 'pixel':
48
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
49
+ elif freqs_for == 'constant':
50
+ freqs = torch.ones(num_freqs).float()
51
+ else:
52
+ raise ValueError(f'unknown modality {freqs_for}')
53
+
54
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
55
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
56
+
57
+ freqs = torch.einsum('..., f -> ... f', t, freqs)
58
+ freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
59
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
60
+
61
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
62
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
63
+
64
+ self.patch_dropout = patch_dropout
65
+
66
+ self.register_buffer("freqs_cos", freqs_cos)
67
+ self.register_buffer("freqs_sin", freqs_sin)
68
+
69
+ logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
70
+
71
+ def forward(self, t, patch_indices_keep=None):
72
+ if patch_indices_keep is not None:
73
+ batch = t.size()[0]
74
+ batch_indices = torch.arange(batch)
75
+ batch_indices = batch_indices[..., None]
76
+
77
+ freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
78
+ freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
79
+
80
+ freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
81
+ freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
82
+ freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
83
+ freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
84
+
85
+ return t * freqs_cos + rotate_half(t) * freqs_sin
86
+
87
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
88
+
89
+ import torch.nn as nn
90
+ import os
91
+ from dataclasses import dataclass
92
+ from typing import Optional, Tuple, Union
93
+ from functools import partial
94
+
95
+ import numpy as np
96
+ import torch
97
+ import torch.nn.functional as F
98
+ from torch import nn
99
+
100
+ # --------------------------------------------------------
101
+ # Adapted from https://github.com/microsoft/unilm/tree/master/beit
102
+ # --------------------------------------------------------
103
+ import math
104
+ import os
105
+ from functools import partial
106
+ import torch
107
+ import torch.nn as nn
108
+ import torch.nn.functional as F
109
+ import logging
110
+ try:
111
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
112
+ except:
113
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
114
+
115
+ class PatchDropout(nn.Module):
116
+ """
117
+ https://arxiv.org/abs/2212.00794
118
+ """
119
+
120
+ def __init__(self, prob, exclude_first_token=True):
121
+ super().__init__()
122
+ assert 0 <= prob < 1.
123
+ self.prob = prob
124
+ self.exclude_first_token = exclude_first_token # exclude CLS token
125
+ logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
126
+
127
+ def forward(self, x):
128
+ if not self.training or self.prob == 0.:
129
+ return x
130
+
131
+ if self.exclude_first_token:
132
+ cls_tokens, x = x[:, :1], x[:, 1:]
133
+ else:
134
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
135
+
136
+ batch = x.size()[0]
137
+ num_tokens = x.size()[1]
138
+
139
+ batch_indices = torch.arange(batch)
140
+ batch_indices = batch_indices[..., None]
141
+
142
+ keep_prob = 1 - self.prob
143
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
144
+
145
+ rand = torch.randn(batch, num_tokens)
146
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
147
+
148
+ x = x[batch_indices, patch_indices_keep]
149
+
150
+ if self.exclude_first_token:
151
+ x = torch.cat((cls_tokens, x), dim=1)
152
+
153
+ if self.training and os.getenv('RoPE') == '1':
154
+ return x, patch_indices_keep
155
+
156
+ return x
157
+
158
+ if os.getenv('ENV_TYPE') == 'deepspeed':
159
+ try:
160
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
161
+ except:
162
+ from torch.utils.checkpoint import checkpoint
163
+ else:
164
+ from torch.utils.checkpoint import checkpoint
165
+
166
+ import xformers.ops as xops
167
+
168
+ class DropPath(nn.Module):
169
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
170
+ """
171
+ def __init__(self, drop_prob=None):
172
+ super(DropPath, self).__init__()
173
+ self.drop_prob = drop_prob
174
+
175
+ def forward(self, x):
176
+ return drop_path(x, self.drop_prob, self.training)
177
+
178
+ def extra_repr(self) -> str:
179
+ return 'p={}'.format(self.drop_prob)
180
+
181
+
182
+ class Mlp(nn.Module):
183
+ def __init__(
184
+ self,
185
+ in_features,
186
+ hidden_features=None,
187
+ out_features=None,
188
+ act_layer=nn.GELU,
189
+ norm_layer=nn.LayerNorm,
190
+ drop=0.,
191
+ subln=False,
192
+
193
+ ):
194
+ super().__init__()
195
+ out_features = out_features or in_features
196
+ hidden_features = hidden_features or in_features
197
+ self.fc1 = nn.Linear(in_features, hidden_features)
198
+ self.act = act_layer()
199
+
200
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
201
+
202
+ self.fc2 = nn.Linear(hidden_features, out_features)
203
+ self.drop = nn.Dropout(drop)
204
+
205
+ def forward(self, x):
206
+ x = self.fc1(x)
207
+ x = self.act(x)
208
+ # x = self.drop(x)
209
+ # commit this for the orignal BERT implement
210
+ x = self.ffn_ln(x)
211
+
212
+ x = self.fc2(x)
213
+ x = self.drop(x)
214
+ return x
215
+
216
+ class SwiGLU(nn.Module):
217
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
218
+ norm_layer=nn.LayerNorm, subln=False):
219
+ super().__init__()
220
+ out_features = out_features or in_features
221
+ hidden_features = hidden_features or in_features
222
+
223
+ self.w1 = nn.Linear(in_features, hidden_features)
224
+ self.w2 = nn.Linear(in_features, hidden_features)
225
+
226
+ self.act = act_layer()
227
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
228
+ self.w3 = nn.Linear(hidden_features, out_features)
229
+
230
+ self.drop = nn.Dropout(drop)
231
+
232
+ def forward(self, x):
233
+ x1 = self.w1(x)
234
+ x2 = self.w2(x)
235
+ hidden = self.act(x1) * x2
236
+ x = self.ffn_ln(hidden)
237
+ x = self.w3(x)
238
+ x = self.drop(x)
239
+ return x
240
+
241
+ class Attention(nn.Module):
242
+ def __init__(
243
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
244
+ proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
245
+ super().__init__()
246
+ self.num_heads = num_heads
247
+ head_dim = dim // num_heads
248
+ if attn_head_dim is not None:
249
+ head_dim = attn_head_dim
250
+ all_head_dim = head_dim * self.num_heads
251
+ self.scale = qk_scale or head_dim ** -0.5
252
+
253
+ self.subln = subln
254
+ if self.subln:
255
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
256
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
257
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
258
+ else:
259
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
260
+
261
+ if qkv_bias:
262
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
263
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
264
+ else:
265
+ self.q_bias = None
266
+ self.v_bias = None
267
+
268
+ if window_size:
269
+ self.window_size = window_size
270
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
271
+ self.relative_position_bias_table = nn.Parameter(
272
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
273
+ # cls to token & token 2 cls & cls to cls
274
+
275
+ # get pair-wise relative position index for each token inside the window
276
+ coords_h = torch.arange(window_size[0])
277
+ coords_w = torch.arange(window_size[1])
278
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
279
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
280
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
281
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
282
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
283
+ relative_coords[:, :, 1] += window_size[1] - 1
284
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
285
+ relative_position_index = \
286
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
287
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
288
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
289
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
290
+ relative_position_index[0, 0] = self.num_relative_distance - 1
291
+
292
+ self.register_buffer("relative_position_index", relative_position_index)
293
+ else:
294
+ self.window_size = None
295
+ self.relative_position_bias_table = None
296
+ self.relative_position_index = None
297
+
298
+ self.attn_drop = nn.Dropout(attn_drop)
299
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
300
+ # self.proj = nn.Linear(all_head_dim, all_head_dim)
301
+ self.proj = nn.Linear(all_head_dim, dim)
302
+ self.proj_drop = nn.Dropout(proj_drop)
303
+ self.xattn = xattn
304
+ self.xattn_drop = attn_drop
305
+
306
+ self.rope = rope
307
+
308
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
309
+ B, N, C = x.shape
310
+ if self.subln:
311
+ if self.q_proj.weight.dtype == torch.uint8:
312
+ import bitsandbytes as bnb
313
+ q = bnb.matmul_4bit(x, self.q_proj.weight.t(), bias=self.q_bias, quant_state=self.q_proj.weight.quant_state)
314
+ k = bnb.matmul_4bit(x, self.k_proj.weight.t(), bias=None, quant_state=self.k_proj.weight.quant_state)
315
+ v = bnb.matmul_4bit(x, self.v_proj.weight.t(), bias=self.v_bias, quant_state=self.v_proj.weight.quant_state)
316
+ else:
317
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
318
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
319
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
320
+
321
+ q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
322
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
323
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
324
+ else:
325
+
326
+ qkv_bias = None
327
+ if self.q_bias is not None:
328
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
329
+
330
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
331
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
332
+ q, k, v = qkv[0], qkv[1], qkv[2]
333
+
334
+ if self.rope:
335
+ # slightly fast impl
336
+ q_t = q[:, :, 1:, :]
337
+ ro_q_t = self.rope(q_t)
338
+ q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
339
+
340
+ k_t = k[:, :, 1:, :]
341
+ ro_k_t = self.rope(k_t)
342
+ k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
343
+
344
+ if self.xattn:
345
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
346
+ k = k.permute(0, 2, 1, 3)
347
+ v = v.permute(0, 2, 1, 3)
348
+
349
+ x = xops.memory_efficient_attention(
350
+ q, k, v,
351
+ p=self.xattn_drop,
352
+ scale=self.scale,
353
+ )
354
+ x = x.reshape(B, N, -1)
355
+ x = self.inner_attn_ln(x)
356
+ x = self.proj(x)
357
+ x = self.proj_drop(x)
358
+ else:
359
+ q = q * self.scale
360
+ attn = (q @ k.transpose(-2, -1))
361
+
362
+ if self.relative_position_bias_table is not None:
363
+ relative_position_bias = \
364
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
365
+ self.window_size[0] * self.window_size[1] + 1,
366
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
367
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
368
+ attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
369
+
370
+ if rel_pos_bias is not None:
371
+ attn = attn + rel_pos_bias.type_as(attn)
372
+
373
+ if attn_mask is not None:
374
+ attn_mask = attn_mask.bool()
375
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
376
+
377
+ attn = attn.softmax(dim=-1)
378
+ attn = self.attn_drop(attn)
379
+
380
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
381
+ x = self.inner_attn_ln(x)
382
+ x = self.proj(x)
383
+ x = self.proj_drop(x)
384
+ return x
385
+
386
+
387
+ class Block(nn.Module):
388
+
389
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
390
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
391
+ window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
392
+ subln=False, naiveswiglu=False):
393
+ super().__init__()
394
+ self.norm1 = norm_layer(dim)
395
+ self.attn = Attention(
396
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
397
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
398
+ xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
399
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
400
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
401
+ self.norm2 = norm_layer(dim)
402
+ mlp_hidden_dim = int(dim * mlp_ratio)
403
+
404
+ if naiveswiglu:
405
+ self.mlp = SwiGLU(
406
+ in_features=dim,
407
+ hidden_features=mlp_hidden_dim,
408
+ subln=subln,
409
+ norm_layer=norm_layer,
410
+ )
411
+ else:
412
+ self.mlp = Mlp(
413
+ in_features=dim,
414
+ hidden_features=mlp_hidden_dim,
415
+ act_layer=act_layer,
416
+ subln=subln,
417
+ drop=drop
418
+ )
419
+
420
+ if init_values is not None and init_values > 0:
421
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
422
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
423
+ else:
424
+ self.gamma_1, self.gamma_2 = None, None
425
+
426
+ self.postnorm = postnorm
427
+
428
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
429
+ if self.gamma_1 is None:
430
+ if self.postnorm:
431
+ x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
432
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
433
+ else:
434
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
435
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
436
+ else:
437
+ if self.postnorm:
438
+ x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
439
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
440
+ else:
441
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
442
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
443
+ return x
444
+
445
+
446
+ class PatchEmbed(nn.Module):
447
+ """ Image to Patch Embedding
448
+ """
449
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
450
+ super().__init__()
451
+ img_size = to_2tuple(img_size)
452
+ patch_size = to_2tuple(patch_size)
453
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
454
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
455
+ self.img_size = img_size
456
+ self.patch_size = patch_size
457
+ self.num_patches = num_patches
458
+
459
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
460
+
461
+ def forward(self, x, **kwargs):
462
+ B, C, H, W = x.shape
463
+ # FIXME look at relaxing size constraints
464
+ assert H == self.img_size[0] and W == self.img_size[1], \
465
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
466
+ x = self.proj(x).flatten(2).transpose(1, 2)
467
+ return x
468
+
469
+
470
+ class RelativePositionBias(nn.Module):
471
+
472
+ def __init__(self, window_size, num_heads):
473
+ super().__init__()
474
+ self.window_size = window_size
475
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
476
+ self.relative_position_bias_table = nn.Parameter(
477
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
478
+ # cls to token & token 2 cls & cls to cls
479
+
480
+ # get pair-wise relative position index for each token inside the window
481
+ coords_h = torch.arange(window_size[0])
482
+ coords_w = torch.arange(window_size[1])
483
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
484
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
485
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
486
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
487
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
488
+ relative_coords[:, :, 1] += window_size[1] - 1
489
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
490
+ relative_position_index = \
491
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
492
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
493
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
494
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
495
+ relative_position_index[0, 0] = self.num_relative_distance - 1
496
+
497
+ self.register_buffer("relative_position_index", relative_position_index)
498
+
499
+ def forward(self):
500
+ relative_position_bias = \
501
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
502
+ self.window_size[0] * self.window_size[1] + 1,
503
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
504
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
505
+
506
+
507
+ class EVAVisionTransformer(nn.Module):
508
+ """ Vision Transformer with support for patch or hybrid CNN input stage
509
+ """
510
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
511
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
512
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
513
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
514
+ use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
515
+ pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False):
516
+ super().__init__()
517
+ self.image_size = img_size
518
+ self.num_classes = num_classes
519
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
520
+
521
+ self.patch_embed = PatchEmbed(
522
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
523
+ num_patches = self.patch_embed.num_patches
524
+
525
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
526
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
527
+ if use_abs_pos_emb:
528
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
529
+ else:
530
+ self.pos_embed = None
531
+ self.pos_drop = nn.Dropout(p=drop_rate)
532
+
533
+ if use_shared_rel_pos_bias:
534
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
535
+ else:
536
+ self.rel_pos_bias = None
537
+
538
+ if rope:
539
+ half_head_dim = embed_dim // num_heads // 2
540
+ hw_seq_len = img_size // patch_size
541
+ self.rope = VisionRotaryEmbeddingFast(
542
+ dim=half_head_dim,
543
+ pt_seq_len=pt_hw_seq_len,
544
+ ft_seq_len=hw_seq_len if intp_freq else None,
545
+ # patch_dropout=patch_dropout
546
+ )
547
+ else:
548
+ self.rope = None
549
+
550
+ self.naiveswiglu = naiveswiglu
551
+
552
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
553
+ self.use_rel_pos_bias = use_rel_pos_bias
554
+ self.blocks = nn.ModuleList([
555
+ Block(
556
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
557
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
558
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
559
+ xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
560
+ for i in range(depth)])
561
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
562
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
563
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
564
+
565
+ if self.pos_embed is not None:
566
+ trunc_normal_(self.pos_embed, std=.02)
567
+
568
+ trunc_normal_(self.cls_token, std=.02)
569
+ # trunc_normal_(self.mask_token, std=.02)
570
+
571
+ self.apply(self._init_weights)
572
+ self.fix_init_weight()
573
+
574
+ if isinstance(self.head, nn.Linear):
575
+ trunc_normal_(self.head.weight, std=.02)
576
+ self.head.weight.data.mul_(init_scale)
577
+ self.head.bias.data.mul_(init_scale)
578
+
579
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
580
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
581
+
582
+ self.grad_checkpointing = grad_checkpointing
583
+
584
+ def fix_init_weight(self):
585
+ def rescale(param, layer_id):
586
+ param.div_(math.sqrt(2.0 * layer_id))
587
+
588
+ for layer_id, layer in enumerate(self.blocks):
589
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
590
+ if self.naiveswiglu:
591
+ rescale(layer.mlp.w3.weight.data, layer_id + 1)
592
+ else:
593
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
594
+
595
+ def get_cast_dtype(self) -> torch.dtype:
596
+ return self.blocks[0].mlp.fc2.weight.dtype
597
+
598
+ def _init_weights(self, m):
599
+ if isinstance(m, nn.Linear):
600
+ trunc_normal_(m.weight, std=.02)
601
+ if m.bias is not None:
602
+ nn.init.constant_(m.bias, 0)
603
+ elif isinstance(m, nn.LayerNorm):
604
+ nn.init.constant_(m.bias, 0)
605
+ nn.init.constant_(m.weight, 1.0)
606
+
607
+ def get_num_layers(self):
608
+ return len(self.blocks)
609
+
610
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
611
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
612
+ for param in self.parameters():
613
+ param.requires_grad = False
614
+
615
+ @torch.jit.ignore
616
+ def set_grad_checkpointing(self, enable=True):
617
+ self.grad_checkpointing = enable
618
+
619
+ @torch.jit.ignore
620
+ def no_weight_decay(self):
621
+ return {'pos_embed', 'cls_token'}
622
+
623
+ def get_classifier(self):
624
+ return self.head
625
+
626
+ def reset_classifier(self, num_classes, global_pool=''):
627
+ self.num_classes = num_classes
628
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
629
+
630
+ def forward_features(self, x, return_all_features=False):
631
+
632
+ x = self.patch_embed(x)
633
+ batch_size, seq_len, _ = x.size()
634
+
635
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
636
+ x = torch.cat((cls_tokens, x), dim=1)
637
+ if self.pos_embed is not None:
638
+ x = x + self.pos_embed
639
+ x = self.pos_drop(x)
640
+
641
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
642
+ if os.getenv('RoPE') == '1':
643
+ if self.training and not isinstance(self.patch_dropout, nn.Identity):
644
+ x, patch_indices_keep = self.patch_dropout(x)
645
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
646
+ else:
647
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
648
+ x = self.patch_dropout(x)
649
+ else:
650
+ x = self.patch_dropout(x)
651
+
652
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
653
+ for i, blk in enumerate(self.blocks):
654
+ if i == len(self.blocks)-1:
655
+ continue
656
+ if self.grad_checkpointing:
657
+ x = checkpoint(blk, x, (rel_pos_bias,))
658
+ else:
659
+ x = blk(x, rel_pos_bias=rel_pos_bias)
660
+
661
+ if not return_all_features:
662
+ x = self.norm(x)
663
+ if self.fc_norm is not None:
664
+ return self.fc_norm(x.mean(1))
665
+ else:
666
+ return x[:, 0]
667
+ return x
668
+
669
+ def forward(self, x, return_all_features=False):
670
+ if return_all_features:
671
+ return self.forward_features(x, return_all_features)
672
+ x = self.forward_features(x)
673
+ x = self.head(x)
674
+ return x
675
+
676
+ class LayerNorm(nn.LayerNorm):
677
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
678
+
679
+ def forward(self, x: torch.Tensor):
680
+ orig_type = x.dtype
681
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
682
+ return x.to(orig_type)
683
+
684
+ try:
685
+ from apex.normalization import FusedLayerNorm
686
+ except:
687
+ FusedLayerNorm = LayerNorm
688
+ print("Please 'pip install apex'")
689
+
690
+
691
+ @dataclass
692
+ class CLIPVisionCfg:
693
+ layers: Union[Tuple[int, int, int, int], int] = 12
694
+ width: int = 768
695
+ head_width: int = 64
696
+ mlp_ratio: float = 4.0
697
+ patch_size: int = 16
698
+ image_size: Union[Tuple[int, int], int] = 224
699
+ ls_init_value: Optional[float] = None # layer scale initial value
700
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
701
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
702
+ drop_path_rate: Optional[float] = None # drop path rate
703
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
704
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
705
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
706
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
707
+ timm_proj_bias: bool = False # enable bias final projection
708
+ eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
709
+ qkv_bias: bool = True
710
+ fusedLN: bool = False
711
+ xattn: bool = False
712
+ postnorm: bool = False
713
+ rope: bool = False
714
+ pt_hw_seq_len: int = 16 # 224/14
715
+ intp_freq: bool = False
716
+ naiveswiglu: bool = False
717
+ subln: bool = False
718
+
719
+
720
+ def _build_vision_tower(
721
+ embed_dim: int,
722
+ vision_cfg: CLIPVisionCfg
723
+ ):
724
+ if isinstance(vision_cfg, dict):
725
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
726
+
727
+ if vision_cfg.eva_model_name:
728
+ vision_heads = vision_cfg.width // vision_cfg.head_width
729
+ norm_layer = LayerNorm
730
+ visual = EVAVisionTransformer(
731
+ img_size=vision_cfg.image_size,
732
+ patch_size=vision_cfg.patch_size,
733
+ num_classes=embed_dim,
734
+ use_mean_pooling=vision_cfg.global_average_pool, #False
735
+ init_values=vision_cfg.ls_init_value,
736
+ patch_dropout=vision_cfg.patch_dropout,
737
+ embed_dim=vision_cfg.width,
738
+ depth=vision_cfg.layers,
739
+ num_heads=vision_heads,
740
+ mlp_ratio=vision_cfg.mlp_ratio,
741
+ qkv_bias=vision_cfg.qkv_bias,
742
+ drop_path_rate=vision_cfg.drop_path_rate,
743
+ norm_layer= partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6),
744
+ xattn=vision_cfg.xattn,
745
+ rope=vision_cfg.rope,
746
+ postnorm=vision_cfg.postnorm,
747
+ pt_hw_seq_len= vision_cfg.pt_hw_seq_len, # 224/14
748
+ intp_freq= vision_cfg.intp_freq,
749
+ naiveswiglu= vision_cfg.naiveswiglu,
750
+ subln= vision_cfg.subln
751
+ )
752
+
753
+ return visual
754
+
755
+ class Eva2LargeEncoder(nn.Module):
756
+ def __init__(self, image_size=224):
757
+ super(Eva2LargeEncoder, self).__init__()
758
+ self.config = {
759
+ "embed_dim": 768,
760
+ "vision_cfg": {
761
+ "image_size": 336,
762
+ "layers": 24,
763
+ "width": 1024,
764
+ "drop_path_rate": 0,
765
+ "head_width": 64,
766
+ "mlp_ratio": 2.6667,
767
+ "patch_size": 14,
768
+ "eva_model_name": "eva-clip-l-14-336",
769
+ "xattn": True,
770
+ "fusedLN": True,
771
+ "rope": True,
772
+ "pt_hw_seq_len": 16,
773
+ "intp_freq": True,
774
+ "naiveswiglu": True,
775
+ "subln": True
776
+ }
777
+ }
778
+ self.config['vision_cfg']['image_size'] = image_size
779
+
780
+ import os
781
+ os.environ['delRoPE'] = '1' # to avoid error in rope params when changing image size
782
+ self.model = _build_vision_tower(**self.config)
783
+
784
+
785
+ def forward(self, images):
786
+ encode = self.model(images, return_all_features=True)[:, 1:, :]
787
+ return encode
788
+
789
+ class CrossVisionModel(nn.Module):
790
+ def __init__(self, config):
791
+ super().__init__()
792
+ self.vit = Eva2LargeEncoder(image_size=config.cross_image_size)
793
+ self.pos_embed = nn.Parameter(torch.zeros((self.vit.config['vision_cfg']['image_size'] // self.vit.config['vision_cfg']['patch_size']) ** 2, self.vit.config['vision_cfg']['width']))
794
+
795
+ def forward(self, images):
796
+ enc = self.vit(images)
797
+ return enc + self.pos_embed.unsqueeze(0)
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.36.0.dev0"
7
+ }
model-00001-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15c2451e4e0dd5caae61e91de67ea0dac7b554c4e2b39d54e67ffbe232460063
3
+ size 4974581824
model-00002-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:351e987e5f1f28124a8c839085aeb3111e8f956393d3092439e07c7a285a5d90
3
+ size 4982995648
model-00003-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff93f061698e4a8bbc042d9f243efd9647b19889f0e5f87188cd6bfc0400f839
3
+ size 4982995728
model-00004-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54c8f6eca491986fb61914029014d80a7813c1538807a55296530c6337fe2829
3
+ size 4982995728
model-00005-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4218cb8e5bda353d4491a70054b2e8f384f00c55bcc453807309982760fc48fd
3
+ size 4982995728
model-00006-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23f2f3123c49dc4fe16fa44f6ee3dd388b18490fa05966f9dd21e806e3cbd22d
3
+ size 4950060832
model-00007-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d6eec2f2b61a1eb39dfb7703231639e16394bba2348108f1ed970268abac86e
3
+ size 4945866712
model-00008-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6df8c2486c4d21b20c5b2a72a699b47fd8ac62199ba9ae69db32054ef4fab1a2
3
+ size 1783098344
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_cogagent.py ADDED
@@ -0,0 +1,917 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """largely copy from llama and adapt for CogAgent"""
2
+ import warnings
3
+ from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, Any
4
+
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import CrossEntropyLoss
9
+ from torchvision import transforms
10
+ from einops import rearrange
11
+
12
+ from transformers import PreTrainedModel, PreTrainedTokenizer
13
+ from transformers.utils.logging import get_logger
14
+ from transformers.activations import ACT2FN
15
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
+
17
+ from .configuration_cogagent import CogAgentConfig
18
+ from .util import FastRotaryEmbedding
19
+ from .visual import EVA2CLIPModel
20
+ from .cross_visual import CrossVisionModel
21
+
22
+ if TYPE_CHECKING:
23
+ from transformers.utils import ModelOutput
24
+
25
+ logger = get_logger(__name__)
26
+
27
+ LANGUAGE_TOKEN_TYPE = 0
28
+ VISION_TOKEN_TYPE = 1
29
+
30
+
31
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
32
+ def _make_causal_mask(
33
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
34
+ ):
35
+ """
36
+ Make causal mask used for bi-directional self-attention.
37
+ """
38
+ bsz, tgt_len = input_ids_shape
39
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
40
+ mask_cond = torch.arange(mask.size(-1), device=device)
41
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
42
+ mask = mask.to(dtype)
43
+
44
+ if past_key_values_length > 0:
45
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
46
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
47
+
48
+
49
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
50
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
51
+ """
52
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
53
+ """
54
+ bsz, src_len = mask.size()
55
+ tgt_len = tgt_len if tgt_len is not None else src_len
56
+
57
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
58
+
59
+ inverted_mask = 1.0 - expanded_mask
60
+
61
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
62
+
63
+
64
+ class RMSNorm(nn.Module):
65
+ def __init__(self, hidden_size, eps=1e-6):
66
+ super().__init__()
67
+ self.weight = nn.Parameter(torch.ones(hidden_size))
68
+ self.variance_epsilon = eps
69
+
70
+ def forward(self, hidden_states):
71
+ input_dtype = hidden_states.dtype
72
+ hidden_states = hidden_states.to(torch.float32)
73
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
74
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
75
+ return (self.weight * hidden_states).to(input_dtype)
76
+
77
+
78
+ class MLP(nn.Module):
79
+ def __init__(self, config):
80
+ super().__init__()
81
+ self.hidden_size = config.hidden_size
82
+ self.intermediate_size = config.intermediate_size
83
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
84
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
85
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
86
+ self.act_fn = ACT2FN[config.hidden_act]
87
+
88
+ def forward(self, x):
89
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
90
+ return down_proj
91
+
92
+
93
+ def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
94
+ vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
95
+ vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] == VISION_TOKEN_TYPE)
96
+ language_token_mask = ~vision_token_mask
97
+ return vision_token_mask, language_token_mask
98
+
99
+
100
+ class VisionExpertMLP(nn.Module):
101
+ def __init__(self, config):
102
+ super().__init__()
103
+ self.language_mlp = MLP(config)
104
+ self.vision_mlp = MLP(config)
105
+
106
+ def forward(self, hidden_states: "torch.Tensor(B, L, D)", token_type_ids: "torch.LongTensor(B, L)"):
107
+ output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device)
108
+ vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
109
+ output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
110
+ output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
111
+ return output
112
+
113
+
114
+ def attention_fn(
115
+ query_layer: "torch.tensor(B, H, L, HD)",
116
+ key_layer: "torch.tensor(B, H, L, HD)",
117
+ value_layer: "torch.tensor(B, H, L, HD)",
118
+ attention_mask: "torch.tensor(B, H, L, HD)",
119
+ *,
120
+ scaling_attention_score: bool = True,
121
+ attention_dropout: nn.Module = None
122
+ ):
123
+ attention_mask_bool = (attention_mask == 0)
124
+ is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
125
+ is_full = (attention_mask_bool > 0).all()
126
+ if not (int(torch.__version__.split('.')[0]) >= 2):
127
+ warnings.warn("It's recommended to use torch2.0 or higher.")
128
+ if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
129
+ dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
130
+ return torch.nn.functional.scaled_dot_product_attention(
131
+ query_layer, key_layer, value_layer,
132
+ attn_mask=None,
133
+ dropout_p=dropout_p,
134
+ is_causal=not is_full
135
+ )
136
+ else:
137
+ if scaling_attention_score:
138
+ query_layer = query_layer / math.sqrt(query_layer.shape[-1])
139
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
140
+ attention_scores = attention_scores + attention_mask
141
+ attention_scores = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
142
+ if attention_dropout is not None:
143
+ attention_scores = attention_dropout(attention_scores)
144
+ context_layer = torch.matmul(attention_scores, value_layer)
145
+ return context_layer
146
+
147
+
148
+ class VisionExpertAttention(nn.Module):
149
+ def __init__(self, config):
150
+ super().__init__()
151
+ self.config = config
152
+ self.hidden_size = config.hidden_size
153
+ self.num_heads = config.num_attention_heads
154
+ self.head_dim = self.hidden_size // self.num_heads
155
+ self.max_position_embeddings = config.max_position_embeddings
156
+
157
+ # self.rotary_emb = RotaryEmbedding(self.hidden_size // self.num_heads)
158
+ self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False)
159
+ self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
160
+ self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
161
+ self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
162
+ self.language_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
163
+
164
+ def _transpose_for_scores(self, tensor):
165
+ """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
166
+ new_tensor_shape = tensor.size()[:-1] + (self.num_heads, self.head_dim)
167
+ tensor = tensor.view(*new_tensor_shape)
168
+ return tensor.permute(0, 2, 1, 3)
169
+
170
+ def forward(
171
+ self,
172
+ hidden_states: torch.Tensor,
173
+ token_type_ids: torch.LongTensor,
174
+ position_ids: torch.LongTensor,
175
+ attention_mask: Optional[torch.Tensor] = None,
176
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
177
+ output_attentions: bool = False,
178
+ use_cache: bool = False,
179
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
180
+ bsz, q_len, _ = hidden_states.size()
181
+ vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
182
+
183
+ shape = list(hidden_states.shape)
184
+ shape[-1] = shape[-1] * 3
185
+ mixed_raw_layer = torch.empty(shape, dtype=hidden_states.dtype, device=hidden_states.device)
186
+ mixed_raw_layer[vision_token_mask] = self.vision_expert_query_key_value(hidden_states[vision_token_mask])
187
+ mixed_raw_layer[language_token_mask] = self.language_expert_query_key_value(hidden_states[language_token_mask])
188
+
189
+ query_states, key_states, value_states = torch.split(mixed_raw_layer, self.hidden_size, dim=-1)
190
+ query_states = self._transpose_for_scores(query_states) # B, H, L, HD
191
+ key_states = self._transpose_for_scores(key_states) # B, H, L, HD
192
+ value_states = self._transpose_for_scores(value_states) # B, H, L, HD
193
+
194
+ kv_seq_len = key_states.shape[-2]
195
+ if past_key_value is not None:
196
+ kv_seq_len += past_key_value[0].shape[-2]
197
+
198
+ query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids, max_seqlen=position_ids.max() + 1)
199
+
200
+ if past_key_value is not None:
201
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
202
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
203
+
204
+ past_key_value = (key_states, value_states) if use_cache else None
205
+
206
+ context_layer = attention_fn(
207
+ query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
208
+ scaling_attention_score=True, attention_dropout=None)
209
+ if context_layer.size() != (bsz, self.num_heads, q_len, self.head_dim):
210
+ raise ValueError(
211
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
212
+ f" {context_layer.size()}"
213
+ )
214
+ context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
215
+
216
+ attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
217
+ attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
218
+ attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
219
+
220
+ if output_attentions:
221
+ warnings.warn("output_attentions is not implemented.")
222
+
223
+ return attn_output, None, past_key_value
224
+
225
+ class CrossAttention(nn.Module):
226
+ def __init__(self, config):
227
+ super().__init__()
228
+ self.config = config
229
+ self.hidden_size = config.hidden_size
230
+ self.cross_hidden_size = config.cross_hidden_size
231
+ self.cross_compute_hidden_size = config.cross_compute_hidden_size
232
+ self.num_heads = config.num_attention_heads
233
+ self.head_dim = self.hidden_size // self.num_heads
234
+ self.cross_head_dim = self.cross_compute_hidden_size // self.num_heads
235
+ self.max_position_embeddings = config.max_position_embeddings
236
+
237
+ # self.rotary_emb = RotaryEmbedding(self.hidden_size // self.num_heads)
238
+ self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False)
239
+ self.query = nn.Linear(self.hidden_size, self.cross_compute_hidden_size, bias=False)
240
+ self.key_value = nn.Linear(self.cross_hidden_size, self.cross_compute_hidden_size * 2, bias=False)
241
+ self.dense = nn.Linear(self.cross_compute_hidden_size, self.hidden_size, bias=False)
242
+
243
+ def _transpose_for_scores(self, tensor):
244
+ """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
245
+ new_tensor_shape = tensor.size()[:-1] + (self.num_heads, self.cross_head_dim)
246
+ tensor = tensor.view(*new_tensor_shape)
247
+ return tensor.permute(0, 2, 1, 3)
248
+
249
+ def forward(
250
+ self,
251
+ hidden_states: torch.Tensor,
252
+ encoder_outputs: torch.LongTensor,
253
+ attention_mask: Optional[torch.Tensor] = None,
254
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
255
+ output_attentions: bool = False,
256
+ use_cache: bool = False,
257
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
258
+ bsz, q_len, _ = hidden_states.size()
259
+
260
+ shape = list(hidden_states.shape)
261
+ shape[-1] = shape[-1] * 3
262
+
263
+ mixed_query_layer = self.query(hidden_states)
264
+ if past_key_value is None:
265
+ mixed_x_layer = self.key_value(encoder_outputs)
266
+ mixed_key_layer, mixed_value_layer = torch.split(mixed_x_layer, self.cross_compute_hidden_size, dim=-1)
267
+ key_states = self._transpose_for_scores(mixed_key_layer) # B, H, L, HD
268
+ value_states = self._transpose_for_scores(mixed_value_layer) # B, H, L, HD
269
+ else:
270
+ key_states, value_states = past_key_value
271
+
272
+ query_states = self._transpose_for_scores(mixed_query_layer) # B, H, L, HD
273
+
274
+ past_key_value = (key_states, value_states) if use_cache else None
275
+
276
+ context_layer = attention_fn(
277
+ query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
278
+ scaling_attention_score=True, attention_dropout=None)
279
+ if context_layer.size() != (bsz, self.num_heads, q_len, self.cross_head_dim):
280
+ raise ValueError(
281
+ f"`cross_attn_output` should be of size {(bsz, self.num_heads, q_len, self.cross_head_dim)}, but is"
282
+ f" {context_layer.size()}"
283
+ )
284
+ context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.cross_hidden_size)
285
+
286
+ attn_output = self.dense(context_layer)
287
+
288
+ if output_attentions:
289
+ warnings.warn("output_attentions is not implemented.")
290
+
291
+ return attn_output, None, past_key_value
292
+
293
+ class CogAgentDecoderLayer(nn.Module):
294
+ def __init__(self, config):
295
+ super().__init__()
296
+ self.hidden_size = config.hidden_size
297
+ self.self_attn = VisionExpertAttention(config=config)
298
+ self.cross_attn = CrossAttention(config=config)
299
+ self.mlp = VisionExpertMLP(config)
300
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
301
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
302
+ self.post_cross_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
303
+
304
+ def forward(
305
+ self,
306
+ hidden_states: torch.Tensor,
307
+ encoder_outputs: torch.Tensor,
308
+ token_type_ids: torch.LongTensor,
309
+ position_ids: torch.LongTensor,
310
+ attention_mask: Optional[torch.Tensor] = None,
311
+ cross_attention_mask: Optional[torch.Tensor] = None,
312
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
313
+ output_attentions: Optional[bool] = False,
314
+ use_cache: Optional[bool] = False,
315
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
316
+ residual = hidden_states
317
+
318
+ hidden_states = self.input_layernorm(hidden_states)
319
+
320
+ # Self Attention
321
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
322
+ hidden_states=hidden_states,
323
+ token_type_ids=token_type_ids,
324
+ position_ids=position_ids,
325
+ attention_mask=attention_mask,
326
+ past_key_value=past_key_value[:2] if past_key_value is not None else None,
327
+ output_attentions=output_attentions,
328
+ use_cache=use_cache,
329
+ )
330
+ hidden_states = residual + hidden_states
331
+
332
+ cross_input = self.post_cross_attention_layernorm(hidden_states)
333
+ # Fully Connected
334
+ attention_output, self_cross_attn_weights, present_cross_key_value = self.cross_attn(
335
+ hidden_states=cross_input,
336
+ encoder_outputs=encoder_outputs,
337
+ attention_mask=cross_attention_mask,
338
+ past_key_value=past_key_value[-2:] if past_key_value is not None else None,
339
+ output_attentions=output_attentions,
340
+ use_cache=use_cache,
341
+ )
342
+ hidden_states = hidden_states + attention_output
343
+ mlp_input = self.post_attention_layernorm(hidden_states)
344
+ mlp_output = self.mlp(mlp_input, token_type_ids=token_type_ids)
345
+ hidden_states = mlp_output + hidden_states
346
+
347
+ outputs = (hidden_states,)
348
+
349
+ if output_attentions:
350
+ outputs += (self_attn_weights,)
351
+
352
+ if use_cache:
353
+ outputs += (present_key_value+present_cross_key_value,)
354
+
355
+ return outputs # type: ignore
356
+
357
+
358
+ class CogAgentPreTrainedModel(PreTrainedModel):
359
+ config_class = CogAgentConfig
360
+ base_model_prefix = "model"
361
+ supports_gradient_checkpointing = False
362
+ _no_split_modules = ["CogAgentDecoderLayer"]
363
+ _skip_keys_device_placement = "past_key_values"
364
+
365
+ def _init_weights(self, module):
366
+ std = self.config.initializer_range
367
+ if isinstance(module, nn.Linear):
368
+ module.weight.data.normal_(mean=0.0, std=std)
369
+ if module.bias is not None:
370
+ module.bias.data.zero_()
371
+ elif isinstance(module, nn.Embedding):
372
+ module.weight.data.normal_(mean=0.0, std=std)
373
+ if module.padding_idx is not None:
374
+ module.weight.data[module.padding_idx].zero_()
375
+
376
+
377
+ def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
378
+ if images_list is None or len(images_list) == 0:
379
+ return True
380
+ for image_list in images_list:
381
+ if len(image_list):
382
+ return False
383
+ return True
384
+
385
+
386
+ def build_position_ids(x: "torch.BoolTensor(B, L)", attention_mask: Optional["torch.BoolTensor(B, L)"] = None) -> "torch.LongTensor(B, L)":
387
+ if attention_mask is not None:
388
+ tmp = x.clone()
389
+ tmp[~(attention_mask.bool())] = -1
390
+ else:
391
+ tmp = x.clone()
392
+ # image boi eoi token as LANGUAGE_TOKEN_TYPE
393
+ is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
394
+ is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
395
+ is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
396
+ is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
397
+ is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
398
+ tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
399
+ # final position ids
400
+ y = torch.zeros_like(x, dtype=torch.long)
401
+ y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ((tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
402
+ y = y.cumsum(dim=-1)
403
+ return y
404
+
405
+
406
+ class CogAgentModel(CogAgentPreTrainedModel):
407
+ def __init__(self, config):
408
+ super().__init__(config)
409
+ self.padding_idx = config.pad_token_id
410
+ self.vocab_size = config.vocab_size
411
+
412
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
413
+ self.layers = nn.ModuleList([CogAgentDecoderLayer(config) for _ in range(config.num_hidden_layers)])
414
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
415
+
416
+ self.vision = EVA2CLIPModel(config)
417
+ self.cross_vision = CrossVisionModel(config)
418
+
419
+ self.gradient_checkpointing = False
420
+ # Initialize weights and apply final processing
421
+ self.post_init()
422
+
423
+ def encode_images(self, images: List[List[torch.Tensor]]) -> torch.Tensor:
424
+ images_list, images = images, []
425
+
426
+ images = []
427
+ for image_list in images_list:
428
+ for image in image_list:
429
+ images.append(image)
430
+
431
+ images = torch.stack(images)
432
+ images_features = self.vision(images)
433
+ return images_features
434
+
435
+ def encode_cross_images(self, images: List[List[torch.Tensor]]) -> torch.Tensor:
436
+ images_list, images = images, []
437
+
438
+ images = []
439
+ for image_list in images_list:
440
+ for image in image_list:
441
+ images.append(image)
442
+
443
+ images = torch.stack(images)
444
+ encoder_outputs = self.cross_vision(images)
445
+ return encoder_outputs
446
+
447
+ def forward(
448
+ self,
449
+ input_ids: torch.LongTensor = None,
450
+ images: List[List[torch.Tensor]] = None,
451
+ cross_images: List[List[torch.Tensor]] = None,
452
+ token_type_ids: Optional[torch.LongTensor] = None,
453
+ attention_mask: Optional[torch.Tensor] = None,
454
+ cross_attention_mask: Optional[torch.Tensor] = None,
455
+ position_ids: Optional[torch.LongTensor] = None,
456
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
457
+ inputs_embeds: Optional[torch.FloatTensor] = None,
458
+ use_cache: Optional[bool] = None,
459
+ output_attentions: Optional[bool] = None,
460
+ output_hidden_states: Optional[bool] = None,
461
+ return_dict: Optional[bool] = None,
462
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
463
+ """take care of image_encode, token_type_ids, position_ids and (attention_mask = None is fine)"""
464
+
465
+ if past_key_values is not None:
466
+ encoder_outputs = None
467
+ # generate mode with past_key_values. the image features are already mapped
468
+ else:
469
+ # not allow for inputs_embeds, because we want to process image feature
470
+ assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
471
+ if not is_empty(images): # multi-modality
472
+ assert token_type_ids is not None, f"multi-modality requires `token_type_ids`!"
473
+ assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
474
+ inputs_embeds = self.embed_tokens(input_ids)
475
+ images_features = self.encode_images(images)
476
+ encoder_outputs = self.encode_cross_images(cross_images)
477
+ images_features = rearrange(images_features, 'b n d -> (b n) d')
478
+ images_features = images_features.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
479
+ inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
480
+ else: # single-modality
481
+ if token_type_ids is None:
482
+ token_type_ids = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) * LANGUAGE_TOKEN_TYPE
483
+ assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
484
+ inputs_embeds = self.embed_tokens(input_ids)
485
+ encoder_outputs = None
486
+
487
+ if position_ids is None:
488
+ position_ids = build_position_ids(token_type_ids, attention_mask)
489
+ input_ids = None
490
+
491
+ return self.llm_forward(
492
+ input_ids=input_ids,
493
+ encoder_outputs=encoder_outputs,
494
+ token_type_ids=token_type_ids,
495
+ attention_mask=attention_mask,
496
+ cross_attention_mask=cross_attention_mask,
497
+ position_ids=position_ids,
498
+ past_key_values=past_key_values,
499
+ inputs_embeds=inputs_embeds,
500
+ use_cache=use_cache,
501
+ output_attentions=output_attentions,
502
+ output_hidden_states=output_hidden_states,
503
+ return_dict=return_dict,
504
+ )
505
+
506
+ def llm_forward(
507
+ self,
508
+ input_ids: torch.LongTensor = None,
509
+ encoder_outputs: torch.LongTensor = None,
510
+ token_type_ids: torch.LongTensor = None,
511
+ attention_mask: Optional[torch.Tensor] = None,
512
+ cross_attention_mask: Optional[torch.Tensor] = None,
513
+ position_ids: Optional[torch.LongTensor] = None,
514
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
515
+ inputs_embeds: Optional[torch.FloatTensor] = None,
516
+ use_cache: Optional[bool] = None,
517
+ output_attentions: Optional[bool] = None,
518
+ output_hidden_states: Optional[bool] = None,
519
+ return_dict: Optional[bool] = None,
520
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
521
+ """largely copy from llama forward and adapt for CogAgent with `token_type_ids`"""
522
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
523
+ output_hidden_states = (
524
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
525
+ )
526
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
527
+
528
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
529
+
530
+ # retrieve input_ids and inputs_embeds
531
+ if input_ids is not None and inputs_embeds is not None:
532
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
533
+ elif input_ids is not None:
534
+ batch_size, seq_length = input_ids.shape
535
+ elif inputs_embeds is not None:
536
+ batch_size, seq_length, _ = inputs_embeds.shape
537
+ else:
538
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
539
+
540
+ seq_length_with_past = seq_length
541
+ past_key_values_length = 0
542
+
543
+ if past_key_values is not None:
544
+ past_key_values_length = past_key_values[0][0].shape[2]
545
+ seq_length_with_past = seq_length_with_past + past_key_values_length
546
+
547
+ if position_ids is None:
548
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
549
+ position_ids = torch.arange(
550
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
551
+ )
552
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
553
+ else:
554
+ position_ids = position_ids.view(-1, seq_length).long()
555
+
556
+ if inputs_embeds is None:
557
+ inputs_embeds = self.embed_tokens(input_ids)
558
+ # embed positions
559
+ if attention_mask is None:
560
+ attention_mask = torch.ones(
561
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
562
+ )
563
+ if cross_attention_mask is None:
564
+ cross_attention_mask = torch.ones(
565
+ (batch_size, 1), dtype=torch.bool, device=inputs_embeds.device
566
+ )
567
+ attention_mask = self._prepare_decoder_attention_mask(
568
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
569
+ )
570
+
571
+ hidden_states = inputs_embeds
572
+
573
+ # decoder layers
574
+ all_hidden_states = () if output_hidden_states else None
575
+ all_self_attns = () if output_attentions else None
576
+ next_decoder_cache = () if use_cache else None
577
+
578
+ for idx, decoder_layer in enumerate(self.layers):
579
+ if output_hidden_states:
580
+ all_hidden_states += (hidden_states,)
581
+
582
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
583
+ layer_outputs = decoder_layer(
584
+ hidden_states,
585
+ encoder_outputs=encoder_outputs,
586
+ token_type_ids=token_type_ids,
587
+ attention_mask=attention_mask,
588
+ cross_attention_mask=cross_attention_mask,
589
+ position_ids=position_ids,
590
+ past_key_value=past_key_value,
591
+ output_attentions=output_attentions,
592
+ use_cache=use_cache,
593
+ )
594
+ hidden_states = layer_outputs[0]
595
+
596
+ if use_cache:
597
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
598
+
599
+ if output_attentions:
600
+ all_self_attns += (layer_outputs[1],)
601
+
602
+ hidden_states = self.norm(hidden_states)
603
+
604
+ # add hidden states from the last decoder layer
605
+ if output_hidden_states:
606
+ all_hidden_states += (hidden_states,)
607
+
608
+ next_cache = next_decoder_cache if use_cache else None
609
+ if not return_dict:
610
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
611
+ return BaseModelOutputWithPast(
612
+ last_hidden_state=hidden_states,
613
+ past_key_values=next_cache,
614
+ hidden_states=all_hidden_states,
615
+ attentions=all_self_attns,
616
+ )
617
+
618
+ def get_input_embeddings(self):
619
+ return self.embed_tokens
620
+
621
+ def set_input_embeddings(self, value):
622
+ self.embed_tokens = value
623
+
624
+ # noinspection PyMethodMayBeStatic
625
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
626
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
627
+ # create causal mask
628
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
629
+ combined_attention_mask = None
630
+ if input_shape[-1] > 1:
631
+ combined_attention_mask = _make_causal_mask(
632
+ input_shape,
633
+ inputs_embeds.dtype,
634
+ device=inputs_embeds.device,
635
+ past_key_values_length=past_key_values_length,
636
+ )
637
+
638
+ if attention_mask is not None:
639
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
640
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
641
+ inputs_embeds.device
642
+ )
643
+ combined_attention_mask = (
644
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
645
+ )
646
+
647
+ return combined_attention_mask
648
+
649
+ def chat_old_history_to_prompt(history, query):
650
+ prompt = "<EOI>Question: "
651
+ for i, (old_query, response) in enumerate(history):
652
+ prompt += old_query + " Answer: " + response + "\nQuestion: "
653
+ prompt += query + " Answer:"
654
+ return prompt
655
+
656
+ def chat_history_to_prompt(history, query):
657
+ prompt = " [INST] "
658
+ for i, (old_query, response) in enumerate(history):
659
+ prompt += old_query + " [/INST] " + response + " [INST] "
660
+ prompt += query + " [/INST] "
661
+ return prompt
662
+
663
+
664
+ def base_history_to_prompt(history, query):
665
+ prompt = query
666
+ return prompt
667
+
668
+
669
+ _history_to_prompt = {
670
+ "base": base_history_to_prompt,
671
+ "chat": chat_history_to_prompt,
672
+ "chat_old": chat_old_history_to_prompt
673
+ }
674
+
675
+
676
+ class CogAgentForCausalLM(CogAgentPreTrainedModel):
677
+ _auto_class = "AutoModelForCausalLM"
678
+
679
+ def __init__(self, config):
680
+ super().__init__(config)
681
+ self.model = CogAgentModel(config)
682
+ self.vocab_size = config.vocab_size
683
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
684
+
685
+ # Initialize weights and apply final processing
686
+ self.post_init()
687
+
688
+ def get_input_embeddings(self):
689
+ return self.model.embed_tokens
690
+
691
+ def set_input_embeddings(self, value):
692
+ self.model.embed_tokens = value
693
+
694
+ def get_output_embeddings(self):
695
+ return self.lm_head
696
+
697
+ def set_output_embeddings(self, new_embeddings):
698
+ self.lm_head = new_embeddings
699
+
700
+ def set_decoder(self, decoder):
701
+ self.model = decoder
702
+
703
+ def get_decoder(self):
704
+ return self.model
705
+
706
+ def forward(
707
+ self,
708
+ input_ids: torch.LongTensor = None,
709
+ images: List[List[torch.Tensor]] = None,
710
+ cross_images: List[List[torch.Tensor]] = None,
711
+ token_type_ids: Optional[torch.LongTensor] = None,
712
+ attention_mask: Optional[torch.Tensor] = None,
713
+ position_ids: Optional[torch.LongTensor] = None,
714
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
715
+ inputs_embeds: Optional[torch.FloatTensor] = None,
716
+ use_cache: Optional[bool] = None,
717
+ output_attentions: Optional[bool] = None,
718
+ output_hidden_states: Optional[bool] = None,
719
+ return_dict: Optional[bool] = None,
720
+ labels: Optional[torch.LongTensor] = None,
721
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
722
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
723
+ output_hidden_states = (
724
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
725
+ )
726
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
727
+
728
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
729
+ outputs = self.model(
730
+ input_ids=input_ids,
731
+ images=images,
732
+ cross_images=cross_images,
733
+ token_type_ids=token_type_ids,
734
+ attention_mask=attention_mask,
735
+ position_ids=position_ids,
736
+ past_key_values=past_key_values,
737
+ inputs_embeds=inputs_embeds,
738
+ use_cache=use_cache,
739
+ output_attentions=output_attentions,
740
+ output_hidden_states=output_hidden_states,
741
+ return_dict=return_dict,
742
+ )
743
+
744
+ hidden_states = outputs[0]
745
+ logits = self.lm_head(hidden_states)
746
+ logits = logits.float()
747
+
748
+ loss = None
749
+ if labels is not None:
750
+ # Shift so that tokens < n predict n
751
+ shift_logits = logits[..., :-1, :].contiguous()
752
+ shift_labels = labels[..., 1:].contiguous()
753
+ # Flatten the tokens
754
+ loss_fct = CrossEntropyLoss()
755
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
756
+ shift_labels = shift_labels.view(-1)
757
+ # Enable model parallelism
758
+ shift_labels = shift_labels.to(shift_logits.device)
759
+ loss = loss_fct(shift_logits, shift_labels)
760
+
761
+ if not return_dict:
762
+ output = (logits,) + outputs[1:]
763
+ return (loss,) + output if loss is not None else output
764
+
765
+ return CausalLMOutputWithPast(
766
+ loss=loss,
767
+ logits=logits,
768
+ past_key_values=outputs.past_key_values,
769
+ hidden_states=outputs.hidden_states,
770
+ attentions=outputs.attentions,
771
+ )
772
+
773
+ def _prepare_attention_mask_for_generation(
774
+ self,
775
+ inputs: torch.Tensor,
776
+ pad_token_id: Optional[int],
777
+ eos_token_id: Optional[Union[int, List[int]]],
778
+ ) -> torch.LongTensor:
779
+ return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
780
+
781
+ def prepare_inputs_for_generation(
782
+ self, input_ids, token_type_ids, images=None, cross_images=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
783
+ ):
784
+ # build position_ids if needed
785
+ position_ids = kwargs.get("position_ids", None)
786
+ if position_ids is None:
787
+ position_ids = build_position_ids(token_type_ids, attention_mask)
788
+
789
+ if past_key_values:
790
+ input_ids = input_ids[:, -1:]
791
+ token_type_ids = token_type_ids[:, -1:]
792
+ position_ids = position_ids[:, -1:]
793
+
794
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
795
+ if inputs_embeds is not None and past_key_values is None:
796
+ model_inputs = {"inputs_embeds": inputs_embeds}
797
+ else:
798
+ model_inputs = {"input_ids": input_ids}
799
+
800
+ model_inputs.update(
801
+ {
802
+ "token_type_ids": token_type_ids,
803
+ "images": images,
804
+ "cross_images": cross_images,
805
+ "position_ids": position_ids,
806
+ "past_key_values": past_key_values,
807
+ "use_cache": kwargs.get("use_cache"),
808
+ "attention_mask": attention_mask,
809
+ }
810
+ )
811
+ return model_inputs
812
+
813
+ def _update_model_kwargs_for_generation(
814
+ self,
815
+ outputs: "ModelOutput",
816
+ model_kwargs: Dict[str, Any],
817
+ is_encoder_decoder: bool = False,
818
+ standardize_cache_format: bool = False,
819
+ ) -> Dict[str, Any]:
820
+ # update past_key_values
821
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
822
+ outputs, standardize_cache_format=standardize_cache_format
823
+ )
824
+ if getattr(outputs, "state", None) is not None:
825
+ model_kwargs["state"] = outputs.state
826
+
827
+ # update token_type_ids with last value
828
+ if "token_type_ids" in model_kwargs:
829
+ token_type_ids = model_kwargs["token_type_ids"]
830
+ new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype, device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
831
+ model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
832
+
833
+ if not is_encoder_decoder:
834
+ # update attention mask
835
+ if "attention_mask" in model_kwargs:
836
+ attention_mask = model_kwargs["attention_mask"]
837
+ model_kwargs["attention_mask"] = torch.cat(
838
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
839
+ )
840
+ else:
841
+ # update decoder attention mask
842
+ if "decoder_attention_mask" in model_kwargs:
843
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
844
+ model_kwargs["decoder_attention_mask"] = torch.cat(
845
+ [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
846
+ dim=-1,
847
+ )
848
+
849
+ return model_kwargs
850
+
851
+ def _reorder_cache(self, past_key_values, beam_idx):
852
+ reordered_past = ()
853
+ for layer_past in past_key_values:
854
+ reordered_past += (
855
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
856
+ )
857
+ return reordered_past
858
+
859
+ def build_conversation_input_ids(
860
+ self,
861
+ tokenizer: "PreTrainedTokenizer",
862
+ *,
863
+ query: str,
864
+ history: Optional[List[Tuple[str, str]]] = None,
865
+ images: Optional[List["PIL.Image"]] = None,
866
+ template_version: Optional[Literal["base", "chat", "vqa"]] = None,
867
+ ):
868
+ image_size: int = self.config.vision_config['image_size']
869
+ cross_image_size: int = self.config.cross_image_size
870
+ patch_size: int = self.config.vision_config['patch_size']
871
+ template_version = template_version or self.config.template_version
872
+ assert images is None or len(images) <= 1, f"not support multi images by now."
873
+ history = history or []
874
+ text = _history_to_prompt[template_version](history, query)
875
+
876
+ input_ids = [tokenizer.bos_token_id]
877
+ token_type_ids = [LANGUAGE_TOKEN_TYPE]
878
+ if images is not None and len(images) == 1:
879
+ ori = images
880
+ # vision
881
+ transform = transforms.Compose(
882
+ [
883
+ transforms.Resize(
884
+ (image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC
885
+ ),
886
+ transforms.ToTensor(),
887
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
888
+ ]
889
+ )
890
+ images = [transform(ori[0])]
891
+ cross_transform = transforms.Compose(
892
+ [
893
+ transforms.Resize(
894
+ (cross_image_size, cross_image_size), interpolation=transforms.InterpolationMode.BICUBIC
895
+ ),
896
+ transforms.ToTensor(),
897
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
898
+ ]
899
+ )
900
+ cross_images = [cross_transform(ori[0])]
901
+ # language
902
+ vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
903
+ input_ids += [tokenizer.pad_token_id] * vision_token_num
904
+ token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
905
+ text_ids = tokenizer.encode(text, add_special_tokens=False)
906
+
907
+ input_ids += text_ids
908
+ token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
909
+ attention_mask = [1] * len(input_ids)
910
+
911
+ return {
912
+ 'input_ids': torch.tensor(input_ids, dtype=torch.long),
913
+ 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
914
+ 'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
915
+ 'images': images,
916
+ 'cross_images': cross_images
917
+ }
util.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ from einops import rearrange, repeat
5
+ import torch.nn.functional as F
6
+
7
+ import triton
8
+ import triton.language as tl
9
+
10
+
11
+ # @triton.autotune(
12
+ # configs=[
13
+ # triton.Config({"BLOCK_M": 2}),
14
+ # triton.Config({"BLOCK_M": 4}),
15
+ # triton.Config({"BLOCK_M": 8}),
16
+ # triton.Config({"BLOCK_M": 16}),
17
+ # ],
18
+ # key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
19
+ # )
20
+ @triton.jit
21
+ def rotary_kernel(
22
+ OUT, # Pointers to matrices
23
+ X,
24
+ COS,
25
+ SIN,
26
+ CU_SEQLENS,
27
+ SEQLEN_OFFSETS, # this could be int or a pointer
28
+ # Matrix dimensions
29
+ seqlen,
30
+ nheads,
31
+ rotary_dim,
32
+ seqlen_ro,
33
+ CACHE_KEY_SEQLEN,
34
+ # strides
35
+ stride_out_batch,
36
+ stride_out_nheads,
37
+ stride_out_seqlen,
38
+ stride_out_headdim,
39
+ stride_x_batch,
40
+ stride_x_nheads,
41
+ stride_x_seqlen,
42
+ stride_x_headdim,
43
+ # Meta-parameters
44
+ BLOCK_K: tl.constexpr,
45
+ IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
46
+ IS_VARLEN: tl.constexpr,
47
+ INTERLEAVED: tl.constexpr,
48
+ CONJUGATE: tl.constexpr,
49
+ BLOCK_M: tl.constexpr,
50
+ ):
51
+ pid_m = tl.program_id(axis=0)
52
+ pid_batch = tl.program_id(axis=1)
53
+ pid_head = tl.program_id(axis=2)
54
+ rotary_dim_half = rotary_dim // 2
55
+
56
+ if not IS_VARLEN:
57
+ X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
58
+ OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
59
+ COS = COS + pid_batch * seqlen_ro * rotary_dim_half
60
+ SIN = SIN + pid_batch * seqlen_ro * rotary_dim_half
61
+ else:
62
+ start_idx = tl.load(CU_SEQLENS + pid_batch)
63
+ seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
64
+ X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
65
+ OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
66
+
67
+ if pid_m * BLOCK_M >= seqlen:
68
+ return
69
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
70
+ if not IS_SEQLEN_OFFSETS_TENSOR:
71
+ rm_cs = rm + SEQLEN_OFFSETS
72
+ else:
73
+ rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
74
+ rk = tl.arange(0, BLOCK_K)
75
+ rk_half = tl.arange(0, BLOCK_K // 2)
76
+
77
+ if not INTERLEAVED:
78
+ # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
79
+ X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
80
+ COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
81
+ SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
82
+ cos = tl.load(
83
+ COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
84
+ )
85
+ sin = tl.load(
86
+ SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
87
+ )
88
+ x0 = tl.load(
89
+ X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
90
+ )
91
+ x1 = tl.load(
92
+ X + rotary_dim_half * stride_x_headdim,
93
+ mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
94
+ other=0.0,
95
+ )
96
+ if CONJUGATE:
97
+ sin = -sin
98
+ o0 = x0 * cos - x1 * sin
99
+ o1 = x0 * sin + x1 * cos
100
+ # write back result
101
+ OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
102
+ tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
103
+ tl.store(
104
+ OUT + rotary_dim_half * stride_out_headdim,
105
+ o1,
106
+ mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
107
+ )
108
+ else:
109
+ # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
110
+ # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
111
+ # Loading x0 will be fast but x1 will be slow.
112
+ # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
113
+ # Then we do the calculation and use tl.where to pick put the right outputs for the even
114
+ # and for the odd indices.
115
+ rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
116
+ rk_repeat = tl.arange(0, BLOCK_K) // 2
117
+ X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
118
+ X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
119
+ COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
120
+ SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
121
+ cos = tl.load(
122
+ COS,
123
+ mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
124
+ other=1.0,
125
+ ).to(tl.float32)
126
+ sin = tl.load(
127
+ SIN,
128
+ mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
129
+ other=0.0,
130
+ ).to(tl.float32)
131
+ x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
132
+ tl.float32
133
+ )
134
+ x1 = tl.load(
135
+ X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
136
+ ).to(tl.float32)
137
+ if CONJUGATE:
138
+ sin = -sin
139
+ x0_cos = x0 * cos
140
+ x1_sin = x1 * sin
141
+ out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
142
+ OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
143
+ tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
144
+
145
+
146
+ def apply_rotary(
147
+ x: torch.Tensor,
148
+ cos: torch.Tensor,
149
+ sin: torch.Tensor,
150
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
151
+ cu_seqlens: Optional[torch.Tensor] = None,
152
+ max_seqlen: Optional[int] = None,
153
+ interleaved=False,
154
+ inplace=False,
155
+ conjugate=False,
156
+ ) -> torch.Tensor:
157
+ """
158
+ Arguments:
159
+ x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
160
+ else (total_seqlen, nheads, headdim).
161
+ cos: (seqlen_ro, rotary_dim / 2)
162
+ sin: (seqlen_ro, rotary_dim / 2)
163
+ seqlen_offsets: integer or integer tensor of size (batch,)
164
+ cu_seqlens: (batch + 1,) or None
165
+ max_seqlen: int
166
+ Returns:
167
+ y: (batch, seqlen, nheads, headdim)
168
+ """
169
+
170
+ batch, nheads, seqlen, headdim = x.shape
171
+
172
+ batch_ro, seqlen_ro, rotary_dim = cos.shape
173
+
174
+ assert batch == batch_ro
175
+ assert sin.shape == cos.shape
176
+ rotary_dim *= 2
177
+ assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
178
+ assert headdim <= 256, "Only support headdim <= 256"
179
+
180
+ assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
181
+
182
+ assert (
183
+ cos.dtype == sin.dtype
184
+ ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
185
+ assert (
186
+ x.dtype == cos.dtype
187
+ ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
188
+
189
+ cos, sin = cos.contiguous(), sin.contiguous()
190
+ if isinstance(seqlen_offsets, torch.Tensor):
191
+ assert seqlen_offsets.shape == (batch,)
192
+ assert seqlen_offsets.dtype in [torch.int32, torch.int64]
193
+ seqlen_offsets = seqlen_offsets.contiguous()
194
+ else:
195
+ assert seqlen_offsets + seqlen <= seqlen_ro
196
+
197
+ output = torch.empty_like(x) if not inplace else x
198
+ if rotary_dim < headdim and not inplace:
199
+ output[..., rotary_dim:].copy_(x[..., rotary_dim:])
200
+
201
+ BLOCK_K = (
202
+ 32
203
+ if rotary_dim <= 32
204
+ else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
205
+ )
206
+ grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
207
+ BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
208
+
209
+ # Need this, otherwise Triton tries to launch from cuda:0 and we get
210
+ # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
211
+ with torch.cuda.device(x.device.index):
212
+ rotary_kernel[grid](
213
+ output, # data ptrs
214
+ x,
215
+ cos,
216
+ sin,
217
+ cu_seqlens,
218
+ seqlen_offsets,
219
+ seqlen, # shapes
220
+ nheads,
221
+ rotary_dim,
222
+ seqlen_ro,
223
+ seqlen // 128, # key for triton cache (limit number of compilations)
224
+ output.stride(0), # batch_strides
225
+ output.stride(-3), # nheads_stride
226
+ output.stride(-2), # seqlen_stride
227
+ output.stride(-1), # headdim_stride
228
+ x.stride(0), # batch_strides
229
+ x.stride(-3), # nheads stride
230
+ x.stride(-2), # seqlen stride
231
+ x.stride(-1), # headdim stride
232
+ BLOCK_K,
233
+ isinstance(seqlen_offsets, torch.Tensor),
234
+ False,
235
+ interleaved,
236
+ conjugate,
237
+ BLOCK_M,
238
+ )
239
+ return output
240
+
241
+
242
+ class ApplyRotaryEmb(torch.autograd.Function):
243
+ @staticmethod
244
+ def forward(
245
+ ctx,
246
+ x,
247
+ cos,
248
+ sin,
249
+ interleaved=False,
250
+ inplace=False,
251
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
252
+ cu_seqlens: Optional[torch.Tensor] = None,
253
+ max_seqlen: Optional[int] = None,
254
+ ):
255
+ out = apply_rotary(
256
+ x,
257
+ cos,
258
+ sin,
259
+ seqlen_offsets=seqlen_offsets,
260
+ cu_seqlens=cu_seqlens,
261
+ max_seqlen=max_seqlen,
262
+ interleaved=interleaved,
263
+ inplace=inplace,
264
+ )
265
+ if isinstance(seqlen_offsets, int):
266
+ ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
267
+ ctx.seqlen_offsets = seqlen_offsets
268
+ else:
269
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
270
+ ctx.seqlen_offsets = None
271
+ ctx.interleaved = interleaved
272
+ ctx.inplace = inplace
273
+ ctx.max_seqlen = max_seqlen
274
+ return out if not inplace else x
275
+
276
+ @staticmethod
277
+ def backward(ctx, do):
278
+ seqlen_offsets = ctx.seqlen_offsets
279
+ if seqlen_offsets is None:
280
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
281
+ else:
282
+ cos, sin, cu_seqlens = ctx.saved_tensors
283
+ # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
284
+ # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
285
+ if not ctx.interleaved and not ctx.inplace:
286
+ do = do.clone()
287
+ dx = apply_rotary(
288
+ do,
289
+ cos,
290
+ sin,
291
+ seqlen_offsets=seqlen_offsets,
292
+ cu_seqlens=cu_seqlens,
293
+ max_seqlen=ctx.max_seqlen,
294
+ interleaved=ctx.interleaved,
295
+ inplace=ctx.inplace,
296
+ conjugate=True,
297
+ )
298
+ return dx, None, None, None, None, None, None, None
299
+
300
+
301
+ def apply_rotary_emb(
302
+ x,
303
+ cos,
304
+ sin,
305
+ interleaved=False,
306
+ inplace=False,
307
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
308
+ cu_seqlens: Optional[torch.Tensor] = None,
309
+ max_seqlen: Optional[int] = None,
310
+ ):
311
+ """
312
+ Arguments:
313
+ x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
314
+ else (total_seqlen, nheads, headdim)
315
+ cos, sin: (seqlen_rotary, rotary_dim / 2)
316
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
317
+ of 1st half and 2nd half (GPT-NeoX style).
318
+ inplace: if True, apply rotary embedding in-place.
319
+ seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
320
+ Most commonly used in inference when we have KV cache.
321
+ cu_seqlens: (batch + 1,) or None
322
+ max_seqlen: int
323
+ Return:
324
+ out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
325
+ else (total_seqlen, nheads, headdim)
326
+ rotary_dim must be <= headdim
327
+ Apply rotary embedding to the first rotary_dim of x.
328
+ """
329
+ return ApplyRotaryEmb.apply(
330
+ x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
331
+ )
332
+
333
+
334
+ # For backward compatibility
335
+ apply_rotary_emb_func = apply_rotary_emb
336
+
337
+
338
+ class FastRotaryEmbedding(torch.nn.Module):
339
+ """
340
+ The rotary position embeddings from RoFormer_ (Su et. al).
341
+ A crucial insight from the method is that the query and keys are
342
+ transformed by rotation matrices which depend on the relative positions.
343
+
344
+ Other implementations are available in the Rotary Transformer repo_ and in
345
+ GPT-NeoX_, GPT-NeoX was an inspiration
346
+
347
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
348
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
349
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
350
+
351
+ If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
352
+ A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
353
+ Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
354
+ """
355
+
356
+ def __init__(
357
+ self,
358
+ dim: int,
359
+ base=10000,
360
+ interleaved=False,
361
+ scale_base=None,
362
+ pos_idx_in_fp32=True,
363
+ device=None,
364
+ ):
365
+ """
366
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
367
+ of 1st half and 2nd half (GPT-NeoX style).
368
+ pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
369
+ otherwise they might be in lower precision.
370
+ This option was added because previously (before 2023-07-02), when we construct
371
+ the position indices, we use the dtype of self.inv_freq. In most cases this would
372
+ be fp32, but if the model is trained in pure bf16 (not mixed precision), then
373
+ self.inv_freq would be bf16, and the position indices are also in bf16.
374
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
375
+ embeddings for some positions will coincide.
376
+ To maintain compatibility with models previously trained in pure bf16,
377
+ we add this option.
378
+ """
379
+ super().__init__()
380
+ self.dim = dim
381
+ self.base = base
382
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
383
+ # Generate and save the inverse frequency buffer (non trainable)
384
+ inv_freq = self._compute_inv_freq(device)
385
+ self.register_buffer("inv_freq", inv_freq)
386
+ self.interleaved = interleaved
387
+ self.scale_base = scale_base
388
+ scale = (
389
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
390
+ if scale_base is not None
391
+ else None
392
+ )
393
+ self.register_buffer("scale", scale, persistent=False)
394
+
395
+ self._seq_len_cached = 0
396
+ self._cos_cached = None
397
+ self._sin_cached = None
398
+ self._cos_k_cached = None
399
+ self._sin_k_cached = None
400
+ self.cos = None
401
+ self.sin = None
402
+
403
+ def _compute_inv_freq(self, device=None):
404
+ return 1.0 / (
405
+ self.base
406
+ ** (torch.arange(0, self.dim, 2, device=device) / self.dim)
407
+ # ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)
408
+ )
409
+
410
+ def _update_cos_sin_cache(self, seqlen, position_id, device=None, dtype=None):
411
+
412
+ if (
413
+ seqlen > self._seq_len_cached
414
+ ):
415
+ self._seq_len_cached = seqlen
416
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
417
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
418
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
419
+ if self.pos_idx_in_fp32:
420
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
421
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
422
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
423
+ # cos & sin output to change significantly.
424
+ # We want to recompute self.inv_freq if it was not loaded in fp32
425
+ if self.inv_freq.dtype != torch.float32:
426
+ inv_freq = self._compute_inv_freq(device=device)
427
+ else:
428
+ inv_freq = self.inv_freq
429
+ else:
430
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
431
+ inv_freq = self.inv_freq
432
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
433
+ if self.scale is None:
434
+ self._cos_cached = torch.cos(freqs).to(dtype)
435
+ self._sin_cached = torch.sin(freqs).to(dtype)
436
+
437
+ else:
438
+ power = (
439
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
440
+ - seqlen // 2
441
+ ) / self.scale_base
442
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
443
+ # We want the multiplication by scale to happen in fp32
444
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
445
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
446
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
447
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
448
+
449
+ def forward(
450
+ self,
451
+ q: torch.Tensor,
452
+ k: torch.Tensor,
453
+ position_ids: torch.Tensor,
454
+ max_seqlen,
455
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
456
+ """
457
+ q: (batch, nheads, seqlen, headdim)
458
+ k: (batch, nheads, seqlen, headdim)
459
+ position_id: (batch, seqlen)
460
+ max_seqlen: int
461
+ layer_id: int
462
+ only if layer_id == 0, then update cons and sin
463
+ Apply rotary embedding *inplace* to q k.
464
+ """
465
+
466
+ self._update_cos_sin_cache(max_seqlen, position_ids, device=q.device, dtype=q.dtype)
467
+ cos, sin = F.embedding(position_ids, self._cos_cached), F.embedding(position_ids, self._sin_cached)
468
+
469
+ q = apply_rotary_emb_func(
470
+ q,
471
+ cos,
472
+ sin,
473
+ interleaved=self.interleaved,
474
+ inplace=True
475
+ )
476
+ k = apply_rotary_emb_func(
477
+ k,
478
+ cos,
479
+ sin,
480
+ interleaved=self.interleaved,
481
+ inplace=True
482
+ )
483
+ return q, k
visual.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from argparse import Namespace
4
+ import xformers.ops as xops
5
+ from transformers.activations import ACT2FN
6
+
7
+
8
+ class PatchEmbedding(nn.Module):
9
+ def __init__(self, config):
10
+ super().__init__()
11
+ self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size)
12
+ self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
13
+ self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
14
+
15
+ def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
16
+ x = self.proj(images)
17
+ x = x.flatten(2).transpose(1, 2)
18
+ cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
19
+ x = torch.cat((cls_token, x), dim=1)
20
+ x += self.position_embedding.weight.unsqueeze(0)
21
+ return x
22
+
23
+
24
+ class Attention(nn.Module):
25
+ def __init__(self, config):
26
+ super().__init__()
27
+ self.num_heads = config.num_heads
28
+ head_dim = config.hidden_size // config.num_heads
29
+ self.scale = head_dim ** -0.5
30
+ self.query_key_value = nn.Linear(config.hidden_size, config.hidden_size * 3)
31
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
32
+ self.output_dropout = torch.nn.Dropout(config.dropout_prob)
33
+
34
+ def forward(self, x: "tensor(B, L, D)") -> "tensor(B, L, D)":
35
+ B, L, _ = x.shape
36
+ qkv = self.query_key_value(x)
37
+ qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) # 3, B, L, H, D
38
+ q, k, v = qkv[0], qkv[1], qkv[2]
39
+
40
+ out = xops.memory_efficient_attention(
41
+ q, k, v, scale=self.scale,
42
+ )
43
+ output = self.dense(out.view(B, L, -1))
44
+ output = self.output_dropout(output)
45
+ return output
46
+
47
+ def attention(self, q, k, v):
48
+ attn_weights = torch.matmul(q * self.scale, k.transpose(-2, -1))
49
+ attn_weights = attn_weights.softmax(dim=-1)
50
+ output = torch.matmul(attn_weights, v)
51
+ return output
52
+
53
+
54
+ class MLP(nn.Module):
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.config = config
58
+ self.activation_fn = ACT2FN[config.hidden_act]
59
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
60
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
61
+
62
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
63
+ x = self.fc1(x)
64
+ x = self.activation_fn(x)
65
+ x = self.fc2(x)
66
+ return x
67
+
68
+
69
+ class TransformerLayer(nn.Module):
70
+ def __init__(self, config):
71
+ super().__init__()
72
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
73
+ self.attention = Attention(config)
74
+ self.mlp = MLP(config)
75
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
76
+
77
+ def forward(self, hidden_states):
78
+ attention_input = hidden_states
79
+ attention_output = self.input_layernorm(self.attention(attention_input))
80
+ hidden_states = attention_input + attention_output
81
+ mlp_input = hidden_states
82
+ mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
83
+ output = mlp_input + mlp_output
84
+ return output
85
+
86
+
87
+ class Transformer(nn.Module):
88
+ def __init__(self, config):
89
+ super().__init__()
90
+ self.layers = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)])
91
+
92
+ def forward(self, hidden_states):
93
+ for layer_module in self.layers:
94
+ hidden_states = layer_module(hidden_states)
95
+ return hidden_states
96
+
97
+
98
+ class GLU(nn.Module):
99
+ def __init__(self, config, in_features):
100
+ super().__init__()
101
+ self.linear_proj = nn.Linear(in_features, config.hidden_size, bias=False)
102
+ self.norm1 = nn.LayerNorm(config.hidden_size)
103
+ self.act1 = nn.GELU()
104
+ self.act2 = nn.functional.silu
105
+ self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
106
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
107
+ self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
108
+
109
+ def forward(self, x):
110
+ x = self.linear_proj(x)
111
+ x = self.act1(self.norm1(x))
112
+ x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
113
+ x = self.dense_4h_to_h(x)
114
+ return x
115
+
116
+
117
+ class EVA2CLIPModel(nn.Module):
118
+ def __init__(self, config):
119
+ super().__init__()
120
+ vision_config = Namespace(**config.vision_config)
121
+ self.patch_embedding = PatchEmbedding(vision_config)
122
+ self.transformer = Transformer(vision_config)
123
+ self.linear_proj = GLU(config, in_features=vision_config.hidden_size)
124
+ self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
125
+ self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
126
+ self.pos_embed = nn.Parameter(torch.zeros((vision_config.image_size // vision_config.patch_size) ** 2, vision_config.hidden_size))
127
+
128
+ def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
129
+ x = self.patch_embedding(images)
130
+ x = self.transformer(x)
131
+ x = x[:, 1:]
132
+ x = self.linear_proj(x + self.pos_embed.unsqueeze(0))
133
+ boi = self.boi.expand(x.shape[0], -1, -1)
134
+ eoi = self.eoi.expand(x.shape[0], -1, -1)
135
+ x = torch.cat((boi, x, eoi), dim=1)
136
+ return x