grim3000 chenkq commited on
Commit
afb44cc
0 Parent(s):

Duplicate from THUDM/cogvlm-chat-hf

Browse files

Co-authored-by: chenkq <chenkq@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ ---
6
+ # CogVLM
7
+
8
+ **CogVLM** 是一个强大的开源视觉语言模型(VLM)。CogVLM-17B 拥有 100 亿视觉参数和 70 亿语言参数,在 10 个经典跨模态基准测试上取得了 SOTA 性能,包括 NoCaps、Flicker30k captioning、RefCOCO、RefCOCO+、RefCOCOg、Visual7W、GQA、ScienceQA、VizWiz VQA 和 TDIUC,而在 VQAv2、OKVQA、TextVQA、COCO captioning 等方面则排名第二,超越或与 PaLI-X 55B 持平。您可以通过线上 [demo](http://36.103.203.44:7861/) 体验 CogVLM 多模态对话。
9
+
10
+ **CogVLM** is a powerful **open-source visual language model** (**VLM**). CogVLM-17B has 10 billion vision parameters and 7 billion language parameters. CogVLM-17B achieves state-of-the-art performance on 10 classic cross-modal benchmarks, including NoCaps, Flicker30k captioning, RefCOCO, RefCOCO+, RefCOCOg, Visual7W, GQA, ScienceQA, VizWiz VQA and TDIUC, and rank the 2nd on VQAv2, OKVQA, TextVQA, COCO captioning, etc., **surpassing or matching PaLI-X 55B**. CogVLM can also [chat with you](http://36.103.203.44:7861/) about images.
11
+
12
+ <div align="center">
13
+ <img src="https://github.com/THUDM/CogVLM/raw/main/assets/metrics-min.png" alt="img" style="zoom: 50%;" />
14
+ </div>
15
+
16
+ # 快速开始(Qiuckstart)
17
+
18
+ 硬件需求(hardware requirement)
19
+
20
+ 需要近 40GB GPU 显存用于模型推理。如果没有一整块GPU显存超过40GB,则需要使用accelerate的将模型切分到多个有较小显存的GPU设备上。
21
+
22
+ 40GB VRAM for inference. If there is no single GPU with more than 40GB of VRAM, you will need to use the "accelerate" library to dispatch the model into multiple GPUs with smaller VRAM.
23
+
24
+ 安装依赖(dependencies)
25
+
26
+ ```base
27
+ pip install torch==2.1.0 transformers==4.35.0 accelerate==0.24.1 sentencepiece==0.1.99 einops==0.7.0 xformers==0.0.22.post7 triton==2.1.0
28
+ ```
29
+
30
+ 代码示例(example)
31
+
32
+ ```python
33
+ import torch
34
+ import requests
35
+ from PIL import Image
36
+ from transformers import AutoModelForCausalLM, LlamaTokenizer
37
+
38
+ tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
39
+ model = AutoModelForCausalLM.from_pretrained(
40
+ 'THUDM/cogvlm-chat-hf',
41
+ torch_dtype=torch.bfloat16,
42
+ low_cpu_mem_usage=True,
43
+ trust_remote_code=True
44
+ ).to('cuda').eval()
45
+
46
+
47
+ # chat example
48
+ query = 'Describe this image'
49
+ image = Image.open(requests.get('https://github.com/THUDM/CogVLM/blob/main/examples/1.png?raw=true', stream=True).raw).convert('RGB')
50
+ inputs = model.build_conversation_input_ids(tokenizer, query=query, history=[], images=[image]) # chat mode
51
+ inputs = {
52
+ 'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
53
+ 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
54
+ 'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
55
+ 'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)]],
56
+ }
57
+ gen_kwargs = {"max_length": 2048, "do_sample": False}
58
+
59
+ with torch.no_grad():
60
+ outputs = model.generate(**inputs, **gen_kwargs)
61
+ outputs = outputs[:, inputs['input_ids'].shape[1]:]
62
+ print(tokenizer.decode(outputs[0]))
63
+
64
+ # This image captures a moment from a basketball game. Two players are prominently featured: one wearing a yellow jersey with the number
65
+ # 24 and the word 'Lakers' written on it, and the other wearing a navy blue jersey with the word 'Washington' and the number 34. The player
66
+ # in yellow is holding a basketball and appears to be dribbling it, while the player in navy blue is reaching out with his arm, possibly
67
+ # trying to block or defend. The background shows a filled stadium with spectators, indicating that this is a professional game.</s>
68
+
69
+
70
+
71
+ # vqa example
72
+ query = 'How many houses are there in this cartoon?'
73
+ image = Image.open(requests.get('https://github.com/THUDM/CogVLM/blob/main/examples/3.jpg?raw=true', stream=True).raw).convert('RGB')
74
+ inputs = model.build_conversation_input_ids(tokenizer, query=query, history=[], images=[image], template_version='vqa') # vqa mode
75
+ inputs = {
76
+ 'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
77
+ 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
78
+ 'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
79
+ 'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)]],
80
+ }
81
+ gen_kwargs = {"max_length": 2048, "do_sample": False}
82
+
83
+ with torch.no_grad():
84
+ outputs = model.generate(**inputs, **gen_kwargs)
85
+ outputs = outputs[:, inputs['input_ids'].shape[1]:]
86
+ print(tokenizer.decode(outputs[0]))
87
+
88
+ # 4</s>
89
+ ```
90
+
91
+ 当单卡显存不足时,可以将模型切分到多个小显存GPU上。以下是个当你有两张24GB的GPU,16GBCPU内存的例子。
92
+ 你可以将`infer_auto_device_map`的参数改成你的配置。注意这里将GPU显存少写了一点,这是为推理时中间状态预留出一部分显存。
93
+
94
+ dispatch the model into multiple GPUs with smaller VRAM. This is an example for you have two 24GB GPU and 16GB CPU memory.
95
+ you can change the arguments of `infer_auto_device_map` with your own setting.
96
+
97
+ ```python
98
+ import torch
99
+ import requests
100
+ from PIL import Image
101
+ from transformers import AutoModelForCausalLM, LlamaTokenizer
102
+ from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch
103
+
104
+ tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
105
+ with init_empty_weights():
106
+ model = AutoModelForCausalLM.from_pretrained(
107
+ 'THUDM/cogvlm-chat-hf',
108
+ torch_dtype=torch.bfloat16,
109
+ low_cpu_mem_usage=True,
110
+ trust_remote_code=True,
111
+ )
112
+ device_map = infer_auto_device_map(model, max_memory={0:'20GiB',1:'20GiB','cpu':'16GiB'}, no_split_module_classes='CogVLMDecoderLayer')
113
+ model = load_checkpoint_and_dispatch(
114
+ model,
115
+ 'local/path/to/hf/version/chat/model', # typical, '~/.cache/huggingface/hub/models--THUDM--cogvlm-chat-hf/snapshots/balabala'
116
+ device_map=device_map,
117
+ )
118
+ model = model.eval()
119
+
120
+ # check device for weights if u want to
121
+ for n, p in model.named_parameters():
122
+ print(f"{n}: {p.device}")
123
+
124
+ # chat example
125
+ query = 'Describe this image'
126
+ image = Image.open(requests.get('https://github.com/THUDM/CogVLM/blob/main/examples/1.png?raw=true', stream=True).raw).convert('RGB')
127
+ inputs = model.build_conversation_input_ids(tokenizer, query=query, history=[], images=[image]) # chat mode
128
+ inputs = {
129
+ 'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
130
+ 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
131
+ 'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
132
+ 'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)]],
133
+ }
134
+ gen_kwargs = {"max_length": 2048, "do_sample": False}
135
+
136
+ with torch.no_grad():
137
+ outputs = model.generate(**inputs, **gen_kwargs)
138
+ outputs = outputs[:, inputs['input_ids'].shape[1]:]
139
+ print(tokenizer.decode(outputs[0]))
140
+ ```
141
+
142
+
143
+
144
+ # 方法(Method)
145
+
146
+ CogVLM 模型包括四个基本组件:视觉变换器(ViT)编码器、MLP适配器、预训练的大型语言模型(GPT)和一个**视觉专家模块**。更多细节请参见[Paper](https://github.com/THUDM/CogVLM/blob/main/assets/cogvlm-paper.pdf)。
147
+
148
+ CogVLM model comprises four fundamental components: a vision transformer (ViT) encoder, an MLP adapter, a pretrained large language model (GPT), and a **visual expert module**. See [Paper](https://github.com/THUDM/CogVLM/blob/main/assets/cogvlm-paper.pdf) for more details.
149
+
150
+ <div align="center">
151
+ <img src="https://github.com/THUDM/CogVLM/raw/main/assets/method-min.png" style="zoom:50%;" />
152
+ </div>
153
+
154
+ # 许可(License)
155
+
156
+ 此存储库中的代码是根据 [Apache-2.0 许可](https://github.com/THUDM/CogVLM/raw/main/LICENSE) 开放源码,而使用 CogVLM 模型权重必须遵循 [模型许可](https://github.com/THUDM/CogVLM/raw/main/MODEL_LICENSE)。
157
+
158
+ The code in this repository is open source under the [Apache-2.0 license](https://github.com/THUDM/CogVLM/raw/main/LICENSE), while the use of the CogVLM model weights must comply with the [Model License](https://github.com/THUDM/CogVLM/raw/main/MODEL_LICENSE).
159
+
160
+
161
+
162
+ # 引用(Citation)
163
+
164
+ If you find our work helpful, please consider citing the following papers
165
+ ```
166
+ @article{wang2023cogvlm,
167
+ title={CogVLM: Visual Expert for Pretrained Language Models},
168
+ author={Weihan Wang and Qingsong Lv and Wenmeng Yu and Wenyi Hong and Ji Qi and Yan Wang and Junhui Ji and Zhuoyi Yang and Lei Zhao and Xixuan Song and Jiazheng Xu and Bin Xu and Juanzi Li and Yuxiao Dong and Ming Ding and Jie Tang},
169
+ year={2023},
170
+ eprint={2311.03079},
171
+ archivePrefix={arXiv},
172
+ primaryClass={cs.CV}
173
+ }
174
+ ```
config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "cogvlm-chat-v1.1",
3
+ "architectures": [
4
+ "CogVLMForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_cogvlm.CogVLMConfig",
8
+ "AutoModelForCausalLM": "modeling_cogvlm.CogVLMForCausalLM"
9
+ },
10
+ "bos_token_id": 1,
11
+ "eos_token_id": 2,
12
+ "hidden_act": "silu",
13
+ "hidden_size": 4096,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 11008,
16
+ "max_position_embeddings": 2048,
17
+ "num_attention_heads": 32,
18
+ "num_hidden_layers": 32,
19
+ "pad_token_id": 0,
20
+ "rms_norm_eps": 1e-05,
21
+ "template_version": "chat",
22
+ "tie_word_embeddings": false,
23
+ "torch_dtype": "bfloat16",
24
+ "transformers_version": "4.35.0",
25
+ "use_cache": true,
26
+ "vision_config": {
27
+ "dropout_prob": 0.0,
28
+ "hidden_act": "gelu",
29
+ "hidden_size": 1792,
30
+ "image_size": 490,
31
+ "in_channels": 3,
32
+ "intermediate_size": 15360,
33
+ "layer_norm_eps": 1e-06,
34
+ "num_heads": 16,
35
+ "num_hidden_layers": 63,
36
+ "num_positions": 1226,
37
+ "patch_size": 14
38
+ },
39
+ "vocab_size": 32000
40
+ }
configuration_cogvlm.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class CogVLMConfig(PretrainedConfig):
6
+ _auto_class = "AutoConfig"
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=32000,
11
+ hidden_size=4096,
12
+ intermediate_size=11008,
13
+ num_hidden_layers=32,
14
+ num_attention_heads=32,
15
+ hidden_act='silu',
16
+ max_position_embeddings=2048,
17
+ initializer_range=0.02,
18
+ rms_norm_eps=1e-06,
19
+ template_version: Literal["base", "chat"] = "chat",
20
+
21
+ pad_token_id=0,
22
+ bos_token_id=1,
23
+ eos_token_id=2,
24
+ tie_word_embeddings=False,
25
+ use_cache=True,
26
+ **kwargs,
27
+ ):
28
+ self.hidden_size = hidden_size
29
+ self.intermediate_size = intermediate_size
30
+ self.num_attention_heads = num_attention_heads
31
+ self.max_position_embeddings = max_position_embeddings
32
+ self.rms_norm_eps = rms_norm_eps
33
+ self.initializer_range = initializer_range
34
+ self.vocab_size = vocab_size
35
+ self.num_hidden_layers = num_hidden_layers
36
+ self.hidden_act = hidden_act
37
+ self.template_version = template_version
38
+ self.use_cache = use_cache
39
+ super().__init__(
40
+ pad_token_id=pad_token_id,
41
+ bos_token_id=bos_token_id,
42
+ eos_token_id=eos_token_id,
43
+ tie_word_embeddings=tie_word_embeddings,
44
+ **kwargs,
45
+ )
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.35.0"
7
+ }
model-00001-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e29f6ec471ca55789ab14947b527729b9c30313ceb1e7726590b85f9f6406cca
3
+ size 4938885184
model-00002-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e82356882701b1a778408f31e676d17c2aff799c543e8596ed74bc805b4a1213
3
+ size 4947290688
model-00003-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04096f84f42798d0c89319ff8254995a2a3512c16ec88dfd078ce421867d92ec
3
+ size 4947307592
model-00004-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b42af0bb16647959b3e55def4b3c66ab8c3a25fd948a5245c81d070f2b4313d
3
+ size 4991331080
model-00005-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38c07825790e055dd169376479994a58a4f59775ba7cf31d5ca25d8a465e7b0c
3
+ size 4991331088
model-00006-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d01880ca5677e69a5f8632f9dda62814f0c549b5a40d4f7e136065e5d64c1a7d
3
+ size 4970162920
model-00007-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e70b0e10d2ac8800e69e514b6a9b04ac28cd7db43985ce62daa4e0e639b4e5ba
3
+ size 4960543792
model-00008-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a756381ef65b92af7f1fb97da3d59cb04586080982de86d76805299898223294
3
+ size 532677104
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_cogvlm.py ADDED
@@ -0,0 +1,783 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """largely copy from llama and adapt for cogvlm"""
2
+ import warnings
3
+ from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, Any
4
+
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import CrossEntropyLoss
9
+ from torchvision import transforms
10
+ from einops import rearrange
11
+
12
+ from transformers import PreTrainedModel, PreTrainedTokenizer
13
+ from transformers.utils.logging import get_logger
14
+ from transformers.activations import ACT2FN
15
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
+
17
+ from .configuration_cogvlm import CogVLMConfig
18
+ from .util import FastRotaryEmbedding
19
+ from .visual import EVA2CLIPModel
20
+
21
+ if TYPE_CHECKING:
22
+ from transformers.utils import ModelOutput
23
+
24
+ logger = get_logger(__name__)
25
+
26
+ LANGUAGE_TOKEN_TYPE = 0
27
+ VISION_TOKEN_TYPE = 1
28
+
29
+
30
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
31
+ def _make_causal_mask(
32
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
33
+ ):
34
+ """
35
+ Make causal mask used for bi-directional self-attention.
36
+ """
37
+ bsz, tgt_len = input_ids_shape
38
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
39
+ mask_cond = torch.arange(mask.size(-1), device=device)
40
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
41
+ mask = mask.to(dtype)
42
+
43
+ if past_key_values_length > 0:
44
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
45
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
46
+
47
+
48
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
49
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
50
+ """
51
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
52
+ """
53
+ bsz, src_len = mask.size()
54
+ tgt_len = tgt_len if tgt_len is not None else src_len
55
+
56
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
57
+
58
+ inverted_mask = 1.0 - expanded_mask
59
+
60
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
61
+
62
+
63
+ class RMSNorm(nn.Module):
64
+ def __init__(self, hidden_size, eps=1e-6):
65
+ super().__init__()
66
+ self.weight = nn.Parameter(torch.ones(hidden_size))
67
+ self.variance_epsilon = eps
68
+
69
+ def forward(self, hidden_states):
70
+ input_dtype = hidden_states.dtype
71
+ hidden_states = hidden_states.to(torch.float32)
72
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
73
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
74
+ return (self.weight * hidden_states).to(input_dtype)
75
+
76
+
77
+ class MLP(nn.Module):
78
+ def __init__(self, config):
79
+ super().__init__()
80
+ self.hidden_size = config.hidden_size
81
+ self.intermediate_size = config.intermediate_size
82
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
83
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
84
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
85
+ self.act_fn = ACT2FN[config.hidden_act]
86
+
87
+ def forward(self, x):
88
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
89
+ return down_proj
90
+
91
+
92
+ def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
93
+ vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
94
+ vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] == VISION_TOKEN_TYPE)
95
+ language_token_mask = ~vision_token_mask
96
+ return vision_token_mask, language_token_mask
97
+
98
+
99
+ class VisionExpertMLP(nn.Module):
100
+ def __init__(self, config):
101
+ super().__init__()
102
+ self.language_mlp = MLP(config)
103
+ self.vision_mlp = MLP(config)
104
+
105
+ def forward(self, hidden_states: "torch.Tensor(B, L, D)", token_type_ids: "torch.LongTensor(B, L)"):
106
+ output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device)
107
+ vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
108
+ output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
109
+ output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
110
+ return output
111
+
112
+
113
+ def attention_fn(
114
+ query_layer: "torch.tensor(B, H, L, HD)",
115
+ key_layer: "torch.tensor(B, H, L, HD)",
116
+ value_layer: "torch.tensor(B, H, L, HD)",
117
+ attention_mask: "torch.tensor(B, H, L, HD)",
118
+ *,
119
+ scaling_attention_score: bool = True,
120
+ attention_dropout: nn.Module = None
121
+ ):
122
+ attention_mask_bool = (attention_mask == 0)
123
+ is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
124
+ is_full = (attention_mask_bool > 0).all()
125
+ if not (int(torch.__version__.split('.')[0]) >= 2):
126
+ warnings.warn("It's recommended to use torch2.0 or higher.")
127
+ if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
128
+ dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
129
+ return torch.nn.functional.scaled_dot_product_attention(
130
+ query_layer, key_layer, value_layer,
131
+ attn_mask=None,
132
+ dropout_p=dropout_p,
133
+ is_causal=not is_full
134
+ )
135
+ else:
136
+ if scaling_attention_score:
137
+ query_layer = query_layer / math.sqrt(query_layer.shape[-1])
138
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
139
+ attention_scores = attention_scores + attention_mask
140
+ attention_scores = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
141
+ if attention_dropout is not None:
142
+ attention_scores = attention_dropout(attention_scores)
143
+ context_layer = torch.matmul(attention_scores, value_layer)
144
+ return context_layer
145
+
146
+
147
+ class VisionExpertAttention(nn.Module):
148
+ def __init__(self, config):
149
+ super().__init__()
150
+ self.config = config
151
+ self.hidden_size = config.hidden_size
152
+ self.num_heads = config.num_attention_heads
153
+ self.head_dim = self.hidden_size // self.num_heads
154
+ self.max_position_embeddings = config.max_position_embeddings
155
+
156
+ # self.rotary_emb = RotaryEmbedding(self.hidden_size // self.num_heads)
157
+ self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False)
158
+ self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
159
+ self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
160
+ self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
161
+ self.language_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
162
+
163
+ def _transpose_for_scores(self, tensor):
164
+ """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
165
+ new_tensor_shape = tensor.size()[:-1] + (self.num_heads, self.head_dim)
166
+ tensor = tensor.view(*new_tensor_shape)
167
+ return tensor.permute(0, 2, 1, 3)
168
+
169
+ def forward(
170
+ self,
171
+ hidden_states: torch.Tensor,
172
+ token_type_ids: torch.LongTensor,
173
+ position_ids: torch.LongTensor,
174
+ attention_mask: Optional[torch.Tensor] = None,
175
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
176
+ output_attentions: bool = False,
177
+ use_cache: bool = False,
178
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
179
+ bsz, q_len, _ = hidden_states.size()
180
+ vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
181
+
182
+ shape = list(hidden_states.shape)
183
+ shape[-1] = shape[-1] * 3
184
+ mixed_raw_layer = torch.empty(shape, dtype=hidden_states.dtype, device=hidden_states.device)
185
+ mixed_raw_layer[vision_token_mask] = self.vision_expert_query_key_value(hidden_states[vision_token_mask])
186
+ mixed_raw_layer[language_token_mask] = self.language_expert_query_key_value(hidden_states[language_token_mask])
187
+
188
+ query_states, key_states, value_states = torch.split(mixed_raw_layer, self.hidden_size, dim=-1)
189
+ query_states = self._transpose_for_scores(query_states) # B, H, L, HD
190
+ key_states = self._transpose_for_scores(key_states) # B, H, L, HD
191
+ value_states = self._transpose_for_scores(value_states) # B, H, L, HD
192
+
193
+ kv_seq_len = key_states.shape[-2]
194
+ if past_key_value is not None:
195
+ kv_seq_len += past_key_value[0].shape[-2]
196
+
197
+ query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids, max_seqlen=position_ids.max() + 1)
198
+
199
+ if past_key_value is not None:
200
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
201
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
202
+
203
+ past_key_value = (key_states, value_states) if use_cache else None
204
+
205
+ context_layer = attention_fn(
206
+ query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
207
+ scaling_attention_score=True, attention_dropout=None)
208
+ if context_layer.size() != (bsz, self.num_heads, q_len, self.head_dim):
209
+ raise ValueError(
210
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
211
+ f" {context_layer.size()}"
212
+ )
213
+ context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
214
+
215
+ attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
216
+ attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
217
+ attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
218
+
219
+ if output_attentions:
220
+ warnings.warn("output_attentions is not implemented.")
221
+
222
+ return attn_output, None, past_key_value
223
+
224
+
225
+ class CogVLMDecoderLayer(nn.Module):
226
+ def __init__(self, config):
227
+ super().__init__()
228
+ self.hidden_size = config.hidden_size
229
+ self.self_attn = VisionExpertAttention(config=config)
230
+ self.mlp = VisionExpertMLP(config)
231
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
232
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
233
+
234
+ def forward(
235
+ self,
236
+ hidden_states: torch.Tensor,
237
+ token_type_ids: torch.LongTensor,
238
+ position_ids: torch.LongTensor,
239
+ attention_mask: Optional[torch.Tensor] = None,
240
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
241
+ output_attentions: Optional[bool] = False,
242
+ use_cache: Optional[bool] = False,
243
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
244
+ residual = hidden_states
245
+
246
+ hidden_states = self.input_layernorm(hidden_states)
247
+
248
+ # Self Attention
249
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
250
+ hidden_states=hidden_states,
251
+ token_type_ids=token_type_ids,
252
+ position_ids=position_ids,
253
+ attention_mask=attention_mask,
254
+ past_key_value=past_key_value,
255
+ output_attentions=output_attentions,
256
+ use_cache=use_cache,
257
+ )
258
+ hidden_states = residual + hidden_states
259
+
260
+ # Fully Connected
261
+ residual = hidden_states
262
+ hidden_states = self.post_attention_layernorm(hidden_states)
263
+ hidden_states = self.mlp(hidden_states, token_type_ids=token_type_ids)
264
+ hidden_states = residual + hidden_states
265
+
266
+ outputs = (hidden_states,)
267
+
268
+ if output_attentions:
269
+ outputs += (self_attn_weights,)
270
+
271
+ if use_cache:
272
+ outputs += (present_key_value,)
273
+
274
+ return outputs # type: ignore
275
+
276
+
277
+ class CogVLMPreTrainedModel(PreTrainedModel):
278
+ config_class = CogVLMConfig
279
+ base_model_prefix = "model"
280
+ supports_gradient_checkpointing = False
281
+ _no_split_modules = ["CogVLMDecoderLayer", "TransformerLayer"]
282
+ _skip_keys_device_placement = "past_key_values"
283
+
284
+ def _init_weights(self, module):
285
+ std = self.config.initializer_range
286
+ if isinstance(module, nn.Linear):
287
+ module.weight.data.normal_(mean=0.0, std=std)
288
+ if module.bias is not None:
289
+ module.bias.data.zero_()
290
+ elif isinstance(module, nn.Embedding):
291
+ module.weight.data.normal_(mean=0.0, std=std)
292
+ if module.padding_idx is not None:
293
+ module.weight.data[module.padding_idx].zero_()
294
+
295
+
296
+ def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
297
+ if images_list is None or len(images_list) == 0:
298
+ return True
299
+ for image_list in images_list:
300
+ if len(image_list):
301
+ return False
302
+ return True
303
+
304
+
305
+ def build_position_ids(x: "torch.BoolTensor(B, L)", attention_mask: Optional["torch.BoolTensor(B, L)"] = None) -> "torch.LongTensor(B, L)":
306
+ if attention_mask is not None:
307
+ tmp = x.clone()
308
+ tmp[~(attention_mask.bool())] = -1
309
+ else:
310
+ tmp = x.clone()
311
+ # image boi eoi token as LANGUAGE_TOKEN_TYPE
312
+ is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
313
+ is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
314
+ is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
315
+ is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
316
+ is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
317
+ tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
318
+ # final position ids
319
+ y = torch.zeros_like(x, dtype=torch.long)
320
+ y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ((tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
321
+ y = y.cumsum(dim=-1)
322
+ return y
323
+
324
+
325
+ class CogVLMModel(CogVLMPreTrainedModel):
326
+ def __init__(self, config):
327
+ super().__init__(config)
328
+ self.padding_idx = config.pad_token_id
329
+ self.vocab_size = config.vocab_size
330
+
331
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
332
+ self.layers = nn.ModuleList([CogVLMDecoderLayer(config) for _ in range(config.num_hidden_layers)])
333
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
334
+
335
+ self.vision = EVA2CLIPModel(config)
336
+
337
+ self.gradient_checkpointing = False
338
+ # Initialize weights and apply final processing
339
+ self.post_init()
340
+
341
+ def encode_images(self, images: List[List[torch.Tensor]]) -> torch.Tensor:
342
+ images_list, images = images, []
343
+
344
+ images = []
345
+ for image_list in images_list:
346
+ for image in image_list:
347
+ images.append(image)
348
+
349
+ images = torch.stack(images)
350
+ images_features = self.vision(images)
351
+ return images_features
352
+
353
+ def forward(
354
+ self,
355
+ input_ids: torch.LongTensor = None,
356
+ images: List[List[torch.Tensor]] = None,
357
+ token_type_ids: Optional[torch.LongTensor] = None,
358
+ attention_mask: Optional[torch.Tensor] = None,
359
+ position_ids: Optional[torch.LongTensor] = None,
360
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
361
+ inputs_embeds: Optional[torch.FloatTensor] = None,
362
+ use_cache: Optional[bool] = None,
363
+ output_attentions: Optional[bool] = None,
364
+ output_hidden_states: Optional[bool] = None,
365
+ return_dict: Optional[bool] = None,
366
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
367
+ """take care of image_encode, token_type_ids, position_ids and (attention_mask = None is fine)"""
368
+
369
+ if past_key_values is not None:
370
+ pass # generate mode with past_key_values. the image features are already mapped
371
+ else:
372
+ # not allow for inputs_embeds, because we want to process image feature
373
+ assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
374
+ if not is_empty(images): # multi-modality
375
+ assert token_type_ids is not None, f"multi-modality requires `token_type_ids`!"
376
+ assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
377
+ inputs_embeds = self.embed_tokens(input_ids)
378
+ images_features = self.encode_images(images)
379
+ images_features = rearrange(images_features, 'b n d -> (b n) d')
380
+ images_features = images_features.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
381
+ inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
382
+ else: # single-modality
383
+ if token_type_ids is None:
384
+ token_type_ids = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) * LANGUAGE_TOKEN_TYPE
385
+ assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
386
+ inputs_embeds = self.embed_tokens(input_ids)
387
+
388
+ if position_ids is None:
389
+ position_ids = build_position_ids(token_type_ids, attention_mask)
390
+ input_ids = None
391
+
392
+ return self.llm_forward(
393
+ input_ids=input_ids,
394
+ token_type_ids=token_type_ids,
395
+ attention_mask=attention_mask,
396
+ position_ids=position_ids,
397
+ past_key_values=past_key_values,
398
+ inputs_embeds=inputs_embeds,
399
+ use_cache=use_cache,
400
+ output_attentions=output_attentions,
401
+ output_hidden_states=output_hidden_states,
402
+ return_dict=return_dict,
403
+ )
404
+
405
+ def llm_forward(
406
+ self,
407
+ input_ids: torch.LongTensor = None,
408
+ token_type_ids: torch.LongTensor = None,
409
+ attention_mask: Optional[torch.Tensor] = None,
410
+ position_ids: Optional[torch.LongTensor] = None,
411
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
412
+ inputs_embeds: Optional[torch.FloatTensor] = None,
413
+ use_cache: Optional[bool] = None,
414
+ output_attentions: Optional[bool] = None,
415
+ output_hidden_states: Optional[bool] = None,
416
+ return_dict: Optional[bool] = None,
417
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
418
+ """largely copy from llama forward and adapt for cogvlm with `token_type_ids`"""
419
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
420
+ output_hidden_states = (
421
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
422
+ )
423
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
424
+
425
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
426
+
427
+ # retrieve input_ids and inputs_embeds
428
+ if input_ids is not None and inputs_embeds is not None:
429
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
430
+ elif input_ids is not None:
431
+ batch_size, seq_length = input_ids.shape
432
+ elif inputs_embeds is not None:
433
+ batch_size, seq_length, _ = inputs_embeds.shape
434
+ else:
435
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
436
+
437
+ seq_length_with_past = seq_length
438
+ past_key_values_length = 0
439
+
440
+ if past_key_values is not None:
441
+ past_key_values_length = past_key_values[0][0].shape[2]
442
+ seq_length_with_past = seq_length_with_past + past_key_values_length
443
+
444
+ if position_ids is None:
445
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
446
+ position_ids = torch.arange(
447
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
448
+ )
449
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
450
+ else:
451
+ position_ids = position_ids.view(-1, seq_length).long()
452
+
453
+ if inputs_embeds is None:
454
+ inputs_embeds = self.embed_tokens(input_ids)
455
+ # embed positions
456
+ if attention_mask is None:
457
+ attention_mask = torch.ones(
458
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
459
+ )
460
+ attention_mask = self._prepare_decoder_attention_mask(
461
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
462
+ )
463
+
464
+ hidden_states = inputs_embeds
465
+
466
+ # decoder layers
467
+ all_hidden_states = () if output_hidden_states else None
468
+ all_self_attns = () if output_attentions else None
469
+ next_decoder_cache = () if use_cache else None
470
+
471
+ for idx, decoder_layer in enumerate(self.layers):
472
+ if output_hidden_states:
473
+ all_hidden_states += (hidden_states,)
474
+
475
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
476
+ layer_outputs = decoder_layer(
477
+ hidden_states,
478
+ token_type_ids=token_type_ids,
479
+ attention_mask=attention_mask,
480
+ position_ids=position_ids,
481
+ past_key_value=past_key_value,
482
+ output_attentions=output_attentions,
483
+ use_cache=use_cache,
484
+ )
485
+ hidden_states = layer_outputs[0]
486
+
487
+ if use_cache:
488
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
489
+
490
+ if output_attentions:
491
+ all_self_attns += (layer_outputs[1],)
492
+
493
+ hidden_states = self.norm(hidden_states)
494
+
495
+ # add hidden states from the last decoder layer
496
+ if output_hidden_states:
497
+ all_hidden_states += (hidden_states,)
498
+
499
+ next_cache = next_decoder_cache if use_cache else None
500
+ if not return_dict:
501
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
502
+ return BaseModelOutputWithPast(
503
+ last_hidden_state=hidden_states,
504
+ past_key_values=next_cache,
505
+ hidden_states=all_hidden_states,
506
+ attentions=all_self_attns,
507
+ )
508
+
509
+ def get_input_embeddings(self):
510
+ return self.embed_tokens
511
+
512
+ def set_input_embeddings(self, value):
513
+ self.embed_tokens = value
514
+
515
+ # noinspection PyMethodMayBeStatic
516
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
517
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
518
+ # create causal mask
519
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
520
+ combined_attention_mask = None
521
+ if input_shape[-1] > 1:
522
+ combined_attention_mask = _make_causal_mask(
523
+ input_shape,
524
+ inputs_embeds.dtype,
525
+ device=inputs_embeds.device,
526
+ past_key_values_length=past_key_values_length,
527
+ )
528
+
529
+ if attention_mask is not None:
530
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
531
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
532
+ inputs_embeds.device
533
+ )
534
+ combined_attention_mask = (
535
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
536
+ )
537
+
538
+ return combined_attention_mask
539
+
540
+
541
+ def _history_to_prompt(signal_type, history, query):
542
+ if signal_type == 'base':
543
+ return query
544
+ elif signal_type == 'vqa':
545
+ answer_format = 'Short answer:'
546
+ elif signal_type == 'chat':
547
+ answer_format = 'Answer:'
548
+ else:
549
+ assert False, f"Unknown signal type {signal_type}"
550
+
551
+ prompt = ''
552
+ for i, (old_query, response) in enumerate(history):
553
+ prompt += 'Question: ' + old_query + " {} ".format(answer_format) + response + "\n"
554
+ prompt += 'Question: {} {}'.format(query, answer_format)
555
+ return prompt
556
+
557
+
558
+ class CogVLMForCausalLM(CogVLMPreTrainedModel):
559
+ _auto_class = "AutoModelForCausalLM"
560
+
561
+ def __init__(self, config):
562
+ super().__init__(config)
563
+ self.model = CogVLMModel(config)
564
+ self.vocab_size = config.vocab_size
565
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
566
+
567
+ # Initialize weights and apply final processing
568
+ self.post_init()
569
+
570
+ def get_input_embeddings(self):
571
+ return self.model.embed_tokens
572
+
573
+ def set_input_embeddings(self, value):
574
+ self.model.embed_tokens = value
575
+
576
+ def get_output_embeddings(self):
577
+ return self.lm_head
578
+
579
+ def set_output_embeddings(self, new_embeddings):
580
+ self.lm_head = new_embeddings
581
+
582
+ def set_decoder(self, decoder):
583
+ self.model = decoder
584
+
585
+ def get_decoder(self):
586
+ return self.model
587
+
588
+ def forward(
589
+ self,
590
+ input_ids: torch.LongTensor = None,
591
+ images: List[List[torch.Tensor]] = None,
592
+ token_type_ids: Optional[torch.LongTensor] = None,
593
+ attention_mask: Optional[torch.Tensor] = None,
594
+ position_ids: Optional[torch.LongTensor] = None,
595
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
596
+ inputs_embeds: Optional[torch.FloatTensor] = None,
597
+ use_cache: Optional[bool] = None,
598
+ output_attentions: Optional[bool] = None,
599
+ output_hidden_states: Optional[bool] = None,
600
+ return_dict: Optional[bool] = None,
601
+ labels: Optional[torch.LongTensor] = None,
602
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
603
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
604
+ output_hidden_states = (
605
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
606
+ )
607
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
608
+
609
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
610
+ outputs = self.model(
611
+ input_ids=input_ids,
612
+ images=images,
613
+ token_type_ids=token_type_ids,
614
+ attention_mask=attention_mask,
615
+ position_ids=position_ids,
616
+ past_key_values=past_key_values,
617
+ inputs_embeds=inputs_embeds,
618
+ use_cache=use_cache,
619
+ output_attentions=output_attentions,
620
+ output_hidden_states=output_hidden_states,
621
+ return_dict=return_dict,
622
+ )
623
+
624
+ hidden_states = outputs[0]
625
+ logits = self.lm_head(hidden_states)
626
+ logits = logits.float()
627
+
628
+ loss = None
629
+ if labels is not None:
630
+ # Shift so that tokens < n predict n
631
+ shift_logits = logits[..., :-1, :].contiguous()
632
+ shift_labels = labels[..., 1:].contiguous()
633
+ # Flatten the tokens
634
+ loss_fct = CrossEntropyLoss()
635
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
636
+ shift_labels = shift_labels.view(-1)
637
+ # Enable model parallelism
638
+ shift_labels = shift_labels.to(shift_logits.device)
639
+ loss = loss_fct(shift_logits, shift_labels)
640
+
641
+ if not return_dict:
642
+ output = (logits,) + outputs[1:]
643
+ return (loss,) + output if loss is not None else output
644
+
645
+ return CausalLMOutputWithPast(
646
+ loss=loss,
647
+ logits=logits,
648
+ past_key_values=outputs.past_key_values,
649
+ hidden_states=outputs.hidden_states,
650
+ attentions=outputs.attentions,
651
+ )
652
+
653
+ def _prepare_attention_mask_for_generation(
654
+ self,
655
+ inputs: torch.Tensor,
656
+ pad_token_id: Optional[int],
657
+ eos_token_id: Optional[Union[int, List[int]]],
658
+ ) -> torch.LongTensor:
659
+ return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
660
+
661
+ def prepare_inputs_for_generation(
662
+ self, input_ids, token_type_ids, images=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
663
+ ):
664
+ # build position_ids if needed
665
+ position_ids = kwargs.get("position_ids", None)
666
+ if position_ids is None:
667
+ position_ids = build_position_ids(token_type_ids, attention_mask)
668
+
669
+ if past_key_values:
670
+ input_ids = input_ids[:, -1:]
671
+ token_type_ids = token_type_ids[:, -1:]
672
+ position_ids = position_ids[:, -1:]
673
+
674
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
675
+ if inputs_embeds is not None and past_key_values is None:
676
+ model_inputs = {"inputs_embeds": inputs_embeds}
677
+ else:
678
+ model_inputs = {"input_ids": input_ids}
679
+
680
+ model_inputs.update(
681
+ {
682
+ "token_type_ids": token_type_ids,
683
+ "images": images,
684
+ "position_ids": position_ids,
685
+ "past_key_values": past_key_values,
686
+ "use_cache": kwargs.get("use_cache"),
687
+ "attention_mask": attention_mask,
688
+ }
689
+ )
690
+ return model_inputs
691
+
692
+ def _update_model_kwargs_for_generation(
693
+ self,
694
+ outputs: "ModelOutput",
695
+ model_kwargs: Dict[str, Any],
696
+ is_encoder_decoder: bool = False,
697
+ standardize_cache_format: bool = False,
698
+ ) -> Dict[str, Any]:
699
+ # update past_key_values
700
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
701
+ outputs, standardize_cache_format=standardize_cache_format
702
+ )
703
+ if getattr(outputs, "state", None) is not None:
704
+ model_kwargs["state"] = outputs.state
705
+
706
+ # update token_type_ids with last value
707
+ if "token_type_ids" in model_kwargs:
708
+ token_type_ids = model_kwargs["token_type_ids"]
709
+ new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype, device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
710
+ model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
711
+
712
+ if not is_encoder_decoder:
713
+ # update attention mask
714
+ if "attention_mask" in model_kwargs:
715
+ attention_mask = model_kwargs["attention_mask"]
716
+ model_kwargs["attention_mask"] = torch.cat(
717
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
718
+ )
719
+ else:
720
+ # update decoder attention mask
721
+ if "decoder_attention_mask" in model_kwargs:
722
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
723
+ model_kwargs["decoder_attention_mask"] = torch.cat(
724
+ [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
725
+ dim=-1,
726
+ )
727
+
728
+ return model_kwargs
729
+
730
+ def _reorder_cache(self, past_key_values, beam_idx):
731
+ reordered_past = ()
732
+ for layer_past in past_key_values:
733
+ reordered_past += (
734
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
735
+ )
736
+ return reordered_past
737
+
738
+ def build_conversation_input_ids(
739
+ self,
740
+ tokenizer: "PreTrainedTokenizer",
741
+ *,
742
+ query: str,
743
+ history: Optional[List[Tuple[str, str]]] = None,
744
+ images: Optional[List["PIL.Image"]] = None,
745
+ template_version: Optional[Literal["base", "chat", "vqa"]] = None,
746
+ ):
747
+ image_size: int = self.config.vision_config['image_size']
748
+ patch_size: int = self.config.vision_config['patch_size']
749
+ template_version = template_version or self.config.template_version
750
+ assert images is None or len(images) <= 1, f"not support multi images by now."
751
+ history = history or []
752
+ text = _history_to_prompt(template_version, history, query)
753
+
754
+ input_ids = [tokenizer.bos_token_id]
755
+ token_type_ids = [LANGUAGE_TOKEN_TYPE]
756
+ if images is not None and len(images) == 1:
757
+ # vision
758
+ transform = transforms.Compose(
759
+ [
760
+ transforms.Resize(
761
+ (image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC
762
+ ),
763
+ transforms.ToTensor(),
764
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
765
+ ]
766
+ )
767
+ images = [transform(images[0])]
768
+ # language
769
+ vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
770
+ input_ids += [tokenizer.pad_token_id] * vision_token_num
771
+ token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
772
+ text_ids = tokenizer.encode(text, add_special_tokens=False)
773
+
774
+ input_ids += text_ids
775
+ token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
776
+ attention_mask = [1] * len(input_ids)
777
+
778
+ return {
779
+ 'input_ids': torch.tensor(input_ids, dtype=torch.long),
780
+ 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
781
+ 'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
782
+ 'images': images,
783
+ }
util.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ from einops import rearrange, repeat
5
+ import torch.nn.functional as F
6
+
7
+ import triton
8
+ import triton.language as tl
9
+
10
+
11
+ # @triton.autotune(
12
+ # configs=[
13
+ # triton.Config({"BLOCK_M": 2}),
14
+ # triton.Config({"BLOCK_M": 4}),
15
+ # triton.Config({"BLOCK_M": 8}),
16
+ # triton.Config({"BLOCK_M": 16}),
17
+ # ],
18
+ # key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
19
+ # )
20
+ @triton.jit
21
+ def rotary_kernel(
22
+ OUT, # Pointers to matrices
23
+ X,
24
+ COS,
25
+ SIN,
26
+ CU_SEQLENS,
27
+ SEQLEN_OFFSETS, # this could be int or a pointer
28
+ # Matrix dimensions
29
+ seqlen,
30
+ nheads,
31
+ rotary_dim,
32
+ seqlen_ro,
33
+ CACHE_KEY_SEQLEN,
34
+ # strides
35
+ stride_out_batch,
36
+ stride_out_nheads,
37
+ stride_out_seqlen,
38
+ stride_out_headdim,
39
+ stride_x_batch,
40
+ stride_x_nheads,
41
+ stride_x_seqlen,
42
+ stride_x_headdim,
43
+ # Meta-parameters
44
+ BLOCK_K: tl.constexpr,
45
+ IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
46
+ IS_VARLEN: tl.constexpr,
47
+ INTERLEAVED: tl.constexpr,
48
+ CONJUGATE: tl.constexpr,
49
+ BLOCK_M: tl.constexpr,
50
+ ):
51
+ pid_m = tl.program_id(axis=0)
52
+ pid_batch = tl.program_id(axis=1)
53
+ pid_head = tl.program_id(axis=2)
54
+ rotary_dim_half = rotary_dim // 2
55
+
56
+ if not IS_VARLEN:
57
+ X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
58
+ OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
59
+ COS = COS + pid_batch * seqlen_ro * rotary_dim_half
60
+ SIN = SIN + pid_batch * seqlen_ro * rotary_dim_half
61
+ else:
62
+ start_idx = tl.load(CU_SEQLENS + pid_batch)
63
+ seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
64
+ X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
65
+ OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
66
+
67
+ if pid_m * BLOCK_M >= seqlen:
68
+ return
69
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
70
+ if not IS_SEQLEN_OFFSETS_TENSOR:
71
+ rm_cs = rm + SEQLEN_OFFSETS
72
+ else:
73
+ rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
74
+ rk = tl.arange(0, BLOCK_K)
75
+ rk_half = tl.arange(0, BLOCK_K // 2)
76
+
77
+ if not INTERLEAVED:
78
+ # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
79
+ X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
80
+ COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
81
+ SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
82
+ cos = tl.load(
83
+ COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
84
+ )
85
+ sin = tl.load(
86
+ SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
87
+ )
88
+ x0 = tl.load(
89
+ X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
90
+ )
91
+ x1 = tl.load(
92
+ X + rotary_dim_half * stride_x_headdim,
93
+ mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
94
+ other=0.0,
95
+ )
96
+ if CONJUGATE:
97
+ sin = -sin
98
+ o0 = x0 * cos - x1 * sin
99
+ o1 = x0 * sin + x1 * cos
100
+ # write back result
101
+ OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
102
+ tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
103
+ tl.store(
104
+ OUT + rotary_dim_half * stride_out_headdim,
105
+ o1,
106
+ mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
107
+ )
108
+ else:
109
+ # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
110
+ # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
111
+ # Loading x0 will be fast but x1 will be slow.
112
+ # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
113
+ # Then we do the calculation and use tl.where to pick put the right outputs for the even
114
+ # and for the odd indices.
115
+ rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
116
+ rk_repeat = tl.arange(0, BLOCK_K) // 2
117
+ X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
118
+ X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
119
+ COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
120
+ SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
121
+ cos = tl.load(
122
+ COS,
123
+ mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
124
+ other=1.0,
125
+ ).to(tl.float32)
126
+ sin = tl.load(
127
+ SIN,
128
+ mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
129
+ other=0.0,
130
+ ).to(tl.float32)
131
+ x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
132
+ tl.float32
133
+ )
134
+ x1 = tl.load(
135
+ X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
136
+ ).to(tl.float32)
137
+ if CONJUGATE:
138
+ sin = -sin
139
+ x0_cos = x0 * cos
140
+ x1_sin = x1 * sin
141
+ out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
142
+ OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
143
+ tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
144
+
145
+
146
+ def apply_rotary(
147
+ x: torch.Tensor,
148
+ cos: torch.Tensor,
149
+ sin: torch.Tensor,
150
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
151
+ cu_seqlens: Optional[torch.Tensor] = None,
152
+ max_seqlen: Optional[int] = None,
153
+ interleaved=False,
154
+ inplace=False,
155
+ conjugate=False,
156
+ ) -> torch.Tensor:
157
+ """
158
+ Arguments:
159
+ x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
160
+ else (total_seqlen, nheads, headdim).
161
+ cos: (seqlen_ro, rotary_dim / 2)
162
+ sin: (seqlen_ro, rotary_dim / 2)
163
+ seqlen_offsets: integer or integer tensor of size (batch,)
164
+ cu_seqlens: (batch + 1,) or None
165
+ max_seqlen: int
166
+ Returns:
167
+ y: (batch, seqlen, nheads, headdim)
168
+ """
169
+
170
+ batch, nheads, seqlen, headdim = x.shape
171
+
172
+ batch_ro, seqlen_ro, rotary_dim = cos.shape
173
+
174
+ assert batch == batch_ro
175
+ assert sin.shape == cos.shape
176
+ rotary_dim *= 2
177
+ assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
178
+ assert headdim <= 256, "Only support headdim <= 256"
179
+
180
+ assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
181
+
182
+ assert (
183
+ cos.dtype == sin.dtype
184
+ ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
185
+ assert (
186
+ x.dtype == cos.dtype
187
+ ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
188
+
189
+ cos, sin = cos.contiguous(), sin.contiguous()
190
+ if isinstance(seqlen_offsets, torch.Tensor):
191
+ assert seqlen_offsets.shape == (batch,)
192
+ assert seqlen_offsets.dtype in [torch.int32, torch.int64]
193
+ seqlen_offsets = seqlen_offsets.contiguous()
194
+ else:
195
+ assert seqlen_offsets + seqlen <= seqlen_ro
196
+
197
+ output = torch.empty_like(x) if not inplace else x
198
+ if rotary_dim < headdim and not inplace:
199
+ output[..., rotary_dim:].copy_(x[..., rotary_dim:])
200
+
201
+ BLOCK_K = (
202
+ 32
203
+ if rotary_dim <= 32
204
+ else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
205
+ )
206
+ grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
207
+ BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
208
+
209
+ # Need this, otherwise Triton tries to launch from cuda:0 and we get
210
+ # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
211
+ with torch.cuda.device(x.device.index):
212
+ rotary_kernel[grid](
213
+ output, # data ptrs
214
+ x,
215
+ cos,
216
+ sin,
217
+ cu_seqlens,
218
+ seqlen_offsets,
219
+ seqlen, # shapes
220
+ nheads,
221
+ rotary_dim,
222
+ seqlen_ro,
223
+ seqlen // 128, # key for triton cache (limit number of compilations)
224
+ output.stride(0), # batch_strides
225
+ output.stride(-3), # nheads_stride
226
+ output.stride(-2), # seqlen_stride
227
+ output.stride(-1), # headdim_stride
228
+ x.stride(0), # batch_strides
229
+ x.stride(-3), # nheads stride
230
+ x.stride(-2), # seqlen stride
231
+ x.stride(-1), # headdim stride
232
+ BLOCK_K,
233
+ isinstance(seqlen_offsets, torch.Tensor),
234
+ False,
235
+ interleaved,
236
+ conjugate,
237
+ BLOCK_M,
238
+ )
239
+ return output
240
+
241
+
242
+ class ApplyRotaryEmb(torch.autograd.Function):
243
+ @staticmethod
244
+ def forward(
245
+ ctx,
246
+ x,
247
+ cos,
248
+ sin,
249
+ interleaved=False,
250
+ inplace=False,
251
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
252
+ cu_seqlens: Optional[torch.Tensor] = None,
253
+ max_seqlen: Optional[int] = None,
254
+ ):
255
+ out = apply_rotary(
256
+ x,
257
+ cos,
258
+ sin,
259
+ seqlen_offsets=seqlen_offsets,
260
+ cu_seqlens=cu_seqlens,
261
+ max_seqlen=max_seqlen,
262
+ interleaved=interleaved,
263
+ inplace=inplace,
264
+ )
265
+ if isinstance(seqlen_offsets, int):
266
+ ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
267
+ ctx.seqlen_offsets = seqlen_offsets
268
+ else:
269
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
270
+ ctx.seqlen_offsets = None
271
+ ctx.interleaved = interleaved
272
+ ctx.inplace = inplace
273
+ ctx.max_seqlen = max_seqlen
274
+ return out if not inplace else x
275
+
276
+ @staticmethod
277
+ def backward(ctx, do):
278
+ seqlen_offsets = ctx.seqlen_offsets
279
+ if seqlen_offsets is None:
280
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
281
+ else:
282
+ cos, sin, cu_seqlens = ctx.saved_tensors
283
+ # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
284
+ # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
285
+ if not ctx.interleaved and not ctx.inplace:
286
+ do = do.clone()
287
+ dx = apply_rotary(
288
+ do,
289
+ cos,
290
+ sin,
291
+ seqlen_offsets=seqlen_offsets,
292
+ cu_seqlens=cu_seqlens,
293
+ max_seqlen=ctx.max_seqlen,
294
+ interleaved=ctx.interleaved,
295
+ inplace=ctx.inplace,
296
+ conjugate=True,
297
+ )
298
+ return dx, None, None, None, None, None, None, None
299
+
300
+
301
+ def apply_rotary_emb(
302
+ x,
303
+ cos,
304
+ sin,
305
+ interleaved=False,
306
+ inplace=False,
307
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
308
+ cu_seqlens: Optional[torch.Tensor] = None,
309
+ max_seqlen: Optional[int] = None,
310
+ ):
311
+ """
312
+ Arguments:
313
+ x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
314
+ else (total_seqlen, nheads, headdim)
315
+ cos, sin: (seqlen_rotary, rotary_dim / 2)
316
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
317
+ of 1st half and 2nd half (GPT-NeoX style).
318
+ inplace: if True, apply rotary embedding in-place.
319
+ seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
320
+ Most commonly used in inference when we have KV cache.
321
+ cu_seqlens: (batch + 1,) or None
322
+ max_seqlen: int
323
+ Return:
324
+ out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
325
+ else (total_seqlen, nheads, headdim)
326
+ rotary_dim must be <= headdim
327
+ Apply rotary embedding to the first rotary_dim of x.
328
+ """
329
+ return ApplyRotaryEmb.apply(
330
+ x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
331
+ )
332
+
333
+
334
+ # For backward compatibility
335
+ apply_rotary_emb_func = apply_rotary_emb
336
+
337
+
338
+ class FastRotaryEmbedding(torch.nn.Module):
339
+ """
340
+ The rotary position embeddings from RoFormer_ (Su et. al).
341
+ A crucial insight from the method is that the query and keys are
342
+ transformed by rotation matrices which depend on the relative positions.
343
+
344
+ Other implementations are available in the Rotary Transformer repo_ and in
345
+ GPT-NeoX_, GPT-NeoX was an inspiration
346
+
347
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
348
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
349
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
350
+
351
+ If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
352
+ A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
353
+ Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
354
+ """
355
+
356
+ def __init__(
357
+ self,
358
+ dim: int,
359
+ base=10000,
360
+ interleaved=False,
361
+ scale_base=None,
362
+ pos_idx_in_fp32=True,
363
+ device=None,
364
+ ):
365
+ """
366
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
367
+ of 1st half and 2nd half (GPT-NeoX style).
368
+ pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
369
+ otherwise they might be in lower precision.
370
+ This option was added because previously (before 2023-07-02), when we construct
371
+ the position indices, we use the dtype of self.inv_freq. In most cases this would
372
+ be fp32, but if the model is trained in pure bf16 (not mixed precision), then
373
+ self.inv_freq would be bf16, and the position indices are also in bf16.
374
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
375
+ embeddings for some positions will coincide.
376
+ To maintain compatibility with models previously trained in pure bf16,
377
+ we add this option.
378
+ """
379
+ super().__init__()
380
+ self.dim = dim
381
+ self.base = base
382
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
383
+ # Generate and save the inverse frequency buffer (non trainable)
384
+ inv_freq = self._compute_inv_freq(device)
385
+ self.register_buffer("inv_freq", inv_freq)
386
+ self.interleaved = interleaved
387
+ self.scale_base = scale_base
388
+ scale = (
389
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
390
+ if scale_base is not None
391
+ else None
392
+ )
393
+ self.register_buffer("scale", scale, persistent=False)
394
+
395
+ self._seq_len_cached = 0
396
+ self._cos_cached = None
397
+ self._sin_cached = None
398
+ self._cos_k_cached = None
399
+ self._sin_k_cached = None
400
+ self.cos = None
401
+ self.sin = None
402
+
403
+ def _compute_inv_freq(self, device=None):
404
+ return 1.0 / (
405
+ self.base
406
+ ** (torch.arange(0, self.dim, 2, device=device) / self.dim)
407
+ # ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)
408
+ )
409
+
410
+ def _update_cos_sin_cache(self, seqlen, position_id, device=None, dtype=None):
411
+
412
+ if (
413
+ seqlen > self._seq_len_cached
414
+ ):
415
+ self._seq_len_cached = seqlen
416
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
417
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
418
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
419
+ if self.pos_idx_in_fp32:
420
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
421
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
422
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
423
+ # cos & sin output to change significantly.
424
+ # We want to recompute self.inv_freq if it was not loaded in fp32
425
+ if self.inv_freq.dtype != torch.float32:
426
+ inv_freq = self._compute_inv_freq(device=device)
427
+ else:
428
+ inv_freq = self.inv_freq
429
+ else:
430
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
431
+ inv_freq = self.inv_freq
432
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
433
+ if self.scale is None:
434
+ self._cos_cached = torch.cos(freqs).to(dtype)
435
+ self._sin_cached = torch.sin(freqs).to(dtype)
436
+
437
+ else:
438
+ power = (
439
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
440
+ - seqlen // 2
441
+ ) / self.scale_base
442
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
443
+ # We want the multiplication by scale to happen in fp32
444
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
445
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
446
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
447
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
448
+
449
+ def forward(
450
+ self,
451
+ q: torch.Tensor,
452
+ k: torch.Tensor,
453
+ position_ids: torch.Tensor,
454
+ max_seqlen,
455
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
456
+ """
457
+ q: (batch, nheads, seqlen, headdim)
458
+ k: (batch, nheads, seqlen, headdim)
459
+ position_id: (batch, seqlen)
460
+ max_seqlen: int
461
+ layer_id: int
462
+ only if layer_id == 0, then update cons and sin
463
+ Apply rotary embedding *inplace* to q k.
464
+ """
465
+
466
+ self._update_cos_sin_cache(max_seqlen, position_ids, device=q.device, dtype=q.dtype)
467
+ cos, sin = F.embedding(position_ids, self._cos_cached), F.embedding(position_ids, self._sin_cached)
468
+
469
+ q = apply_rotary_emb_func(
470
+ q,
471
+ cos,
472
+ sin,
473
+ interleaved=self.interleaved,
474
+ inplace=True
475
+ )
476
+ k = apply_rotary_emb_func(
477
+ k,
478
+ cos,
479
+ sin,
480
+ interleaved=self.interleaved,
481
+ inplace=True
482
+ )
483
+ return q, k
visual.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from argparse import Namespace
4
+ import xformers.ops as xops
5
+ from transformers.activations import ACT2FN
6
+
7
+
8
+ class PatchEmbedding(nn.Module):
9
+ def __init__(self, config):
10
+ super().__init__()
11
+ self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size)
12
+ self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
13
+ self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
14
+
15
+ def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
16
+ x = self.proj(images)
17
+ x = x.flatten(2).transpose(1, 2)
18
+ cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
19
+ x = torch.cat((cls_token, x), dim=1)
20
+ x += self.position_embedding.weight.unsqueeze(0)
21
+ return x
22
+
23
+
24
+ class Attention(nn.Module):
25
+ def __init__(self, config):
26
+ super().__init__()
27
+ self.num_heads = config.num_heads
28
+ head_dim = config.hidden_size // config.num_heads
29
+ self.scale = head_dim ** -0.5
30
+ self.query_key_value = nn.Linear(config.hidden_size, config.hidden_size * 3)
31
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
32
+ self.output_dropout = torch.nn.Dropout(config.dropout_prob)
33
+
34
+ def forward(self, x: "tensor(B, L, D)") -> "tensor(B, L, D)":
35
+ B, L, _ = x.shape
36
+ qkv = self.query_key_value(x)
37
+ qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) # 3, B, L, H, D
38
+ q, k, v = qkv[0], qkv[1], qkv[2]
39
+
40
+ out = xops.memory_efficient_attention(
41
+ q, k, v, scale=self.scale,
42
+ )
43
+ output = self.dense(out.view(B, L, -1))
44
+ output = self.output_dropout(output)
45
+ return output
46
+
47
+ def attention(self, q, k, v):
48
+ attn_weights = torch.matmul(q * self.scale, k.transpose(-2, -1))
49
+ attn_weights = attn_weights.softmax(dim=-1)
50
+ output = torch.matmul(attn_weights, v)
51
+ return output
52
+
53
+
54
+ class MLP(nn.Module):
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.config = config
58
+ self.activation_fn = ACT2FN[config.hidden_act]
59
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
60
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
61
+
62
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
63
+ x = self.fc1(x)
64
+ x = self.activation_fn(x)
65
+ x = self.fc2(x)
66
+ return x
67
+
68
+
69
+ class TransformerLayer(nn.Module):
70
+ def __init__(self, config):
71
+ super().__init__()
72
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
73
+ self.attention = Attention(config)
74
+ self.mlp = MLP(config)
75
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
76
+
77
+ def forward(self, hidden_states):
78
+ attention_input = hidden_states
79
+ attention_output = self.input_layernorm(self.attention(attention_input))
80
+ hidden_states = attention_input + attention_output
81
+ mlp_input = hidden_states
82
+ mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
83
+ output = mlp_input + mlp_output
84
+ return output
85
+
86
+
87
+ class Transformer(nn.Module):
88
+ def __init__(self, config):
89
+ super().__init__()
90
+ self.layers = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)])
91
+
92
+ def forward(self, hidden_states):
93
+ for layer_module in self.layers:
94
+ hidden_states = layer_module(hidden_states)
95
+ return hidden_states
96
+
97
+
98
+ class GLU(nn.Module):
99
+ def __init__(self, config, in_features):
100
+ super().__init__()
101
+ self.linear_proj = nn.Linear(in_features, config.hidden_size, bias=False)
102
+ self.norm1 = nn.LayerNorm(config.hidden_size)
103
+ self.act1 = nn.GELU()
104
+ self.act2 = nn.functional.silu
105
+ self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
106
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
107
+ self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
108
+
109
+ def forward(self, x):
110
+ x = self.linear_proj(x)
111
+ x = self.act1(self.norm1(x))
112
+ x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
113
+ x = self.dense_4h_to_h(x)
114
+ return x
115
+
116
+
117
+ class EVA2CLIPModel(nn.Module):
118
+ def __init__(self, config):
119
+ super().__init__()
120
+ vision_config = Namespace(**config.vision_config)
121
+ self.patch_embedding = PatchEmbedding(vision_config)
122
+ self.transformer = Transformer(vision_config)
123
+ self.linear_proj = GLU(config, in_features=vision_config.hidden_size)
124
+ self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
125
+ self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
126
+
127
+ def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
128
+ x = self.patch_embedding(images)
129
+ x = self.transformer(x)
130
+ x = x[:, 1:]
131
+ x = self.linear_proj(x)
132
+ boi = self.boi.expand(x.shape[0], -1, -1)
133
+ eoi = self.eoi.expand(x.shape[0], -1, -1)
134
+ x = torch.cat((boi, x, eoi), dim=1)
135
+ return x