qingsonglv
commited on
Commit
•
a2c2a1e
1
Parent(s):
ff4a7b3
upload model
Browse files- README.md +153 -0
- config.json +43 -0
- configuration_cogagent.py +51 -0
- cross_visual.py +797 -0
- generation_config.json +7 -0
- model-00001-of-00008.safetensors +3 -0
- model-00002-of-00008.safetensors +3 -0
- model-00003-of-00008.safetensors +3 -0
- model-00004-of-00008.safetensors +3 -0
- model-00005-of-00008.safetensors +3 -0
- model-00006-of-00008.safetensors +3 -0
- model-00007-of-00008.safetensors +3 -0
- model-00008-of-00008.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_cogagent.py +917 -0
- util.py +483 -0
- visual.py +136 -0
README.md
CHANGED
@@ -1,3 +1,156 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
+
language:
|
4 |
+
- en
|
5 |
---
|
6 |
+
# CogAgent
|
7 |
+
|
8 |
+
## Introduction
|
9 |
+
|
10 |
+
**CogAgent** is an open-source visual language model improved based on **CogVLM**.
|
11 |
+
|
12 |
+
📖 Paper: https://arxiv.org/abs/2312.08914
|
13 |
+
|
14 |
+
**CogAgent-18B** has 11 billion visual parameters and 7 billion language parameters and achieves state-of-the-art generalist performance on 9 classic cross-modal benchmarks, including:
|
15 |
+
+ VQAv2
|
16 |
+
+ OK-VQ
|
17 |
+
+ TextVQA
|
18 |
+
+ ST-VQA
|
19 |
+
+ ChartQA
|
20 |
+
+ infoVQA
|
21 |
+
+ DocVQA
|
22 |
+
+ MM-Vet
|
23 |
+
+ POPE
|
24 |
+
|
25 |
+
**CogAgent-18B** significantly surpasses existing models on GUI operation datasets such as AITW and Mind2Web.
|
26 |
+
|
27 |
+
In addition to all the features already present in **CogVLM** (visual multi-round dialogue, visual grounding), **CogAgent**:
|
28 |
+
|
29 |
+
1. Supports higher resolution visual input and dialogue question-answering. It supports ultra-high-resolution image inputs of **1120x1120**.
|
30 |
+
|
31 |
+
2. Possesses the capabilities of a visual Agent, being able to return a plan, next action, and specific operations with coordinates for any given task on any GUI screenshot.
|
32 |
+
|
33 |
+
3. Enhanced GUI-related question-answering capabilities, allowing it to handle questions about any GUI screenshot, such as web pages, PC apps, mobile applications, etc.
|
34 |
+
|
35 |
+
4. Enhanced capabilities in OCR-related tasks through improved pre-training and fine-tuning.
|
36 |
+
|
37 |
+
<div align="center">
|
38 |
+
<img src="https://raw.githubusercontent.com/THUDM/CogVLM/master/assets/cogagent_function.jpg" alt="img" style="zoom: 50%;" />
|
39 |
+
</div>
|
40 |
+
|
41 |
+
## Quick Start
|
42 |
+
|
43 |
+
use this python code to get started quickly in `cli_demo.py`:
|
44 |
+
|
45 |
+
```python
|
46 |
+
import torch
|
47 |
+
from PIL import Image
|
48 |
+
from transformers import AutoModelForCausalLM, LlamaTokenizer
|
49 |
+
import argparse
|
50 |
+
|
51 |
+
parser = argparse.ArgumentParser()
|
52 |
+
parser.add_argument("--quant", choices=[4], type=int, default=None, help='quantization bits')
|
53 |
+
parser.add_argument("--from_pretrained", type=str, default="THUDM/cogagent-chat-hf", help='pretrained ckpt')
|
54 |
+
parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path')
|
55 |
+
parser.add_argument("--fp16", action="store_true")
|
56 |
+
parser.add_argument("--bf16", action="store_true")
|
57 |
+
|
58 |
+
args = parser.parse_args()
|
59 |
+
MODEL_PATH = args.from_pretrained
|
60 |
+
TOKENIZER_PATH = args.local_tokenizer
|
61 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
62 |
+
|
63 |
+
tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH)
|
64 |
+
if args.bf16:
|
65 |
+
torch_type = torch.bfloat16
|
66 |
+
else:
|
67 |
+
torch_type = torch.float16
|
68 |
+
|
69 |
+
print("========Use torch type as:{} with device:{}========\n\n".format(torch_type, DEVICE))
|
70 |
+
|
71 |
+
if args.quant:
|
72 |
+
model = AutoModelForCausalLM.from_pretrained(
|
73 |
+
MODEL_PATH,
|
74 |
+
torch_dtype=torch_type,
|
75 |
+
low_cpu_mem_usage=True,
|
76 |
+
load_in_4bit=True,
|
77 |
+
trust_remote_code=True
|
78 |
+
).eval()
|
79 |
+
else:
|
80 |
+
model = AutoModelForCausalLM.from_pretrained(
|
81 |
+
MODEL_PATH,
|
82 |
+
torch_dtype=torch_type,
|
83 |
+
low_cpu_mem_usage=True,
|
84 |
+
load_in_4bit=args.quant is not None,
|
85 |
+
trust_remote_code=True
|
86 |
+
).to(DEVICE).eval()
|
87 |
+
|
88 |
+
while True:
|
89 |
+
image_path = input("image path >>>>> ")
|
90 |
+
if image_path == "stop":
|
91 |
+
break
|
92 |
+
|
93 |
+
image = Image.open(image_path).convert('RGB')
|
94 |
+
history = []
|
95 |
+
while True:
|
96 |
+
query = input("Human:")
|
97 |
+
if query == "clear":
|
98 |
+
break
|
99 |
+
input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image])
|
100 |
+
inputs = {
|
101 |
+
'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE),
|
102 |
+
'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE),
|
103 |
+
'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE),
|
104 |
+
'images': [[input_by_model['images'][0].to(DEVICE).to(torch_type)]],
|
105 |
+
}
|
106 |
+
if 'cross_images' in input_by_model and input_by_model['cross_images']:
|
107 |
+
inputs['cross_images'] = [[input_by_model['cross_images'][0].to(DEVICE).to(torch_type)]]
|
108 |
+
|
109 |
+
# add any transformers params here.
|
110 |
+
gen_kwargs = {"max_length": 2048,
|
111 |
+
"temperature": 0.9,
|
112 |
+
"do_sample": False}
|
113 |
+
with torch.no_grad():
|
114 |
+
outputs = model.generate(**inputs, **gen_kwargs)
|
115 |
+
outputs = outputs[:, inputs['input_ids'].shape[1]:]
|
116 |
+
response = tokenizer.decode(outputs[0])
|
117 |
+
response = response.split("</s>")[0]
|
118 |
+
print("\nCog:", response)
|
119 |
+
history.append((query, response))
|
120 |
+
```
|
121 |
+
|
122 |
+
Then run:
|
123 |
+
|
124 |
+
```bash
|
125 |
+
python cli_demo_hf.py --bf16
|
126 |
+
```
|
127 |
+
for more information such as Web Demo and Finetune, please refer to [Our GitHub](https://github.com/THUDM/CogVLM/)
|
128 |
+
|
129 |
+
## License
|
130 |
+
|
131 |
+
The code in this repository is open source under the [Apache-2.0 license](./LICENSE), while the use of CogAgent and CogVLM model weights must comply with the [Model License](./MODEL_LICENSE).
|
132 |
+
|
133 |
+
## Citation & Acknowledgements
|
134 |
+
|
135 |
+
If you find our work helpful, please consider citing the following papers
|
136 |
+
|
137 |
+
```
|
138 |
+
@misc{hong2023cogagent,
|
139 |
+
title={CogAgent: A Visual Language Model for GUI Agents},
|
140 |
+
author={Wenyi Hong and Weihan Wang and Qingsong Lv and Jiazheng Xu and Wenmeng Yu and Junhui Ji and Yan Wang and Zihan Wang and Yuxiao Dong and Ming Ding and Jie Tang},
|
141 |
+
year={2023},
|
142 |
+
eprint={2312.08914},
|
143 |
+
archivePrefix={arXiv},
|
144 |
+
primaryClass={cs.CV}
|
145 |
+
}
|
146 |
+
|
147 |
+
@misc{wang2023cogvlm,
|
148 |
+
title={CogVLM: Visual Expert for Pretrained Language Models},
|
149 |
+
author={Weihan Wang and Qingsong Lv and Wenmeng Yu and Wenyi Hong and Ji Qi and Yan Wang and Junhui Ji and Zhuoyi Yang and Lei Zhao and Xixuan Song and Jiazheng Xu and Bin Xu and Juanzi Li and Yuxiao Dong and Ming Ding and Jie Tang},
|
150 |
+
year={2023},
|
151 |
+
eprint={2311.03079},
|
152 |
+
archivePrefix={arXiv},
|
153 |
+
primaryClass={cs.CV}
|
154 |
+
}
|
155 |
+
```
|
156 |
+
In the instruction fine-tuning phase of the CogVLM, there are some English image-text data from the [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4), [LLAVA](https://github.com/haotian-liu/LLaVA), [LRV-Instruction](https://github.com/FuxiaoLiu/LRV-Instruction), [LLaVAR](https://github.com/SALT-NLP/LLaVAR) and [Shikra](https://github.com/shikras/shikra) projects, as well as many classic cross-modal work datasets. We sincerely thank them for their contributions.
|
config.json
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "cogagent",
|
3 |
+
"architectures": [
|
4 |
+
"CogAgentForCausalLM"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "configuration_cogagent.CogAgentConfig",
|
8 |
+
"AutoModelForCausalLM": "modeling_cogagent.CogAgentForCausalLM"
|
9 |
+
},
|
10 |
+
"bos_token_id": 1,
|
11 |
+
"cross_compute_hidden_size": 1024,
|
12 |
+
"cross_hidden_size": 1024,
|
13 |
+
"cross_image_size": 1120,
|
14 |
+
"eos_token_id": 2,
|
15 |
+
"hidden_act": "silu",
|
16 |
+
"hidden_size": 4096,
|
17 |
+
"initializer_range": 0.02,
|
18 |
+
"intermediate_size": 11008,
|
19 |
+
"max_position_embeddings": 2048,
|
20 |
+
"num_attention_heads": 32,
|
21 |
+
"num_hidden_layers": 32,
|
22 |
+
"pad_token_id": 0,
|
23 |
+
"rms_norm_eps": 1e-05,
|
24 |
+
"template_version": "chat_old",
|
25 |
+
"tie_word_embeddings": false,
|
26 |
+
"torch_dtype": "bfloat16",
|
27 |
+
"transformers_version": "4.36.0.dev0",
|
28 |
+
"use_cache": true,
|
29 |
+
"vision_config": {
|
30 |
+
"dropout_prob": 0.0,
|
31 |
+
"hidden_act": "gelu",
|
32 |
+
"hidden_size": 1792,
|
33 |
+
"image_size": 224,
|
34 |
+
"in_channels": 3,
|
35 |
+
"intermediate_size": 15360,
|
36 |
+
"layer_norm_eps": 1e-06,
|
37 |
+
"num_heads": 16,
|
38 |
+
"num_hidden_layers": 63,
|
39 |
+
"num_positions": 257,
|
40 |
+
"patch_size": 14
|
41 |
+
},
|
42 |
+
"vocab_size": 32000
|
43 |
+
}
|
configuration_cogagent.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
from transformers import PretrainedConfig
|
3 |
+
|
4 |
+
|
5 |
+
class CogAgentConfig(PretrainedConfig):
|
6 |
+
_auto_class = "AutoConfig"
|
7 |
+
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
vocab_size=32000,
|
11 |
+
hidden_size=4096,
|
12 |
+
cross_hidden_size=1024,
|
13 |
+
cross_compute_hidden_size=1024,
|
14 |
+
cross_image_size=1120,
|
15 |
+
intermediate_size=11008,
|
16 |
+
num_hidden_layers=32,
|
17 |
+
num_attention_heads=32,
|
18 |
+
hidden_act='silu',
|
19 |
+
max_position_embeddings=2048,
|
20 |
+
initializer_range=0.02,
|
21 |
+
rms_norm_eps=1e-06,
|
22 |
+
template_version: Literal["base", "chat"] = "chat",
|
23 |
+
|
24 |
+
pad_token_id=0,
|
25 |
+
bos_token_id=1,
|
26 |
+
eos_token_id=2,
|
27 |
+
tie_word_embeddings=False,
|
28 |
+
use_cache=True,
|
29 |
+
**kwargs,
|
30 |
+
):
|
31 |
+
self.hidden_size = hidden_size
|
32 |
+
self.cross_hidden_size = cross_hidden_size
|
33 |
+
self.cross_compute_hidden_size = cross_compute_hidden_size
|
34 |
+
self.cross_image_size = cross_image_size
|
35 |
+
self.intermediate_size = intermediate_size
|
36 |
+
self.num_attention_heads = num_attention_heads
|
37 |
+
self.max_position_embeddings = max_position_embeddings
|
38 |
+
self.rms_norm_eps = rms_norm_eps
|
39 |
+
self.initializer_range = initializer_range
|
40 |
+
self.vocab_size = vocab_size
|
41 |
+
self.num_hidden_layers = num_hidden_layers
|
42 |
+
self.hidden_act = hidden_act
|
43 |
+
self.template_version = template_version
|
44 |
+
self.use_cache = use_cache
|
45 |
+
super().__init__(
|
46 |
+
pad_token_id=pad_token_id,
|
47 |
+
bos_token_id=bos_token_id,
|
48 |
+
eos_token_id=eos_token_id,
|
49 |
+
tie_word_embeddings=tie_word_embeddings,
|
50 |
+
**kwargs,
|
51 |
+
)
|
cross_visual.py
ADDED
@@ -0,0 +1,797 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import pi
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
import logging
|
6 |
+
|
7 |
+
def broadcat(tensors, dim = -1):
|
8 |
+
num_tensors = len(tensors)
|
9 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
10 |
+
assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
|
11 |
+
shape_len = list(shape_lens)[0]
|
12 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
13 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
14 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
15 |
+
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
|
16 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
17 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
18 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
19 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
20 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
21 |
+
return torch.cat(tensors, dim = dim)
|
22 |
+
|
23 |
+
def rotate_half(x):
|
24 |
+
x = rearrange(x, '... (d r) -> ... d r', r = 2)
|
25 |
+
x1, x2 = x.unbind(dim = -1)
|
26 |
+
x = torch.stack((-x2, x1), dim = -1)
|
27 |
+
return rearrange(x, '... d r -> ... (d r)')
|
28 |
+
|
29 |
+
class VisionRotaryEmbeddingFast(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
dim,
|
33 |
+
pt_seq_len,
|
34 |
+
ft_seq_len=None,
|
35 |
+
custom_freqs = None,
|
36 |
+
freqs_for = 'lang',
|
37 |
+
theta = 10000,
|
38 |
+
max_freq = 10,
|
39 |
+
num_freqs = 1,
|
40 |
+
patch_dropout = 0.
|
41 |
+
):
|
42 |
+
super().__init__()
|
43 |
+
if custom_freqs:
|
44 |
+
freqs = custom_freqs
|
45 |
+
elif freqs_for == 'lang':
|
46 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
47 |
+
elif freqs_for == 'pixel':
|
48 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
49 |
+
elif freqs_for == 'constant':
|
50 |
+
freqs = torch.ones(num_freqs).float()
|
51 |
+
else:
|
52 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
53 |
+
|
54 |
+
if ft_seq_len is None: ft_seq_len = pt_seq_len
|
55 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
56 |
+
|
57 |
+
freqs = torch.einsum('..., f -> ... f', t, freqs)
|
58 |
+
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
|
59 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
|
60 |
+
|
61 |
+
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
|
62 |
+
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
|
63 |
+
|
64 |
+
self.patch_dropout = patch_dropout
|
65 |
+
|
66 |
+
self.register_buffer("freqs_cos", freqs_cos)
|
67 |
+
self.register_buffer("freqs_sin", freqs_sin)
|
68 |
+
|
69 |
+
logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
|
70 |
+
|
71 |
+
def forward(self, t, patch_indices_keep=None):
|
72 |
+
if patch_indices_keep is not None:
|
73 |
+
batch = t.size()[0]
|
74 |
+
batch_indices = torch.arange(batch)
|
75 |
+
batch_indices = batch_indices[..., None]
|
76 |
+
|
77 |
+
freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
|
78 |
+
freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
|
79 |
+
|
80 |
+
freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
|
81 |
+
freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
|
82 |
+
freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
|
83 |
+
freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
|
84 |
+
|
85 |
+
return t * freqs_cos + rotate_half(t) * freqs_sin
|
86 |
+
|
87 |
+
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
|
88 |
+
|
89 |
+
import torch.nn as nn
|
90 |
+
import os
|
91 |
+
from dataclasses import dataclass
|
92 |
+
from typing import Optional, Tuple, Union
|
93 |
+
from functools import partial
|
94 |
+
|
95 |
+
import numpy as np
|
96 |
+
import torch
|
97 |
+
import torch.nn.functional as F
|
98 |
+
from torch import nn
|
99 |
+
|
100 |
+
# --------------------------------------------------------
|
101 |
+
# Adapted from https://github.com/microsoft/unilm/tree/master/beit
|
102 |
+
# --------------------------------------------------------
|
103 |
+
import math
|
104 |
+
import os
|
105 |
+
from functools import partial
|
106 |
+
import torch
|
107 |
+
import torch.nn as nn
|
108 |
+
import torch.nn.functional as F
|
109 |
+
import logging
|
110 |
+
try:
|
111 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
112 |
+
except:
|
113 |
+
from timm.layers import drop_path, to_2tuple, trunc_normal_
|
114 |
+
|
115 |
+
class PatchDropout(nn.Module):
|
116 |
+
"""
|
117 |
+
https://arxiv.org/abs/2212.00794
|
118 |
+
"""
|
119 |
+
|
120 |
+
def __init__(self, prob, exclude_first_token=True):
|
121 |
+
super().__init__()
|
122 |
+
assert 0 <= prob < 1.
|
123 |
+
self.prob = prob
|
124 |
+
self.exclude_first_token = exclude_first_token # exclude CLS token
|
125 |
+
logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
if not self.training or self.prob == 0.:
|
129 |
+
return x
|
130 |
+
|
131 |
+
if self.exclude_first_token:
|
132 |
+
cls_tokens, x = x[:, :1], x[:, 1:]
|
133 |
+
else:
|
134 |
+
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
135 |
+
|
136 |
+
batch = x.size()[0]
|
137 |
+
num_tokens = x.size()[1]
|
138 |
+
|
139 |
+
batch_indices = torch.arange(batch)
|
140 |
+
batch_indices = batch_indices[..., None]
|
141 |
+
|
142 |
+
keep_prob = 1 - self.prob
|
143 |
+
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
144 |
+
|
145 |
+
rand = torch.randn(batch, num_tokens)
|
146 |
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
147 |
+
|
148 |
+
x = x[batch_indices, patch_indices_keep]
|
149 |
+
|
150 |
+
if self.exclude_first_token:
|
151 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
152 |
+
|
153 |
+
if self.training and os.getenv('RoPE') == '1':
|
154 |
+
return x, patch_indices_keep
|
155 |
+
|
156 |
+
return x
|
157 |
+
|
158 |
+
if os.getenv('ENV_TYPE') == 'deepspeed':
|
159 |
+
try:
|
160 |
+
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
|
161 |
+
except:
|
162 |
+
from torch.utils.checkpoint import checkpoint
|
163 |
+
else:
|
164 |
+
from torch.utils.checkpoint import checkpoint
|
165 |
+
|
166 |
+
import xformers.ops as xops
|
167 |
+
|
168 |
+
class DropPath(nn.Module):
|
169 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
170 |
+
"""
|
171 |
+
def __init__(self, drop_prob=None):
|
172 |
+
super(DropPath, self).__init__()
|
173 |
+
self.drop_prob = drop_prob
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
return drop_path(x, self.drop_prob, self.training)
|
177 |
+
|
178 |
+
def extra_repr(self) -> str:
|
179 |
+
return 'p={}'.format(self.drop_prob)
|
180 |
+
|
181 |
+
|
182 |
+
class Mlp(nn.Module):
|
183 |
+
def __init__(
|
184 |
+
self,
|
185 |
+
in_features,
|
186 |
+
hidden_features=None,
|
187 |
+
out_features=None,
|
188 |
+
act_layer=nn.GELU,
|
189 |
+
norm_layer=nn.LayerNorm,
|
190 |
+
drop=0.,
|
191 |
+
subln=False,
|
192 |
+
|
193 |
+
):
|
194 |
+
super().__init__()
|
195 |
+
out_features = out_features or in_features
|
196 |
+
hidden_features = hidden_features or in_features
|
197 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
198 |
+
self.act = act_layer()
|
199 |
+
|
200 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
201 |
+
|
202 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
203 |
+
self.drop = nn.Dropout(drop)
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
x = self.fc1(x)
|
207 |
+
x = self.act(x)
|
208 |
+
# x = self.drop(x)
|
209 |
+
# commit this for the orignal BERT implement
|
210 |
+
x = self.ffn_ln(x)
|
211 |
+
|
212 |
+
x = self.fc2(x)
|
213 |
+
x = self.drop(x)
|
214 |
+
return x
|
215 |
+
|
216 |
+
class SwiGLU(nn.Module):
|
217 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
|
218 |
+
norm_layer=nn.LayerNorm, subln=False):
|
219 |
+
super().__init__()
|
220 |
+
out_features = out_features or in_features
|
221 |
+
hidden_features = hidden_features or in_features
|
222 |
+
|
223 |
+
self.w1 = nn.Linear(in_features, hidden_features)
|
224 |
+
self.w2 = nn.Linear(in_features, hidden_features)
|
225 |
+
|
226 |
+
self.act = act_layer()
|
227 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
228 |
+
self.w3 = nn.Linear(hidden_features, out_features)
|
229 |
+
|
230 |
+
self.drop = nn.Dropout(drop)
|
231 |
+
|
232 |
+
def forward(self, x):
|
233 |
+
x1 = self.w1(x)
|
234 |
+
x2 = self.w2(x)
|
235 |
+
hidden = self.act(x1) * x2
|
236 |
+
x = self.ffn_ln(hidden)
|
237 |
+
x = self.w3(x)
|
238 |
+
x = self.drop(x)
|
239 |
+
return x
|
240 |
+
|
241 |
+
class Attention(nn.Module):
|
242 |
+
def __init__(
|
243 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
244 |
+
proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
|
245 |
+
super().__init__()
|
246 |
+
self.num_heads = num_heads
|
247 |
+
head_dim = dim // num_heads
|
248 |
+
if attn_head_dim is not None:
|
249 |
+
head_dim = attn_head_dim
|
250 |
+
all_head_dim = head_dim * self.num_heads
|
251 |
+
self.scale = qk_scale or head_dim ** -0.5
|
252 |
+
|
253 |
+
self.subln = subln
|
254 |
+
if self.subln:
|
255 |
+
self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
|
256 |
+
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
|
257 |
+
self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
|
258 |
+
else:
|
259 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
260 |
+
|
261 |
+
if qkv_bias:
|
262 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
263 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
264 |
+
else:
|
265 |
+
self.q_bias = None
|
266 |
+
self.v_bias = None
|
267 |
+
|
268 |
+
if window_size:
|
269 |
+
self.window_size = window_size
|
270 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
271 |
+
self.relative_position_bias_table = nn.Parameter(
|
272 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
273 |
+
# cls to token & token 2 cls & cls to cls
|
274 |
+
|
275 |
+
# get pair-wise relative position index for each token inside the window
|
276 |
+
coords_h = torch.arange(window_size[0])
|
277 |
+
coords_w = torch.arange(window_size[1])
|
278 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
279 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
280 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
281 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
282 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
283 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
284 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
285 |
+
relative_position_index = \
|
286 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
|
287 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
288 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
289 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
290 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
291 |
+
|
292 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
293 |
+
else:
|
294 |
+
self.window_size = None
|
295 |
+
self.relative_position_bias_table = None
|
296 |
+
self.relative_position_index = None
|
297 |
+
|
298 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
299 |
+
self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
|
300 |
+
# self.proj = nn.Linear(all_head_dim, all_head_dim)
|
301 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
302 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
303 |
+
self.xattn = xattn
|
304 |
+
self.xattn_drop = attn_drop
|
305 |
+
|
306 |
+
self.rope = rope
|
307 |
+
|
308 |
+
def forward(self, x, rel_pos_bias=None, attn_mask=None):
|
309 |
+
B, N, C = x.shape
|
310 |
+
if self.subln:
|
311 |
+
if self.q_proj.weight.dtype == torch.uint8:
|
312 |
+
import bitsandbytes as bnb
|
313 |
+
q = bnb.matmul_4bit(x, self.q_proj.weight.t(), bias=self.q_bias, quant_state=self.q_proj.weight.quant_state)
|
314 |
+
k = bnb.matmul_4bit(x, self.k_proj.weight.t(), bias=None, quant_state=self.k_proj.weight.quant_state)
|
315 |
+
v = bnb.matmul_4bit(x, self.v_proj.weight.t(), bias=self.v_bias, quant_state=self.v_proj.weight.quant_state)
|
316 |
+
else:
|
317 |
+
q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
|
318 |
+
k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
|
319 |
+
v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
|
320 |
+
|
321 |
+
q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
|
322 |
+
k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
323 |
+
v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
324 |
+
else:
|
325 |
+
|
326 |
+
qkv_bias = None
|
327 |
+
if self.q_bias is not None:
|
328 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
329 |
+
|
330 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
331 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
|
332 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
333 |
+
|
334 |
+
if self.rope:
|
335 |
+
# slightly fast impl
|
336 |
+
q_t = q[:, :, 1:, :]
|
337 |
+
ro_q_t = self.rope(q_t)
|
338 |
+
q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
|
339 |
+
|
340 |
+
k_t = k[:, :, 1:, :]
|
341 |
+
ro_k_t = self.rope(k_t)
|
342 |
+
k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
|
343 |
+
|
344 |
+
if self.xattn:
|
345 |
+
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
|
346 |
+
k = k.permute(0, 2, 1, 3)
|
347 |
+
v = v.permute(0, 2, 1, 3)
|
348 |
+
|
349 |
+
x = xops.memory_efficient_attention(
|
350 |
+
q, k, v,
|
351 |
+
p=self.xattn_drop,
|
352 |
+
scale=self.scale,
|
353 |
+
)
|
354 |
+
x = x.reshape(B, N, -1)
|
355 |
+
x = self.inner_attn_ln(x)
|
356 |
+
x = self.proj(x)
|
357 |
+
x = self.proj_drop(x)
|
358 |
+
else:
|
359 |
+
q = q * self.scale
|
360 |
+
attn = (q @ k.transpose(-2, -1))
|
361 |
+
|
362 |
+
if self.relative_position_bias_table is not None:
|
363 |
+
relative_position_bias = \
|
364 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
365 |
+
self.window_size[0] * self.window_size[1] + 1,
|
366 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
367 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
368 |
+
attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
|
369 |
+
|
370 |
+
if rel_pos_bias is not None:
|
371 |
+
attn = attn + rel_pos_bias.type_as(attn)
|
372 |
+
|
373 |
+
if attn_mask is not None:
|
374 |
+
attn_mask = attn_mask.bool()
|
375 |
+
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
|
376 |
+
|
377 |
+
attn = attn.softmax(dim=-1)
|
378 |
+
attn = self.attn_drop(attn)
|
379 |
+
|
380 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
381 |
+
x = self.inner_attn_ln(x)
|
382 |
+
x = self.proj(x)
|
383 |
+
x = self.proj_drop(x)
|
384 |
+
return x
|
385 |
+
|
386 |
+
|
387 |
+
class Block(nn.Module):
|
388 |
+
|
389 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
390 |
+
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
391 |
+
window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
|
392 |
+
subln=False, naiveswiglu=False):
|
393 |
+
super().__init__()
|
394 |
+
self.norm1 = norm_layer(dim)
|
395 |
+
self.attn = Attention(
|
396 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
397 |
+
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
|
398 |
+
xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
|
399 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
400 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
401 |
+
self.norm2 = norm_layer(dim)
|
402 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
403 |
+
|
404 |
+
if naiveswiglu:
|
405 |
+
self.mlp = SwiGLU(
|
406 |
+
in_features=dim,
|
407 |
+
hidden_features=mlp_hidden_dim,
|
408 |
+
subln=subln,
|
409 |
+
norm_layer=norm_layer,
|
410 |
+
)
|
411 |
+
else:
|
412 |
+
self.mlp = Mlp(
|
413 |
+
in_features=dim,
|
414 |
+
hidden_features=mlp_hidden_dim,
|
415 |
+
act_layer=act_layer,
|
416 |
+
subln=subln,
|
417 |
+
drop=drop
|
418 |
+
)
|
419 |
+
|
420 |
+
if init_values is not None and init_values > 0:
|
421 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
422 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
423 |
+
else:
|
424 |
+
self.gamma_1, self.gamma_2 = None, None
|
425 |
+
|
426 |
+
self.postnorm = postnorm
|
427 |
+
|
428 |
+
def forward(self, x, rel_pos_bias=None, attn_mask=None):
|
429 |
+
if self.gamma_1 is None:
|
430 |
+
if self.postnorm:
|
431 |
+
x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
|
432 |
+
x = x + self.drop_path(self.norm2(self.mlp(x)))
|
433 |
+
else:
|
434 |
+
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
|
435 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
436 |
+
else:
|
437 |
+
if self.postnorm:
|
438 |
+
x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
|
439 |
+
x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
|
440 |
+
else:
|
441 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
|
442 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
443 |
+
return x
|
444 |
+
|
445 |
+
|
446 |
+
class PatchEmbed(nn.Module):
|
447 |
+
""" Image to Patch Embedding
|
448 |
+
"""
|
449 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
450 |
+
super().__init__()
|
451 |
+
img_size = to_2tuple(img_size)
|
452 |
+
patch_size = to_2tuple(patch_size)
|
453 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
454 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
455 |
+
self.img_size = img_size
|
456 |
+
self.patch_size = patch_size
|
457 |
+
self.num_patches = num_patches
|
458 |
+
|
459 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
460 |
+
|
461 |
+
def forward(self, x, **kwargs):
|
462 |
+
B, C, H, W = x.shape
|
463 |
+
# FIXME look at relaxing size constraints
|
464 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
465 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
466 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
467 |
+
return x
|
468 |
+
|
469 |
+
|
470 |
+
class RelativePositionBias(nn.Module):
|
471 |
+
|
472 |
+
def __init__(self, window_size, num_heads):
|
473 |
+
super().__init__()
|
474 |
+
self.window_size = window_size
|
475 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
476 |
+
self.relative_position_bias_table = nn.Parameter(
|
477 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
478 |
+
# cls to token & token 2 cls & cls to cls
|
479 |
+
|
480 |
+
# get pair-wise relative position index for each token inside the window
|
481 |
+
coords_h = torch.arange(window_size[0])
|
482 |
+
coords_w = torch.arange(window_size[1])
|
483 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
484 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
485 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
486 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
487 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
488 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
489 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
490 |
+
relative_position_index = \
|
491 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
492 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
493 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
494 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
495 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
496 |
+
|
497 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
498 |
+
|
499 |
+
def forward(self):
|
500 |
+
relative_position_bias = \
|
501 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
502 |
+
self.window_size[0] * self.window_size[1] + 1,
|
503 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
504 |
+
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
505 |
+
|
506 |
+
|
507 |
+
class EVAVisionTransformer(nn.Module):
|
508 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
509 |
+
"""
|
510 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
511 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
512 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
|
513 |
+
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
|
514 |
+
use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
|
515 |
+
pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False):
|
516 |
+
super().__init__()
|
517 |
+
self.image_size = img_size
|
518 |
+
self.num_classes = num_classes
|
519 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
520 |
+
|
521 |
+
self.patch_embed = PatchEmbed(
|
522 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
523 |
+
num_patches = self.patch_embed.num_patches
|
524 |
+
|
525 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
526 |
+
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
527 |
+
if use_abs_pos_emb:
|
528 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
529 |
+
else:
|
530 |
+
self.pos_embed = None
|
531 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
532 |
+
|
533 |
+
if use_shared_rel_pos_bias:
|
534 |
+
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
|
535 |
+
else:
|
536 |
+
self.rel_pos_bias = None
|
537 |
+
|
538 |
+
if rope:
|
539 |
+
half_head_dim = embed_dim // num_heads // 2
|
540 |
+
hw_seq_len = img_size // patch_size
|
541 |
+
self.rope = VisionRotaryEmbeddingFast(
|
542 |
+
dim=half_head_dim,
|
543 |
+
pt_seq_len=pt_hw_seq_len,
|
544 |
+
ft_seq_len=hw_seq_len if intp_freq else None,
|
545 |
+
# patch_dropout=patch_dropout
|
546 |
+
)
|
547 |
+
else:
|
548 |
+
self.rope = None
|
549 |
+
|
550 |
+
self.naiveswiglu = naiveswiglu
|
551 |
+
|
552 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
553 |
+
self.use_rel_pos_bias = use_rel_pos_bias
|
554 |
+
self.blocks = nn.ModuleList([
|
555 |
+
Block(
|
556 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
557 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
558 |
+
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
|
559 |
+
xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
|
560 |
+
for i in range(depth)])
|
561 |
+
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
562 |
+
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
563 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
564 |
+
|
565 |
+
if self.pos_embed is not None:
|
566 |
+
trunc_normal_(self.pos_embed, std=.02)
|
567 |
+
|
568 |
+
trunc_normal_(self.cls_token, std=.02)
|
569 |
+
# trunc_normal_(self.mask_token, std=.02)
|
570 |
+
|
571 |
+
self.apply(self._init_weights)
|
572 |
+
self.fix_init_weight()
|
573 |
+
|
574 |
+
if isinstance(self.head, nn.Linear):
|
575 |
+
trunc_normal_(self.head.weight, std=.02)
|
576 |
+
self.head.weight.data.mul_(init_scale)
|
577 |
+
self.head.bias.data.mul_(init_scale)
|
578 |
+
|
579 |
+
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
|
580 |
+
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
|
581 |
+
|
582 |
+
self.grad_checkpointing = grad_checkpointing
|
583 |
+
|
584 |
+
def fix_init_weight(self):
|
585 |
+
def rescale(param, layer_id):
|
586 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
587 |
+
|
588 |
+
for layer_id, layer in enumerate(self.blocks):
|
589 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
590 |
+
if self.naiveswiglu:
|
591 |
+
rescale(layer.mlp.w3.weight.data, layer_id + 1)
|
592 |
+
else:
|
593 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
594 |
+
|
595 |
+
def get_cast_dtype(self) -> torch.dtype:
|
596 |
+
return self.blocks[0].mlp.fc2.weight.dtype
|
597 |
+
|
598 |
+
def _init_weights(self, m):
|
599 |
+
if isinstance(m, nn.Linear):
|
600 |
+
trunc_normal_(m.weight, std=.02)
|
601 |
+
if m.bias is not None:
|
602 |
+
nn.init.constant_(m.bias, 0)
|
603 |
+
elif isinstance(m, nn.LayerNorm):
|
604 |
+
nn.init.constant_(m.bias, 0)
|
605 |
+
nn.init.constant_(m.weight, 1.0)
|
606 |
+
|
607 |
+
def get_num_layers(self):
|
608 |
+
return len(self.blocks)
|
609 |
+
|
610 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
611 |
+
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
|
612 |
+
for param in self.parameters():
|
613 |
+
param.requires_grad = False
|
614 |
+
|
615 |
+
@torch.jit.ignore
|
616 |
+
def set_grad_checkpointing(self, enable=True):
|
617 |
+
self.grad_checkpointing = enable
|
618 |
+
|
619 |
+
@torch.jit.ignore
|
620 |
+
def no_weight_decay(self):
|
621 |
+
return {'pos_embed', 'cls_token'}
|
622 |
+
|
623 |
+
def get_classifier(self):
|
624 |
+
return self.head
|
625 |
+
|
626 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
627 |
+
self.num_classes = num_classes
|
628 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
629 |
+
|
630 |
+
def forward_features(self, x, return_all_features=False):
|
631 |
+
|
632 |
+
x = self.patch_embed(x)
|
633 |
+
batch_size, seq_len, _ = x.size()
|
634 |
+
|
635 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
636 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
637 |
+
if self.pos_embed is not None:
|
638 |
+
x = x + self.pos_embed
|
639 |
+
x = self.pos_drop(x)
|
640 |
+
|
641 |
+
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
642 |
+
if os.getenv('RoPE') == '1':
|
643 |
+
if self.training and not isinstance(self.patch_dropout, nn.Identity):
|
644 |
+
x, patch_indices_keep = self.patch_dropout(x)
|
645 |
+
self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
|
646 |
+
else:
|
647 |
+
self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
|
648 |
+
x = self.patch_dropout(x)
|
649 |
+
else:
|
650 |
+
x = self.patch_dropout(x)
|
651 |
+
|
652 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
653 |
+
for i, blk in enumerate(self.blocks):
|
654 |
+
if i == len(self.blocks)-1:
|
655 |
+
continue
|
656 |
+
if self.grad_checkpointing:
|
657 |
+
x = checkpoint(blk, x, (rel_pos_bias,))
|
658 |
+
else:
|
659 |
+
x = blk(x, rel_pos_bias=rel_pos_bias)
|
660 |
+
|
661 |
+
if not return_all_features:
|
662 |
+
x = self.norm(x)
|
663 |
+
if self.fc_norm is not None:
|
664 |
+
return self.fc_norm(x.mean(1))
|
665 |
+
else:
|
666 |
+
return x[:, 0]
|
667 |
+
return x
|
668 |
+
|
669 |
+
def forward(self, x, return_all_features=False):
|
670 |
+
if return_all_features:
|
671 |
+
return self.forward_features(x, return_all_features)
|
672 |
+
x = self.forward_features(x)
|
673 |
+
x = self.head(x)
|
674 |
+
return x
|
675 |
+
|
676 |
+
class LayerNorm(nn.LayerNorm):
|
677 |
+
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
|
678 |
+
|
679 |
+
def forward(self, x: torch.Tensor):
|
680 |
+
orig_type = x.dtype
|
681 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
682 |
+
return x.to(orig_type)
|
683 |
+
|
684 |
+
try:
|
685 |
+
from apex.normalization import FusedLayerNorm
|
686 |
+
except:
|
687 |
+
FusedLayerNorm = LayerNorm
|
688 |
+
print("Please 'pip install apex'")
|
689 |
+
|
690 |
+
|
691 |
+
@dataclass
|
692 |
+
class CLIPVisionCfg:
|
693 |
+
layers: Union[Tuple[int, int, int, int], int] = 12
|
694 |
+
width: int = 768
|
695 |
+
head_width: int = 64
|
696 |
+
mlp_ratio: float = 4.0
|
697 |
+
patch_size: int = 16
|
698 |
+
image_size: Union[Tuple[int, int], int] = 224
|
699 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
700 |
+
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
701 |
+
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
|
702 |
+
drop_path_rate: Optional[float] = None # drop path rate
|
703 |
+
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
|
704 |
+
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
705 |
+
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
706 |
+
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
|
707 |
+
timm_proj_bias: bool = False # enable bias final projection
|
708 |
+
eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
|
709 |
+
qkv_bias: bool = True
|
710 |
+
fusedLN: bool = False
|
711 |
+
xattn: bool = False
|
712 |
+
postnorm: bool = False
|
713 |
+
rope: bool = False
|
714 |
+
pt_hw_seq_len: int = 16 # 224/14
|
715 |
+
intp_freq: bool = False
|
716 |
+
naiveswiglu: bool = False
|
717 |
+
subln: bool = False
|
718 |
+
|
719 |
+
|
720 |
+
def _build_vision_tower(
|
721 |
+
embed_dim: int,
|
722 |
+
vision_cfg: CLIPVisionCfg
|
723 |
+
):
|
724 |
+
if isinstance(vision_cfg, dict):
|
725 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
726 |
+
|
727 |
+
if vision_cfg.eva_model_name:
|
728 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
729 |
+
norm_layer = LayerNorm
|
730 |
+
visual = EVAVisionTransformer(
|
731 |
+
img_size=vision_cfg.image_size,
|
732 |
+
patch_size=vision_cfg.patch_size,
|
733 |
+
num_classes=embed_dim,
|
734 |
+
use_mean_pooling=vision_cfg.global_average_pool, #False
|
735 |
+
init_values=vision_cfg.ls_init_value,
|
736 |
+
patch_dropout=vision_cfg.patch_dropout,
|
737 |
+
embed_dim=vision_cfg.width,
|
738 |
+
depth=vision_cfg.layers,
|
739 |
+
num_heads=vision_heads,
|
740 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
741 |
+
qkv_bias=vision_cfg.qkv_bias,
|
742 |
+
drop_path_rate=vision_cfg.drop_path_rate,
|
743 |
+
norm_layer= partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6),
|
744 |
+
xattn=vision_cfg.xattn,
|
745 |
+
rope=vision_cfg.rope,
|
746 |
+
postnorm=vision_cfg.postnorm,
|
747 |
+
pt_hw_seq_len= vision_cfg.pt_hw_seq_len, # 224/14
|
748 |
+
intp_freq= vision_cfg.intp_freq,
|
749 |
+
naiveswiglu= vision_cfg.naiveswiglu,
|
750 |
+
subln= vision_cfg.subln
|
751 |
+
)
|
752 |
+
|
753 |
+
return visual
|
754 |
+
|
755 |
+
class Eva2LargeEncoder(nn.Module):
|
756 |
+
def __init__(self, image_size=224):
|
757 |
+
super(Eva2LargeEncoder, self).__init__()
|
758 |
+
self.config = {
|
759 |
+
"embed_dim": 768,
|
760 |
+
"vision_cfg": {
|
761 |
+
"image_size": 336,
|
762 |
+
"layers": 24,
|
763 |
+
"width": 1024,
|
764 |
+
"drop_path_rate": 0,
|
765 |
+
"head_width": 64,
|
766 |
+
"mlp_ratio": 2.6667,
|
767 |
+
"patch_size": 14,
|
768 |
+
"eva_model_name": "eva-clip-l-14-336",
|
769 |
+
"xattn": True,
|
770 |
+
"fusedLN": True,
|
771 |
+
"rope": True,
|
772 |
+
"pt_hw_seq_len": 16,
|
773 |
+
"intp_freq": True,
|
774 |
+
"naiveswiglu": True,
|
775 |
+
"subln": True
|
776 |
+
}
|
777 |
+
}
|
778 |
+
self.config['vision_cfg']['image_size'] = image_size
|
779 |
+
|
780 |
+
import os
|
781 |
+
os.environ['delRoPE'] = '1' # to avoid error in rope params when changing image size
|
782 |
+
self.model = _build_vision_tower(**self.config)
|
783 |
+
|
784 |
+
|
785 |
+
def forward(self, images):
|
786 |
+
encode = self.model(images, return_all_features=True)[:, 1:, :]
|
787 |
+
return encode
|
788 |
+
|
789 |
+
class CrossVisionModel(nn.Module):
|
790 |
+
def __init__(self, config):
|
791 |
+
super().__init__()
|
792 |
+
self.vit = Eva2LargeEncoder(image_size=config.cross_image_size)
|
793 |
+
self.pos_embed = nn.Parameter(torch.zeros((self.vit.config['vision_cfg']['image_size'] // self.vit.config['vision_cfg']['patch_size']) ** 2, self.vit.config['vision_cfg']['width']))
|
794 |
+
|
795 |
+
def forward(self, images):
|
796 |
+
enc = self.vit(images)
|
797 |
+
return enc + self.pos_embed.unsqueeze(0)
|
generation_config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 1,
|
4 |
+
"eos_token_id": 2,
|
5 |
+
"pad_token_id": 0,
|
6 |
+
"transformers_version": "4.36.0.dev0"
|
7 |
+
}
|
model-00001-of-00008.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:15c2451e4e0dd5caae61e91de67ea0dac7b554c4e2b39d54e67ffbe232460063
|
3 |
+
size 4974581824
|
model-00002-of-00008.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:351e987e5f1f28124a8c839085aeb3111e8f956393d3092439e07c7a285a5d90
|
3 |
+
size 4982995648
|
model-00003-of-00008.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ff93f061698e4a8bbc042d9f243efd9647b19889f0e5f87188cd6bfc0400f839
|
3 |
+
size 4982995728
|
model-00004-of-00008.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:54c8f6eca491986fb61914029014d80a7813c1538807a55296530c6337fe2829
|
3 |
+
size 4982995728
|
model-00005-of-00008.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4218cb8e5bda353d4491a70054b2e8f384f00c55bcc453807309982760fc48fd
|
3 |
+
size 4982995728
|
model-00006-of-00008.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:23f2f3123c49dc4fe16fa44f6ee3dd388b18490fa05966f9dd21e806e3cbd22d
|
3 |
+
size 4950060832
|
model-00007-of-00008.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8d6eec2f2b61a1eb39dfb7703231639e16394bba2348108f1ed970268abac86e
|
3 |
+
size 4945866712
|
model-00008-of-00008.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6df8c2486c4d21b20c5b2a72a699b47fd8ac62199ba9ae69db32054ef4fab1a2
|
3 |
+
size 1783098344
|
model.safetensors.index.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
modeling_cogagent.py
ADDED
@@ -0,0 +1,917 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""largely copy from llama and adapt for CogAgent"""
|
2 |
+
import warnings
|
3 |
+
from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, Any
|
4 |
+
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import CrossEntropyLoss
|
9 |
+
from torchvision import transforms
|
10 |
+
from einops import rearrange
|
11 |
+
|
12 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
13 |
+
from transformers.utils.logging import get_logger
|
14 |
+
from transformers.activations import ACT2FN
|
15 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
16 |
+
|
17 |
+
from .configuration_cogagent import CogAgentConfig
|
18 |
+
from .util import FastRotaryEmbedding
|
19 |
+
from .visual import EVA2CLIPModel
|
20 |
+
from .cross_visual import CrossVisionModel
|
21 |
+
|
22 |
+
if TYPE_CHECKING:
|
23 |
+
from transformers.utils import ModelOutput
|
24 |
+
|
25 |
+
logger = get_logger(__name__)
|
26 |
+
|
27 |
+
LANGUAGE_TOKEN_TYPE = 0
|
28 |
+
VISION_TOKEN_TYPE = 1
|
29 |
+
|
30 |
+
|
31 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
32 |
+
def _make_causal_mask(
|
33 |
+
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
34 |
+
):
|
35 |
+
"""
|
36 |
+
Make causal mask used for bi-directional self-attention.
|
37 |
+
"""
|
38 |
+
bsz, tgt_len = input_ids_shape
|
39 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
40 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
41 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
42 |
+
mask = mask.to(dtype)
|
43 |
+
|
44 |
+
if past_key_values_length > 0:
|
45 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
46 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
47 |
+
|
48 |
+
|
49 |
+
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
50 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
51 |
+
"""
|
52 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
53 |
+
"""
|
54 |
+
bsz, src_len = mask.size()
|
55 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
56 |
+
|
57 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
58 |
+
|
59 |
+
inverted_mask = 1.0 - expanded_mask
|
60 |
+
|
61 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
62 |
+
|
63 |
+
|
64 |
+
class RMSNorm(nn.Module):
|
65 |
+
def __init__(self, hidden_size, eps=1e-6):
|
66 |
+
super().__init__()
|
67 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
68 |
+
self.variance_epsilon = eps
|
69 |
+
|
70 |
+
def forward(self, hidden_states):
|
71 |
+
input_dtype = hidden_states.dtype
|
72 |
+
hidden_states = hidden_states.to(torch.float32)
|
73 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
74 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
75 |
+
return (self.weight * hidden_states).to(input_dtype)
|
76 |
+
|
77 |
+
|
78 |
+
class MLP(nn.Module):
|
79 |
+
def __init__(self, config):
|
80 |
+
super().__init__()
|
81 |
+
self.hidden_size = config.hidden_size
|
82 |
+
self.intermediate_size = config.intermediate_size
|
83 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
84 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
85 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
86 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
90 |
+
return down_proj
|
91 |
+
|
92 |
+
|
93 |
+
def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
|
94 |
+
vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
|
95 |
+
vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] == VISION_TOKEN_TYPE)
|
96 |
+
language_token_mask = ~vision_token_mask
|
97 |
+
return vision_token_mask, language_token_mask
|
98 |
+
|
99 |
+
|
100 |
+
class VisionExpertMLP(nn.Module):
|
101 |
+
def __init__(self, config):
|
102 |
+
super().__init__()
|
103 |
+
self.language_mlp = MLP(config)
|
104 |
+
self.vision_mlp = MLP(config)
|
105 |
+
|
106 |
+
def forward(self, hidden_states: "torch.Tensor(B, L, D)", token_type_ids: "torch.LongTensor(B, L)"):
|
107 |
+
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
108 |
+
vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
|
109 |
+
output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
|
110 |
+
output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
|
111 |
+
return output
|
112 |
+
|
113 |
+
|
114 |
+
def attention_fn(
|
115 |
+
query_layer: "torch.tensor(B, H, L, HD)",
|
116 |
+
key_layer: "torch.tensor(B, H, L, HD)",
|
117 |
+
value_layer: "torch.tensor(B, H, L, HD)",
|
118 |
+
attention_mask: "torch.tensor(B, H, L, HD)",
|
119 |
+
*,
|
120 |
+
scaling_attention_score: bool = True,
|
121 |
+
attention_dropout: nn.Module = None
|
122 |
+
):
|
123 |
+
attention_mask_bool = (attention_mask == 0)
|
124 |
+
is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
|
125 |
+
is_full = (attention_mask_bool > 0).all()
|
126 |
+
if not (int(torch.__version__.split('.')[0]) >= 2):
|
127 |
+
warnings.warn("It's recommended to use torch2.0 or higher.")
|
128 |
+
if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
|
129 |
+
dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
|
130 |
+
return torch.nn.functional.scaled_dot_product_attention(
|
131 |
+
query_layer, key_layer, value_layer,
|
132 |
+
attn_mask=None,
|
133 |
+
dropout_p=dropout_p,
|
134 |
+
is_causal=not is_full
|
135 |
+
)
|
136 |
+
else:
|
137 |
+
if scaling_attention_score:
|
138 |
+
query_layer = query_layer / math.sqrt(query_layer.shape[-1])
|
139 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
140 |
+
attention_scores = attention_scores + attention_mask
|
141 |
+
attention_scores = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
|
142 |
+
if attention_dropout is not None:
|
143 |
+
attention_scores = attention_dropout(attention_scores)
|
144 |
+
context_layer = torch.matmul(attention_scores, value_layer)
|
145 |
+
return context_layer
|
146 |
+
|
147 |
+
|
148 |
+
class VisionExpertAttention(nn.Module):
|
149 |
+
def __init__(self, config):
|
150 |
+
super().__init__()
|
151 |
+
self.config = config
|
152 |
+
self.hidden_size = config.hidden_size
|
153 |
+
self.num_heads = config.num_attention_heads
|
154 |
+
self.head_dim = self.hidden_size // self.num_heads
|
155 |
+
self.max_position_embeddings = config.max_position_embeddings
|
156 |
+
|
157 |
+
# self.rotary_emb = RotaryEmbedding(self.hidden_size // self.num_heads)
|
158 |
+
self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False)
|
159 |
+
self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
|
160 |
+
self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
161 |
+
self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
|
162 |
+
self.language_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
163 |
+
|
164 |
+
def _transpose_for_scores(self, tensor):
|
165 |
+
"""Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
|
166 |
+
new_tensor_shape = tensor.size()[:-1] + (self.num_heads, self.head_dim)
|
167 |
+
tensor = tensor.view(*new_tensor_shape)
|
168 |
+
return tensor.permute(0, 2, 1, 3)
|
169 |
+
|
170 |
+
def forward(
|
171 |
+
self,
|
172 |
+
hidden_states: torch.Tensor,
|
173 |
+
token_type_ids: torch.LongTensor,
|
174 |
+
position_ids: torch.LongTensor,
|
175 |
+
attention_mask: Optional[torch.Tensor] = None,
|
176 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
177 |
+
output_attentions: bool = False,
|
178 |
+
use_cache: bool = False,
|
179 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
180 |
+
bsz, q_len, _ = hidden_states.size()
|
181 |
+
vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
|
182 |
+
|
183 |
+
shape = list(hidden_states.shape)
|
184 |
+
shape[-1] = shape[-1] * 3
|
185 |
+
mixed_raw_layer = torch.empty(shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
186 |
+
mixed_raw_layer[vision_token_mask] = self.vision_expert_query_key_value(hidden_states[vision_token_mask])
|
187 |
+
mixed_raw_layer[language_token_mask] = self.language_expert_query_key_value(hidden_states[language_token_mask])
|
188 |
+
|
189 |
+
query_states, key_states, value_states = torch.split(mixed_raw_layer, self.hidden_size, dim=-1)
|
190 |
+
query_states = self._transpose_for_scores(query_states) # B, H, L, HD
|
191 |
+
key_states = self._transpose_for_scores(key_states) # B, H, L, HD
|
192 |
+
value_states = self._transpose_for_scores(value_states) # B, H, L, HD
|
193 |
+
|
194 |
+
kv_seq_len = key_states.shape[-2]
|
195 |
+
if past_key_value is not None:
|
196 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
197 |
+
|
198 |
+
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids, max_seqlen=position_ids.max() + 1)
|
199 |
+
|
200 |
+
if past_key_value is not None:
|
201 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
202 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
203 |
+
|
204 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
205 |
+
|
206 |
+
context_layer = attention_fn(
|
207 |
+
query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
|
208 |
+
scaling_attention_score=True, attention_dropout=None)
|
209 |
+
if context_layer.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
210 |
+
raise ValueError(
|
211 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
212 |
+
f" {context_layer.size()}"
|
213 |
+
)
|
214 |
+
context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
|
215 |
+
|
216 |
+
attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
217 |
+
attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
|
218 |
+
attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
|
219 |
+
|
220 |
+
if output_attentions:
|
221 |
+
warnings.warn("output_attentions is not implemented.")
|
222 |
+
|
223 |
+
return attn_output, None, past_key_value
|
224 |
+
|
225 |
+
class CrossAttention(nn.Module):
|
226 |
+
def __init__(self, config):
|
227 |
+
super().__init__()
|
228 |
+
self.config = config
|
229 |
+
self.hidden_size = config.hidden_size
|
230 |
+
self.cross_hidden_size = config.cross_hidden_size
|
231 |
+
self.cross_compute_hidden_size = config.cross_compute_hidden_size
|
232 |
+
self.num_heads = config.num_attention_heads
|
233 |
+
self.head_dim = self.hidden_size // self.num_heads
|
234 |
+
self.cross_head_dim = self.cross_compute_hidden_size // self.num_heads
|
235 |
+
self.max_position_embeddings = config.max_position_embeddings
|
236 |
+
|
237 |
+
# self.rotary_emb = RotaryEmbedding(self.hidden_size // self.num_heads)
|
238 |
+
self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False)
|
239 |
+
self.query = nn.Linear(self.hidden_size, self.cross_compute_hidden_size, bias=False)
|
240 |
+
self.key_value = nn.Linear(self.cross_hidden_size, self.cross_compute_hidden_size * 2, bias=False)
|
241 |
+
self.dense = nn.Linear(self.cross_compute_hidden_size, self.hidden_size, bias=False)
|
242 |
+
|
243 |
+
def _transpose_for_scores(self, tensor):
|
244 |
+
"""Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
|
245 |
+
new_tensor_shape = tensor.size()[:-1] + (self.num_heads, self.cross_head_dim)
|
246 |
+
tensor = tensor.view(*new_tensor_shape)
|
247 |
+
return tensor.permute(0, 2, 1, 3)
|
248 |
+
|
249 |
+
def forward(
|
250 |
+
self,
|
251 |
+
hidden_states: torch.Tensor,
|
252 |
+
encoder_outputs: torch.LongTensor,
|
253 |
+
attention_mask: Optional[torch.Tensor] = None,
|
254 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
255 |
+
output_attentions: bool = False,
|
256 |
+
use_cache: bool = False,
|
257 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
258 |
+
bsz, q_len, _ = hidden_states.size()
|
259 |
+
|
260 |
+
shape = list(hidden_states.shape)
|
261 |
+
shape[-1] = shape[-1] * 3
|
262 |
+
|
263 |
+
mixed_query_layer = self.query(hidden_states)
|
264 |
+
if past_key_value is None:
|
265 |
+
mixed_x_layer = self.key_value(encoder_outputs)
|
266 |
+
mixed_key_layer, mixed_value_layer = torch.split(mixed_x_layer, self.cross_compute_hidden_size, dim=-1)
|
267 |
+
key_states = self._transpose_for_scores(mixed_key_layer) # B, H, L, HD
|
268 |
+
value_states = self._transpose_for_scores(mixed_value_layer) # B, H, L, HD
|
269 |
+
else:
|
270 |
+
key_states, value_states = past_key_value
|
271 |
+
|
272 |
+
query_states = self._transpose_for_scores(mixed_query_layer) # B, H, L, HD
|
273 |
+
|
274 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
275 |
+
|
276 |
+
context_layer = attention_fn(
|
277 |
+
query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
|
278 |
+
scaling_attention_score=True, attention_dropout=None)
|
279 |
+
if context_layer.size() != (bsz, self.num_heads, q_len, self.cross_head_dim):
|
280 |
+
raise ValueError(
|
281 |
+
f"`cross_attn_output` should be of size {(bsz, self.num_heads, q_len, self.cross_head_dim)}, but is"
|
282 |
+
f" {context_layer.size()}"
|
283 |
+
)
|
284 |
+
context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.cross_hidden_size)
|
285 |
+
|
286 |
+
attn_output = self.dense(context_layer)
|
287 |
+
|
288 |
+
if output_attentions:
|
289 |
+
warnings.warn("output_attentions is not implemented.")
|
290 |
+
|
291 |
+
return attn_output, None, past_key_value
|
292 |
+
|
293 |
+
class CogAgentDecoderLayer(nn.Module):
|
294 |
+
def __init__(self, config):
|
295 |
+
super().__init__()
|
296 |
+
self.hidden_size = config.hidden_size
|
297 |
+
self.self_attn = VisionExpertAttention(config=config)
|
298 |
+
self.cross_attn = CrossAttention(config=config)
|
299 |
+
self.mlp = VisionExpertMLP(config)
|
300 |
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
301 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
302 |
+
self.post_cross_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
303 |
+
|
304 |
+
def forward(
|
305 |
+
self,
|
306 |
+
hidden_states: torch.Tensor,
|
307 |
+
encoder_outputs: torch.Tensor,
|
308 |
+
token_type_ids: torch.LongTensor,
|
309 |
+
position_ids: torch.LongTensor,
|
310 |
+
attention_mask: Optional[torch.Tensor] = None,
|
311 |
+
cross_attention_mask: Optional[torch.Tensor] = None,
|
312 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
313 |
+
output_attentions: Optional[bool] = False,
|
314 |
+
use_cache: Optional[bool] = False,
|
315 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
316 |
+
residual = hidden_states
|
317 |
+
|
318 |
+
hidden_states = self.input_layernorm(hidden_states)
|
319 |
+
|
320 |
+
# Self Attention
|
321 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
322 |
+
hidden_states=hidden_states,
|
323 |
+
token_type_ids=token_type_ids,
|
324 |
+
position_ids=position_ids,
|
325 |
+
attention_mask=attention_mask,
|
326 |
+
past_key_value=past_key_value[:2] if past_key_value is not None else None,
|
327 |
+
output_attentions=output_attentions,
|
328 |
+
use_cache=use_cache,
|
329 |
+
)
|
330 |
+
hidden_states = residual + hidden_states
|
331 |
+
|
332 |
+
cross_input = self.post_cross_attention_layernorm(hidden_states)
|
333 |
+
# Fully Connected
|
334 |
+
attention_output, self_cross_attn_weights, present_cross_key_value = self.cross_attn(
|
335 |
+
hidden_states=cross_input,
|
336 |
+
encoder_outputs=encoder_outputs,
|
337 |
+
attention_mask=cross_attention_mask,
|
338 |
+
past_key_value=past_key_value[-2:] if past_key_value is not None else None,
|
339 |
+
output_attentions=output_attentions,
|
340 |
+
use_cache=use_cache,
|
341 |
+
)
|
342 |
+
hidden_states = hidden_states + attention_output
|
343 |
+
mlp_input = self.post_attention_layernorm(hidden_states)
|
344 |
+
mlp_output = self.mlp(mlp_input, token_type_ids=token_type_ids)
|
345 |
+
hidden_states = mlp_output + hidden_states
|
346 |
+
|
347 |
+
outputs = (hidden_states,)
|
348 |
+
|
349 |
+
if output_attentions:
|
350 |
+
outputs += (self_attn_weights,)
|
351 |
+
|
352 |
+
if use_cache:
|
353 |
+
outputs += (present_key_value+present_cross_key_value,)
|
354 |
+
|
355 |
+
return outputs # type: ignore
|
356 |
+
|
357 |
+
|
358 |
+
class CogAgentPreTrainedModel(PreTrainedModel):
|
359 |
+
config_class = CogAgentConfig
|
360 |
+
base_model_prefix = "model"
|
361 |
+
supports_gradient_checkpointing = False
|
362 |
+
_no_split_modules = ["CogAgentDecoderLayer"]
|
363 |
+
_skip_keys_device_placement = "past_key_values"
|
364 |
+
|
365 |
+
def _init_weights(self, module):
|
366 |
+
std = self.config.initializer_range
|
367 |
+
if isinstance(module, nn.Linear):
|
368 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
369 |
+
if module.bias is not None:
|
370 |
+
module.bias.data.zero_()
|
371 |
+
elif isinstance(module, nn.Embedding):
|
372 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
373 |
+
if module.padding_idx is not None:
|
374 |
+
module.weight.data[module.padding_idx].zero_()
|
375 |
+
|
376 |
+
|
377 |
+
def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
|
378 |
+
if images_list is None or len(images_list) == 0:
|
379 |
+
return True
|
380 |
+
for image_list in images_list:
|
381 |
+
if len(image_list):
|
382 |
+
return False
|
383 |
+
return True
|
384 |
+
|
385 |
+
|
386 |
+
def build_position_ids(x: "torch.BoolTensor(B, L)", attention_mask: Optional["torch.BoolTensor(B, L)"] = None) -> "torch.LongTensor(B, L)":
|
387 |
+
if attention_mask is not None:
|
388 |
+
tmp = x.clone()
|
389 |
+
tmp[~(attention_mask.bool())] = -1
|
390 |
+
else:
|
391 |
+
tmp = x.clone()
|
392 |
+
# image boi eoi token as LANGUAGE_TOKEN_TYPE
|
393 |
+
is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
|
394 |
+
is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
|
395 |
+
is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
|
396 |
+
is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
|
397 |
+
is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
|
398 |
+
tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
|
399 |
+
# final position ids
|
400 |
+
y = torch.zeros_like(x, dtype=torch.long)
|
401 |
+
y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ((tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
|
402 |
+
y = y.cumsum(dim=-1)
|
403 |
+
return y
|
404 |
+
|
405 |
+
|
406 |
+
class CogAgentModel(CogAgentPreTrainedModel):
|
407 |
+
def __init__(self, config):
|
408 |
+
super().__init__(config)
|
409 |
+
self.padding_idx = config.pad_token_id
|
410 |
+
self.vocab_size = config.vocab_size
|
411 |
+
|
412 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
413 |
+
self.layers = nn.ModuleList([CogAgentDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
414 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
415 |
+
|
416 |
+
self.vision = EVA2CLIPModel(config)
|
417 |
+
self.cross_vision = CrossVisionModel(config)
|
418 |
+
|
419 |
+
self.gradient_checkpointing = False
|
420 |
+
# Initialize weights and apply final processing
|
421 |
+
self.post_init()
|
422 |
+
|
423 |
+
def encode_images(self, images: List[List[torch.Tensor]]) -> torch.Tensor:
|
424 |
+
images_list, images = images, []
|
425 |
+
|
426 |
+
images = []
|
427 |
+
for image_list in images_list:
|
428 |
+
for image in image_list:
|
429 |
+
images.append(image)
|
430 |
+
|
431 |
+
images = torch.stack(images)
|
432 |
+
images_features = self.vision(images)
|
433 |
+
return images_features
|
434 |
+
|
435 |
+
def encode_cross_images(self, images: List[List[torch.Tensor]]) -> torch.Tensor:
|
436 |
+
images_list, images = images, []
|
437 |
+
|
438 |
+
images = []
|
439 |
+
for image_list in images_list:
|
440 |
+
for image in image_list:
|
441 |
+
images.append(image)
|
442 |
+
|
443 |
+
images = torch.stack(images)
|
444 |
+
encoder_outputs = self.cross_vision(images)
|
445 |
+
return encoder_outputs
|
446 |
+
|
447 |
+
def forward(
|
448 |
+
self,
|
449 |
+
input_ids: torch.LongTensor = None,
|
450 |
+
images: List[List[torch.Tensor]] = None,
|
451 |
+
cross_images: List[List[torch.Tensor]] = None,
|
452 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
453 |
+
attention_mask: Optional[torch.Tensor] = None,
|
454 |
+
cross_attention_mask: Optional[torch.Tensor] = None,
|
455 |
+
position_ids: Optional[torch.LongTensor] = None,
|
456 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
457 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
458 |
+
use_cache: Optional[bool] = None,
|
459 |
+
output_attentions: Optional[bool] = None,
|
460 |
+
output_hidden_states: Optional[bool] = None,
|
461 |
+
return_dict: Optional[bool] = None,
|
462 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
463 |
+
"""take care of image_encode, token_type_ids, position_ids and (attention_mask = None is fine)"""
|
464 |
+
|
465 |
+
if past_key_values is not None:
|
466 |
+
encoder_outputs = None
|
467 |
+
# generate mode with past_key_values. the image features are already mapped
|
468 |
+
else:
|
469 |
+
# not allow for inputs_embeds, because we want to process image feature
|
470 |
+
assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
|
471 |
+
if not is_empty(images): # multi-modality
|
472 |
+
assert token_type_ids is not None, f"multi-modality requires `token_type_ids`!"
|
473 |
+
assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
|
474 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
475 |
+
images_features = self.encode_images(images)
|
476 |
+
encoder_outputs = self.encode_cross_images(cross_images)
|
477 |
+
images_features = rearrange(images_features, 'b n d -> (b n) d')
|
478 |
+
images_features = images_features.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
479 |
+
inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
|
480 |
+
else: # single-modality
|
481 |
+
if token_type_ids is None:
|
482 |
+
token_type_ids = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) * LANGUAGE_TOKEN_TYPE
|
483 |
+
assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
|
484 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
485 |
+
encoder_outputs = None
|
486 |
+
|
487 |
+
if position_ids is None:
|
488 |
+
position_ids = build_position_ids(token_type_ids, attention_mask)
|
489 |
+
input_ids = None
|
490 |
+
|
491 |
+
return self.llm_forward(
|
492 |
+
input_ids=input_ids,
|
493 |
+
encoder_outputs=encoder_outputs,
|
494 |
+
token_type_ids=token_type_ids,
|
495 |
+
attention_mask=attention_mask,
|
496 |
+
cross_attention_mask=cross_attention_mask,
|
497 |
+
position_ids=position_ids,
|
498 |
+
past_key_values=past_key_values,
|
499 |
+
inputs_embeds=inputs_embeds,
|
500 |
+
use_cache=use_cache,
|
501 |
+
output_attentions=output_attentions,
|
502 |
+
output_hidden_states=output_hidden_states,
|
503 |
+
return_dict=return_dict,
|
504 |
+
)
|
505 |
+
|
506 |
+
def llm_forward(
|
507 |
+
self,
|
508 |
+
input_ids: torch.LongTensor = None,
|
509 |
+
encoder_outputs: torch.LongTensor = None,
|
510 |
+
token_type_ids: torch.LongTensor = None,
|
511 |
+
attention_mask: Optional[torch.Tensor] = None,
|
512 |
+
cross_attention_mask: Optional[torch.Tensor] = None,
|
513 |
+
position_ids: Optional[torch.LongTensor] = None,
|
514 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
515 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
516 |
+
use_cache: Optional[bool] = None,
|
517 |
+
output_attentions: Optional[bool] = None,
|
518 |
+
output_hidden_states: Optional[bool] = None,
|
519 |
+
return_dict: Optional[bool] = None,
|
520 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
521 |
+
"""largely copy from llama forward and adapt for CogAgent with `token_type_ids`"""
|
522 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
523 |
+
output_hidden_states = (
|
524 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
525 |
+
)
|
526 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
527 |
+
|
528 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
529 |
+
|
530 |
+
# retrieve input_ids and inputs_embeds
|
531 |
+
if input_ids is not None and inputs_embeds is not None:
|
532 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
533 |
+
elif input_ids is not None:
|
534 |
+
batch_size, seq_length = input_ids.shape
|
535 |
+
elif inputs_embeds is not None:
|
536 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
537 |
+
else:
|
538 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
539 |
+
|
540 |
+
seq_length_with_past = seq_length
|
541 |
+
past_key_values_length = 0
|
542 |
+
|
543 |
+
if past_key_values is not None:
|
544 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
545 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
546 |
+
|
547 |
+
if position_ids is None:
|
548 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
549 |
+
position_ids = torch.arange(
|
550 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
551 |
+
)
|
552 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
553 |
+
else:
|
554 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
555 |
+
|
556 |
+
if inputs_embeds is None:
|
557 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
558 |
+
# embed positions
|
559 |
+
if attention_mask is None:
|
560 |
+
attention_mask = torch.ones(
|
561 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
562 |
+
)
|
563 |
+
if cross_attention_mask is None:
|
564 |
+
cross_attention_mask = torch.ones(
|
565 |
+
(batch_size, 1), dtype=torch.bool, device=inputs_embeds.device
|
566 |
+
)
|
567 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
568 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
569 |
+
)
|
570 |
+
|
571 |
+
hidden_states = inputs_embeds
|
572 |
+
|
573 |
+
# decoder layers
|
574 |
+
all_hidden_states = () if output_hidden_states else None
|
575 |
+
all_self_attns = () if output_attentions else None
|
576 |
+
next_decoder_cache = () if use_cache else None
|
577 |
+
|
578 |
+
for idx, decoder_layer in enumerate(self.layers):
|
579 |
+
if output_hidden_states:
|
580 |
+
all_hidden_states += (hidden_states,)
|
581 |
+
|
582 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
583 |
+
layer_outputs = decoder_layer(
|
584 |
+
hidden_states,
|
585 |
+
encoder_outputs=encoder_outputs,
|
586 |
+
token_type_ids=token_type_ids,
|
587 |
+
attention_mask=attention_mask,
|
588 |
+
cross_attention_mask=cross_attention_mask,
|
589 |
+
position_ids=position_ids,
|
590 |
+
past_key_value=past_key_value,
|
591 |
+
output_attentions=output_attentions,
|
592 |
+
use_cache=use_cache,
|
593 |
+
)
|
594 |
+
hidden_states = layer_outputs[0]
|
595 |
+
|
596 |
+
if use_cache:
|
597 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
598 |
+
|
599 |
+
if output_attentions:
|
600 |
+
all_self_attns += (layer_outputs[1],)
|
601 |
+
|
602 |
+
hidden_states = self.norm(hidden_states)
|
603 |
+
|
604 |
+
# add hidden states from the last decoder layer
|
605 |
+
if output_hidden_states:
|
606 |
+
all_hidden_states += (hidden_states,)
|
607 |
+
|
608 |
+
next_cache = next_decoder_cache if use_cache else None
|
609 |
+
if not return_dict:
|
610 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
611 |
+
return BaseModelOutputWithPast(
|
612 |
+
last_hidden_state=hidden_states,
|
613 |
+
past_key_values=next_cache,
|
614 |
+
hidden_states=all_hidden_states,
|
615 |
+
attentions=all_self_attns,
|
616 |
+
)
|
617 |
+
|
618 |
+
def get_input_embeddings(self):
|
619 |
+
return self.embed_tokens
|
620 |
+
|
621 |
+
def set_input_embeddings(self, value):
|
622 |
+
self.embed_tokens = value
|
623 |
+
|
624 |
+
# noinspection PyMethodMayBeStatic
|
625 |
+
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
626 |
+
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
627 |
+
# create causal mask
|
628 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
629 |
+
combined_attention_mask = None
|
630 |
+
if input_shape[-1] > 1:
|
631 |
+
combined_attention_mask = _make_causal_mask(
|
632 |
+
input_shape,
|
633 |
+
inputs_embeds.dtype,
|
634 |
+
device=inputs_embeds.device,
|
635 |
+
past_key_values_length=past_key_values_length,
|
636 |
+
)
|
637 |
+
|
638 |
+
if attention_mask is not None:
|
639 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
640 |
+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
641 |
+
inputs_embeds.device
|
642 |
+
)
|
643 |
+
combined_attention_mask = (
|
644 |
+
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
645 |
+
)
|
646 |
+
|
647 |
+
return combined_attention_mask
|
648 |
+
|
649 |
+
def chat_old_history_to_prompt(history, query):
|
650 |
+
prompt = "<EOI>Question: "
|
651 |
+
for i, (old_query, response) in enumerate(history):
|
652 |
+
prompt += old_query + " Answer: " + response + "\nQuestion: "
|
653 |
+
prompt += query + " Answer:"
|
654 |
+
return prompt
|
655 |
+
|
656 |
+
def chat_history_to_prompt(history, query):
|
657 |
+
prompt = " [INST] "
|
658 |
+
for i, (old_query, response) in enumerate(history):
|
659 |
+
prompt += old_query + " [/INST] " + response + " [INST] "
|
660 |
+
prompt += query + " [/INST] "
|
661 |
+
return prompt
|
662 |
+
|
663 |
+
|
664 |
+
def base_history_to_prompt(history, query):
|
665 |
+
prompt = query
|
666 |
+
return prompt
|
667 |
+
|
668 |
+
|
669 |
+
_history_to_prompt = {
|
670 |
+
"base": base_history_to_prompt,
|
671 |
+
"chat": chat_history_to_prompt,
|
672 |
+
"chat_old": chat_old_history_to_prompt
|
673 |
+
}
|
674 |
+
|
675 |
+
|
676 |
+
class CogAgentForCausalLM(CogAgentPreTrainedModel):
|
677 |
+
_auto_class = "AutoModelForCausalLM"
|
678 |
+
|
679 |
+
def __init__(self, config):
|
680 |
+
super().__init__(config)
|
681 |
+
self.model = CogAgentModel(config)
|
682 |
+
self.vocab_size = config.vocab_size
|
683 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
684 |
+
|
685 |
+
# Initialize weights and apply final processing
|
686 |
+
self.post_init()
|
687 |
+
|
688 |
+
def get_input_embeddings(self):
|
689 |
+
return self.model.embed_tokens
|
690 |
+
|
691 |
+
def set_input_embeddings(self, value):
|
692 |
+
self.model.embed_tokens = value
|
693 |
+
|
694 |
+
def get_output_embeddings(self):
|
695 |
+
return self.lm_head
|
696 |
+
|
697 |
+
def set_output_embeddings(self, new_embeddings):
|
698 |
+
self.lm_head = new_embeddings
|
699 |
+
|
700 |
+
def set_decoder(self, decoder):
|
701 |
+
self.model = decoder
|
702 |
+
|
703 |
+
def get_decoder(self):
|
704 |
+
return self.model
|
705 |
+
|
706 |
+
def forward(
|
707 |
+
self,
|
708 |
+
input_ids: torch.LongTensor = None,
|
709 |
+
images: List[List[torch.Tensor]] = None,
|
710 |
+
cross_images: List[List[torch.Tensor]] = None,
|
711 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
712 |
+
attention_mask: Optional[torch.Tensor] = None,
|
713 |
+
position_ids: Optional[torch.LongTensor] = None,
|
714 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
715 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
716 |
+
use_cache: Optional[bool] = None,
|
717 |
+
output_attentions: Optional[bool] = None,
|
718 |
+
output_hidden_states: Optional[bool] = None,
|
719 |
+
return_dict: Optional[bool] = None,
|
720 |
+
labels: Optional[torch.LongTensor] = None,
|
721 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
722 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
723 |
+
output_hidden_states = (
|
724 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
725 |
+
)
|
726 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
727 |
+
|
728 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
729 |
+
outputs = self.model(
|
730 |
+
input_ids=input_ids,
|
731 |
+
images=images,
|
732 |
+
cross_images=cross_images,
|
733 |
+
token_type_ids=token_type_ids,
|
734 |
+
attention_mask=attention_mask,
|
735 |
+
position_ids=position_ids,
|
736 |
+
past_key_values=past_key_values,
|
737 |
+
inputs_embeds=inputs_embeds,
|
738 |
+
use_cache=use_cache,
|
739 |
+
output_attentions=output_attentions,
|
740 |
+
output_hidden_states=output_hidden_states,
|
741 |
+
return_dict=return_dict,
|
742 |
+
)
|
743 |
+
|
744 |
+
hidden_states = outputs[0]
|
745 |
+
logits = self.lm_head(hidden_states)
|
746 |
+
logits = logits.float()
|
747 |
+
|
748 |
+
loss = None
|
749 |
+
if labels is not None:
|
750 |
+
# Shift so that tokens < n predict n
|
751 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
752 |
+
shift_labels = labels[..., 1:].contiguous()
|
753 |
+
# Flatten the tokens
|
754 |
+
loss_fct = CrossEntropyLoss()
|
755 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
756 |
+
shift_labels = shift_labels.view(-1)
|
757 |
+
# Enable model parallelism
|
758 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
759 |
+
loss = loss_fct(shift_logits, shift_labels)
|
760 |
+
|
761 |
+
if not return_dict:
|
762 |
+
output = (logits,) + outputs[1:]
|
763 |
+
return (loss,) + output if loss is not None else output
|
764 |
+
|
765 |
+
return CausalLMOutputWithPast(
|
766 |
+
loss=loss,
|
767 |
+
logits=logits,
|
768 |
+
past_key_values=outputs.past_key_values,
|
769 |
+
hidden_states=outputs.hidden_states,
|
770 |
+
attentions=outputs.attentions,
|
771 |
+
)
|
772 |
+
|
773 |
+
def _prepare_attention_mask_for_generation(
|
774 |
+
self,
|
775 |
+
inputs: torch.Tensor,
|
776 |
+
pad_token_id: Optional[int],
|
777 |
+
eos_token_id: Optional[Union[int, List[int]]],
|
778 |
+
) -> torch.LongTensor:
|
779 |
+
return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
|
780 |
+
|
781 |
+
def prepare_inputs_for_generation(
|
782 |
+
self, input_ids, token_type_ids, images=None, cross_images=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
783 |
+
):
|
784 |
+
# build position_ids if needed
|
785 |
+
position_ids = kwargs.get("position_ids", None)
|
786 |
+
if position_ids is None:
|
787 |
+
position_ids = build_position_ids(token_type_ids, attention_mask)
|
788 |
+
|
789 |
+
if past_key_values:
|
790 |
+
input_ids = input_ids[:, -1:]
|
791 |
+
token_type_ids = token_type_ids[:, -1:]
|
792 |
+
position_ids = position_ids[:, -1:]
|
793 |
+
|
794 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
795 |
+
if inputs_embeds is not None and past_key_values is None:
|
796 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
797 |
+
else:
|
798 |
+
model_inputs = {"input_ids": input_ids}
|
799 |
+
|
800 |
+
model_inputs.update(
|
801 |
+
{
|
802 |
+
"token_type_ids": token_type_ids,
|
803 |
+
"images": images,
|
804 |
+
"cross_images": cross_images,
|
805 |
+
"position_ids": position_ids,
|
806 |
+
"past_key_values": past_key_values,
|
807 |
+
"use_cache": kwargs.get("use_cache"),
|
808 |
+
"attention_mask": attention_mask,
|
809 |
+
}
|
810 |
+
)
|
811 |
+
return model_inputs
|
812 |
+
|
813 |
+
def _update_model_kwargs_for_generation(
|
814 |
+
self,
|
815 |
+
outputs: "ModelOutput",
|
816 |
+
model_kwargs: Dict[str, Any],
|
817 |
+
is_encoder_decoder: bool = False,
|
818 |
+
standardize_cache_format: bool = False,
|
819 |
+
) -> Dict[str, Any]:
|
820 |
+
# update past_key_values
|
821 |
+
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
822 |
+
outputs, standardize_cache_format=standardize_cache_format
|
823 |
+
)
|
824 |
+
if getattr(outputs, "state", None) is not None:
|
825 |
+
model_kwargs["state"] = outputs.state
|
826 |
+
|
827 |
+
# update token_type_ids with last value
|
828 |
+
if "token_type_ids" in model_kwargs:
|
829 |
+
token_type_ids = model_kwargs["token_type_ids"]
|
830 |
+
new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype, device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
|
831 |
+
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
|
832 |
+
|
833 |
+
if not is_encoder_decoder:
|
834 |
+
# update attention mask
|
835 |
+
if "attention_mask" in model_kwargs:
|
836 |
+
attention_mask = model_kwargs["attention_mask"]
|
837 |
+
model_kwargs["attention_mask"] = torch.cat(
|
838 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
839 |
+
)
|
840 |
+
else:
|
841 |
+
# update decoder attention mask
|
842 |
+
if "decoder_attention_mask" in model_kwargs:
|
843 |
+
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
844 |
+
model_kwargs["decoder_attention_mask"] = torch.cat(
|
845 |
+
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
|
846 |
+
dim=-1,
|
847 |
+
)
|
848 |
+
|
849 |
+
return model_kwargs
|
850 |
+
|
851 |
+
def _reorder_cache(self, past_key_values, beam_idx):
|
852 |
+
reordered_past = ()
|
853 |
+
for layer_past in past_key_values:
|
854 |
+
reordered_past += (
|
855 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
856 |
+
)
|
857 |
+
return reordered_past
|
858 |
+
|
859 |
+
def build_conversation_input_ids(
|
860 |
+
self,
|
861 |
+
tokenizer: "PreTrainedTokenizer",
|
862 |
+
*,
|
863 |
+
query: str,
|
864 |
+
history: Optional[List[Tuple[str, str]]] = None,
|
865 |
+
images: Optional[List["PIL.Image"]] = None,
|
866 |
+
template_version: Optional[Literal["base", "chat", "vqa"]] = None,
|
867 |
+
):
|
868 |
+
image_size: int = self.config.vision_config['image_size']
|
869 |
+
cross_image_size: int = self.config.cross_image_size
|
870 |
+
patch_size: int = self.config.vision_config['patch_size']
|
871 |
+
template_version = template_version or self.config.template_version
|
872 |
+
assert images is None or len(images) <= 1, f"not support multi images by now."
|
873 |
+
history = history or []
|
874 |
+
text = _history_to_prompt[template_version](history, query)
|
875 |
+
|
876 |
+
input_ids = [tokenizer.bos_token_id]
|
877 |
+
token_type_ids = [LANGUAGE_TOKEN_TYPE]
|
878 |
+
if images is not None and len(images) == 1:
|
879 |
+
ori = images
|
880 |
+
# vision
|
881 |
+
transform = transforms.Compose(
|
882 |
+
[
|
883 |
+
transforms.Resize(
|
884 |
+
(image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC
|
885 |
+
),
|
886 |
+
transforms.ToTensor(),
|
887 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
888 |
+
]
|
889 |
+
)
|
890 |
+
images = [transform(ori[0])]
|
891 |
+
cross_transform = transforms.Compose(
|
892 |
+
[
|
893 |
+
transforms.Resize(
|
894 |
+
(cross_image_size, cross_image_size), interpolation=transforms.InterpolationMode.BICUBIC
|
895 |
+
),
|
896 |
+
transforms.ToTensor(),
|
897 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
898 |
+
]
|
899 |
+
)
|
900 |
+
cross_images = [cross_transform(ori[0])]
|
901 |
+
# language
|
902 |
+
vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
|
903 |
+
input_ids += [tokenizer.pad_token_id] * vision_token_num
|
904 |
+
token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
|
905 |
+
text_ids = tokenizer.encode(text, add_special_tokens=False)
|
906 |
+
|
907 |
+
input_ids += text_ids
|
908 |
+
token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
|
909 |
+
attention_mask = [1] * len(input_ids)
|
910 |
+
|
911 |
+
return {
|
912 |
+
'input_ids': torch.tensor(input_ids, dtype=torch.long),
|
913 |
+
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
|
914 |
+
'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
|
915 |
+
'images': images,
|
916 |
+
'cross_images': cross_images
|
917 |
+
}
|
util.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
import triton
|
8 |
+
import triton.language as tl
|
9 |
+
|
10 |
+
|
11 |
+
# @triton.autotune(
|
12 |
+
# configs=[
|
13 |
+
# triton.Config({"BLOCK_M": 2}),
|
14 |
+
# triton.Config({"BLOCK_M": 4}),
|
15 |
+
# triton.Config({"BLOCK_M": 8}),
|
16 |
+
# triton.Config({"BLOCK_M": 16}),
|
17 |
+
# ],
|
18 |
+
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
|
19 |
+
# )
|
20 |
+
@triton.jit
|
21 |
+
def rotary_kernel(
|
22 |
+
OUT, # Pointers to matrices
|
23 |
+
X,
|
24 |
+
COS,
|
25 |
+
SIN,
|
26 |
+
CU_SEQLENS,
|
27 |
+
SEQLEN_OFFSETS, # this could be int or a pointer
|
28 |
+
# Matrix dimensions
|
29 |
+
seqlen,
|
30 |
+
nheads,
|
31 |
+
rotary_dim,
|
32 |
+
seqlen_ro,
|
33 |
+
CACHE_KEY_SEQLEN,
|
34 |
+
# strides
|
35 |
+
stride_out_batch,
|
36 |
+
stride_out_nheads,
|
37 |
+
stride_out_seqlen,
|
38 |
+
stride_out_headdim,
|
39 |
+
stride_x_batch,
|
40 |
+
stride_x_nheads,
|
41 |
+
stride_x_seqlen,
|
42 |
+
stride_x_headdim,
|
43 |
+
# Meta-parameters
|
44 |
+
BLOCK_K: tl.constexpr,
|
45 |
+
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
|
46 |
+
IS_VARLEN: tl.constexpr,
|
47 |
+
INTERLEAVED: tl.constexpr,
|
48 |
+
CONJUGATE: tl.constexpr,
|
49 |
+
BLOCK_M: tl.constexpr,
|
50 |
+
):
|
51 |
+
pid_m = tl.program_id(axis=0)
|
52 |
+
pid_batch = tl.program_id(axis=1)
|
53 |
+
pid_head = tl.program_id(axis=2)
|
54 |
+
rotary_dim_half = rotary_dim // 2
|
55 |
+
|
56 |
+
if not IS_VARLEN:
|
57 |
+
X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
|
58 |
+
OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
|
59 |
+
COS = COS + pid_batch * seqlen_ro * rotary_dim_half
|
60 |
+
SIN = SIN + pid_batch * seqlen_ro * rotary_dim_half
|
61 |
+
else:
|
62 |
+
start_idx = tl.load(CU_SEQLENS + pid_batch)
|
63 |
+
seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
|
64 |
+
X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
|
65 |
+
OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
|
66 |
+
|
67 |
+
if pid_m * BLOCK_M >= seqlen:
|
68 |
+
return
|
69 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
70 |
+
if not IS_SEQLEN_OFFSETS_TENSOR:
|
71 |
+
rm_cs = rm + SEQLEN_OFFSETS
|
72 |
+
else:
|
73 |
+
rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
|
74 |
+
rk = tl.arange(0, BLOCK_K)
|
75 |
+
rk_half = tl.arange(0, BLOCK_K // 2)
|
76 |
+
|
77 |
+
if not INTERLEAVED:
|
78 |
+
# Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
|
79 |
+
X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
|
80 |
+
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
|
81 |
+
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
|
82 |
+
cos = tl.load(
|
83 |
+
COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
|
84 |
+
)
|
85 |
+
sin = tl.load(
|
86 |
+
SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
|
87 |
+
)
|
88 |
+
x0 = tl.load(
|
89 |
+
X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
|
90 |
+
)
|
91 |
+
x1 = tl.load(
|
92 |
+
X + rotary_dim_half * stride_x_headdim,
|
93 |
+
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
|
94 |
+
other=0.0,
|
95 |
+
)
|
96 |
+
if CONJUGATE:
|
97 |
+
sin = -sin
|
98 |
+
o0 = x0 * cos - x1 * sin
|
99 |
+
o1 = x0 * sin + x1 * cos
|
100 |
+
# write back result
|
101 |
+
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
|
102 |
+
tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
|
103 |
+
tl.store(
|
104 |
+
OUT + rotary_dim_half * stride_out_headdim,
|
105 |
+
o1,
|
106 |
+
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
|
107 |
+
)
|
108 |
+
else:
|
109 |
+
# We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
|
110 |
+
# Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
|
111 |
+
# Loading x0 will be fast but x1 will be slow.
|
112 |
+
# Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
|
113 |
+
# Then we do the calculation and use tl.where to pick put the right outputs for the even
|
114 |
+
# and for the odd indices.
|
115 |
+
rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
|
116 |
+
rk_repeat = tl.arange(0, BLOCK_K) // 2
|
117 |
+
X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
|
118 |
+
X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
|
119 |
+
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
|
120 |
+
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
|
121 |
+
cos = tl.load(
|
122 |
+
COS,
|
123 |
+
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
|
124 |
+
other=1.0,
|
125 |
+
).to(tl.float32)
|
126 |
+
sin = tl.load(
|
127 |
+
SIN,
|
128 |
+
mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
|
129 |
+
other=0.0,
|
130 |
+
).to(tl.float32)
|
131 |
+
x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
|
132 |
+
tl.float32
|
133 |
+
)
|
134 |
+
x1 = tl.load(
|
135 |
+
X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
|
136 |
+
).to(tl.float32)
|
137 |
+
if CONJUGATE:
|
138 |
+
sin = -sin
|
139 |
+
x0_cos = x0 * cos
|
140 |
+
x1_sin = x1 * sin
|
141 |
+
out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
|
142 |
+
OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
|
143 |
+
tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
|
144 |
+
|
145 |
+
|
146 |
+
def apply_rotary(
|
147 |
+
x: torch.Tensor,
|
148 |
+
cos: torch.Tensor,
|
149 |
+
sin: torch.Tensor,
|
150 |
+
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
151 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
152 |
+
max_seqlen: Optional[int] = None,
|
153 |
+
interleaved=False,
|
154 |
+
inplace=False,
|
155 |
+
conjugate=False,
|
156 |
+
) -> torch.Tensor:
|
157 |
+
"""
|
158 |
+
Arguments:
|
159 |
+
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
|
160 |
+
else (total_seqlen, nheads, headdim).
|
161 |
+
cos: (seqlen_ro, rotary_dim / 2)
|
162 |
+
sin: (seqlen_ro, rotary_dim / 2)
|
163 |
+
seqlen_offsets: integer or integer tensor of size (batch,)
|
164 |
+
cu_seqlens: (batch + 1,) or None
|
165 |
+
max_seqlen: int
|
166 |
+
Returns:
|
167 |
+
y: (batch, seqlen, nheads, headdim)
|
168 |
+
"""
|
169 |
+
|
170 |
+
batch, nheads, seqlen, headdim = x.shape
|
171 |
+
|
172 |
+
batch_ro, seqlen_ro, rotary_dim = cos.shape
|
173 |
+
|
174 |
+
assert batch == batch_ro
|
175 |
+
assert sin.shape == cos.shape
|
176 |
+
rotary_dim *= 2
|
177 |
+
assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
|
178 |
+
assert headdim <= 256, "Only support headdim <= 256"
|
179 |
+
|
180 |
+
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
|
181 |
+
|
182 |
+
assert (
|
183 |
+
cos.dtype == sin.dtype
|
184 |
+
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
|
185 |
+
assert (
|
186 |
+
x.dtype == cos.dtype
|
187 |
+
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
|
188 |
+
|
189 |
+
cos, sin = cos.contiguous(), sin.contiguous()
|
190 |
+
if isinstance(seqlen_offsets, torch.Tensor):
|
191 |
+
assert seqlen_offsets.shape == (batch,)
|
192 |
+
assert seqlen_offsets.dtype in [torch.int32, torch.int64]
|
193 |
+
seqlen_offsets = seqlen_offsets.contiguous()
|
194 |
+
else:
|
195 |
+
assert seqlen_offsets + seqlen <= seqlen_ro
|
196 |
+
|
197 |
+
output = torch.empty_like(x) if not inplace else x
|
198 |
+
if rotary_dim < headdim and not inplace:
|
199 |
+
output[..., rotary_dim:].copy_(x[..., rotary_dim:])
|
200 |
+
|
201 |
+
BLOCK_K = (
|
202 |
+
32
|
203 |
+
if rotary_dim <= 32
|
204 |
+
else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
|
205 |
+
)
|
206 |
+
grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
|
207 |
+
BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
|
208 |
+
|
209 |
+
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
210 |
+
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
211 |
+
with torch.cuda.device(x.device.index):
|
212 |
+
rotary_kernel[grid](
|
213 |
+
output, # data ptrs
|
214 |
+
x,
|
215 |
+
cos,
|
216 |
+
sin,
|
217 |
+
cu_seqlens,
|
218 |
+
seqlen_offsets,
|
219 |
+
seqlen, # shapes
|
220 |
+
nheads,
|
221 |
+
rotary_dim,
|
222 |
+
seqlen_ro,
|
223 |
+
seqlen // 128, # key for triton cache (limit number of compilations)
|
224 |
+
output.stride(0), # batch_strides
|
225 |
+
output.stride(-3), # nheads_stride
|
226 |
+
output.stride(-2), # seqlen_stride
|
227 |
+
output.stride(-1), # headdim_stride
|
228 |
+
x.stride(0), # batch_strides
|
229 |
+
x.stride(-3), # nheads stride
|
230 |
+
x.stride(-2), # seqlen stride
|
231 |
+
x.stride(-1), # headdim stride
|
232 |
+
BLOCK_K,
|
233 |
+
isinstance(seqlen_offsets, torch.Tensor),
|
234 |
+
False,
|
235 |
+
interleaved,
|
236 |
+
conjugate,
|
237 |
+
BLOCK_M,
|
238 |
+
)
|
239 |
+
return output
|
240 |
+
|
241 |
+
|
242 |
+
class ApplyRotaryEmb(torch.autograd.Function):
|
243 |
+
@staticmethod
|
244 |
+
def forward(
|
245 |
+
ctx,
|
246 |
+
x,
|
247 |
+
cos,
|
248 |
+
sin,
|
249 |
+
interleaved=False,
|
250 |
+
inplace=False,
|
251 |
+
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
252 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
253 |
+
max_seqlen: Optional[int] = None,
|
254 |
+
):
|
255 |
+
out = apply_rotary(
|
256 |
+
x,
|
257 |
+
cos,
|
258 |
+
sin,
|
259 |
+
seqlen_offsets=seqlen_offsets,
|
260 |
+
cu_seqlens=cu_seqlens,
|
261 |
+
max_seqlen=max_seqlen,
|
262 |
+
interleaved=interleaved,
|
263 |
+
inplace=inplace,
|
264 |
+
)
|
265 |
+
if isinstance(seqlen_offsets, int):
|
266 |
+
ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
|
267 |
+
ctx.seqlen_offsets = seqlen_offsets
|
268 |
+
else:
|
269 |
+
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
|
270 |
+
ctx.seqlen_offsets = None
|
271 |
+
ctx.interleaved = interleaved
|
272 |
+
ctx.inplace = inplace
|
273 |
+
ctx.max_seqlen = max_seqlen
|
274 |
+
return out if not inplace else x
|
275 |
+
|
276 |
+
@staticmethod
|
277 |
+
def backward(ctx, do):
|
278 |
+
seqlen_offsets = ctx.seqlen_offsets
|
279 |
+
if seqlen_offsets is None:
|
280 |
+
cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
|
281 |
+
else:
|
282 |
+
cos, sin, cu_seqlens = ctx.saved_tensors
|
283 |
+
# TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
|
284 |
+
# "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
|
285 |
+
if not ctx.interleaved and not ctx.inplace:
|
286 |
+
do = do.clone()
|
287 |
+
dx = apply_rotary(
|
288 |
+
do,
|
289 |
+
cos,
|
290 |
+
sin,
|
291 |
+
seqlen_offsets=seqlen_offsets,
|
292 |
+
cu_seqlens=cu_seqlens,
|
293 |
+
max_seqlen=ctx.max_seqlen,
|
294 |
+
interleaved=ctx.interleaved,
|
295 |
+
inplace=ctx.inplace,
|
296 |
+
conjugate=True,
|
297 |
+
)
|
298 |
+
return dx, None, None, None, None, None, None, None
|
299 |
+
|
300 |
+
|
301 |
+
def apply_rotary_emb(
|
302 |
+
x,
|
303 |
+
cos,
|
304 |
+
sin,
|
305 |
+
interleaved=False,
|
306 |
+
inplace=False,
|
307 |
+
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
308 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
309 |
+
max_seqlen: Optional[int] = None,
|
310 |
+
):
|
311 |
+
"""
|
312 |
+
Arguments:
|
313 |
+
x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
|
314 |
+
else (total_seqlen, nheads, headdim)
|
315 |
+
cos, sin: (seqlen_rotary, rotary_dim / 2)
|
316 |
+
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
317 |
+
of 1st half and 2nd half (GPT-NeoX style).
|
318 |
+
inplace: if True, apply rotary embedding in-place.
|
319 |
+
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
320 |
+
Most commonly used in inference when we have KV cache.
|
321 |
+
cu_seqlens: (batch + 1,) or None
|
322 |
+
max_seqlen: int
|
323 |
+
Return:
|
324 |
+
out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
|
325 |
+
else (total_seqlen, nheads, headdim)
|
326 |
+
rotary_dim must be <= headdim
|
327 |
+
Apply rotary embedding to the first rotary_dim of x.
|
328 |
+
"""
|
329 |
+
return ApplyRotaryEmb.apply(
|
330 |
+
x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
|
331 |
+
)
|
332 |
+
|
333 |
+
|
334 |
+
# For backward compatibility
|
335 |
+
apply_rotary_emb_func = apply_rotary_emb
|
336 |
+
|
337 |
+
|
338 |
+
class FastRotaryEmbedding(torch.nn.Module):
|
339 |
+
"""
|
340 |
+
The rotary position embeddings from RoFormer_ (Su et. al).
|
341 |
+
A crucial insight from the method is that the query and keys are
|
342 |
+
transformed by rotation matrices which depend on the relative positions.
|
343 |
+
|
344 |
+
Other implementations are available in the Rotary Transformer repo_ and in
|
345 |
+
GPT-NeoX_, GPT-NeoX was an inspiration
|
346 |
+
|
347 |
+
.. _RoFormer: https://arxiv.org/abs/2104.09864
|
348 |
+
.. _repo: https://github.com/ZhuiyiTechnology/roformer
|
349 |
+
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
350 |
+
|
351 |
+
If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
|
352 |
+
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
|
353 |
+
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
|
354 |
+
"""
|
355 |
+
|
356 |
+
def __init__(
|
357 |
+
self,
|
358 |
+
dim: int,
|
359 |
+
base=10000,
|
360 |
+
interleaved=False,
|
361 |
+
scale_base=None,
|
362 |
+
pos_idx_in_fp32=True,
|
363 |
+
device=None,
|
364 |
+
):
|
365 |
+
"""
|
366 |
+
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
367 |
+
of 1st half and 2nd half (GPT-NeoX style).
|
368 |
+
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
|
369 |
+
otherwise they might be in lower precision.
|
370 |
+
This option was added because previously (before 2023-07-02), when we construct
|
371 |
+
the position indices, we use the dtype of self.inv_freq. In most cases this would
|
372 |
+
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
|
373 |
+
self.inv_freq would be bf16, and the position indices are also in bf16.
|
374 |
+
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
|
375 |
+
embeddings for some positions will coincide.
|
376 |
+
To maintain compatibility with models previously trained in pure bf16,
|
377 |
+
we add this option.
|
378 |
+
"""
|
379 |
+
super().__init__()
|
380 |
+
self.dim = dim
|
381 |
+
self.base = base
|
382 |
+
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
383 |
+
# Generate and save the inverse frequency buffer (non trainable)
|
384 |
+
inv_freq = self._compute_inv_freq(device)
|
385 |
+
self.register_buffer("inv_freq", inv_freq)
|
386 |
+
self.interleaved = interleaved
|
387 |
+
self.scale_base = scale_base
|
388 |
+
scale = (
|
389 |
+
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
390 |
+
if scale_base is not None
|
391 |
+
else None
|
392 |
+
)
|
393 |
+
self.register_buffer("scale", scale, persistent=False)
|
394 |
+
|
395 |
+
self._seq_len_cached = 0
|
396 |
+
self._cos_cached = None
|
397 |
+
self._sin_cached = None
|
398 |
+
self._cos_k_cached = None
|
399 |
+
self._sin_k_cached = None
|
400 |
+
self.cos = None
|
401 |
+
self.sin = None
|
402 |
+
|
403 |
+
def _compute_inv_freq(self, device=None):
|
404 |
+
return 1.0 / (
|
405 |
+
self.base
|
406 |
+
** (torch.arange(0, self.dim, 2, device=device) / self.dim)
|
407 |
+
# ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)
|
408 |
+
)
|
409 |
+
|
410 |
+
def _update_cos_sin_cache(self, seqlen, position_id, device=None, dtype=None):
|
411 |
+
|
412 |
+
if (
|
413 |
+
seqlen > self._seq_len_cached
|
414 |
+
):
|
415 |
+
self._seq_len_cached = seqlen
|
416 |
+
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
|
417 |
+
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
418 |
+
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
|
419 |
+
if self.pos_idx_in_fp32:
|
420 |
+
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
421 |
+
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
|
422 |
+
# will be large. Having it in bf16 will lose a lot of precision and cause the
|
423 |
+
# cos & sin output to change significantly.
|
424 |
+
# We want to recompute self.inv_freq if it was not loaded in fp32
|
425 |
+
if self.inv_freq.dtype != torch.float32:
|
426 |
+
inv_freq = self._compute_inv_freq(device=device)
|
427 |
+
else:
|
428 |
+
inv_freq = self.inv_freq
|
429 |
+
else:
|
430 |
+
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
431 |
+
inv_freq = self.inv_freq
|
432 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
433 |
+
if self.scale is None:
|
434 |
+
self._cos_cached = torch.cos(freqs).to(dtype)
|
435 |
+
self._sin_cached = torch.sin(freqs).to(dtype)
|
436 |
+
|
437 |
+
else:
|
438 |
+
power = (
|
439 |
+
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
|
440 |
+
- seqlen // 2
|
441 |
+
) / self.scale_base
|
442 |
+
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
443 |
+
# We want the multiplication by scale to happen in fp32
|
444 |
+
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
445 |
+
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
446 |
+
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
447 |
+
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
448 |
+
|
449 |
+
def forward(
|
450 |
+
self,
|
451 |
+
q: torch.Tensor,
|
452 |
+
k: torch.Tensor,
|
453 |
+
position_ids: torch.Tensor,
|
454 |
+
max_seqlen,
|
455 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
456 |
+
"""
|
457 |
+
q: (batch, nheads, seqlen, headdim)
|
458 |
+
k: (batch, nheads, seqlen, headdim)
|
459 |
+
position_id: (batch, seqlen)
|
460 |
+
max_seqlen: int
|
461 |
+
layer_id: int
|
462 |
+
only if layer_id == 0, then update cons and sin
|
463 |
+
Apply rotary embedding *inplace* to q k.
|
464 |
+
"""
|
465 |
+
|
466 |
+
self._update_cos_sin_cache(max_seqlen, position_ids, device=q.device, dtype=q.dtype)
|
467 |
+
cos, sin = F.embedding(position_ids, self._cos_cached), F.embedding(position_ids, self._sin_cached)
|
468 |
+
|
469 |
+
q = apply_rotary_emb_func(
|
470 |
+
q,
|
471 |
+
cos,
|
472 |
+
sin,
|
473 |
+
interleaved=self.interleaved,
|
474 |
+
inplace=True
|
475 |
+
)
|
476 |
+
k = apply_rotary_emb_func(
|
477 |
+
k,
|
478 |
+
cos,
|
479 |
+
sin,
|
480 |
+
interleaved=self.interleaved,
|
481 |
+
inplace=True
|
482 |
+
)
|
483 |
+
return q, k
|
visual.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from argparse import Namespace
|
4 |
+
import xformers.ops as xops
|
5 |
+
from transformers.activations import ACT2FN
|
6 |
+
|
7 |
+
|
8 |
+
class PatchEmbedding(nn.Module):
|
9 |
+
def __init__(self, config):
|
10 |
+
super().__init__()
|
11 |
+
self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size)
|
12 |
+
self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
|
13 |
+
self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
|
14 |
+
|
15 |
+
def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
|
16 |
+
x = self.proj(images)
|
17 |
+
x = x.flatten(2).transpose(1, 2)
|
18 |
+
cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
|
19 |
+
x = torch.cat((cls_token, x), dim=1)
|
20 |
+
x += self.position_embedding.weight.unsqueeze(0)
|
21 |
+
return x
|
22 |
+
|
23 |
+
|
24 |
+
class Attention(nn.Module):
|
25 |
+
def __init__(self, config):
|
26 |
+
super().__init__()
|
27 |
+
self.num_heads = config.num_heads
|
28 |
+
head_dim = config.hidden_size // config.num_heads
|
29 |
+
self.scale = head_dim ** -0.5
|
30 |
+
self.query_key_value = nn.Linear(config.hidden_size, config.hidden_size * 3)
|
31 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
32 |
+
self.output_dropout = torch.nn.Dropout(config.dropout_prob)
|
33 |
+
|
34 |
+
def forward(self, x: "tensor(B, L, D)") -> "tensor(B, L, D)":
|
35 |
+
B, L, _ = x.shape
|
36 |
+
qkv = self.query_key_value(x)
|
37 |
+
qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) # 3, B, L, H, D
|
38 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
39 |
+
|
40 |
+
out = xops.memory_efficient_attention(
|
41 |
+
q, k, v, scale=self.scale,
|
42 |
+
)
|
43 |
+
output = self.dense(out.view(B, L, -1))
|
44 |
+
output = self.output_dropout(output)
|
45 |
+
return output
|
46 |
+
|
47 |
+
def attention(self, q, k, v):
|
48 |
+
attn_weights = torch.matmul(q * self.scale, k.transpose(-2, -1))
|
49 |
+
attn_weights = attn_weights.softmax(dim=-1)
|
50 |
+
output = torch.matmul(attn_weights, v)
|
51 |
+
return output
|
52 |
+
|
53 |
+
|
54 |
+
class MLP(nn.Module):
|
55 |
+
def __init__(self, config):
|
56 |
+
super().__init__()
|
57 |
+
self.config = config
|
58 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
59 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
60 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
61 |
+
|
62 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
63 |
+
x = self.fc1(x)
|
64 |
+
x = self.activation_fn(x)
|
65 |
+
x = self.fc2(x)
|
66 |
+
return x
|
67 |
+
|
68 |
+
|
69 |
+
class TransformerLayer(nn.Module):
|
70 |
+
def __init__(self, config):
|
71 |
+
super().__init__()
|
72 |
+
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
73 |
+
self.attention = Attention(config)
|
74 |
+
self.mlp = MLP(config)
|
75 |
+
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
76 |
+
|
77 |
+
def forward(self, hidden_states):
|
78 |
+
attention_input = hidden_states
|
79 |
+
attention_output = self.input_layernorm(self.attention(attention_input))
|
80 |
+
hidden_states = attention_input + attention_output
|
81 |
+
mlp_input = hidden_states
|
82 |
+
mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
|
83 |
+
output = mlp_input + mlp_output
|
84 |
+
return output
|
85 |
+
|
86 |
+
|
87 |
+
class Transformer(nn.Module):
|
88 |
+
def __init__(self, config):
|
89 |
+
super().__init__()
|
90 |
+
self.layers = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)])
|
91 |
+
|
92 |
+
def forward(self, hidden_states):
|
93 |
+
for layer_module in self.layers:
|
94 |
+
hidden_states = layer_module(hidden_states)
|
95 |
+
return hidden_states
|
96 |
+
|
97 |
+
|
98 |
+
class GLU(nn.Module):
|
99 |
+
def __init__(self, config, in_features):
|
100 |
+
super().__init__()
|
101 |
+
self.linear_proj = nn.Linear(in_features, config.hidden_size, bias=False)
|
102 |
+
self.norm1 = nn.LayerNorm(config.hidden_size)
|
103 |
+
self.act1 = nn.GELU()
|
104 |
+
self.act2 = nn.functional.silu
|
105 |
+
self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
106 |
+
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
107 |
+
self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
x = self.linear_proj(x)
|
111 |
+
x = self.act1(self.norm1(x))
|
112 |
+
x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
|
113 |
+
x = self.dense_4h_to_h(x)
|
114 |
+
return x
|
115 |
+
|
116 |
+
|
117 |
+
class EVA2CLIPModel(nn.Module):
|
118 |
+
def __init__(self, config):
|
119 |
+
super().__init__()
|
120 |
+
vision_config = Namespace(**config.vision_config)
|
121 |
+
self.patch_embedding = PatchEmbedding(vision_config)
|
122 |
+
self.transformer = Transformer(vision_config)
|
123 |
+
self.linear_proj = GLU(config, in_features=vision_config.hidden_size)
|
124 |
+
self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
125 |
+
self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
126 |
+
self.pos_embed = nn.Parameter(torch.zeros((vision_config.image_size // vision_config.patch_size) ** 2, vision_config.hidden_size))
|
127 |
+
|
128 |
+
def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
|
129 |
+
x = self.patch_embedding(images)
|
130 |
+
x = self.transformer(x)
|
131 |
+
x = x[:, 1:]
|
132 |
+
x = self.linear_proj(x + self.pos_embed.unsqueeze(0))
|
133 |
+
boi = self.boi.expand(x.shape[0], -1, -1)
|
134 |
+
eoi = self.eoi.expand(x.shape[0], -1, -1)
|
135 |
+
x = torch.cat((boi, x, eoi), dim=1)
|
136 |
+
return x
|