Spaces:
Running
on
Zero
Running
on
Zero
EvanTHU
commited on
Commit
•
445d3d1
1
Parent(s):
bc6c851
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +14 -0
- LICENSE +9 -0
- README copy.md +133 -0
- app copy.py +661 -0
- assets/application.png +0 -0
- assets/compare.png +0 -0
- assets/highlight.png +0 -0
- assets/logo.png +0 -0
- assets/system.png +0 -0
- generate.py +199 -0
- lit_gpt/__init__.py +15 -0
- lit_gpt/adapter.py +165 -0
- lit_gpt/adapter_v2.py +197 -0
- lit_gpt/config.py +1040 -0
- lit_gpt/lora.py +671 -0
- lit_gpt/model.py +355 -0
- lit_gpt/packed_dataset.py +235 -0
- lit_gpt/rmsnorm.py +26 -0
- lit_gpt/speed_monitor.py +425 -0
- lit_gpt/tokenizer.py +103 -0
- lit_gpt/utils.py +311 -0
- lit_llama/__init__.py +2 -0
- lit_llama/adapter.py +151 -0
- lit_llama/indexed_dataset.py +588 -0
- lit_llama/lora.py +232 -0
- lit_llama/model.py +246 -0
- lit_llama/quantization.py +281 -0
- lit_llama/tokenizer.py +49 -0
- lit_llama/utils.py +244 -0
- models/__init__.py +0 -0
- models/constants.py +18 -0
- models/encdec.py +67 -0
- models/evaluator_wrapper.py +92 -0
- models/modules.py +109 -0
- models/multimodal_encoder/builder.py +49 -0
- models/multimodal_encoder/clip_encoder.py +78 -0
- models/multimodal_encoder/languagebind/__init__.py +285 -0
- models/multimodal_encoder/languagebind/audio/configuration_audio.py +430 -0
- models/multimodal_encoder/languagebind/audio/modeling_audio.py +1030 -0
- models/multimodal_encoder/languagebind/audio/processing_audio.py +190 -0
- models/multimodal_encoder/languagebind/audio/tokenization_audio.py +77 -0
- models/multimodal_encoder/languagebind/depth/configuration_depth.py +425 -0
- models/multimodal_encoder/languagebind/depth/modeling_depth.py +1030 -0
- models/multimodal_encoder/languagebind/depth/processing_depth.py +108 -0
- models/multimodal_encoder/languagebind/depth/tokenization_depth.py +77 -0
- models/multimodal_encoder/languagebind/image/configuration_image.py +423 -0
- models/multimodal_encoder/languagebind/image/modeling_image.py +1030 -0
- models/multimodal_encoder/languagebind/image/processing_image.py +82 -0
- models/multimodal_encoder/languagebind/image/tokenization_image.py +77 -0
- models/multimodal_encoder/languagebind/thermal/configuration_thermal.py +423 -0
.gitignore
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
**/*.pyc
|
2 |
+
**/__pycache__
|
3 |
+
__pycache__/
|
4 |
+
cache_dir/
|
5 |
+
checkpoints/
|
6 |
+
feedback/
|
7 |
+
temp/
|
8 |
+
models--LanguageBind--Video-LLaVA-7B/
|
9 |
+
*.jsonl
|
10 |
+
*.json
|
11 |
+
linghao
|
12 |
+
run.sh
|
13 |
+
examples/
|
14 |
+
assets/task.gif
|
LICENSE
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
License for Non-commercial Scientific Research Purposes
|
2 |
+
|
3 |
+
IDEA grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under IDEA’s copyright interests to reproduce, distribute, and create derivative works of the text, videos, codes solely for your non-commercial research purposes.
|
4 |
+
|
5 |
+
Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited.
|
6 |
+
|
7 |
+
Text and visualization results are owned by International Digital Economy Academy (IDEA).
|
8 |
+
|
9 |
+
You also need to obey the original license of the dependency models/data used in this service.
|
README copy.md
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MotionLLM: Understanding Human Behaviors from Human Motions and Videos
|
2 |
+
|
3 |
+
![task](./assets/task.gif)
|
4 |
+
|
5 |
+
[Ling-Hao Chen](https://lhchen.top)<sup>😎 1, 3</sup>,
|
6 |
+
[Shunlin Lu](https://shunlinlu.github.io)<sup>😎 2, 3</sup>,
|
7 |
+
[Ailing Zeng](https://ailingzeng.sit)<sup>3</sup>,
|
8 |
+
[Hao Zhang](https://haozhang534.github.io/)<sup>3, 4</sup>,
|
9 |
+
[Benyou Wang](https://wabyking.github.io/old.html)<sup>2</sup>,
|
10 |
+
[Ruimao Zhang](http://zhangruimao.site)<sup>2</sup>,
|
11 |
+
[Lei Zhang](https://leizhang.org)<sup>🤗 3</sup>
|
12 |
+
|
13 |
+
<sup>😎</sup>Co-first author. Listing order is random.
|
14 |
+
<sup>🤗</sup>Corresponding author.
|
15 |
+
|
16 |
+
<sup>1</sup>Tsinghua University,
|
17 |
+
<sup>2</sup>School of Data Science, The Chinese University of Hong Kong, Shenzhen (CUHK-SZ),
|
18 |
+
<sup>3</sup>International Digital Economy Academy (IDEA),
|
19 |
+
<sup>4</sup>The Hong Kong University of Science and Technology
|
20 |
+
|
21 |
+
<p align="center">
|
22 |
+
<a href='https://arxiv.org/abs/2304'>
|
23 |
+
<img src='https://img.shields.io/badge/Arxiv-2304.tomorrow-A42C25?style=flat&logo=arXiv&logoColor=A42C25'>
|
24 |
+
</a>
|
25 |
+
<a href='https://arxiv.org/pdf/2304.pdf'>
|
26 |
+
<img src='https://img.shields.io/badge/Paper-PDF-yellow?style=flat&logo=arXiv&logoColor=yellow'>
|
27 |
+
</a>
|
28 |
+
<a href='https://lhchen.top/MotionLLM'>
|
29 |
+
<img src='https://img.shields.io/badge/Project-Page-%23df5b46?style=flat&logo=Google%20chrome&logoColor=%23df5b46'></a>
|
30 |
+
<a href='https://research.lhchen.top/blogpost/motionllm'>
|
31 |
+
<img src='https://img.shields.io/badge/Blog-post-4EABE6?style=flat&logoColor=4EABE6'></a>
|
32 |
+
<a href='https://github.com/IDEA-Research/MotionLLM'>
|
33 |
+
<img src='https://img.shields.io/badge/GitHub-Code-black?style=flat&logo=github&logoColor=white'></a>
|
34 |
+
<a href='LICENSE'>
|
35 |
+
<img src='https://img.shields.io/badge/License-IDEA-blue.svg'>
|
36 |
+
</a>
|
37 |
+
<a href="" target='_blank'>
|
38 |
+
<img src="https://visitor-badge.laobi.icu/badge?page_id=IDEA-Research.MotionLLM&left_color=gray&right_color=%2342b983">
|
39 |
+
</a>
|
40 |
+
</p>
|
41 |
+
|
42 |
+
# 🤩 Abstract
|
43 |
+
|
44 |
+
This study delves into the realm of multi-modality (i.e., video and motion modalities) human behavior understanding by leveraging the powerful capabilities of Large Language Models (LLMs). Diverging from recent LLMs designed for video-only or motion-only understanding, we argue that understanding human behavior necessitates joint modeling from both videos and motion sequences (e.g., SMPL sequences) to capture nuanced body part dynamics and semantics effectively. In light of this, we present MotionLLM, a straightforward yet effective framework for human motion understanding, captioning, and reasoning. Specifically, MotionLLM adopts a unified video-motion training strategy that leverages the complementary advantages of existing coarse video-text data and fine-grained motion-text data to glean rich spatial-temporal insights. Furthermore, we collect a substantial dataset, MoVid, comprising diverse videos, motions, captions, and instructions. Additionally, we propose the MoVid-Bench, with carefully manual annotations, for better evaluation of human behavior understanding on video and motion. Extensive experiments show the superiority of MotionLLM in the caption, spatial-temporal comprehension, and reasoning ability.
|
45 |
+
|
46 |
+
## 🤩 Highlight Applications
|
47 |
+
|
48 |
+
![application](./assets/application.png)
|
49 |
+
|
50 |
+
## 🔧 Technical Solution
|
51 |
+
|
52 |
+
![system](./assets/system.png)
|
53 |
+
|
54 |
+
## 💻 Try it
|
55 |
+
|
56 |
+
We provide a simple online [demo](https://demo.humotionx.com/) for you to try MotionLLM. Below is the guidance to deploy the demo on your local machine.
|
57 |
+
|
58 |
+
### Step 1: Set up the environment
|
59 |
+
|
60 |
+
```bash
|
61 |
+
pip install -r requirements.txt
|
62 |
+
```
|
63 |
+
|
64 |
+
### Step 2: Download the pre-trained model
|
65 |
+
|
66 |
+
|
67 |
+
<details>
|
68 |
+
<summary><b> 2.1 Download the LLM </b></summary>
|
69 |
+
|
70 |
+
Please follow the instruction of [Lit-GPT](https://github.com/Lightning-AI/litgpt) to prepare the LLM model (vicuna 1.5-7B). These files will be:
|
71 |
+
```bah
|
72 |
+
./checkpoints/vicuna-7b-v1.5
|
73 |
+
├── generation_config.json
|
74 |
+
├── lit_config.json
|
75 |
+
├── lit_model.pth
|
76 |
+
├── pytorch_model-00001-of-00002.bin
|
77 |
+
├── pytorch_model-00002-of-00002.bin
|
78 |
+
├── pytorch_model.bin.index.json
|
79 |
+
├── tokenizer_config.json
|
80 |
+
└── tokenizer.model
|
81 |
+
```
|
82 |
+
|
83 |
+
If you have any confusion, we will update a more detailed instruction in couple of days.
|
84 |
+
|
85 |
+
</details>
|
86 |
+
|
87 |
+
<details>
|
88 |
+
<summary><b> 2.2 Dowload the LoRA and the projection layer of the MotionLLM </b></summary>
|
89 |
+
|
90 |
+
We now release one versions of the MotionLLM checkpoints, namely `v1.0` (download [here](https://drive.google.com/drive/folders/1d_5vaL34Hs2z9ACcMXyPEfZNyMs36xKx?usp=sharing)). Opening for the suggestions to Ling-Hao Chen and Shunlin Lu.
|
91 |
+
|
92 |
+
```bash
|
93 |
+
wget xxx
|
94 |
+
```
|
95 |
+
Keep them in a folder named and remember the path (`LINEAR_V` and `LORA`).
|
96 |
+
|
97 |
+
</details>
|
98 |
+
|
99 |
+
### 2.3 Run the demo
|
100 |
+
|
101 |
+
```bash
|
102 |
+
GRADIO_TEMP_DIR=temp python app.py --lora_path $LORA --mlp_path $LINEAR_V
|
103 |
+
```
|
104 |
+
If you have some error in downloading the huggingface model, you can try the following command with the mirror of huggingface.
|
105 |
+
```bash
|
106 |
+
HF_ENDPOINT=https://hf-mirror.com GRADIO_TEMP_DIR=temp python app.py --lora_path $LORA --mlp_path $LINEAR_V
|
107 |
+
```
|
108 |
+
The `GRADIO_TEMP_DIR=temp` defines a temporary directory as `./temp` for the Gradio to store the data. You can change it to your own path.
|
109 |
+
|
110 |
+
After thiess, you can open the browser and visit the local host via the command line output reminder. If it is not loaded, please change the IP address as your local IP address (via command `ifconfig`).
|
111 |
+
|
112 |
+
|
113 |
+
## 💼 To-Do
|
114 |
+
|
115 |
+
- [x] Release the video demo of MotionLLM.
|
116 |
+
- [ ] Release the motion demo of MotionLLM.
|
117 |
+
- [ ] Release the MoVid dataset and MoVid-Bench.
|
118 |
+
- [ ] Release the tuning instruction of MotionLLM.
|
119 |
+
|
120 |
+
|
121 |
+
## 💋 Acknowledgement
|
122 |
+
|
123 |
+
|
124 |
+
The author team would like to deliver many thanks to many people. Qing Jiang helps a lot with some parts of manual annotation on MoVid Bench and resolves some ethics issues of MotionLLM. Jingcheng Hu provided some technical suggestions for efficient training. Shilong Liu and Bojia Zi provided some significant technical suggestions on LLM tuning. Jiale Liu, Wenhao Yang, and Chenlai Qian provided some significant suggestions for us to polish the paper. Hongyang Li helped us a lot with the figure design. Yiren Pang provided GPT API keys when our keys were temporarily out of quota. The code is on the basis of [Video-LLaVA](https://github.com/PKU-YuanGroup/Video-LLaVA), [HumanTOMATO](https://lhchen.top/HumanTOMATO/), [MotionGPT](https://github.com/qiqiApink/MotionGPT). [lit-gpt](https://github.com/Lightning-AI/litgpt), and [HumanML3D](https://github.com/EricGuo5513/HumanML3D). Thanks to all contributors!
|
125 |
+
|
126 |
+
|
127 |
+
## 📚 License
|
128 |
+
|
129 |
+
This code is distributed under an [IDEA LICENSE](LICENSE). Note that our code depends on other libraries and datasets which each have their own respective licenses that must also be followed.
|
130 |
+
|
131 |
+
|
132 |
+
If you have any question, please contact at: thu [DOT] lhchen [AT] gmail [DOT] com AND shunlinlu0803 [AT] gmail [DOT] com.
|
133 |
+
|
app copy.py
ADDED
@@ -0,0 +1,661 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import shutil
|
2 |
+
import subprocess
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import gradio as gr
|
6 |
+
from fastapi import FastAPI
|
7 |
+
import os
|
8 |
+
from PIL import Image
|
9 |
+
import tempfile
|
10 |
+
from decord import VideoReader, cpu
|
11 |
+
import uvicorn
|
12 |
+
from transformers import TextStreamer
|
13 |
+
|
14 |
+
import hashlib
|
15 |
+
import os
|
16 |
+
import sys
|
17 |
+
import time
|
18 |
+
import warnings
|
19 |
+
from pathlib import Path
|
20 |
+
from typing import Optional
|
21 |
+
from typing import Dict, List, Literal, Optional, Tuple
|
22 |
+
from lit_gpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable
|
23 |
+
|
24 |
+
import lightning as L
|
25 |
+
import numpy as np
|
26 |
+
import torch.nn as nn
|
27 |
+
import torch.nn.functional as F
|
28 |
+
|
29 |
+
from generate import generate as generate_
|
30 |
+
from lit_llama import Tokenizer, LLaMA, LLaMAConfig
|
31 |
+
from lit_llama.lora import lora
|
32 |
+
from lit_llama.utils import EmptyInitOnDevice
|
33 |
+
from lit_gpt.utils import lazy_load
|
34 |
+
from scripts.video_dataset.prepare_video_dataset_video_llava import generate_prompt_mlp
|
35 |
+
from options import option
|
36 |
+
import imageio
|
37 |
+
from tqdm import tqdm
|
38 |
+
|
39 |
+
from models.multimodal_encoder.builder import build_image_tower, build_video_tower
|
40 |
+
from models.multimodal_projector.builder import build_vision_projector
|
41 |
+
|
42 |
+
|
43 |
+
title_markdown = ("""<div class="embed_hidden" style="text-align: center;">
|
44 |
+
<h1>MotionLLM: Understanding Human Behaviors from Human Motions and Videos</h1>
|
45 |
+
<h3>
|
46 |
+
<a href="https://lhchen.top" target="_blank" rel="noopener noreferrer">Ling-Hao Chen</a><sup>😎 1, 3</sup>,
|
47 |
+
<a href="https://shunlinlu.github.io" target="_blank" rel="noopener noreferrer">Shunlin Lu</a><sup>😎 2, 3</sup>,
|
48 |
+
<br>
|
49 |
+
<a href="https://ailingzeng.sit" target="_blank" rel="noopener noreferrer">Ailing Zeng</a><sup>3</sup>,
|
50 |
+
<a href="https://haozhang534.github.io/" target="_blank" rel="noopener noreferrer">Hao Zhang</a><sup>3, 4</sup>,
|
51 |
+
<a href="https://wabyking.github.io/old.html" target="_blank" rel="noopener noreferrer">Benyou Wang</a><sup>2</sup>,
|
52 |
+
<a href="http://zhangruimao.site" target="_blank" rel="noopener noreferrer">Ruimao Zhang</a><sup>2</sup>,
|
53 |
+
<a href="https://leizhang.org" target="_blank" rel="noopener noreferrer">Lei Zhang</a><sup>🤗 3</sup>
|
54 |
+
</h3>
|
55 |
+
<h3><sup>😎</sup><i>Co-first author. Listing order is random.</i>   <sup>🤗</sup><i>Corresponding author.</i></h3>
|
56 |
+
<h3>
|
57 |
+
<sup>1</sup>THU  
|
58 |
+
<sup>2</sup>CUHK (SZ)  
|
59 |
+
<sup>3</sup>IDEA Research  
|
60 |
+
<sup>4</sup>HKUST
|
61 |
+
</h3>
|
62 |
+
</div>
|
63 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
64 |
+
<img src="https://lhchen.top/MotionLLM/assets/img/highlight.png" alt="MotionLLM" style="width:60%; height: auto; align-items: center;">
|
65 |
+
</div>
|
66 |
+
|
67 |
+
""")
|
68 |
+
|
69 |
+
block_css = """
|
70 |
+
#buttons button {
|
71 |
+
min-width: min(120px,100%);
|
72 |
+
}
|
73 |
+
"""
|
74 |
+
|
75 |
+
|
76 |
+
tos_markdown = ("""
|
77 |
+
*We are now working to support the motion branch of the MotionLLM model.
|
78 |
+
|
79 |
+
### Terms of use
|
80 |
+
By using this service, users are required to agree to the following terms:
|
81 |
+
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content.
|
82 |
+
It is forbidden to use the service to generate content that is illegal, harmful, violent, racist, or sexual
|
83 |
+
The usage of this service is subject to the IDEA License.
|
84 |
+
""")
|
85 |
+
|
86 |
+
|
87 |
+
learn_more_markdown = ("""
|
88 |
+
### License
|
89 |
+
License for Non-commercial Scientific Research Purposes
|
90 |
+
|
91 |
+
IDEA grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under IDEA’s copyright interests to reproduce, distribute, and create derivative works of the text, videos, codes solely for your non-commercial research purposes.
|
92 |
+
|
93 |
+
Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited.
|
94 |
+
|
95 |
+
Text and visualization results are owned by International Digital Economy Academy (IDEA).
|
96 |
+
|
97 |
+
You also need to obey the original license of the dependency models/data used in this service.
|
98 |
+
""")
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
class LlavaMetaModel:
|
103 |
+
|
104 |
+
def __init__(self, config, pretrained_checkpoint):
|
105 |
+
super(LlavaMetaModel, self).__init__()
|
106 |
+
# import pdb; pdb.set_trace()
|
107 |
+
if hasattr(config, "mm_image_tower") or hasattr(config, "image_tower"):
|
108 |
+
self.image_tower = build_image_tower(config, delay_load=True)
|
109 |
+
self.mm_projector = build_vision_projector(config)
|
110 |
+
if hasattr(config, "mm_video_tower") or hasattr(config, "video_tower"):
|
111 |
+
self.video_tower = build_video_tower(config, delay_load=True)
|
112 |
+
self.mm_projector = build_vision_projector(config)
|
113 |
+
self.load_video_tower_pretrained(pretrained_checkpoint)
|
114 |
+
|
115 |
+
def get_image_tower(self):
|
116 |
+
image_tower = getattr(self, 'image_tower', None)
|
117 |
+
if type(image_tower) is list:
|
118 |
+
image_tower = image_tower[0]
|
119 |
+
return image_tower
|
120 |
+
|
121 |
+
def get_video_tower(self):
|
122 |
+
video_tower = getattr(self, 'video_tower', None)
|
123 |
+
|
124 |
+
if type(video_tower) is list:
|
125 |
+
video_tower = video_tower[0]
|
126 |
+
return video_tower
|
127 |
+
|
128 |
+
|
129 |
+
def get_all_tower(self, keys):
|
130 |
+
tower = {key: getattr(self, f'get_{key}_tower') for key in keys}
|
131 |
+
return tower
|
132 |
+
|
133 |
+
|
134 |
+
def load_video_tower_pretrained(self, pretrained_checkpoint):
|
135 |
+
self.mm_projector.load_state_dict(pretrained_checkpoint, strict=True)
|
136 |
+
|
137 |
+
|
138 |
+
def initialize_image_modules(self, model_args, fsdp=None):
|
139 |
+
image_tower = model_args.image_tower
|
140 |
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
141 |
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
142 |
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
143 |
+
|
144 |
+
self.config.mm_image_tower = image_tower
|
145 |
+
|
146 |
+
image_tower = build_image_tower(model_args)
|
147 |
+
|
148 |
+
if fsdp is not None and len(fsdp) > 0:
|
149 |
+
self.image_tower = [image_tower]
|
150 |
+
else:
|
151 |
+
self.image_tower = image_tower
|
152 |
+
|
153 |
+
self.config.use_mm_proj = True
|
154 |
+
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
|
155 |
+
self.config.mm_hidden_size = image_tower.hidden_size
|
156 |
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
157 |
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
158 |
+
|
159 |
+
self.mm_projector = build_vision_projector(self.config)
|
160 |
+
|
161 |
+
if pretrain_mm_mlp_adapter is not None:
|
162 |
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
163 |
+
def get_w(weights, keyword):
|
164 |
+
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
165 |
+
|
166 |
+
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
|
167 |
+
|
168 |
+
def initialize_video_modules(self, model_args, fsdp=None):
|
169 |
+
video_tower = model_args.video_tower
|
170 |
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
171 |
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
172 |
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
173 |
+
|
174 |
+
self.config.mm_video_tower = video_tower
|
175 |
+
|
176 |
+
video_tower = build_video_tower(model_args)
|
177 |
+
|
178 |
+
if fsdp is not None and len(fsdp) > 0:
|
179 |
+
self.video_tower = [video_tower]
|
180 |
+
else:
|
181 |
+
self.video_tower = video_tower
|
182 |
+
|
183 |
+
self.config.use_mm_proj = True
|
184 |
+
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
|
185 |
+
self.config.mm_hidden_size = video_tower.hidden_size
|
186 |
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
187 |
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
188 |
+
|
189 |
+
self.mm_projector = build_vision_projector(self.config)
|
190 |
+
|
191 |
+
if pretrain_mm_mlp_adapter is not None:
|
192 |
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
193 |
+
def get_w(weights, keyword):
|
194 |
+
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
195 |
+
|
196 |
+
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
|
197 |
+
|
198 |
+
def encode_images(self, images):
|
199 |
+
image_features = self.get_image_tower()(images)
|
200 |
+
image_features = self.mm_projector(image_features)
|
201 |
+
return image_features
|
202 |
+
|
203 |
+
def encode_videos(self, videos):
|
204 |
+
# import pdb; pdb.set_trace()
|
205 |
+
# videos: torch.Size([1, 3, 8, 224, 224])
|
206 |
+
video_features = self.get_video_tower()(videos) # torch.Size([1, 2048, 1024])
|
207 |
+
video_features = self.mm_projector(video_features.float()) # torch.Size([1, 2048, 4096])
|
208 |
+
return video_features
|
209 |
+
|
210 |
+
def get_multimodal_embeddings(self, X_modalities):
|
211 |
+
Xs, keys= X_modalities
|
212 |
+
|
213 |
+
X_features = getattr(self, f'encode_{keys[0]}s')(Xs) # expand to get batchsize
|
214 |
+
|
215 |
+
return X_features
|
216 |
+
|
217 |
+
|
218 |
+
class Projection(nn.Module):
|
219 |
+
def __init__(self, ):
|
220 |
+
super().__init__()
|
221 |
+
self.linear_proj = nn.Linear(512, 4096)
|
222 |
+
def forward(self, x):
|
223 |
+
return self.linear_proj(x)
|
224 |
+
|
225 |
+
|
226 |
+
class ProjectionNN(nn.Module):
|
227 |
+
def __init__(self, ):
|
228 |
+
super().__init__()
|
229 |
+
self.proj = nn.Sequential(
|
230 |
+
nn.Linear(512, 4096),
|
231 |
+
nn.GELU(),
|
232 |
+
nn.Linear(4096, 4096)
|
233 |
+
)
|
234 |
+
def forward(self, x):
|
235 |
+
return self.proj(x)
|
236 |
+
|
237 |
+
|
238 |
+
class Conversation():
|
239 |
+
def __init__(self, output=None, input_prompt=None, prompt=None):
|
240 |
+
if output is None:
|
241 |
+
self.messages = []
|
242 |
+
else:
|
243 |
+
self.messages = []
|
244 |
+
self.append_message(prompt, input_prompt, output)
|
245 |
+
|
246 |
+
def append_message(self, output, input_prompt, prompt, show_images):
|
247 |
+
# print(output)
|
248 |
+
# print(input_prompt)
|
249 |
+
# print(prompt)
|
250 |
+
# print(show_images)
|
251 |
+
self.messages.append((output, input_prompt, prompt, show_images))
|
252 |
+
|
253 |
+
def to_gradio_chatbot(self, show_images=None, output_text=None):
|
254 |
+
# return a list
|
255 |
+
if show_images is None:
|
256 |
+
show_images = self.messages[-1][3]
|
257 |
+
output_text = self.messages[-1][0]
|
258 |
+
return [
|
259 |
+
[show_images, output_text]
|
260 |
+
]
|
261 |
+
|
262 |
+
def get_info(self):
|
263 |
+
return self.messages[-1][0], self.messages[-1][1]
|
264 |
+
|
265 |
+
|
266 |
+
class ConversationBuffer():
|
267 |
+
def __init__(self, input_text):
|
268 |
+
self.buffer_ = []
|
269 |
+
self.buffer.append(input_text)
|
270 |
+
|
271 |
+
|
272 |
+
def init_conv():
|
273 |
+
conv = Conversation()
|
274 |
+
return conv
|
275 |
+
|
276 |
+
|
277 |
+
def get_processor(X, config, device, pretrained_checkpoint_tower, model_path = 'LanguageBind/MotionLLM-7B'):
|
278 |
+
mm_backbone_mlp_model = LlavaMetaModel(config, pretrained_checkpoint_tower)
|
279 |
+
|
280 |
+
processor = {}
|
281 |
+
if 'Image' in X:
|
282 |
+
image_tower = mm_backbone_mlp_model.get_image_tower() # LanguageBindImageTower()
|
283 |
+
if not image_tower.is_loaded:
|
284 |
+
image_tower.load_model()
|
285 |
+
image_tower.to(device=device, dtype=torch.float16)
|
286 |
+
image_processor = image_tower.image_processor
|
287 |
+
processor['image'] = image_processor
|
288 |
+
if 'Video' in X:
|
289 |
+
video_tower = mm_backbone_mlp_model.get_video_tower()
|
290 |
+
if not video_tower.is_loaded:
|
291 |
+
video_tower.load_model()
|
292 |
+
video_tower.to(device=device, dtype=torch.float16)
|
293 |
+
video_processor = video_tower.video_processor
|
294 |
+
processor['video'] = video_processor
|
295 |
+
|
296 |
+
return mm_backbone_mlp_model, processor
|
297 |
+
|
298 |
+
|
299 |
+
def motionllm(
|
300 |
+
args,
|
301 |
+
input_video_path: str,
|
302 |
+
text_en_in: str,
|
303 |
+
quantize: Optional[str] = None,
|
304 |
+
dtype: str = "float32",
|
305 |
+
max_new_tokens: int = 200,
|
306 |
+
top_k: int = 200,
|
307 |
+
temperature: float = 0.8,
|
308 |
+
accelerator: str = "auto",):
|
309 |
+
|
310 |
+
video_tensor = video_processor(input_video_path, return_tensors='pt')['pixel_values']
|
311 |
+
|
312 |
+
if type(video_tensor) is list:
|
313 |
+
tensor = [video.to('cuda', dtype=torch.float16) for video in video_tensor]
|
314 |
+
else:
|
315 |
+
tensor = video_tensor.to('cuda', dtype=torch.float16) # (1,3,8,224,224)
|
316 |
+
|
317 |
+
X_modalities = [tensor,['video']]
|
318 |
+
video_feature = mm_backbone_mlp_model.get_multimodal_embeddings(X_modalities)
|
319 |
+
prompt = text_en_in
|
320 |
+
input_prompt = prompt
|
321 |
+
|
322 |
+
sample = {"instruction": prompt, "input": input_video_path}
|
323 |
+
|
324 |
+
prefix = generate_prompt_mlp(sample)
|
325 |
+
pre = torch.cat((tokenizer.encode(prefix.split('INPUT_VIDEO: ')[0] + "\n", bos=True, eos=False, device=model.device).view(1, -1), tokenizer.encode("INPUT_VIDEO: ", bos=False, eos=False, device=model.device).view(1, -1)), dim=1)
|
326 |
+
|
327 |
+
prompt = (pre, ". ASSISTANT: ")
|
328 |
+
encoded = (prompt[0], video_feature[0], tokenizer.encode(prompt[1], bos=False, eos=False, device=model.device).view(1, -1))
|
329 |
+
|
330 |
+
t0 = time.perf_counter()
|
331 |
+
|
332 |
+
output_seq = generate_(
|
333 |
+
model,
|
334 |
+
idx=encoded,
|
335 |
+
max_seq_length=4096,
|
336 |
+
max_new_tokens=max_new_tokens,
|
337 |
+
temperature=temperature,
|
338 |
+
top_k=top_k,
|
339 |
+
eos_id=tokenizer.eos_id,
|
340 |
+
tokenizer = tokenizer,
|
341 |
+
)
|
342 |
+
outputfull = tokenizer.decode(output_seq)
|
343 |
+
output = outputfull.split("ASSISTANT:")[-1].strip()
|
344 |
+
print("================================")
|
345 |
+
print(output)
|
346 |
+
print("================================")
|
347 |
+
|
348 |
+
return output, input_prompt, prompt
|
349 |
+
|
350 |
+
|
351 |
+
def save_image_to_local(image):
|
352 |
+
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg')
|
353 |
+
image = Image.open(image)
|
354 |
+
image.save(filename)
|
355 |
+
# print(filename)
|
356 |
+
return filename
|
357 |
+
|
358 |
+
|
359 |
+
def save_video_to_local(video_path):
|
360 |
+
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4')
|
361 |
+
shutil.copyfile(video_path, filename)
|
362 |
+
return filename
|
363 |
+
|
364 |
+
|
365 |
+
def generate(image1, video, textbox_in, first_run, state, images_tensor):
|
366 |
+
flag = 1
|
367 |
+
|
368 |
+
image1 = image1 if image1 else "none"
|
369 |
+
video = video if video else "none"
|
370 |
+
|
371 |
+
if type(state) is not Conversation:
|
372 |
+
state = init_conv()
|
373 |
+
images_tensor = [[], []]
|
374 |
+
|
375 |
+
first_run = False if len(state.messages) > 0 else True
|
376 |
+
text_en_in = textbox_in.replace("picture", "image")
|
377 |
+
output, input_prompt, prompt = motionllm(args, video, text_en_in)
|
378 |
+
|
379 |
+
text_en_out = output
|
380 |
+
textbox_out = text_en_out
|
381 |
+
|
382 |
+
show_images = ""
|
383 |
+
if os.path.exists(image1):
|
384 |
+
filename = save_image_to_local(image1)
|
385 |
+
show_images += f'<img src="./file={filename}" style="display: inline-block;width: 250px;max-height: 400px;">'
|
386 |
+
|
387 |
+
if os.path.exists(video):
|
388 |
+
filename = save_video_to_local(video)
|
389 |
+
show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={filename}"></video>'
|
390 |
+
|
391 |
+
show_images = textbox_in + "\n" + show_images
|
392 |
+
state.append_message(output, input_prompt, prompt, show_images)
|
393 |
+
|
394 |
+
torch.cuda.empty_cache()
|
395 |
+
|
396 |
+
return (state, state.to_gradio_chatbot(show_images, output), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(image1) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True))
|
397 |
+
|
398 |
+
def regenerate(state):
|
399 |
+
if len(state.messages) > 0:
|
400 |
+
tobot = state.to_gradio_chatbot()
|
401 |
+
tobot[-1][1] = None
|
402 |
+
textbox = state.messages[-1][1]
|
403 |
+
state.messages.pop(-1)
|
404 |
+
return state, tobot, False, textbox
|
405 |
+
return (state, [], True)
|
406 |
+
|
407 |
+
|
408 |
+
def clear_history(state):
|
409 |
+
state = init_conv()
|
410 |
+
try:
|
411 |
+
tgt = state.to_gradio_chatbot()
|
412 |
+
except:
|
413 |
+
tgt = [None, None]
|
414 |
+
return (gr.update(value=None, interactive=True),
|
415 |
+
gr.update(value=None, interactive=True),\
|
416 |
+
gr.update(value=None, interactive=True),\
|
417 |
+
True, state, tgt, [[], []])
|
418 |
+
|
419 |
+
|
420 |
+
def get_md5(file_path):
|
421 |
+
hash_md5 = hashlib.md5()
|
422 |
+
with open(file_path, "rb") as f:
|
423 |
+
for chunk in iter(lambda: f.read(4096), b""):
|
424 |
+
hash_md5.update(chunk)
|
425 |
+
return hash_md5.hexdigest()
|
426 |
+
|
427 |
+
|
428 |
+
def logging_up(video, state):
|
429 |
+
try:
|
430 |
+
state.get_info()
|
431 |
+
except:
|
432 |
+
return False
|
433 |
+
action = "upvote"
|
434 |
+
# Get the current time
|
435 |
+
current_time = str(time.time())
|
436 |
+
|
437 |
+
# Create an md5 object
|
438 |
+
hash_object = hashlib.md5(current_time.encode())
|
439 |
+
|
440 |
+
# Get the hexadecimal representation of the hash
|
441 |
+
md5_hash = get_md5(video) + "-" + hash_object.hexdigest()
|
442 |
+
|
443 |
+
command = f"cp {video} ./feedback/{action}/mp4/{md5_hash}.mp4"
|
444 |
+
os.system(command)
|
445 |
+
with open (f"./feedback/{action}/txt/{md5_hash}.txt", "w") as f:
|
446 |
+
out, prp = state.get_info()
|
447 |
+
f.write(f"==========\nPrompt: {prp}\n==========\nOutput: {out}==========\n")
|
448 |
+
return True
|
449 |
+
|
450 |
+
|
451 |
+
def logging_down(video, state):
|
452 |
+
try:
|
453 |
+
state.get_info()
|
454 |
+
except:
|
455 |
+
return False
|
456 |
+
action = "downvote"
|
457 |
+
# Get the current time
|
458 |
+
current_time = str(time.time())
|
459 |
+
|
460 |
+
# Create an md5 object
|
461 |
+
hash_object = hashlib.md5(current_time.encode())
|
462 |
+
|
463 |
+
# Get the hexadecimal representation of the hash
|
464 |
+
md5_hash = get_md5(video) + "-" + hash_object.hexdigest()
|
465 |
+
|
466 |
+
command = f"cp {video} ./feedback/{action}/mp4/{md5_hash}.mp4"
|
467 |
+
os.system(command)
|
468 |
+
with open (f"./feedback/{action}/txt/{md5_hash}.txt", "w") as f:
|
469 |
+
out, prp = state.get_info()
|
470 |
+
f.write(f"==========\nPrompt: {prp}\n==========\nOutput: {out}==========\n")
|
471 |
+
return True
|
472 |
+
|
473 |
+
|
474 |
+
torch.set_float32_matmul_precision("high")
|
475 |
+
warnings.filterwarnings('ignore')
|
476 |
+
args = option.get_args_parser()
|
477 |
+
|
478 |
+
conv_mode = "llava_v1"
|
479 |
+
model_path = 'LanguageBind/Video-LLaVA-7B'
|
480 |
+
device = 'cuda'
|
481 |
+
load_8bit = False
|
482 |
+
load_4bit = True
|
483 |
+
dtype = torch.float16
|
484 |
+
|
485 |
+
if not os.path.exists("temp"):
|
486 |
+
os.makedirs("temp")
|
487 |
+
|
488 |
+
lora_path = Path(args.lora_path)
|
489 |
+
pretrained_llm_path = Path(f"./checkpoints/vicuna-7b-v1.5/lit_model.pth")
|
490 |
+
tokenizer_llm_path = Path("./checkpoints/vicuna-7b-v1.5/tokenizer.model")
|
491 |
+
|
492 |
+
# assert lora_path.is_file()
|
493 |
+
assert pretrained_llm_path.is_file()
|
494 |
+
assert tokenizer_llm_path.is_file()
|
495 |
+
|
496 |
+
accelerator = "auto"
|
497 |
+
fabric = L.Fabric(accelerator=accelerator, devices=1)
|
498 |
+
|
499 |
+
dtype = "float32"
|
500 |
+
dt = getattr(torch, dtype, None)
|
501 |
+
if not isinstance(dt, torch.dtype):
|
502 |
+
raise ValueError(f"{dtype} is not a valid dtype.")
|
503 |
+
dtype = dt
|
504 |
+
|
505 |
+
quantize = None
|
506 |
+
t0 = time.time()
|
507 |
+
|
508 |
+
with EmptyInitOnDevice(
|
509 |
+
device=fabric.device, dtype=dtype, quantization_mode=quantize
|
510 |
+
), lora(r=args.lora_r, alpha=args.lora_alpha, dropout=args.lora_dropout, enabled=True):
|
511 |
+
checkpoint_dir = Path("checkpoints/vicuna-7b-v1.5")
|
512 |
+
lora_query = True
|
513 |
+
lora_key = False
|
514 |
+
lora_value = True
|
515 |
+
lora_projection = False
|
516 |
+
lora_mlp = False
|
517 |
+
lora_head = False
|
518 |
+
config = Config.from_name(
|
519 |
+
name=checkpoint_dir.name,
|
520 |
+
r=args.lora_r,
|
521 |
+
alpha=args.lora_alpha,
|
522 |
+
dropout=args.lora_dropout,
|
523 |
+
to_query=lora_query,
|
524 |
+
to_key=lora_key,
|
525 |
+
to_value=lora_value,
|
526 |
+
to_projection=lora_projection,
|
527 |
+
to_mlp=lora_mlp,
|
528 |
+
to_head=lora_head,
|
529 |
+
)
|
530 |
+
model = GPT(config).bfloat16()
|
531 |
+
|
532 |
+
mlp_path = args.mlp_path
|
533 |
+
pretrained_checkpoint_mlp = torch.load(mlp_path)
|
534 |
+
|
535 |
+
X = ['Video']
|
536 |
+
|
537 |
+
mm_backbone_mlp_model, processor = get_processor(X, args, 'cuda', pretrained_checkpoint_mlp, model_path = 'LanguageBind/Video-LLaVA-7B')
|
538 |
+
video_processor = processor['video']
|
539 |
+
|
540 |
+
linear_proj = mm_backbone_mlp_model.mm_projector
|
541 |
+
|
542 |
+
# 1. Load the pretrained weights
|
543 |
+
pretrained_llm_checkpoint = lazy_load(pretrained_llm_path)
|
544 |
+
# 2. Load the fine-tuned LoRA weights
|
545 |
+
lora_checkpoint = lazy_load(lora_path)
|
546 |
+
# 3. merge the two checkpoints
|
547 |
+
model_state_dict = {**pretrained_llm_checkpoint, **lora_checkpoint}
|
548 |
+
model.load_state_dict(model_state_dict, strict=True)
|
549 |
+
print('Load llm base model from', pretrained_llm_path)
|
550 |
+
print('Load lora model from', lora_path)
|
551 |
+
|
552 |
+
# load mlp again, to en sure, not neccessary actually
|
553 |
+
linear_proj.load_state_dict(pretrained_checkpoint_mlp)
|
554 |
+
linear_proj = linear_proj.cuda()
|
555 |
+
print('Load mlp model again from', mlp_path)
|
556 |
+
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
|
557 |
+
|
558 |
+
model.eval()
|
559 |
+
model = fabric.setup_module(model)
|
560 |
+
linear_proj.eval()
|
561 |
+
|
562 |
+
tokenizer = Tokenizer(tokenizer_llm_path)
|
563 |
+
print('Load tokenizer from', tokenizer_llm_path)
|
564 |
+
|
565 |
+
print(torch.cuda.memory_allocated())
|
566 |
+
print(torch.cuda.max_memory_allocated())
|
567 |
+
|
568 |
+
|
569 |
+
app = FastAPI()
|
570 |
+
|
571 |
+
textbox = gr.Textbox(
|
572 |
+
show_label=False, placeholder="Enter text and press ENTER", container=False
|
573 |
+
)
|
574 |
+
|
575 |
+
with gr.Blocks(title='MotionLLM', theme=gr.themes.Default(), css=block_css) as demo:
|
576 |
+
gr.Markdown(title_markdown)
|
577 |
+
state = gr.State()
|
578 |
+
buffer_ = gr.State()
|
579 |
+
first_run = gr.State()
|
580 |
+
images_tensor = gr.State()
|
581 |
+
|
582 |
+
with gr.Row():
|
583 |
+
with gr.Column(scale=3):
|
584 |
+
image1 = gr.State()
|
585 |
+
video = gr.Video(label="Input Video")
|
586 |
+
|
587 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
588 |
+
gr.Examples(
|
589 |
+
examples=[
|
590 |
+
[
|
591 |
+
f"{cur_dir}/examples/Play_Electric_guitar_16_clip1.mp4",
|
592 |
+
"why is the girl so happy",
|
593 |
+
],
|
594 |
+
[
|
595 |
+
f"{cur_dir}/examples/guoyoucai.mov",
|
596 |
+
"what is the feeling of him",
|
597 |
+
],
|
598 |
+
[
|
599 |
+
f"{cur_dir}/examples/sprint_run_18_clip1.mp4",
|
600 |
+
"Why is the man running so fast?",
|
601 |
+
],
|
602 |
+
[
|
603 |
+
f"{cur_dir}/examples/lift_weight.mp4",
|
604 |
+
"Assume you are a fitness coach, refer to the video of the professional athlete, please analyze specific action essentials in steps and give detailed instruction.",
|
605 |
+
],
|
606 |
+
[
|
607 |
+
f"{cur_dir}/examples/Shaolin_Kung_Fu_Wushu_Selfdefense_Sword_Form_Session_22_clip3.mp4",
|
608 |
+
"wow, can you teach me the motion, step by step in detail",
|
609 |
+
],
|
610 |
+
[
|
611 |
+
f"{cur_dir}/examples/mabaoguo.mp4",
|
612 |
+
"why is the video funny?",
|
613 |
+
],
|
614 |
+
[
|
615 |
+
f"{cur_dir}/examples/COBRA_PUSH_UPS_clip2.mp4",
|
616 |
+
"describe the body movement of the woman",
|
617 |
+
],
|
618 |
+
[
|
619 |
+
f"{cur_dir}/examples/sample_demo_1.mp4",
|
620 |
+
"Why is this video interesting?",
|
621 |
+
],
|
622 |
+
],
|
623 |
+
inputs=[video, textbox],
|
624 |
+
)
|
625 |
+
|
626 |
+
with gr.Column(scale=7):
|
627 |
+
chatbot = gr.Chatbot(label="MotionLLM", bubble_full_width=True).style(height=875)
|
628 |
+
with gr.Row():
|
629 |
+
with gr.Column(scale=8):
|
630 |
+
textbox.render()
|
631 |
+
with gr.Column(scale=1, min_width=50):
|
632 |
+
submit_btn = gr.Button(
|
633 |
+
value="Send", variant="primary", interactive=True
|
634 |
+
)
|
635 |
+
with gr.Row(elem_id="buttons") as button_row:
|
636 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
|
637 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
|
638 |
+
flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
|
639 |
+
# stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
|
640 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
|
641 |
+
clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
|
642 |
+
|
643 |
+
gr.Markdown(tos_markdown)
|
644 |
+
gr.Markdown(learn_more_markdown)
|
645 |
+
|
646 |
+
tmp = gr.State()
|
647 |
+
upvote_btn.click(logging_up, [video, state], [tmp])
|
648 |
+
|
649 |
+
downvote_btn.click(logging_down, [video, state], [tmp])
|
650 |
+
|
651 |
+
submit_btn.click(generate, [image1, video, textbox, first_run, state, images_tensor],
|
652 |
+
[state, chatbot, first_run, textbox, images_tensor, image1, video])
|
653 |
+
|
654 |
+
regenerate_btn.click(regenerate, [state], [state, chatbot, first_run, textbox]).then(
|
655 |
+
generate, [image1, video, textbox, first_run, state, images_tensor], [state, chatbot, first_run, textbox, images_tensor, image1, video])
|
656 |
+
|
657 |
+
clear_btn.click(clear_history, [state],
|
658 |
+
[image1, video, textbox, first_run, state, chatbot, images_tensor])
|
659 |
+
|
660 |
+
app = gr.mount_gradio_app(app, demo, path="/")
|
661 |
+
uvicorn.run(app, host="0.0.0.0", port=6657)
|
assets/application.png
ADDED
assets/compare.png
ADDED
assets/highlight.png
ADDED
assets/logo.png
ADDED
assets/system.png
ADDED
generate.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import time
|
3 |
+
import warnings
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import lightning as L
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from lit_llama import LLaMA, Tokenizer
|
11 |
+
from lit_llama.utils import EmptyInitOnDevice, lazy_load
|
12 |
+
|
13 |
+
|
14 |
+
@torch.no_grad()
|
15 |
+
def generate(
|
16 |
+
model: torch.nn.Module,
|
17 |
+
idx: torch.Tensor,
|
18 |
+
max_new_tokens: int,
|
19 |
+
max_seq_length: int,
|
20 |
+
temperature: float = 1.0,
|
21 |
+
top_k: Optional[int] = None,
|
22 |
+
eos_id: Optional[int] = None,
|
23 |
+
tokenizer = None,
|
24 |
+
) -> torch.Tensor:
|
25 |
+
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
|
26 |
+
|
27 |
+
The implementation of this function is modified from A. Karpathy's nanoGPT.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
model: The model to use.
|
31 |
+
idx: Tensor of shape (T) with indices of the prompt sequence.
|
32 |
+
max_new_tokens: The number of new tokens to generate.
|
33 |
+
max_seq_length: The maximum sequence length allowed.
|
34 |
+
temperature: Scales the predicted logits by 1 / temperature
|
35 |
+
top_k: If specified, only sample among the tokens with the k highest probabilities
|
36 |
+
eos_id: If specified, stop generating any more token once the <eos> token is triggered
|
37 |
+
"""
|
38 |
+
# create an empty tensor of the expected final shape and fill in the current tokens
|
39 |
+
# import pdb; pdb.set_trace()
|
40 |
+
if type(idx) == tuple:
|
41 |
+
# import pdb; pdb.set_trace()
|
42 |
+
T = idx[0].shape[-1] + idx[2].shape[-1] + len(idx[1])
|
43 |
+
before_len = idx[0].shape[-1]
|
44 |
+
catted = torch.cat((idx[0], torch.zeros((1, len(idx[1]))).cuda(), idx[2]), dim=1).long()
|
45 |
+
idx = (catted, idx[1], before_len)
|
46 |
+
T_new = T + max_new_tokens
|
47 |
+
# import pdb; pdb.set_trace()
|
48 |
+
empty = torch.empty(T_new, dtype=idx[0].dtype, device=idx[0].device)
|
49 |
+
empty = torch.empty(T_new, dtype=idx[0].dtype, device=idx[0].device)
|
50 |
+
empty[:T] = idx[0]
|
51 |
+
idx = (empty, idx[1], [before_len])
|
52 |
+
# import pdb; pdb.set_trace()
|
53 |
+
else:
|
54 |
+
# import pdb; pdb.set_trace()
|
55 |
+
T = idx.size(0)
|
56 |
+
T_new = T + max_new_tokens
|
57 |
+
empty = torch.empty(T_new, dtype=idx.dtype, device=idx.device)
|
58 |
+
empty[:T] = idx
|
59 |
+
idx = empty
|
60 |
+
|
61 |
+
# generate max_new_tokens tokens
|
62 |
+
# import pdb; pdb.set_trace()
|
63 |
+
for t in range(T, T_new):
|
64 |
+
if type(idx) == tuple:
|
65 |
+
idx_cond = idx[0][:t]
|
66 |
+
tmp = idx_cond if T <= max_seq_length else idx_cond[-max_seq_length:]
|
67 |
+
# import pdb; pdb.set_trace()
|
68 |
+
idx_cond = (tmp.view(1, -1), idx[1].unsqueeze(0), idx[2])
|
69 |
+
else:
|
70 |
+
# ignore the not-filled-yet tokens
|
71 |
+
idx_cond = idx[:t]
|
72 |
+
# if the sequence context is growing too long we must crop it at max_seq_length
|
73 |
+
idx_cond = idx_cond if T <= max_seq_length else idx_cond[-max_seq_length:]
|
74 |
+
|
75 |
+
# forward
|
76 |
+
if type(idx) == tuple:
|
77 |
+
logits = model(idx_cond, maxlen=idx_cond[0].size(1))
|
78 |
+
else:
|
79 |
+
logits = model(idx_cond.view(1, -1))
|
80 |
+
logits = logits[0, -1] / temperature
|
81 |
+
|
82 |
+
# import pdb; pdb.set_trace()
|
83 |
+
# optionally crop the logits to only the top k options
|
84 |
+
if top_k is not None:
|
85 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
86 |
+
logits[logits < v[[-1]]] = -float("Inf")
|
87 |
+
|
88 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
89 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
90 |
+
|
91 |
+
# concatenate the new generation
|
92 |
+
if type(idx) == tuple:
|
93 |
+
seq = idx[0]
|
94 |
+
seq[t] = idx_next
|
95 |
+
idx = (seq, idx[1], idx[2])
|
96 |
+
else:
|
97 |
+
idx[t] = idx_next
|
98 |
+
|
99 |
+
# if <eos> token is triggered, return the output (stop generation)
|
100 |
+
if idx_next == eos_id:
|
101 |
+
if type(idx) == tuple:
|
102 |
+
return idx[0][:t+1]
|
103 |
+
else:
|
104 |
+
return idx[:t + 1] # include the EOS token
|
105 |
+
if type(idx) == tuple:
|
106 |
+
return idx[0]
|
107 |
+
else:
|
108 |
+
return idx
|
109 |
+
|
110 |
+
|
111 |
+
def main(
|
112 |
+
prompt: str = "Hello, my name is",
|
113 |
+
*,
|
114 |
+
num_samples: int = 1,
|
115 |
+
max_new_tokens: int = 50,
|
116 |
+
top_k: int = 200,
|
117 |
+
temperature: float = 0.8,
|
118 |
+
checkpoint_path: Optional[Path] = None,
|
119 |
+
tokenizer_path: Optional[Path] = None,
|
120 |
+
model_size: str = "7B",
|
121 |
+
quantize: Optional[str] = None,
|
122 |
+
) -> None:
|
123 |
+
"""Generates text samples based on a pre-trained LLaMA model and tokenizer.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
prompt: The prompt string to use for generating the samples.
|
127 |
+
num_samples: The number of text samples to generate.
|
128 |
+
max_new_tokens: The number of generation steps to take.
|
129 |
+
top_k: The number of top most probable tokens to consider in the sampling process.
|
130 |
+
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
|
131 |
+
samples.
|
132 |
+
checkpoint_path: The checkpoint path to load.
|
133 |
+
tokenizer_path: The tokenizer path to load.
|
134 |
+
model_size: The model size to load.
|
135 |
+
quantize: Whether to quantize the model and using which method:
|
136 |
+
``"llm.int8"``: LLM.int8() mode,
|
137 |
+
``"gptq.int4"``: GPTQ 4-bit mode.
|
138 |
+
"""
|
139 |
+
if not checkpoint_path:
|
140 |
+
checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/lit-llama.pth")
|
141 |
+
if not tokenizer_path:
|
142 |
+
tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
|
143 |
+
assert checkpoint_path.is_file(), checkpoint_path
|
144 |
+
assert tokenizer_path.is_file(), tokenizer_path
|
145 |
+
|
146 |
+
fabric = L.Fabric(accelerator="cuda", devices=1)
|
147 |
+
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
148 |
+
|
149 |
+
print("Loading model ...", file=sys.stderr)
|
150 |
+
t0 = time.time()
|
151 |
+
with EmptyInitOnDevice(
|
152 |
+
device=fabric.device, dtype=dtype, quantization_mode=quantize
|
153 |
+
):
|
154 |
+
model = LLaMA.from_name(model_size)
|
155 |
+
|
156 |
+
checkpoint = lazy_load(checkpoint_path)
|
157 |
+
model.load_state_dict(checkpoint)
|
158 |
+
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
|
159 |
+
|
160 |
+
model.eval()
|
161 |
+
model = fabric.setup_module(model)
|
162 |
+
|
163 |
+
tokenizer = Tokenizer(tokenizer_path)
|
164 |
+
encoded_prompt = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
|
165 |
+
|
166 |
+
L.seed_everything(1234)
|
167 |
+
t0 = time.perf_counter()
|
168 |
+
|
169 |
+
for _ in range(num_samples):
|
170 |
+
y = generate(
|
171 |
+
model,
|
172 |
+
encoded_prompt,
|
173 |
+
max_new_tokens,
|
174 |
+
model.config.block_size, # type: ignore[union-attr,arg-type]
|
175 |
+
temperature=temperature,
|
176 |
+
top_k=top_k,
|
177 |
+
)
|
178 |
+
print(tokenizer.decode(y))
|
179 |
+
|
180 |
+
t = time.perf_counter() - t0
|
181 |
+
print(f"\n\nTime for inference: {t:.02f} sec total, {num_samples * max_new_tokens / t:.02f} tokens/sec", file=sys.stderr)
|
182 |
+
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
|
183 |
+
|
184 |
+
|
185 |
+
if __name__ == "__main__":
|
186 |
+
from jsonargparse import CLI
|
187 |
+
|
188 |
+
torch.set_float32_matmul_precision("high")
|
189 |
+
warnings.filterwarnings(
|
190 |
+
# Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
|
191 |
+
"ignore",
|
192 |
+
message="ComplexHalf support is experimental and many operators don't support it yet"
|
193 |
+
)
|
194 |
+
warnings.filterwarnings(
|
195 |
+
# Triggered in bitsandbytes/autograd/_functions.py:298
|
196 |
+
"ignore",
|
197 |
+
message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization",
|
198 |
+
)
|
199 |
+
CLI(main)
|
lit_gpt/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lit_gpt.model import GPT
|
2 |
+
from lit_gpt.config import Config
|
3 |
+
from lit_gpt.tokenizer import Tokenizer
|
4 |
+
|
5 |
+
from lightning_utilities.core.imports import RequirementCache
|
6 |
+
|
7 |
+
_LIGHTNING_AVAILABLE = RequirementCache("lightning>=2.1.0.dev0")
|
8 |
+
# if not bool(_LIGHTNING_AVAILABLE):
|
9 |
+
# raise ImportError(
|
10 |
+
# "Lit-GPT requires lightning==2.1. Please run:\n"
|
11 |
+
# f" pip uninstall -y lightning; pip install -r requirements.txt\n{str(_LIGHTNING_AVAILABLE)}"
|
12 |
+
# )
|
13 |
+
|
14 |
+
|
15 |
+
__all__ = ["GPT", "Config", "Tokenizer"]
|
lit_gpt/adapter.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Implementation of the paper:
|
2 |
+
|
3 |
+
LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
|
4 |
+
https://arxiv.org/abs/2303.16199
|
5 |
+
|
6 |
+
Port for Lit-GPT
|
7 |
+
"""
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from typing_extensions import Self
|
14 |
+
|
15 |
+
from lit_gpt.config import Config as BaseConfig
|
16 |
+
from lit_gpt.model import GPT as BaseModel
|
17 |
+
from lit_gpt.model import Block as BaseBlock
|
18 |
+
from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class Config(BaseConfig):
|
23 |
+
adapter_prompt_length: int = 10
|
24 |
+
adapter_start_layer: int = 2
|
25 |
+
|
26 |
+
|
27 |
+
class GPT(BaseModel):
|
28 |
+
"""The implementation is identical to `lit_gpt.model.GPT` with the exception that
|
29 |
+
the `Block` saves the layer index and passes it down to the attention layer."""
|
30 |
+
|
31 |
+
def __init__(self, config: Config) -> None:
|
32 |
+
nn.Module.__init__(self)
|
33 |
+
assert config.padded_vocab_size is not None
|
34 |
+
self.config = config
|
35 |
+
|
36 |
+
self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
|
37 |
+
self.transformer = nn.ModuleDict(
|
38 |
+
dict(
|
39 |
+
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
|
40 |
+
h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
|
41 |
+
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
|
42 |
+
)
|
43 |
+
)
|
44 |
+
self.max_seq_length = self.config.block_size
|
45 |
+
self.mask_cache: Optional[torch.Tensor] = None
|
46 |
+
|
47 |
+
def forward(
|
48 |
+
self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0
|
49 |
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
50 |
+
T = idx.size(1)
|
51 |
+
if self.max_seq_length < T:
|
52 |
+
raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")
|
53 |
+
|
54 |
+
if input_pos is not None: # use the kv cache
|
55 |
+
cos = self.cos.index_select(0, input_pos)
|
56 |
+
sin = self.sin.index_select(0, input_pos)
|
57 |
+
if self.mask_cache is None:
|
58 |
+
raise TypeError("You need to call `gpt.set_kv_cache()`")
|
59 |
+
mask = self.mask_cache.index_select(2, input_pos)
|
60 |
+
else:
|
61 |
+
cos = self.cos[:T]
|
62 |
+
sin = self.sin[:T]
|
63 |
+
mask = None
|
64 |
+
|
65 |
+
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
66 |
+
for block in self.transformer.h:
|
67 |
+
x = block(x, cos, sin, mask, input_pos)
|
68 |
+
x = self.transformer.ln_f(x)
|
69 |
+
if lm_head_chunk_size > 0:
|
70 |
+
# chunk the lm head logits to reduce the peak memory used by autograd
|
71 |
+
return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)]
|
72 |
+
return self.lm_head(x) # (b, t, vocab_size)
|
73 |
+
|
74 |
+
@classmethod
|
75 |
+
def from_name(cls, name: str, **kwargs: Any) -> Self:
|
76 |
+
return cls(Config.from_name(name, **kwargs))
|
77 |
+
|
78 |
+
def _init_weights(self, module: nn.Module) -> None:
|
79 |
+
"""Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
|
80 |
+
super()._init_weights(module)
|
81 |
+
if isinstance(module, CausalSelfAttention):
|
82 |
+
module.reset_parameters()
|
83 |
+
|
84 |
+
|
85 |
+
class Block(BaseBlock):
|
86 |
+
"""The implementation is identical to `lit_gpt.model.Block` with the exception that
|
87 |
+
we replace the attention layer where adaption is implemented."""
|
88 |
+
|
89 |
+
def __init__(self, config: Config, block_idx: int) -> None:
|
90 |
+
# Skip the parent class __init__ altogether and replace it to avoid useless allocations
|
91 |
+
nn.Module.__init__(self)
|
92 |
+
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
|
93 |
+
self.attn = CausalSelfAttention(config, block_idx)
|
94 |
+
if not config.shared_attention_norm:
|
95 |
+
self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
|
96 |
+
self.mlp = config.mlp_class(config)
|
97 |
+
|
98 |
+
self.config = config
|
99 |
+
|
100 |
+
|
101 |
+
class CausalSelfAttention(BaseCausalSelfAttention):
|
102 |
+
"""A modification of `lit_gpt.model.CausalSelfAttention` that adds the attention
|
103 |
+
over the adaption prompt."""
|
104 |
+
|
105 |
+
def __init__(self, config: Config, block_idx: int) -> None:
|
106 |
+
super().__init__(config)
|
107 |
+
if block_idx >= config.adapter_start_layer:
|
108 |
+
# adapter embedding layer
|
109 |
+
self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
|
110 |
+
# gate for adaption
|
111 |
+
self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
|
112 |
+
# kv cache for inference
|
113 |
+
self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
114 |
+
self.block_idx = block_idx
|
115 |
+
|
116 |
+
def scaled_dot_product_attention(
|
117 |
+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
|
118 |
+
) -> torch.Tensor:
|
119 |
+
y = super().scaled_dot_product_attention(q, k, v, mask)
|
120 |
+
if self.block_idx < self.config.adapter_start_layer:
|
121 |
+
return y
|
122 |
+
|
123 |
+
aT = self.config.adapter_prompt_length
|
124 |
+
if self.adapter_kv_cache is not None:
|
125 |
+
# since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av
|
126 |
+
# are the same every call
|
127 |
+
ak, av = self.adapter_kv_cache
|
128 |
+
else:
|
129 |
+
prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd)
|
130 |
+
aqkv = self.attn(prefix)
|
131 |
+
q_per_kv = self.config.n_head // self.config.n_query_groups
|
132 |
+
aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)
|
133 |
+
aqkv = aqkv.permute(0, 2, 3, 1, 4)
|
134 |
+
_, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2)
|
135 |
+
if self.config.n_query_groups != 1:
|
136 |
+
# for MHA this is a no-op
|
137 |
+
ak = ak.repeat_interleave(q_per_kv, dim=2)
|
138 |
+
av = av.repeat_interleave(q_per_kv, dim=2)
|
139 |
+
ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs)
|
140 |
+
av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs)
|
141 |
+
self.adapter_kv_cache = (ak, av)
|
142 |
+
|
143 |
+
T = q.size(2)
|
144 |
+
amask = torch.ones(T, aT, dtype=torch.bool, device=q.device)
|
145 |
+
ay = super().scaled_dot_product_attention(q, ak, av, amask)
|
146 |
+
return y + self.gating_factor * ay
|
147 |
+
|
148 |
+
def reset_parameters(self) -> None:
|
149 |
+
torch.nn.init.zeros_(self.gating_factor)
|
150 |
+
|
151 |
+
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
|
152 |
+
"""For compatibility with older checkpoints."""
|
153 |
+
if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
|
154 |
+
state_dict[key] = state_dict[key].permute(0, 2, 1, 3)
|
155 |
+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
156 |
+
|
157 |
+
|
158 |
+
def mark_only_adapter_as_trainable(model: GPT) -> None:
|
159 |
+
"""Sets `requires_grad=False` for all non-adapter weights."""
|
160 |
+
for name, param in model.named_parameters():
|
161 |
+
param.requires_grad = adapter_filter(name, param)
|
162 |
+
|
163 |
+
|
164 |
+
def adapter_filter(key: str, value: Any) -> bool:
|
165 |
+
return "adapter_wte" in key or "gating_factor" in key
|
lit_gpt/adapter_v2.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Implementation of the paper:
|
2 |
+
|
3 |
+
LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model
|
4 |
+
https://arxiv.org/abs/2304.15010
|
5 |
+
|
6 |
+
Port for Lit-GPT
|
7 |
+
"""
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from typing import Any, Dict, Optional, Tuple, Type
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from typing_extensions import Self
|
14 |
+
|
15 |
+
import lit_gpt
|
16 |
+
from lit_gpt.adapter import GPT as BaseModel
|
17 |
+
from lit_gpt.adapter import Block as BaseBlock
|
18 |
+
from lit_gpt.adapter import CausalSelfAttention as BaseCausalSelfAttention
|
19 |
+
from lit_gpt.adapter import Config as BaseConfig
|
20 |
+
from lit_gpt.model import KVCache
|
21 |
+
from lit_gpt.utils import map_old_state_dict_weights
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class Config(BaseConfig):
|
26 |
+
@property
|
27 |
+
def mlp_class(self) -> Type:
|
28 |
+
return getattr(lit_gpt.adapter_v2, self._mlp_class)
|
29 |
+
|
30 |
+
|
31 |
+
def adapter_filter(key: str, value: Any) -> bool:
|
32 |
+
adapter_substrings = (
|
33 |
+
# regular adapter v1 parameters
|
34 |
+
"adapter_wte",
|
35 |
+
"gating_factor",
|
36 |
+
# adapter v2: new bias and scale used in Linear
|
37 |
+
"adapter_scale",
|
38 |
+
"adapter_bias",
|
39 |
+
# adapter v2: Norm parameters are now trainable
|
40 |
+
"norm_1",
|
41 |
+
"norm_2",
|
42 |
+
"ln_f",
|
43 |
+
)
|
44 |
+
return any(s in key for s in adapter_substrings)
|
45 |
+
|
46 |
+
|
47 |
+
class AdapterV2Linear(torch.nn.Module):
|
48 |
+
def __init__(self, in_features: int, out_features: int, **kwargs) -> None:
|
49 |
+
super().__init__()
|
50 |
+
self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
|
51 |
+
self.adapter_bias = torch.nn.Parameter(torch.zeros(out_features), requires_grad=False)
|
52 |
+
self.adapter_scale = torch.nn.Parameter(torch.ones(out_features), requires_grad=False)
|
53 |
+
|
54 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
55 |
+
return self.adapter_scale * (self.linear(x) + self.adapter_bias)
|
56 |
+
|
57 |
+
def reset_parameters(self) -> None:
|
58 |
+
nn.init.zeros_(self.adapter_bias)
|
59 |
+
nn.init.ones_(self.adapter_scale)
|
60 |
+
|
61 |
+
|
62 |
+
class GPT(BaseModel):
|
63 |
+
def __init__(self, config: Config) -> None:
|
64 |
+
# Skip the parent class __init__ altogether and replace it to avoid useless allocations
|
65 |
+
nn.Module.__init__(self)
|
66 |
+
assert config.padded_vocab_size is not None
|
67 |
+
self.config = config
|
68 |
+
|
69 |
+
self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
|
70 |
+
self.transformer = nn.ModuleDict(
|
71 |
+
dict(
|
72 |
+
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
|
73 |
+
h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
|
74 |
+
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
|
75 |
+
)
|
76 |
+
)
|
77 |
+
self.max_seq_length = self.config.block_size
|
78 |
+
self.mask_cache: Optional[torch.Tensor] = None
|
79 |
+
|
80 |
+
@classmethod
|
81 |
+
def from_name(cls, name: str, **kwargs: Any) -> Self:
|
82 |
+
return cls(Config.from_name(name, **kwargs))
|
83 |
+
|
84 |
+
def _init_weights(self, module: nn.Module) -> None:
|
85 |
+
"""Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
|
86 |
+
super()._init_weights(module)
|
87 |
+
if isinstance(module, AdapterV2Linear):
|
88 |
+
module.reset_parameters()
|
89 |
+
|
90 |
+
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
|
91 |
+
"""For compatibility with base checkpoints."""
|
92 |
+
mapping = {"lm_head.weight": "lm_head.linear.weight"}
|
93 |
+
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
|
94 |
+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
95 |
+
|
96 |
+
|
97 |
+
class Block(BaseBlock):
|
98 |
+
"""The implementation is identical to `lit_gpt.model.Block` with the exception that
|
99 |
+
we replace the attention layer where adaption is implemented."""
|
100 |
+
|
101 |
+
def __init__(self, config: Config, block_idx: int) -> None:
|
102 |
+
# Skip the parent class __init__ altogether and replace it to avoid useless allocations
|
103 |
+
nn.Module.__init__(self)
|
104 |
+
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
|
105 |
+
self.attn = CausalSelfAttention(config, block_idx)
|
106 |
+
if not config.shared_attention_norm:
|
107 |
+
self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
|
108 |
+
self.mlp = config.mlp_class(config)
|
109 |
+
|
110 |
+
self.config = config
|
111 |
+
|
112 |
+
|
113 |
+
class CausalSelfAttention(BaseCausalSelfAttention):
|
114 |
+
"""A modification of `lit_gpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class"""
|
115 |
+
|
116 |
+
def __init__(self, config: Config, block_idx: int) -> None:
|
117 |
+
# Skip the parent class __init__ altogether and replace it to avoid useless allocations
|
118 |
+
nn.Module.__init__(self)
|
119 |
+
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
|
120 |
+
# key, query, value projections for all heads, but in a batch
|
121 |
+
self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias)
|
122 |
+
# output projection
|
123 |
+
self.proj = AdapterV2Linear(config.n_embd, config.n_embd, bias=config.bias)
|
124 |
+
# disabled by default
|
125 |
+
self.kv_cache: Optional[KVCache] = None
|
126 |
+
|
127 |
+
if block_idx >= config.adapter_start_layer:
|
128 |
+
# adapter embedding layer
|
129 |
+
self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
|
130 |
+
# gate for adaption
|
131 |
+
self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
|
132 |
+
# kv cache for inference
|
133 |
+
self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
134 |
+
self.block_idx = block_idx
|
135 |
+
|
136 |
+
self.config = config
|
137 |
+
|
138 |
+
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
|
139 |
+
"""For compatibility with base checkpoints."""
|
140 |
+
mapping = {
|
141 |
+
"attn.weight": "attn.linear.weight",
|
142 |
+
"attn.bias": "attn.linear.bias",
|
143 |
+
"proj.weight": "proj.linear.weight",
|
144 |
+
"proj.bias": "proj.linear.bias",
|
145 |
+
}
|
146 |
+
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
|
147 |
+
# For compatibility with older checkpoints
|
148 |
+
if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
|
149 |
+
state_dict[key] = state_dict[key].permute(0, 2, 1, 3)
|
150 |
+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
151 |
+
|
152 |
+
|
153 |
+
class GptNeoxMLP(lit_gpt.model.GptNeoxMLP):
|
154 |
+
def __init__(self, config: Config) -> None:
|
155 |
+
nn.Module.__init__(self)
|
156 |
+
self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
|
157 |
+
self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)
|
158 |
+
|
159 |
+
self.config = config
|
160 |
+
|
161 |
+
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
|
162 |
+
"""For compatibility with base checkpoints."""
|
163 |
+
mapping = {
|
164 |
+
"fc.weight": "fc.linear.weight",
|
165 |
+
"fc.bias": "fc.linear.bias",
|
166 |
+
"proj.weight": "proj.linear.weight",
|
167 |
+
"proj.bias": "proj.linear.bias",
|
168 |
+
}
|
169 |
+
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
|
170 |
+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
171 |
+
|
172 |
+
|
173 |
+
class LLaMAMLP(lit_gpt.model.LLaMAMLP):
|
174 |
+
def __init__(self, config: Config) -> None:
|
175 |
+
nn.Module.__init__(self)
|
176 |
+
self.fc_1 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
|
177 |
+
self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
|
178 |
+
self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)
|
179 |
+
|
180 |
+
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
|
181 |
+
"""For compatibility with base checkpoints."""
|
182 |
+
mapping = {
|
183 |
+
"fc_1.weight": "fc_1.linear.weight",
|
184 |
+
"fc_1.bias": "fc_1.linear.bias",
|
185 |
+
"fc_2.weight": "fc_2.linear.weight",
|
186 |
+
"fc_2.bias": "fc_2.linear.bias",
|
187 |
+
"proj.weight": "proj.linear.weight",
|
188 |
+
"proj.bias": "proj.linear.bias",
|
189 |
+
}
|
190 |
+
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
|
191 |
+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
192 |
+
|
193 |
+
|
194 |
+
def mark_only_adapter_v2_as_trainable(model: GPT) -> None:
|
195 |
+
"""Sets requires_grad=False for all non-adapter weights"""
|
196 |
+
for name, param in model.named_parameters():
|
197 |
+
param.requires_grad = adapter_filter(name, param)
|
lit_gpt/config.py
ADDED
@@ -0,0 +1,1040 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Any, Literal, Optional, Type, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from typing_extensions import Self
|
8 |
+
|
9 |
+
import lit_gpt.model
|
10 |
+
from lit_gpt.utils import find_multiple
|
11 |
+
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class Config:
|
15 |
+
org: str = "Lightning-AI"
|
16 |
+
name: str = "lit-GPT"
|
17 |
+
block_size: int = 4096
|
18 |
+
vocab_size: int = 50254
|
19 |
+
padding_multiple: int = 512
|
20 |
+
padded_vocab_size: Optional[int] = None
|
21 |
+
n_layer: int = 16
|
22 |
+
n_head: int = 32
|
23 |
+
n_embd: int = 4096
|
24 |
+
rotary_percentage: float = 0.25
|
25 |
+
parallel_residual: bool = True
|
26 |
+
bias: bool = True
|
27 |
+
lm_head_bias: bool = False
|
28 |
+
# to use multi-head attention (MHA), set this to `n_head` (default)
|
29 |
+
# to use multi-query attention (MQA), set this to 1
|
30 |
+
# to use grouped-query attention (GQA), set this to a value in between
|
31 |
+
# Example with `n_head=4`
|
32 |
+
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
|
33 |
+
# │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
|
34 |
+
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
|
35 |
+
# │ │ │ │ │ │ │
|
36 |
+
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
|
37 |
+
# │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
|
38 |
+
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
|
39 |
+
# │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
|
40 |
+
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
|
41 |
+
# │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
|
42 |
+
# └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
|
43 |
+
# ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
|
44 |
+
# MHA GQA MQA
|
45 |
+
# n_query_groups=4 n_query_groups=2 n_query_groups=1
|
46 |
+
#
|
47 |
+
# credit https://arxiv.org/pdf/2305.13245.pdf
|
48 |
+
n_query_groups: Optional[int] = None
|
49 |
+
shared_attention_norm: bool = False
|
50 |
+
_norm_class: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
|
51 |
+
norm_eps: float = 1e-5
|
52 |
+
_mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP"
|
53 |
+
gelu_approximate: str = "none"
|
54 |
+
intermediate_size: Optional[int] = None
|
55 |
+
rope_condense_ratio: int = 1
|
56 |
+
rope_base: int = 10000
|
57 |
+
|
58 |
+
def __post_init__(self):
|
59 |
+
assert self.n_embd % self.n_head == 0
|
60 |
+
self.head_size = self.n_embd // self.n_head
|
61 |
+
|
62 |
+
# vocab size should be a power of 2 to be optimal on hardware. compute the closest value
|
63 |
+
if self.padded_vocab_size is None:
|
64 |
+
self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple)
|
65 |
+
else:
|
66 |
+
# vocab size shouldn't be larger than padded vocab size
|
67 |
+
self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
|
68 |
+
|
69 |
+
# compute the number of query groups
|
70 |
+
if self.n_query_groups is not None:
|
71 |
+
assert self.n_head % self.n_query_groups == 0
|
72 |
+
else:
|
73 |
+
self.n_query_groups = self.n_head
|
74 |
+
|
75 |
+
# compute the intermediate size for MLP if not set
|
76 |
+
if self.intermediate_size is None:
|
77 |
+
if self._mlp_class == "LLaMAMLP":
|
78 |
+
raise ValueError("The config needs to set the `intermediate_size`")
|
79 |
+
self.intermediate_size = 4 * self.n_embd
|
80 |
+
|
81 |
+
self.rope_n_elem = int(self.rotary_percentage * self.head_size)
|
82 |
+
|
83 |
+
@classmethod
|
84 |
+
def from_name(cls, name: str, **kwargs: Any) -> Self:
|
85 |
+
conf_dict = name_to_config[name].copy()
|
86 |
+
if "condense_ratio" in kwargs: # legacy name
|
87 |
+
kwargs["rope_condense_ratio"] = kwargs.pop("condense_ratio")
|
88 |
+
conf_dict.update(kwargs)
|
89 |
+
return cls(**conf_dict)
|
90 |
+
|
91 |
+
@classmethod
|
92 |
+
def from_json(cls, path: Union[str, Path], **kwargs: Any) -> Self:
|
93 |
+
with open(path, encoding="utf-8") as fp:
|
94 |
+
json_kwargs = json.load(fp)
|
95 |
+
if "condense_ratio" in json_kwargs: # legacy name
|
96 |
+
json_kwargs["rope_condense_ratio"] = json_kwargs.pop("condense_ratio")
|
97 |
+
if "condense_ratio" in kwargs: # legacy name
|
98 |
+
kwargs["rope_condense_ratio"] = kwargs.pop("condense_ratio")
|
99 |
+
json_kwargs.update(kwargs)
|
100 |
+
return cls(**json_kwargs)
|
101 |
+
|
102 |
+
@property
|
103 |
+
def mlp_class(self) -> Type:
|
104 |
+
# `self._mlp_class` cannot be the type to keep the config json serializable
|
105 |
+
return getattr(lit_gpt.model, self._mlp_class)
|
106 |
+
|
107 |
+
@property
|
108 |
+
def norm_class(self) -> Type:
|
109 |
+
# `self._norm_class` cannot be the type to keep the config json serializable
|
110 |
+
if self._norm_class == "RMSNorm":
|
111 |
+
from lit_gpt.rmsnorm import RMSNorm
|
112 |
+
|
113 |
+
return RMSNorm
|
114 |
+
return getattr(torch.nn, self._norm_class)
|
115 |
+
|
116 |
+
|
117 |
+
########################
|
118 |
+
# Stability AI StableLM
|
119 |
+
########################
|
120 |
+
configs = [
|
121 |
+
# https://huggingface.co/stabilityai/stablelm-base-alpha-3b/blob/main/config.json
|
122 |
+
dict(org="stabilityai", name="stablelm-base-alpha-3b"),
|
123 |
+
# https://huggingface.co/stabilityai/stablelm-base-alpha-7b/blob/main/config.json
|
124 |
+
dict(org="stabilityai", name="stablelm-base-alpha-7b", n_head=48, n_embd=6144, padding_multiple=256),
|
125 |
+
# https://huggingface.co/stabilityai/stablelm-tuned-alpha-3b/blob/main/config.json
|
126 |
+
dict(org="stabilityai", name="stablelm-tuned-alpha-3b", n_head=32),
|
127 |
+
# https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b/blob/main/config.json
|
128 |
+
dict(org="stabilityai", name="stablelm-tuned-alpha-7b", n_head=48, n_embd=6144, padding_multiple=256),
|
129 |
+
]
|
130 |
+
|
131 |
+
####################
|
132 |
+
# EleutherAI Pythia
|
133 |
+
####################
|
134 |
+
pythia = [
|
135 |
+
# https://huggingface.co/EleutherAI/pythia-70m/blob/main/config.json
|
136 |
+
dict(org="EleutherAI", name="pythia-70m", block_size=2048, n_layer=6, n_embd=512, n_head=8, padding_multiple=128),
|
137 |
+
# https://huggingface.co/EleutherAI/pythia-160m/blob/main/config.json
|
138 |
+
dict(
|
139 |
+
org="EleutherAI", name="pythia-160m", block_size=2048, n_layer=12, n_embd=768, n_head=12, padding_multiple=128
|
140 |
+
),
|
141 |
+
# https://huggingface.co/EleutherAI/pythia-410m/blob/main/config.json
|
142 |
+
dict(
|
143 |
+
org="EleutherAI", name="pythia-410m", block_size=2048, n_layer=24, n_embd=1024, n_head=16, padding_multiple=128
|
144 |
+
),
|
145 |
+
# https://huggingface.co/EleutherAI/pythia-1b/blob/main/config.json
|
146 |
+
dict(org="EleutherAI", name="pythia-1b", block_size=2048, n_embd=2048, n_head=8, padding_multiple=128),
|
147 |
+
# https://huggingface.co/EleutherAI/pythia-1.4b/blob/main/config.json
|
148 |
+
dict(
|
149 |
+
org="EleutherAI", name="pythia-1.4b", block_size=2048, n_layer=24, n_embd=2048, n_head=16, padding_multiple=128
|
150 |
+
),
|
151 |
+
# https://huggingface.co/EleutherAI/pythia-2.8b/blob/main/config.json
|
152 |
+
dict(org="EleutherAI", name="pythia-2.8b", block_size=2048, n_layer=32, n_embd=2560, padding_multiple=128),
|
153 |
+
# https://huggingface.co/EleutherAI/pythia-6.9b/blob/main/config.json
|
154 |
+
dict(org="EleutherAI", name="pythia-6.9b", block_size=2048, n_layer=32, padding_multiple=256),
|
155 |
+
# https://huggingface.co/EleutherAI/pythia-12b/blob/main/config.json
|
156 |
+
dict(org="EleutherAI", name="pythia-12b", block_size=2048, n_layer=36, n_embd=5120, n_head=40),
|
157 |
+
]
|
158 |
+
configs.extend(pythia)
|
159 |
+
for c in pythia:
|
160 |
+
copy = c.copy()
|
161 |
+
copy["name"] = f"{c['name']}-deduped"
|
162 |
+
configs.append(copy)
|
163 |
+
|
164 |
+
|
165 |
+
####################################
|
166 |
+
# togethercomputer RedPajama INCITE
|
167 |
+
####################################
|
168 |
+
redpajama_incite = [
|
169 |
+
# https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1/blob/main/config.json
|
170 |
+
dict(
|
171 |
+
org="togethercomputer",
|
172 |
+
name="RedPajama-INCITE-{}-3B-v1",
|
173 |
+
block_size=2048,
|
174 |
+
n_layer=32,
|
175 |
+
n_embd=2560,
|
176 |
+
padding_multiple=256,
|
177 |
+
rotary_percentage=1.0,
|
178 |
+
parallel_residual=False,
|
179 |
+
),
|
180 |
+
# https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Base/blob/main/config.json
|
181 |
+
dict(
|
182 |
+
org="togethercomputer",
|
183 |
+
name="RedPajama-INCITE-7B-{}",
|
184 |
+
block_size=2048,
|
185 |
+
n_layer=32,
|
186 |
+
padding_multiple=256,
|
187 |
+
rotary_percentage=1.0,
|
188 |
+
parallel_residual=False,
|
189 |
+
),
|
190 |
+
# this redirects to the checkpoint above. kept for those who had the old weights already downloaded
|
191 |
+
dict(
|
192 |
+
org="togethercomputer",
|
193 |
+
name="RedPajama-INCITE-{}-7B-v0.1",
|
194 |
+
block_size=2048,
|
195 |
+
n_layer=32,
|
196 |
+
padding_multiple=256,
|
197 |
+
rotary_percentage=1.0,
|
198 |
+
parallel_residual=False,
|
199 |
+
),
|
200 |
+
]
|
201 |
+
for c in redpajama_incite:
|
202 |
+
for kind in ("Base", "Chat", "Instruct"):
|
203 |
+
copy = c.copy()
|
204 |
+
copy["name"] = c["name"].format(kind)
|
205 |
+
configs.append(copy)
|
206 |
+
|
207 |
+
|
208 |
+
#################
|
209 |
+
# TII UAE Falcon
|
210 |
+
#################
|
211 |
+
falcon = [
|
212 |
+
# https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json
|
213 |
+
dict(
|
214 |
+
org="tiiuae",
|
215 |
+
name="falcon-7b{}",
|
216 |
+
block_size=2048,
|
217 |
+
vocab_size=65024,
|
218 |
+
padded_vocab_size=65024,
|
219 |
+
n_layer=32,
|
220 |
+
n_head=71,
|
221 |
+
n_embd=4544,
|
222 |
+
rotary_percentage=1.0,
|
223 |
+
n_query_groups=1,
|
224 |
+
bias=False,
|
225 |
+
# this is not in the config, but in the original model implementation, only for this config
|
226 |
+
shared_attention_norm=True,
|
227 |
+
),
|
228 |
+
# https://huggingface.co/tiiuae/falcon-40b/blob/main/config.json
|
229 |
+
dict(
|
230 |
+
org="tiiuae",
|
231 |
+
name="falcon-40b{}",
|
232 |
+
block_size=2048,
|
233 |
+
vocab_size=65024,
|
234 |
+
padded_vocab_size=65024,
|
235 |
+
n_layer=60,
|
236 |
+
n_head=128,
|
237 |
+
n_embd=8192,
|
238 |
+
rotary_percentage=1.0,
|
239 |
+
n_query_groups=8,
|
240 |
+
bias=False,
|
241 |
+
),
|
242 |
+
]
|
243 |
+
for c in falcon:
|
244 |
+
for kind in ("", "-instruct"):
|
245 |
+
copy = c.copy()
|
246 |
+
copy["name"] = c["name"].format(kind)
|
247 |
+
configs.append(copy)
|
248 |
+
|
249 |
+
# https://huggingface.co/tiiuae/falcon-180b/blob/main/config.json
|
250 |
+
falcon180b = dict(
|
251 |
+
org="tiiuae",
|
252 |
+
name="falcon-180B{}",
|
253 |
+
block_size=2048,
|
254 |
+
vocab_size=65024,
|
255 |
+
padded_vocab_size=65024,
|
256 |
+
n_layer=80,
|
257 |
+
n_head=232,
|
258 |
+
n_embd=14848,
|
259 |
+
rotary_percentage=1.0,
|
260 |
+
n_query_groups=8,
|
261 |
+
bias=False,
|
262 |
+
)
|
263 |
+
|
264 |
+
for kind in ("", "-chat"):
|
265 |
+
copy = falcon180b.copy()
|
266 |
+
copy["name"] = falcon180b["name"].format(kind)
|
267 |
+
configs.append(copy)
|
268 |
+
|
269 |
+
|
270 |
+
#############################
|
271 |
+
# OpenLM Research Open LLaMA
|
272 |
+
#############################
|
273 |
+
open_LLaMA = [
|
274 |
+
# https://huggingface.co/openlm-research/open_llama_3b/blob/main/config.json
|
275 |
+
dict(
|
276 |
+
org="openlm-research",
|
277 |
+
name="open_llama_3b",
|
278 |
+
block_size=2048,
|
279 |
+
vocab_size=32000,
|
280 |
+
padding_multiple=64,
|
281 |
+
n_layer=26,
|
282 |
+
n_embd=3200,
|
283 |
+
rotary_percentage=1.0,
|
284 |
+
parallel_residual=False,
|
285 |
+
bias=False,
|
286 |
+
_norm_class="RMSNorm",
|
287 |
+
norm_eps=1e-6,
|
288 |
+
_mlp_class="LLaMAMLP",
|
289 |
+
intermediate_size=8640,
|
290 |
+
),
|
291 |
+
# https://huggingface.co/openlm-research/open_llama_7b/blob/main/config.json
|
292 |
+
dict(
|
293 |
+
org="openlm-research",
|
294 |
+
name="open_llama_7b",
|
295 |
+
block_size=2048,
|
296 |
+
vocab_size=32000,
|
297 |
+
padding_multiple=64,
|
298 |
+
n_layer=32,
|
299 |
+
rotary_percentage=1.0,
|
300 |
+
parallel_residual=False,
|
301 |
+
bias=False,
|
302 |
+
_norm_class="RMSNorm",
|
303 |
+
norm_eps=1e-6,
|
304 |
+
_mlp_class="LLaMAMLP",
|
305 |
+
intermediate_size=11008,
|
306 |
+
),
|
307 |
+
# https://huggingface.co/openlm-research/open_llama_13b/blob/main/config.json
|
308 |
+
dict(
|
309 |
+
org="openlm-research",
|
310 |
+
name="open_llama_13b",
|
311 |
+
block_size=2048,
|
312 |
+
vocab_size=32000,
|
313 |
+
padding_multiple=64,
|
314 |
+
n_layer=40,
|
315 |
+
n_head=40,
|
316 |
+
n_embd=5120,
|
317 |
+
rotary_percentage=1.0,
|
318 |
+
parallel_residual=False,
|
319 |
+
bias=False,
|
320 |
+
_norm_class="RMSNorm",
|
321 |
+
norm_eps=1e-6,
|
322 |
+
_mlp_class="LLaMAMLP",
|
323 |
+
intermediate_size=13824,
|
324 |
+
),
|
325 |
+
]
|
326 |
+
configs.extend(open_LLaMA)
|
327 |
+
|
328 |
+
|
329 |
+
###############
|
330 |
+
# LMSYS Vicuna
|
331 |
+
###############
|
332 |
+
vicuna = [
|
333 |
+
# https://huggingface.co/lmsys/vicuna-7b-v1.3/blob/main/config.json
|
334 |
+
dict(
|
335 |
+
org="lmsys",
|
336 |
+
name="vicuna-7b-v1.3",
|
337 |
+
block_size=2048,
|
338 |
+
vocab_size=32000,
|
339 |
+
padding_multiple=64,
|
340 |
+
n_layer=32,
|
341 |
+
rotary_percentage=1.0,
|
342 |
+
parallel_residual=False,
|
343 |
+
bias=False,
|
344 |
+
_norm_class="RMSNorm",
|
345 |
+
norm_eps=1e-6,
|
346 |
+
_mlp_class="LLaMAMLP",
|
347 |
+
intermediate_size=11008,
|
348 |
+
),
|
349 |
+
# https://huggingface.co/lmsys/vicuna-13b-v1.3/blob/main/config.json
|
350 |
+
dict(
|
351 |
+
org="lmsys",
|
352 |
+
name="vicuna-13b-v1.3",
|
353 |
+
block_size=2048,
|
354 |
+
vocab_size=32000,
|
355 |
+
padding_multiple=64,
|
356 |
+
n_layer=40,
|
357 |
+
n_head=40,
|
358 |
+
n_embd=5120,
|
359 |
+
rotary_percentage=1.0,
|
360 |
+
parallel_residual=False,
|
361 |
+
bias=False,
|
362 |
+
_norm_class="RMSNorm",
|
363 |
+
norm_eps=1e-6,
|
364 |
+
_mlp_class="LLaMAMLP",
|
365 |
+
intermediate_size=13824,
|
366 |
+
),
|
367 |
+
# https://huggingface.co/lmsys/vicuna-33b-v1.3/blob/main/config.json
|
368 |
+
dict(
|
369 |
+
org="lmsys",
|
370 |
+
name="vicuna-33b-v1.3",
|
371 |
+
block_size=2048,
|
372 |
+
vocab_size=32000,
|
373 |
+
padding_multiple=64,
|
374 |
+
n_layer=60,
|
375 |
+
n_head=52,
|
376 |
+
n_embd=6656,
|
377 |
+
rotary_percentage=1.0,
|
378 |
+
parallel_residual=False,
|
379 |
+
bias=False,
|
380 |
+
_norm_class="RMSNorm",
|
381 |
+
norm_eps=1e-6,
|
382 |
+
_mlp_class="LLaMAMLP",
|
383 |
+
intermediate_size=17920,
|
384 |
+
),
|
385 |
+
# https://huggingface.co/lmsys/vicuna-7b-v1.5/blob/main/config.json
|
386 |
+
dict(
|
387 |
+
org="lmsys",
|
388 |
+
name="vicuna-7b-v1.5",
|
389 |
+
vocab_size=32000,
|
390 |
+
padding_multiple=64,
|
391 |
+
n_layer=32,
|
392 |
+
rotary_percentage=1.0,
|
393 |
+
parallel_residual=False,
|
394 |
+
bias=False,
|
395 |
+
_norm_class="RMSNorm",
|
396 |
+
_mlp_class="LLaMAMLP",
|
397 |
+
intermediate_size=11008,
|
398 |
+
),
|
399 |
+
# https://huggingface.co/lmsys/vicuna-7b-v1.5-16k/blob/main/config.json
|
400 |
+
dict(
|
401 |
+
org="lmsys",
|
402 |
+
name="vicuna-7b-v1.5-16k",
|
403 |
+
block_size=16384,
|
404 |
+
vocab_size=32000,
|
405 |
+
padding_multiple=64,
|
406 |
+
n_layer=32,
|
407 |
+
rotary_percentage=1.0,
|
408 |
+
parallel_residual=False,
|
409 |
+
bias=False,
|
410 |
+
_norm_class="RMSNorm",
|
411 |
+
_mlp_class="LLaMAMLP",
|
412 |
+
intermediate_size=11008,
|
413 |
+
rope_condense_ratio=4,
|
414 |
+
),
|
415 |
+
# https://huggingface.co/lmsys/vicuna-13b-v1.5/blob/main/config.json
|
416 |
+
dict(
|
417 |
+
org="lmsys",
|
418 |
+
name="vicuna-13b-v1.5",
|
419 |
+
vocab_size=32000,
|
420 |
+
padding_multiple=64,
|
421 |
+
n_layer=40,
|
422 |
+
n_head=40,
|
423 |
+
n_embd=5120,
|
424 |
+
rotary_percentage=1.0,
|
425 |
+
parallel_residual=False,
|
426 |
+
bias=False,
|
427 |
+
_norm_class="RMSNorm",
|
428 |
+
_mlp_class="LLaMAMLP",
|
429 |
+
intermediate_size=13824,
|
430 |
+
),
|
431 |
+
# https://huggingface.co/lmsys/vicuna-13b-v1.5-16k/blob/main/config.json
|
432 |
+
dict(
|
433 |
+
org="lmsys",
|
434 |
+
name="vicuna-13b-v1.5-16k",
|
435 |
+
block_size=16384,
|
436 |
+
vocab_size=32000,
|
437 |
+
padding_multiple=64,
|
438 |
+
n_layer=40,
|
439 |
+
n_head=40,
|
440 |
+
n_embd=5120,
|
441 |
+
rotary_percentage=1.0,
|
442 |
+
parallel_residual=False,
|
443 |
+
bias=False,
|
444 |
+
_norm_class="RMSNorm",
|
445 |
+
_mlp_class="LLaMAMLP",
|
446 |
+
intermediate_size=13824,
|
447 |
+
rope_condense_ratio=4,
|
448 |
+
),
|
449 |
+
]
|
450 |
+
configs.extend(vicuna)
|
451 |
+
|
452 |
+
|
453 |
+
#################
|
454 |
+
# LMSYS LongChat
|
455 |
+
#################
|
456 |
+
long_chat = [
|
457 |
+
# https://huggingface.co/lmsys/longchat-7b-16k/blob/main/config.json
|
458 |
+
dict(
|
459 |
+
org="lmsys",
|
460 |
+
name="longchat-7b-16k",
|
461 |
+
block_size=16384,
|
462 |
+
vocab_size=32000,
|
463 |
+
padding_multiple=64,
|
464 |
+
n_layer=32,
|
465 |
+
rotary_percentage=1.0,
|
466 |
+
parallel_residual=False,
|
467 |
+
bias=False,
|
468 |
+
_norm_class="RMSNorm",
|
469 |
+
norm_eps=1e-6,
|
470 |
+
_mlp_class="LLaMAMLP",
|
471 |
+
intermediate_size=11008,
|
472 |
+
rope_condense_ratio=8,
|
473 |
+
),
|
474 |
+
# https://huggingface.co/lmsys/longchat-13b-16k/blob/main/config.json
|
475 |
+
dict(
|
476 |
+
org="lmsys",
|
477 |
+
name="longchat-13b-16k",
|
478 |
+
block_size=16384,
|
479 |
+
vocab_size=32000,
|
480 |
+
padding_multiple=64,
|
481 |
+
n_layer=40,
|
482 |
+
n_head=40,
|
483 |
+
n_embd=5120,
|
484 |
+
rotary_percentage=1.0,
|
485 |
+
parallel_residual=False,
|
486 |
+
bias=False,
|
487 |
+
_norm_class="RMSNorm",
|
488 |
+
norm_eps=1e-6,
|
489 |
+
_mlp_class="LLaMAMLP",
|
490 |
+
intermediate_size=13824,
|
491 |
+
rope_condense_ratio=8,
|
492 |
+
),
|
493 |
+
]
|
494 |
+
configs.extend(long_chat)
|
495 |
+
|
496 |
+
|
497 |
+
######################
|
498 |
+
# NousResearch Hermes
|
499 |
+
######################
|
500 |
+
nous_research = [
|
501 |
+
# https://huggingface.co/NousResearch/Nous-Hermes-llama-2-7b/blob/main/config.json
|
502 |
+
dict(
|
503 |
+
org="NousResearch",
|
504 |
+
name="Nous-Hermes-llama-2-7b",
|
505 |
+
padded_vocab_size=32000,
|
506 |
+
n_layer=32,
|
507 |
+
rotary_percentage=1.0,
|
508 |
+
parallel_residual=False,
|
509 |
+
bias=False,
|
510 |
+
_norm_class="RMSNorm",
|
511 |
+
norm_eps=1e-05,
|
512 |
+
_mlp_class="LLaMAMLP",
|
513 |
+
intermediate_size=11008,
|
514 |
+
),
|
515 |
+
# https://huggingface.co/NousResearch/Nous-Hermes-13B/blob/main/config.json
|
516 |
+
dict(
|
517 |
+
org="NousResearch",
|
518 |
+
name="Nous-Hermes-13b",
|
519 |
+
block_size=2048,
|
520 |
+
vocab_size=32000,
|
521 |
+
padded_vocab_size=32001,
|
522 |
+
n_layer=40,
|
523 |
+
n_head=40,
|
524 |
+
n_embd=5120,
|
525 |
+
rotary_percentage=1.0,
|
526 |
+
parallel_residual=False,
|
527 |
+
bias=False,
|
528 |
+
_norm_class="RMSNorm",
|
529 |
+
norm_eps=1e-6,
|
530 |
+
_mlp_class="LLaMAMLP",
|
531 |
+
intermediate_size=13824,
|
532 |
+
),
|
533 |
+
# https://huggingface.co/NousResearch/Nous-Hermes-Llama2-13b
|
534 |
+
dict(
|
535 |
+
org="NousResearch",
|
536 |
+
name="Nous-Hermes-Llama2-13b",
|
537 |
+
vocab_size=32000,
|
538 |
+
padded_vocab_size=32032,
|
539 |
+
n_layer=40,
|
540 |
+
n_head=40,
|
541 |
+
n_embd=5120,
|
542 |
+
rotary_percentage=1.0,
|
543 |
+
parallel_residual=False,
|
544 |
+
bias=False,
|
545 |
+
_norm_class="RMSNorm",
|
546 |
+
norm_eps=1e-05,
|
547 |
+
_mlp_class="LLaMAMLP",
|
548 |
+
intermediate_size=13824,
|
549 |
+
),
|
550 |
+
]
|
551 |
+
configs.extend(nous_research)
|
552 |
+
|
553 |
+
|
554 |
+
###############
|
555 |
+
# Meta LLaMA 2
|
556 |
+
###############
|
557 |
+
llama_2 = [
|
558 |
+
# https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json
|
559 |
+
dict(
|
560 |
+
org="meta-llama",
|
561 |
+
name="Llama-2-7b{}-hf",
|
562 |
+
vocab_size=32000,
|
563 |
+
padding_multiple=64,
|
564 |
+
n_layer=32,
|
565 |
+
rotary_percentage=1.0,
|
566 |
+
parallel_residual=False,
|
567 |
+
bias=False,
|
568 |
+
_norm_class="RMSNorm",
|
569 |
+
_mlp_class="LLaMAMLP",
|
570 |
+
intermediate_size=11008,
|
571 |
+
),
|
572 |
+
# https://huggingface.co/meta-llama/Llama-2-13b-hf/blob/main/config.json
|
573 |
+
dict(
|
574 |
+
org="meta-llama",
|
575 |
+
name="Llama-2-13b{}-hf",
|
576 |
+
vocab_size=32000,
|
577 |
+
padding_multiple=64,
|
578 |
+
n_layer=40,
|
579 |
+
n_head=40,
|
580 |
+
n_embd=5120,
|
581 |
+
rotary_percentage=1.0,
|
582 |
+
parallel_residual=False,
|
583 |
+
bias=False,
|
584 |
+
_norm_class="RMSNorm",
|
585 |
+
_mlp_class="LLaMAMLP",
|
586 |
+
intermediate_size=13824,
|
587 |
+
),
|
588 |
+
# https://huggingface.co/meta-llama/Llama-2-70b-hf/blob/main/config.json
|
589 |
+
dict(
|
590 |
+
org="meta-llama",
|
591 |
+
name="Llama-2-70b{}-hf",
|
592 |
+
vocab_size=32000,
|
593 |
+
padding_multiple=64,
|
594 |
+
n_layer=80,
|
595 |
+
n_head=64,
|
596 |
+
n_embd=8192,
|
597 |
+
n_query_groups=8,
|
598 |
+
rotary_percentage=1.0,
|
599 |
+
parallel_residual=False,
|
600 |
+
bias=False,
|
601 |
+
_norm_class="RMSNorm",
|
602 |
+
_mlp_class="LLaMAMLP",
|
603 |
+
intermediate_size=28672,
|
604 |
+
),
|
605 |
+
]
|
606 |
+
for c in llama_2:
|
607 |
+
for kind in ("", "-chat"):
|
608 |
+
copy = c.copy()
|
609 |
+
copy["name"] = c["name"].format(kind)
|
610 |
+
configs.append(copy)
|
611 |
+
|
612 |
+
|
613 |
+
##########################
|
614 |
+
# Stability AI FreeWilly2
|
615 |
+
##########################
|
616 |
+
freewilly_2 = [
|
617 |
+
# https://huggingface.co/stabilityai/FreeWilly2/blob/main/config.json
|
618 |
+
dict(
|
619 |
+
org="stabilityai",
|
620 |
+
name="FreeWilly2",
|
621 |
+
vocab_size=32000,
|
622 |
+
padding_multiple=64,
|
623 |
+
n_layer=80,
|
624 |
+
n_head=64,
|
625 |
+
n_embd=8192,
|
626 |
+
n_query_groups=8,
|
627 |
+
rotary_percentage=1.0,
|
628 |
+
parallel_residual=False,
|
629 |
+
bias=False,
|
630 |
+
_norm_class="RMSNorm",
|
631 |
+
_mlp_class="LLaMAMLP",
|
632 |
+
intermediate_size=28672,
|
633 |
+
)
|
634 |
+
]
|
635 |
+
configs.extend(freewilly_2)
|
636 |
+
|
637 |
+
|
638 |
+
##################
|
639 |
+
# Meta Code Llama
|
640 |
+
##################
|
641 |
+
code_llama = [
|
642 |
+
# https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json
|
643 |
+
dict(
|
644 |
+
org="codellama",
|
645 |
+
name="CodeLlama-7b-hf",
|
646 |
+
block_size=16384,
|
647 |
+
vocab_size=32016,
|
648 |
+
padding_multiple=16,
|
649 |
+
n_layer=32,
|
650 |
+
rotary_percentage=1.0,
|
651 |
+
parallel_residual=False,
|
652 |
+
bias=False,
|
653 |
+
_norm_class="RMSNorm",
|
654 |
+
norm_eps=1e-05,
|
655 |
+
_mlp_class="LLaMAMLP",
|
656 |
+
intermediate_size=11008,
|
657 |
+
rope_base=1000000,
|
658 |
+
),
|
659 |
+
# https://huggingface.co/codellama/CodeLlama-13b-hf/blob/main/config.json
|
660 |
+
dict(
|
661 |
+
org="codellama",
|
662 |
+
name="CodeLlama-13b-hf",
|
663 |
+
block_size=16384,
|
664 |
+
vocab_size=32016,
|
665 |
+
padding_multiple=16,
|
666 |
+
n_layer=40,
|
667 |
+
n_head=40,
|
668 |
+
n_embd=5120,
|
669 |
+
rotary_percentage=1.0,
|
670 |
+
parallel_residual=False,
|
671 |
+
bias=False,
|
672 |
+
_norm_class="RMSNorm",
|
673 |
+
norm_eps=1e-05,
|
674 |
+
_mlp_class="LLaMAMLP",
|
675 |
+
intermediate_size=13824,
|
676 |
+
rope_base=1000000,
|
677 |
+
),
|
678 |
+
# https://huggingface.co/codellama/CodeLlama-34b-hf/blob/main/config.json
|
679 |
+
dict(
|
680 |
+
org="codellama",
|
681 |
+
name="CodeLlama-34b-hf",
|
682 |
+
block_size=16384,
|
683 |
+
vocab_size=32000,
|
684 |
+
padding_multiple=64,
|
685 |
+
n_layer=48,
|
686 |
+
n_head=64,
|
687 |
+
n_embd=8192,
|
688 |
+
n_query_groups=8,
|
689 |
+
rotary_percentage=1.0,
|
690 |
+
parallel_residual=False,
|
691 |
+
bias=False,
|
692 |
+
_norm_class="RMSNorm",
|
693 |
+
norm_eps=1e-05,
|
694 |
+
_mlp_class="LLaMAMLP",
|
695 |
+
intermediate_size=22016,
|
696 |
+
rope_base=1000000,
|
697 |
+
),
|
698 |
+
# https://huggingface.co/codellama/CodeLlama-7b-Python-hf/blob/main/config.json
|
699 |
+
dict(
|
700 |
+
org="codellama",
|
701 |
+
name="CodeLlama-7b-Python-hf",
|
702 |
+
block_size=16384,
|
703 |
+
vocab_size=32000,
|
704 |
+
padding_multiple=64,
|
705 |
+
n_layer=32,
|
706 |
+
rotary_percentage=1.0,
|
707 |
+
parallel_residual=False,
|
708 |
+
bias=False,
|
709 |
+
_norm_class="RMSNorm",
|
710 |
+
norm_eps=1e-05,
|
711 |
+
_mlp_class="LLaMAMLP",
|
712 |
+
intermediate_size=11008,
|
713 |
+
rope_base=1000000,
|
714 |
+
),
|
715 |
+
# https://huggingface.co/codellama/CodeLlama-13b-Python-hf/blob/main/config.json
|
716 |
+
dict(
|
717 |
+
org="codellama",
|
718 |
+
name="CodeLlama-13b-Python-hf",
|
719 |
+
block_size=16384,
|
720 |
+
vocab_size=32000,
|
721 |
+
padding_multiple=64,
|
722 |
+
n_layer=40,
|
723 |
+
n_head=40,
|
724 |
+
n_embd=5120,
|
725 |
+
rotary_percentage=1.0,
|
726 |
+
parallel_residual=False,
|
727 |
+
bias=False,
|
728 |
+
_norm_class="RMSNorm",
|
729 |
+
norm_eps=1e-05,
|
730 |
+
_mlp_class="LLaMAMLP",
|
731 |
+
intermediate_size=13824,
|
732 |
+
rope_base=1000000,
|
733 |
+
),
|
734 |
+
# https://huggingface.co/codellama/CodeLlama-34b-Python-hf/blob/main/config.json
|
735 |
+
dict(
|
736 |
+
org="codellama",
|
737 |
+
name="CodeLlama-34b-Python-hf",
|
738 |
+
block_size=16384,
|
739 |
+
vocab_size=32000,
|
740 |
+
padding_multiple=64,
|
741 |
+
n_layer=48,
|
742 |
+
n_head=64,
|
743 |
+
n_embd=8192,
|
744 |
+
n_query_groups=8,
|
745 |
+
rotary_percentage=1.0,
|
746 |
+
parallel_residual=False,
|
747 |
+
bias=False,
|
748 |
+
_norm_class="RMSNorm",
|
749 |
+
norm_eps=1e-05,
|
750 |
+
_mlp_class="LLaMAMLP",
|
751 |
+
intermediate_size=22016,
|
752 |
+
rope_base=1000000,
|
753 |
+
),
|
754 |
+
# https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/tree/main/config.json
|
755 |
+
dict(
|
756 |
+
org="codellama",
|
757 |
+
name="CodeLlama-7b-Instruct-hf",
|
758 |
+
block_size=16384,
|
759 |
+
vocab_size=32016,
|
760 |
+
padding_multiple=16,
|
761 |
+
n_layer=32,
|
762 |
+
rotary_percentage=1.0,
|
763 |
+
parallel_residual=False,
|
764 |
+
bias=False,
|
765 |
+
_norm_class="RMSNorm",
|
766 |
+
norm_eps=1e-05,
|
767 |
+
_mlp_class="LLaMAMLP",
|
768 |
+
intermediate_size=11008,
|
769 |
+
rope_base=1000000,
|
770 |
+
),
|
771 |
+
# https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf/blob/main/config.json
|
772 |
+
dict(
|
773 |
+
org="codellama",
|
774 |
+
name="CodeLlama-13b-Instruct-hf",
|
775 |
+
block_size=2048,
|
776 |
+
vocab_size=32016,
|
777 |
+
padding_multiple=16,
|
778 |
+
n_layer=40,
|
779 |
+
n_head=40,
|
780 |
+
n_embd=5120,
|
781 |
+
rotary_percentage=1.0,
|
782 |
+
parallel_residual=False,
|
783 |
+
bias=False,
|
784 |
+
_norm_class="RMSNorm",
|
785 |
+
norm_eps=1e-05,
|
786 |
+
_mlp_class="LLaMAMLP",
|
787 |
+
intermediate_size=13824,
|
788 |
+
rope_base=1000000,
|
789 |
+
),
|
790 |
+
# https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf/blob/main/config.json
|
791 |
+
dict(
|
792 |
+
org="codellama",
|
793 |
+
name="CodeLlama-34b-Instruct-hf",
|
794 |
+
block_size=16384,
|
795 |
+
vocab_size=32000,
|
796 |
+
padding_multiple=64,
|
797 |
+
n_layer=48,
|
798 |
+
n_head=64,
|
799 |
+
n_embd=8192,
|
800 |
+
n_query_groups=8,
|
801 |
+
rotary_percentage=1.0,
|
802 |
+
parallel_residual=False,
|
803 |
+
bias=False,
|
804 |
+
_norm_class="RMSNorm",
|
805 |
+
norm_eps=1e-05,
|
806 |
+
_mlp_class="LLaMAMLP",
|
807 |
+
intermediate_size=22016,
|
808 |
+
rope_base=1000000,
|
809 |
+
),
|
810 |
+
]
|
811 |
+
configs.extend(code_llama)
|
812 |
+
|
813 |
+
|
814 |
+
########################
|
815 |
+
# garage-bAInd Platypus
|
816 |
+
########################
|
817 |
+
platypus = [
|
818 |
+
# https://huggingface.co/garage-bAInd/Platypus-30B/blob/main/config.json
|
819 |
+
dict(
|
820 |
+
org="garage-bAInd",
|
821 |
+
name="Platypus-30B",
|
822 |
+
block_size=2048,
|
823 |
+
padded_vocab_size=32000,
|
824 |
+
n_layer=60,
|
825 |
+
n_head=52,
|
826 |
+
n_embd=6656,
|
827 |
+
rotary_percentage=1.0,
|
828 |
+
parallel_residual=False,
|
829 |
+
bias=False,
|
830 |
+
_norm_class="RMSNorm",
|
831 |
+
norm_eps=1e-06,
|
832 |
+
_mlp_class="LLaMAMLP",
|
833 |
+
intermediate_size=17920,
|
834 |
+
),
|
835 |
+
# https://huggingface.co/garage-bAInd/Platypus2-7B/blob/main/config.json
|
836 |
+
dict(
|
837 |
+
org="garage-bAInd",
|
838 |
+
name="Platypus2-7B",
|
839 |
+
padded_vocab_size=32000,
|
840 |
+
n_layer=32,
|
841 |
+
rotary_percentage=1.0,
|
842 |
+
parallel_residual=False,
|
843 |
+
bias=False,
|
844 |
+
_norm_class="RMSNorm",
|
845 |
+
norm_eps=1e-05,
|
846 |
+
_mlp_class="LLaMAMLP",
|
847 |
+
intermediate_size=11008,
|
848 |
+
),
|
849 |
+
# https://huggingface.co/garage-bAInd/Platypus2-13B/blob/main/config.json
|
850 |
+
dict(
|
851 |
+
org="garage-bAInd",
|
852 |
+
name="Platypus2-13B",
|
853 |
+
padded_vocab_size=32000,
|
854 |
+
n_layer=40,
|
855 |
+
n_head=40,
|
856 |
+
n_embd=5120,
|
857 |
+
rotary_percentage=1.0,
|
858 |
+
parallel_residual=False,
|
859 |
+
bias=False,
|
860 |
+
_norm_class="RMSNorm",
|
861 |
+
norm_eps=1e-05,
|
862 |
+
_mlp_class="LLaMAMLP",
|
863 |
+
intermediate_size=13824,
|
864 |
+
),
|
865 |
+
# https://huggingface.co/garage-bAInd/Platypus2-70B/blob/main/config.json
|
866 |
+
dict(
|
867 |
+
org="garage-bAInd",
|
868 |
+
name="Platypus2-70B",
|
869 |
+
padded_vocab_size=32000,
|
870 |
+
n_layer=80,
|
871 |
+
n_head=64,
|
872 |
+
n_embd=8192,
|
873 |
+
rotary_percentage=1.0,
|
874 |
+
parallel_residual=False,
|
875 |
+
bias=False,
|
876 |
+
_norm_class="RMSNorm",
|
877 |
+
_mlp_class="LLaMAMLP",
|
878 |
+
intermediate_size=28672,
|
879 |
+
),
|
880 |
+
# https://huggingface.co/garage-bAInd/Camel-Platypus2-13B/blob/main/config.json
|
881 |
+
dict(
|
882 |
+
org="garage-bAInd",
|
883 |
+
name="Camel-Platypus2-13B",
|
884 |
+
padded_vocab_size=32000,
|
885 |
+
n_layer=40,
|
886 |
+
n_head=40,
|
887 |
+
n_embd=5120,
|
888 |
+
rotary_percentage=1.0,
|
889 |
+
parallel_residual=False,
|
890 |
+
bias=False,
|
891 |
+
_norm_class="RMSNorm",
|
892 |
+
_mlp_class="LLaMAMLP",
|
893 |
+
intermediate_size=13824,
|
894 |
+
),
|
895 |
+
# https://huggingface.co/garage-bAInd/Camel-Platypus2-70B/blob/main/config.json
|
896 |
+
dict(
|
897 |
+
org="garage-bAInd",
|
898 |
+
name="Camel-Platypus2-70B",
|
899 |
+
padded_vocab_size=32000,
|
900 |
+
n_layer=80,
|
901 |
+
n_head=64,
|
902 |
+
n_embd=8192,
|
903 |
+
n_query_groups=8,
|
904 |
+
rotary_percentage=1.0,
|
905 |
+
parallel_residual=False,
|
906 |
+
bias=False,
|
907 |
+
_norm_class="RMSNorm",
|
908 |
+
_mlp_class="LLaMAMLP",
|
909 |
+
intermediate_size=28672,
|
910 |
+
),
|
911 |
+
# https://huggingface.co/garage-bAInd/Stable-Platypus2-13B/blob/main/config.json
|
912 |
+
dict(
|
913 |
+
org="garage-bAInd",
|
914 |
+
name="Stable-Platypus2-13B",
|
915 |
+
padded_vocab_size=32000,
|
916 |
+
n_layer=40,
|
917 |
+
n_head=40,
|
918 |
+
n_embd=5120,
|
919 |
+
rotary_percentage=1.0,
|
920 |
+
parallel_residual=False,
|
921 |
+
bias=False,
|
922 |
+
_norm_class="RMSNorm",
|
923 |
+
_mlp_class="LLaMAMLP",
|
924 |
+
intermediate_size=13824,
|
925 |
+
),
|
926 |
+
# https://huggingface.co/garage-bAInd/Platypus2-70B-instruct/blob/main/config.json
|
927 |
+
dict(
|
928 |
+
org="garage-bAInd",
|
929 |
+
name="Platypus2-70B-instruct",
|
930 |
+
padded_vocab_size=32000,
|
931 |
+
n_layer=80,
|
932 |
+
n_head=64,
|
933 |
+
n_embd=8192,
|
934 |
+
n_query_groups=8,
|
935 |
+
rotary_percentage=1.0,
|
936 |
+
parallel_residual=False,
|
937 |
+
bias=False,
|
938 |
+
_norm_class="RMSNorm",
|
939 |
+
_mlp_class="LLaMAMLP",
|
940 |
+
intermediate_size=28672,
|
941 |
+
),
|
942 |
+
]
|
943 |
+
configs.extend(platypus)
|
944 |
+
|
945 |
+
|
946 |
+
##########################
|
947 |
+
# Stability AI StableCode
|
948 |
+
##########################
|
949 |
+
stablecode = [
|
950 |
+
# https://huggingface.co/stabilityai/stablecode-completion-alpha-3b/blob/main/config.json
|
951 |
+
dict(
|
952 |
+
org="stabilityai",
|
953 |
+
name="stablecode-completion-alpha-3b",
|
954 |
+
block_size=16384,
|
955 |
+
vocab_size=49152,
|
956 |
+
n_layer=32,
|
957 |
+
n_embd=2560,
|
958 |
+
),
|
959 |
+
# https://huggingface.co/stabilityai/stablecode-completion-alpha-3b-4k/blob/main/config.json
|
960 |
+
dict(org="stabilityai", name="stablecode-completion-alpha-3b-4k", vocab_size=49152, n_layer=32, n_embd=2560),
|
961 |
+
# https://huggingface.co/stabilityai/stablecode-instruct-alpha-3b/blob/main/config.json
|
962 |
+
dict(org="stabilityai", name="stablecode-instruct-alpha-3b", vocab_size=49152, n_layer=32, n_embd=2560),
|
963 |
+
]
|
964 |
+
configs.extend(stablecode)
|
965 |
+
|
966 |
+
|
967 |
+
##################################
|
968 |
+
# togethercomputer LLaMA-2-7B-32K
|
969 |
+
##################################
|
970 |
+
together_llama2_32k = [
|
971 |
+
# https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/config.json
|
972 |
+
dict(
|
973 |
+
org="togethercomputer",
|
974 |
+
name="LLaMA-2-7B-32K",
|
975 |
+
vocab_size=32000,
|
976 |
+
padding_multiple=64,
|
977 |
+
n_layer=32,
|
978 |
+
rotary_percentage=1.0,
|
979 |
+
parallel_residual=False,
|
980 |
+
bias=False,
|
981 |
+
_norm_class="RMSNorm",
|
982 |
+
_mlp_class="LLaMAMLP",
|
983 |
+
intermediate_size=11008,
|
984 |
+
rope_condense_ratio=8,
|
985 |
+
)
|
986 |
+
]
|
987 |
+
configs.extend(together_llama2_32k)
|
988 |
+
|
989 |
+
|
990 |
+
################
|
991 |
+
# Microsoft Phi
|
992 |
+
################
|
993 |
+
phi = [
|
994 |
+
# https://huggingface.co/microsoft/phi-1_5/blob/main/config.json
|
995 |
+
dict(
|
996 |
+
org="microsoft",
|
997 |
+
name="phi-1_5",
|
998 |
+
vocab_size=50257,
|
999 |
+
padded_vocab_size=51200,
|
1000 |
+
block_size=2048,
|
1001 |
+
n_embd=2048,
|
1002 |
+
n_layer=24,
|
1003 |
+
rotary_percentage=0.5, # 32 / (n_embd / n_head) = 32 / 64
|
1004 |
+
shared_attention_norm=True,
|
1005 |
+
lm_head_bias=True,
|
1006 |
+
gelu_approximate="tanh",
|
1007 |
+
)
|
1008 |
+
]
|
1009 |
+
configs.extend(phi)
|
1010 |
+
|
1011 |
+
|
1012 |
+
#############
|
1013 |
+
# Mistral AI
|
1014 |
+
#############
|
1015 |
+
mistral = [
|
1016 |
+
# https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
|
1017 |
+
dict(
|
1018 |
+
org="mistralai",
|
1019 |
+
name="Mistral-7B-{}v0.1",
|
1020 |
+
padded_vocab_size=32000,
|
1021 |
+
block_size=4096, # should be 32768 but sliding window attention is not implemented
|
1022 |
+
n_layer=32,
|
1023 |
+
n_query_groups=8,
|
1024 |
+
rotary_percentage=1.0,
|
1025 |
+
parallel_residual=False,
|
1026 |
+
bias=False,
|
1027 |
+
_norm_class="RMSNorm",
|
1028 |
+
norm_eps=1e-05,
|
1029 |
+
_mlp_class="LLaMAMLP",
|
1030 |
+
intermediate_size=14336,
|
1031 |
+
)
|
1032 |
+
]
|
1033 |
+
for c in mistral:
|
1034 |
+
for kind in ("", "Instruct-"):
|
1035 |
+
copy = c.copy()
|
1036 |
+
copy["name"] = c["name"].format(kind)
|
1037 |
+
configs.append(copy)
|
1038 |
+
|
1039 |
+
|
1040 |
+
name_to_config = {config["name"]: config for config in configs}
|
lit_gpt/lora.py
ADDED
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Derived from https://github.com/microsoft/LoRA
|
2 |
+
# ------------------------------------------------------------------------------------------
|
3 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
4 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
5 |
+
# ------------------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
r"""
|
8 |
+
Low Ranking Adaptation for LLMs scheme.
|
9 |
+
|
10 |
+
┌───────────────────┐
|
11 |
+
┆ h ┆
|
12 |
+
└───────────────────┘
|
13 |
+
▲
|
14 |
+
|
|
15 |
+
+
|
16 |
+
/ \
|
17 |
+
┌─────────────────┐ ╭───────────────╮ Matrix initialization:
|
18 |
+
┆ ┆ \ B / B = 0
|
19 |
+
┆ pretrained ┆ \ r*d / A = N(0, sigma^2)
|
20 |
+
┆ weights ┆ ╰─────────╯
|
21 |
+
┆ ┆ | r | r - rank
|
22 |
+
┆ W e R^(d*d) ┆ | ◀─────▶ |
|
23 |
+
┆ ┆ ╭─────────╮
|
24 |
+
└─────────────────┘ / A \
|
25 |
+
▲ / d*r \
|
26 |
+
\ ╰───────────────╯
|
27 |
+
\ ▲
|
28 |
+
\ /
|
29 |
+
\ /
|
30 |
+
┌───────────────────┐
|
31 |
+
┆ x ┆
|
32 |
+
└───────────────────┘
|
33 |
+
|
34 |
+
With LoRA (Low Ranking Adaptation: https://arxiv.org/abs/2106.09685) instead of learning weights of size d*d,
|
35 |
+
we can freeze the pretrained weights and instead learn two matrices of size d*r and r*d (they will store weight updates
|
36 |
+
for the pretrained weights): the number of parameters in this case will be reduced drastically (depending on the rank of
|
37 |
+
course) yet after multiplication of matrices d*r and r*d we will get a matrix d*d which we can sum with frozen
|
38 |
+
pretrained weights and thus fine-tune the model.
|
39 |
+
|
40 |
+
The goal of this approach is to move weight updates into a separate matrix which is decomposed with
|
41 |
+
two matrices of a lower rank.
|
42 |
+
"""
|
43 |
+
|
44 |
+
import math
|
45 |
+
from dataclasses import dataclass
|
46 |
+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
47 |
+
|
48 |
+
import torch
|
49 |
+
import torch.nn as nn
|
50 |
+
from torch.nn import functional as F
|
51 |
+
from typing_extensions import Self
|
52 |
+
|
53 |
+
import lit_gpt
|
54 |
+
from lit_gpt.config import Config as BaseConfig
|
55 |
+
from lit_gpt.model import GPT as BaseModel
|
56 |
+
from lit_gpt.model import Block as BaseBlock
|
57 |
+
from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention
|
58 |
+
from lit_gpt.model import KVCache
|
59 |
+
from lit_gpt.utils import map_old_state_dict_weights
|
60 |
+
|
61 |
+
|
62 |
+
class LoRALayer(nn.Module):
|
63 |
+
def __init__(self, r: int, lora_alpha: int, lora_dropout: float):
|
64 |
+
"""Store LoRA specific attributes in a class.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
|
68 |
+
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
|
69 |
+
lora_alpha: alpha is needed for scaling updates as alpha/r
|
70 |
+
"This scaling helps to reduce the need to retune hyperparameters when we vary r"
|
71 |
+
https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
|
72 |
+
lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
|
73 |
+
"""
|
74 |
+
super().__init__()
|
75 |
+
assert r >= 0
|
76 |
+
self.r = r
|
77 |
+
self.lora_alpha = lora_alpha
|
78 |
+
# Optional dropout
|
79 |
+
if lora_dropout > 0.0:
|
80 |
+
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
81 |
+
else:
|
82 |
+
self.lora_dropout = lambda x: x
|
83 |
+
# Mark the weight as unmerged
|
84 |
+
self.merged = False
|
85 |
+
|
86 |
+
|
87 |
+
class LoRALinear(LoRALayer):
|
88 |
+
# LoRA implemented in a dense layer
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
# ↓ this part is for pretrained weights
|
92 |
+
in_features: int,
|
93 |
+
out_features: int,
|
94 |
+
# ↓ the remaining part is for LoRA
|
95 |
+
r: int = 0,
|
96 |
+
lora_alpha: int = 1,
|
97 |
+
lora_dropout: float = 0.0,
|
98 |
+
**kwargs,
|
99 |
+
):
|
100 |
+
"""LoRA wrapper around linear class.
|
101 |
+
|
102 |
+
This class has three weight matrices:
|
103 |
+
1. Pretrained weights are stored as `self.linear.weight`
|
104 |
+
2. LoRA A matrix as `self.lora_A`
|
105 |
+
3. LoRA B matrix as `self.lora_B`
|
106 |
+
Only LoRA's A and B matrices are updated, pretrained weights stay frozen.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
in_features: number of input features of the pretrained weights
|
110 |
+
out_features: number of output features of the pretrained weights
|
111 |
+
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
|
112 |
+
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
|
113 |
+
lora_alpha: alpha is needed for scaling updates as alpha/r
|
114 |
+
"This scaling helps to reduce the need to retune hyperparameters when we vary r"
|
115 |
+
https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
|
116 |
+
lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
|
117 |
+
"""
|
118 |
+
super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
|
119 |
+
self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
|
120 |
+
|
121 |
+
# Actual trainable parameters
|
122 |
+
if r > 0:
|
123 |
+
self.lora_A = nn.Parameter(torch.zeros((r, in_features)))
|
124 |
+
self.lora_B = nn.Parameter(torch.zeros((out_features, r)))
|
125 |
+
self.scaling = self.lora_alpha / self.r
|
126 |
+
self.reset_parameters()
|
127 |
+
|
128 |
+
def reset_parameters(self) -> None:
|
129 |
+
"""Reset all the weights, even including pretrained ones."""
|
130 |
+
if hasattr(self, "lora_A"):
|
131 |
+
# initialize A the same way as the default for nn.Linear and B to zero
|
132 |
+
# Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314
|
133 |
+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
134 |
+
nn.init.zeros_(self.lora_B)
|
135 |
+
|
136 |
+
def merge(self) -> None:
|
137 |
+
"""Merges the LoRA weights into the full-rank weights (W = W + delta_W)."""
|
138 |
+
if self.r > 0 and not self.merged:
|
139 |
+
# Merge the weights and mark it
|
140 |
+
self.linear.weight.data += (self.lora_B @ self.lora_A) * self.scaling
|
141 |
+
self.merged = True
|
142 |
+
|
143 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
144 |
+
# if weights are merged or rank is less or equal to zero (LoRA is disabled) - it's only a regular nn.Linear forward pass;
|
145 |
+
# otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights
|
146 |
+
pretrained = self.linear(x)
|
147 |
+
if self.r == 0 or self.merged:
|
148 |
+
return pretrained
|
149 |
+
lora = (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
|
150 |
+
return pretrained + lora
|
151 |
+
|
152 |
+
|
153 |
+
class LoRAQKVLinear(LoRALinear):
|
154 |
+
# LoRA implemented in a dense layer
|
155 |
+
def __init__(
|
156 |
+
self,
|
157 |
+
# ↓ this part is for pretrained weights
|
158 |
+
in_features: int,
|
159 |
+
out_features: int,
|
160 |
+
# ↓ the remaining part is for LoRA
|
161 |
+
n_head: int,
|
162 |
+
n_query_groups: int,
|
163 |
+
r: int = 0,
|
164 |
+
lora_alpha: int = 1,
|
165 |
+
lora_dropout: float = 0.0,
|
166 |
+
enable_lora: Union[bool, Tuple[bool, bool, bool]] = False,
|
167 |
+
**kwargs,
|
168 |
+
):
|
169 |
+
"""LoRA wrapper around linear class that is used for calculation of q, k and v matrices.
|
170 |
+
|
171 |
+
This class has three weight matrices:
|
172 |
+
1. Pretrained weights are stored as `self.linear.weight`
|
173 |
+
2. LoRA A matrix as `self.lora_A`
|
174 |
+
3. LoRA B matrix as `self.lora_B`
|
175 |
+
Only LoRA's A and B matrices are updated, pretrained weights stay frozen.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
in_features: number of input features of the pretrained weights
|
179 |
+
out_features: number of output features of the pretrained weights
|
180 |
+
n_head: number of attention heads
|
181 |
+
n_query_groups: number of query groups (see diagram in `lit_gpt/config.py`)
|
182 |
+
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
|
183 |
+
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
|
184 |
+
lora_alpha: alpha is needed for scaling updates as alpha/r
|
185 |
+
"This scaling helps to reduce the need to retune hyperparameters when we vary r"
|
186 |
+
https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
|
187 |
+
lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
|
188 |
+
enable_lora: MergeLinear class is for attention mechanism where qkv are calculated with a single weight matrix. If we
|
189 |
+
don't want to apply LoRA we can set it as False. For example if we want to apply LoRA only to `query`
|
190 |
+
and `value` but keep `key` without weight updates we should pass `[True, False, True]`
|
191 |
+
"""
|
192 |
+
super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
|
193 |
+
self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
|
194 |
+
self.n_head = n_head
|
195 |
+
self.n_query_groups = n_query_groups
|
196 |
+
if isinstance(enable_lora, bool):
|
197 |
+
enable_lora = [enable_lora] * 3
|
198 |
+
assert len(enable_lora) == 3
|
199 |
+
self.enable_lora = enable_lora
|
200 |
+
|
201 |
+
# Actual trainable parameters
|
202 |
+
# To better understand initialization let's imagine that we have such parameters:
|
203 |
+
# ⚬ in_features: 128 (embeddings_size)
|
204 |
+
# ⚬ out_features: 384 (3 * embedding_size)
|
205 |
+
# ⚬ r: 2
|
206 |
+
# ⚬ enable_lora: [True, False, True]
|
207 |
+
if r > 0 and any(enable_lora):
|
208 |
+
self.lora_A = nn.Parameter(torch.zeros((r * sum(enable_lora), in_features))) # (4, 128)
|
209 |
+
enable_q, enable_k, enable_v = enable_lora
|
210 |
+
self.kv_embd_size = self.linear.in_features // (n_head // n_query_groups)
|
211 |
+
# qkv_shapes will be used to split a tensor with weights correctly
|
212 |
+
qkv_shapes = (
|
213 |
+
self.linear.in_features * enable_q,
|
214 |
+
self.kv_embd_size * enable_k,
|
215 |
+
self.kv_embd_size * enable_v,
|
216 |
+
)
|
217 |
+
self.qkv_shapes = [s for s in qkv_shapes if s]
|
218 |
+
self.lora_B = nn.Parameter(torch.zeros(sum(self.qkv_shapes), r)) # (256, 2))
|
219 |
+
# Notes about shapes above
|
220 |
+
# - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices;
|
221 |
+
# 128 is the input size of the x (embedding size). (4, 128) and not (128, 4) because later on in
|
222 |
+
# F.linear function weights are automatically transposed. In addition conv1d requires channels to
|
223 |
+
# be before seq length
|
224 |
+
# - self.lora_B has shape (256, 2): 256 because LoRA is applied only to two matrices, so the output is
|
225 |
+
# 128*2; 2 tells to have two channels per group for group convolution
|
226 |
+
|
227 |
+
# Scaling:
|
228 |
+
# This balances the pretrained model`s knowledge and the new task-specific adaptation
|
229 |
+
# https://lightning.ai/pages/community/tutorial/lora-llm/
|
230 |
+
# So, set alpha to 1.0 to fully add LoRA. If the LoRA seems to have too much effect (i.e., overfitted), set
|
231 |
+
# alpha to lower value. If the LoRA seems to have too little effect, set alpha to higher than 1.0. You can
|
232 |
+
# tune these values to your needs. This value can be even slightly greater than 1.0!
|
233 |
+
# https://github.com/cloneofsimo/lora
|
234 |
+
self.scaling = self.lora_alpha / self.r
|
235 |
+
|
236 |
+
# Compute the indices
|
237 |
+
# Indices are needed to properly pad weight updates with zeros. If we want to fine-tune queries and values,
|
238 |
+
# but not keys, then the weights update should be:
|
239 |
+
#
|
240 |
+
# [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
|
241 |
+
# [....................................],
|
242 |
+
# [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
|
243 |
+
# ↑ ↑ ↑
|
244 |
+
# ________________________________________
|
245 |
+
# | query | key | value |
|
246 |
+
# ----------------------------------------
|
247 |
+
self.lora_ind = []
|
248 |
+
if enable_q:
|
249 |
+
self.lora_ind.extend(range(0, self.linear.in_features))
|
250 |
+
if enable_k:
|
251 |
+
self.lora_ind.extend(range(self.linear.in_features, self.linear.in_features + self.kv_embd_size))
|
252 |
+
if enable_v:
|
253 |
+
self.lora_ind.extend(range(self.linear.in_features + self.kv_embd_size, self.linear.out_features))
|
254 |
+
self.reset_parameters()
|
255 |
+
|
256 |
+
def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
|
257 |
+
"""Properly pad weight updates with zeros.
|
258 |
+
|
259 |
+
If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys,
|
260 |
+
then the weights update should be:
|
261 |
+
|
262 |
+
[[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
|
263 |
+
[....................................],
|
264 |
+
[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
|
265 |
+
↑ ↑ ↑
|
266 |
+
________________________________________
|
267 |
+
| query | key | value |
|
268 |
+
----------------------------------------
|
269 |
+
|
270 |
+
Args:
|
271 |
+
x: tensor with weights update that will be padded with zeros if necessary
|
272 |
+
|
273 |
+
Returns:
|
274 |
+
A tensor with weight updates and zeros for deselected q, k or v
|
275 |
+
"""
|
276 |
+
# we need to do zero padding only if LoRA is disabled for one of QKV matrices
|
277 |
+
if all(self.enable_lora):
|
278 |
+
return x
|
279 |
+
|
280 |
+
# Let's image that:
|
281 |
+
# ⚬ input x has shape (64, 64, 256): (batch_size, sequence_length, embeddings_size)
|
282 |
+
# ⚬ embeddings_size: 128
|
283 |
+
# ⚬ self.linear.out_features: 384 (3 * embeddings_size)
|
284 |
+
# ⚬ enable_lora: [True, False, True]
|
285 |
+
# Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected
|
286 |
+
# embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but
|
287 |
+
# only for key updates (this is where self.lora_ind comes in handy)
|
288 |
+
# Note: double transpose (in the beginning and in the end) is basically a guard for two-dimensional tensors
|
289 |
+
# for example when we want to merge/unmerge LoRA weights and pretrained weights
|
290 |
+
x = x.transpose(0, 1)
|
291 |
+
result = x.new_zeros((*x.shape[:-1], self.linear.out_features)) # (64, 64, 384)
|
292 |
+
result = result.view(-1, self.linear.out_features) # (4096, 384)
|
293 |
+
result = result.index_copy(
|
294 |
+
1, torch.tensor(self.lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes))
|
295 |
+
) # (4096, 256)
|
296 |
+
return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1) # (64, 64, 384)
|
297 |
+
|
298 |
+
def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
299 |
+
"""An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries.
|
300 |
+
|
301 |
+
If the number of heads is equal to the number of query groups - grouped queries are disabled
|
302 |
+
(see scheme in `lit_gpt/config.py:Config`). In this case the combined QKV matrix consists of equally sized
|
303 |
+
query, key and value parts, which means we can utilize `groups` argument from `conv1d`: with this argument the
|
304 |
+
input and weight matrices will be splitted in equally sized parts and applied separately (like having multiple
|
305 |
+
conv layers side by side).
|
306 |
+
|
307 |
+
Otherwise QKV matrix consists of unequally sized parts and thus we have to split input and weight matrices manually,
|
308 |
+
apply each part of the weight matrix to the corresponding input's part and concatenate the result.
|
309 |
+
|
310 |
+
Args:
|
311 |
+
input: input matrix of shape (B, C, T)
|
312 |
+
weight: weight matrix of shape (C_output, rank, 1).
|
313 |
+
"C_output" is defined as a sum of embedding sizes for each enabled LoRA layer (see init method of the class).
|
314 |
+
|
315 |
+
Returns:
|
316 |
+
A tensor with a shape (B, C_output, T)
|
317 |
+
|
318 |
+
"""
|
319 |
+
if self.n_head == self.n_query_groups:
|
320 |
+
return F.conv1d(input, weight, groups=sum(self.enable_lora)) # (B, C_output, T)
|
321 |
+
|
322 |
+
# Notation:
|
323 |
+
# ⚬ N: number of enabled LoRA layers (self.enable_lora)
|
324 |
+
# ⚬ C_output': embeddings size for each LoRA layer (not equal in size)
|
325 |
+
# ⚬ r: rank of all LoRA layers (equal in size)
|
326 |
+
|
327 |
+
input_splitted = input.chunk(sum(self.enable_lora), dim=1) # N * (B, C // N, T)
|
328 |
+
weight_splitted = weight.split(self.qkv_shapes) # N * (C_output', r, 1)
|
329 |
+
return torch.cat(
|
330 |
+
[F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1 # (B, C_output', T)
|
331 |
+
) # (B, C_output, T)
|
332 |
+
|
333 |
+
def merge(self) -> None:
|
334 |
+
"""Merges the LoRA weights into the full-rank weights (W = W + delta_W)."""
|
335 |
+
|
336 |
+
# Let's assume that:
|
337 |
+
# ⚬ self.linear.weight.data: (384, 128) or (3 * embedding_size, embedding_size)
|
338 |
+
# ⚬ self.lora_A.data: (4, 128)
|
339 |
+
# ⚬ self.lora_B.data: (256, 2)
|
340 |
+
if self.r > 0 and any(self.enable_lora) and not self.merged:
|
341 |
+
delta_w = self.conv1d(
|
342 |
+
self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128)
|
343 |
+
self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
|
344 |
+
).squeeze(
|
345 |
+
0
|
346 |
+
) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)
|
347 |
+
# W = W + delta_W (merge)
|
348 |
+
self.linear.weight.data += self.zero_pad(delta_w * self.scaling) # (256, 128) after zero_pad (384, 128)
|
349 |
+
self.merged = True
|
350 |
+
|
351 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
352 |
+
"""Do the forward pass.
|
353 |
+
|
354 |
+
If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication.
|
355 |
+
If not, then multiply pretrained weights with input, apply LoRA on input and do summation.
|
356 |
+
|
357 |
+
Args:
|
358 |
+
x: input tensor of shape (batch_size, context_length, embedding_size)
|
359 |
+
|
360 |
+
Returns:
|
361 |
+
Output tensor of shape (batch_size, context_length, 3 * embedding_size)
|
362 |
+
"""
|
363 |
+
|
364 |
+
# Let's assume that:
|
365 |
+
# ⚬ x: (64, 64, 128) or (batch_size, context_length, embedding_size)
|
366 |
+
# ⚬ self.linear.weight: (384, 128) or (3 * embedding_size, embedding_size)
|
367 |
+
# ⚬ self.lora_A.data: (4, 128)
|
368 |
+
# ⚬ self.lora_B.data: (256, 2)
|
369 |
+
|
370 |
+
# if weights are merged or LoRA is disabled (r <= 0 or all `enable_lora` are False) - it's only a regular nn.Linear forward pass;
|
371 |
+
# otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights
|
372 |
+
pretrained = self.linear(x)
|
373 |
+
if self.r == 0 or not any(self.enable_lora) or self.merged:
|
374 |
+
return pretrained
|
375 |
+
after_A = F.linear(self.lora_dropout(x), self.lora_A) # (64, 64, 128) @ (4, 128) -> (64, 64, 4)
|
376 |
+
# For F.conv1d:
|
377 |
+
# ⚬ input: input tensor of shape (mini-batch, in_channels, iW)
|
378 |
+
# ⚬ weight: filters of shape (out_channels, in_channels/groups, kW)
|
379 |
+
after_B = self.conv1d(
|
380 |
+
after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64)
|
381 |
+
self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
|
382 |
+
).transpose(
|
383 |
+
-2, -1
|
384 |
+
) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256)
|
385 |
+
lora = self.zero_pad(after_B) * self.scaling # (64, 64, 256) after zero_pad (64, 64, 384)
|
386 |
+
return pretrained + lora
|
387 |
+
|
388 |
+
|
389 |
+
def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
|
390 |
+
"""Freeze all modules except LoRA's and depending on 'bias' value unfreezes bias weights.
|
391 |
+
|
392 |
+
Args:
|
393 |
+
model: model with LoRA layers
|
394 |
+
bias:
|
395 |
+
``"none"``: all bias weights will be frozen,
|
396 |
+
``"lora_only"``: only bias weight for LoRA layers will be unfrozen,
|
397 |
+
``"all"``: all bias weights will be unfrozen.
|
398 |
+
|
399 |
+
Raises:
|
400 |
+
NotImplementedError: if `bias` not in ["none", "lora_only", "all"]
|
401 |
+
"""
|
402 |
+
# freeze all layers except LoRA's
|
403 |
+
for n, p in model.named_parameters():
|
404 |
+
if "lora_" not in n:
|
405 |
+
p.requires_grad = False
|
406 |
+
|
407 |
+
# depending on the `bias` value unfreeze bias weights
|
408 |
+
if bias == "none":
|
409 |
+
return
|
410 |
+
if bias == "all":
|
411 |
+
for n, p in model.named_parameters():
|
412 |
+
if "bias" in n:
|
413 |
+
p.requires_grad = True
|
414 |
+
elif bias == "lora_only":
|
415 |
+
for m in model.modules():
|
416 |
+
if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None:
|
417 |
+
m.bias.requires_grad = True
|
418 |
+
else:
|
419 |
+
raise NotImplementedError
|
420 |
+
|
421 |
+
|
422 |
+
def lora_filter(key: str, value: Any) -> bool:
|
423 |
+
return "lora_" in key
|
424 |
+
|
425 |
+
|
426 |
+
@dataclass
|
427 |
+
class Config(BaseConfig):
|
428 |
+
"""
|
429 |
+
Args:
|
430 |
+
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
|
431 |
+
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
|
432 |
+
alpha: alpha is needed for scaling updates as alpha/r
|
433 |
+
"This scaling helps to reduce the need to retune hyperparameters when we vary r"
|
434 |
+
https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
|
435 |
+
dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
|
436 |
+
to_*: either apply LoRA to the specified weights or not
|
437 |
+
"""
|
438 |
+
|
439 |
+
r: int = 0
|
440 |
+
alpha: int = 1
|
441 |
+
dropout: float = 0.0
|
442 |
+
to_query: bool = False
|
443 |
+
to_key: bool = False
|
444 |
+
to_value: bool = False
|
445 |
+
to_projection: bool = False
|
446 |
+
to_mlp: bool = False
|
447 |
+
to_head: bool = False
|
448 |
+
|
449 |
+
@property
|
450 |
+
def mlp_class(self) -> Type:
|
451 |
+
return getattr(lit_gpt.lora, self._mlp_class)
|
452 |
+
|
453 |
+
|
454 |
+
class GPT(BaseModel):
|
455 |
+
def __init__(self, config: Config) -> None:
|
456 |
+
nn.Module.__init__(self)
|
457 |
+
assert config.padded_vocab_size is not None
|
458 |
+
self.config = config
|
459 |
+
|
460 |
+
self.lm_head = LoRALinear(
|
461 |
+
config.n_embd,
|
462 |
+
config.padded_vocab_size,
|
463 |
+
bias=config.lm_head_bias,
|
464 |
+
r=(config.r if config.to_head else 0),
|
465 |
+
lora_alpha=config.alpha,
|
466 |
+
lora_dropout=config.dropout,
|
467 |
+
)
|
468 |
+
self.transformer = nn.ModuleDict(
|
469 |
+
dict(
|
470 |
+
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
|
471 |
+
h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
|
472 |
+
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
|
473 |
+
)
|
474 |
+
)
|
475 |
+
self.max_seq_length = self.config.block_size
|
476 |
+
self.mask_cache: Optional[torch.Tensor] = None
|
477 |
+
|
478 |
+
def forward(
|
479 |
+
self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0, maxlen: int = None
|
480 |
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
481 |
+
T = idx.size(1) if maxlen is None else maxlen
|
482 |
+
if self.max_seq_length < T:
|
483 |
+
raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")
|
484 |
+
# import pdb; pdb.set_trace()
|
485 |
+
if input_pos is not None: # use the kv cache
|
486 |
+
cos = self.cos.index_select(0, input_pos)
|
487 |
+
sin = self.sin.index_select(0, input_pos)
|
488 |
+
if self.mask_cache is None:
|
489 |
+
raise TypeError("You need to call `gpt.set_kv_cache()`")
|
490 |
+
mask = self.mask_cache.index_select(2, input_pos)
|
491 |
+
else:
|
492 |
+
cos = self.cos[:T]
|
493 |
+
sin = self.sin[:T]
|
494 |
+
mask = None
|
495 |
+
|
496 |
+
if type(idx) is tuple:
|
497 |
+
# import pdb; pdb.set_trace()
|
498 |
+
stack_before_tokens_x, motion_tokens, before_len = idx
|
499 |
+
# stack_before_tokens_x = stack_before_tokens_x.unsqueeze(0)
|
500 |
+
# motion_tokens = motion_tokens.unsqueeze(0)
|
501 |
+
# stack_before_tokens_x[0][before_len[0]: before_len[0] + len(motion_tokens[0])] = 1
|
502 |
+
# import pdb; pdb.set_trace()
|
503 |
+
x = self.transformer.wte(stack_before_tokens_x)
|
504 |
+
# import pdb; pdb.set_trace()
|
505 |
+
for i in range(len(x)):
|
506 |
+
x[i][before_len[i]: before_len[i] + len(motion_tokens[i])] = motion_tokens[i]
|
507 |
+
else:
|
508 |
+
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
509 |
+
for block in self.transformer.h:
|
510 |
+
x = block(x, cos, sin, mask, input_pos)
|
511 |
+
x = self.transformer.ln_f(x)
|
512 |
+
if lm_head_chunk_size > 0:
|
513 |
+
# chunk the lm head logits to reduce the peak memory used by autograd
|
514 |
+
return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)]
|
515 |
+
return self.lm_head(x) # (B, T, vocab_size)
|
516 |
+
|
517 |
+
@classmethod
|
518 |
+
def from_name(cls, name: str, **kwargs: Any) -> Self:
|
519 |
+
return cls(Config.from_name(name, **kwargs))
|
520 |
+
|
521 |
+
def _init_weights(self, module: nn.Module) -> None:
|
522 |
+
"""Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
|
523 |
+
super()._init_weights(module)
|
524 |
+
if isinstance(module, LoRALinear):
|
525 |
+
module.reset_parameters()
|
526 |
+
|
527 |
+
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
|
528 |
+
"""For compatibility with base checkpoints."""
|
529 |
+
mapping = {"lm_head.weight": "lm_head.linear.weight"}
|
530 |
+
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
|
531 |
+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
532 |
+
|
533 |
+
|
534 |
+
class Block(BaseBlock):
|
535 |
+
def __init__(self, config: Config) -> None:
|
536 |
+
nn.Module.__init__(self)
|
537 |
+
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
|
538 |
+
self.attn = CausalSelfAttention(config)
|
539 |
+
if not config.shared_attention_norm:
|
540 |
+
self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
|
541 |
+
self.mlp = config.mlp_class(config)
|
542 |
+
|
543 |
+
self.config = config
|
544 |
+
|
545 |
+
|
546 |
+
class CausalSelfAttention(BaseCausalSelfAttention):
|
547 |
+
def __init__(self, config: Config) -> None:
|
548 |
+
# Skip the parent class __init__ altogether and replace it to avoid
|
549 |
+
# useless allocations
|
550 |
+
nn.Module.__init__(self)
|
551 |
+
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
|
552 |
+
# key, query, value projections for all heads, but in a batch
|
553 |
+
self.attn = LoRAQKVLinear(
|
554 |
+
in_features=config.n_embd,
|
555 |
+
out_features=shape,
|
556 |
+
r=config.r,
|
557 |
+
lora_alpha=config.alpha,
|
558 |
+
lora_dropout=config.dropout,
|
559 |
+
enable_lora=(config.to_query, config.to_key, config.to_value),
|
560 |
+
bias=config.bias,
|
561 |
+
# for MQA/GQA support
|
562 |
+
n_head=config.n_head,
|
563 |
+
n_query_groups=config.n_query_groups,
|
564 |
+
)
|
565 |
+
# output projection
|
566 |
+
self.proj = LoRALinear(
|
567 |
+
config.n_embd,
|
568 |
+
config.n_embd,
|
569 |
+
bias=config.bias,
|
570 |
+
r=(config.r if config.to_projection else 0),
|
571 |
+
lora_alpha=config.alpha,
|
572 |
+
lora_dropout=config.dropout,
|
573 |
+
)
|
574 |
+
# disabled by default
|
575 |
+
self.kv_cache: Optional[KVCache] = None
|
576 |
+
|
577 |
+
self.config = config
|
578 |
+
|
579 |
+
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
|
580 |
+
"""For compatibility with base checkpoints."""
|
581 |
+
mapping = {
|
582 |
+
"attn.weight": "attn.linear.weight",
|
583 |
+
"attn.bias": "attn.linear.bias",
|
584 |
+
"proj.weight": "proj.linear.weight",
|
585 |
+
"proj.bias": "proj.linear.bias",
|
586 |
+
}
|
587 |
+
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
|
588 |
+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
589 |
+
|
590 |
+
|
591 |
+
class GptNeoxMLP(lit_gpt.model.GptNeoxMLP):
|
592 |
+
def __init__(self, config: Config) -> None:
|
593 |
+
nn.Module.__init__(self)
|
594 |
+
self.fc = LoRALinear(
|
595 |
+
config.n_embd,
|
596 |
+
config.intermediate_size,
|
597 |
+
bias=config.bias,
|
598 |
+
r=(config.r if config.to_mlp else 0),
|
599 |
+
lora_alpha=config.alpha,
|
600 |
+
lora_dropout=config.dropout,
|
601 |
+
)
|
602 |
+
self.proj = LoRALinear(
|
603 |
+
config.intermediate_size,
|
604 |
+
config.n_embd,
|
605 |
+
bias=config.bias,
|
606 |
+
r=(config.r if config.to_mlp else 0),
|
607 |
+
lora_alpha=config.alpha,
|
608 |
+
lora_dropout=config.dropout,
|
609 |
+
)
|
610 |
+
|
611 |
+
self.config = config
|
612 |
+
|
613 |
+
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
|
614 |
+
"""For compatibility with base checkpoints."""
|
615 |
+
mapping = {
|
616 |
+
"fc.weight": "fc.linear.weight",
|
617 |
+
"fc.bias": "fc.linear.bias",
|
618 |
+
"proj.weight": "proj.linear.weight",
|
619 |
+
"proj.bias": "proj.linear.bias",
|
620 |
+
}
|
621 |
+
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
|
622 |
+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
623 |
+
|
624 |
+
|
625 |
+
class LLaMAMLP(lit_gpt.model.LLaMAMLP):
|
626 |
+
def __init__(self, config: Config) -> None:
|
627 |
+
nn.Module.__init__(self)
|
628 |
+
self.fc_1 = LoRALinear(
|
629 |
+
config.n_embd,
|
630 |
+
config.intermediate_size,
|
631 |
+
bias=config.bias,
|
632 |
+
r=(config.r if config.to_mlp else 0),
|
633 |
+
lora_alpha=config.alpha,
|
634 |
+
lora_dropout=config.dropout,
|
635 |
+
)
|
636 |
+
self.fc_2 = LoRALinear(
|
637 |
+
config.n_embd,
|
638 |
+
config.intermediate_size,
|
639 |
+
bias=config.bias,
|
640 |
+
r=(config.r if config.to_mlp else 0),
|
641 |
+
lora_alpha=config.alpha,
|
642 |
+
lora_dropout=config.dropout,
|
643 |
+
)
|
644 |
+
self.proj = LoRALinear(
|
645 |
+
config.intermediate_size,
|
646 |
+
config.n_embd,
|
647 |
+
bias=config.bias,
|
648 |
+
r=(config.r if config.to_mlp else 0),
|
649 |
+
lora_alpha=config.alpha,
|
650 |
+
lora_dropout=config.dropout,
|
651 |
+
)
|
652 |
+
|
653 |
+
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
|
654 |
+
"""For compatibility with base checkpoints."""
|
655 |
+
mapping = {
|
656 |
+
"fc_1.weight": "fc_1.linear.weight",
|
657 |
+
"fc_1.bias": "fc_1.linear.bias",
|
658 |
+
"fc_2.weight": "fc_2.linear.weight",
|
659 |
+
"fc_2.bias": "fc_2.linear.bias",
|
660 |
+
"proj.weight": "proj.linear.weight",
|
661 |
+
"proj.bias": "proj.linear.bias",
|
662 |
+
}
|
663 |
+
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
|
664 |
+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
665 |
+
|
666 |
+
|
667 |
+
def merge_lora_weights(model: GPT) -> None:
|
668 |
+
"""Merge LoRA weights into the full-rank weights to speed up inference."""
|
669 |
+
for module in model.modules():
|
670 |
+
if isinstance(module, LoRALinear):
|
671 |
+
module.merge()
|
lit_gpt/model.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Full definition of a GPT NeoX Language Model, all of it in this single file.
|
2 |
+
|
3 |
+
Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
|
4 |
+
https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
|
5 |
+
"""
|
6 |
+
import math
|
7 |
+
from typing import Any, Optional, Tuple
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from typing_extensions import Self
|
12 |
+
|
13 |
+
from lit_gpt.config import Config
|
14 |
+
|
15 |
+
|
16 |
+
class GPT(nn.Module):
|
17 |
+
def __init__(self, config: Config) -> None:
|
18 |
+
super().__init__()
|
19 |
+
assert config.padded_vocab_size is not None
|
20 |
+
self.config = config
|
21 |
+
|
22 |
+
self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
|
23 |
+
self.transformer = nn.ModuleDict(
|
24 |
+
dict(
|
25 |
+
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
|
26 |
+
h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
|
27 |
+
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
|
28 |
+
)
|
29 |
+
)
|
30 |
+
self.max_seq_length = self.config.block_size
|
31 |
+
self.mask_cache: Optional[torch.Tensor] = None
|
32 |
+
|
33 |
+
@property
|
34 |
+
def max_seq_length(self) -> int:
|
35 |
+
return self._max_seq_length
|
36 |
+
|
37 |
+
@max_seq_length.setter
|
38 |
+
def max_seq_length(self, value: int) -> None:
|
39 |
+
"""
|
40 |
+
When doing inference, the sequences used might be shorter than the model's context length.
|
41 |
+
This allows setting a smaller number to avoid allocating unused memory
|
42 |
+
"""
|
43 |
+
if value > self.config.block_size:
|
44 |
+
raise ValueError(f"Cannot attend to {value}, block size is only {self.config.block_size}")
|
45 |
+
self._max_seq_length = value
|
46 |
+
if not hasattr(self, "cos"):
|
47 |
+
# first call
|
48 |
+
cos, sin = self.rope_cache()
|
49 |
+
self.register_buffer("cos", cos, persistent=False)
|
50 |
+
self.register_buffer("sin", sin, persistent=False)
|
51 |
+
elif value != self.cos.size(0):
|
52 |
+
# override
|
53 |
+
self.cos, self.sin = self.rope_cache(device=self.cos.device)
|
54 |
+
# the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know
|
55 |
+
# if the kv cache is expected
|
56 |
+
|
57 |
+
def reset_parameters(self) -> None:
|
58 |
+
# Trigger resetting the rope-cache
|
59 |
+
self.max_seq_length = self.config.block_size
|
60 |
+
|
61 |
+
def _init_weights(self, module: nn.Module) -> None:
|
62 |
+
"""Meant to be used with `gpt.apply(gpt._init_weights)`."""
|
63 |
+
if isinstance(module, nn.Linear):
|
64 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
65 |
+
if module.bias is not None:
|
66 |
+
torch.nn.init.zeros_(module.bias)
|
67 |
+
elif isinstance(module, nn.Embedding):
|
68 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
69 |
+
|
70 |
+
def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, maxlen: int = None) -> torch.Tensor:
|
71 |
+
T = idx.size(1) if maxlen is None else maxlen
|
72 |
+
# print(T, end=', ')
|
73 |
+
if self.max_seq_length < T:
|
74 |
+
raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")
|
75 |
+
|
76 |
+
# import pdb; pdb.set_trace()
|
77 |
+
if input_pos is not None: # use the kv cache
|
78 |
+
cos = self.cos.index_select(0, input_pos)
|
79 |
+
sin = self.sin.index_select(0, input_pos)
|
80 |
+
if self.mask_cache is None:
|
81 |
+
raise TypeError("You need to call `gpt.set_kv_cache()`")
|
82 |
+
mask = self.mask_cache.index_select(2, input_pos)
|
83 |
+
else:
|
84 |
+
cos = self.cos[:T]
|
85 |
+
sin = self.sin[:T]
|
86 |
+
mask = None
|
87 |
+
|
88 |
+
if type(idx) is tuple:
|
89 |
+
stack_before_tokens_x, motion_tokens, before_len = idx
|
90 |
+
# stack_before_tokens_x = stack_before_tokens_x.unsqueeze(0)
|
91 |
+
# motion_tokens = motion_tokens.unsqueeze(0)
|
92 |
+
# stack_before_tokens_x[0][before_len[0]: before_len[0] + len(motion_tokens[0])] = 1
|
93 |
+
x = self.transformer.wte(stack_before_tokens_x.cuda())
|
94 |
+
# import pdb; pdb.set_trace()
|
95 |
+
for i in range(len(x)):
|
96 |
+
x[i][before_len[i]: before_len[i] + len(motion_tokens[i])] = motion_tokens[i].cuda()
|
97 |
+
else:
|
98 |
+
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
99 |
+
for block in self.transformer.h:
|
100 |
+
x = block(x, cos, sin, mask, input_pos)
|
101 |
+
x = self.transformer.ln_f(x)
|
102 |
+
return self.lm_head(x) # (b, t, vocab_size)
|
103 |
+
|
104 |
+
@classmethod
|
105 |
+
def from_name(cls, name: str, **kwargs: Any) -> Self:
|
106 |
+
return cls(Config.from_name(name, **kwargs))
|
107 |
+
|
108 |
+
def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
109 |
+
return build_rope_cache(
|
110 |
+
seq_len=self.max_seq_length,
|
111 |
+
n_elem=self.config.rope_n_elem,
|
112 |
+
device=device,
|
113 |
+
condense_ratio=self.config.rope_condense_ratio,
|
114 |
+
base=self.config.rope_base,
|
115 |
+
)
|
116 |
+
|
117 |
+
def set_kv_cache(
|
118 |
+
self,
|
119 |
+
batch_size: int,
|
120 |
+
rope_cache_length: Optional[int] = None,
|
121 |
+
device: Optional[torch.device] = None,
|
122 |
+
dtype: Optional[torch.dtype] = None,
|
123 |
+
) -> None:
|
124 |
+
if rope_cache_length is None:
|
125 |
+
rope_cache_length = self.cos.size(-1)
|
126 |
+
max_seq_length = self.max_seq_length
|
127 |
+
|
128 |
+
# initialize the kv cache for all blocks
|
129 |
+
for block in self.transformer.h:
|
130 |
+
block.attn.kv_cache = block.attn.build_kv_cache(
|
131 |
+
batch_size, max_seq_length, rope_cache_length, device, dtype
|
132 |
+
)
|
133 |
+
|
134 |
+
if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length:
|
135 |
+
# passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask
|
136 |
+
# for the kv-cache support (only during inference), we only create it in that situation
|
137 |
+
# this will be resolved by https://github.com/pytorch/pytorch/issues/96099
|
138 |
+
ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)
|
139 |
+
self.mask_cache = torch.tril(ones).unsqueeze(0).unsqueeze(0)
|
140 |
+
|
141 |
+
def clear_kv_cache(self) -> None:
|
142 |
+
self.mask_cache = None
|
143 |
+
for block in self.transformer.h:
|
144 |
+
block.attn.kv_cache = None
|
145 |
+
|
146 |
+
|
147 |
+
class Block(nn.Module):
|
148 |
+
def __init__(self, config: Config) -> None:
|
149 |
+
super().__init__()
|
150 |
+
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
|
151 |
+
self.attn = CausalSelfAttention(config)
|
152 |
+
self.norm_2 = None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps)
|
153 |
+
self.mlp = config.mlp_class(config)
|
154 |
+
|
155 |
+
self.config = config
|
156 |
+
|
157 |
+
def forward(
|
158 |
+
self,
|
159 |
+
x: torch.Tensor,
|
160 |
+
cos: torch.Tensor,
|
161 |
+
sin: torch.Tensor,
|
162 |
+
mask: Optional[torch.Tensor] = None,
|
163 |
+
input_pos: Optional[torch.Tensor] = None,
|
164 |
+
) -> torch.Tensor:
|
165 |
+
n_1 = self.norm_1(x)
|
166 |
+
h = self.attn(n_1, cos, sin, mask, input_pos)
|
167 |
+
if self.config.parallel_residual:
|
168 |
+
n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
|
169 |
+
x = self.mlp(n_2) + h + x
|
170 |
+
else:
|
171 |
+
if self.config.shared_attention_norm:
|
172 |
+
raise NotImplementedError(
|
173 |
+
"No checkpoint amongst the ones we support uses this configuration"
|
174 |
+
" (non-parallel residual and shared attention norm)."
|
175 |
+
)
|
176 |
+
x = h + x
|
177 |
+
x = self.mlp(self.norm_2(x)) + x
|
178 |
+
return x
|
179 |
+
|
180 |
+
|
181 |
+
class CausalSelfAttention(nn.Module):
|
182 |
+
def __init__(self, config: Config) -> None:
|
183 |
+
super().__init__()
|
184 |
+
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
|
185 |
+
# key, query, value projections for all heads, but in a batch
|
186 |
+
self.attn = nn.Linear(config.n_embd, shape, bias=config.bias)
|
187 |
+
# output projection
|
188 |
+
self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
189 |
+
# disabled by default
|
190 |
+
self.kv_cache: Optional[KVCache] = None
|
191 |
+
|
192 |
+
self.config = config
|
193 |
+
|
194 |
+
def forward(
|
195 |
+
self,
|
196 |
+
x: torch.Tensor,
|
197 |
+
cos: torch.Tensor,
|
198 |
+
sin: torch.Tensor,
|
199 |
+
mask: Optional[torch.Tensor] = None,
|
200 |
+
input_pos: Optional[torch.Tensor] = None,
|
201 |
+
) -> torch.Tensor:
|
202 |
+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
203 |
+
|
204 |
+
qkv = self.attn(x)
|
205 |
+
|
206 |
+
# assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
|
207 |
+
q_per_kv = self.config.n_head // self.config.n_query_groups
|
208 |
+
total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
|
209 |
+
qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size)
|
210 |
+
qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)
|
211 |
+
|
212 |
+
# split batched computation into three
|
213 |
+
q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
|
214 |
+
|
215 |
+
# maybe repeat k and v if for the non multi-head attention cases
|
216 |
+
# training: flash attention requires it
|
217 |
+
# inference: multi-query would require a full kv cache so avoid it to limit its memory usage
|
218 |
+
if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1):
|
219 |
+
k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
|
220 |
+
v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
|
221 |
+
|
222 |
+
q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs)
|
223 |
+
k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs)
|
224 |
+
v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs)
|
225 |
+
|
226 |
+
q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin)
|
227 |
+
k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin)
|
228 |
+
q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1)
|
229 |
+
k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1)
|
230 |
+
|
231 |
+
if input_pos is not None:
|
232 |
+
if not isinstance(self.kv_cache, KVCache):
|
233 |
+
raise TypeError("You need to call `gpt.set_kv_cache()`")
|
234 |
+
k, v = self.kv_cache(input_pos, k, v)
|
235 |
+
|
236 |
+
y = self.scaled_dot_product_attention(q, k, v, mask)
|
237 |
+
|
238 |
+
y = y.reshape(B, T, C) # re-assemble all head outputs side by side
|
239 |
+
|
240 |
+
# output projection
|
241 |
+
return self.proj(y)
|
242 |
+
|
243 |
+
def scaled_dot_product_attention(
|
244 |
+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
|
245 |
+
) -> torch.Tensor:
|
246 |
+
scale = 1.0 / math.sqrt(self.config.head_size)
|
247 |
+
y = torch.nn.functional.scaled_dot_product_attention(
|
248 |
+
q, k, v, attn_mask=mask, dropout_p=0.0,
|
249 |
+
# scale=scale,
|
250 |
+
is_causal=mask is None
|
251 |
+
)
|
252 |
+
return y.transpose(1, 2)
|
253 |
+
|
254 |
+
def build_kv_cache(
|
255 |
+
self,
|
256 |
+
batch_size: int,
|
257 |
+
max_seq_length: int,
|
258 |
+
rope_cache_length: Optional[int] = None,
|
259 |
+
device: Optional[torch.device] = None,
|
260 |
+
dtype: Optional[torch.dtype] = None,
|
261 |
+
) -> "KVCache":
|
262 |
+
heads = 1 if self.config.n_query_groups == 1 else self.config.n_head
|
263 |
+
v_shape = (batch_size, heads, max_seq_length, self.config.head_size)
|
264 |
+
if rope_cache_length is None:
|
265 |
+
if self.config.rotary_percentage != 1.0:
|
266 |
+
raise TypeError("Please pass the `rope_cache_length=gpt.cos.size(-1)` value")
|
267 |
+
k_shape = v_shape
|
268 |
+
else:
|
269 |
+
k_shape = (
|
270 |
+
batch_size,
|
271 |
+
heads,
|
272 |
+
max_seq_length,
|
273 |
+
rope_cache_length + self.config.head_size - self.config.rope_n_elem,
|
274 |
+
)
|
275 |
+
return KVCache(k_shape, v_shape, device=device, dtype=dtype)
|
276 |
+
|
277 |
+
|
278 |
+
class GptNeoxMLP(nn.Module):
|
279 |
+
def __init__(self, config: Config) -> None:
|
280 |
+
super().__init__()
|
281 |
+
self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
|
282 |
+
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
|
283 |
+
|
284 |
+
self.config = config
|
285 |
+
|
286 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
287 |
+
x = self.fc(x)
|
288 |
+
x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate)
|
289 |
+
return self.proj(x)
|
290 |
+
|
291 |
+
|
292 |
+
class LLaMAMLP(nn.Module):
|
293 |
+
def __init__(self, config: Config) -> None:
|
294 |
+
super().__init__()
|
295 |
+
self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
|
296 |
+
self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
|
297 |
+
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
|
298 |
+
|
299 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
300 |
+
x_fc_1 = self.fc_1(x)
|
301 |
+
x_fc_2 = self.fc_2(x)
|
302 |
+
x = torch.nn.functional.silu(x_fc_1) * x_fc_2
|
303 |
+
return self.proj(x)
|
304 |
+
|
305 |
+
|
306 |
+
def build_rope_cache(
|
307 |
+
seq_len: int, n_elem: int, device: Optional[torch.device] = None, base: int = 10000, condense_ratio: int = 1
|
308 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
309 |
+
"""Enhanced Transformer with Rotary Position Embedding.
|
310 |
+
|
311 |
+
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
|
312 |
+
transformers/rope/__init__.py. MIT License:
|
313 |
+
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
|
314 |
+
"""
|
315 |
+
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
316 |
+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
|
317 |
+
|
318 |
+
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
319 |
+
seq_idx = torch.arange(seq_len, device=device) / condense_ratio
|
320 |
+
|
321 |
+
# Calculate the product of position index and $\theta_i$
|
322 |
+
idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
|
323 |
+
|
324 |
+
return torch.cos(idx_theta), torch.sin(idx_theta)
|
325 |
+
|
326 |
+
|
327 |
+
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
328 |
+
head_size = x.size(-1)
|
329 |
+
x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
|
330 |
+
x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
|
331 |
+
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
|
332 |
+
roped = (x * cos) + (rotated * sin)
|
333 |
+
return roped.type_as(x)
|
334 |
+
|
335 |
+
|
336 |
+
class KVCache(nn.Module):
|
337 |
+
def __init__(
|
338 |
+
self,
|
339 |
+
k_shape: Tuple[int, int, int, int],
|
340 |
+
v_shape: Tuple[int, int, int, int],
|
341 |
+
device: Optional[torch.device] = None,
|
342 |
+
dtype: Optional[torch.dtype] = None,
|
343 |
+
) -> None:
|
344 |
+
super().__init__()
|
345 |
+
self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False)
|
346 |
+
self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False)
|
347 |
+
|
348 |
+
def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
349 |
+
# move the buffer to the activation dtype for when AMP is used
|
350 |
+
self.k = self.k.to(k.dtype)
|
351 |
+
self.v = self.v.to(v.dtype)
|
352 |
+
# update the cache
|
353 |
+
k = self.k.index_copy_(2, input_pos, k)
|
354 |
+
v = self.v.index_copy_(2, input_pos, v)
|
355 |
+
return k, v
|
lit_gpt/packed_dataset.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Very loosely inspired by indexed_dataset in Fairseq, Megatron
|
2 |
+
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py
|
3 |
+
|
4 |
+
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import struct
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from torch.utils.data import IterableDataset, get_worker_info
|
12 |
+
|
13 |
+
dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float32, 7: np.float64, 8: np.uint16}
|
14 |
+
|
15 |
+
|
16 |
+
def code(dtype):
|
17 |
+
for k in dtypes:
|
18 |
+
if dtypes[k] == dtype:
|
19 |
+
return k
|
20 |
+
raise ValueError(dtype)
|
21 |
+
|
22 |
+
|
23 |
+
HDR_MAGIC = b"LITPKDS"
|
24 |
+
HDR_SIZE = 24 # bytes
|
25 |
+
|
26 |
+
|
27 |
+
class PackedDataset(IterableDataset):
|
28 |
+
def __init__(
|
29 |
+
self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0
|
30 |
+
):
|
31 |
+
self._filenames = filenames
|
32 |
+
self._n_chunks = n_chunks
|
33 |
+
self._block_size = block_size
|
34 |
+
self._seed = seed
|
35 |
+
self._shuffle = shuffle
|
36 |
+
self._wrap = wrap
|
37 |
+
self._num_processes = num_processes
|
38 |
+
self._process_rank = process_rank
|
39 |
+
|
40 |
+
def __iter__(self):
|
41 |
+
worker_info = get_worker_info()
|
42 |
+
num_workers = worker_info.num_workers if worker_info is not None else 1
|
43 |
+
worker_id = worker_info.id if worker_info is not None else 0
|
44 |
+
num_shards = num_workers * self._num_processes
|
45 |
+
shard_id = self._process_rank * num_workers + worker_id
|
46 |
+
|
47 |
+
max_num_files = len(self._filenames) // num_shards * num_shards
|
48 |
+
filenames = self._filenames[shard_id:max_num_files:num_shards]
|
49 |
+
|
50 |
+
return PackedDatasetIterator(
|
51 |
+
filenames=filenames,
|
52 |
+
n_chunks=self._n_chunks,
|
53 |
+
block_size=self._block_size,
|
54 |
+
seed=self._seed,
|
55 |
+
shuffle=self._shuffle,
|
56 |
+
wrap=self._wrap,
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
class PackedDatasetBuilder(object):
|
61 |
+
def __init__(self, outdir, prefix, chunk_size, sep_token, dtype="auto", vocab_size=None):
|
62 |
+
if dtype == "auto":
|
63 |
+
if vocab_size is None:
|
64 |
+
raise ValueError("vocab_size cannot be None when dtype='auto'")
|
65 |
+
if vocab_size is not None and vocab_size < 65500:
|
66 |
+
self._dtype = np.uint16
|
67 |
+
else:
|
68 |
+
self._dtype = np.int32
|
69 |
+
else:
|
70 |
+
self._dtype = dtype
|
71 |
+
self._counter = 0
|
72 |
+
self._chunk_size = chunk_size
|
73 |
+
self._outdir = outdir
|
74 |
+
self._prefix = prefix
|
75 |
+
self._sep_token = sep_token
|
76 |
+
self._arr = np.zeros(self._chunk_size, dtype=self._dtype)
|
77 |
+
self._arr.fill(self._sep_token)
|
78 |
+
self._idx = 0
|
79 |
+
self._version = 1
|
80 |
+
self._filenames = []
|
81 |
+
|
82 |
+
def _write_chunk(self):
|
83 |
+
filename = f"{self._prefix}_{self._counter:010d}.bin"
|
84 |
+
filename = os.path.join(self._outdir, filename)
|
85 |
+
|
86 |
+
with open(filename, "wb") as f:
|
87 |
+
f.write(HDR_MAGIC)
|
88 |
+
f.write(struct.pack("<Q", self._version))
|
89 |
+
f.write(struct.pack("<B", code(self._dtype)))
|
90 |
+
f.write(struct.pack("<Q", self._chunk_size))
|
91 |
+
f.write(self._arr.tobytes(order="C"))
|
92 |
+
|
93 |
+
self._filenames.append(filename)
|
94 |
+
self._counter += 1
|
95 |
+
self._arr.fill(self._sep_token)
|
96 |
+
self._idx = 0
|
97 |
+
|
98 |
+
@property
|
99 |
+
def dtype(self):
|
100 |
+
return self._dtype
|
101 |
+
|
102 |
+
@property
|
103 |
+
def filenames(self):
|
104 |
+
return self._filenames.copy()
|
105 |
+
|
106 |
+
def add_array(self, arr):
|
107 |
+
while self._idx + arr.shape[0] > self._chunk_size:
|
108 |
+
part_len = self._chunk_size - self._idx
|
109 |
+
self._arr[self._idx : self._idx + part_len] = arr[:part_len]
|
110 |
+
self._write_chunk()
|
111 |
+
arr = arr[part_len:]
|
112 |
+
|
113 |
+
arr_len = arr.shape[0]
|
114 |
+
self._arr[self._idx : self._idx + arr_len] = arr
|
115 |
+
self._idx += arr_len
|
116 |
+
|
117 |
+
def write_reminder(self):
|
118 |
+
self._write_chunk()
|
119 |
+
|
120 |
+
|
121 |
+
class PackedDatasetIterator:
|
122 |
+
def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap):
|
123 |
+
self._seed = seed
|
124 |
+
self._shuffle = shuffle
|
125 |
+
self._rng = np.random.default_rng(seed) if shuffle else None
|
126 |
+
self._block_idxs = None
|
127 |
+
|
128 |
+
self._wrap = wrap
|
129 |
+
|
130 |
+
# TODO: instead of filenames, we could have a single text stream
|
131 |
+
# (or text file) with the sequence of all files to be
|
132 |
+
# fetched/loaded.
|
133 |
+
self._filenames = filenames
|
134 |
+
self._file_idx = 0
|
135 |
+
|
136 |
+
self._n_chunks = n_chunks
|
137 |
+
|
138 |
+
self._dtype = None
|
139 |
+
self._block_size = block_size
|
140 |
+
self._n_blocks = None
|
141 |
+
|
142 |
+
self._mmaps = []
|
143 |
+
self._buffers = []
|
144 |
+
|
145 |
+
self._block_idxs = []
|
146 |
+
self._curr_idx = 0
|
147 |
+
|
148 |
+
self._load_n_chunks()
|
149 |
+
|
150 |
+
def _read_header(self, path):
|
151 |
+
with open(path, "rb") as f:
|
152 |
+
magic = f.read(len(HDR_MAGIC))
|
153 |
+
assert magic == HDR_MAGIC, "File doesn't match expected format."
|
154 |
+
version = struct.unpack("<Q", f.read(8))
|
155 |
+
assert version == (1,)
|
156 |
+
(dtype_code,) = struct.unpack("<B", f.read(1))
|
157 |
+
dtype = dtypes[dtype_code]
|
158 |
+
(chunk_size,) = struct.unpack("<Q", f.read(8))
|
159 |
+
return dtype, chunk_size
|
160 |
+
|
161 |
+
def _close_mmaps(self):
|
162 |
+
for mmap in self._mmaps:
|
163 |
+
mmap._mmap.close()
|
164 |
+
|
165 |
+
def _load_n_chunks(self):
|
166 |
+
self._close_mmaps()
|
167 |
+
self._mmaps = []
|
168 |
+
self._buffers = []
|
169 |
+
|
170 |
+
if self._n_chunks > len(self._filenames[self._file_idx :]):
|
171 |
+
if not self._wrap:
|
172 |
+
raise StopIteration
|
173 |
+
self._file_idx = 0
|
174 |
+
|
175 |
+
for i in range(self._n_chunks):
|
176 |
+
filename = self._filenames[self._file_idx + i]
|
177 |
+
if self._dtype is None:
|
178 |
+
self._dtype, self._chunk_size = self._read_header(filename)
|
179 |
+
self._n_blocks = self._chunk_size // self._block_size
|
180 |
+
# TODO: check header matches with previous files
|
181 |
+
mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE)
|
182 |
+
self._mmaps.append(mmap)
|
183 |
+
self._buffers.append(memoryview(mmap))
|
184 |
+
|
185 |
+
self._file_idx += self._n_chunks
|
186 |
+
n_all_blocks = self._n_chunks * self._n_blocks
|
187 |
+
|
188 |
+
self._block_idxs = self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks)
|
189 |
+
|
190 |
+
self._curr_idx = 0
|
191 |
+
|
192 |
+
def __del__(self):
|
193 |
+
self._close_mmaps()
|
194 |
+
del self._mmaps
|
195 |
+
del self._buffers
|
196 |
+
|
197 |
+
def __iter__(self):
|
198 |
+
return self
|
199 |
+
|
200 |
+
def __next__(self):
|
201 |
+
if self._curr_idx >= len(self._block_idxs):
|
202 |
+
self._load_n_chunks()
|
203 |
+
# TODO: trigger fetching next next n_chunks if remote
|
204 |
+
block_idx = self._block_idxs[self._curr_idx]
|
205 |
+
chunk_id = block_idx // self._n_blocks
|
206 |
+
buffer = self._buffers[chunk_id]
|
207 |
+
elem_id = (block_idx % self._n_blocks) * self._block_size
|
208 |
+
offset = np.dtype(self._dtype).itemsize * elem_id
|
209 |
+
arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)
|
210 |
+
self._curr_idx += 1
|
211 |
+
return torch.from_numpy(arr.astype(np.int64))
|
212 |
+
|
213 |
+
|
214 |
+
class CombinedDataset(IterableDataset):
|
215 |
+
def __init__(self, datasets, seed, weights=None):
|
216 |
+
self._seed = seed
|
217 |
+
self._datasets = datasets
|
218 |
+
self._weights = weights
|
219 |
+
n_datasets = len(datasets)
|
220 |
+
if weights is None:
|
221 |
+
self._weights = [1 / n_datasets] * n_datasets
|
222 |
+
|
223 |
+
def __iter__(self):
|
224 |
+
return CombinedDatasetIterator(self._datasets, self._seed, self._weights)
|
225 |
+
|
226 |
+
|
227 |
+
class CombinedDatasetIterator:
|
228 |
+
def __init__(self, datasets, seed, weights):
|
229 |
+
self._datasets = [iter(el) for el in datasets]
|
230 |
+
self._weights = weights
|
231 |
+
self._rng = random.Random(seed)
|
232 |
+
|
233 |
+
def __next__(self):
|
234 |
+
(dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1)
|
235 |
+
return next(dataset)
|
lit_gpt/rmsnorm.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class RMSNorm(torch.nn.Module):
|
5 |
+
"""Root Mean Square Layer Normalization.
|
6 |
+
|
7 |
+
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
|
8 |
+
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
|
12 |
+
super().__init__()
|
13 |
+
self.weight = torch.nn.Parameter(torch.ones(size))
|
14 |
+
self.eps = eps
|
15 |
+
self.dim = dim
|
16 |
+
|
17 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
18 |
+
dtype = x.dtype
|
19 |
+
x = x.float()
|
20 |
+
# NOTE: the original RMSNorm paper implementation is not equivalent
|
21 |
+
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
|
22 |
+
x_normed = x * torch.rsqrt(norm_x + self.eps)
|
23 |
+
return (self.weight * x_normed).to(dtype=dtype)
|
24 |
+
|
25 |
+
def reset_parameters(self) -> None:
|
26 |
+
torch.nn.init.ones_(self.weight)
|
lit_gpt/speed_monitor.py
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from collections import deque
|
3 |
+
from contextlib import nullcontext
|
4 |
+
from typing import Any, Callable, Deque, Dict, Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from lightning import Callback, Fabric, LightningModule, Trainer
|
8 |
+
from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1
|
9 |
+
from lightning.fabric.plugins import (
|
10 |
+
BitsandbytesPrecision,
|
11 |
+
DoublePrecision,
|
12 |
+
FSDPPrecision,
|
13 |
+
HalfPrecision,
|
14 |
+
MixedPrecision,
|
15 |
+
Precision,
|
16 |
+
TransformerEnginePrecision,
|
17 |
+
XLAPrecision,
|
18 |
+
)
|
19 |
+
from lightning.fabric.utilities.rank_zero import rank_zero_only as fabric_rank_zero_only
|
20 |
+
from lightning.pytorch.plugins import (
|
21 |
+
DoublePrecisionPlugin,
|
22 |
+
FSDPPrecisionPlugin,
|
23 |
+
HalfPrecisionPlugin,
|
24 |
+
MixedPrecisionPlugin,
|
25 |
+
XLAPrecisionPlugin,
|
26 |
+
)
|
27 |
+
from lightning.pytorch.utilities.rank_zero import rank_zero_only as trainer_rank_zero_only
|
28 |
+
from torch.utils.flop_counter import FlopCounterMode
|
29 |
+
|
30 |
+
from lit_gpt import GPT
|
31 |
+
from lit_gpt.utils import num_parameters
|
32 |
+
|
33 |
+
GPU_AVAILABLE_FLOPS = {
|
34 |
+
# source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
|
35 |
+
# nvidia publishes spec sheet with a 2x sparsity factor
|
36 |
+
"h100-sxm": {
|
37 |
+
torch.float64: 67e12,
|
38 |
+
torch.float32: 67e12,
|
39 |
+
torch.bfloat16: 1.979e15 / 2,
|
40 |
+
torch.float16: 1.979e15 / 2,
|
41 |
+
torch.int8: 3.958e15 / 2,
|
42 |
+
},
|
43 |
+
"h100-pcie": {
|
44 |
+
torch.float64: 51e12,
|
45 |
+
torch.float32: 51e12,
|
46 |
+
torch.bfloat16: 1.513e15 / 2,
|
47 |
+
torch.float16: 1.513e15 / 2,
|
48 |
+
torch.int8: 3.026e15 / 2,
|
49 |
+
},
|
50 |
+
# source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
|
51 |
+
# sxm and pcie have same flop counts
|
52 |
+
"a100": {torch.float64: 19.5e12, torch.float32: 19.5e12, torch.bfloat16: 312e12, torch.float16: 312e12},
|
53 |
+
# source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/a10-datasheet.pdf
|
54 |
+
"a10g": {torch.float32: 31.2e12, torch.bfloat16: 125e12, torch.float16: 125e12},
|
55 |
+
# source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf
|
56 |
+
"v100-sxm": {torch.float64: 7.8e12, torch.float32: 15.7e12, torch.float16: 125e12},
|
57 |
+
"v100-pcie": {torch.float64: 7e12, torch.float32: 14e12, torch.float16: 112e12},
|
58 |
+
"v100s-pcie": {torch.float64: 8.2e12, torch.float32: 16.4e12, torch.float16: 130e12},
|
59 |
+
# source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf
|
60 |
+
# sxm and pcie have same flop counts
|
61 |
+
"t4": {torch.float32: 8.1e12, torch.float16: 65e12, torch.int8: 130e12},
|
62 |
+
# https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/quadro-rtx-5000-data-sheet-us-nvidia-704120-r4-web.pdf
|
63 |
+
"quadro rtx 5000": {torch.float32: 11.2e12, torch.float16: 89.2e12},
|
64 |
+
}
|
65 |
+
|
66 |
+
TPU_AVAILABLE_FLOPS = {
|
67 |
+
# flop count for each TPU generation is the same for all precisions
|
68 |
+
# since bfloat16 precision is always used for performing matrix operations
|
69 |
+
# for more info: https://cloud.google.com/tpu/docs/bfloat16#choosing_bfloat16
|
70 |
+
# source: https://arxiv.org/pdf/1907.10701.pdf
|
71 |
+
"v2": 45e12,
|
72 |
+
# source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v3
|
73 |
+
"v3": 123e12,
|
74 |
+
# source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v4
|
75 |
+
"v4": 275e12,
|
76 |
+
# source: https://cloud.google.com/tpu/docs/v5e-training
|
77 |
+
"v5litepod": 197e12,
|
78 |
+
}
|
79 |
+
|
80 |
+
|
81 |
+
def get_flops_available(device: torch.device, dtype: torch.dtype) -> Optional[float]:
|
82 |
+
if device.type == "cuda":
|
83 |
+
device_name = torch.cuda.get_device_name(device).lower()
|
84 |
+
if "h100" in device_name and "hbm3" in device_name:
|
85 |
+
device_name = "h100-sxm"
|
86 |
+
elif "h100" in device_name and ("pcie" in device_name or "hbm2e" in device_name):
|
87 |
+
device_name = "h100-pcie"
|
88 |
+
elif "a100" in device_name:
|
89 |
+
device_name = "a100"
|
90 |
+
elif "a10g" in device_name:
|
91 |
+
device_name = "a10g"
|
92 |
+
elif "v100-sxm" in device_name:
|
93 |
+
device_name = "v100-sxm"
|
94 |
+
elif "v100-pcie" in device_name:
|
95 |
+
device_name = "v100-pcie"
|
96 |
+
elif "t4" in device_name:
|
97 |
+
device_name = "t4"
|
98 |
+
elif "quadro rtx 5000" in device_name:
|
99 |
+
device_name = "quadro rtx 5000"
|
100 |
+
else:
|
101 |
+
device_name = None
|
102 |
+
|
103 |
+
if device_name is not None:
|
104 |
+
try:
|
105 |
+
return int(GPU_AVAILABLE_FLOPS[device_name][dtype])
|
106 |
+
except KeyError:
|
107 |
+
raise KeyError(
|
108 |
+
f"flop count not found for {device_name} with dtype: {dtype}; "
|
109 |
+
"MFU cannot be calculated and reported."
|
110 |
+
)
|
111 |
+
elif device.type == "xla":
|
112 |
+
if _XLA_GREATER_EQUAL_2_1:
|
113 |
+
from torch_xla._internal import tpu
|
114 |
+
else:
|
115 |
+
from torch_xla.experimental import tpu
|
116 |
+
|
117 |
+
device_name = tpu.get_tpu_env()["TYPE"].lower()
|
118 |
+
try:
|
119 |
+
return int(TPU_AVAILABLE_FLOPS[device_name])
|
120 |
+
except KeyError:
|
121 |
+
raise KeyError(
|
122 |
+
f"flop count not found for {device_name} with dtype: {dtype}; MFU cannot be calculated and reported."
|
123 |
+
)
|
124 |
+
|
125 |
+
return None
|
126 |
+
|
127 |
+
|
128 |
+
# Adapted from https://github.com/mosaicml/composer/blob/f2a2dc820cb75023b9eb7c46fdfd25273712abd0/composer/callbacks/speed_monitor.py
|
129 |
+
|
130 |
+
|
131 |
+
class SpeedMonitorBase:
|
132 |
+
"""Logs the training throughput and utilization.
|
133 |
+
|
134 |
+
+-------------------------------------+-----------------------------------------------------------+
|
135 |
+
| Key | Logged data |
|
136 |
+
+=====================================+===========================================================+
|
137 |
+
| | Rolling average (over `window_size` most recent |
|
138 |
+
| `throughput/batches_per_sec` | batches) of the number of batches processed per second |
|
139 |
+
| | |
|
140 |
+
+-------------------------------------+-----------------------------------------------------------+
|
141 |
+
| | Rolling average (over `window_size` most recent |
|
142 |
+
| `throughput/samples_per_sec` | batches) of the number of samples processed per second |
|
143 |
+
| | |
|
144 |
+
+-------------------------------------+-----------------------------------------------------------+
|
145 |
+
| | Rolling average (over `window_size` most recent |
|
146 |
+
| `throughput/tokens_per_sec` | batches) of the number of tokens processed per second. |
|
147 |
+
| | This may include padding depending on dataset |
|
148 |
+
+-------------------------------------+-----------------------------------------------------------+
|
149 |
+
| | Estimates flops by `flops_per_batch * batches_per_sec` |
|
150 |
+
| `throughput/flops_per_sec` | |
|
151 |
+
| | |
|
152 |
+
+-------------------------------------+-----------------------------------------------------------+
|
153 |
+
| `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size |
|
154 |
+
+-------------------------------------+-----------------------------------------------------------+
|
155 |
+
| `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size |
|
156 |
+
+-------------------------------------+-----------------------------------------------------------+
|
157 |
+
| | `throughput/tokens_per_sec` divided by world size. This |
|
158 |
+
| `throughput/device/tokens_per_sec` | may include pad tokens depending on dataset |
|
159 |
+
| | |
|
160 |
+
+-------------------------------------+-----------------------------------------------------------+
|
161 |
+
| | `throughput/flops_per_sec` divided by world size. Only |
|
162 |
+
| `throughput/device/flops_per_sec` | logged when model has attribute `flops_per_batch` |
|
163 |
+
| | |
|
164 |
+
+-------------------------------------+-----------------------------------------------------------+
|
165 |
+
| | `throughput/device/flops_per_sec` divided by world size. |
|
166 |
+
| `throughput/device/mfu` | |
|
167 |
+
| | |
|
168 |
+
+-------------------------------------+-----------------------------------------------------------+
|
169 |
+
| `time/train` | Total elapsed training time |
|
170 |
+
+-------------------------------------+-----------------------------------------------------------+
|
171 |
+
| `time/val` | Total elapsed validation time |
|
172 |
+
+-------------------------------------+-----------------------------------------------------------+
|
173 |
+
| `time/total` | Total elapsed time (time/train + time/val) |
|
174 |
+
+-------------------------------------+-----------------------------------------------------------+
|
175 |
+
|
176 |
+
Notes:
|
177 |
+
- The implementation assumes that devices are homogeneous as it normalizes by the world size.
|
178 |
+
- Tokens/sec, flops/sec and MFU do not account for padding tokens if present. We suggest using samples/sec or
|
179 |
+
batches/sec to measure throughput under this circumstance.
|
180 |
+
- Be careful when comparing MFU numbers across projects, as this will highly depend on the ``flops_per_batch``.
|
181 |
+
There is no widespread, realistic, and reliable implementation to compute them.
|
182 |
+
We suggest using our ``measure_flops`` function, but many other works will use ``estimated_flops`` which
|
183 |
+
will almost always be an overestimate when compared to the true value.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
window_size (int, optional): Number of batches to use for a rolling average of throughput.
|
187 |
+
Defaults to 100.
|
188 |
+
time_unit (str, optional): Time unit to use for `time` logging. Can be one of
|
189 |
+
'seconds', 'minutes', 'hours', or 'days'. Defaults to 'hours'.
|
190 |
+
"""
|
191 |
+
|
192 |
+
def __init__(
|
193 |
+
self,
|
194 |
+
flops_available: float,
|
195 |
+
log_dict: Callable[[Dict, int], None],
|
196 |
+
window_size: int = 100,
|
197 |
+
time_unit: str = "hours",
|
198 |
+
):
|
199 |
+
self.flops_available = flops_available
|
200 |
+
self.log_dict = log_dict
|
201 |
+
|
202 |
+
# Track the batch num samples and wct to compute throughput over a window of batches
|
203 |
+
self.history_samples: Deque[int] = deque(maxlen=window_size + 1)
|
204 |
+
self.history_wct: Deque[float] = deque(maxlen=window_size + 1)
|
205 |
+
self.history_lengths: Deque[int] = deque(maxlen=window_size + 1)
|
206 |
+
self.history_flops: Deque[int] = deque(maxlen=window_size + 1)
|
207 |
+
|
208 |
+
self.divider = 1
|
209 |
+
if time_unit == "seconds":
|
210 |
+
self.divider = 1
|
211 |
+
elif time_unit == "minutes":
|
212 |
+
self.divider = 60
|
213 |
+
elif time_unit == "hours":
|
214 |
+
self.divider = 60 * 60
|
215 |
+
elif time_unit == "days":
|
216 |
+
self.divider = 60 * 60 * 24
|
217 |
+
else:
|
218 |
+
raise ValueError(
|
219 |
+
f'Invalid time_unit: {time_unit}. Must be one of "seconds", "minutes", "hours", or "days".'
|
220 |
+
)
|
221 |
+
|
222 |
+
# Keep track of time spent evaluating
|
223 |
+
self.total_eval_wct = 0.0
|
224 |
+
self.step = -1
|
225 |
+
|
226 |
+
def on_train_batch_end(
|
227 |
+
self,
|
228 |
+
samples: int, # total samples seen (per device)
|
229 |
+
train_elapsed: float, # total training time (seconds)
|
230 |
+
world_size: int,
|
231 |
+
flops_per_batch: Optional[int] = None, # (per device)
|
232 |
+
lengths: Optional[int] = None, # total length of the samples seen (per device)
|
233 |
+
) -> None:
|
234 |
+
self.step += 1
|
235 |
+
step = self.step
|
236 |
+
metrics = {}
|
237 |
+
|
238 |
+
self.history_samples.append(samples)
|
239 |
+
if lengths is not None:
|
240 |
+
self.history_lengths.append(lengths)
|
241 |
+
# if lengths are passed, there should be as many values as samples
|
242 |
+
assert len(self.history_samples) == len(self.history_lengths)
|
243 |
+
self.history_wct.append(train_elapsed)
|
244 |
+
if len(self.history_wct) == self.history_wct.maxlen:
|
245 |
+
elapsed_batches = len(self.history_samples) - 1
|
246 |
+
elapsed_samples = self.history_samples[-1] - self.history_samples[0]
|
247 |
+
elapsed_wct = self.history_wct[-1] - self.history_wct[0]
|
248 |
+
samples_per_sec = elapsed_samples * world_size / elapsed_wct
|
249 |
+
dev_samples_per_sec = elapsed_samples / elapsed_wct
|
250 |
+
metrics.update(
|
251 |
+
{
|
252 |
+
"throughput/batches_per_sec": elapsed_batches * world_size / elapsed_wct,
|
253 |
+
"throughput/samples_per_sec": samples_per_sec,
|
254 |
+
"throughput/device/batches_per_sec": elapsed_batches / elapsed_wct,
|
255 |
+
"throughput/device/samples_per_sec": dev_samples_per_sec,
|
256 |
+
}
|
257 |
+
)
|
258 |
+
if lengths is not None:
|
259 |
+
elapsed_lengths = int(self.history_lengths[-1]) - int(self.history_lengths[0])
|
260 |
+
avg_length = elapsed_lengths / elapsed_batches
|
261 |
+
metrics.update(
|
262 |
+
{
|
263 |
+
"throughput/tokens_per_sec": samples_per_sec * avg_length,
|
264 |
+
"throughput/device/tokens_per_sec": dev_samples_per_sec * avg_length,
|
265 |
+
}
|
266 |
+
)
|
267 |
+
|
268 |
+
if flops_per_batch is not None:
|
269 |
+
# sum of flops per batch across ranks
|
270 |
+
self.history_flops.append(flops_per_batch * world_size)
|
271 |
+
if len(self.history_flops) == self.history_flops.maxlen:
|
272 |
+
elapsed_flops = sum(self.history_flops) - self.history_flops[0]
|
273 |
+
elapsed_wct = self.history_wct[-1] - self.history_wct[0]
|
274 |
+
flops_per_sec = elapsed_flops / elapsed_wct
|
275 |
+
device_flops_per_sec = flops_per_sec / world_size
|
276 |
+
metrics.update(
|
277 |
+
{"throughput/flops_per_sec": flops_per_sec, "throughput/device/flops_per_sec": device_flops_per_sec}
|
278 |
+
)
|
279 |
+
if self.flops_available:
|
280 |
+
metrics["throughput/device/mfu"] = device_flops_per_sec / self.flops_available
|
281 |
+
|
282 |
+
metrics.update(
|
283 |
+
{
|
284 |
+
"time/train": train_elapsed / self.divider,
|
285 |
+
"time/val": self.total_eval_wct / self.divider,
|
286 |
+
"time/total": (train_elapsed + self.total_eval_wct) / self.divider,
|
287 |
+
"samples": samples,
|
288 |
+
}
|
289 |
+
)
|
290 |
+
|
291 |
+
self.log_dict(metrics, step)
|
292 |
+
|
293 |
+
def eval_end(self, eval_elapsed: float) -> None:
|
294 |
+
self.total_eval_wct += eval_elapsed # seconds
|
295 |
+
|
296 |
+
|
297 |
+
def plugin_to_compute_dtype(plugin: Precision) -> torch.dtype:
|
298 |
+
if isinstance(plugin, BitsandbytesPrecision):
|
299 |
+
return plugin.dtype
|
300 |
+
if isinstance(plugin, (HalfPrecision, MixedPrecision, HalfPrecisionPlugin)):
|
301 |
+
return plugin._desired_input_dtype
|
302 |
+
if isinstance(plugin, MixedPrecisionPlugin):
|
303 |
+
return torch.bfloat16 if plugin.precision == "bf16-mixed" else torch.half
|
304 |
+
if isinstance(plugin, (DoublePrecision, DoublePrecisionPlugin)):
|
305 |
+
return torch.double
|
306 |
+
if isinstance(plugin, (XLAPrecision, XLAPrecisionPlugin)):
|
307 |
+
return plugin._desired_dtype
|
308 |
+
if isinstance(plugin, TransformerEnginePrecision):
|
309 |
+
return torch.int8
|
310 |
+
if isinstance(plugin, (FSDPPrecision, FSDPPrecisionPlugin)):
|
311 |
+
return plugin.mixed_precision_config.reduce_dtype
|
312 |
+
if isinstance(plugin, Precision):
|
313 |
+
return torch.float32
|
314 |
+
raise NotImplementedError(plugin)
|
315 |
+
|
316 |
+
|
317 |
+
class SpeedMonitorFabric(SpeedMonitorBase):
|
318 |
+
def __init__(self, fabric: Fabric, *args: Any, **kwargs: Any) -> None:
|
319 |
+
dtype = plugin_to_compute_dtype(fabric.strategy.precision)
|
320 |
+
flops_available = get_flops_available(fabric.device, dtype)
|
321 |
+
super().__init__(flops_available, fabric.log_dict, *args, **kwargs)
|
322 |
+
|
323 |
+
@fabric_rank_zero_only
|
324 |
+
def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None:
|
325 |
+
super().on_train_batch_end(*args, **kwargs)
|
326 |
+
|
327 |
+
|
328 |
+
class SpeedMonitorCallback(Callback):
|
329 |
+
def __init__(self, length_fn: Callable[[Any], int], batch_size: int, **kwargs: Any) -> None:
|
330 |
+
super().__init__()
|
331 |
+
self.speed_monitor: Optional[SpeedMonitorBase] = None
|
332 |
+
self.speed_monitor_kwargs = kwargs
|
333 |
+
self.length_fn = length_fn
|
334 |
+
self.batch_size = batch_size
|
335 |
+
self.eval_t0: int = 0
|
336 |
+
self.train_t0: int = 0
|
337 |
+
self.total_lengths: int = 0
|
338 |
+
|
339 |
+
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
|
340 |
+
if self.speed_monitor is not None:
|
341 |
+
return # already setup
|
342 |
+
dtype = plugin_to_compute_dtype(trainer.precision_plugin)
|
343 |
+
flops_available = get_flops_available(trainer.strategy.root_device, dtype)
|
344 |
+
self.speed_monitor = SpeedMonitorBase(flops_available, trainer.logger.log_metrics, **self.speed_monitor_kwargs)
|
345 |
+
|
346 |
+
@trainer_rank_zero_only
|
347 |
+
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
348 |
+
if trainer.fit_loop._should_accumulate():
|
349 |
+
return
|
350 |
+
|
351 |
+
self.train_t0 = time.perf_counter()
|
352 |
+
|
353 |
+
@trainer_rank_zero_only
|
354 |
+
def on_train_batch_end(
|
355 |
+
self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int
|
356 |
+
) -> None:
|
357 |
+
self.total_lengths += self.length_fn(batch)
|
358 |
+
if trainer.fit_loop._should_accumulate():
|
359 |
+
return
|
360 |
+
train_elapsed = time.perf_counter() - self.train_t0
|
361 |
+
assert self.speed_monitor is not None
|
362 |
+
iter_num = trainer.fit_loop.total_batch_idx
|
363 |
+
assert (measured_flops := pl_module.measured_flops) is not None
|
364 |
+
self.speed_monitor.on_train_batch_end(
|
365 |
+
(iter_num + 1) * self.batch_size,
|
366 |
+
train_elapsed,
|
367 |
+
# this assumes that device FLOPs are the same and that all devices have the same batch size
|
368 |
+
trainer.world_size,
|
369 |
+
flops_per_batch=measured_flops,
|
370 |
+
lengths=self.total_lengths,
|
371 |
+
)
|
372 |
+
|
373 |
+
@trainer_rank_zero_only
|
374 |
+
def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
375 |
+
self.eval_t0 = time.perf_counter()
|
376 |
+
|
377 |
+
@trainer_rank_zero_only
|
378 |
+
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
379 |
+
eval_elapsed = time.perf_counter() - self.eval_t0
|
380 |
+
assert self.speed_monitor is not None
|
381 |
+
self.speed_monitor.eval_end(eval_elapsed)
|
382 |
+
|
383 |
+
|
384 |
+
def flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int:
|
385 |
+
flops_per_token = 2 * n_params # each parameter is used for a MAC (2 FLOPS) per network operation
|
386 |
+
# this assumes that all samples have a fixed length equal to the block size
|
387 |
+
# which is most likely false during finetuning
|
388 |
+
flops_per_seq = flops_per_token * max_seq_length
|
389 |
+
attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))
|
390 |
+
return flops_per_seq + attn_flops_per_seq
|
391 |
+
|
392 |
+
|
393 |
+
def estimate_flops(model: GPT) -> int:
|
394 |
+
"""Measures estimated FLOPs for MFU.
|
395 |
+
|
396 |
+
Refs:
|
397 |
+
* https://ar5iv.labs.arxiv.org/html/2205.05198#A1
|
398 |
+
* https://ar5iv.labs.arxiv.org/html/2204.02311#A2
|
399 |
+
"""
|
400 |
+
# using all parameters for this is a naive over estimation because not all model parameters actually contribute to
|
401 |
+
# this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage
|
402 |
+
# (~10%) compared to the measured FLOPs, making those lower but more realistic.
|
403 |
+
# For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.
|
404 |
+
n_trainable_params = num_parameters(model, requires_grad=True)
|
405 |
+
trainable_flops = flops_per_param(
|
406 |
+
model.max_seq_length, model.config.n_layer, model.config.n_embd, n_trainable_params
|
407 |
+
)
|
408 |
+
# forward + backward + gradients (assumes no gradient accumulation)
|
409 |
+
ops_per_step = 3 if model.training else 1
|
410 |
+
n_frozen_params = num_parameters(model, requires_grad=False)
|
411 |
+
frozen_flops = flops_per_param(model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params)
|
412 |
+
# forward + backward
|
413 |
+
frozen_ops_per_step = 2 if model.training else 1
|
414 |
+
return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops
|
415 |
+
|
416 |
+
|
417 |
+
def measure_flops(model: GPT, x: torch.Tensor) -> int:
|
418 |
+
"""Measures real FLOPs for HFU"""
|
419 |
+
flop_counter = FlopCounterMode(model, display=False)
|
420 |
+
ctx = nullcontext() if model.training else torch.no_grad()
|
421 |
+
with ctx, flop_counter:
|
422 |
+
y = model(x)
|
423 |
+
if model.training:
|
424 |
+
y.sum().backward()
|
425 |
+
return flop_counter.get_total_flops()
|
lit_gpt/tokenizer.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class Tokenizer:
|
9 |
+
def __init__(self, checkpoint_dir: Path) -> None:
|
10 |
+
self.use_bos = self.check_if_bos_token_used(checkpoint_dir)
|
11 |
+
self.bos_id = None
|
12 |
+
self.eos_id = None
|
13 |
+
|
14 |
+
# some checkpoints have both files, `.model` takes precedence
|
15 |
+
if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file():
|
16 |
+
from sentencepiece import SentencePieceProcessor
|
17 |
+
|
18 |
+
self.processor = SentencePieceProcessor(model_file=str(vocabulary_path))
|
19 |
+
self.backend = "sentencepiece"
|
20 |
+
self.bos_id = self.processor.bos_id()
|
21 |
+
self.eos_id = self.processor.eos_id()
|
22 |
+
|
23 |
+
elif (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file():
|
24 |
+
from tokenizers import Tokenizer as HFTokenizer
|
25 |
+
|
26 |
+
self.processor = HFTokenizer.from_file(str(vocabulary_path))
|
27 |
+
self.backend = "huggingface"
|
28 |
+
|
29 |
+
if (special_tokens_path := checkpoint_dir / "tokenizer_config.json").is_file():
|
30 |
+
with open(special_tokens_path) as fp:
|
31 |
+
config = json.load(fp)
|
32 |
+
bos_token = config.get("bos_token")
|
33 |
+
self.bos_id = self.token_to_id(bos_token) if bos_token is not None else None
|
34 |
+
eos_token = config.get("eos_token")
|
35 |
+
self.eos_id = self.token_to_id(eos_token) if eos_token is not None else None
|
36 |
+
if (special_tokens_path := checkpoint_dir / "generation_config.json").is_file():
|
37 |
+
with open(special_tokens_path) as fp:
|
38 |
+
config = json.load(fp)
|
39 |
+
if self.bos_id is None:
|
40 |
+
self.bos_id = config.get("bos_token_id")
|
41 |
+
if self.eos_id is None:
|
42 |
+
self.eos_id = config.get("eos_token_id")
|
43 |
+
else:
|
44 |
+
raise NotImplementedError
|
45 |
+
|
46 |
+
@property
|
47 |
+
def vocab_size(self) -> int:
|
48 |
+
if self.backend == "huggingface":
|
49 |
+
return self.processor.get_vocab_size(with_added_tokens=False)
|
50 |
+
if self.backend == "sentencepiece":
|
51 |
+
return self.processor.vocab_size()
|
52 |
+
raise RuntimeError
|
53 |
+
|
54 |
+
def token_to_id(self, token: str) -> int:
|
55 |
+
if self.backend == "huggingface":
|
56 |
+
id_ = self.processor.token_to_id(token)
|
57 |
+
elif self.backend == "sentencepiece":
|
58 |
+
id_ = self.processor.piece_to_id(token)
|
59 |
+
else:
|
60 |
+
raise RuntimeError
|
61 |
+
if id_ is None:
|
62 |
+
raise ValueError(f"token {token!r} not found in the collection.")
|
63 |
+
return id_
|
64 |
+
|
65 |
+
def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
|
66 |
+
if not (tokenizer_config_path := checkpoint_dir / "tokenizer_config.json").is_file():
|
67 |
+
return False
|
68 |
+
with open(tokenizer_config_path) as fp:
|
69 |
+
config = json.load(fp)
|
70 |
+
if any(config.get(check, False) for check in ("add_bos_token", "add_prefix_space")):
|
71 |
+
return True
|
72 |
+
# for examples that also use the Llama tokenizer, but do not have or set add_bos_token to True.
|
73 |
+
# ex: https://huggingface.co/stabilityai/StableBeluga2/blob/main/tokenizer_config.json#L2
|
74 |
+
return config.get("add_bos_token") is None and config.get("tokenizer_class") == "LlamaTokenizer"
|
75 |
+
|
76 |
+
def encode(
|
77 |
+
self,
|
78 |
+
string: str,
|
79 |
+
device: Optional[torch.device] = None,
|
80 |
+
bos: Optional[bool] = None,
|
81 |
+
eos: bool = False,
|
82 |
+
max_length: int = -1,
|
83 |
+
) -> torch.Tensor:
|
84 |
+
if self.backend == "huggingface":
|
85 |
+
tokens = self.processor.encode(string).ids
|
86 |
+
elif self.backend == "sentencepiece":
|
87 |
+
tokens = self.processor.encode(string)
|
88 |
+
else:
|
89 |
+
raise RuntimeError
|
90 |
+
if bos or (bos is None and self.use_bos):
|
91 |
+
bos_id = self.bos_id
|
92 |
+
if bos_id is None:
|
93 |
+
raise NotImplementedError("This tokenizer does not have a defined a bos token")
|
94 |
+
tokens = [bos_id] + tokens
|
95 |
+
if eos:
|
96 |
+
tokens = tokens + [self.eos_id]
|
97 |
+
if max_length > 0:
|
98 |
+
tokens = tokens[:max_length]
|
99 |
+
return torch.tensor(tokens, dtype=torch.int, device=device)
|
100 |
+
|
101 |
+
def decode(self, tensor: torch.Tensor) -> str:
|
102 |
+
tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()
|
103 |
+
return self.processor.decode(tokens)
|
lit_gpt/utils.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions for training and inference."""
|
2 |
+
import math
|
3 |
+
import pickle
|
4 |
+
import sys
|
5 |
+
from contextlib import nullcontext
|
6 |
+
from io import BytesIO
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import ContextManager, Dict, List, Mapping, Optional, TypeVar, Union
|
9 |
+
|
10 |
+
import lightning as L
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.utils._device
|
14 |
+
from lightning.fabric.strategies import FSDPStrategy
|
15 |
+
from lightning.fabric.utilities.load import _lazy_load as lazy_load
|
16 |
+
from torch.serialization import normalize_storage_type
|
17 |
+
|
18 |
+
|
19 |
+
def find_multiple(n: int, k: int) -> int:
|
20 |
+
assert k > 0
|
21 |
+
if n % k == 0:
|
22 |
+
return n
|
23 |
+
return n + k - (n % k)
|
24 |
+
|
25 |
+
|
26 |
+
def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:
|
27 |
+
total = 0
|
28 |
+
for p in module.parameters():
|
29 |
+
if requires_grad is None or p.requires_grad == requires_grad:
|
30 |
+
if hasattr(p, "quant_state"):
|
31 |
+
# bitsandbytes 4bit layer support
|
32 |
+
total += math.prod(p.quant_state[1])
|
33 |
+
else:
|
34 |
+
total += p.numel()
|
35 |
+
return total
|
36 |
+
|
37 |
+
|
38 |
+
def gptq_quantization(enabled: bool = False) -> ContextManager:
|
39 |
+
if not enabled:
|
40 |
+
return nullcontext()
|
41 |
+
|
42 |
+
from lightning.fabric.plugins.precision.utils import _ClassReplacementContextManager
|
43 |
+
|
44 |
+
from quantize.gptq import ColBlockQuantizedLinear
|
45 |
+
|
46 |
+
class QuantizedLinear(ColBlockQuantizedLinear):
|
47 |
+
def __init__(self, *args, **kwargs):
|
48 |
+
super().__init__(*args, bits=4, tile_cols=-1, **kwargs)
|
49 |
+
|
50 |
+
return _ClassReplacementContextManager({"torch.nn.Linear": QuantizedLinear})
|
51 |
+
|
52 |
+
|
53 |
+
def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None:
|
54 |
+
files = {
|
55 |
+
"lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(),
|
56 |
+
"lit_config.json": (checkpoint_dir / "lit_config.json").is_file(),
|
57 |
+
"tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or (
|
58 |
+
checkpoint_dir / "tokenizer.model"
|
59 |
+
).is_file(),
|
60 |
+
"tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
|
61 |
+
}
|
62 |
+
if checkpoint_dir.is_dir():
|
63 |
+
if all(files.values()):
|
64 |
+
# we're good
|
65 |
+
return
|
66 |
+
problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}"
|
67 |
+
else:
|
68 |
+
problem = " is not a checkpoint directory"
|
69 |
+
|
70 |
+
# list locally available checkpoints
|
71 |
+
available = list(Path("checkpoints").glob("*/*"))
|
72 |
+
if available:
|
73 |
+
options = "\n --checkpoint_dir ".join([""] + [repr(str(p.resolve())) for p in available])
|
74 |
+
extra = f"\nYou have downloaded locally:{options}\n"
|
75 |
+
else:
|
76 |
+
extra = ""
|
77 |
+
|
78 |
+
error_message = (
|
79 |
+
f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
|
80 |
+
"\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n"
|
81 |
+
f"{extra}\nSee all download options by running:\n python scripts/download.py"
|
82 |
+
)
|
83 |
+
print(error_message, file=sys.stderr)
|
84 |
+
raise SystemExit(1)
|
85 |
+
|
86 |
+
|
87 |
+
class SavingProxyForStorage:
|
88 |
+
def __init__(self, obj, saver, protocol_version=5):
|
89 |
+
self.protocol_version = protocol_version
|
90 |
+
self.saver = saver
|
91 |
+
if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
|
92 |
+
raise TypeError(f"expected storage, not {type(obj)}")
|
93 |
+
|
94 |
+
# this logic is taken from PyTorch 2.0+ torch/serialization.py
|
95 |
+
if isinstance(obj, torch.storage.TypedStorage):
|
96 |
+
# PT upstream wants to deprecate this eventually...
|
97 |
+
storage = obj._untyped_storage
|
98 |
+
storage_type_str = obj._pickle_storage_type()
|
99 |
+
storage_type = getattr(torch, storage_type_str)
|
100 |
+
storage_numel = obj._size()
|
101 |
+
else:
|
102 |
+
storage = obj
|
103 |
+
storage_type = normalize_storage_type(type(obj))
|
104 |
+
storage_numel = storage.nbytes()
|
105 |
+
|
106 |
+
storage_key = saver._write_storage_and_return_key(storage)
|
107 |
+
location = torch.serialization.location_tag(storage)
|
108 |
+
|
109 |
+
self.storage_info = ("storage", storage_type, storage_key, location, storage_numel)
|
110 |
+
|
111 |
+
def __reduce_ex__(self, protocol_version):
|
112 |
+
assert False, "this should be handled with out of band"
|
113 |
+
|
114 |
+
|
115 |
+
class SavingProxyForTensor:
|
116 |
+
def __init__(self, tensor, saver, protocol_version=5):
|
117 |
+
self.protocol_version = protocol_version
|
118 |
+
self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version)
|
119 |
+
if reduce_args[0] == torch._utils._rebuild_tensor_v2:
|
120 |
+
# for Tensors with Python attributes
|
121 |
+
(a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args
|
122 |
+
assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates"
|
123 |
+
storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version)
|
124 |
+
self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args)
|
125 |
+
else:
|
126 |
+
(storage, *other_reduce_args) = reduce_args
|
127 |
+
assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates"
|
128 |
+
storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version)
|
129 |
+
self.reduce_args = (storage_proxy, *other_reduce_args)
|
130 |
+
|
131 |
+
def __reduce_ex__(self, protocol_version):
|
132 |
+
if protocol_version != self.protocol_version:
|
133 |
+
raise RuntimeError(f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}")
|
134 |
+
return self.reduce_ret_fn, self.reduce_args
|
135 |
+
|
136 |
+
|
137 |
+
class IncrementalPyTorchPickler(pickle.Pickler):
|
138 |
+
def __init__(self, saver, *args, **kwargs):
|
139 |
+
super().__init__(*args, **kwargs)
|
140 |
+
self.storage_dtypes = {}
|
141 |
+
self.saver = saver
|
142 |
+
self.id_map = {}
|
143 |
+
|
144 |
+
# this logic is taken from PyTorch 2.0+ torch/serialization.py
|
145 |
+
def persistent_id(self, obj):
|
146 |
+
# FIXME: the docs say that persistent_id should only return a string
|
147 |
+
# but torch store returns tuples. This works only in the binary protocol
|
148 |
+
# see
|
149 |
+
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
|
150 |
+
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
|
151 |
+
if isinstance(obj, SavingProxyForStorage):
|
152 |
+
return obj.storage_info
|
153 |
+
|
154 |
+
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
|
155 |
+
if isinstance(obj, torch.storage.TypedStorage):
|
156 |
+
# TODO: Once we decide to break serialization FC, this case
|
157 |
+
# can be deleted
|
158 |
+
storage = obj._untyped_storage
|
159 |
+
storage_dtype = obj.dtype
|
160 |
+
storage_type_str = obj._pickle_storage_type()
|
161 |
+
storage_type = getattr(torch, storage_type_str)
|
162 |
+
storage_numel = obj._size()
|
163 |
+
|
164 |
+
else:
|
165 |
+
storage = obj
|
166 |
+
storage_dtype = torch.uint8
|
167 |
+
storage_type = normalize_storage_type(type(obj))
|
168 |
+
storage_numel = storage.nbytes()
|
169 |
+
|
170 |
+
# If storage is allocated, ensure that any other saved storages
|
171 |
+
# pointing to the same data all have the same dtype. If storage is
|
172 |
+
# not allocated, don't perform this check
|
173 |
+
if storage.data_ptr() != 0:
|
174 |
+
if storage.data_ptr() in self.storage_dtypes:
|
175 |
+
if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
|
176 |
+
raise RuntimeError(
|
177 |
+
"Cannot save multiple tensors or storages that view the same data as different types"
|
178 |
+
)
|
179 |
+
else:
|
180 |
+
self.storage_dtypes[storage.data_ptr()] = storage_dtype
|
181 |
+
|
182 |
+
storage_key = self.id_map.get(storage._cdata)
|
183 |
+
if storage_key is None:
|
184 |
+
storage_key = self.saver._write_storage_and_return_key(storage)
|
185 |
+
self.id_map[storage._cdata] = storage_key
|
186 |
+
location = torch.serialization.location_tag(storage)
|
187 |
+
|
188 |
+
return ("storage", storage_type, storage_key, location, storage_numel)
|
189 |
+
|
190 |
+
return None
|
191 |
+
|
192 |
+
|
193 |
+
class incremental_save:
|
194 |
+
def __init__(self, name):
|
195 |
+
self.name = name
|
196 |
+
self.zipfile = torch._C.PyTorchFileWriter(str(name))
|
197 |
+
self.has_saved = False
|
198 |
+
self.next_key = 0
|
199 |
+
|
200 |
+
def __enter__(self):
|
201 |
+
return self
|
202 |
+
|
203 |
+
def store_early(self, tensor):
|
204 |
+
if isinstance(tensor, torch.Tensor):
|
205 |
+
return SavingProxyForTensor(tensor, self)
|
206 |
+
raise TypeError(f"can only store tensors early, not {type(tensor)}")
|
207 |
+
|
208 |
+
def save(self, obj):
|
209 |
+
if self.has_saved:
|
210 |
+
raise RuntimeError("have already saved")
|
211 |
+
# Write the pickle data for `obj`
|
212 |
+
data_buf = BytesIO()
|
213 |
+
pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
|
214 |
+
pickler.dump(obj)
|
215 |
+
data_value = data_buf.getvalue()
|
216 |
+
self.zipfile.write_record("data.pkl", data_value, len(data_value))
|
217 |
+
self.has_saved = True
|
218 |
+
|
219 |
+
def _write_storage_and_return_key(self, storage):
|
220 |
+
if self.has_saved:
|
221 |
+
raise RuntimeError("have already saved")
|
222 |
+
key = self.next_key
|
223 |
+
self.next_key += 1
|
224 |
+
name = f"data/{key}"
|
225 |
+
if storage.device.type != "cpu":
|
226 |
+
storage = storage.cpu()
|
227 |
+
num_bytes = storage.nbytes()
|
228 |
+
self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
|
229 |
+
return key
|
230 |
+
|
231 |
+
def __exit__(self, type, value, traceback):
|
232 |
+
self.zipfile.write_end_of_file()
|
233 |
+
|
234 |
+
|
235 |
+
T = TypeVar("T")
|
236 |
+
|
237 |
+
|
238 |
+
def chunked_cross_entropy(
|
239 |
+
logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128
|
240 |
+
) -> torch.Tensor:
|
241 |
+
# with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate
|
242 |
+
# the memory usage in fine-tuning settings with low number of parameters.
|
243 |
+
# as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing
|
244 |
+
# the memory spike's magnitude
|
245 |
+
|
246 |
+
# lm_head was chunked (we are fine-tuning)
|
247 |
+
if isinstance(logits, list):
|
248 |
+
# don't want to chunk cross entropy
|
249 |
+
if chunk_size == 0:
|
250 |
+
logits = torch.cat(logits, dim=1)
|
251 |
+
logits = logits.reshape(-1, logits.size(-1))
|
252 |
+
targets = targets.reshape(-1)
|
253 |
+
return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)
|
254 |
+
|
255 |
+
# chunk cross entropy
|
256 |
+
logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits]
|
257 |
+
target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)]
|
258 |
+
loss_chunks = [
|
259 |
+
torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none")
|
260 |
+
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
|
261 |
+
]
|
262 |
+
return torch.cat(loss_chunks).mean()
|
263 |
+
|
264 |
+
# no chunking at all
|
265 |
+
logits = logits.reshape(-1, logits.size(-1))
|
266 |
+
targets = targets.reshape(-1)
|
267 |
+
if chunk_size == 0:
|
268 |
+
return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)
|
269 |
+
|
270 |
+
# lm_head wasn't chunked, chunk cross entropy
|
271 |
+
logit_chunks = logits.split(chunk_size)
|
272 |
+
target_chunks = targets.split(chunk_size)
|
273 |
+
loss_chunks = [
|
274 |
+
torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none")
|
275 |
+
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
|
276 |
+
]
|
277 |
+
return torch.cat(loss_chunks).mean()
|
278 |
+
|
279 |
+
|
280 |
+
def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:
|
281 |
+
for checkpoint_name, attribute_name in mapping.items():
|
282 |
+
full_checkpoint_name = prefix + checkpoint_name
|
283 |
+
if full_checkpoint_name in state_dict:
|
284 |
+
full_attribute_name = prefix + attribute_name
|
285 |
+
state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name)
|
286 |
+
return state_dict
|
287 |
+
|
288 |
+
|
289 |
+
def get_default_supported_precision(training: bool) -> str:
|
290 |
+
"""Return default precision that is supported by the hardware: either `bf16` or `16`.
|
291 |
+
|
292 |
+
Args:
|
293 |
+
training: `-mixed` or `-true` version of the precision to use
|
294 |
+
|
295 |
+
Returns:
|
296 |
+
default precision that is suitable for the task and is supported by the hardware
|
297 |
+
"""
|
298 |
+
from lightning.fabric.accelerators import MPSAccelerator
|
299 |
+
|
300 |
+
if MPSAccelerator.is_available() or (torch.cuda.is_available() and not torch.cuda.is_bf16_supported()):
|
301 |
+
return "16-mixed" if training else "16-true"
|
302 |
+
return "bf16-mixed" if training else "bf16-true"
|
303 |
+
|
304 |
+
|
305 |
+
def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True) -> None:
|
306 |
+
if isinstance(fabric.strategy, FSDPStrategy):
|
307 |
+
fabric.load_raw(checkpoint_path, model, strict=strict)
|
308 |
+
else:
|
309 |
+
state_dict = lazy_load(checkpoint_path)
|
310 |
+
state_dict = state_dict.get("model", state_dict)
|
311 |
+
model.load_state_dict(state_dict, strict=strict)
|
lit_llama/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from lit_llama.model import LLaMAConfig, LLaMA, RMSNorm, build_rope_cache, apply_rope
|
2 |
+
from lit_llama.tokenizer import Tokenizer
|
lit_llama/adapter.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Implementation of the paper:
|
2 |
+
|
3 |
+
LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
|
4 |
+
https://arxiv.org/abs/2303.16199
|
5 |
+
"""
|
6 |
+
# mypy: ignore-errors
|
7 |
+
import math
|
8 |
+
from dataclasses import dataclass
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from torch.nn import functional as F
|
13 |
+
import lit_llama.model as llama
|
14 |
+
from lit_llama.model import build_rope_cache, apply_rope, RMSNorm, MLP
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class LLaMAConfig(llama.LLaMAConfig):
|
19 |
+
adapter_prompt_length: int = 10
|
20 |
+
adapter_start_layer: int = 2
|
21 |
+
|
22 |
+
|
23 |
+
class CausalSelfAttention(nn.Module):
|
24 |
+
"""A modification of `lit_llama.model.CausalSelfAttention` that adds the attention
|
25 |
+
over the adaption prompt."""
|
26 |
+
|
27 |
+
def __init__(self, config: LLaMAConfig, block_idx: int) -> None:
|
28 |
+
super().__init__()
|
29 |
+
assert config.n_embd % config.n_head == 0
|
30 |
+
|
31 |
+
# key, query, value projections for all heads, but in a batch
|
32 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
|
33 |
+
# output projection
|
34 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
35 |
+
|
36 |
+
if block_idx >= config.adapter_start_layer:
|
37 |
+
# adapter embedding layer
|
38 |
+
self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
|
39 |
+
# gate for adaption
|
40 |
+
self.gating_factor = torch.nn.Parameter(torch.zeros(1))
|
41 |
+
|
42 |
+
self.n_head = config.n_head
|
43 |
+
self.n_embd = config.n_embd
|
44 |
+
self.block_size = config.block_size
|
45 |
+
self.block_idx = block_idx
|
46 |
+
self.adapter_prompt_length = config.adapter_prompt_length
|
47 |
+
self.adapter_start_layer = config.adapter_start_layer
|
48 |
+
self.rope_cache = None
|
49 |
+
|
50 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
51 |
+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
52 |
+
|
53 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
54 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
55 |
+
|
56 |
+
head_size = C // self.n_head
|
57 |
+
k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
|
58 |
+
q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
|
59 |
+
v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
|
60 |
+
|
61 |
+
if self.rope_cache is None:
|
62 |
+
# cache for future forward calls
|
63 |
+
self.rope_cache = build_rope_cache(
|
64 |
+
seq_len=self.block_size,
|
65 |
+
n_elem=self.n_embd // self.n_head,
|
66 |
+
dtype=x.dtype,
|
67 |
+
device=x.device,
|
68 |
+
)
|
69 |
+
|
70 |
+
q = apply_rope(q, self.rope_cache)
|
71 |
+
k = apply_rope(k, self.rope_cache)
|
72 |
+
|
73 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
74 |
+
# att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
75 |
+
# att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
|
76 |
+
# att = F.softmax(att, dim=-1)
|
77 |
+
# y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
78 |
+
|
79 |
+
# efficient attention using Flash Attention CUDA kernels
|
80 |
+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
|
81 |
+
|
82 |
+
if self.block_idx >= self.adapter_start_layer:
|
83 |
+
prefix = self.adapter_wte.weight.reshape(1, self.adapter_prompt_length, self.n_embd)
|
84 |
+
|
85 |
+
aT = prefix.size(1)
|
86 |
+
_, ak, av = self.c_attn(prefix).split(self.n_embd, dim=2)
|
87 |
+
ak = ak.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2)
|
88 |
+
av = av.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2)
|
89 |
+
|
90 |
+
amask = torch.ones(q.shape[-2], ak.shape[-2], dtype=torch.bool, device=x.device)
|
91 |
+
ay = F.scaled_dot_product_attention(q, ak, av, attn_mask=amask, dropout_p=0.0, is_causal=False)
|
92 |
+
y = y + self.gating_factor * ay
|
93 |
+
|
94 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
95 |
+
|
96 |
+
# output projection
|
97 |
+
y = self.c_proj(y)
|
98 |
+
|
99 |
+
return y
|
100 |
+
|
101 |
+
|
102 |
+
class Block(nn.Module):
|
103 |
+
"""The implementation is identical to `lit_llama.model.Block` with the exception that
|
104 |
+
we replace the attention layer where adaption is implemented."""
|
105 |
+
|
106 |
+
def __init__(self, config: LLaMAConfig, block_idx: int) -> None:
|
107 |
+
super().__init__()
|
108 |
+
self.rms_1 = RMSNorm(config.n_embd)
|
109 |
+
self.attn = CausalSelfAttention(config, block_idx)
|
110 |
+
self.rms_2 = RMSNorm(config.n_embd)
|
111 |
+
self.mlp = MLP(config)
|
112 |
+
|
113 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
114 |
+
x = x + self.attn(self.rms_1(x))
|
115 |
+
x = x + self.mlp(self.rms_2(x))
|
116 |
+
return x
|
117 |
+
|
118 |
+
|
119 |
+
class LLaMA(llama.LLaMA):
|
120 |
+
"""The implementation is identical to `lit_llama.model.LLaMA` with the exception that
|
121 |
+
the `Block` saves the layer index and passes it down to the attention layer."""
|
122 |
+
|
123 |
+
def __init__(self, config: LLaMAConfig) -> None:
|
124 |
+
nn.Module.__init__(self)
|
125 |
+
assert config.vocab_size is not None
|
126 |
+
assert config.block_size is not None
|
127 |
+
self.config = config
|
128 |
+
|
129 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
130 |
+
self.transformer = nn.ModuleDict(
|
131 |
+
dict(
|
132 |
+
wte=nn.Embedding(config.vocab_size, config.n_embd),
|
133 |
+
h=nn.ModuleList([Block(config, i) for i in range(config.n_layer)]),
|
134 |
+
ln_f=RMSNorm(config.n_embd),
|
135 |
+
)
|
136 |
+
)
|
137 |
+
|
138 |
+
@classmethod
|
139 |
+
def from_name(cls, name: str):
|
140 |
+
return cls(LLaMAConfig.from_name(name))
|
141 |
+
|
142 |
+
|
143 |
+
def mark_only_adapter_as_trainable(model: LLaMA) -> None:
|
144 |
+
"""Sets `requires_grad=False` for all non-adapter weights."""
|
145 |
+
for name, param in model.named_parameters():
|
146 |
+
param.requires_grad = "adapter_wte" in name or "gating_factor" in name
|
147 |
+
|
148 |
+
|
149 |
+
def adapter_state_from_state_dict(state_dict: dict) -> dict:
|
150 |
+
"""Returns the model state dict with only the adapter weights for saving."""
|
151 |
+
return {name: param for name, param in state_dict.items() if "adapter_wte" in name or "gating_factor" in name}
|
lit_llama/indexed_dataset.py
ADDED
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py
|
2 |
+
|
3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the MIT license found in the
|
6 |
+
# LICENSE file in the root directory of the FairSeq source tree.
|
7 |
+
|
8 |
+
# copied from fairseq/fairseq/data/indexed_dataset.py
|
9 |
+
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
|
10 |
+
# other slight modifications to remove fairseq dependencies
|
11 |
+
# Added document index to index file and made it accessible.
|
12 |
+
# An empty sentence no longer separates documents.
|
13 |
+
|
14 |
+
from functools import lru_cache
|
15 |
+
import os
|
16 |
+
import shutil
|
17 |
+
import struct
|
18 |
+
from itertools import accumulate
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
|
23 |
+
|
24 |
+
def __best_fitting_dtype(vocab_size=None):
|
25 |
+
if vocab_size is not None and vocab_size < 65500:
|
26 |
+
return np.uint16
|
27 |
+
else:
|
28 |
+
return np.int32
|
29 |
+
|
30 |
+
|
31 |
+
def get_available_dataset_impl():
|
32 |
+
return ['lazy', 'cached', 'mmap']
|
33 |
+
|
34 |
+
|
35 |
+
def infer_dataset_impl(path):
|
36 |
+
if IndexedDataset.exists(path):
|
37 |
+
with open(index_file_path(path), 'rb') as f:
|
38 |
+
magic = f.read(8)
|
39 |
+
if magic == IndexedDataset._HDR_MAGIC:
|
40 |
+
return 'cached'
|
41 |
+
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
|
42 |
+
return 'mmap'
|
43 |
+
else:
|
44 |
+
return None
|
45 |
+
else:
|
46 |
+
print(f"Dataset does not exist: {path}")
|
47 |
+
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
|
48 |
+
return None
|
49 |
+
|
50 |
+
|
51 |
+
def make_builder(out_file, impl, vocab_size=None):
|
52 |
+
if impl == 'mmap':
|
53 |
+
return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size))
|
54 |
+
else:
|
55 |
+
return IndexedDatasetBuilder(out_file)
|
56 |
+
|
57 |
+
|
58 |
+
def make_dataset(path, impl, skip_warmup=False):
|
59 |
+
if not IndexedDataset.exists(path):
|
60 |
+
print(f"Dataset does not exist: {path}")
|
61 |
+
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
|
62 |
+
return None
|
63 |
+
if impl == 'infer':
|
64 |
+
impl = infer_dataset_impl(path)
|
65 |
+
if impl == 'lazy' and IndexedDataset.exists(path):
|
66 |
+
return IndexedDataset(path)
|
67 |
+
elif impl == 'cached' and IndexedDataset.exists(path):
|
68 |
+
return IndexedCachedDataset(path)
|
69 |
+
elif impl == 'mmap' and MMapIndexedDataset.exists(path):
|
70 |
+
return MMapIndexedDataset(path, skip_warmup)
|
71 |
+
print(f"Unknown dataset implementation: {impl}")
|
72 |
+
return None
|
73 |
+
|
74 |
+
|
75 |
+
def dataset_exists(path, impl):
|
76 |
+
if impl == 'mmap':
|
77 |
+
return MMapIndexedDataset.exists(path)
|
78 |
+
else:
|
79 |
+
return IndexedDataset.exists(path)
|
80 |
+
|
81 |
+
|
82 |
+
def read_longs(f, n):
|
83 |
+
a = np.empty(n, dtype=np.int64)
|
84 |
+
f.readinto(a)
|
85 |
+
return a
|
86 |
+
|
87 |
+
|
88 |
+
def write_longs(f, a):
|
89 |
+
f.write(np.array(a, dtype=np.int64))
|
90 |
+
|
91 |
+
|
92 |
+
dtypes = {
|
93 |
+
1: np.uint8,
|
94 |
+
2: np.int8,
|
95 |
+
3: np.int16,
|
96 |
+
4: np.int32,
|
97 |
+
5: np.int64,
|
98 |
+
6: np.float32,
|
99 |
+
7: np.float64,
|
100 |
+
8: np.uint16
|
101 |
+
}
|
102 |
+
|
103 |
+
|
104 |
+
def code(dtype):
|
105 |
+
for k in dtypes.keys():
|
106 |
+
if dtypes[k] == dtype:
|
107 |
+
return k
|
108 |
+
raise ValueError(dtype)
|
109 |
+
|
110 |
+
|
111 |
+
def index_file_path(prefix_path):
|
112 |
+
return prefix_path + '.idx'
|
113 |
+
|
114 |
+
|
115 |
+
def data_file_path(prefix_path):
|
116 |
+
return prefix_path + '.bin'
|
117 |
+
|
118 |
+
|
119 |
+
def create_doc_idx(sizes):
|
120 |
+
doc_idx = [0]
|
121 |
+
for i, s in enumerate(sizes):
|
122 |
+
if s == 0:
|
123 |
+
doc_idx.append(i + 1)
|
124 |
+
return doc_idx
|
125 |
+
|
126 |
+
|
127 |
+
class IndexedDataset(torch.utils.data.Dataset):
|
128 |
+
"""Loader for IndexedDataset"""
|
129 |
+
_HDR_MAGIC = b'TNTIDX\x00\x00'
|
130 |
+
|
131 |
+
def __init__(self, path):
|
132 |
+
super().__init__()
|
133 |
+
self.path = path
|
134 |
+
self.data_file = None
|
135 |
+
self.read_index(path)
|
136 |
+
|
137 |
+
def read_index(self, path):
|
138 |
+
with open(index_file_path(path), 'rb') as f:
|
139 |
+
magic = f.read(8)
|
140 |
+
assert magic == self._HDR_MAGIC, (
|
141 |
+
'Index file doesn\'t match expected format. '
|
142 |
+
'Make sure that --dataset-impl is configured properly.'
|
143 |
+
)
|
144 |
+
version = f.read(8)
|
145 |
+
assert struct.unpack('<Q', version) == (1,)
|
146 |
+
code, self.element_size = struct.unpack('<QQ', f.read(16))
|
147 |
+
self.dtype = dtypes[code]
|
148 |
+
self._len, self.s = struct.unpack('<QQ', f.read(16))
|
149 |
+
self.doc_count = struct.unpack('<Q', f.read(8))
|
150 |
+
self.dim_offsets = read_longs(f, self._len + 1)
|
151 |
+
self.data_offsets = read_longs(f, self._len + 1)
|
152 |
+
self.sizes = read_longs(f, self.s)
|
153 |
+
self.doc_idx = read_longs(f, self.doc_count)
|
154 |
+
|
155 |
+
def read_data(self, path):
|
156 |
+
self.data_file = open(data_file_path(path), 'rb', buffering=0)
|
157 |
+
|
158 |
+
def check_index(self, i):
|
159 |
+
if i < 0 or i >= self._len:
|
160 |
+
raise IndexError('index out of range')
|
161 |
+
|
162 |
+
def __del__(self):
|
163 |
+
if self.data_file:
|
164 |
+
self.data_file.close()
|
165 |
+
|
166 |
+
# @lru_cache(maxsize=8)
|
167 |
+
def __getitem__(self, idx):
|
168 |
+
if not self.data_file:
|
169 |
+
self.read_data(self.path)
|
170 |
+
if isinstance(idx, int):
|
171 |
+
i = idx
|
172 |
+
self.check_index(i)
|
173 |
+
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
|
174 |
+
a = np.empty(tensor_size, dtype=self.dtype)
|
175 |
+
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
176 |
+
self.data_file.readinto(a)
|
177 |
+
return a
|
178 |
+
elif isinstance(idx, slice):
|
179 |
+
start, stop, step = idx.indices(len(self))
|
180 |
+
if step != 1:
|
181 |
+
raise ValueError("Slices into indexed_dataset must be contiguous")
|
182 |
+
sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]]
|
183 |
+
size = sum(sizes)
|
184 |
+
a = np.empty(size, dtype=self.dtype)
|
185 |
+
self.data_file.seek(self.data_offsets[start] * self.element_size)
|
186 |
+
self.data_file.readinto(a)
|
187 |
+
offsets = list(accumulate(sizes))
|
188 |
+
sents = np.split(a, offsets[:-1])
|
189 |
+
return sents
|
190 |
+
|
191 |
+
def __len__(self):
|
192 |
+
return self._len
|
193 |
+
|
194 |
+
def num_tokens(self, index):
|
195 |
+
return self.sizes[index]
|
196 |
+
|
197 |
+
def size(self, index):
|
198 |
+
return self.sizes[index]
|
199 |
+
|
200 |
+
@staticmethod
|
201 |
+
def exists(path):
|
202 |
+
return (
|
203 |
+
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
|
204 |
+
)
|
205 |
+
|
206 |
+
@property
|
207 |
+
def supports_prefetch(self):
|
208 |
+
return False # avoid prefetching to save memory
|
209 |
+
|
210 |
+
|
211 |
+
class IndexedCachedDataset(IndexedDataset):
|
212 |
+
|
213 |
+
def __init__(self, path):
|
214 |
+
super().__init__(path)
|
215 |
+
self.cache = None
|
216 |
+
self.cache_index = {}
|
217 |
+
|
218 |
+
@property
|
219 |
+
def supports_prefetch(self):
|
220 |
+
return True
|
221 |
+
|
222 |
+
def prefetch(self, indices):
|
223 |
+
if all(i in self.cache_index for i in indices):
|
224 |
+
return
|
225 |
+
if not self.data_file:
|
226 |
+
self.read_data(self.path)
|
227 |
+
indices = sorted(set(indices))
|
228 |
+
total_size = 0
|
229 |
+
for i in indices:
|
230 |
+
total_size += self.data_offsets[i + 1] - self.data_offsets[i]
|
231 |
+
self.cache = np.empty(total_size, dtype=self.dtype)
|
232 |
+
ptx = 0
|
233 |
+
self.cache_index.clear()
|
234 |
+
for i in indices:
|
235 |
+
self.cache_index[i] = ptx
|
236 |
+
size = self.data_offsets[i + 1] - self.data_offsets[i]
|
237 |
+
a = self.cache[ptx: ptx + size]
|
238 |
+
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
239 |
+
self.data_file.readinto(a)
|
240 |
+
ptx += size
|
241 |
+
if self.data_file:
|
242 |
+
# close and delete data file after prefetch so we can pickle
|
243 |
+
self.data_file.close()
|
244 |
+
self.data_file = None
|
245 |
+
|
246 |
+
# @lru_cache(maxsize=8)
|
247 |
+
def __getitem__(self, idx):
|
248 |
+
if isinstance(idx, int):
|
249 |
+
i = idx
|
250 |
+
self.check_index(i)
|
251 |
+
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
|
252 |
+
a = np.empty(tensor_size, dtype=self.dtype)
|
253 |
+
ptx = self.cache_index[i]
|
254 |
+
np.copyto(a, self.cache[ptx: ptx + a.size])
|
255 |
+
return a
|
256 |
+
elif isinstance(idx, slice):
|
257 |
+
# Hack just to make this work, can optimizer later if necessary
|
258 |
+
sents = []
|
259 |
+
for i in range(*idx.indices(len(self))):
|
260 |
+
sents.append(self[i])
|
261 |
+
return sents
|
262 |
+
|
263 |
+
|
264 |
+
class IndexedDatasetBuilder(object):
|
265 |
+
element_sizes = {
|
266 |
+
np.uint8: 1,
|
267 |
+
np.int8: 1,
|
268 |
+
np.int16: 2,
|
269 |
+
np.int32: 4,
|
270 |
+
np.int64: 8,
|
271 |
+
np.float32: 4,
|
272 |
+
np.float64: 8
|
273 |
+
}
|
274 |
+
|
275 |
+
def __init__(self, out_file, dtype=np.int32):
|
276 |
+
self.out_file = open(out_file, 'wb')
|
277 |
+
self.dtype = dtype
|
278 |
+
self.data_offsets = [0]
|
279 |
+
self.dim_offsets = [0]
|
280 |
+
self.sizes = []
|
281 |
+
self.element_size = self.element_sizes[self.dtype]
|
282 |
+
self.doc_idx = [0]
|
283 |
+
|
284 |
+
def add_item(self, tensor):
|
285 |
+
bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype))
|
286 |
+
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
|
287 |
+
for s in tensor.size():
|
288 |
+
self.sizes.append(s)
|
289 |
+
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
|
290 |
+
|
291 |
+
def end_document(self):
|
292 |
+
self.doc_idx.append(len(self.sizes))
|
293 |
+
|
294 |
+
def merge_file_(self, another_file):
|
295 |
+
index = IndexedDataset(another_file)
|
296 |
+
assert index.dtype == self.dtype
|
297 |
+
|
298 |
+
doc_offset = len(self.sizes)
|
299 |
+
|
300 |
+
begin = self.data_offsets[-1]
|
301 |
+
for data_offset in index.data_offsets[1:]:
|
302 |
+
self.data_offsets.append(begin + data_offset)
|
303 |
+
self.sizes.extend(index.sizes)
|
304 |
+
|
305 |
+
begin = self.dim_offsets[-1]
|
306 |
+
for dim_offset in index.dim_offsets[1:]:
|
307 |
+
self.dim_offsets.append(begin + dim_offset)
|
308 |
+
|
309 |
+
self.doc_idx.extend((doc_offset + index.doc_idx)[1:])
|
310 |
+
|
311 |
+
with open(data_file_path(another_file), 'rb') as f:
|
312 |
+
while True:
|
313 |
+
data = f.read(1024)
|
314 |
+
if data:
|
315 |
+
self.out_file.write(data)
|
316 |
+
else:
|
317 |
+
break
|
318 |
+
|
319 |
+
def finalize(self, index_file):
|
320 |
+
self.out_file.close()
|
321 |
+
index = open(index_file, 'wb')
|
322 |
+
index.write(b'TNTIDX\x00\x00')
|
323 |
+
index.write(struct.pack('<Q', 1))
|
324 |
+
index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
|
325 |
+
index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
|
326 |
+
index.write(struct.pack('<Q', len(self.doc_idx)))
|
327 |
+
write_longs(index, self.dim_offsets)
|
328 |
+
write_longs(index, self.data_offsets)
|
329 |
+
write_longs(index, self.sizes)
|
330 |
+
write_longs(index, self.doc_idx)
|
331 |
+
index.close()
|
332 |
+
|
333 |
+
|
334 |
+
def _warmup_mmap_file(path):
|
335 |
+
with open(path, 'rb') as stream:
|
336 |
+
while stream.read(100 * 1024 * 1024):
|
337 |
+
pass
|
338 |
+
|
339 |
+
|
340 |
+
class MMapIndexedDataset(torch.utils.data.Dataset):
|
341 |
+
class Index(object):
|
342 |
+
_HDR_MAGIC = b'MMIDIDX\x00\x00'
|
343 |
+
|
344 |
+
@classmethod
|
345 |
+
def writer(cls, path, dtype):
|
346 |
+
class _Writer(object):
|
347 |
+
def __enter__(self):
|
348 |
+
self._file = open(path, 'wb')
|
349 |
+
|
350 |
+
self._file.write(cls._HDR_MAGIC)
|
351 |
+
self._file.write(struct.pack('<Q', 1))
|
352 |
+
self._file.write(struct.pack('<B', code(dtype)))
|
353 |
+
|
354 |
+
return self
|
355 |
+
|
356 |
+
@staticmethod
|
357 |
+
def _get_pointers(sizes):
|
358 |
+
dtype_size = dtype().itemsize
|
359 |
+
address = 0
|
360 |
+
pointers = []
|
361 |
+
|
362 |
+
for size in sizes:
|
363 |
+
pointers.append(address)
|
364 |
+
address += size * dtype_size
|
365 |
+
|
366 |
+
return pointers
|
367 |
+
|
368 |
+
def write(self, sizes, doc_idx):
|
369 |
+
pointers = self._get_pointers(sizes)
|
370 |
+
|
371 |
+
self._file.write(struct.pack('<Q', len(sizes)))
|
372 |
+
self._file.write(struct.pack('<Q', len(doc_idx)))
|
373 |
+
|
374 |
+
sizes = np.array(sizes, dtype=np.int32)
|
375 |
+
self._file.write(sizes.tobytes(order='C'))
|
376 |
+
del sizes
|
377 |
+
|
378 |
+
pointers = np.array(pointers, dtype=np.int64)
|
379 |
+
self._file.write(pointers.tobytes(order='C'))
|
380 |
+
del pointers
|
381 |
+
|
382 |
+
doc_idx = np.array(doc_idx, dtype=np.int64)
|
383 |
+
self._file.write(doc_idx.tobytes(order='C'))
|
384 |
+
|
385 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
386 |
+
self._file.close()
|
387 |
+
|
388 |
+
return _Writer()
|
389 |
+
|
390 |
+
def __init__(self, path, skip_warmup=False):
|
391 |
+
with open(path, 'rb') as stream:
|
392 |
+
magic_test = stream.read(9)
|
393 |
+
assert self._HDR_MAGIC == magic_test, (
|
394 |
+
'Index file doesn\'t match expected format. '
|
395 |
+
'Make sure that --dataset-impl is configured properly.'
|
396 |
+
)
|
397 |
+
version = struct.unpack('<Q', stream.read(8))
|
398 |
+
assert (1,) == version
|
399 |
+
|
400 |
+
dtype_code, = struct.unpack('<B', stream.read(1))
|
401 |
+
self._dtype = dtypes[dtype_code]
|
402 |
+
self._dtype_size = self._dtype().itemsize
|
403 |
+
|
404 |
+
self._len = struct.unpack('<Q', stream.read(8))[0]
|
405 |
+
self._doc_count = struct.unpack('<Q', stream.read(8))[0]
|
406 |
+
offset = stream.tell()
|
407 |
+
|
408 |
+
if not skip_warmup:
|
409 |
+
print(" warming up index mmap file...")
|
410 |
+
_warmup_mmap_file(path)
|
411 |
+
|
412 |
+
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
|
413 |
+
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
414 |
+
print(" reading sizes...")
|
415 |
+
self._sizes = np.frombuffer(
|
416 |
+
self._bin_buffer,
|
417 |
+
dtype=np.int32,
|
418 |
+
count=self._len,
|
419 |
+
offset=offset)
|
420 |
+
print(" reading pointers...")
|
421 |
+
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
|
422 |
+
offset=offset + self._sizes.nbytes)
|
423 |
+
print(" reading document index...")
|
424 |
+
self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
|
425 |
+
offset=offset + self._sizes.nbytes + self._pointers.nbytes)
|
426 |
+
|
427 |
+
def __del__(self):
|
428 |
+
self._bin_buffer_mmap._mmap.close()
|
429 |
+
del self._bin_buffer_mmap
|
430 |
+
|
431 |
+
@property
|
432 |
+
def dtype(self):
|
433 |
+
return self._dtype
|
434 |
+
|
435 |
+
@property
|
436 |
+
def sizes(self):
|
437 |
+
return self._sizes
|
438 |
+
|
439 |
+
@property
|
440 |
+
def doc_idx(self):
|
441 |
+
return self._doc_idx
|
442 |
+
|
443 |
+
@lru_cache(maxsize=8)
|
444 |
+
def __getitem__(self, i):
|
445 |
+
return self._pointers[i], self._sizes[i]
|
446 |
+
|
447 |
+
def __len__(self):
|
448 |
+
return self._len
|
449 |
+
|
450 |
+
def __init__(self, path, skip_warmup=False):
|
451 |
+
super().__init__()
|
452 |
+
|
453 |
+
self._path = None
|
454 |
+
self._index = None
|
455 |
+
self._bin_buffer = None
|
456 |
+
|
457 |
+
self._do_init(path, skip_warmup)
|
458 |
+
|
459 |
+
def __getstate__(self):
|
460 |
+
return self._path
|
461 |
+
|
462 |
+
def __setstate__(self, state):
|
463 |
+
self._do_init(state, skip_warmup=True)
|
464 |
+
|
465 |
+
def _do_init(self, path, skip_warmup):
|
466 |
+
self._path = path
|
467 |
+
self._index = self.Index(index_file_path(self._path), skip_warmup)
|
468 |
+
|
469 |
+
if not skip_warmup:
|
470 |
+
print(" warming up data mmap file...")
|
471 |
+
_warmup_mmap_file(data_file_path(self._path))
|
472 |
+
print(" creating numpy buffer of mmap...")
|
473 |
+
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
|
474 |
+
print(" creating memory view of numpy buffer...")
|
475 |
+
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
476 |
+
|
477 |
+
def __del__(self):
|
478 |
+
self._bin_buffer_mmap._mmap.close()
|
479 |
+
del self._bin_buffer_mmap
|
480 |
+
del self._index
|
481 |
+
|
482 |
+
def __len__(self):
|
483 |
+
return len(self._index)
|
484 |
+
|
485 |
+
# @lru_cache(maxsize=8)
|
486 |
+
def __getitem__(self, idx):
|
487 |
+
if isinstance(idx, (int, np.integer)):
|
488 |
+
ptr, size = self._index[idx]
|
489 |
+
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
|
490 |
+
count=size, offset=ptr)
|
491 |
+
return np_array
|
492 |
+
elif isinstance(idx, slice):
|
493 |
+
start, stop, step = idx.indices(len(self))
|
494 |
+
if step != 1:
|
495 |
+
raise ValueError("Slices into indexed_dataset must be contiguous")
|
496 |
+
ptr = self._index._pointers[start]
|
497 |
+
sizes = self._index._sizes[idx]
|
498 |
+
offsets = list(accumulate(sizes))
|
499 |
+
total_size = sum(sizes)
|
500 |
+
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
|
501 |
+
count=total_size, offset=ptr)
|
502 |
+
sents = np.split(np_array, offsets[:-1])
|
503 |
+
return sents
|
504 |
+
else:
|
505 |
+
raise TypeError("Unexpected type received for idx: {}".format(type(idx)))
|
506 |
+
|
507 |
+
def get(self, idx, offset=0, length=None):
|
508 |
+
""" Retrieves a single item from the dataset with the option to only
|
509 |
+
return a portion of the item.
|
510 |
+
|
511 |
+
get(idx) is the same as [idx] but get() does not support slicing.
|
512 |
+
"""
|
513 |
+
ptr, size = self._index[idx]
|
514 |
+
if length is None:
|
515 |
+
length = size - offset
|
516 |
+
ptr += offset * np.dtype(self._index.dtype).itemsize
|
517 |
+
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
|
518 |
+
count=length, offset=ptr)
|
519 |
+
return np_array
|
520 |
+
|
521 |
+
@property
|
522 |
+
def sizes(self):
|
523 |
+
return self._index.sizes
|
524 |
+
|
525 |
+
@property
|
526 |
+
def doc_idx(self):
|
527 |
+
return self._index.doc_idx
|
528 |
+
|
529 |
+
def get_doc_idx(self):
|
530 |
+
return self._index._doc_idx
|
531 |
+
|
532 |
+
def set_doc_idx(self, doc_idx_):
|
533 |
+
self._index._doc_idx = doc_idx_
|
534 |
+
|
535 |
+
@property
|
536 |
+
def supports_prefetch(self):
|
537 |
+
return False
|
538 |
+
|
539 |
+
@staticmethod
|
540 |
+
def exists(path):
|
541 |
+
return (
|
542 |
+
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
|
543 |
+
)
|
544 |
+
|
545 |
+
|
546 |
+
class MMapIndexedDatasetBuilder(object):
|
547 |
+
def __init__(self, out_file, dtype=np.int64):
|
548 |
+
self._data_file = open(out_file, 'wb')
|
549 |
+
self._dtype = dtype
|
550 |
+
self._sizes = []
|
551 |
+
self._doc_idx = [0]
|
552 |
+
|
553 |
+
@property
|
554 |
+
def dtype(self):
|
555 |
+
return self._dtype
|
556 |
+
|
557 |
+
def add_item(self, np_array):
|
558 |
+
# np_array = np.array(tensor.numpy(), dtype=self._dtype)
|
559 |
+
self._data_file.write(np_array.tobytes(order='C'))
|
560 |
+
self._sizes.append(np_array.size)
|
561 |
+
|
562 |
+
def add_doc(self, np_array, sizes):
|
563 |
+
# np_array = np.array(tensor, dtype=self._dtype)
|
564 |
+
self._data_file.write(np_array.tobytes(order='C'))
|
565 |
+
self._sizes.extend(sizes)
|
566 |
+
self._doc_idx.append(len(self._sizes))
|
567 |
+
|
568 |
+
def end_document(self):
|
569 |
+
self._doc_idx.append(len(self._sizes))
|
570 |
+
|
571 |
+
def merge_file_(self, another_file):
|
572 |
+
# Concatenate index
|
573 |
+
index = MMapIndexedDataset.Index(index_file_path(another_file))
|
574 |
+
assert index.dtype == self._dtype
|
575 |
+
|
576 |
+
offset = len(self._sizes)
|
577 |
+
self._sizes.extend(index.sizes)
|
578 |
+
self._doc_idx.extend((offset + index.doc_idx)[1:])
|
579 |
+
|
580 |
+
# Concatenate data
|
581 |
+
with open(data_file_path(another_file), 'rb') as f:
|
582 |
+
shutil.copyfileobj(f, self._data_file)
|
583 |
+
|
584 |
+
def finalize(self, index_file):
|
585 |
+
self._data_file.close()
|
586 |
+
|
587 |
+
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
|
588 |
+
index.write(self._sizes, self._doc_idx)
|
lit_llama/lora.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Derived from https://github.com/microsoft/LoRA
|
2 |
+
# ------------------------------------------------------------------------------------------
|
3 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
4 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
5 |
+
# ------------------------------------------------------------------------------------------
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import math
|
11 |
+
from typing import Dict, List
|
12 |
+
|
13 |
+
import lit_llama.model as llama
|
14 |
+
|
15 |
+
from contextlib import contextmanager
|
16 |
+
from dataclasses import dataclass
|
17 |
+
|
18 |
+
class LoRALayer():
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
r: int,
|
22 |
+
lora_alpha: int,
|
23 |
+
lora_dropout: float,
|
24 |
+
merge_weights: bool,
|
25 |
+
):
|
26 |
+
self.r = r
|
27 |
+
self.lora_alpha = lora_alpha
|
28 |
+
# Optional dropout
|
29 |
+
if lora_dropout > 0.:
|
30 |
+
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
31 |
+
else:
|
32 |
+
self.lora_dropout = lambda x: x
|
33 |
+
# Mark the weight as unmerged
|
34 |
+
self.merged = False
|
35 |
+
self.merge_weights = merge_weights
|
36 |
+
|
37 |
+
|
38 |
+
class MergedLinear(nn.Linear, LoRALayer):
|
39 |
+
# LoRA implemented in a dense layer
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
in_features: int,
|
43 |
+
out_features: int,
|
44 |
+
r: int = 0,
|
45 |
+
lora_alpha: int = 1,
|
46 |
+
lora_dropout: float = 0.,
|
47 |
+
enable_lora: List[bool] = [False],
|
48 |
+
fan_in_fan_out: bool = False,
|
49 |
+
merge_weights: bool = True,
|
50 |
+
**kwargs
|
51 |
+
):
|
52 |
+
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
53 |
+
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
|
54 |
+
merge_weights=merge_weights)
|
55 |
+
assert out_features % len(enable_lora) == 0, \
|
56 |
+
'The length of enable_lora must divide out_features'
|
57 |
+
self.enable_lora = enable_lora
|
58 |
+
self.fan_in_fan_out = fan_in_fan_out
|
59 |
+
# Actual trainable parameters
|
60 |
+
if r > 0 and any(enable_lora):
|
61 |
+
self.lora_A = nn.Parameter(
|
62 |
+
self.weight.new_zeros((r * sum(enable_lora), in_features)))
|
63 |
+
self.lora_B = nn.Parameter(
|
64 |
+
self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))
|
65 |
+
) # weights for Conv1D with groups=sum(enable_lora)
|
66 |
+
self.scaling = self.lora_alpha / self.r
|
67 |
+
# Freezing the pre-trained weight matrix
|
68 |
+
self.weight.requires_grad = False
|
69 |
+
# Compute the indices
|
70 |
+
self.lora_ind = self.weight.new_zeros(
|
71 |
+
(out_features, ), dtype=torch.bool
|
72 |
+
).view(len(enable_lora), -1)
|
73 |
+
self.lora_ind[enable_lora, :] = True
|
74 |
+
self.lora_ind = self.lora_ind.view(-1)
|
75 |
+
self.reset_parameters()
|
76 |
+
if fan_in_fan_out:
|
77 |
+
self.weight.data = self.weight.data.T
|
78 |
+
|
79 |
+
def reset_parameters(self):
|
80 |
+
nn.Linear.reset_parameters(self)
|
81 |
+
if hasattr(self, 'lora_A'):
|
82 |
+
# initialize A the same way as the default for nn.Linear and B to zero
|
83 |
+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
84 |
+
nn.init.zeros_(self.lora_B)
|
85 |
+
|
86 |
+
def zero_pad(self, x):
|
87 |
+
result = x.new_zeros((*x.shape[:-1], self.out_features))
|
88 |
+
result = result.view(-1, self.out_features)
|
89 |
+
result[:, self.lora_ind] = x.reshape(
|
90 |
+
-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
|
91 |
+
)
|
92 |
+
return result.view((*x.shape[:-1], self.out_features))
|
93 |
+
|
94 |
+
def train(self, mode: bool = True):
|
95 |
+
def T(w):
|
96 |
+
return w.T if self.fan_in_fan_out else w
|
97 |
+
nn.Linear.train(self, mode)
|
98 |
+
if self.merge_weights and self.merged:
|
99 |
+
# Make sure that the weights are not merged
|
100 |
+
if self.r > 0 and any(self.enable_lora):
|
101 |
+
delta_w = F.conv1d(
|
102 |
+
self.lora_A.data.unsqueeze(0),
|
103 |
+
self.lora_B.data.unsqueeze(-1),
|
104 |
+
groups=sum(self.enable_lora)
|
105 |
+
).squeeze(0)
|
106 |
+
self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
|
107 |
+
self.merged = False
|
108 |
+
|
109 |
+
def eval(self):
|
110 |
+
def T(w):
|
111 |
+
return w.T if self.fan_in_fan_out else w
|
112 |
+
nn.Linear.eval(self)
|
113 |
+
if self.merge_weights and not self.merged:
|
114 |
+
# Merge the weights and mark it
|
115 |
+
if self.r > 0 and any(self.enable_lora):
|
116 |
+
delta_w = F.conv1d(
|
117 |
+
self.lora_A.data.unsqueeze(0),
|
118 |
+
self.lora_B.data.unsqueeze(-1),
|
119 |
+
groups=sum(self.enable_lora)
|
120 |
+
).squeeze(0)
|
121 |
+
self.weight.data += self.zero_pad(T(delta_w * self.scaling))
|
122 |
+
self.merged = True
|
123 |
+
|
124 |
+
def forward(self, x: torch.Tensor):
|
125 |
+
def T(w):
|
126 |
+
return w.T if self.fan_in_fan_out else w
|
127 |
+
if self.merged:
|
128 |
+
return F.linear(x, T(self.weight), bias=self.bias)
|
129 |
+
else:
|
130 |
+
result = F.linear(x, T(self.weight), bias=self.bias)
|
131 |
+
if self.r > 0:
|
132 |
+
after_A = F.linear(self.lora_dropout(x), self.lora_A)
|
133 |
+
after_B = F.conv1d(
|
134 |
+
after_A.transpose(-2, -1),
|
135 |
+
self.lora_B.unsqueeze(-1),
|
136 |
+
groups=sum(self.enable_lora)
|
137 |
+
).transpose(-2, -1)
|
138 |
+
result += self.zero_pad(after_B) * self.scaling
|
139 |
+
return result
|
140 |
+
|
141 |
+
|
142 |
+
def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
|
143 |
+
# import pdb; pdb.set_trace()
|
144 |
+
for n, p in model.named_parameters():
|
145 |
+
if 'lora_' not in n and 'motion_proj' not in n and 'llama_proj' not in n:
|
146 |
+
p.requires_grad = False
|
147 |
+
if bias == 'none':
|
148 |
+
return
|
149 |
+
elif bias == 'all':
|
150 |
+
for n, p in model.named_parameters():
|
151 |
+
if 'bias' in n:
|
152 |
+
p.requires_grad = True
|
153 |
+
elif bias == 'lora_only':
|
154 |
+
for m in model.modules():
|
155 |
+
if isinstance(m, LoRALayer) and \
|
156 |
+
hasattr(m, 'bias') and \
|
157 |
+
m.bias is not None:
|
158 |
+
m.bias.requires_grad = True
|
159 |
+
else:
|
160 |
+
raise NotImplementedError
|
161 |
+
|
162 |
+
|
163 |
+
def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
|
164 |
+
my_state_dict = model.state_dict()
|
165 |
+
if bias == 'none':
|
166 |
+
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'llama_proj' in k or 'motion_proj' in k}
|
167 |
+
elif bias == 'all':
|
168 |
+
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k or 'llama_proj' in k or 'motion_proj' in k}
|
169 |
+
elif bias == 'lora_only':
|
170 |
+
to_return = {}
|
171 |
+
for k in my_state_dict:
|
172 |
+
if 'lora_' in k:
|
173 |
+
to_return[k] = my_state_dict[k]
|
174 |
+
bias_name = k.split('lora_')[0]+'bias'
|
175 |
+
if bias_name in my_state_dict:
|
176 |
+
to_return[bias_name] = my_state_dict[bias_name]
|
177 |
+
return to_return
|
178 |
+
else:
|
179 |
+
raise NotImplementedError
|
180 |
+
|
181 |
+
|
182 |
+
@dataclass
|
183 |
+
class LoRAConfig:
|
184 |
+
r: float = 0.0
|
185 |
+
alpha: float = 1.0
|
186 |
+
dropout: float = 0.0
|
187 |
+
|
188 |
+
|
189 |
+
class CausalSelfAttention(llama.CausalSelfAttention):
|
190 |
+
lora_config = None
|
191 |
+
|
192 |
+
def __init__(self, config: llama.LLaMAConfig) -> None:
|
193 |
+
# Skip the parent class __init__ altogether and replace it to avoid
|
194 |
+
# useless allocations
|
195 |
+
nn.Module.__init__(self)
|
196 |
+
assert config.n_embd % config.n_head == 0
|
197 |
+
|
198 |
+
# key, query, value projections for all heads, but in a batch
|
199 |
+
self.c_attn = MergedLinear(
|
200 |
+
in_features=config.n_embd,
|
201 |
+
out_features=3 * config.n_embd,
|
202 |
+
r=self.lora_config.r,
|
203 |
+
lora_alpha=self.lora_config.alpha,
|
204 |
+
lora_dropout=self.lora_config.dropout,
|
205 |
+
enable_lora=[True, False, True],
|
206 |
+
fan_in_fan_out = False,
|
207 |
+
merge_weights=True,
|
208 |
+
bias=False)
|
209 |
+
# output projection
|
210 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
211 |
+
# regularization
|
212 |
+
self.n_head = config.n_head
|
213 |
+
self.n_embd = config.n_embd
|
214 |
+
self.block_size = config.block_size
|
215 |
+
self.rope_cache = None
|
216 |
+
|
217 |
+
|
218 |
+
@contextmanager
|
219 |
+
def lora(r, alpha, dropout, enabled: bool = True):
|
220 |
+
"""A context manager under which you can instantiate the model with LoRA."""
|
221 |
+
if not enabled:
|
222 |
+
yield
|
223 |
+
return
|
224 |
+
|
225 |
+
CausalSelfAttention.lora_config = LoRAConfig(r=r, alpha=alpha, dropout=dropout)
|
226 |
+
|
227 |
+
causal_self_attention = llama.CausalSelfAttention
|
228 |
+
llama.CausalSelfAttention = CausalSelfAttention
|
229 |
+
yield
|
230 |
+
llama.CausalSelfAttention = causal_self_attention
|
231 |
+
|
232 |
+
CausalSelfAttention.lora_config = None
|
lit_llama/model.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Full definition of a LLaMA Language Model, all of it in this single file.
|
2 |
+
|
3 |
+
Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
|
4 |
+
"""
|
5 |
+
# mypy: ignore-errors
|
6 |
+
import math
|
7 |
+
from dataclasses import dataclass
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from typing_extensions import Self
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class LLaMAConfig:
|
17 |
+
block_size: int = 4096
|
18 |
+
vocab_size: int = 32000
|
19 |
+
n_layer: int = 32
|
20 |
+
n_head: int = 32
|
21 |
+
n_embd: int = 4096
|
22 |
+
|
23 |
+
@classmethod
|
24 |
+
def from_name(cls, name: str) -> Self:
|
25 |
+
return cls(**llama_configs[name])
|
26 |
+
|
27 |
+
|
28 |
+
llama_configs = {
|
29 |
+
"7B": dict(n_layer=32, n_head=32, n_embd=4096),
|
30 |
+
"13B": dict(n_layer=40, n_head=40, n_embd=5120),
|
31 |
+
"30B": dict(n_layer=60, n_head=52, n_embd=6656),
|
32 |
+
"65B": dict(n_layer=80, n_head=64, n_embd=8192),
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
class LLaMA(nn.Module):
|
37 |
+
def __init__(self, config: LLaMAConfig) -> None:
|
38 |
+
super().__init__()
|
39 |
+
assert config.vocab_size is not None
|
40 |
+
assert config.block_size is not None
|
41 |
+
self.config = config
|
42 |
+
|
43 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
44 |
+
self.transformer = nn.ModuleDict(
|
45 |
+
dict(
|
46 |
+
wte=nn.Embedding(config.vocab_size, config.n_embd),
|
47 |
+
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
48 |
+
ln_f=RMSNorm(config.n_embd),
|
49 |
+
)
|
50 |
+
)
|
51 |
+
# self.llama_proj = nn.Sequential(
|
52 |
+
# nn.Linear(256, 1024),
|
53 |
+
# nn.ReLU(),
|
54 |
+
# nn.Linear(1024, config.n_embd)
|
55 |
+
# )
|
56 |
+
self.llama_proj = nn.Linear(512, config.n_embd)
|
57 |
+
# self.motion_proj = nn.Sequential(
|
58 |
+
# nn.Linear(config.n_embd, 1024),
|
59 |
+
# nn.ReLU(),
|
60 |
+
# nn.Linear(1024, 256)
|
61 |
+
# )
|
62 |
+
self.motion_proj = nn.Linear(config.n_embd, 512)
|
63 |
+
|
64 |
+
def _init_weights(self, module: nn.Module) -> None:
|
65 |
+
if isinstance(module, nn.Linear):
|
66 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
|
67 |
+
elif isinstance(module, nn.Embedding):
|
68 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
|
69 |
+
|
70 |
+
def forward(self, idx: torch.Tensor) -> torch.Tensor:
|
71 |
+
# import pdb; pdb.set_trace()
|
72 |
+
_, t = idx.size()
|
73 |
+
assert (
|
74 |
+
t <= self.config.block_size
|
75 |
+
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
76 |
+
|
77 |
+
# forward the LLaMA model itself
|
78 |
+
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
79 |
+
|
80 |
+
for block in self.transformer.h:
|
81 |
+
x = block(x)
|
82 |
+
x = self.transformer.ln_f(x)
|
83 |
+
|
84 |
+
logits = self.lm_head(x) # (b, t, vocab_size)
|
85 |
+
|
86 |
+
return logits
|
87 |
+
|
88 |
+
@classmethod
|
89 |
+
def from_name(cls, name: str) -> Self:
|
90 |
+
return cls(LLaMAConfig.from_name(name))
|
91 |
+
|
92 |
+
|
93 |
+
class Block(nn.Module):
|
94 |
+
def __init__(self, config: LLaMAConfig) -> None:
|
95 |
+
super().__init__()
|
96 |
+
self.rms_1 = RMSNorm(config.n_embd)
|
97 |
+
self.attn = CausalSelfAttention(config)
|
98 |
+
self.rms_2 = RMSNorm(config.n_embd)
|
99 |
+
self.mlp = MLP(config)
|
100 |
+
|
101 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
102 |
+
x = x + self.attn(self.rms_1(x))
|
103 |
+
x = x + self.mlp(self.rms_2(x))
|
104 |
+
return x
|
105 |
+
|
106 |
+
|
107 |
+
class CausalSelfAttention(nn.Module):
|
108 |
+
def __init__(self, config: LLaMAConfig) -> None:
|
109 |
+
super().__init__()
|
110 |
+
assert config.n_embd % config.n_head == 0
|
111 |
+
|
112 |
+
# key, query, value projections for all heads, but in a batch
|
113 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
|
114 |
+
# output projection
|
115 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
116 |
+
|
117 |
+
self.n_head = config.n_head
|
118 |
+
self.n_embd = config.n_embd
|
119 |
+
self.block_size = config.block_size
|
120 |
+
self.rope_cache = None
|
121 |
+
|
122 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
123 |
+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
124 |
+
|
125 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
126 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
127 |
+
|
128 |
+
head_size = C // self.n_head
|
129 |
+
k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
|
130 |
+
q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
|
131 |
+
v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
|
132 |
+
|
133 |
+
if self.rope_cache is None:
|
134 |
+
# cache for future forward calls
|
135 |
+
self.rope_cache = build_rope_cache(
|
136 |
+
seq_len=self.block_size,
|
137 |
+
n_elem=self.n_embd // self.n_head,
|
138 |
+
dtype=x.dtype,
|
139 |
+
device=x.device,
|
140 |
+
)
|
141 |
+
|
142 |
+
q = apply_rope(q, self.rope_cache)
|
143 |
+
k = apply_rope(k, self.rope_cache)
|
144 |
+
|
145 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
146 |
+
# att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
147 |
+
# att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
|
148 |
+
# att = F.softmax(att, dim=-1)
|
149 |
+
# y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
150 |
+
|
151 |
+
# efficient attention using Flash Attention CUDA kernels
|
152 |
+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
|
153 |
+
|
154 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
155 |
+
|
156 |
+
# output projection
|
157 |
+
y = self.c_proj(y)
|
158 |
+
|
159 |
+
return y
|
160 |
+
|
161 |
+
|
162 |
+
class MLP(nn.Module):
|
163 |
+
def __init__(self, config: LLaMAConfig) -> None:
|
164 |
+
super().__init__()
|
165 |
+
hidden_dim = 4 * config.n_embd
|
166 |
+
n_hidden = int(2 * hidden_dim / 3)
|
167 |
+
N = 256
|
168 |
+
# ensure n_hidden is multiple of N
|
169 |
+
n_hidden = ((n_hidden - 1) // N) * N + N
|
170 |
+
|
171 |
+
self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False)
|
172 |
+
self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False)
|
173 |
+
self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False)
|
174 |
+
|
175 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
176 |
+
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
|
177 |
+
x = self.c_proj(x)
|
178 |
+
return x
|
179 |
+
|
180 |
+
|
181 |
+
class RMSNorm(nn.Module):
|
182 |
+
"""Root Mean Square Layer Normalization.
|
183 |
+
|
184 |
+
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
|
185 |
+
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
|
186 |
+
"""
|
187 |
+
|
188 |
+
def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
|
189 |
+
super().__init__()
|
190 |
+
self.scale = nn.Parameter(torch.ones(size))
|
191 |
+
self.eps = eps
|
192 |
+
self.dim = dim
|
193 |
+
|
194 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
195 |
+
# NOTE: the original RMSNorm paper implementation is not equivalent
|
196 |
+
# norm_x = x.norm(2, dim=self.dim, keepdim=True)
|
197 |
+
# rms_x = norm_x * d_x ** (-1. / 2)
|
198 |
+
# x_normed = x / (rms_x + self.eps)
|
199 |
+
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
|
200 |
+
x_normed = x * torch.rsqrt(norm_x + self.eps)
|
201 |
+
return self.scale * x_normed
|
202 |
+
|
203 |
+
|
204 |
+
def build_rope_cache(seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000) -> torch.Tensor:
|
205 |
+
"""Enhanced Transformer with Rotary Position Embedding.
|
206 |
+
|
207 |
+
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
|
208 |
+
transformers/rope/__init__.py. MIT License:
|
209 |
+
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
|
210 |
+
"""
|
211 |
+
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
212 |
+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
|
213 |
+
|
214 |
+
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
215 |
+
seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
|
216 |
+
|
217 |
+
# Calculate the product of position index and $\theta_i$
|
218 |
+
idx_theta = torch.outer(seq_idx, theta)
|
219 |
+
|
220 |
+
# Compute cache. Because polar only takes float32 or float64, we need to cast
|
221 |
+
# when working with 16 bit floats (float16 or bfloat16)
|
222 |
+
dtypes_requiring_casting = [torch.float16, torch.bfloat16, torch.int8]
|
223 |
+
working_dtype = (
|
224 |
+
torch.float32 if dtype in dtypes_requiring_casting else dtype
|
225 |
+
)
|
226 |
+
complex_dtype = (
|
227 |
+
torch.complex32 if dtype in dtypes_requiring_casting else torch.complex64
|
228 |
+
)
|
229 |
+
cache = torch.polar(
|
230 |
+
torch.ones_like(idx_theta).to(working_dtype), idx_theta.to(working_dtype)
|
231 |
+
).to(complex_dtype)
|
232 |
+
return cache
|
233 |
+
|
234 |
+
|
235 |
+
def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
236 |
+
x = x.transpose(1, 2)
|
237 |
+
|
238 |
+
# truncate to support variable sizes
|
239 |
+
T = x.size(1)
|
240 |
+
rope_cache = rope_cache[:T]
|
241 |
+
|
242 |
+
# cast because `view_as_complex` does not support 16 bit tensors
|
243 |
+
xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
244 |
+
rope_cache = rope_cache.view(1, xc.size(1), 1, xc.size(3))
|
245 |
+
x_out = torch.view_as_real(xc * rope_cache).flatten(3)
|
246 |
+
return x_out.transpose(1, 2).type_as(x)
|
lit_llama/quantization.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from contextlib import contextmanager
|
3 |
+
import warnings
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
# configuration for bitsandbytes before import
|
9 |
+
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
|
10 |
+
warnings.filterwarnings(
|
11 |
+
"ignore",
|
12 |
+
message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization"
|
13 |
+
)
|
14 |
+
warnings.filterwarnings(
|
15 |
+
"ignore",
|
16 |
+
message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization"
|
17 |
+
)
|
18 |
+
warnings.filterwarnings(
|
19 |
+
"ignore",
|
20 |
+
message="The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable."
|
21 |
+
)
|
22 |
+
|
23 |
+
try:
|
24 |
+
import bitsandbytes as bnb # noqa: E402
|
25 |
+
except:
|
26 |
+
bnb = None
|
27 |
+
|
28 |
+
if bnb is not None:
|
29 |
+
class Linear8bitLt(bnb.nn.Linear8bitLt):
|
30 |
+
"""Wraps `bnb.nn.Linear8bitLt` and enables instantiation directly on the device and
|
31 |
+
re-quantizaton when loading the state dict.
|
32 |
+
|
33 |
+
|
34 |
+
This should only be used for inference. For training, use `bnb.nn.Linear8bitLt` directly.
|
35 |
+
"""
|
36 |
+
def __init__(self, *args, **kwargs):
|
37 |
+
super().__init__(*args, **kwargs, has_fp16_weights=False, threshold=6.0)
|
38 |
+
# We quantize the initial weight here so we don't end up filling the device
|
39 |
+
# memory with float32 weights which could lead to OOM.
|
40 |
+
self._quantize_weight(self.weight.data)
|
41 |
+
|
42 |
+
def _load_from_state_dict(self, local_state_dict, *args, **kwargs):
|
43 |
+
# There is only one key that ends with `*.weight`, the other one is the bias
|
44 |
+
weight_key = next((name for name in local_state_dict.keys() if name.endswith("weight")), None)
|
45 |
+
if weight_key is None:
|
46 |
+
return
|
47 |
+
|
48 |
+
# Load the weight from the state dict and re-quantize it
|
49 |
+
weight = local_state_dict.pop(weight_key)
|
50 |
+
self._quantize_weight(weight)
|
51 |
+
|
52 |
+
# If there is a bias, let nn.Module load it
|
53 |
+
if local_state_dict:
|
54 |
+
super()._load_from_state_dict(local_state_dict, *args, **kwargs)
|
55 |
+
|
56 |
+
def _quantize_weight(self, weight: torch.Tensor) -> None:
|
57 |
+
# This code is taken and adapted from `bnb.nn.Int8Params.cuda()`
|
58 |
+
B = weight.contiguous().half().cuda()
|
59 |
+
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
|
60 |
+
del CBt
|
61 |
+
del SCBt
|
62 |
+
self.weight.data = CB
|
63 |
+
setattr(self.weight, "CB", CB)
|
64 |
+
setattr(self.weight, "SCB", SCB)
|
65 |
+
|
66 |
+
|
67 |
+
# for correctness but with terrible perf
|
68 |
+
class ColBlockQuantizedLinear(torch.nn.Module):
|
69 |
+
def __init__(self, in_features, out_features, bias: bool, *, bits, tile_cols):
|
70 |
+
super().__init__()
|
71 |
+
self.in_features = in_features
|
72 |
+
self.out_features = out_features
|
73 |
+
self.tile_cols = tile_cols if tile_cols != -1 else self.in_features
|
74 |
+
self.bits = bits
|
75 |
+
self.entries_per_byte = 8 // bits
|
76 |
+
assert self.entries_per_byte > 0 and self.entries_per_byte * self.bits == 8
|
77 |
+
assert in_features % self.entries_per_byte == 0
|
78 |
+
self.register_buffer("quant_weight", torch.empty((self.out_features, self.in_features // self.entries_per_byte), dtype=torch.uint8))
|
79 |
+
self.register_buffer("scales", torch.empty((self.out_features, (self.in_features + self.tile_cols - 1) // self.tile_cols)))
|
80 |
+
self.register_buffer("zeros", torch.empty_like(self.scales))
|
81 |
+
assert isinstance(bias, bool)
|
82 |
+
if bias:
|
83 |
+
self.register_buffer("bias", torch.empty((self.out_features,)))
|
84 |
+
else:
|
85 |
+
self.register_buffer("bias", None)
|
86 |
+
|
87 |
+
def pack_weight(self, weight):
|
88 |
+
weight = weight.to(device=self.quant_weight.device, copy=True)
|
89 |
+
for j in range(self.scales.size(1)):
|
90 |
+
weight[:, j * self.tile_cols: (j + 1) * self.tile_cols] /= self.scales[: , j: j+1]
|
91 |
+
weight[:, j * self.tile_cols: (j + 1) * self.tile_cols] += self.zeros[: , j: j+1]
|
92 |
+
weight = weight.clamp_(min=0, max=2 ** self.bits - 1).to(dtype=torch.uint8)
|
93 |
+
self.quant_weight.zero_()
|
94 |
+
for nr in range(self.entries_per_byte):
|
95 |
+
self.quant_weight += weight[:, nr::self.entries_per_byte] << (nr * self.bits)
|
96 |
+
|
97 |
+
def get_weight(self, dtype=torch.float):
|
98 |
+
weight = torch.empty((self.out_features, self.in_features), device=self.quant_weight.device, dtype=dtype)
|
99 |
+
mask = (1<<self.bits) - 1
|
100 |
+
for nr in range(self.entries_per_byte):
|
101 |
+
weight[:, nr::self.entries_per_byte] = ((self.quant_weight >> (nr * self.bits)) & mask).float()
|
102 |
+
self.quant_weight.to(dtype)
|
103 |
+
for j in range(self.scales.size(1)):
|
104 |
+
weight[:, j * self.tile_cols: (j + 1) * self.tile_cols] -= self.zeros[: , j: j+1]
|
105 |
+
weight[:, j * self.tile_cols: (j + 1) * self.tile_cols] *= self.scales[: , j: j+1]
|
106 |
+
return weight
|
107 |
+
|
108 |
+
def forward(self, inp):
|
109 |
+
weight = self.get_weight(dtype=inp.dtype)
|
110 |
+
return torch.nn.functional.linear(inp, weight, self.bias)
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
class GPTQQuantizer:
|
116 |
+
# The algorithm and code has been taken from https://github.com/IST-DASLab/gptq/
|
117 |
+
# E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
|
118 |
+
# portions copyright by the authors licensed under the Apache License 2.0
|
119 |
+
# All errors are our own.
|
120 |
+
|
121 |
+
def __init__(self, linear_module, *, bits, perchannel=True, sym=False, blocksize=128, percdamp=.01, groupsize=-1, actorder=False):
|
122 |
+
assert isinstance(linear_module, torch.nn.Linear)
|
123 |
+
|
124 |
+
self.linear_module = linear_module
|
125 |
+
self.dev = self.linear_module.weight.device
|
126 |
+
self.rows = linear_module.weight.shape[0]
|
127 |
+
self.columns = linear_module.weight.shape[1]
|
128 |
+
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
|
129 |
+
self.nsamples = 0
|
130 |
+
self.bits = bits
|
131 |
+
self.maxq = 2 ** bits - 1
|
132 |
+
self.perchannel = perchannel
|
133 |
+
self.sym = sym
|
134 |
+
self.blocksize = blocksize
|
135 |
+
self.percdamp = percdamp
|
136 |
+
self.groupsize = groupsize
|
137 |
+
self.actorder = actorder
|
138 |
+
self.tile_cols = self.columns if groupsize == -1 else groupsize
|
139 |
+
self.scales = torch.zeros((self.rows, (self.columns + self.tile_cols - 1) // self.tile_cols), dtype=self.linear_module.weight.dtype, device = self.dev)
|
140 |
+
self.zeros = torch.zeros_like(self.scales)
|
141 |
+
assert not (self.actorder and self.groupsize != -1), "The permutation trick does not work for grouped quantization"
|
142 |
+
|
143 |
+
@staticmethod
|
144 |
+
def quantize_weight(x, scale, zero, maxq):
|
145 |
+
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
|
146 |
+
x_rec = scale * (q - zero)
|
147 |
+
return x_rec
|
148 |
+
|
149 |
+
def find_params_weight(self, x):
|
150 |
+
dev = x.device
|
151 |
+
|
152 |
+
shape = x.shape
|
153 |
+
if self.perchannel:
|
154 |
+
x = x.flatten(1)
|
155 |
+
else:
|
156 |
+
x = x.flatten().unsqueeze(0)
|
157 |
+
|
158 |
+
tmp = torch.zeros(x.shape[0], device=dev)
|
159 |
+
xmin = torch.minimum(x.min(1)[0], tmp)
|
160 |
+
xmax = torch.maximum(x.max(1)[0], tmp)
|
161 |
+
|
162 |
+
if self.sym:
|
163 |
+
xmax = torch.maximum(torch.abs(xmin), xmax)
|
164 |
+
tmp = xmin < 0
|
165 |
+
if torch.any(tmp):
|
166 |
+
xmin[tmp] = -xmax[tmp]
|
167 |
+
tmp = (xmin == 0) & (xmax == 0)
|
168 |
+
xmin[tmp] = -1
|
169 |
+
xmax[tmp] = +1
|
170 |
+
|
171 |
+
scale = (xmax - xmin) / self.maxq
|
172 |
+
if self.sym:
|
173 |
+
zero = torch.full_like(scale, (self.maxq + 1) / 2)
|
174 |
+
else:
|
175 |
+
zero = torch.round(-xmin / scale)
|
176 |
+
|
177 |
+
if not self.perchannel:
|
178 |
+
tmp = shape[0]
|
179 |
+
scale = scale.repeat(tmp)
|
180 |
+
zero = zero.repeat(tmp)
|
181 |
+
|
182 |
+
shape = [-1] + [1] * (len(shape) - 1)
|
183 |
+
scale = scale.reshape(shape)
|
184 |
+
zero = zero.reshape(shape)
|
185 |
+
return scale, zero
|
186 |
+
|
187 |
+
def collect_input_stats(self, _1, inp, _2):
|
188 |
+
inp = inp[0].detach()
|
189 |
+
self.last_inp = inp
|
190 |
+
if len(inp.shape) == 2:
|
191 |
+
inp = inp.unsqueeze(0)
|
192 |
+
tmp = inp.shape[0]
|
193 |
+
if len(inp.shape) == 3:
|
194 |
+
inp = inp.reshape((-1, inp.shape[-1]))
|
195 |
+
inp = inp.t()
|
196 |
+
self.H *= self.nsamples / (self.nsamples + tmp)
|
197 |
+
self.nsamples += tmp
|
198 |
+
# inp = inp.float()
|
199 |
+
inp = math.sqrt(2 / self.nsamples) * inp.float()
|
200 |
+
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
|
201 |
+
self.H += inp.matmul(inp.t())
|
202 |
+
|
203 |
+
def quantize(self):
|
204 |
+
W = self.linear_module.weight.detach().to(dtype=torch.float, copy=True)
|
205 |
+
|
206 |
+
scale, zero = self.find_params_weight(W)
|
207 |
+
self.scales[:] = scale
|
208 |
+
self.zeros[:] = zero
|
209 |
+
|
210 |
+
H = self.H
|
211 |
+
del self.H
|
212 |
+
dead = torch.diag(H) == 0
|
213 |
+
H[dead, dead] = 1
|
214 |
+
W[:, dead] = 0
|
215 |
+
if self.actorder:
|
216 |
+
perm = torch.argsort(torch.diag(H), descending=True)
|
217 |
+
W = W[:, perm]
|
218 |
+
H = H[perm][:, perm]
|
219 |
+
|
220 |
+
Losses = torch.zeros_like(W)
|
221 |
+
Q = torch.zeros_like(W)
|
222 |
+
|
223 |
+
damp = self.percdamp * torch.mean(torch.diag(H))
|
224 |
+
diag = torch.arange(self.columns, device=self.dev)
|
225 |
+
H[diag, diag] += damp
|
226 |
+
H = torch.linalg.cholesky(H)
|
227 |
+
H = torch.cholesky_inverse(H)
|
228 |
+
H = torch.linalg.cholesky(H, upper=True)
|
229 |
+
Hinv = H
|
230 |
+
|
231 |
+
for i1 in range(0, self.columns, self.blocksize):
|
232 |
+
i2 = min(i1 + self.blocksize, self.columns)
|
233 |
+
count = i2 - i1
|
234 |
+
|
235 |
+
W1 = W[:, i1:i2].clone()
|
236 |
+
Q1 = torch.zeros_like(W1)
|
237 |
+
Err1 = torch.zeros_like(W1)
|
238 |
+
Losses1 = torch.zeros_like(W1)
|
239 |
+
Hinv1 = Hinv[i1:i2, i1:i2]
|
240 |
+
|
241 |
+
for i in range(count):
|
242 |
+
w = W1[:, i]
|
243 |
+
d = Hinv1[i, i]
|
244 |
+
|
245 |
+
if self.groupsize != -1:
|
246 |
+
if (i1 + i) % self.groupsize == 0:
|
247 |
+
scale, zero = self.find_params_weight(W[:, (i1 + i):(i1 + i + self.groupsize)])
|
248 |
+
self.scales[:, (i1 + i) // self.groupsize] = scale
|
249 |
+
self.zeros[:, (i1 + i) // self.groupsize] = zeros
|
250 |
+
|
251 |
+
q = self.quantize_weight(
|
252 |
+
w.unsqueeze(1), scale, zero, self.maxq
|
253 |
+
)
|
254 |
+
q = q.squeeze(1)
|
255 |
+
assert q.dim() == 1
|
256 |
+
Q1[:, i] = q
|
257 |
+
Losses1[:, i] = (w - q) ** 2 / d ** 2
|
258 |
+
|
259 |
+
err1 = (w - q) / d
|
260 |
+
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
|
261 |
+
Err1[:, i] = err1
|
262 |
+
|
263 |
+
Q[:, i1:i2] = Q1
|
264 |
+
Losses[:, i1:i2] = Losses1 / 2
|
265 |
+
|
266 |
+
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
|
267 |
+
|
268 |
+
if self.actorder:
|
269 |
+
invperm = torch.argsort(perm)
|
270 |
+
Q = Q[:, invperm]
|
271 |
+
|
272 |
+
weight = Q.reshape(self.linear_module.weight.shape).to(self.linear_module.weight.data.dtype)
|
273 |
+
error = torch.sum(Losses).item()
|
274 |
+
|
275 |
+
q_module = ColBlockQuantizedLinear(self.linear_module.in_features, self.linear_module.out_features, self.linear_module.bias is not None,
|
276 |
+
bits=self.bits, tile_cols=self.groupsize).to(self.dev)
|
277 |
+
q_module.scales = self.scales
|
278 |
+
q_module.zeros = self.zeros
|
279 |
+
q_module.pack_weight(weight)
|
280 |
+
q_module.bias = self.linear_module.bias
|
281 |
+
return q_module, error
|
lit_llama/tokenizer.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from sentencepiece import SentencePieceProcessor, SentencePieceTrainer
|
7 |
+
|
8 |
+
|
9 |
+
class Tokenizer:
|
10 |
+
"""Tokenizer for LLaMA."""
|
11 |
+
|
12 |
+
def __init__(self, model_path: Path) -> None:
|
13 |
+
self.processor = SentencePieceProcessor(model_file=str(model_path))
|
14 |
+
self.bos_id = self.processor.bos_id()
|
15 |
+
self.eos_id = self.processor.eos_id()
|
16 |
+
self.pad_id = self.processor.pad_id()
|
17 |
+
|
18 |
+
@property
|
19 |
+
def vocab_size(self) -> int:
|
20 |
+
return self.processor.vocab_size()
|
21 |
+
|
22 |
+
def encode(
|
23 |
+
self,
|
24 |
+
string: str,
|
25 |
+
bos: bool = True,
|
26 |
+
eos: bool = False,
|
27 |
+
max_length: int = -1,
|
28 |
+
pad: bool = False,
|
29 |
+
device: Optional[torch.device] = None
|
30 |
+
) -> torch.Tensor:
|
31 |
+
tokens = self.processor.encode(string)
|
32 |
+
if bos:
|
33 |
+
tokens = [self.bos_id] + tokens
|
34 |
+
if eos:
|
35 |
+
tokens = tokens + [self.eos_id]
|
36 |
+
if max_length > 0:
|
37 |
+
tokens = tokens[:max_length]
|
38 |
+
if pad and len(tokens) < max_length:
|
39 |
+
tokens += [self.pad_id] * (max_length - len(tokens))
|
40 |
+
|
41 |
+
return torch.tensor(tokens, dtype=torch.int, device=device)
|
42 |
+
|
43 |
+
def decode(self, tokens: torch.Tensor) -> str:
|
44 |
+
return self.processor.decode(tokens.tolist())
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def train(input: str, destination: str, vocab_size=32000) -> None:
|
48 |
+
model_prefix = os.path.join(destination, "tokenizer")
|
49 |
+
SentencePieceTrainer.Train(input=input, model_prefix=model_prefix, vocab_size=vocab_size)
|
lit_llama/utils.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions for training and inference."""
|
2 |
+
|
3 |
+
import functools
|
4 |
+
from pathlib import Path
|
5 |
+
import pickle
|
6 |
+
import warnings
|
7 |
+
from io import BytesIO
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.utils._device
|
11 |
+
from lightning.fabric.strategies import DeepSpeedStrategy, FSDPStrategy
|
12 |
+
from torch.distributed.fsdp import FullStateDictConfig
|
13 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
14 |
+
from torch.distributed.fsdp import StateDictType
|
15 |
+
|
16 |
+
|
17 |
+
def save_model_checkpoint(fabric, model, file_path):
|
18 |
+
"""Handles boilerplate logic for retrieving and saving the state_dict.
|
19 |
+
|
20 |
+
This will be upstreamed to Fabric soon.
|
21 |
+
"""
|
22 |
+
file_path = Path(file_path)
|
23 |
+
|
24 |
+
if isinstance(fabric.strategy, DeepSpeedStrategy):
|
25 |
+
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
26 |
+
|
27 |
+
fabric.save(file_path, {"model": model})
|
28 |
+
fabric.barrier()
|
29 |
+
if fabric.global_rank == 0:
|
30 |
+
# Create a consolidated checkpoint with the same name next to the deepspeed checkpoint
|
31 |
+
convert_zero_checkpoint_to_fp32_state_dict(file_path, file_path.with_suffix(".pth"))
|
32 |
+
return
|
33 |
+
|
34 |
+
if isinstance(fabric.strategy, FSDPStrategy):
|
35 |
+
save_policy = FullStateDictConfig(offload_to_cpu=(fabric.world_size > 1), rank0_only=True)
|
36 |
+
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
|
37 |
+
state_dict = model._forward_module.state_dict()
|
38 |
+
else:
|
39 |
+
state_dict = model.state_dict()
|
40 |
+
|
41 |
+
if fabric.global_rank == 0:
|
42 |
+
torch.save(state_dict, file_path)
|
43 |
+
fabric.barrier()
|
44 |
+
|
45 |
+
|
46 |
+
class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
|
47 |
+
def __init__(self, device=None, dtype=None, quantization_mode=None):
|
48 |
+
"""
|
49 |
+
Create tensors with given device and dtype and don't run initialization
|
50 |
+
(but instead use "empty tensors", i.e. uninitialized memory).
|
51 |
+
|
52 |
+
device: `torch.device` to work with
|
53 |
+
dtype: `torch.dtype` to work with
|
54 |
+
quantization_mode: optional string, quantization mode to work with, default `None`.
|
55 |
+
Available modes: `llm.int8` bitsnbytes LLM.int8 quantization (only on GPU)
|
56 |
+
`qptq.int4`, `gptq.int8`: GPTQ pre-quantized models
|
57 |
+
|
58 |
+
Example::
|
59 |
+
with EmptyInitOnDevice("cuda", dtype=torch.bfloat16):
|
60 |
+
model = LLaMA.from_name('7B')
|
61 |
+
model.load_state_dict(torch.load('llama-lit/7B/lit-llama.pth'))"""
|
62 |
+
|
63 |
+
self.quantization_mode = quantization_mode
|
64 |
+
self.quantized_linear_cls = None
|
65 |
+
if self.quantization_mode == 'llm.int8':
|
66 |
+
if device.type != "cuda":
|
67 |
+
raise ValueError("Quantization is only supported on the GPU.")
|
68 |
+
from .quantization import Linear8bitLt
|
69 |
+
self.quantized_linear_cls = Linear8bitLt
|
70 |
+
elif self.quantization_mode == 'gptq.int4':
|
71 |
+
from .quantization import ColBlockQuantizedLinear
|
72 |
+
self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1)
|
73 |
+
elif self.quantization_mode == 'gptq.int8':
|
74 |
+
from .quantization import ColBlockQuantizedLinear
|
75 |
+
self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1)
|
76 |
+
elif self.quantization_mode is not None:
|
77 |
+
raise RuntimeError(f"unknown quantization mode {self.quantization_mode}")
|
78 |
+
self.device = device
|
79 |
+
self.dtype = dtype
|
80 |
+
|
81 |
+
def __enter__(self):
|
82 |
+
if self.quantized_linear_cls != None:
|
83 |
+
self.torch_linear_cls = torch.nn.Linear
|
84 |
+
torch.nn.Linear = self.quantized_linear_cls
|
85 |
+
return super().__enter__()
|
86 |
+
|
87 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
88 |
+
if self.quantized_linear_cls != None:
|
89 |
+
torch.nn.Linear = self.torch_linear_cls
|
90 |
+
return super().__exit__(exc_type, exc_val, exc_tb)
|
91 |
+
|
92 |
+
def __torch_function__(self, func, types, args=(), kwargs=None):
|
93 |
+
kwargs = kwargs or {}
|
94 |
+
if getattr(func, "__module__", None) == "torch.nn.init":
|
95 |
+
if "tensor" in kwargs:
|
96 |
+
return kwargs["tensor"]
|
97 |
+
else:
|
98 |
+
return args[0]
|
99 |
+
if (
|
100 |
+
self.device is not None
|
101 |
+
and func in torch.utils._device._device_constructors()
|
102 |
+
and kwargs.get("device") is None
|
103 |
+
):
|
104 |
+
kwargs["device"] = self.device
|
105 |
+
if (
|
106 |
+
self.dtype is not None
|
107 |
+
and func in torch.utils._device._device_constructors()
|
108 |
+
and kwargs.get("dtype") is None
|
109 |
+
):
|
110 |
+
kwargs["dtype"] = self.dtype
|
111 |
+
return func(*args, **kwargs)
|
112 |
+
|
113 |
+
|
114 |
+
# this is taken from torchhacks https://github.com/lernapparat/torchhacks
|
115 |
+
|
116 |
+
|
117 |
+
class NotYetLoadedTensor:
|
118 |
+
def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args):
|
119 |
+
self.metatensor = metatensor
|
120 |
+
self.archiveinfo = archiveinfo
|
121 |
+
self.storageinfo = storageinfo
|
122 |
+
self.rebuild_args = rebuild_args
|
123 |
+
|
124 |
+
@classmethod
|
125 |
+
def rebuild(
|
126 |
+
cls,
|
127 |
+
storage,
|
128 |
+
storage_offset,
|
129 |
+
size,
|
130 |
+
stride,
|
131 |
+
requires_grad,
|
132 |
+
backward_hooks,
|
133 |
+
metadata=None,
|
134 |
+
archiveinfo=None,
|
135 |
+
):
|
136 |
+
rebuild_args = (
|
137 |
+
storage_offset,
|
138 |
+
size,
|
139 |
+
stride,
|
140 |
+
requires_grad,
|
141 |
+
backward_hooks,
|
142 |
+
metadata,
|
143 |
+
)
|
144 |
+
metatensor = torch._utils._rebuild_tensor_v2(
|
145 |
+
storage,
|
146 |
+
storage_offset,
|
147 |
+
size,
|
148 |
+
stride,
|
149 |
+
requires_grad,
|
150 |
+
backward_hooks,
|
151 |
+
metadata,
|
152 |
+
)
|
153 |
+
storageinfo = storage.archiveinfo
|
154 |
+
return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args)
|
155 |
+
|
156 |
+
def _load_tensor(self):
|
157 |
+
name, storage_cls, fn, device, size = self.storageinfo
|
158 |
+
dtype = self.metatensor.dtype
|
159 |
+
|
160 |
+
uts = (
|
161 |
+
self.archiveinfo.zipfile.get_storage_from_record(
|
162 |
+
f"data/{fn}",
|
163 |
+
size * torch._utils._element_size(dtype),
|
164 |
+
torch.UntypedStorage,
|
165 |
+
)
|
166 |
+
._typed_storage()
|
167 |
+
._untyped_storage
|
168 |
+
)
|
169 |
+
with warnings.catch_warnings():
|
170 |
+
warnings.simplefilter("ignore")
|
171 |
+
storage = torch.storage.TypedStorage(
|
172 |
+
wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True
|
173 |
+
)
|
174 |
+
tensor = torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args)
|
175 |
+
return tensor
|
176 |
+
|
177 |
+
@classmethod
|
178 |
+
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
179 |
+
if kwargs is None:
|
180 |
+
kwargs = {}
|
181 |
+
loaded_args = [
|
182 |
+
(a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args
|
183 |
+
]
|
184 |
+
res = func(*loaded_args, **kwargs)
|
185 |
+
# gc.collect would be costly here, maybe do it optionally
|
186 |
+
return res
|
187 |
+
|
188 |
+
def __getattr__(self, name):
|
189 |
+
# properties
|
190 |
+
## TODO: device, is_...??
|
191 |
+
## TODO: mH, mT, H, T, data, imag, real
|
192 |
+
## name ???
|
193 |
+
if name in {
|
194 |
+
"dtype",
|
195 |
+
"grad",
|
196 |
+
"grad_fn",
|
197 |
+
"layout",
|
198 |
+
"names",
|
199 |
+
"ndim",
|
200 |
+
"output_nr",
|
201 |
+
"requires_grad",
|
202 |
+
"retains_grad",
|
203 |
+
"shape",
|
204 |
+
"volatile",
|
205 |
+
}:
|
206 |
+
return getattr(self.metatensor, name)
|
207 |
+
if name in {"size"}:
|
208 |
+
return getattr(self.metatensor, name)
|
209 |
+
# materializing with contiguous is needed for quantization
|
210 |
+
if name in {"contiguous"}:
|
211 |
+
return getattr(self._load_tensor(), name)
|
212 |
+
|
213 |
+
raise AttributeError(f"{type(self)} does not have {name}")
|
214 |
+
|
215 |
+
def __repr__(self):
|
216 |
+
return f"NotYetLoadedTensor({repr(self.metatensor)})"
|
217 |
+
|
218 |
+
|
219 |
+
class LazyLoadingUnpickler(pickle.Unpickler):
|
220 |
+
def __init__(self, file, zipfile):
|
221 |
+
super().__init__(file)
|
222 |
+
self.zipfile = zipfile
|
223 |
+
|
224 |
+
def find_class(self, module, name):
|
225 |
+
if module == "torch._utils" and name == "_rebuild_tensor_v2":
|
226 |
+
res = super().find_class(module, name)
|
227 |
+
return functools.partial(NotYetLoadedTensor.rebuild, archiveinfo=self)
|
228 |
+
return super().find_class(module, name)
|
229 |
+
|
230 |
+
def persistent_load(self, pid):
|
231 |
+
name, cls, fn, device, size = pid
|
232 |
+
with warnings.catch_warnings():
|
233 |
+
warnings.simplefilter("ignore")
|
234 |
+
s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta")
|
235 |
+
s.archiveinfo = pid
|
236 |
+
return s
|
237 |
+
|
238 |
+
|
239 |
+
def lazy_load(fn):
|
240 |
+
zf = torch._C.PyTorchFileReader(str(fn))
|
241 |
+
with BytesIO(zf.get_record("data.pkl")) as pkl:
|
242 |
+
mup = LazyLoadingUnpickler(pkl, zf)
|
243 |
+
sd = mup.load()
|
244 |
+
return sd
|
models/__init__.py
ADDED
File without changes
|
models/constants.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
3 |
+
|
4 |
+
LOGDIR = "."
|
5 |
+
|
6 |
+
# Model Constants
|
7 |
+
IGNORE_INDEX = -100
|
8 |
+
X_TOKEN_INDEX = {'IMAGE': -200, 'VIDEO': -201, 'AUDIO': -202, 'THERMAL': -203, 'DEPTH': -204}
|
9 |
+
X_INDEX_TOKEN = {v: k for k, v in X_TOKEN_INDEX.items()}
|
10 |
+
# IMAGE_TOKEN_INDEX = -200
|
11 |
+
DEFAULT_X_TOKEN = {'IMAGE': "<image>", 'VIDEO': "<video>", 'AUDIO': "<audio>", 'THERMAL': "<thermal>", 'DEPTH': "<depth>"}
|
12 |
+
# DEFAULT_IMAGE_TOKEN = "<image>"
|
13 |
+
DEFAULT_X_PATCH_TOKEN = {'IMAGE': "<im_patch>", 'VIDEO': "<vi_patch>", 'AUDIO': "<au_patch>", 'THERMAL': "<th_patch>", 'DEPTH': "<de_patch>"}
|
14 |
+
# DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
15 |
+
DEFAULT_X_START_TOKEN = {'IMAGE': "<im_start>", 'VIDEO': "<vi_start>", 'AUDIO': "<au_start>", 'THERMAL': "<th_start>", 'DEPTH': "<de_start>"}
|
16 |
+
# DEFAULT_IM_START_TOKEN = "<im_start>"
|
17 |
+
DEFAULT_X_END_TOKEN = {'IMAGE': "<im_end>", 'VIDEO': "<vi_end>", 'AUDIO': "<au_end>", 'THERMAL': "<th_end>", 'DEPTH': "<de_end>"}
|
18 |
+
# DEFAULT_IM_END_TOKEN = "<im_end>"
|
models/encdec.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from models.resnet import Resnet1D
|
3 |
+
|
4 |
+
class Encoder(nn.Module):
|
5 |
+
def __init__(self,
|
6 |
+
input_emb_width = 3,
|
7 |
+
output_emb_width = 512,
|
8 |
+
down_t = 3,
|
9 |
+
stride_t = 2,
|
10 |
+
width = 512,
|
11 |
+
depth = 3,
|
12 |
+
dilation_growth_rate = 3,
|
13 |
+
activation='relu',
|
14 |
+
norm=None):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
blocks = []
|
18 |
+
filter_t, pad_t = stride_t * 2, stride_t // 2
|
19 |
+
blocks.append(nn.Conv1d(input_emb_width, width, 3, 1, 1))
|
20 |
+
blocks.append(nn.ReLU())
|
21 |
+
|
22 |
+
for i in range(down_t):
|
23 |
+
input_dim = width
|
24 |
+
block = nn.Sequential(
|
25 |
+
nn.Conv1d(input_dim, width, filter_t, stride_t, pad_t),
|
26 |
+
Resnet1D(width, depth, dilation_growth_rate, activation=activation, norm=norm),
|
27 |
+
)
|
28 |
+
blocks.append(block)
|
29 |
+
blocks.append(nn.Conv1d(width, output_emb_width, 3, 1, 1))
|
30 |
+
self.model = nn.Sequential(*blocks)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
return self.model(x)
|
34 |
+
|
35 |
+
class Decoder(nn.Module):
|
36 |
+
def __init__(self,
|
37 |
+
input_emb_width = 3,
|
38 |
+
output_emb_width = 512,
|
39 |
+
down_t = 3,
|
40 |
+
stride_t = 2,
|
41 |
+
width = 512,
|
42 |
+
depth = 3,
|
43 |
+
dilation_growth_rate = 3,
|
44 |
+
activation='relu',
|
45 |
+
norm=None):
|
46 |
+
super().__init__()
|
47 |
+
blocks = []
|
48 |
+
|
49 |
+
filter_t, pad_t = stride_t * 2, stride_t // 2
|
50 |
+
blocks.append(nn.Conv1d(output_emb_width, width, 3, 1, 1))
|
51 |
+
blocks.append(nn.ReLU())
|
52 |
+
for i in range(down_t):
|
53 |
+
out_dim = width
|
54 |
+
block = nn.Sequential(
|
55 |
+
Resnet1D(width, depth, dilation_growth_rate, reverse_dilation=True, activation=activation, norm=norm),
|
56 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
57 |
+
nn.Conv1d(width, out_dim, 3, 1, 1)
|
58 |
+
)
|
59 |
+
blocks.append(block)
|
60 |
+
blocks.append(nn.Conv1d(width, width, 3, 1, 1))
|
61 |
+
blocks.append(nn.ReLU())
|
62 |
+
blocks.append(nn.Conv1d(width, input_emb_width, 3, 1, 1))
|
63 |
+
self.model = nn.Sequential(*blocks)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
return self.model(x)
|
67 |
+
|
models/evaluator_wrapper.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
from os.path import join as pjoin
|
4 |
+
import numpy as np
|
5 |
+
from models.modules import MovementConvEncoder, TextEncoderBiGRUCo, MotionEncoderBiGRUCo
|
6 |
+
from utils.word_vectorizer import POS_enumerator
|
7 |
+
|
8 |
+
def build_models(opt):
|
9 |
+
movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent)
|
10 |
+
text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word,
|
11 |
+
pos_size=opt.dim_pos_ohot,
|
12 |
+
hidden_size=opt.dim_text_hidden,
|
13 |
+
output_size=opt.dim_coemb_hidden,
|
14 |
+
device=opt.device)
|
15 |
+
|
16 |
+
motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent,
|
17 |
+
hidden_size=opt.dim_motion_hidden,
|
18 |
+
output_size=opt.dim_coemb_hidden,
|
19 |
+
device=opt.device)
|
20 |
+
|
21 |
+
checkpoint = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'),
|
22 |
+
map_location=opt.device)
|
23 |
+
movement_enc.load_state_dict(checkpoint['movement_encoder'])
|
24 |
+
text_enc.load_state_dict(checkpoint['text_encoder'])
|
25 |
+
motion_enc.load_state_dict(checkpoint['motion_encoder'])
|
26 |
+
print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch']))
|
27 |
+
return text_enc, motion_enc, movement_enc
|
28 |
+
|
29 |
+
|
30 |
+
class EvaluatorModelWrapper(object):
|
31 |
+
|
32 |
+
def __init__(self, opt):
|
33 |
+
|
34 |
+
if opt.dataset_name == 't2m':
|
35 |
+
opt.dim_pose = 263
|
36 |
+
elif opt.dataset_name == 'kit':
|
37 |
+
opt.dim_pose = 251
|
38 |
+
else:
|
39 |
+
raise KeyError('Dataset not Recognized!!!')
|
40 |
+
|
41 |
+
opt.dim_word = 300
|
42 |
+
opt.max_motion_length = 196
|
43 |
+
opt.dim_pos_ohot = len(POS_enumerator)
|
44 |
+
opt.dim_motion_hidden = 1024
|
45 |
+
opt.max_text_len = 20
|
46 |
+
opt.dim_text_hidden = 512
|
47 |
+
opt.dim_coemb_hidden = 512
|
48 |
+
|
49 |
+
# print(opt)
|
50 |
+
|
51 |
+
self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt)
|
52 |
+
self.opt = opt
|
53 |
+
self.device = opt.device
|
54 |
+
|
55 |
+
self.text_encoder.to(opt.device)
|
56 |
+
self.motion_encoder.to(opt.device)
|
57 |
+
self.movement_encoder.to(opt.device)
|
58 |
+
|
59 |
+
self.text_encoder.eval()
|
60 |
+
self.motion_encoder.eval()
|
61 |
+
self.movement_encoder.eval()
|
62 |
+
|
63 |
+
# Please note that the results does not following the order of inputs
|
64 |
+
def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens):
|
65 |
+
with torch.no_grad():
|
66 |
+
word_embs = word_embs.detach().to(self.device).float()
|
67 |
+
pos_ohot = pos_ohot.detach().to(self.device).float()
|
68 |
+
motions = motions.detach().to(self.device).float()
|
69 |
+
|
70 |
+
'''Movement Encoding'''
|
71 |
+
movements = self.movement_encoder(motions[..., :-4]).detach()
|
72 |
+
m_lens = m_lens // self.opt.unit_length
|
73 |
+
motion_embedding = self.motion_encoder(movements, m_lens)
|
74 |
+
|
75 |
+
'''Text Encoding'''
|
76 |
+
text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens)
|
77 |
+
return text_embedding, motion_embedding
|
78 |
+
|
79 |
+
# Please note that the results does not following the order of inputs
|
80 |
+
def get_motion_embeddings(self, motions, m_lens):
|
81 |
+
with torch.no_grad():
|
82 |
+
motions = motions.detach().to(self.device).float()
|
83 |
+
|
84 |
+
align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
|
85 |
+
motions = motions[align_idx]
|
86 |
+
m_lens = m_lens[align_idx]
|
87 |
+
|
88 |
+
'''Movement Encoding'''
|
89 |
+
movements = self.movement_encoder(motions[..., :-4]).detach()
|
90 |
+
m_lens = m_lens // self.opt.unit_length
|
91 |
+
motion_embedding = self.motion_encoder(movements, m_lens)
|
92 |
+
return motion_embedding
|
models/modules.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn.utils.rnn import pack_padded_sequence
|
4 |
+
|
5 |
+
def init_weight(m):
|
6 |
+
if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d):
|
7 |
+
nn.init.xavier_normal_(m.weight)
|
8 |
+
# m.bias.data.fill_(0.01)
|
9 |
+
if m.bias is not None:
|
10 |
+
nn.init.constant_(m.bias, 0)
|
11 |
+
|
12 |
+
|
13 |
+
class MovementConvEncoder(nn.Module):
|
14 |
+
def __init__(self, input_size, hidden_size, output_size):
|
15 |
+
super(MovementConvEncoder, self).__init__()
|
16 |
+
self.main = nn.Sequential(
|
17 |
+
nn.Conv1d(input_size, hidden_size, 4, 2, 1),
|
18 |
+
nn.Dropout(0.2, inplace=True),
|
19 |
+
nn.LeakyReLU(0.2, inplace=True),
|
20 |
+
nn.Conv1d(hidden_size, output_size, 4, 2, 1),
|
21 |
+
nn.Dropout(0.2, inplace=True),
|
22 |
+
nn.LeakyReLU(0.2, inplace=True),
|
23 |
+
)
|
24 |
+
self.out_net = nn.Linear(output_size, output_size)
|
25 |
+
self.main.apply(init_weight)
|
26 |
+
self.out_net.apply(init_weight)
|
27 |
+
|
28 |
+
def forward(self, inputs):
|
29 |
+
inputs = inputs.permute(0, 2, 1)
|
30 |
+
outputs = self.main(inputs).permute(0, 2, 1)
|
31 |
+
# print(outputs.shape)
|
32 |
+
return self.out_net(outputs)
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
class TextEncoderBiGRUCo(nn.Module):
|
37 |
+
def __init__(self, word_size, pos_size, hidden_size, output_size, device):
|
38 |
+
super(TextEncoderBiGRUCo, self).__init__()
|
39 |
+
self.device = device
|
40 |
+
|
41 |
+
self.pos_emb = nn.Linear(pos_size, word_size)
|
42 |
+
self.input_emb = nn.Linear(word_size, hidden_size)
|
43 |
+
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
|
44 |
+
self.output_net = nn.Sequential(
|
45 |
+
nn.Linear(hidden_size * 2, hidden_size),
|
46 |
+
nn.LayerNorm(hidden_size),
|
47 |
+
nn.LeakyReLU(0.2, inplace=True),
|
48 |
+
nn.Linear(hidden_size, output_size)
|
49 |
+
)
|
50 |
+
|
51 |
+
self.input_emb.apply(init_weight)
|
52 |
+
self.pos_emb.apply(init_weight)
|
53 |
+
self.output_net.apply(init_weight)
|
54 |
+
self.hidden_size = hidden_size
|
55 |
+
self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
|
56 |
+
|
57 |
+
# input(batch_size, seq_len, dim)
|
58 |
+
def forward(self, word_embs, pos_onehot, cap_lens):
|
59 |
+
num_samples = word_embs.shape[0]
|
60 |
+
|
61 |
+
pos_embs = self.pos_emb(pos_onehot)
|
62 |
+
inputs = word_embs + pos_embs
|
63 |
+
input_embs = self.input_emb(inputs)
|
64 |
+
hidden = self.hidden.repeat(1, num_samples, 1)
|
65 |
+
|
66 |
+
cap_lens = cap_lens.data.tolist()
|
67 |
+
emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
|
68 |
+
|
69 |
+
gru_seq, gru_last = self.gru(emb, hidden)
|
70 |
+
|
71 |
+
gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
|
72 |
+
|
73 |
+
return self.output_net(gru_last)
|
74 |
+
|
75 |
+
|
76 |
+
class MotionEncoderBiGRUCo(nn.Module):
|
77 |
+
def __init__(self, input_size, hidden_size, output_size, device):
|
78 |
+
super(MotionEncoderBiGRUCo, self).__init__()
|
79 |
+
self.device = device
|
80 |
+
|
81 |
+
self.input_emb = nn.Linear(input_size, hidden_size)
|
82 |
+
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
|
83 |
+
self.output_net = nn.Sequential(
|
84 |
+
nn.Linear(hidden_size*2, hidden_size),
|
85 |
+
nn.LayerNorm(hidden_size),
|
86 |
+
nn.LeakyReLU(0.2, inplace=True),
|
87 |
+
nn.Linear(hidden_size, output_size)
|
88 |
+
)
|
89 |
+
|
90 |
+
self.input_emb.apply(init_weight)
|
91 |
+
self.output_net.apply(init_weight)
|
92 |
+
self.hidden_size = hidden_size
|
93 |
+
self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
|
94 |
+
|
95 |
+
# input(batch_size, seq_len, dim)
|
96 |
+
def forward(self, inputs, m_lens):
|
97 |
+
num_samples = inputs.shape[0]
|
98 |
+
|
99 |
+
input_embs = self.input_emb(inputs)
|
100 |
+
hidden = self.hidden.repeat(1, num_samples, 1)
|
101 |
+
|
102 |
+
cap_lens = m_lens.data.tolist()
|
103 |
+
emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True, enforce_sorted=False)
|
104 |
+
|
105 |
+
gru_seq, gru_last = self.gru(emb, hidden)
|
106 |
+
|
107 |
+
gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
|
108 |
+
|
109 |
+
return self.output_net(gru_last)
|
models/multimodal_encoder/builder.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .clip_encoder import CLIPVisionTower
|
3 |
+
from .languagebind import LanguageBindImageTower, LanguageBindVideoTower
|
4 |
+
from .mae_encoder import MAEVisionTower
|
5 |
+
from transformers import CLIPModel
|
6 |
+
|
7 |
+
def build_image_tower(image_tower_cfg, **kwargs):
|
8 |
+
image_tower = getattr(image_tower_cfg, 'mm_image_tower', getattr(image_tower_cfg, 'image_tower', None))
|
9 |
+
is_absolute_path_exists = os.path.exists(image_tower)
|
10 |
+
if is_absolute_path_exists or image_tower.startswith("openai") or image_tower.startswith("laion"):
|
11 |
+
return CLIPVisionTower(image_tower, args=image_tower_cfg, **kwargs)
|
12 |
+
if image_tower.endswith('LanguageBind_Image'):
|
13 |
+
return LanguageBindImageTower(image_tower, args=image_tower_cfg, cache_dir='./cache_dir', **kwargs)
|
14 |
+
if 'mae' in image_tower:
|
15 |
+
print('maemaemaemaemaemaemaemae')
|
16 |
+
print('maemaemaemaemaemaemaemae')
|
17 |
+
print('maemaemaemaemaemaemaemae')
|
18 |
+
print('maemaemaemaemaemaemaemae')
|
19 |
+
print('maemaemaemaemaemaemaemae')
|
20 |
+
return MAEVisionTower(image_tower, args=image_tower_cfg, cache_dir='./cache_dir', **kwargs)
|
21 |
+
raise ValueError(f'Unknown image tower: {image_tower}')
|
22 |
+
|
23 |
+
def build_video_tower(video_tower_cfg, **kwargs):
|
24 |
+
video_tower = getattr(video_tower_cfg, 'mm_video_tower', getattr(video_tower_cfg, 'video_tower', None))
|
25 |
+
if video_tower.endswith('LanguageBind_Video_merge'):
|
26 |
+
return LanguageBindVideoTower(video_tower, args=video_tower_cfg, cache_dir='./cache_dir', **kwargs)
|
27 |
+
raise ValueError(f'Unknown video tower: {video_tower}')
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
# import os
|
32 |
+
# from .clip_encoder import CLIPVisionTower
|
33 |
+
# from .languagebind import LanguageBindImageTower, LanguageBindVideoTower
|
34 |
+
# from transformers import CLIPModel
|
35 |
+
|
36 |
+
# def build_image_tower(image_tower_cfg, **kwargs):
|
37 |
+
# image_tower = getattr(image_tower_cfg, 'mm_image_tower', getattr(image_tower_cfg, 'image_tower', None))
|
38 |
+
# is_absolute_path_exists = os.path.exists(image_tower)
|
39 |
+
# if is_absolute_path_exists or image_tower.startswith("openai") or image_tower.startswith("laion"):
|
40 |
+
# return CLIPVisionTower(image_tower, args=image_tower_cfg, **kwargs)
|
41 |
+
# if image_tower.endswith('LanguageBind_Image'):
|
42 |
+
# return LanguageBindImageTower(image_tower, args=image_tower_cfg, cache_dir='./cache_dir', **kwargs)
|
43 |
+
# raise ValueError(f'Unknown image tower: {image_tower}')
|
44 |
+
|
45 |
+
# def build_video_tower(video_tower_cfg, **kwargs):
|
46 |
+
# video_tower = getattr(video_tower_cfg, 'mm_video_tower', getattr(video_tower_cfg, 'video_tower', None))
|
47 |
+
# if video_tower.endswith('LanguageBind_Video'):
|
48 |
+
# return LanguageBindVideoTower(video_tower, args=video_tower_cfg, cache_dir='./cache_dir', **kwargs)
|
49 |
+
# raise ValueError(f'Unknown video tower: {video_tower}')
|
models/multimodal_encoder/clip_encoder.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
5 |
+
|
6 |
+
|
7 |
+
class CLIPVisionTower(nn.Module):
|
8 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.is_loaded = False
|
12 |
+
|
13 |
+
self.vision_tower_name = vision_tower
|
14 |
+
self.select_layer = args.mm_vision_select_layer
|
15 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
16 |
+
|
17 |
+
if not delay_load:
|
18 |
+
self.load_model()
|
19 |
+
else:
|
20 |
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
21 |
+
|
22 |
+
def load_model(self):
|
23 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
24 |
+
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
|
25 |
+
self.vision_tower.requires_grad_(False)
|
26 |
+
|
27 |
+
self.is_loaded = True
|
28 |
+
|
29 |
+
def feature_select(self, image_forward_outs):
|
30 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
31 |
+
if self.select_feature == 'patch':
|
32 |
+
image_features = image_features[:, 1:]
|
33 |
+
elif self.select_feature == 'cls_patch':
|
34 |
+
image_features = image_features
|
35 |
+
else:
|
36 |
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
37 |
+
return image_features
|
38 |
+
|
39 |
+
@torch.no_grad()
|
40 |
+
def forward(self, images):
|
41 |
+
if type(images) is list:
|
42 |
+
image_features = []
|
43 |
+
for image in images:
|
44 |
+
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
45 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
46 |
+
image_features.append(image_feature)
|
47 |
+
else:
|
48 |
+
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
49 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
50 |
+
|
51 |
+
return image_features
|
52 |
+
|
53 |
+
@property
|
54 |
+
def dummy_feature(self):
|
55 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
56 |
+
|
57 |
+
@property
|
58 |
+
def dtype(self):
|
59 |
+
return self.vision_tower.dtype
|
60 |
+
|
61 |
+
@property
|
62 |
+
def device(self):
|
63 |
+
return self.vision_tower.device
|
64 |
+
|
65 |
+
@property
|
66 |
+
def config(self):
|
67 |
+
if self.is_loaded:
|
68 |
+
return self.vision_tower.config
|
69 |
+
else:
|
70 |
+
return self.cfg_only
|
71 |
+
|
72 |
+
@property
|
73 |
+
def hidden_size(self):
|
74 |
+
return self.config.hidden_size
|
75 |
+
|
76 |
+
@property
|
77 |
+
def num_patches(self):
|
78 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
models/multimodal_encoder/languagebind/__init__.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from transformers import AutoConfig
|
4 |
+
|
5 |
+
from .image.configuration_image import LanguageBindImageConfig
|
6 |
+
from .image.modeling_image import LanguageBindImage
|
7 |
+
from .image.tokenization_image import LanguageBindImageTokenizer
|
8 |
+
from .image.processing_image import LanguageBindImageProcessor
|
9 |
+
|
10 |
+
from .video.configuration_video import LanguageBindVideoConfig
|
11 |
+
from .video.modeling_video import LanguageBindVideo
|
12 |
+
from .video.tokenization_video import LanguageBindVideoTokenizer
|
13 |
+
from .video.processing_video import LanguageBindVideoProcessor
|
14 |
+
|
15 |
+
from .depth.configuration_depth import LanguageBindDepthConfig
|
16 |
+
from .depth.modeling_depth import LanguageBindDepth
|
17 |
+
from .depth.tokenization_depth import LanguageBindDepthTokenizer
|
18 |
+
from .depth.processing_depth import LanguageBindDepthProcessor
|
19 |
+
|
20 |
+
from .audio.configuration_audio import LanguageBindAudioConfig
|
21 |
+
from .audio.modeling_audio import LanguageBindAudio
|
22 |
+
from .audio.tokenization_audio import LanguageBindAudioTokenizer
|
23 |
+
from .audio.processing_audio import LanguageBindAudioProcessor
|
24 |
+
|
25 |
+
from .thermal.configuration_thermal import LanguageBindThermalConfig
|
26 |
+
from .thermal.modeling_thermal import LanguageBindThermal
|
27 |
+
from .thermal.tokenization_thermal import LanguageBindThermalTokenizer
|
28 |
+
from .thermal.processing_thermal import LanguageBindThermalProcessor
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
config_dict = {
|
33 |
+
'thermal': LanguageBindThermalConfig,
|
34 |
+
'image': LanguageBindImageConfig,
|
35 |
+
'video': LanguageBindVideoConfig,
|
36 |
+
'depth': LanguageBindDepthConfig,
|
37 |
+
'audio': LanguageBindAudioConfig
|
38 |
+
}
|
39 |
+
model_dict = {
|
40 |
+
'thermal': LanguageBindThermal,
|
41 |
+
'image': LanguageBindImage,
|
42 |
+
'video': LanguageBindVideo,
|
43 |
+
'depth': LanguageBindDepth,
|
44 |
+
'audio': LanguageBindAudio
|
45 |
+
}
|
46 |
+
transform_dict = {
|
47 |
+
'video': LanguageBindVideoProcessor,
|
48 |
+
'audio': LanguageBindAudioProcessor,
|
49 |
+
'depth': LanguageBindDepthProcessor,
|
50 |
+
'thermal': LanguageBindThermalProcessor,
|
51 |
+
'image': LanguageBindImageProcessor,
|
52 |
+
}
|
53 |
+
|
54 |
+
class LanguageBind(nn.Module):
|
55 |
+
def __init__(self, clip_type=('thermal', 'image', 'video', 'depth', 'audio'), use_temp=True, cache_dir='./cache_dir'):
|
56 |
+
super(LanguageBind, self).__init__()
|
57 |
+
self.use_temp = use_temp
|
58 |
+
self.modality_encoder = {}
|
59 |
+
self.modality_proj = {}
|
60 |
+
self.modality_scale = {}
|
61 |
+
self.modality_config = {}
|
62 |
+
for c in clip_type:
|
63 |
+
pretrained_ckpt = f'LanguageBind/LanguageBind_{c.capitalize()}'
|
64 |
+
model = model_dict[c].from_pretrained(pretrained_ckpt, cache_dir=cache_dir)
|
65 |
+
self.modality_encoder[c] = model.vision_model
|
66 |
+
self.modality_proj[c] = model.visual_projection
|
67 |
+
self.modality_scale[c] = model.logit_scale
|
68 |
+
self.modality_config[c] = model.config
|
69 |
+
self.modality_encoder['language'] = model.text_model
|
70 |
+
self.modality_proj['language'] = model.text_projection
|
71 |
+
|
72 |
+
self.modality_encoder = nn.ModuleDict(self.modality_encoder)
|
73 |
+
self.modality_proj = nn.ModuleDict(self.modality_proj)
|
74 |
+
|
75 |
+
def forward(self, inputs):
|
76 |
+
outputs = {}
|
77 |
+
for key, value in inputs.items():
|
78 |
+
value = self.modality_encoder[key](**value)[1]
|
79 |
+
value = self.modality_proj[key](value)
|
80 |
+
value = value / value.norm(p=2, dim=-1, keepdim=True)
|
81 |
+
if self.use_temp:
|
82 |
+
if key != 'language':
|
83 |
+
value = value * self.modality_scale[key].exp()
|
84 |
+
outputs[key] = value
|
85 |
+
return outputs
|
86 |
+
|
87 |
+
def to_device(x, device):
|
88 |
+
out_dict = {k: v.to(device) for k, v in x.items()}
|
89 |
+
return out_dict
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
class LanguageBindImageTower(nn.Module):
|
95 |
+
def __init__(self, image_tower, args, delay_load=False, cache_dir='./cache_dir'):
|
96 |
+
super().__init__()
|
97 |
+
# import pdb; pdb.set_trace()
|
98 |
+
self.is_loaded = False
|
99 |
+
|
100 |
+
self.image_tower_name = image_tower
|
101 |
+
self.select_layer = args.mm_vision_select_layer
|
102 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
103 |
+
|
104 |
+
self.cache_dir = cache_dir
|
105 |
+
|
106 |
+
if not delay_load:
|
107 |
+
self.load_model()
|
108 |
+
else:
|
109 |
+
# import pdb; pdb.set_trace()
|
110 |
+
self.cfg_only = LanguageBindImageConfig.from_pretrained(self.image_tower_name, cache_dir=self.cache_dir)
|
111 |
+
|
112 |
+
############################################################
|
113 |
+
def load_model(self):
|
114 |
+
model = LanguageBindImage.from_pretrained(self.image_tower_name, cache_dir=self.cache_dir)
|
115 |
+
self.image_tower = model.vision_model
|
116 |
+
self.image_tower.requires_grad_(False)
|
117 |
+
|
118 |
+
self.image_processor = LanguageBindImageProcessor(model.config)
|
119 |
+
|
120 |
+
self.is_loaded = True
|
121 |
+
|
122 |
+
def feature_select(self, image_forward_outs):
|
123 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
124 |
+
if self.select_feature == 'patch':
|
125 |
+
image_features = image_features[:, 1:]
|
126 |
+
elif self.select_feature == 'cls_patch':
|
127 |
+
image_features = image_features
|
128 |
+
else:
|
129 |
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
130 |
+
return image_features
|
131 |
+
|
132 |
+
@torch.no_grad()
|
133 |
+
def forward(self, images):
|
134 |
+
if type(images) is list:
|
135 |
+
image_features = []
|
136 |
+
for image in images:
|
137 |
+
image_forward_out = self.image_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
138 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
139 |
+
image_features.append(image_feature)
|
140 |
+
else:
|
141 |
+
# print('images', images.shape)
|
142 |
+
image_forward_outs = self.image_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
143 |
+
# print('image_forward_outs', len(image_forward_outs), image_forward_outs[0].shape)
|
144 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
145 |
+
# print('image_features', image_features.shape)
|
146 |
+
|
147 |
+
return image_features
|
148 |
+
|
149 |
+
@property
|
150 |
+
def dummy_feature(self):
|
151 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
152 |
+
|
153 |
+
@property
|
154 |
+
def dtype(self):
|
155 |
+
return self.image_tower.embeddings.class_embedding.dtype #############
|
156 |
+
|
157 |
+
@property
|
158 |
+
def device(self):
|
159 |
+
return self.image_tower.embeddings.class_embedding.device ##############
|
160 |
+
|
161 |
+
@property
|
162 |
+
def config(self):
|
163 |
+
if self.is_loaded:
|
164 |
+
return self.image_tower.config
|
165 |
+
else:
|
166 |
+
return self.cfg_only
|
167 |
+
|
168 |
+
@property
|
169 |
+
def hidden_size(self):
|
170 |
+
return self.config.hidden_size
|
171 |
+
|
172 |
+
@property
|
173 |
+
def num_patches(self):
|
174 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
175 |
+
|
176 |
+
class temp_model(nn.Module):
|
177 |
+
def __init__(self):
|
178 |
+
super(temp_model, self).__init__()
|
179 |
+
def forward(self, **kwargs):
|
180 |
+
return torch.randn(25, 1, 256, 1024)
|
181 |
+
|
182 |
+
|
183 |
+
class LanguageBindVideoTower(nn.Module):
|
184 |
+
def __init__(self, video_tower, args, delay_load=False, cache_dir='./cache_dir'):
|
185 |
+
super().__init__()
|
186 |
+
|
187 |
+
self.is_loaded = False
|
188 |
+
|
189 |
+
self.video_tower_name = video_tower
|
190 |
+
self.select_layer = args.mm_vision_select_layer
|
191 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
192 |
+
|
193 |
+
self.cache_dir = cache_dir
|
194 |
+
|
195 |
+
if not delay_load:
|
196 |
+
self.load_model()
|
197 |
+
else:
|
198 |
+
self.cfg_only = LanguageBindVideoConfig.from_pretrained(self.video_tower_name, cache_dir=self.cache_dir)
|
199 |
+
|
200 |
+
## 使用deley load, from_pretrained 之后,self.is_loaded 仍然是false
|
201 |
+
# import pdb; pdb.set_trace()
|
202 |
+
|
203 |
+
############################################################
|
204 |
+
def load_model(self):
|
205 |
+
model = LanguageBindVideo.from_pretrained(self.video_tower_name, cache_dir=self.cache_dir)
|
206 |
+
self.video_processor = LanguageBindVideoProcessor(model.config)
|
207 |
+
|
208 |
+
|
209 |
+
# model = LanguageBindImage.from_pretrained('LanguageBind/LanguageBind_Image', cache_dir=self.cache_dir)
|
210 |
+
self.video_tower = model.vision_model
|
211 |
+
self.video_tower.requires_grad_(False)
|
212 |
+
|
213 |
+
|
214 |
+
self.is_loaded = True
|
215 |
+
|
216 |
+
# def feature_select(self, image_forward_outs):
|
217 |
+
# image_features = image_forward_outs.hidden_states[self.select_layer]
|
218 |
+
# if self.select_feature == 'patch':
|
219 |
+
# image_features = image_features[:, 1:]
|
220 |
+
# elif self.select_feature == 'cls_patch':
|
221 |
+
# image_features = image_features
|
222 |
+
# else:
|
223 |
+
# raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
224 |
+
# return image_features
|
225 |
+
|
226 |
+
def feature_select(self, video_forward_outs):
|
227 |
+
# print('len(video_forward_outs.hidden_states)', len(video_forward_outs.hidden_states))
|
228 |
+
video_features = video_forward_outs.hidden_states[self.select_layer] # b t n c
|
229 |
+
b, t, n, c = video_features.shape
|
230 |
+
# print('video_features', video_features.shape)
|
231 |
+
if self.select_feature == 'patch':
|
232 |
+
# video_features = video_features[:, 1:]
|
233 |
+
video_features = video_features[:, :, 1:]
|
234 |
+
video_features = video_features.reshape(b, -1, c)
|
235 |
+
elif self.select_feature == 'cls_patch':
|
236 |
+
# video_features = video_features
|
237 |
+
video_features = video_features.reshape(b, -1, c)
|
238 |
+
else:
|
239 |
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
240 |
+
return video_features
|
241 |
+
|
242 |
+
@torch.no_grad()
|
243 |
+
def forward(self, videos):
|
244 |
+
# import pdb; pdb.set_trace()
|
245 |
+
if type(videos) is list:
|
246 |
+
video_features = []
|
247 |
+
for video in videos:
|
248 |
+
video_forward_out = self.video_tower(video.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
249 |
+
video_feature = self.feature_select(video_forward_out).to(video.dtype)
|
250 |
+
video_features.append(video_feature)
|
251 |
+
else:
|
252 |
+
# print(11111111111, videos.shape)
|
253 |
+
video_forward_outs = self.video_tower(videos.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
254 |
+
video_features = self.feature_select(video_forward_outs).to(videos.dtype)
|
255 |
+
|
256 |
+
return video_features
|
257 |
+
|
258 |
+
@property
|
259 |
+
def dummy_feature(self):
|
260 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
261 |
+
|
262 |
+
@property
|
263 |
+
def dtype(self):
|
264 |
+
return self.video_tower.embeddings.class_embedding.dtype #############
|
265 |
+
# return torch.randn(1).cuda().dtype
|
266 |
+
|
267 |
+
@property
|
268 |
+
def device(self):
|
269 |
+
return self.video_tower.embeddings.class_embedding.device ##############
|
270 |
+
# return torch.randn(1).cuda().device
|
271 |
+
|
272 |
+
@property
|
273 |
+
def config(self):
|
274 |
+
if self.is_loaded:
|
275 |
+
return self.video_tower.config
|
276 |
+
else:
|
277 |
+
return self.cfg_only
|
278 |
+
|
279 |
+
@property
|
280 |
+
def hidden_size(self):
|
281 |
+
return self.config.hidden_size
|
282 |
+
|
283 |
+
@property
|
284 |
+
def num_patches(self):
|
285 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
models/multimodal_encoder/languagebind/audio/configuration_audio.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
from transformers import PretrainedConfig
|
6 |
+
from transformers.utils import logging
|
7 |
+
|
8 |
+
logger = logging.get_logger(__name__)
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
class CLIPTextConfig(PretrainedConfig):
|
17 |
+
r"""
|
18 |
+
This is the configuration class to store the configuration of a [`CLIPTextModel`]. It is used to instantiate a CLIP
|
19 |
+
text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
|
20 |
+
with the defaults will yield a similar configuration to that of the text encoder of the CLIP
|
21 |
+
[openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
|
22 |
+
|
23 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
24 |
+
documentation from [`PretrainedConfig`] for more information.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
vocab_size (`int`, *optional*, defaults to 49408):
|
28 |
+
Vocabulary size of the CLIP text model. Defines the number of different tokens that can be represented by
|
29 |
+
the `inputs_ids` passed when calling [`CLIPModel`].
|
30 |
+
hidden_size (`int`, *optional*, defaults to 512):
|
31 |
+
Dimensionality of the encoder layers and the pooler layer.
|
32 |
+
intermediate_size (`int`, *optional*, defaults to 2048):
|
33 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
34 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
35 |
+
Number of hidden layers in the Transformer encoder.
|
36 |
+
num_attention_heads (`int`, *optional*, defaults to 8):
|
37 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
38 |
+
max_position_embeddings (`int`, *optional*, defaults to 77):
|
39 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
40 |
+
just in case (e.g., 512 or 1024 or 2048).
|
41 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
|
42 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
43 |
+
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
44 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
45 |
+
The epsilon used by the layer normalization layers.
|
46 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
47 |
+
The dropout ratio for the attention probabilities.
|
48 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
49 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
50 |
+
initializer_factor (`float`, *optional*, defaults to 1):
|
51 |
+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
52 |
+
testing).
|
53 |
+
|
54 |
+
Example:
|
55 |
+
|
56 |
+
```python
|
57 |
+
>>> from transformers import CLIPTextConfig, CLIPTextModel
|
58 |
+
|
59 |
+
>>> # Initializing a CLIPTextConfig with openai/clip-vit-base-patch32 style configuration
|
60 |
+
>>> configuration = CLIPTextConfig()
|
61 |
+
|
62 |
+
>>> # Initializing a CLIPTextModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
|
63 |
+
>>> model = CLIPTextModel(configuration)
|
64 |
+
|
65 |
+
>>> # Accessing the model configuration
|
66 |
+
>>> configuration = model.config
|
67 |
+
```"""
|
68 |
+
model_type = "clip_text_model"
|
69 |
+
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
vocab_size=49408,
|
73 |
+
hidden_size=512,
|
74 |
+
intermediate_size=2048,
|
75 |
+
projection_dim=512,
|
76 |
+
num_hidden_layers=12,
|
77 |
+
num_attention_heads=8,
|
78 |
+
max_position_embeddings=77,
|
79 |
+
hidden_act="quick_gelu",
|
80 |
+
layer_norm_eps=1e-5,
|
81 |
+
attention_dropout=0.0,
|
82 |
+
initializer_range=0.02,
|
83 |
+
initializer_factor=1.0,
|
84 |
+
# This differs from `CLIPTokenizer`'s default and from openai/clip
|
85 |
+
# See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
|
86 |
+
pad_token_id=1,
|
87 |
+
bos_token_id=49406,
|
88 |
+
eos_token_id=49407,
|
89 |
+
**kwargs,
|
90 |
+
):
|
91 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
92 |
+
|
93 |
+
self.vocab_size = vocab_size
|
94 |
+
self.hidden_size = hidden_size
|
95 |
+
self.intermediate_size = intermediate_size
|
96 |
+
self.projection_dim = projection_dim
|
97 |
+
self.num_hidden_layers = num_hidden_layers
|
98 |
+
self.num_attention_heads = num_attention_heads
|
99 |
+
self.max_position_embeddings = max_position_embeddings
|
100 |
+
self.layer_norm_eps = layer_norm_eps
|
101 |
+
self.hidden_act = hidden_act
|
102 |
+
self.initializer_range = initializer_range
|
103 |
+
self.initializer_factor = initializer_factor
|
104 |
+
self.attention_dropout = attention_dropout
|
105 |
+
self.add_time_attn = False ######################################
|
106 |
+
|
107 |
+
@classmethod
|
108 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
109 |
+
cls._set_token_in_kwargs(kwargs)
|
110 |
+
|
111 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
112 |
+
|
113 |
+
# get the text config dict if we are loading from CLIPConfig
|
114 |
+
if config_dict.get("model_type") == "clip":
|
115 |
+
config_dict = config_dict["text_config"]
|
116 |
+
|
117 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
118 |
+
logger.warning(
|
119 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
120 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
121 |
+
)
|
122 |
+
|
123 |
+
return cls.from_dict(config_dict, **kwargs)
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
class CLIPVisionConfig(PretrainedConfig):
|
129 |
+
r"""
|
130 |
+
This is the configuration class to store the configuration of a [`CLIPVisionModel`]. It is used to instantiate a
|
131 |
+
CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
132 |
+
configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP
|
133 |
+
[openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
|
134 |
+
|
135 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
136 |
+
documentation from [`PretrainedConfig`] for more information.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
140 |
+
Dimensionality of the encoder layers and the pooler layer.
|
141 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
142 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
143 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
144 |
+
Number of hidden layers in the Transformer encoder.
|
145 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
146 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
147 |
+
image_size (`int`, *optional*, defaults to 224):
|
148 |
+
The size (resolution) of each image.
|
149 |
+
patch_size (`int`, *optional*, defaults to 32):
|
150 |
+
The size (resolution) of each patch.
|
151 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
|
152 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
153 |
+
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
|
154 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
155 |
+
The epsilon used by the layer normalization layers.
|
156 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
157 |
+
The dropout ratio for the attention probabilities.
|
158 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
159 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
160 |
+
initializer_factor (`float`, *optional*, defaults to 1):
|
161 |
+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
162 |
+
testing).
|
163 |
+
|
164 |
+
Example:
|
165 |
+
|
166 |
+
```python
|
167 |
+
>>> from transformers import CLIPVisionConfig, CLIPVisionModel
|
168 |
+
|
169 |
+
>>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration
|
170 |
+
>>> configuration = CLIPVisionConfig()
|
171 |
+
|
172 |
+
>>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
|
173 |
+
>>> model = CLIPVisionModel(configuration)
|
174 |
+
|
175 |
+
>>> # Accessing the model configuration
|
176 |
+
>>> configuration = model.config
|
177 |
+
```"""
|
178 |
+
|
179 |
+
model_type = "clip_vision_model"
|
180 |
+
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
hidden_size=768,
|
184 |
+
intermediate_size=3072,
|
185 |
+
projection_dim=512,
|
186 |
+
num_hidden_layers=12,
|
187 |
+
num_attention_heads=12,
|
188 |
+
num_channels=3,
|
189 |
+
image_size=224,
|
190 |
+
patch_size=32,
|
191 |
+
hidden_act="quick_gelu",
|
192 |
+
layer_norm_eps=1e-5,
|
193 |
+
attention_dropout=0.0,
|
194 |
+
initializer_range=0.02,
|
195 |
+
initializer_factor=1.0,
|
196 |
+
|
197 |
+
add_time_attn=False, ################################
|
198 |
+
num_frames=1, ################################
|
199 |
+
force_patch_dropout=0.0, ################################
|
200 |
+
lora_r=2, ################################
|
201 |
+
lora_alpha=16, ################################
|
202 |
+
lora_dropout=0.0, ################################
|
203 |
+
num_mel_bins=0.0, ################################
|
204 |
+
target_length=0.0, ################################
|
205 |
+
video_decode_backend='decord', #########################
|
206 |
+
audio_sample_rate=16000,
|
207 |
+
audio_mean=0.5,
|
208 |
+
audio_std=0.5,
|
209 |
+
**kwargs,
|
210 |
+
):
|
211 |
+
super().__init__(**kwargs)
|
212 |
+
|
213 |
+
self.hidden_size = hidden_size
|
214 |
+
self.intermediate_size = intermediate_size
|
215 |
+
self.projection_dim = projection_dim
|
216 |
+
self.num_hidden_layers = num_hidden_layers
|
217 |
+
self.num_attention_heads = num_attention_heads
|
218 |
+
self.num_channels = num_channels
|
219 |
+
self.patch_size = patch_size
|
220 |
+
self.image_size = image_size
|
221 |
+
self.initializer_range = initializer_range
|
222 |
+
self.initializer_factor = initializer_factor
|
223 |
+
self.attention_dropout = attention_dropout
|
224 |
+
self.layer_norm_eps = layer_norm_eps
|
225 |
+
self.hidden_act = hidden_act
|
226 |
+
|
227 |
+
self.add_time_attn = add_time_attn ################
|
228 |
+
self.num_frames = num_frames ################
|
229 |
+
self.force_patch_dropout = force_patch_dropout ################
|
230 |
+
self.lora_r = lora_r ################
|
231 |
+
self.lora_alpha = lora_alpha ################
|
232 |
+
self.lora_dropout = lora_dropout ################
|
233 |
+
self.num_mel_bins = num_mel_bins ################
|
234 |
+
self.target_length = target_length ################
|
235 |
+
self.video_decode_backend = video_decode_backend ################
|
236 |
+
|
237 |
+
self.audio_sample_rate = audio_sample_rate
|
238 |
+
self.audio_mean = audio_mean
|
239 |
+
self.audio_std = audio_std
|
240 |
+
|
241 |
+
@classmethod
|
242 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
243 |
+
cls._set_token_in_kwargs(kwargs)
|
244 |
+
|
245 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
246 |
+
|
247 |
+
# get the vision config dict if we are loading from CLIPConfig
|
248 |
+
if config_dict.get("model_type") == "clip":
|
249 |
+
config_dict = config_dict["vision_config"]
|
250 |
+
|
251 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
252 |
+
logger.warning(
|
253 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
254 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
255 |
+
)
|
256 |
+
|
257 |
+
return cls.from_dict(config_dict, **kwargs)
|
258 |
+
|
259 |
+
|
260 |
+
class LanguageBindAudioConfig(PretrainedConfig):
|
261 |
+
r"""
|
262 |
+
[`CLIPConfig`] is the configuration class to store the configuration of a [`CLIPModel`]. It is used to instantiate
|
263 |
+
a CLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating
|
264 |
+
a configuration with the defaults will yield a similar configuration to that of the CLIP
|
265 |
+
[openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
|
266 |
+
|
267 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
268 |
+
documentation from [`PretrainedConfig`] for more information.
|
269 |
+
|
270 |
+
Args:
|
271 |
+
text_config (`dict`, *optional*):
|
272 |
+
Dictionary of configuration options used to initialize [`CLIPTextConfig`].
|
273 |
+
vision_config (`dict`, *optional*):
|
274 |
+
Dictionary of configuration options used to initialize [`CLIPVisionConfig`].
|
275 |
+
projection_dim (`int`, *optional*, defaults to 512):
|
276 |
+
Dimentionality of text and vision projection layers.
|
277 |
+
logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
|
278 |
+
The inital value of the *logit_scale* paramter. Default is used as per the original CLIP implementation.
|
279 |
+
kwargs (*optional*):
|
280 |
+
Dictionary of keyword arguments.
|
281 |
+
|
282 |
+
Example:
|
283 |
+
|
284 |
+
```python
|
285 |
+
>>> from transformers import CLIPConfig, CLIPModel
|
286 |
+
|
287 |
+
>>> # Initializing a CLIPConfig with openai/clip-vit-base-patch32 style configuration
|
288 |
+
>>> configuration = CLIPConfig()
|
289 |
+
|
290 |
+
>>> # Initializing a CLIPModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
|
291 |
+
>>> model = CLIPModel(configuration)
|
292 |
+
|
293 |
+
>>> # Accessing the model configuration
|
294 |
+
>>> configuration = model.config
|
295 |
+
|
296 |
+
>>> # We can also initialize a CLIPConfig from a CLIPTextConfig and a CLIPVisionConfig
|
297 |
+
>>> from transformers import CLIPTextConfig, CLIPVisionConfig
|
298 |
+
|
299 |
+
>>> # Initializing a CLIPText and CLIPVision configuration
|
300 |
+
>>> config_text = CLIPTextConfig()
|
301 |
+
>>> config_vision = CLIPVisionConfig()
|
302 |
+
|
303 |
+
>>> config = CLIPConfig.from_text_vision_configs(config_text, config_vision)
|
304 |
+
```"""
|
305 |
+
|
306 |
+
model_type = "LanguageBindAudio"
|
307 |
+
is_composition = True
|
308 |
+
|
309 |
+
def __init__(
|
310 |
+
self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
|
311 |
+
):
|
312 |
+
# If `_config_dict` exist, we use them for the backward compatibility.
|
313 |
+
# We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
|
314 |
+
# of confusion!).
|
315 |
+
text_config_dict = kwargs.pop("text_config_dict", None)
|
316 |
+
vision_config_dict = kwargs.pop("vision_config_dict", None)
|
317 |
+
|
318 |
+
super().__init__(**kwargs)
|
319 |
+
|
320 |
+
# Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
|
321 |
+
# `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
|
322 |
+
# cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
|
323 |
+
if text_config_dict is not None:
|
324 |
+
if text_config is None:
|
325 |
+
text_config = {}
|
326 |
+
|
327 |
+
# This is the complete result when using `text_config_dict`.
|
328 |
+
_text_config_dict = CLIPTextConfig(**text_config_dict).to_dict()
|
329 |
+
|
330 |
+
# Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
|
331 |
+
for key, value in _text_config_dict.items():
|
332 |
+
if key in text_config and value != text_config[key] and key not in ["transformers_version"]:
|
333 |
+
# If specified in `text_config_dict`
|
334 |
+
if key in text_config_dict:
|
335 |
+
message = (
|
336 |
+
f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. "
|
337 |
+
f'The value `text_config_dict["{key}"]` will be used instead.'
|
338 |
+
)
|
339 |
+
# If inferred from default argument values (just to be super careful)
|
340 |
+
else:
|
341 |
+
message = (
|
342 |
+
f"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The "
|
343 |
+
f'value `text_config["{key}"]` will be overriden.'
|
344 |
+
)
|
345 |
+
logger.warning(message)
|
346 |
+
|
347 |
+
# Update all values in `text_config` with the ones in `_text_config_dict`.
|
348 |
+
text_config.update(_text_config_dict)
|
349 |
+
|
350 |
+
if vision_config_dict is not None:
|
351 |
+
if vision_config is None:
|
352 |
+
vision_config = {}
|
353 |
+
|
354 |
+
# This is the complete result when using `vision_config_dict`.
|
355 |
+
_vision_config_dict = CLIPVisionConfig(**vision_config_dict).to_dict()
|
356 |
+
# convert keys to string instead of integer
|
357 |
+
if "id2label" in _vision_config_dict:
|
358 |
+
_vision_config_dict["id2label"] = {
|
359 |
+
str(key): value for key, value in _vision_config_dict["id2label"].items()
|
360 |
+
}
|
361 |
+
|
362 |
+
# Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different.
|
363 |
+
for key, value in _vision_config_dict.items():
|
364 |
+
if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]:
|
365 |
+
# If specified in `vision_config_dict`
|
366 |
+
if key in vision_config_dict:
|
367 |
+
message = (
|
368 |
+
f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different "
|
369 |
+
f'values. The value `vision_config_dict["{key}"]` will be used instead.'
|
370 |
+
)
|
371 |
+
# If inferred from default argument values (just to be super careful)
|
372 |
+
else:
|
373 |
+
message = (
|
374 |
+
f"`vision_config_dict` is provided which will be used to initialize `CLIPVisionConfig`. "
|
375 |
+
f'The value `vision_config["{key}"]` will be overriden.'
|
376 |
+
)
|
377 |
+
logger.warning(message)
|
378 |
+
|
379 |
+
# Update all values in `vision_config` with the ones in `_vision_config_dict`.
|
380 |
+
vision_config.update(_vision_config_dict)
|
381 |
+
|
382 |
+
if text_config is None:
|
383 |
+
text_config = {}
|
384 |
+
logger.info("`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.")
|
385 |
+
|
386 |
+
if vision_config is None:
|
387 |
+
vision_config = {}
|
388 |
+
logger.info("`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.")
|
389 |
+
|
390 |
+
self.text_config = CLIPTextConfig(**text_config)
|
391 |
+
self.vision_config = CLIPVisionConfig(**vision_config)
|
392 |
+
|
393 |
+
self.projection_dim = projection_dim
|
394 |
+
self.logit_scale_init_value = logit_scale_init_value
|
395 |
+
self.initializer_factor = 1.0
|
396 |
+
|
397 |
+
@classmethod
|
398 |
+
def from_text_vision_configs(cls, text_config: CLIPTextConfig, vision_config: CLIPVisionConfig, **kwargs):
|
399 |
+
r"""
|
400 |
+
Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model
|
401 |
+
configuration.
|
402 |
+
|
403 |
+
Returns:
|
404 |
+
[`CLIPConfig`]: An instance of a configuration object
|
405 |
+
"""
|
406 |
+
|
407 |
+
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
|
408 |
+
|
409 |
+
def to_dict(self):
|
410 |
+
"""
|
411 |
+
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
412 |
+
|
413 |
+
Returns:
|
414 |
+
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
415 |
+
"""
|
416 |
+
output = copy.deepcopy(self.__dict__)
|
417 |
+
output["text_config"] = self.text_config.to_dict()
|
418 |
+
output["vision_config"] = self.vision_config.to_dict()
|
419 |
+
output["model_type"] = self.__class__.model_type
|
420 |
+
return output
|
421 |
+
|
422 |
+
|
423 |
+
|
424 |
+
|
425 |
+
|
426 |
+
|
427 |
+
|
428 |
+
|
429 |
+
|
430 |
+
|
models/multimodal_encoder/languagebind/audio/modeling_audio.py
ADDED
@@ -0,0 +1,1030 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from einops import rearrange
|
6 |
+
from peft import LoraConfig, get_peft_model
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from transformers import PreTrainedModel, add_start_docstrings
|
10 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
11 |
+
from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \
|
12 |
+
CLIPVisionModelWithProjection, CLIPTextModelWithProjection, _expand_mask, CLIPOutput, clip_loss
|
13 |
+
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
|
14 |
+
|
15 |
+
from .configuration_audio import LanguageBindAudioConfig, CLIPVisionConfig, CLIPTextConfig
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
class PatchDropout(nn.Module):
|
20 |
+
"""
|
21 |
+
https://arxiv.org/abs/2212.00794
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, prob, exclude_first_token=True):
|
25 |
+
super().__init__()
|
26 |
+
assert 0 <= prob < 1.
|
27 |
+
self.prob = prob
|
28 |
+
self.exclude_first_token = exclude_first_token # exclude CLS token
|
29 |
+
|
30 |
+
def forward(self, x, B, T):
|
31 |
+
if not self.training or self.prob == 0.:
|
32 |
+
return x
|
33 |
+
|
34 |
+
if self.exclude_first_token:
|
35 |
+
cls_tokens, x = x[:, :1], x[:, 1:]
|
36 |
+
else:
|
37 |
+
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
38 |
+
|
39 |
+
batch = x.size()[0]
|
40 |
+
num_tokens = x.size()[1]
|
41 |
+
|
42 |
+
batch_indices = torch.arange(batch)
|
43 |
+
batch_indices = batch_indices[..., None]
|
44 |
+
|
45 |
+
keep_prob = 1 - self.prob
|
46 |
+
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
47 |
+
|
48 |
+
if T == 1:
|
49 |
+
rand = torch.randn(batch, num_tokens)
|
50 |
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
51 |
+
else:
|
52 |
+
rand = torch.randn(B, num_tokens)
|
53 |
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
54 |
+
patch_indices_keep = patch_indices_keep.unsqueeze(1).repeat(1, T, 1)
|
55 |
+
patch_indices_keep = rearrange(patch_indices_keep, 'b t n -> (b t) n')
|
56 |
+
|
57 |
+
|
58 |
+
x = x[batch_indices, patch_indices_keep]
|
59 |
+
|
60 |
+
if self.exclude_first_token:
|
61 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
62 |
+
|
63 |
+
return x
|
64 |
+
|
65 |
+
class CLIPEncoderLayer(nn.Module):
|
66 |
+
def __init__(self, config: LanguageBindAudioConfig):
|
67 |
+
super().__init__()
|
68 |
+
self.embed_dim = config.hidden_size
|
69 |
+
self.self_attn = CLIPAttention(config)
|
70 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
71 |
+
self.mlp = CLIPMLP(config)
|
72 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
73 |
+
|
74 |
+
self.add_time_attn = config.add_time_attn
|
75 |
+
if self.add_time_attn:
|
76 |
+
self.t = config.num_frames
|
77 |
+
self.temporal_embedding = nn.Parameter(torch.zeros(1, config.num_frames, config.hidden_size))
|
78 |
+
nn.init.normal_(self.temporal_embedding, std=config.hidden_size ** -0.5)
|
79 |
+
|
80 |
+
self.embed_dim = config.hidden_size
|
81 |
+
self.temporal_attn = CLIPAttention(config)
|
82 |
+
self.temporal_layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
83 |
+
self.temporal_mlp = CLIPMLP(config)
|
84 |
+
self.temporal_layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
85 |
+
|
86 |
+
def forward(
|
87 |
+
self,
|
88 |
+
hidden_states: torch.Tensor,
|
89 |
+
attention_mask: torch.Tensor,
|
90 |
+
causal_attention_mask: torch.Tensor,
|
91 |
+
output_attentions: Optional[bool] = False,
|
92 |
+
) -> Tuple[torch.FloatTensor]:
|
93 |
+
"""
|
94 |
+
Args:
|
95 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
96 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
97 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
98 |
+
`(config.encoder_attention_heads,)`.
|
99 |
+
output_attentions (`bool`, *optional*):
|
100 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
101 |
+
returned tensors for more detail.
|
102 |
+
"""
|
103 |
+
|
104 |
+
|
105 |
+
if self.add_time_attn:
|
106 |
+
bt, n, d = hidden_states.shape
|
107 |
+
t = self.t
|
108 |
+
|
109 |
+
# time embed
|
110 |
+
if t != 1:
|
111 |
+
n = hidden_states.shape[1]
|
112 |
+
hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
|
113 |
+
hidden_states = hidden_states + self.temporal_embedding[:, :t, :]
|
114 |
+
hidden_states = rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
|
115 |
+
|
116 |
+
# time attn
|
117 |
+
residual = hidden_states
|
118 |
+
hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
|
119 |
+
# hidden_states = self.layer_norm1(hidden_states) # share layernorm
|
120 |
+
hidden_states = self.temporal_layer_norm1(hidden_states)
|
121 |
+
hidden_states, attn_weights = self.temporal_attn(
|
122 |
+
hidden_states=hidden_states,
|
123 |
+
attention_mask=attention_mask,
|
124 |
+
causal_attention_mask=causal_attention_mask,
|
125 |
+
output_attentions=output_attentions,
|
126 |
+
)
|
127 |
+
hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
|
128 |
+
|
129 |
+
residual = hidden_states
|
130 |
+
hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
|
131 |
+
# hidden_states = self.layer_norm2(hidden_states) # share layernorm
|
132 |
+
hidden_states = self.temporal_layer_norm2(hidden_states)
|
133 |
+
hidden_states = self.temporal_mlp(hidden_states)
|
134 |
+
hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
|
135 |
+
|
136 |
+
# spatial attn
|
137 |
+
residual = hidden_states
|
138 |
+
|
139 |
+
hidden_states = self.layer_norm1(hidden_states)
|
140 |
+
hidden_states, attn_weights = self.self_attn(
|
141 |
+
hidden_states=hidden_states,
|
142 |
+
attention_mask=attention_mask,
|
143 |
+
causal_attention_mask=causal_attention_mask,
|
144 |
+
output_attentions=output_attentions,
|
145 |
+
)
|
146 |
+
hidden_states = residual + hidden_states
|
147 |
+
|
148 |
+
residual = hidden_states
|
149 |
+
hidden_states = self.layer_norm2(hidden_states)
|
150 |
+
hidden_states = self.mlp(hidden_states)
|
151 |
+
hidden_states = residual + hidden_states
|
152 |
+
|
153 |
+
outputs = (hidden_states,)
|
154 |
+
|
155 |
+
if output_attentions:
|
156 |
+
outputs += (attn_weights,)
|
157 |
+
|
158 |
+
return outputs
|
159 |
+
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
class CLIPPreTrainedModel(PreTrainedModel):
|
169 |
+
"""
|
170 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
171 |
+
models.
|
172 |
+
"""
|
173 |
+
|
174 |
+
config_class = LanguageBindAudioConfig
|
175 |
+
base_model_prefix = "clip"
|
176 |
+
supports_gradient_checkpointing = True
|
177 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
178 |
+
|
179 |
+
def _init_weights(self, module):
|
180 |
+
"""Initialize the weights"""
|
181 |
+
factor = self.config.initializer_factor
|
182 |
+
if isinstance(module, CLIPTextEmbeddings):
|
183 |
+
module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
184 |
+
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
185 |
+
elif isinstance(module, CLIPVisionEmbeddings):
|
186 |
+
factor = self.config.initializer_factor
|
187 |
+
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
188 |
+
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
189 |
+
nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
|
190 |
+
elif isinstance(module, CLIPAttention):
|
191 |
+
factor = self.config.initializer_factor
|
192 |
+
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
193 |
+
out_proj_std = (module.embed_dim**-0.5) * factor
|
194 |
+
nn.init.normal_(module.q_proj.weight, std=in_proj_std)
|
195 |
+
nn.init.normal_(module.k_proj.weight, std=in_proj_std)
|
196 |
+
nn.init.normal_(module.v_proj.weight, std=in_proj_std)
|
197 |
+
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
|
198 |
+
elif isinstance(module, CLIPMLP):
|
199 |
+
factor = self.config.initializer_factor
|
200 |
+
in_proj_std = (
|
201 |
+
(module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
202 |
+
)
|
203 |
+
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
|
204 |
+
nn.init.normal_(module.fc1.weight, std=fc_std)
|
205 |
+
nn.init.normal_(module.fc2.weight, std=in_proj_std)
|
206 |
+
elif isinstance(module, LanguageBindAudio):
|
207 |
+
nn.init.normal_(
|
208 |
+
module.text_projection.weight,
|
209 |
+
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
|
210 |
+
)
|
211 |
+
nn.init.normal_(
|
212 |
+
module.visual_projection.weight,
|
213 |
+
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
|
214 |
+
)
|
215 |
+
elif isinstance(module, CLIPVisionModelWithProjection):
|
216 |
+
nn.init.normal_(
|
217 |
+
module.visual_projection.weight,
|
218 |
+
std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
|
219 |
+
)
|
220 |
+
elif isinstance(module, CLIPTextModelWithProjection):
|
221 |
+
nn.init.normal_(
|
222 |
+
module.text_projection.weight,
|
223 |
+
std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
|
224 |
+
)
|
225 |
+
|
226 |
+
if isinstance(module, nn.LayerNorm):
|
227 |
+
module.bias.data.zero_()
|
228 |
+
module.weight.data.fill_(1.0)
|
229 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
230 |
+
module.bias.data.zero_()
|
231 |
+
|
232 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
233 |
+
if isinstance(module, CLIPEncoder):
|
234 |
+
module.gradient_checkpointing = value
|
235 |
+
|
236 |
+
|
237 |
+
CLIP_START_DOCSTRING = r"""
|
238 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
239 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
240 |
+
etc.)
|
241 |
+
|
242 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
243 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
244 |
+
and behavior.
|
245 |
+
|
246 |
+
Parameters:
|
247 |
+
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
|
248 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
249 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
250 |
+
"""
|
251 |
+
|
252 |
+
CLIP_TEXT_INPUTS_DOCSTRING = r"""
|
253 |
+
Args:
|
254 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
255 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
256 |
+
it.
|
257 |
+
|
258 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
259 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
260 |
+
|
261 |
+
[What are input IDs?](../glossary#input-ids)
|
262 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
263 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
264 |
+
|
265 |
+
- 1 for tokens that are **not masked**,
|
266 |
+
- 0 for tokens that are **masked**.
|
267 |
+
|
268 |
+
[What are attention masks?](../glossary#attention-mask)
|
269 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
270 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
271 |
+
config.max_position_embeddings - 1]`.
|
272 |
+
|
273 |
+
[What are position IDs?](../glossary#position-ids)
|
274 |
+
output_attentions (`bool`, *optional*):
|
275 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
276 |
+
tensors for more detail.
|
277 |
+
output_hidden_states (`bool`, *optional*):
|
278 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
279 |
+
more detail.
|
280 |
+
return_dict (`bool`, *optional*):
|
281 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
282 |
+
"""
|
283 |
+
|
284 |
+
CLIP_VISION_INPUTS_DOCSTRING = r"""
|
285 |
+
Args:
|
286 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
287 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
288 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
289 |
+
output_attentions (`bool`, *optional*):
|
290 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
291 |
+
tensors for more detail.
|
292 |
+
output_hidden_states (`bool`, *optional*):
|
293 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
294 |
+
more detail.
|
295 |
+
return_dict (`bool`, *optional*):
|
296 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
297 |
+
"""
|
298 |
+
|
299 |
+
CLIP_INPUTS_DOCSTRING = r"""
|
300 |
+
Args:
|
301 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
302 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
303 |
+
it.
|
304 |
+
|
305 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
306 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
307 |
+
|
308 |
+
[What are input IDs?](../glossary#input-ids)
|
309 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
310 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
311 |
+
|
312 |
+
- 1 for tokens that are **not masked**,
|
313 |
+
- 0 for tokens that are **masked**.
|
314 |
+
|
315 |
+
[What are attention masks?](../glossary#attention-mask)
|
316 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
317 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
318 |
+
config.max_position_embeddings - 1]`.
|
319 |
+
|
320 |
+
[What are position IDs?](../glossary#position-ids)
|
321 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
322 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
323 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
324 |
+
return_loss (`bool`, *optional*):
|
325 |
+
Whether or not to return the contrastive loss.
|
326 |
+
output_attentions (`bool`, *optional*):
|
327 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
328 |
+
tensors for more detail.
|
329 |
+
output_hidden_states (`bool`, *optional*):
|
330 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
331 |
+
more detail.
|
332 |
+
return_dict (`bool`, *optional*):
|
333 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
334 |
+
"""
|
335 |
+
|
336 |
+
|
337 |
+
class CLIPEncoder(nn.Module):
|
338 |
+
"""
|
339 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
340 |
+
[`CLIPEncoderLayer`].
|
341 |
+
|
342 |
+
Args:
|
343 |
+
config: CLIPConfig
|
344 |
+
"""
|
345 |
+
|
346 |
+
def __init__(self, config: LanguageBindAudioConfig):
|
347 |
+
super().__init__()
|
348 |
+
self.config = config
|
349 |
+
self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
350 |
+
self.gradient_checkpointing = False
|
351 |
+
|
352 |
+
def forward(
|
353 |
+
self,
|
354 |
+
inputs_embeds,
|
355 |
+
attention_mask: Optional[torch.Tensor] = None,
|
356 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
357 |
+
output_attentions: Optional[bool] = None,
|
358 |
+
output_hidden_states: Optional[bool] = None,
|
359 |
+
return_dict: Optional[bool] = None,
|
360 |
+
) -> Union[Tuple, BaseModelOutput]:
|
361 |
+
r"""
|
362 |
+
Args:
|
363 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
364 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
365 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
366 |
+
than the model's internal embedding lookup matrix.
|
367 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
368 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
369 |
+
|
370 |
+
- 1 for tokens that are **not masked**,
|
371 |
+
- 0 for tokens that are **masked**.
|
372 |
+
|
373 |
+
[What are attention masks?](../glossary#attention-mask)
|
374 |
+
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
375 |
+
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
376 |
+
|
377 |
+
- 1 for tokens that are **not masked**,
|
378 |
+
- 0 for tokens that are **masked**.
|
379 |
+
|
380 |
+
[What are attention masks?](../glossary#attention-mask)
|
381 |
+
output_attentions (`bool`, *optional*):
|
382 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
383 |
+
returned tensors for more detail.
|
384 |
+
output_hidden_states (`bool`, *optional*):
|
385 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
386 |
+
for more detail.
|
387 |
+
return_dict (`bool`, *optional*):
|
388 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
389 |
+
"""
|
390 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
391 |
+
output_hidden_states = (
|
392 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
393 |
+
)
|
394 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
395 |
+
|
396 |
+
encoder_states = () if output_hidden_states else None
|
397 |
+
all_attentions = () if output_attentions else None
|
398 |
+
|
399 |
+
hidden_states = inputs_embeds
|
400 |
+
for idx, encoder_layer in enumerate(self.layers):
|
401 |
+
if output_hidden_states:
|
402 |
+
encoder_states = encoder_states + (hidden_states,)
|
403 |
+
if self.gradient_checkpointing and self.training:
|
404 |
+
|
405 |
+
def create_custom_forward(module):
|
406 |
+
def custom_forward(*inputs):
|
407 |
+
return module(*inputs, output_attentions)
|
408 |
+
|
409 |
+
return custom_forward
|
410 |
+
|
411 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
412 |
+
create_custom_forward(encoder_layer),
|
413 |
+
hidden_states,
|
414 |
+
attention_mask,
|
415 |
+
causal_attention_mask,
|
416 |
+
)
|
417 |
+
else:
|
418 |
+
layer_outputs = encoder_layer(
|
419 |
+
hidden_states,
|
420 |
+
attention_mask,
|
421 |
+
causal_attention_mask,
|
422 |
+
output_attentions=output_attentions,
|
423 |
+
)
|
424 |
+
|
425 |
+
hidden_states = layer_outputs[0]
|
426 |
+
|
427 |
+
if output_attentions:
|
428 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
429 |
+
|
430 |
+
if output_hidden_states:
|
431 |
+
encoder_states = encoder_states + (hidden_states,)
|
432 |
+
|
433 |
+
if not return_dict:
|
434 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
435 |
+
return BaseModelOutput(
|
436 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
437 |
+
)
|
438 |
+
|
439 |
+
|
440 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
441 |
+
def _make_causal_mask(
|
442 |
+
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
443 |
+
):
|
444 |
+
"""
|
445 |
+
Make causal mask used for bi-directional self-attention.
|
446 |
+
"""
|
447 |
+
bsz, tgt_len = input_ids_shape
|
448 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
449 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
450 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
451 |
+
mask = mask.to(dtype)
|
452 |
+
|
453 |
+
if past_key_values_length > 0:
|
454 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
455 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
456 |
+
|
457 |
+
|
458 |
+
class CLIPTextTransformer(nn.Module):
|
459 |
+
def __init__(self, config: CLIPTextConfig):
|
460 |
+
super().__init__()
|
461 |
+
self.config = config
|
462 |
+
embed_dim = config.hidden_size
|
463 |
+
self.embeddings = CLIPTextEmbeddings(config)
|
464 |
+
self.encoder = CLIPEncoder(config)
|
465 |
+
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
466 |
+
|
467 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
468 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
469 |
+
def forward(
|
470 |
+
self,
|
471 |
+
input_ids: Optional[torch.Tensor] = None,
|
472 |
+
attention_mask: Optional[torch.Tensor] = None,
|
473 |
+
position_ids: Optional[torch.Tensor] = None,
|
474 |
+
output_attentions: Optional[bool] = None,
|
475 |
+
output_hidden_states: Optional[bool] = None,
|
476 |
+
return_dict: Optional[bool] = None,
|
477 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
478 |
+
r"""
|
479 |
+
Returns:
|
480 |
+
|
481 |
+
"""
|
482 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
483 |
+
output_hidden_states = (
|
484 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
485 |
+
)
|
486 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
487 |
+
|
488 |
+
if input_ids is None:
|
489 |
+
raise ValueError("You have to specify input_ids")
|
490 |
+
|
491 |
+
input_shape = input_ids.size()
|
492 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
493 |
+
|
494 |
+
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
495 |
+
|
496 |
+
# CLIP's text model uses causal mask, prepare it here.
|
497 |
+
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
498 |
+
causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
|
499 |
+
# expand attention_mask
|
500 |
+
if attention_mask is not None:
|
501 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
502 |
+
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
503 |
+
|
504 |
+
encoder_outputs = self.encoder(
|
505 |
+
inputs_embeds=hidden_states,
|
506 |
+
attention_mask=attention_mask,
|
507 |
+
causal_attention_mask=causal_attention_mask,
|
508 |
+
output_attentions=output_attentions,
|
509 |
+
output_hidden_states=output_hidden_states,
|
510 |
+
return_dict=return_dict,
|
511 |
+
)
|
512 |
+
|
513 |
+
last_hidden_state = encoder_outputs[0]
|
514 |
+
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
515 |
+
|
516 |
+
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
517 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
518 |
+
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
519 |
+
pooled_output = last_hidden_state[
|
520 |
+
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
521 |
+
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
522 |
+
]
|
523 |
+
|
524 |
+
if not return_dict:
|
525 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
526 |
+
|
527 |
+
return BaseModelOutputWithPooling(
|
528 |
+
last_hidden_state=last_hidden_state,
|
529 |
+
pooler_output=pooled_output,
|
530 |
+
hidden_states=encoder_outputs.hidden_states,
|
531 |
+
attentions=encoder_outputs.attentions,
|
532 |
+
)
|
533 |
+
|
534 |
+
|
535 |
+
@add_start_docstrings(
|
536 |
+
"""The text model from CLIP without any head or projection on top.""",
|
537 |
+
CLIP_START_DOCSTRING,
|
538 |
+
)
|
539 |
+
class CLIPTextModel(CLIPPreTrainedModel):
|
540 |
+
config_class = CLIPTextConfig
|
541 |
+
|
542 |
+
_no_split_modules = ["CLIPEncoderLayer"]
|
543 |
+
|
544 |
+
def __init__(self, config: CLIPTextConfig):
|
545 |
+
super().__init__(config)
|
546 |
+
self.text_model = CLIPTextTransformer(config)
|
547 |
+
# Initialize weights and apply final processing
|
548 |
+
self.post_init()
|
549 |
+
|
550 |
+
def get_input_embeddings(self) -> nn.Module:
|
551 |
+
return self.text_model.embeddings.token_embedding
|
552 |
+
|
553 |
+
def set_input_embeddings(self, value):
|
554 |
+
self.text_model.embeddings.token_embedding = value
|
555 |
+
|
556 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
557 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
558 |
+
def forward(
|
559 |
+
self,
|
560 |
+
input_ids: Optional[torch.Tensor] = None,
|
561 |
+
attention_mask: Optional[torch.Tensor] = None,
|
562 |
+
position_ids: Optional[torch.Tensor] = None,
|
563 |
+
output_attentions: Optional[bool] = None,
|
564 |
+
output_hidden_states: Optional[bool] = None,
|
565 |
+
return_dict: Optional[bool] = None,
|
566 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
567 |
+
r"""
|
568 |
+
Returns:
|
569 |
+
|
570 |
+
Examples:
|
571 |
+
|
572 |
+
```python
|
573 |
+
>>> from transformers import AutoTokenizer, CLIPTextModel
|
574 |
+
|
575 |
+
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
576 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
577 |
+
|
578 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
579 |
+
|
580 |
+
>>> outputs = model(**inputs)
|
581 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
582 |
+
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
583 |
+
```"""
|
584 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
585 |
+
|
586 |
+
return self.text_model(
|
587 |
+
input_ids=input_ids,
|
588 |
+
attention_mask=attention_mask,
|
589 |
+
position_ids=position_ids,
|
590 |
+
output_attentions=output_attentions,
|
591 |
+
output_hidden_states=output_hidden_states,
|
592 |
+
return_dict=return_dict,
|
593 |
+
)
|
594 |
+
|
595 |
+
|
596 |
+
class CLIPVisionTransformer(nn.Module):
|
597 |
+
def __init__(self, config: CLIPVisionConfig):
|
598 |
+
super().__init__()
|
599 |
+
self.config = config
|
600 |
+
embed_dim = config.hidden_size
|
601 |
+
|
602 |
+
self.embeddings = CLIPVisionEmbeddings(config)
|
603 |
+
self.patch_dropout = PatchDropout(config.force_patch_dropout)
|
604 |
+
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
605 |
+
self.encoder = CLIPEncoder(config)
|
606 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
607 |
+
|
608 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
609 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
|
610 |
+
def forward(
|
611 |
+
self,
|
612 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
613 |
+
output_attentions: Optional[bool] = None,
|
614 |
+
output_hidden_states: Optional[bool] = None,
|
615 |
+
return_dict: Optional[bool] = None,
|
616 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
617 |
+
r"""
|
618 |
+
Returns:
|
619 |
+
|
620 |
+
"""
|
621 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
622 |
+
output_hidden_states = (
|
623 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
624 |
+
)
|
625 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
626 |
+
|
627 |
+
if pixel_values is None:
|
628 |
+
raise ValueError("You have to specify pixel_values")
|
629 |
+
######################################
|
630 |
+
if len(pixel_values.shape) == 7:
|
631 |
+
b_new, pair_new, T, bs_new, channel_new, h_new, w_new = pixel_values.shape
|
632 |
+
# print(pixel_values.shape)
|
633 |
+
B = b_new * pair_new * bs_new
|
634 |
+
pixel_values = pixel_values.reshape(B*T, channel_new, h_new, w_new)
|
635 |
+
|
636 |
+
elif len(pixel_values.shape) == 5:
|
637 |
+
B, _, T, _, _ = pixel_values.shape
|
638 |
+
# print(pixel_values.shape)
|
639 |
+
pixel_values = rearrange(pixel_values, 'b c t h w -> (b t) c h w')
|
640 |
+
else:
|
641 |
+
# print(pixel_values.shape)
|
642 |
+
B, _, _, _ = pixel_values.shape
|
643 |
+
T = 1
|
644 |
+
###########################
|
645 |
+
hidden_states = self.embeddings(pixel_values)
|
646 |
+
|
647 |
+
hidden_states = self.patch_dropout(hidden_states, B, T) ##############################################
|
648 |
+
|
649 |
+
hidden_states = self.pre_layrnorm(hidden_states)
|
650 |
+
|
651 |
+
encoder_outputs = self.encoder(
|
652 |
+
inputs_embeds=hidden_states,
|
653 |
+
output_attentions=output_attentions,
|
654 |
+
output_hidden_states=output_hidden_states,
|
655 |
+
return_dict=return_dict,
|
656 |
+
)
|
657 |
+
|
658 |
+
last_hidden_state = encoder_outputs[0]
|
659 |
+
pooled_output = last_hidden_state[:, 0, :]
|
660 |
+
pooled_output = self.post_layernorm(pooled_output)
|
661 |
+
|
662 |
+
pooled_output = pooled_output.reshape(B, T, -1).mean(1) ################################
|
663 |
+
|
664 |
+
if not return_dict:
|
665 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
666 |
+
|
667 |
+
return BaseModelOutputWithPooling(
|
668 |
+
last_hidden_state=last_hidden_state,
|
669 |
+
pooler_output=pooled_output,
|
670 |
+
hidden_states=encoder_outputs.hidden_states,
|
671 |
+
attentions=encoder_outputs.attentions,
|
672 |
+
)
|
673 |
+
|
674 |
+
|
675 |
+
@add_start_docstrings(
|
676 |
+
"""The vision model from CLIP without any head or projection on top.""",
|
677 |
+
CLIP_START_DOCSTRING,
|
678 |
+
)
|
679 |
+
class CLIPVisionModel(CLIPPreTrainedModel):
|
680 |
+
config_class = CLIPVisionConfig
|
681 |
+
main_input_name = "pixel_values"
|
682 |
+
|
683 |
+
def __init__(self, config: CLIPVisionConfig):
|
684 |
+
super().__init__(config)
|
685 |
+
self.vision_model = CLIPVisionTransformer(config)
|
686 |
+
# Initialize weights and apply final processing
|
687 |
+
self.post_init()
|
688 |
+
|
689 |
+
def get_input_embeddings(self) -> nn.Module:
|
690 |
+
return self.vision_model.embeddings.patch_embedding
|
691 |
+
|
692 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
693 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
|
694 |
+
def forward(
|
695 |
+
self,
|
696 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
697 |
+
output_attentions: Optional[bool] = None,
|
698 |
+
output_hidden_states: Optional[bool] = None,
|
699 |
+
return_dict: Optional[bool] = None,
|
700 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
701 |
+
r"""
|
702 |
+
Returns:
|
703 |
+
|
704 |
+
Examples:
|
705 |
+
|
706 |
+
```python
|
707 |
+
>>> from PIL import Image
|
708 |
+
>>> import requests
|
709 |
+
>>> from transformers import AutoProcessor, CLIPVisionModel
|
710 |
+
|
711 |
+
>>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
|
712 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
713 |
+
|
714 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
715 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
716 |
+
|
717 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
718 |
+
|
719 |
+
>>> outputs = model(**inputs)
|
720 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
721 |
+
>>> pooled_output = outputs.pooler_output # pooled CLS states
|
722 |
+
```"""
|
723 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
724 |
+
|
725 |
+
return self.vision_model(
|
726 |
+
pixel_values=pixel_values,
|
727 |
+
output_attentions=output_attentions,
|
728 |
+
output_hidden_states=output_hidden_states,
|
729 |
+
return_dict=return_dict,
|
730 |
+
)
|
731 |
+
|
732 |
+
|
733 |
+
@add_start_docstrings(CLIP_START_DOCSTRING)
|
734 |
+
class LanguageBindAudio(CLIPPreTrainedModel):
|
735 |
+
config_class = LanguageBindAudioConfig
|
736 |
+
|
737 |
+
def __init__(self, config: LanguageBindAudioConfig):
|
738 |
+
super().__init__(config)
|
739 |
+
|
740 |
+
if not isinstance(config.text_config, CLIPTextConfig):
|
741 |
+
raise ValueError(
|
742 |
+
"config.text_config is expected to be of type CLIPTextConfig but is of type"
|
743 |
+
f" {type(config.text_config)}."
|
744 |
+
)
|
745 |
+
|
746 |
+
if not isinstance(config.vision_config, CLIPVisionConfig):
|
747 |
+
raise ValueError(
|
748 |
+
"config.vision_config is expected to be of type CLIPVisionConfig but is of type"
|
749 |
+
f" {type(config.vision_config)}."
|
750 |
+
)
|
751 |
+
|
752 |
+
text_config = config.text_config
|
753 |
+
vision_config = config.vision_config
|
754 |
+
self.add_time_attn = vision_config.add_time_attn
|
755 |
+
self.lora_r = vision_config.lora_r
|
756 |
+
self.lora_alpha = vision_config.lora_alpha
|
757 |
+
self.lora_dropout = vision_config.lora_dropout
|
758 |
+
|
759 |
+
self.projection_dim = config.projection_dim
|
760 |
+
self.text_embed_dim = text_config.hidden_size
|
761 |
+
self.vision_embed_dim = vision_config.hidden_size
|
762 |
+
|
763 |
+
self.text_model = CLIPTextTransformer(text_config)
|
764 |
+
self.vision_model = CLIPVisionTransformer(vision_config)
|
765 |
+
|
766 |
+
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
|
767 |
+
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
|
768 |
+
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
|
769 |
+
|
770 |
+
# Initialize weights and apply final processing
|
771 |
+
self.post_init()
|
772 |
+
self.convert_to_lora()
|
773 |
+
self.resize_pos(self.vision_model.embeddings, vision_config)
|
774 |
+
|
775 |
+
def convert_to_lora(self):
|
776 |
+
if self.lora_r == 0:
|
777 |
+
return
|
778 |
+
if self.add_time_attn:
|
779 |
+
target_modules = ["temporal_attn.k_proj", "temporal_attn.v_proj",
|
780 |
+
"temporal_attn.q_proj", "temporal_attn.out_proj",
|
781 |
+
"temporal_mlp.fc1", "temporal_mlp.fc2"]
|
782 |
+
else:
|
783 |
+
target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"]
|
784 |
+
config = LoraConfig(
|
785 |
+
r=self.lora_r, # 16
|
786 |
+
lora_alpha=self.lora_alpha, # 16
|
787 |
+
target_modules=target_modules, # self_attn.out_proj
|
788 |
+
lora_dropout=self.lora_dropout, # 0.1
|
789 |
+
bias="none",
|
790 |
+
modules_to_save=[],
|
791 |
+
)
|
792 |
+
self.vision_model.encoder.is_gradient_checkpointing = False
|
793 |
+
self.vision_model.encoder = get_peft_model(self.vision_model.encoder, config)
|
794 |
+
|
795 |
+
def resize_pos(self, m, vision_config):
|
796 |
+
# convert embedding
|
797 |
+
if vision_config.num_mel_bins!=0 and vision_config.target_length!=0:
|
798 |
+
m.image_size = [vision_config.num_mel_bins, vision_config.target_length]
|
799 |
+
m.config.image_size = [m.image_size, m.image_size] if isinstance(m.image_size, int) else m.image_size
|
800 |
+
# pos resize
|
801 |
+
old_pos_embed_state_dict = m.position_embedding.state_dict()
|
802 |
+
old_pos_embed = old_pos_embed_state_dict['weight']
|
803 |
+
dtype = old_pos_embed.dtype
|
804 |
+
grid_size = [m.config.image_size[0] // m.patch_size, m.config.image_size[1] // m.patch_size]
|
805 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
806 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
807 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
808 |
+
# m.to(args.device)
|
809 |
+
return
|
810 |
+
|
811 |
+
m.num_patches = grid_size[0] * grid_size[1]
|
812 |
+
m.num_positions = m.num_patches + 1
|
813 |
+
m.register_buffer("position_ids", torch.arange(m.num_positions).expand((1, -1)))
|
814 |
+
new_position_embedding = nn.Embedding(m.num_positions, m.embed_dim)
|
815 |
+
|
816 |
+
if extra_tokens:
|
817 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
818 |
+
else:
|
819 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
820 |
+
old_grid_size = [int(math.sqrt(len(pos_emb_img)))] * 2
|
821 |
+
|
822 |
+
# if is_master(args):
|
823 |
+
# logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
824 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
825 |
+
pos_emb_img = F.interpolate(
|
826 |
+
pos_emb_img,
|
827 |
+
size=grid_size,
|
828 |
+
mode='bicubic',
|
829 |
+
antialias=True,
|
830 |
+
align_corners=False,
|
831 |
+
)
|
832 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
833 |
+
if pos_emb_tok is not None:
|
834 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
835 |
+
else:
|
836 |
+
new_pos_embed = pos_emb_img
|
837 |
+
old_pos_embed_state_dict['weight'] = new_pos_embed.to(dtype)
|
838 |
+
m.position_embedding = new_position_embedding
|
839 |
+
m.position_embedding.load_state_dict(old_pos_embed_state_dict)
|
840 |
+
|
841 |
+
# m.to(args.device)
|
842 |
+
|
843 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
844 |
+
def get_text_features(
|
845 |
+
self,
|
846 |
+
input_ids: Optional[torch.Tensor] = None,
|
847 |
+
attention_mask: Optional[torch.Tensor] = None,
|
848 |
+
position_ids: Optional[torch.Tensor] = None,
|
849 |
+
output_attentions: Optional[bool] = None,
|
850 |
+
output_hidden_states: Optional[bool] = None,
|
851 |
+
return_dict: Optional[bool] = None,
|
852 |
+
) -> torch.FloatTensor:
|
853 |
+
r"""
|
854 |
+
Returns:
|
855 |
+
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
|
856 |
+
applying the projection layer to the pooled output of [`CLIPTextModel`].
|
857 |
+
|
858 |
+
Examples:
|
859 |
+
|
860 |
+
```python
|
861 |
+
>>> from transformers import AutoTokenizer, CLIPModel
|
862 |
+
|
863 |
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
864 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
865 |
+
|
866 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
867 |
+
>>> text_features = model.get_text_features(**inputs)
|
868 |
+
```"""
|
869 |
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
870 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
871 |
+
output_hidden_states = (
|
872 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
873 |
+
)
|
874 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
875 |
+
|
876 |
+
text_outputs = self.text_model(
|
877 |
+
input_ids=input_ids,
|
878 |
+
attention_mask=attention_mask,
|
879 |
+
position_ids=position_ids,
|
880 |
+
output_attentions=output_attentions,
|
881 |
+
output_hidden_states=output_hidden_states,
|
882 |
+
return_dict=return_dict,
|
883 |
+
)
|
884 |
+
|
885 |
+
pooled_output = text_outputs[1]
|
886 |
+
text_features = self.text_projection(pooled_output)
|
887 |
+
|
888 |
+
return text_features
|
889 |
+
|
890 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
891 |
+
def get_image_features(
|
892 |
+
self,
|
893 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
894 |
+
output_attentions: Optional[bool] = None,
|
895 |
+
output_hidden_states: Optional[bool] = None,
|
896 |
+
return_dict: Optional[bool] = None,
|
897 |
+
) -> torch.FloatTensor:
|
898 |
+
r"""
|
899 |
+
Returns:
|
900 |
+
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
901 |
+
applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
902 |
+
|
903 |
+
Examples:
|
904 |
+
|
905 |
+
```python
|
906 |
+
>>> from PIL import Image
|
907 |
+
>>> import requests
|
908 |
+
>>> from transformers import AutoProcessor, CLIPModel
|
909 |
+
|
910 |
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
911 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
912 |
+
|
913 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
914 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
915 |
+
|
916 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
917 |
+
|
918 |
+
>>> image_features = model.get_image_features(**inputs)
|
919 |
+
```"""
|
920 |
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
921 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
922 |
+
output_hidden_states = (
|
923 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
924 |
+
)
|
925 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
926 |
+
|
927 |
+
vision_outputs = self.vision_model(
|
928 |
+
pixel_values=pixel_values,
|
929 |
+
output_attentions=output_attentions,
|
930 |
+
output_hidden_states=output_hidden_states,
|
931 |
+
return_dict=return_dict,
|
932 |
+
)
|
933 |
+
|
934 |
+
pooled_output = vision_outputs[1] # pooled_output
|
935 |
+
image_features = self.visual_projection(pooled_output)
|
936 |
+
|
937 |
+
return image_features
|
938 |
+
|
939 |
+
@add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
|
940 |
+
@replace_return_docstrings(output_type=CLIPOutput, config_class=LanguageBindAudioConfig)
|
941 |
+
def forward(
|
942 |
+
self,
|
943 |
+
input_ids: Optional[torch.LongTensor] = None,
|
944 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
945 |
+
attention_mask: Optional[torch.Tensor] = None,
|
946 |
+
position_ids: Optional[torch.LongTensor] = None,
|
947 |
+
return_loss: Optional[bool] = None,
|
948 |
+
output_attentions: Optional[bool] = None,
|
949 |
+
output_hidden_states: Optional[bool] = None,
|
950 |
+
return_dict: Optional[bool] = None,
|
951 |
+
) -> Union[Tuple, CLIPOutput]:
|
952 |
+
r"""
|
953 |
+
Returns:
|
954 |
+
|
955 |
+
Examples:
|
956 |
+
|
957 |
+
```python
|
958 |
+
>>> from PIL import Image
|
959 |
+
>>> import requests
|
960 |
+
>>> from transformers import AutoProcessor, CLIPModel
|
961 |
+
|
962 |
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
963 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
964 |
+
|
965 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
966 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
967 |
+
|
968 |
+
>>> inputs = processor(
|
969 |
+
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
|
970 |
+
... )
|
971 |
+
|
972 |
+
>>> outputs = model(**inputs)
|
973 |
+
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
974 |
+
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
975 |
+
```"""
|
976 |
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
977 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
978 |
+
output_hidden_states = (
|
979 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
980 |
+
)
|
981 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
982 |
+
|
983 |
+
vision_outputs = self.vision_model(
|
984 |
+
pixel_values=pixel_values,
|
985 |
+
output_attentions=output_attentions,
|
986 |
+
output_hidden_states=output_hidden_states,
|
987 |
+
return_dict=return_dict,
|
988 |
+
)
|
989 |
+
|
990 |
+
text_outputs = self.text_model(
|
991 |
+
input_ids=input_ids,
|
992 |
+
attention_mask=attention_mask,
|
993 |
+
position_ids=position_ids,
|
994 |
+
output_attentions=output_attentions,
|
995 |
+
output_hidden_states=output_hidden_states,
|
996 |
+
return_dict=return_dict,
|
997 |
+
)
|
998 |
+
|
999 |
+
image_embeds = vision_outputs[1]
|
1000 |
+
image_embeds = self.visual_projection(image_embeds)
|
1001 |
+
|
1002 |
+
text_embeds = text_outputs[1]
|
1003 |
+
text_embeds = self.text_projection(text_embeds)
|
1004 |
+
|
1005 |
+
# normalized features
|
1006 |
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
1007 |
+
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
1008 |
+
|
1009 |
+
# cosine similarity as logits
|
1010 |
+
logit_scale = self.logit_scale.exp()
|
1011 |
+
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
1012 |
+
logits_per_image = logits_per_text.t()
|
1013 |
+
|
1014 |
+
loss = None
|
1015 |
+
if return_loss:
|
1016 |
+
loss = clip_loss(logits_per_text)
|
1017 |
+
|
1018 |
+
if not return_dict:
|
1019 |
+
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
1020 |
+
return ((loss,) + output) if loss is not None else output
|
1021 |
+
|
1022 |
+
return CLIPOutput(
|
1023 |
+
loss=loss,
|
1024 |
+
logits_per_image=logits_per_image,
|
1025 |
+
logits_per_text=logits_per_text,
|
1026 |
+
text_embeds=text_embeds,
|
1027 |
+
image_embeds=image_embeds,
|
1028 |
+
text_model_output=text_outputs,
|
1029 |
+
vision_model_output=vision_outputs,
|
1030 |
+
)
|
models/multimodal_encoder/languagebind/audio/processing_audio.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
# import torchaudio
|
5 |
+
from torchvision import transforms
|
6 |
+
from transformers import ProcessorMixin, BatchEncoding
|
7 |
+
from transformers.image_processing_utils import BatchFeature
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
|
11 |
+
def make_list_of_images(x):
|
12 |
+
if not isinstance(x, list):
|
13 |
+
return [x]
|
14 |
+
return x
|
15 |
+
|
16 |
+
|
17 |
+
#torchaudio.set_audio_backend("soundfile")
|
18 |
+
|
19 |
+
def torchaudio_loader(path):
|
20 |
+
return torchaudio.load(path)
|
21 |
+
|
22 |
+
def int16_to_float32_torch(x):
|
23 |
+
return (x / 32767.0).type(torch.float32)
|
24 |
+
|
25 |
+
def float32_to_int16_torch(x):
|
26 |
+
x = torch.clamp(x, min=-1., max=1.)
|
27 |
+
return (x * 32767.).type(torch.int16)
|
28 |
+
|
29 |
+
DEFAULT_AUDIO_FRAME_SHIFT_MS = 10
|
30 |
+
|
31 |
+
class AudioTransform:
|
32 |
+
def __init__(self, config):
|
33 |
+
self.sample_rate = config.audio_sample_rate
|
34 |
+
self.num_mel_bins = config.num_mel_bins
|
35 |
+
self.target_length = config.target_length
|
36 |
+
self.audio_mean = config.audio_mean
|
37 |
+
self.audio_std = config.audio_std
|
38 |
+
# mean=-4.2677393
|
39 |
+
# std=4.5689974
|
40 |
+
self.norm = transforms.Normalize(mean=self.audio_mean, std=self.audio_std)
|
41 |
+
|
42 |
+
def __call__(self, audio_data_and_origin_sr):
|
43 |
+
audio_data, origin_sr = audio_data_and_origin_sr
|
44 |
+
if self.sample_rate != origin_sr:
|
45 |
+
# print(audio_data.shape, origin_sr)
|
46 |
+
audio_data = torchaudio.functional.resample(audio_data, orig_freq=origin_sr, new_freq=self.sample_rate)
|
47 |
+
waveform_melspec = self.waveform2melspec(audio_data[0])
|
48 |
+
return self.norm(waveform_melspec)
|
49 |
+
|
50 |
+
def waveform2melspec(self, audio_data):
|
51 |
+
max_len = self.target_length * self.sample_rate // 100
|
52 |
+
if audio_data.shape[-1] > max_len:
|
53 |
+
mel = self.get_mel(audio_data)
|
54 |
+
# split to three parts
|
55 |
+
chunk_frames = self.target_length
|
56 |
+
total_frames = mel.shape[0]
|
57 |
+
ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3)
|
58 |
+
# print('total_frames-chunk_frames:', total_frames-chunk_frames,
|
59 |
+
# 'len(audio_data):', len(audio_data),
|
60 |
+
# 'chunk_frames:', chunk_frames,
|
61 |
+
# 'total_frames:', total_frames)
|
62 |
+
if len(ranges[1]) == 0: # if the audio is too short, we just use the first chunk
|
63 |
+
ranges[1] = [0]
|
64 |
+
if len(ranges[2]) == 0: # if the audio is too short, we just use the first chunk
|
65 |
+
ranges[2] = [0]
|
66 |
+
# randomly choose index for each part
|
67 |
+
# idx_front = np.random.choice(ranges[0])
|
68 |
+
# idx_middle = np.random.choice(ranges[1])
|
69 |
+
# idx_back = np.random.choice(ranges[2])
|
70 |
+
idx_front = ranges[0][0] # fixed
|
71 |
+
idx_middle = ranges[1][0]
|
72 |
+
idx_back = ranges[2][0]
|
73 |
+
# select mel
|
74 |
+
mel_chunk_front = mel[idx_front:idx_front + chunk_frames, :]
|
75 |
+
mel_chunk_middle = mel[idx_middle:idx_middle + chunk_frames, :]
|
76 |
+
mel_chunk_back = mel[idx_back:idx_back + chunk_frames, :]
|
77 |
+
# stack
|
78 |
+
mel_fusion = torch.stack([mel_chunk_front, mel_chunk_middle, mel_chunk_back], dim=0)
|
79 |
+
elif audio_data.shape[-1] < max_len: # padding if too short
|
80 |
+
n_repeat = int(max_len / len(audio_data))
|
81 |
+
audio_data = audio_data.repeat(n_repeat)
|
82 |
+
audio_data = F.pad(
|
83 |
+
audio_data,
|
84 |
+
(0, max_len - len(audio_data)),
|
85 |
+
mode="constant",
|
86 |
+
value=0,
|
87 |
+
)
|
88 |
+
mel = self.get_mel(audio_data)
|
89 |
+
mel_fusion = torch.stack([mel, mel, mel], dim=0)
|
90 |
+
else: # if equal
|
91 |
+
mel = self.get_mel(audio_data)
|
92 |
+
mel_fusion = torch.stack([mel, mel, mel], dim=0)
|
93 |
+
|
94 |
+
# twice check
|
95 |
+
p = self.target_length - mel_fusion.shape[1]
|
96 |
+
|
97 |
+
# if abs(p) / self.target_length > 0.2:
|
98 |
+
# logging.warning(
|
99 |
+
# "Large gap between audio n_frames(%d) and "
|
100 |
+
# "target_length (%d). Is the audio_target_length "
|
101 |
+
# "setting correct?",
|
102 |
+
# mel_fusion.shape[1],
|
103 |
+
# self.target_length,
|
104 |
+
# )
|
105 |
+
|
106 |
+
# cut and pad
|
107 |
+
if p > 0:
|
108 |
+
m = torch.nn.ZeroPad2d((0, 0, 0, p))
|
109 |
+
mel_fusion = m(mel_fusion)
|
110 |
+
elif p < 0:
|
111 |
+
mel_fusion = mel_fusion[:, 0: self.target_length, :]
|
112 |
+
|
113 |
+
mel_fusion = mel_fusion.transpose(1, 2) # [3, target_length, mel_bins] -> [3, mel_bins, target_length]
|
114 |
+
return mel_fusion
|
115 |
+
|
116 |
+
def get_mel(self, audio_data):
|
117 |
+
# mel shape: (n_mels, T)
|
118 |
+
audio_data -= audio_data.mean()
|
119 |
+
mel = torchaudio.compliance.kaldi.fbank(
|
120 |
+
audio_data.unsqueeze(0),
|
121 |
+
htk_compat=True,
|
122 |
+
sample_frequency=self.sample_rate,
|
123 |
+
use_energy=False,
|
124 |
+
window_type="hanning",
|
125 |
+
num_mel_bins=self.num_mel_bins,
|
126 |
+
dither=0.0,
|
127 |
+
frame_length=25,
|
128 |
+
frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS,
|
129 |
+
)
|
130 |
+
return mel # (T, n_mels)
|
131 |
+
|
132 |
+
def get_audio_transform(config):
|
133 |
+
config = config.vision_config
|
134 |
+
return AudioTransform(config)
|
135 |
+
|
136 |
+
|
137 |
+
def load_and_transform_audio(
|
138 |
+
audio_path,
|
139 |
+
transform,
|
140 |
+
):
|
141 |
+
waveform_and_sr = torchaudio_loader(audio_path)
|
142 |
+
audio_outputs = transform(waveform_and_sr)
|
143 |
+
|
144 |
+
return audio_outputs
|
145 |
+
|
146 |
+
class LanguageBindAudioProcessor(ProcessorMixin):
|
147 |
+
attributes = []
|
148 |
+
tokenizer_class = ("LanguageBindAudioTokenizer")
|
149 |
+
|
150 |
+
def __init__(self, config, tokenizer=None, **kwargs):
|
151 |
+
super().__init__(**kwargs)
|
152 |
+
self.config = config
|
153 |
+
self.transform = get_audio_transform(config)
|
154 |
+
self.image_processor = load_and_transform_audio
|
155 |
+
self.tokenizer = tokenizer
|
156 |
+
|
157 |
+
def __call__(self, images=None, text=None, context_length=77, return_tensors=None, **kwargs):
|
158 |
+
if text is None and images is None:
|
159 |
+
raise ValueError("You have to specify either text or images. Both cannot be none.")
|
160 |
+
|
161 |
+
if text is not None:
|
162 |
+
encoding = self.tokenizer(text, max_length=context_length, padding='max_length',
|
163 |
+
truncation=True, return_tensors=return_tensors, **kwargs)
|
164 |
+
|
165 |
+
if images is not None:
|
166 |
+
images = make_list_of_images(images)
|
167 |
+
image_features = [self.image_processor(image, self.transform) for image in images]
|
168 |
+
image_features = torch.stack(image_features)
|
169 |
+
|
170 |
+
if text is not None and images is not None:
|
171 |
+
encoding["pixel_values"] = image_features
|
172 |
+
return encoding
|
173 |
+
elif text is not None:
|
174 |
+
return encoding
|
175 |
+
else:
|
176 |
+
return {"pixel_values": image_features}
|
177 |
+
|
178 |
+
def batch_decode(self, skip_special_tokens=True, *args, **kwargs):
|
179 |
+
"""
|
180 |
+
This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
181 |
+
refer to the docstring of this method for more information.
|
182 |
+
"""
|
183 |
+
return self.tokenizer.batch_decode(*args, skip_special_tokens=skip_special_tokens, **kwargs)
|
184 |
+
|
185 |
+
def decode(self, skip_special_tokens=True, *args, **kwargs):
|
186 |
+
"""
|
187 |
+
This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
188 |
+
the docstring of this method for more information.
|
189 |
+
"""
|
190 |
+
return self.tokenizer.decode(*args, skip_special_tokens=skip_special_tokens, **kwargs)
|
models/multimodal_encoder/languagebind/audio/tokenization_audio.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import CLIPTokenizer
|
2 |
+
from transformers.utils import logging
|
3 |
+
|
4 |
+
logger = logging.get_logger(__name__)
|
5 |
+
|
6 |
+
VOCAB_FILES_NAMES = {
|
7 |
+
"vocab_file": "vocab.json",
|
8 |
+
"merges_file": "merges.txt",
|
9 |
+
}
|
10 |
+
|
11 |
+
PRETRAINED_VOCAB_FILES_MAP = {
|
12 |
+
"vocab_file": {
|
13 |
+
"lb203/LanguageBind-Audio": "https://huggingface.co/lb203/LanguageBind-Audio/resolve/main/vocab.json",
|
14 |
+
},
|
15 |
+
"merges_file": {
|
16 |
+
"lb203/LanguageBind-Audio": "https://huggingface.co/lb203/LanguageBind-Audio/resolve/main/merges.txt",
|
17 |
+
},
|
18 |
+
}
|
19 |
+
|
20 |
+
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
21 |
+
"lb203/LanguageBind-Audio": 77,
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
PRETRAINED_INIT_CONFIGURATION = {
|
26 |
+
"lb203/LanguageBind-Audio": {},
|
27 |
+
}
|
28 |
+
|
29 |
+
class LanguageBindAudioTokenizer(CLIPTokenizer):
|
30 |
+
"""
|
31 |
+
Construct a CLIP tokenizer. Based on byte-level Byte-Pair-Encoding.
|
32 |
+
|
33 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
34 |
+
this superclass for more information regarding those methods.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
vocab_file (`str`):
|
38 |
+
Path to the vocabulary file.
|
39 |
+
merges_file (`str`):
|
40 |
+
Path to the merges file.
|
41 |
+
errors (`str`, *optional*, defaults to `"replace"`):
|
42 |
+
Paradigm to follow when decoding bytes to UTF-8. See
|
43 |
+
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
|
44 |
+
unk_token (`str`, *optional*, defaults to `<|endoftext|>`):
|
45 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
46 |
+
token instead.
|
47 |
+
bos_token (`str`, *optional*, defaults to `<|startoftext|>`):
|
48 |
+
The beginning of sequence token.
|
49 |
+
eos_token (`str`, *optional*, defaults to `<|endoftext|>`):
|
50 |
+
The end of sequence token.
|
51 |
+
"""
|
52 |
+
|
53 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
54 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
55 |
+
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
56 |
+
model_input_names = ["input_ids", "attention_mask"]
|
57 |
+
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
vocab_file,
|
61 |
+
merges_file,
|
62 |
+
errors="replace",
|
63 |
+
unk_token="<|endoftext|>",
|
64 |
+
bos_token="<|startoftext|>",
|
65 |
+
eos_token="<|endoftext|>",
|
66 |
+
pad_token="<|endoftext|>", # hack to enable padding
|
67 |
+
**kwargs,
|
68 |
+
):
|
69 |
+
super(LanguageBindAudioTokenizer, self).__init__(
|
70 |
+
vocab_file,
|
71 |
+
merges_file,
|
72 |
+
errors,
|
73 |
+
unk_token,
|
74 |
+
bos_token,
|
75 |
+
eos_token,
|
76 |
+
pad_token, # hack to enable padding
|
77 |
+
**kwargs,)
|
models/multimodal_encoder/languagebind/depth/configuration_depth.py
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
from transformers import PretrainedConfig
|
6 |
+
from transformers.utils import logging
|
7 |
+
|
8 |
+
logger = logging.get_logger(__name__)
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
class CLIPTextConfig(PretrainedConfig):
|
17 |
+
r"""
|
18 |
+
This is the configuration class to store the configuration of a [`CLIPTextModel`]. It is used to instantiate a CLIP
|
19 |
+
text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
|
20 |
+
with the defaults will yield a similar configuration to that of the text encoder of the CLIP
|
21 |
+
[openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
|
22 |
+
|
23 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
24 |
+
documentation from [`PretrainedConfig`] for more information.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
vocab_size (`int`, *optional*, defaults to 49408):
|
28 |
+
Vocabulary size of the CLIP text model. Defines the number of different tokens that can be represented by
|
29 |
+
the `inputs_ids` passed when calling [`CLIPModel`].
|
30 |
+
hidden_size (`int`, *optional*, defaults to 512):
|
31 |
+
Dimensionality of the encoder layers and the pooler layer.
|
32 |
+
intermediate_size (`int`, *optional*, defaults to 2048):
|
33 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
34 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
35 |
+
Number of hidden layers in the Transformer encoder.
|
36 |
+
num_attention_heads (`int`, *optional*, defaults to 8):
|
37 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
38 |
+
max_position_embeddings (`int`, *optional*, defaults to 77):
|
39 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
40 |
+
just in case (e.g., 512 or 1024 or 2048).
|
41 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
|
42 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
43 |
+
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
44 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
45 |
+
The epsilon used by the layer normalization layers.
|
46 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
47 |
+
The dropout ratio for the attention probabilities.
|
48 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
49 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
50 |
+
initializer_factor (`float`, *optional*, defaults to 1):
|
51 |
+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
52 |
+
testing).
|
53 |
+
|
54 |
+
Example:
|
55 |
+
|
56 |
+
```python
|
57 |
+
>>> from transformers import CLIPTextConfig, CLIPTextModel
|
58 |
+
|
59 |
+
>>> # Initializing a CLIPTextConfig with openai/clip-vit-base-patch32 style configuration
|
60 |
+
>>> configuration = CLIPTextConfig()
|
61 |
+
|
62 |
+
>>> # Initializing a CLIPTextModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
|
63 |
+
>>> model = CLIPTextModel(configuration)
|
64 |
+
|
65 |
+
>>> # Accessing the model configuration
|
66 |
+
>>> configuration = model.config
|
67 |
+
```"""
|
68 |
+
model_type = "clip_text_model"
|
69 |
+
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
vocab_size=49408,
|
73 |
+
hidden_size=512,
|
74 |
+
intermediate_size=2048,
|
75 |
+
projection_dim=512,
|
76 |
+
num_hidden_layers=12,
|
77 |
+
num_attention_heads=8,
|
78 |
+
max_position_embeddings=77,
|
79 |
+
hidden_act="quick_gelu",
|
80 |
+
layer_norm_eps=1e-5,
|
81 |
+
attention_dropout=0.0,
|
82 |
+
initializer_range=0.02,
|
83 |
+
initializer_factor=1.0,
|
84 |
+
# This differs from `CLIPTokenizer`'s default and from openai/clip
|
85 |
+
# See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
|
86 |
+
pad_token_id=1,
|
87 |
+
bos_token_id=49406,
|
88 |
+
eos_token_id=49407,
|
89 |
+
**kwargs,
|
90 |
+
):
|
91 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
92 |
+
|
93 |
+
self.vocab_size = vocab_size
|
94 |
+
self.hidden_size = hidden_size
|
95 |
+
self.intermediate_size = intermediate_size
|
96 |
+
self.projection_dim = projection_dim
|
97 |
+
self.num_hidden_layers = num_hidden_layers
|
98 |
+
self.num_attention_heads = num_attention_heads
|
99 |
+
self.max_position_embeddings = max_position_embeddings
|
100 |
+
self.layer_norm_eps = layer_norm_eps
|
101 |
+
self.hidden_act = hidden_act
|
102 |
+
self.initializer_range = initializer_range
|
103 |
+
self.initializer_factor = initializer_factor
|
104 |
+
self.attention_dropout = attention_dropout
|
105 |
+
self.add_time_attn = False ######################################
|
106 |
+
|
107 |
+
@classmethod
|
108 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
109 |
+
cls._set_token_in_kwargs(kwargs)
|
110 |
+
|
111 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
112 |
+
|
113 |
+
# get the text config dict if we are loading from CLIPConfig
|
114 |
+
if config_dict.get("model_type") == "clip":
|
115 |
+
config_dict = config_dict["text_config"]
|
116 |
+
|
117 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
118 |
+
logger.warning(
|
119 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
120 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
121 |
+
)
|
122 |
+
|
123 |
+
return cls.from_dict(config_dict, **kwargs)
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
class CLIPVisionConfig(PretrainedConfig):
|
129 |
+
r"""
|
130 |
+
This is the configuration class to store the configuration of a [`CLIPVisionModel`]. It is used to instantiate a
|
131 |
+
CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
132 |
+
configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP
|
133 |
+
[openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
|
134 |
+
|
135 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
136 |
+
documentation from [`PretrainedConfig`] for more information.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
140 |
+
Dimensionality of the encoder layers and the pooler layer.
|
141 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
142 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
143 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
144 |
+
Number of hidden layers in the Transformer encoder.
|
145 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
146 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
147 |
+
image_size (`int`, *optional*, defaults to 224):
|
148 |
+
The size (resolution) of each image.
|
149 |
+
patch_size (`int`, *optional*, defaults to 32):
|
150 |
+
The size (resolution) of each patch.
|
151 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
|
152 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
153 |
+
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
|
154 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
155 |
+
The epsilon used by the layer normalization layers.
|
156 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
157 |
+
The dropout ratio for the attention probabilities.
|
158 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
159 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
160 |
+
initializer_factor (`float`, *optional*, defaults to 1):
|
161 |
+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
162 |
+
testing).
|
163 |
+
|
164 |
+
Example:
|
165 |
+
|
166 |
+
```python
|
167 |
+
>>> from transformers import CLIPVisionConfig, CLIPVisionModel
|
168 |
+
|
169 |
+
>>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration
|
170 |
+
>>> configuration = CLIPVisionConfig()
|
171 |
+
|
172 |
+
>>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
|
173 |
+
>>> model = CLIPVisionModel(configuration)
|
174 |
+
|
175 |
+
>>> # Accessing the model configuration
|
176 |
+
>>> configuration = model.config
|
177 |
+
```"""
|
178 |
+
|
179 |
+
model_type = "clip_vision_model"
|
180 |
+
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
hidden_size=768,
|
184 |
+
intermediate_size=3072,
|
185 |
+
projection_dim=512,
|
186 |
+
num_hidden_layers=12,
|
187 |
+
num_attention_heads=12,
|
188 |
+
num_channels=3,
|
189 |
+
image_size=224,
|
190 |
+
patch_size=32,
|
191 |
+
hidden_act="quick_gelu",
|
192 |
+
layer_norm_eps=1e-5,
|
193 |
+
attention_dropout=0.0,
|
194 |
+
initializer_range=0.02,
|
195 |
+
initializer_factor=1.0,
|
196 |
+
|
197 |
+
add_time_attn=False, ################################
|
198 |
+
num_frames=1, ################################
|
199 |
+
force_patch_dropout=0.0, ################################
|
200 |
+
lora_r=2, ################################
|
201 |
+
lora_alpha=16, ################################
|
202 |
+
lora_dropout=0.0, ################################
|
203 |
+
num_mel_bins=0.0, ################################
|
204 |
+
target_length=0.0, ################################
|
205 |
+
max_depth=10,
|
206 |
+
video_decode_backend='decord', #########################
|
207 |
+
**kwargs,
|
208 |
+
):
|
209 |
+
super().__init__(**kwargs)
|
210 |
+
|
211 |
+
self.hidden_size = hidden_size
|
212 |
+
self.intermediate_size = intermediate_size
|
213 |
+
self.projection_dim = projection_dim
|
214 |
+
self.num_hidden_layers = num_hidden_layers
|
215 |
+
self.num_attention_heads = num_attention_heads
|
216 |
+
self.num_channels = num_channels
|
217 |
+
self.patch_size = patch_size
|
218 |
+
self.image_size = image_size
|
219 |
+
self.initializer_range = initializer_range
|
220 |
+
self.initializer_factor = initializer_factor
|
221 |
+
self.attention_dropout = attention_dropout
|
222 |
+
self.layer_norm_eps = layer_norm_eps
|
223 |
+
self.hidden_act = hidden_act
|
224 |
+
|
225 |
+
self.add_time_attn = add_time_attn ################
|
226 |
+
self.num_frames = num_frames ################
|
227 |
+
self.force_patch_dropout = force_patch_dropout ################
|
228 |
+
self.lora_r = lora_r ################
|
229 |
+
self.lora_alpha = lora_alpha ################
|
230 |
+
self.lora_dropout = lora_dropout ################
|
231 |
+
self.num_mel_bins = num_mel_bins ################
|
232 |
+
self.target_length = target_length ################
|
233 |
+
self.max_depth = max_depth ################
|
234 |
+
self.video_decode_backend = video_decode_backend ################
|
235 |
+
|
236 |
+
@classmethod
|
237 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
238 |
+
cls._set_token_in_kwargs(kwargs)
|
239 |
+
|
240 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
241 |
+
|
242 |
+
# get the vision config dict if we are loading from CLIPConfig
|
243 |
+
if config_dict.get("model_type") == "clip":
|
244 |
+
config_dict = config_dict["vision_config"]
|
245 |
+
|
246 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
247 |
+
logger.warning(
|
248 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
249 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
250 |
+
)
|
251 |
+
|
252 |
+
return cls.from_dict(config_dict, **kwargs)
|
253 |
+
|
254 |
+
|
255 |
+
class LanguageBindDepthConfig(PretrainedConfig):
|
256 |
+
r"""
|
257 |
+
[`CLIPConfig`] is the configuration class to store the configuration of a [`CLIPModel`]. It is used to instantiate
|
258 |
+
a CLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating
|
259 |
+
a configuration with the defaults will yield a similar configuration to that of the CLIP
|
260 |
+
[openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
|
261 |
+
|
262 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
263 |
+
documentation from [`PretrainedConfig`] for more information.
|
264 |
+
|
265 |
+
Args:
|
266 |
+
text_config (`dict`, *optional*):
|
267 |
+
Dictionary of configuration options used to initialize [`CLIPTextConfig`].
|
268 |
+
vision_config (`dict`, *optional*):
|
269 |
+
Dictionary of configuration options used to initialize [`CLIPVisionConfig`].
|
270 |
+
projection_dim (`int`, *optional*, defaults to 512):
|
271 |
+
Dimentionality of text and vision projection layers.
|
272 |
+
logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
|
273 |
+
The inital value of the *logit_scale* paramter. Default is used as per the original CLIP implementation.
|
274 |
+
kwargs (*optional*):
|
275 |
+
Dictionary of keyword arguments.
|
276 |
+
|
277 |
+
Example:
|
278 |
+
|
279 |
+
```python
|
280 |
+
>>> from transformers import CLIPConfig, CLIPModel
|
281 |
+
|
282 |
+
>>> # Initializing a CLIPConfig with openai/clip-vit-base-patch32 style configuration
|
283 |
+
>>> configuration = CLIPConfig()
|
284 |
+
|
285 |
+
>>> # Initializing a CLIPModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
|
286 |
+
>>> model = CLIPModel(configuration)
|
287 |
+
|
288 |
+
>>> # Accessing the model configuration
|
289 |
+
>>> configuration = model.config
|
290 |
+
|
291 |
+
>>> # We can also initialize a CLIPConfig from a CLIPTextConfig and a CLIPVisionConfig
|
292 |
+
>>> from transformers import CLIPTextConfig, CLIPVisionConfig
|
293 |
+
|
294 |
+
>>> # Initializing a CLIPText and CLIPVision configuration
|
295 |
+
>>> config_text = CLIPTextConfig()
|
296 |
+
>>> config_vision = CLIPVisionConfig()
|
297 |
+
|
298 |
+
>>> config = CLIPConfig.from_text_vision_configs(config_text, config_vision)
|
299 |
+
```"""
|
300 |
+
|
301 |
+
model_type = "LanguageBindDepth"
|
302 |
+
is_composition = True
|
303 |
+
|
304 |
+
def __init__(
|
305 |
+
self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
|
306 |
+
):
|
307 |
+
# If `_config_dict` exist, we use them for the backward compatibility.
|
308 |
+
# We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
|
309 |
+
# of confusion!).
|
310 |
+
text_config_dict = kwargs.pop("text_config_dict", None)
|
311 |
+
vision_config_dict = kwargs.pop("vision_config_dict", None)
|
312 |
+
|
313 |
+
super().__init__(**kwargs)
|
314 |
+
|
315 |
+
# Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
|
316 |
+
# `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
|
317 |
+
# cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
|
318 |
+
if text_config_dict is not None:
|
319 |
+
if text_config is None:
|
320 |
+
text_config = {}
|
321 |
+
|
322 |
+
# This is the complete result when using `text_config_dict`.
|
323 |
+
_text_config_dict = CLIPTextConfig(**text_config_dict).to_dict()
|
324 |
+
|
325 |
+
# Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
|
326 |
+
for key, value in _text_config_dict.items():
|
327 |
+
if key in text_config and value != text_config[key] and key not in ["transformers_version"]:
|
328 |
+
# If specified in `text_config_dict`
|
329 |
+
if key in text_config_dict:
|
330 |
+
message = (
|
331 |
+
f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. "
|
332 |
+
f'The value `text_config_dict["{key}"]` will be used instead.'
|
333 |
+
)
|
334 |
+
# If inferred from default argument values (just to be super careful)
|
335 |
+
else:
|
336 |
+
message = (
|
337 |
+
f"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The "
|
338 |
+
f'value `text_config["{key}"]` will be overriden.'
|
339 |
+
)
|
340 |
+
logger.warning(message)
|
341 |
+
|
342 |
+
# Update all values in `text_config` with the ones in `_text_config_dict`.
|
343 |
+
text_config.update(_text_config_dict)
|
344 |
+
|
345 |
+
if vision_config_dict is not None:
|
346 |
+
if vision_config is None:
|
347 |
+
vision_config = {}
|
348 |
+
|
349 |
+
# This is the complete result when using `vision_config_dict`.
|
350 |
+
_vision_config_dict = CLIPVisionConfig(**vision_config_dict).to_dict()
|
351 |
+
# convert keys to string instead of integer
|
352 |
+
if "id2label" in _vision_config_dict:
|
353 |
+
_vision_config_dict["id2label"] = {
|
354 |
+
str(key): value for key, value in _vision_config_dict["id2label"].items()
|
355 |
+
}
|
356 |
+
|
357 |
+
# Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different.
|
358 |
+
for key, value in _vision_config_dict.items():
|
359 |
+
if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]:
|
360 |
+
# If specified in `vision_config_dict`
|
361 |
+
if key in vision_config_dict:
|
362 |
+
message = (
|
363 |
+
f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different "
|
364 |
+
f'values. The value `vision_config_dict["{key}"]` will be used instead.'
|
365 |
+
)
|
366 |
+
# If inferred from default argument values (just to be super careful)
|
367 |
+
else:
|
368 |
+
message = (
|
369 |
+
f"`vision_config_dict` is provided which will be used to initialize `CLIPVisionConfig`. "
|
370 |
+
f'The value `vision_config["{key}"]` will be overriden.'
|
371 |
+
)
|
372 |
+
logger.warning(message)
|
373 |
+
|
374 |
+
# Update all values in `vision_config` with the ones in `_vision_config_dict`.
|
375 |
+
vision_config.update(_vision_config_dict)
|
376 |
+
|
377 |
+
if text_config is None:
|
378 |
+
text_config = {}
|
379 |
+
logger.info("`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.")
|
380 |
+
|
381 |
+
if vision_config is None:
|
382 |
+
vision_config = {}
|
383 |
+
logger.info("`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.")
|
384 |
+
|
385 |
+
self.text_config = CLIPTextConfig(**text_config)
|
386 |
+
self.vision_config = CLIPVisionConfig(**vision_config)
|
387 |
+
|
388 |
+
self.projection_dim = projection_dim
|
389 |
+
self.logit_scale_init_value = logit_scale_init_value
|
390 |
+
self.initializer_factor = 1.0
|
391 |
+
|
392 |
+
@classmethod
|
393 |
+
def from_text_vision_configs(cls, text_config: CLIPTextConfig, vision_config: CLIPVisionConfig, **kwargs):
|
394 |
+
r"""
|
395 |
+
Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model
|
396 |
+
configuration.
|
397 |
+
|
398 |
+
Returns:
|
399 |
+
[`CLIPConfig`]: An instance of a configuration object
|
400 |
+
"""
|
401 |
+
|
402 |
+
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
|
403 |
+
|
404 |
+
def to_dict(self):
|
405 |
+
"""
|
406 |
+
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
407 |
+
|
408 |
+
Returns:
|
409 |
+
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
410 |
+
"""
|
411 |
+
output = copy.deepcopy(self.__dict__)
|
412 |
+
output["text_config"] = self.text_config.to_dict()
|
413 |
+
output["vision_config"] = self.vision_config.to_dict()
|
414 |
+
output["model_type"] = self.__class__.model_type
|
415 |
+
return output
|
416 |
+
|
417 |
+
|
418 |
+
|
419 |
+
|
420 |
+
|
421 |
+
|
422 |
+
|
423 |
+
|
424 |
+
|
425 |
+
|
models/multimodal_encoder/languagebind/depth/modeling_depth.py
ADDED
@@ -0,0 +1,1030 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from einops import rearrange
|
6 |
+
from peft import LoraConfig, get_peft_model
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from transformers import PreTrainedModel, add_start_docstrings
|
10 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
11 |
+
from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \
|
12 |
+
CLIPVisionModelWithProjection, CLIPTextModelWithProjection, _expand_mask, CLIPOutput, clip_loss
|
13 |
+
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
|
14 |
+
|
15 |
+
from .configuration_depth import LanguageBindDepthConfig, CLIPVisionConfig, CLIPTextConfig
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
class PatchDropout(nn.Module):
|
20 |
+
"""
|
21 |
+
https://arxiv.org/abs/2212.00794
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, prob, exclude_first_token=True):
|
25 |
+
super().__init__()
|
26 |
+
assert 0 <= prob < 1.
|
27 |
+
self.prob = prob
|
28 |
+
self.exclude_first_token = exclude_first_token # exclude CLS token
|
29 |
+
|
30 |
+
def forward(self, x, B, T):
|
31 |
+
if not self.training or self.prob == 0.:
|
32 |
+
return x
|
33 |
+
|
34 |
+
if self.exclude_first_token:
|
35 |
+
cls_tokens, x = x[:, :1], x[:, 1:]
|
36 |
+
else:
|
37 |
+
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
38 |
+
|
39 |
+
batch = x.size()[0]
|
40 |
+
num_tokens = x.size()[1]
|
41 |
+
|
42 |
+
batch_indices = torch.arange(batch)
|
43 |
+
batch_indices = batch_indices[..., None]
|
44 |
+
|
45 |
+
keep_prob = 1 - self.prob
|
46 |
+
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
47 |
+
|
48 |
+
if T == 1:
|
49 |
+
rand = torch.randn(batch, num_tokens)
|
50 |
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
51 |
+
else:
|
52 |
+
rand = torch.randn(B, num_tokens)
|
53 |
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
54 |
+
patch_indices_keep = patch_indices_keep.unsqueeze(1).repeat(1, T, 1)
|
55 |
+
patch_indices_keep = rearrange(patch_indices_keep, 'b t n -> (b t) n')
|
56 |
+
|
57 |
+
|
58 |
+
x = x[batch_indices, patch_indices_keep]
|
59 |
+
|
60 |
+
if self.exclude_first_token:
|
61 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
62 |
+
|
63 |
+
return x
|
64 |
+
|
65 |
+
class CLIPEncoderLayer(nn.Module):
|
66 |
+
def __init__(self, config: LanguageBindDepthConfig):
|
67 |
+
super().__init__()
|
68 |
+
self.embed_dim = config.hidden_size
|
69 |
+
self.self_attn = CLIPAttention(config)
|
70 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
71 |
+
self.mlp = CLIPMLP(config)
|
72 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
73 |
+
|
74 |
+
self.add_time_attn = config.add_time_attn
|
75 |
+
if self.add_time_attn:
|
76 |
+
self.t = config.num_frames
|
77 |
+
self.temporal_embedding = nn.Parameter(torch.zeros(1, config.num_frames, config.hidden_size))
|
78 |
+
nn.init.normal_(self.temporal_embedding, std=config.hidden_size ** -0.5)
|
79 |
+
|
80 |
+
self.embed_dim = config.hidden_size
|
81 |
+
self.temporal_attn = CLIPAttention(config)
|
82 |
+
self.temporal_layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
83 |
+
self.temporal_mlp = CLIPMLP(config)
|
84 |
+
self.temporal_layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
85 |
+
|
86 |
+
def forward(
|
87 |
+
self,
|
88 |
+
hidden_states: torch.Tensor,
|
89 |
+
attention_mask: torch.Tensor,
|
90 |
+
causal_attention_mask: torch.Tensor,
|
91 |
+
output_attentions: Optional[bool] = False,
|
92 |
+
) -> Tuple[torch.FloatTensor]:
|
93 |
+
"""
|
94 |
+
Args:
|
95 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
96 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
97 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
98 |
+
`(config.encoder_attention_heads,)`.
|
99 |
+
output_attentions (`bool`, *optional*):
|
100 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
101 |
+
returned tensors for more detail.
|
102 |
+
"""
|
103 |
+
|
104 |
+
|
105 |
+
if self.add_time_attn:
|
106 |
+
bt, n, d = hidden_states.shape
|
107 |
+
t = self.t
|
108 |
+
|
109 |
+
# time embed
|
110 |
+
if t != 1:
|
111 |
+
n = hidden_states.shape[1]
|
112 |
+
hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
|
113 |
+
hidden_states = hidden_states + self.temporal_embedding[:, :t, :]
|
114 |
+
hidden_states = rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
|
115 |
+
|
116 |
+
# time attn
|
117 |
+
residual = hidden_states
|
118 |
+
hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
|
119 |
+
# hidden_states = self.layer_norm1(hidden_states) # share layernorm
|
120 |
+
hidden_states = self.temporal_layer_norm1(hidden_states)
|
121 |
+
hidden_states, attn_weights = self.temporal_attn(
|
122 |
+
hidden_states=hidden_states,
|
123 |
+
attention_mask=attention_mask,
|
124 |
+
causal_attention_mask=causal_attention_mask,
|
125 |
+
output_attentions=output_attentions,
|
126 |
+
)
|
127 |
+
hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
|
128 |
+
|
129 |
+
residual = hidden_states
|
130 |
+
hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
|
131 |
+
# hidden_states = self.layer_norm2(hidden_states) # share layernorm
|
132 |
+
hidden_states = self.temporal_layer_norm2(hidden_states)
|
133 |
+
hidden_states = self.temporal_mlp(hidden_states)
|
134 |
+
hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
|
135 |
+
|
136 |
+
# spatial attn
|
137 |
+
residual = hidden_states
|
138 |
+
|
139 |
+
hidden_states = self.layer_norm1(hidden_states)
|
140 |
+
hidden_states, attn_weights = self.self_attn(
|
141 |
+
hidden_states=hidden_states,
|
142 |
+
attention_mask=attention_mask,
|
143 |
+
causal_attention_mask=causal_attention_mask,
|
144 |
+
output_attentions=output_attentions,
|
145 |
+
)
|
146 |
+
hidden_states = residual + hidden_states
|
147 |
+
|
148 |
+
residual = hidden_states
|
149 |
+
hidden_states = self.layer_norm2(hidden_states)
|
150 |
+
hidden_states = self.mlp(hidden_states)
|
151 |
+
hidden_states = residual + hidden_states
|
152 |
+
|
153 |
+
outputs = (hidden_states,)
|
154 |
+
|
155 |
+
if output_attentions:
|
156 |
+
outputs += (attn_weights,)
|
157 |
+
|
158 |
+
return outputs
|
159 |
+
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
class CLIPPreTrainedModel(PreTrainedModel):
|
169 |
+
"""
|
170 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
171 |
+
models.
|
172 |
+
"""
|
173 |
+
|
174 |
+
config_class = LanguageBindDepthConfig
|
175 |
+
base_model_prefix = "clip"
|
176 |
+
supports_gradient_checkpointing = True
|
177 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
178 |
+
|
179 |
+
def _init_weights(self, module):
|
180 |
+
"""Initialize the weights"""
|
181 |
+
factor = self.config.initializer_factor
|
182 |
+
if isinstance(module, CLIPTextEmbeddings):
|
183 |
+
module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
184 |
+
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
185 |
+
elif isinstance(module, CLIPVisionEmbeddings):
|
186 |
+
factor = self.config.initializer_factor
|
187 |
+
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
188 |
+
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
189 |
+
nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
|
190 |
+
elif isinstance(module, CLIPAttention):
|
191 |
+
factor = self.config.initializer_factor
|
192 |
+
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
193 |
+
out_proj_std = (module.embed_dim**-0.5) * factor
|
194 |
+
nn.init.normal_(module.q_proj.weight, std=in_proj_std)
|
195 |
+
nn.init.normal_(module.k_proj.weight, std=in_proj_std)
|
196 |
+
nn.init.normal_(module.v_proj.weight, std=in_proj_std)
|
197 |
+
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
|
198 |
+
elif isinstance(module, CLIPMLP):
|
199 |
+
factor = self.config.initializer_factor
|
200 |
+
in_proj_std = (
|
201 |
+
(module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
202 |
+
)
|
203 |
+
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
|
204 |
+
nn.init.normal_(module.fc1.weight, std=fc_std)
|
205 |
+
nn.init.normal_(module.fc2.weight, std=in_proj_std)
|
206 |
+
elif isinstance(module, LanguageBindDepth):
|
207 |
+
nn.init.normal_(
|
208 |
+
module.text_projection.weight,
|
209 |
+
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
|
210 |
+
)
|
211 |
+
nn.init.normal_(
|
212 |
+
module.visual_projection.weight,
|
213 |
+
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
|
214 |
+
)
|
215 |
+
elif isinstance(module, CLIPVisionModelWithProjection):
|
216 |
+
nn.init.normal_(
|
217 |
+
module.visual_projection.weight,
|
218 |
+
std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
|
219 |
+
)
|
220 |
+
elif isinstance(module, CLIPTextModelWithProjection):
|
221 |
+
nn.init.normal_(
|
222 |
+
module.text_projection.weight,
|
223 |
+
std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
|
224 |
+
)
|
225 |
+
|
226 |
+
if isinstance(module, nn.LayerNorm):
|
227 |
+
module.bias.data.zero_()
|
228 |
+
module.weight.data.fill_(1.0)
|
229 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
230 |
+
module.bias.data.zero_()
|
231 |
+
|
232 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
233 |
+
if isinstance(module, CLIPEncoder):
|
234 |
+
module.gradient_checkpointing = value
|
235 |
+
|
236 |
+
|
237 |
+
CLIP_START_DOCSTRING = r"""
|
238 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
239 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
240 |
+
etc.)
|
241 |
+
|
242 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
243 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
244 |
+
and behavior.
|
245 |
+
|
246 |
+
Parameters:
|
247 |
+
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
|
248 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
249 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
250 |
+
"""
|
251 |
+
|
252 |
+
CLIP_TEXT_INPUTS_DOCSTRING = r"""
|
253 |
+
Args:
|
254 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
255 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
256 |
+
it.
|
257 |
+
|
258 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
259 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
260 |
+
|
261 |
+
[What are input IDs?](../glossary#input-ids)
|
262 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
263 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
264 |
+
|
265 |
+
- 1 for tokens that are **not masked**,
|
266 |
+
- 0 for tokens that are **masked**.
|
267 |
+
|
268 |
+
[What are attention masks?](../glossary#attention-mask)
|
269 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
270 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
271 |
+
config.max_position_embeddings - 1]`.
|
272 |
+
|
273 |
+
[What are position IDs?](../glossary#position-ids)
|
274 |
+
output_attentions (`bool`, *optional*):
|
275 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
276 |
+
tensors for more detail.
|
277 |
+
output_hidden_states (`bool`, *optional*):
|
278 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
279 |
+
more detail.
|
280 |
+
return_dict (`bool`, *optional*):
|
281 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
282 |
+
"""
|
283 |
+
|
284 |
+
CLIP_VISION_INPUTS_DOCSTRING = r"""
|
285 |
+
Args:
|
286 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
287 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
288 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
289 |
+
output_attentions (`bool`, *optional*):
|
290 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
291 |
+
tensors for more detail.
|
292 |
+
output_hidden_states (`bool`, *optional*):
|
293 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
294 |
+
more detail.
|
295 |
+
return_dict (`bool`, *optional*):
|
296 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
297 |
+
"""
|
298 |
+
|
299 |
+
CLIP_INPUTS_DOCSTRING = r"""
|
300 |
+
Args:
|
301 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
302 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
303 |
+
it.
|
304 |
+
|
305 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
306 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
307 |
+
|
308 |
+
[What are input IDs?](../glossary#input-ids)
|
309 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
310 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
311 |
+
|
312 |
+
- 1 for tokens that are **not masked**,
|
313 |
+
- 0 for tokens that are **masked**.
|
314 |
+
|
315 |
+
[What are attention masks?](../glossary#attention-mask)
|
316 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
317 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
318 |
+
config.max_position_embeddings - 1]`.
|
319 |
+
|
320 |
+
[What are position IDs?](../glossary#position-ids)
|
321 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
322 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
323 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
324 |
+
return_loss (`bool`, *optional*):
|
325 |
+
Whether or not to return the contrastive loss.
|
326 |
+
output_attentions (`bool`, *optional*):
|
327 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
328 |
+
tensors for more detail.
|
329 |
+
output_hidden_states (`bool`, *optional*):
|
330 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
331 |
+
more detail.
|
332 |
+
return_dict (`bool`, *optional*):
|
333 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
334 |
+
"""
|
335 |
+
|
336 |
+
|
337 |
+
class CLIPEncoder(nn.Module):
|
338 |
+
"""
|
339 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
340 |
+
[`CLIPEncoderLayer`].
|
341 |
+
|
342 |
+
Args:
|
343 |
+
config: CLIPConfig
|
344 |
+
"""
|
345 |
+
|
346 |
+
def __init__(self, config: LanguageBindDepthConfig):
|
347 |
+
super().__init__()
|
348 |
+
self.config = config
|
349 |
+
self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
350 |
+
self.gradient_checkpointing = False
|
351 |
+
|
352 |
+
def forward(
|
353 |
+
self,
|
354 |
+
inputs_embeds,
|
355 |
+
attention_mask: Optional[torch.Tensor] = None,
|
356 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
357 |
+
output_attentions: Optional[bool] = None,
|
358 |
+
output_hidden_states: Optional[bool] = None,
|
359 |
+
return_dict: Optional[bool] = None,
|
360 |
+
) -> Union[Tuple, BaseModelOutput]:
|
361 |
+
r"""
|
362 |
+
Args:
|
363 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
364 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
365 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
366 |
+
than the model's internal embedding lookup matrix.
|
367 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
368 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
369 |
+
|
370 |
+
- 1 for tokens that are **not masked**,
|
371 |
+
- 0 for tokens that are **masked**.
|
372 |
+
|
373 |
+
[What are attention masks?](../glossary#attention-mask)
|
374 |
+
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
375 |
+
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
376 |
+
|
377 |
+
- 1 for tokens that are **not masked**,
|
378 |
+
- 0 for tokens that are **masked**.
|
379 |
+
|
380 |
+
[What are attention masks?](../glossary#attention-mask)
|
381 |
+
output_attentions (`bool`, *optional*):
|
382 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
383 |
+
returned tensors for more detail.
|
384 |
+
output_hidden_states (`bool`, *optional*):
|
385 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
386 |
+
for more detail.
|
387 |
+
return_dict (`bool`, *optional*):
|
388 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
389 |
+
"""
|
390 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
391 |
+
output_hidden_states = (
|
392 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
393 |
+
)
|
394 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
395 |
+
|
396 |
+
encoder_states = () if output_hidden_states else None
|
397 |
+
all_attentions = () if output_attentions else None
|
398 |
+
|
399 |
+
hidden_states = inputs_embeds
|
400 |
+
for idx, encoder_layer in enumerate(self.layers):
|
401 |
+
if output_hidden_states:
|
402 |
+
encoder_states = encoder_states + (hidden_states,)
|
403 |
+
if self.gradient_checkpointing and self.training:
|
404 |
+
|
405 |
+
def create_custom_forward(module):
|
406 |
+
def custom_forward(*inputs):
|
407 |
+
return module(*inputs, output_attentions)
|
408 |
+
|
409 |
+
return custom_forward
|
410 |
+
|
411 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
412 |
+
create_custom_forward(encoder_layer),
|
413 |
+
hidden_states,
|
414 |
+
attention_mask,
|
415 |
+
causal_attention_mask,
|
416 |
+
)
|
417 |
+
else:
|
418 |
+
layer_outputs = encoder_layer(
|
419 |
+
hidden_states,
|
420 |
+
attention_mask,
|
421 |
+
causal_attention_mask,
|
422 |
+
output_attentions=output_attentions,
|
423 |
+
)
|
424 |
+
|
425 |
+
hidden_states = layer_outputs[0]
|
426 |
+
|
427 |
+
if output_attentions:
|
428 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
429 |
+
|
430 |
+
if output_hidden_states:
|
431 |
+
encoder_states = encoder_states + (hidden_states,)
|
432 |
+
|
433 |
+
if not return_dict:
|
434 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
435 |
+
return BaseModelOutput(
|
436 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
437 |
+
)
|
438 |
+
|
439 |
+
|
440 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
441 |
+
def _make_causal_mask(
|
442 |
+
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
443 |
+
):
|
444 |
+
"""
|
445 |
+
Make causal mask used for bi-directional self-attention.
|
446 |
+
"""
|
447 |
+
bsz, tgt_len = input_ids_shape
|
448 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
449 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
450 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
451 |
+
mask = mask.to(dtype)
|
452 |
+
|
453 |
+
if past_key_values_length > 0:
|
454 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
455 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
456 |
+
|
457 |
+
|
458 |
+
class CLIPTextTransformer(nn.Module):
|
459 |
+
def __init__(self, config: CLIPTextConfig):
|
460 |
+
super().__init__()
|
461 |
+
self.config = config
|
462 |
+
embed_dim = config.hidden_size
|
463 |
+
self.embeddings = CLIPTextEmbeddings(config)
|
464 |
+
self.encoder = CLIPEncoder(config)
|
465 |
+
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
466 |
+
|
467 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
468 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
469 |
+
def forward(
|
470 |
+
self,
|
471 |
+
input_ids: Optional[torch.Tensor] = None,
|
472 |
+
attention_mask: Optional[torch.Tensor] = None,
|
473 |
+
position_ids: Optional[torch.Tensor] = None,
|
474 |
+
output_attentions: Optional[bool] = None,
|
475 |
+
output_hidden_states: Optional[bool] = None,
|
476 |
+
return_dict: Optional[bool] = None,
|
477 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
478 |
+
r"""
|
479 |
+
Returns:
|
480 |
+
|
481 |
+
"""
|
482 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
483 |
+
output_hidden_states = (
|
484 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
485 |
+
)
|
486 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
487 |
+
|
488 |
+
if input_ids is None:
|
489 |
+
raise ValueError("You have to specify input_ids")
|
490 |
+
|
491 |
+
input_shape = input_ids.size()
|
492 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
493 |
+
|
494 |
+
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
495 |
+
|
496 |
+
# CLIP's text model uses causal mask, prepare it here.
|
497 |
+
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
498 |
+
causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
|
499 |
+
# expand attention_mask
|
500 |
+
if attention_mask is not None:
|
501 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
502 |
+
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
503 |
+
|
504 |
+
encoder_outputs = self.encoder(
|
505 |
+
inputs_embeds=hidden_states,
|
506 |
+
attention_mask=attention_mask,
|
507 |
+
causal_attention_mask=causal_attention_mask,
|
508 |
+
output_attentions=output_attentions,
|
509 |
+
output_hidden_states=output_hidden_states,
|
510 |
+
return_dict=return_dict,
|
511 |
+
)
|
512 |
+
|
513 |
+
last_hidden_state = encoder_outputs[0]
|
514 |
+
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
515 |
+
|
516 |
+
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
517 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
518 |
+
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
519 |
+
pooled_output = last_hidden_state[
|
520 |
+
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
521 |
+
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
522 |
+
]
|
523 |
+
|
524 |
+
if not return_dict:
|
525 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
526 |
+
|
527 |
+
return BaseModelOutputWithPooling(
|
528 |
+
last_hidden_state=last_hidden_state,
|
529 |
+
pooler_output=pooled_output,
|
530 |
+
hidden_states=encoder_outputs.hidden_states,
|
531 |
+
attentions=encoder_outputs.attentions,
|
532 |
+
)
|
533 |
+
|
534 |
+
|
535 |
+
@add_start_docstrings(
|
536 |
+
"""The text model from CLIP without any head or projection on top.""",
|
537 |
+
CLIP_START_DOCSTRING,
|
538 |
+
)
|
539 |
+
class CLIPTextModel(CLIPPreTrainedModel):
|
540 |
+
config_class = CLIPTextConfig
|
541 |
+
|
542 |
+
_no_split_modules = ["CLIPEncoderLayer"]
|
543 |
+
|
544 |
+
def __init__(self, config: CLIPTextConfig):
|
545 |
+
super().__init__(config)
|
546 |
+
self.text_model = CLIPTextTransformer(config)
|
547 |
+
# Initialize weights and apply final processing
|
548 |
+
self.post_init()
|
549 |
+
|
550 |
+
def get_input_embeddings(self) -> nn.Module:
|
551 |
+
return self.text_model.embeddings.token_embedding
|
552 |
+
|
553 |
+
def set_input_embeddings(self, value):
|
554 |
+
self.text_model.embeddings.token_embedding = value
|
555 |
+
|
556 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
557 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
558 |
+
def forward(
|
559 |
+
self,
|
560 |
+
input_ids: Optional[torch.Tensor] = None,
|
561 |
+
attention_mask: Optional[torch.Tensor] = None,
|
562 |
+
position_ids: Optional[torch.Tensor] = None,
|
563 |
+
output_attentions: Optional[bool] = None,
|
564 |
+
output_hidden_states: Optional[bool] = None,
|
565 |
+
return_dict: Optional[bool] = None,
|
566 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
567 |
+
r"""
|
568 |
+
Returns:
|
569 |
+
|
570 |
+
Examples:
|
571 |
+
|
572 |
+
```python
|
573 |
+
>>> from transformers import AutoTokenizer, CLIPTextModel
|
574 |
+
|
575 |
+
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
576 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
577 |
+
|
578 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
579 |
+
|
580 |
+
>>> outputs = model(**inputs)
|
581 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
582 |
+
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
583 |
+
```"""
|
584 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
585 |
+
|
586 |
+
return self.text_model(
|
587 |
+
input_ids=input_ids,
|
588 |
+
attention_mask=attention_mask,
|
589 |
+
position_ids=position_ids,
|
590 |
+
output_attentions=output_attentions,
|
591 |
+
output_hidden_states=output_hidden_states,
|
592 |
+
return_dict=return_dict,
|
593 |
+
)
|
594 |
+
|
595 |
+
|
596 |
+
class CLIPVisionTransformer(nn.Module):
|
597 |
+
def __init__(self, config: CLIPVisionConfig):
|
598 |
+
super().__init__()
|
599 |
+
self.config = config
|
600 |
+
embed_dim = config.hidden_size
|
601 |
+
|
602 |
+
self.embeddings = CLIPVisionEmbeddings(config)
|
603 |
+
self.patch_dropout = PatchDropout(config.force_patch_dropout)
|
604 |
+
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
605 |
+
self.encoder = CLIPEncoder(config)
|
606 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
607 |
+
|
608 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
609 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
|
610 |
+
def forward(
|
611 |
+
self,
|
612 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
613 |
+
output_attentions: Optional[bool] = None,
|
614 |
+
output_hidden_states: Optional[bool] = None,
|
615 |
+
return_dict: Optional[bool] = None,
|
616 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
617 |
+
r"""
|
618 |
+
Returns:
|
619 |
+
|
620 |
+
"""
|
621 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
622 |
+
output_hidden_states = (
|
623 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
624 |
+
)
|
625 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
626 |
+
|
627 |
+
if pixel_values is None:
|
628 |
+
raise ValueError("You have to specify pixel_values")
|
629 |
+
######################################
|
630 |
+
if len(pixel_values.shape) == 7:
|
631 |
+
b_new, pair_new, T, bs_new, channel_new, h_new, w_new = pixel_values.shape
|
632 |
+
# print(pixel_values.shape)
|
633 |
+
B = b_new * pair_new * bs_new
|
634 |
+
pixel_values = pixel_values.reshape(B*T, channel_new, h_new, w_new)
|
635 |
+
|
636 |
+
elif len(pixel_values.shape) == 5:
|
637 |
+
B, _, T, _, _ = pixel_values.shape
|
638 |
+
# print(pixel_values.shape)
|
639 |
+
pixel_values = rearrange(pixel_values, 'b c t h w -> (b t) c h w')
|
640 |
+
else:
|
641 |
+
# print(pixel_values.shape)
|
642 |
+
B, _, _, _ = pixel_values.shape
|
643 |
+
T = 1
|
644 |
+
###########################
|
645 |
+
hidden_states = self.embeddings(pixel_values)
|
646 |
+
|
647 |
+
hidden_states = self.patch_dropout(hidden_states, B, T) ##############################################
|
648 |
+
|
649 |
+
hidden_states = self.pre_layrnorm(hidden_states)
|
650 |
+
|
651 |
+
encoder_outputs = self.encoder(
|
652 |
+
inputs_embeds=hidden_states,
|
653 |
+
output_attentions=output_attentions,
|
654 |
+
output_hidden_states=output_hidden_states,
|
655 |
+
return_dict=return_dict,
|
656 |
+
)
|
657 |
+
|
658 |
+
last_hidden_state = encoder_outputs[0]
|
659 |
+
pooled_output = last_hidden_state[:, 0, :]
|
660 |
+
pooled_output = self.post_layernorm(pooled_output)
|
661 |
+
|
662 |
+
pooled_output = pooled_output.reshape(B, T, -1).mean(1) ################################
|
663 |
+
|
664 |
+
if not return_dict:
|
665 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
666 |
+
|
667 |
+
return BaseModelOutputWithPooling(
|
668 |
+
last_hidden_state=last_hidden_state,
|
669 |
+
pooler_output=pooled_output,
|
670 |
+
hidden_states=encoder_outputs.hidden_states,
|
671 |
+
attentions=encoder_outputs.attentions,
|
672 |
+
)
|
673 |
+
|
674 |
+
|
675 |
+
@add_start_docstrings(
|
676 |
+
"""The vision model from CLIP without any head or projection on top.""",
|
677 |
+
CLIP_START_DOCSTRING,
|
678 |
+
)
|
679 |
+
class CLIPVisionModel(CLIPPreTrainedModel):
|
680 |
+
config_class = CLIPVisionConfig
|
681 |
+
main_input_name = "pixel_values"
|
682 |
+
|
683 |
+
def __init__(self, config: CLIPVisionConfig):
|
684 |
+
super().__init__(config)
|
685 |
+
self.vision_model = CLIPVisionTransformer(config)
|
686 |
+
# Initialize weights and apply final processing
|
687 |
+
self.post_init()
|
688 |
+
|
689 |
+
def get_input_embeddings(self) -> nn.Module:
|
690 |
+
return self.vision_model.embeddings.patch_embedding
|
691 |
+
|
692 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
693 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
|
694 |
+
def forward(
|
695 |
+
self,
|
696 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
697 |
+
output_attentions: Optional[bool] = None,
|
698 |
+
output_hidden_states: Optional[bool] = None,
|
699 |
+
return_dict: Optional[bool] = None,
|
700 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
701 |
+
r"""
|
702 |
+
Returns:
|
703 |
+
|
704 |
+
Examples:
|
705 |
+
|
706 |
+
```python
|
707 |
+
>>> from PIL import Image
|
708 |
+
>>> import requests
|
709 |
+
>>> from transformers import AutoProcessor, CLIPVisionModel
|
710 |
+
|
711 |
+
>>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
|
712 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
713 |
+
|
714 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
715 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
716 |
+
|
717 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
718 |
+
|
719 |
+
>>> outputs = model(**inputs)
|
720 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
721 |
+
>>> pooled_output = outputs.pooler_output # pooled CLS states
|
722 |
+
```"""
|
723 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
724 |
+
|
725 |
+
return self.vision_model(
|
726 |
+
pixel_values=pixel_values,
|
727 |
+
output_attentions=output_attentions,
|
728 |
+
output_hidden_states=output_hidden_states,
|
729 |
+
return_dict=return_dict,
|
730 |
+
)
|
731 |
+
|
732 |
+
|
733 |
+
@add_start_docstrings(CLIP_START_DOCSTRING)
|
734 |
+
class LanguageBindDepth(CLIPPreTrainedModel):
|
735 |
+
config_class = LanguageBindDepthConfig
|
736 |
+
|
737 |
+
def __init__(self, config: LanguageBindDepthConfig):
|
738 |
+
super().__init__(config)
|
739 |
+
|
740 |
+
if not isinstance(config.text_config, CLIPTextConfig):
|
741 |
+
raise ValueError(
|
742 |
+
"config.text_config is expected to be of type CLIPTextConfig but is of type"
|
743 |
+
f" {type(config.text_config)}."
|
744 |
+
)
|
745 |
+
|
746 |
+
if not isinstance(config.vision_config, CLIPVisionConfig):
|
747 |
+
raise ValueError(
|
748 |
+
"config.vision_config is expected to be of type CLIPVisionConfig but is of type"
|
749 |
+
f" {type(config.vision_config)}."
|
750 |
+
)
|
751 |
+
|
752 |
+
text_config = config.text_config
|
753 |
+
vision_config = config.vision_config
|
754 |
+
self.add_time_attn = vision_config.add_time_attn
|
755 |
+
self.lora_r = vision_config.lora_r
|
756 |
+
self.lora_alpha = vision_config.lora_alpha
|
757 |
+
self.lora_dropout = vision_config.lora_dropout
|
758 |
+
|
759 |
+
self.projection_dim = config.projection_dim
|
760 |
+
self.text_embed_dim = text_config.hidden_size
|
761 |
+
self.vision_embed_dim = vision_config.hidden_size
|
762 |
+
|
763 |
+
self.text_model = CLIPTextTransformer(text_config)
|
764 |
+
self.vision_model = CLIPVisionTransformer(vision_config)
|
765 |
+
|
766 |
+
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
|
767 |
+
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
|
768 |
+
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
|
769 |
+
|
770 |
+
# Initialize weights and apply final processing
|
771 |
+
self.post_init()
|
772 |
+
self.convert_to_lora()
|
773 |
+
self.resize_pos(self.vision_model.embeddings, vision_config)
|
774 |
+
|
775 |
+
def convert_to_lora(self):
|
776 |
+
if self.lora_r == 0:
|
777 |
+
return
|
778 |
+
if self.add_time_attn:
|
779 |
+
target_modules = ["temporal_attn.k_proj", "temporal_attn.v_proj",
|
780 |
+
"temporal_attn.q_proj", "temporal_attn.out_proj",
|
781 |
+
"temporal_mlp.fc1", "temporal_mlp.fc2"]
|
782 |
+
else:
|
783 |
+
target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"]
|
784 |
+
config = LoraConfig(
|
785 |
+
r=self.lora_r, # 16
|
786 |
+
lora_alpha=self.lora_alpha, # 16
|
787 |
+
target_modules=target_modules, # self_attn.out_proj
|
788 |
+
lora_dropout=self.lora_dropout, # 0.1
|
789 |
+
bias="none",
|
790 |
+
modules_to_save=[],
|
791 |
+
)
|
792 |
+
self.vision_model.encoder.is_gradient_checkpointing = False
|
793 |
+
self.vision_model.encoder = get_peft_model(self.vision_model.encoder, config)
|
794 |
+
|
795 |
+
def resize_pos(self, m, vision_config):
|
796 |
+
# convert embedding
|
797 |
+
if vision_config.num_mel_bins!=0 and vision_config.target_length!=0:
|
798 |
+
m.image_size = [vision_config.num_mel_bins, vision_config.target_length]
|
799 |
+
m.config.image_size = [m.image_size, m.image_size] if isinstance(m.image_size, int) else m.image_size
|
800 |
+
# pos resize
|
801 |
+
old_pos_embed_state_dict = m.position_embedding.state_dict()
|
802 |
+
old_pos_embed = old_pos_embed_state_dict['weight']
|
803 |
+
dtype = old_pos_embed.dtype
|
804 |
+
grid_size = [m.config.image_size[0] // m.patch_size, m.config.image_size[1] // m.patch_size]
|
805 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
806 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
807 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
808 |
+
# m.to(args.device)
|
809 |
+
return
|
810 |
+
|
811 |
+
m.num_patches = grid_size[0] * grid_size[1]
|
812 |
+
m.num_positions = m.num_patches + 1
|
813 |
+
m.register_buffer("position_ids", torch.arange(m.num_positions).expand((1, -1)))
|
814 |
+
new_position_embedding = nn.Embedding(m.num_positions, m.embed_dim)
|
815 |
+
|
816 |
+
if extra_tokens:
|
817 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
818 |
+
else:
|
819 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
820 |
+
old_grid_size = [int(math.sqrt(len(pos_emb_img)))] * 2
|
821 |
+
|
822 |
+
# if is_master(args):
|
823 |
+
# logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
824 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
825 |
+
pos_emb_img = F.interpolate(
|
826 |
+
pos_emb_img,
|
827 |
+
size=grid_size,
|
828 |
+
mode='bicubic',
|
829 |
+
antialias=True,
|
830 |
+
align_corners=False,
|
831 |
+
)
|
832 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
833 |
+
if pos_emb_tok is not None:
|
834 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
835 |
+
else:
|
836 |
+
new_pos_embed = pos_emb_img
|
837 |
+
old_pos_embed_state_dict['weight'] = new_pos_embed.to(dtype)
|
838 |
+
m.position_embedding = new_position_embedding
|
839 |
+
m.position_embedding.load_state_dict(old_pos_embed_state_dict)
|
840 |
+
|
841 |
+
# m.to(args.device)
|
842 |
+
|
843 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
844 |
+
def get_text_features(
|
845 |
+
self,
|
846 |
+
input_ids: Optional[torch.Tensor] = None,
|
847 |
+
attention_mask: Optional[torch.Tensor] = None,
|
848 |
+
position_ids: Optional[torch.Tensor] = None,
|
849 |
+
output_attentions: Optional[bool] = None,
|
850 |
+
output_hidden_states: Optional[bool] = None,
|
851 |
+
return_dict: Optional[bool] = None,
|
852 |
+
) -> torch.FloatTensor:
|
853 |
+
r"""
|
854 |
+
Returns:
|
855 |
+
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
|
856 |
+
applying the projection layer to the pooled output of [`CLIPTextModel`].
|
857 |
+
|
858 |
+
Examples:
|
859 |
+
|
860 |
+
```python
|
861 |
+
>>> from transformers import AutoTokenizer, CLIPModel
|
862 |
+
|
863 |
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
864 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
865 |
+
|
866 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
867 |
+
>>> text_features = model.get_text_features(**inputs)
|
868 |
+
```"""
|
869 |
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
870 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
871 |
+
output_hidden_states = (
|
872 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
873 |
+
)
|
874 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
875 |
+
|
876 |
+
text_outputs = self.text_model(
|
877 |
+
input_ids=input_ids,
|
878 |
+
attention_mask=attention_mask,
|
879 |
+
position_ids=position_ids,
|
880 |
+
output_attentions=output_attentions,
|
881 |
+
output_hidden_states=output_hidden_states,
|
882 |
+
return_dict=return_dict,
|
883 |
+
)
|
884 |
+
|
885 |
+
pooled_output = text_outputs[1]
|
886 |
+
text_features = self.text_projection(pooled_output)
|
887 |
+
|
888 |
+
return text_features
|
889 |
+
|
890 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
891 |
+
def get_image_features(
|
892 |
+
self,
|
893 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
894 |
+
output_attentions: Optional[bool] = None,
|
895 |
+
output_hidden_states: Optional[bool] = None,
|
896 |
+
return_dict: Optional[bool] = None,
|
897 |
+
) -> torch.FloatTensor:
|
898 |
+
r"""
|
899 |
+
Returns:
|
900 |
+
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
901 |
+
applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
902 |
+
|
903 |
+
Examples:
|
904 |
+
|
905 |
+
```python
|
906 |
+
>>> from PIL import Image
|
907 |
+
>>> import requests
|
908 |
+
>>> from transformers import AutoProcessor, CLIPModel
|
909 |
+
|
910 |
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
911 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
912 |
+
|
913 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
914 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
915 |
+
|
916 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
917 |
+
|
918 |
+
>>> image_features = model.get_image_features(**inputs)
|
919 |
+
```"""
|
920 |
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
921 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
922 |
+
output_hidden_states = (
|
923 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
924 |
+
)
|
925 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
926 |
+
|
927 |
+
vision_outputs = self.vision_model(
|
928 |
+
pixel_values=pixel_values,
|
929 |
+
output_attentions=output_attentions,
|
930 |
+
output_hidden_states=output_hidden_states,
|
931 |
+
return_dict=return_dict,
|
932 |
+
)
|
933 |
+
|
934 |
+
pooled_output = vision_outputs[1] # pooled_output
|
935 |
+
image_features = self.visual_projection(pooled_output)
|
936 |
+
|
937 |
+
return image_features
|
938 |
+
|
939 |
+
@add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
|
940 |
+
@replace_return_docstrings(output_type=CLIPOutput, config_class=LanguageBindDepthConfig)
|
941 |
+
def forward(
|
942 |
+
self,
|
943 |
+
input_ids: Optional[torch.LongTensor] = None,
|
944 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
945 |
+
attention_mask: Optional[torch.Tensor] = None,
|
946 |
+
position_ids: Optional[torch.LongTensor] = None,
|
947 |
+
return_loss: Optional[bool] = None,
|
948 |
+
output_attentions: Optional[bool] = None,
|
949 |
+
output_hidden_states: Optional[bool] = None,
|
950 |
+
return_dict: Optional[bool] = None,
|
951 |
+
) -> Union[Tuple, CLIPOutput]:
|
952 |
+
r"""
|
953 |
+
Returns:
|
954 |
+
|
955 |
+
Examples:
|
956 |
+
|
957 |
+
```python
|
958 |
+
>>> from PIL import Image
|
959 |
+
>>> import requests
|
960 |
+
>>> from transformers import AutoProcessor, CLIPModel
|
961 |
+
|
962 |
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
963 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
964 |
+
|
965 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
966 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
967 |
+
|
968 |
+
>>> inputs = processor(
|
969 |
+
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
|
970 |
+
... )
|
971 |
+
|
972 |
+
>>> outputs = model(**inputs)
|
973 |
+
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
974 |
+
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
975 |
+
```"""
|
976 |
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
977 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
978 |
+
output_hidden_states = (
|
979 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
980 |
+
)
|
981 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
982 |
+
|
983 |
+
vision_outputs = self.vision_model(
|
984 |
+
pixel_values=pixel_values,
|
985 |
+
output_attentions=output_attentions,
|
986 |
+
output_hidden_states=output_hidden_states,
|
987 |
+
return_dict=return_dict,
|
988 |
+
)
|
989 |
+
|
990 |
+
text_outputs = self.text_model(
|
991 |
+
input_ids=input_ids,
|
992 |
+
attention_mask=attention_mask,
|
993 |
+
position_ids=position_ids,
|
994 |
+
output_attentions=output_attentions,
|
995 |
+
output_hidden_states=output_hidden_states,
|
996 |
+
return_dict=return_dict,
|
997 |
+
)
|
998 |
+
|
999 |
+
image_embeds = vision_outputs[1]
|
1000 |
+
image_embeds = self.visual_projection(image_embeds)
|
1001 |
+
|
1002 |
+
text_embeds = text_outputs[1]
|
1003 |
+
text_embeds = self.text_projection(text_embeds)
|
1004 |
+
|
1005 |
+
# normalized features
|
1006 |
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
1007 |
+
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
1008 |
+
|
1009 |
+
# cosine similarity as logits
|
1010 |
+
logit_scale = self.logit_scale.exp()
|
1011 |
+
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
1012 |
+
logits_per_image = logits_per_text.t()
|
1013 |
+
|
1014 |
+
loss = None
|
1015 |
+
if return_loss:
|
1016 |
+
loss = clip_loss(logits_per_text)
|
1017 |
+
|
1018 |
+
if not return_dict:
|
1019 |
+
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
1020 |
+
return ((loss,) + output) if loss is not None else output
|
1021 |
+
|
1022 |
+
return CLIPOutput(
|
1023 |
+
loss=loss,
|
1024 |
+
logits_per_image=logits_per_image,
|
1025 |
+
logits_per_text=logits_per_text,
|
1026 |
+
text_embeds=text_embeds,
|
1027 |
+
image_embeds=image_embeds,
|
1028 |
+
text_model_output=text_outputs,
|
1029 |
+
vision_model_output=vision_outputs,
|
1030 |
+
)
|
models/multimodal_encoder/languagebind/depth/processing_depth.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from torch import nn
|
5 |
+
from torchvision import transforms
|
6 |
+
from transformers import ProcessorMixin, BatchEncoding
|
7 |
+
from transformers.image_processing_utils import BatchFeature
|
8 |
+
|
9 |
+
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
10 |
+
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
11 |
+
|
12 |
+
def make_list_of_images(x):
|
13 |
+
if not isinstance(x, list):
|
14 |
+
return [x]
|
15 |
+
return x
|
16 |
+
|
17 |
+
def opencv_loader(path):
|
18 |
+
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype('float32')
|
19 |
+
|
20 |
+
|
21 |
+
class DepthNorm(nn.Module):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
max_depth=0,
|
25 |
+
min_depth=0.01,
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
self.max_depth = max_depth
|
29 |
+
self.min_depth = min_depth
|
30 |
+
self.scale = 1000.0 # nyuv2 abs.depth
|
31 |
+
|
32 |
+
def forward(self, image):
|
33 |
+
# image = np.array(image)
|
34 |
+
depth_img = image / self.scale # (H, W) in meters
|
35 |
+
depth_img = depth_img.clip(min=self.min_depth)
|
36 |
+
if self.max_depth != 0:
|
37 |
+
depth_img = depth_img.clip(max=self.max_depth)
|
38 |
+
depth_img /= self.max_depth # 0-1
|
39 |
+
else:
|
40 |
+
depth_img /= depth_img.max()
|
41 |
+
depth_img = torch.from_numpy(depth_img).unsqueeze(0).repeat(3, 1, 1) # assume image
|
42 |
+
return depth_img.to(torch.get_default_dtype())
|
43 |
+
|
44 |
+
def get_depth_transform(config):
|
45 |
+
config = config.vision_config
|
46 |
+
transform = transforms.Compose(
|
47 |
+
[
|
48 |
+
DepthNorm(max_depth=config.max_depth),
|
49 |
+
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
|
50 |
+
transforms.CenterCrop(224),
|
51 |
+
transforms.Normalize(OPENAI_DATASET_MEAN, OPENAI_DATASET_STD), # assume image
|
52 |
+
# transforms.Normalize((0.5, ), (0.5, )) # 0-1 to norm distribution
|
53 |
+
# transforms.Normalize((0.0418, ), (0.0295, )) # sun rgb-d imagebind
|
54 |
+
# transforms.Normalize((0.02, ), (0.00295, )) # nyuv2
|
55 |
+
]
|
56 |
+
)
|
57 |
+
return transform
|
58 |
+
|
59 |
+
def load_and_transform_depth(depth_path, transform):
|
60 |
+
depth = opencv_loader(depth_path)
|
61 |
+
depth_outputs = transform(depth)
|
62 |
+
return depth_outputs
|
63 |
+
|
64 |
+
class LanguageBindDepthProcessor(ProcessorMixin):
|
65 |
+
attributes = []
|
66 |
+
tokenizer_class = ("LanguageBindDepthTokenizer")
|
67 |
+
|
68 |
+
def __init__(self, config, tokenizer=None, **kwargs):
|
69 |
+
super().__init__(**kwargs)
|
70 |
+
self.config = config
|
71 |
+
self.transform = get_depth_transform(config)
|
72 |
+
self.image_processor = load_and_transform_depth
|
73 |
+
self.tokenizer = tokenizer
|
74 |
+
|
75 |
+
def __call__(self, images=None, text=None, context_length=77, return_tensors=None, **kwargs):
|
76 |
+
if text is None and images is None:
|
77 |
+
raise ValueError("You have to specify either text or images. Both cannot be none.")
|
78 |
+
|
79 |
+
if text is not None:
|
80 |
+
encoding = self.tokenizer(text, max_length=context_length, padding='max_length',
|
81 |
+
truncation=True, return_tensors=return_tensors, **kwargs)
|
82 |
+
|
83 |
+
if images is not None:
|
84 |
+
images = make_list_of_images(images)
|
85 |
+
image_features = [self.image_processor(image, self.transform) for image in images]
|
86 |
+
image_features = torch.stack(image_features)
|
87 |
+
|
88 |
+
if text is not None and images is not None:
|
89 |
+
encoding["pixel_values"] = image_features
|
90 |
+
return encoding
|
91 |
+
elif text is not None:
|
92 |
+
return encoding
|
93 |
+
else:
|
94 |
+
return {"pixel_values": image_features}
|
95 |
+
|
96 |
+
def batch_decode(self, skip_special_tokens=True, *args, **kwargs):
|
97 |
+
"""
|
98 |
+
This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
99 |
+
refer to the docstring of this method for more information.
|
100 |
+
"""
|
101 |
+
return self.tokenizer.batch_decode(*args, skip_special_tokens=skip_special_tokens, **kwargs)
|
102 |
+
|
103 |
+
def decode(self, skip_special_tokens=True, *args, **kwargs):
|
104 |
+
"""
|
105 |
+
This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
106 |
+
the docstring of this method for more information.
|
107 |
+
"""
|
108 |
+
return self.tokenizer.decode(*args, skip_special_tokens=skip_special_tokens, **kwargs)
|
models/multimodal_encoder/languagebind/depth/tokenization_depth.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import CLIPTokenizer
|
2 |
+
from transformers.utils import logging
|
3 |
+
|
4 |
+
logger = logging.get_logger(__name__)
|
5 |
+
|
6 |
+
VOCAB_FILES_NAMES = {
|
7 |
+
"vocab_file": "vocab.json",
|
8 |
+
"merges_file": "merges.txt",
|
9 |
+
}
|
10 |
+
|
11 |
+
PRETRAINED_VOCAB_FILES_MAP = {
|
12 |
+
"vocab_file": {
|
13 |
+
"lb203/LanguageBind-Depth": "https://huggingface.co/lb203/LanguageBind-Depth/resolve/main/vocab.json",
|
14 |
+
},
|
15 |
+
"merges_file": {
|
16 |
+
"lb203/LanguageBind-Depth": "https://huggingface.co/lb203/LanguageBind-Depth/resolve/main/merges.txt",
|
17 |
+
},
|
18 |
+
}
|
19 |
+
|
20 |
+
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
21 |
+
"lb203/LanguageBind-Depth": 77,
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
PRETRAINED_INIT_CONFIGURATION = {
|
26 |
+
"lb203/LanguageBind-Thermal": {},
|
27 |
+
}
|
28 |
+
|
29 |
+
class LanguageBindDepthTokenizer(CLIPTokenizer):
|
30 |
+
"""
|
31 |
+
Construct a CLIP tokenizer. Based on byte-level Byte-Pair-Encoding.
|
32 |
+
|
33 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
34 |
+
this superclass for more information regarding those methods.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
vocab_file (`str`):
|
38 |
+
Path to the vocabulary file.
|
39 |
+
merges_file (`str`):
|
40 |
+
Path to the merges file.
|
41 |
+
errors (`str`, *optional*, defaults to `"replace"`):
|
42 |
+
Paradigm to follow when decoding bytes to UTF-8. See
|
43 |
+
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
|
44 |
+
unk_token (`str`, *optional*, defaults to `<|endoftext|>`):
|
45 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
46 |
+
token instead.
|
47 |
+
bos_token (`str`, *optional*, defaults to `<|startoftext|>`):
|
48 |
+
The beginning of sequence token.
|
49 |
+
eos_token (`str`, *optional*, defaults to `<|endoftext|>`):
|
50 |
+
The end of sequence token.
|
51 |
+
"""
|
52 |
+
|
53 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
54 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
55 |
+
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
56 |
+
model_input_names = ["input_ids", "attention_mask"]
|
57 |
+
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
vocab_file,
|
61 |
+
merges_file,
|
62 |
+
errors="replace",
|
63 |
+
unk_token="<|endoftext|>",
|
64 |
+
bos_token="<|startoftext|>",
|
65 |
+
eos_token="<|endoftext|>",
|
66 |
+
pad_token="<|endoftext|>", # hack to enable padding
|
67 |
+
**kwargs,
|
68 |
+
):
|
69 |
+
super(LanguageBindDepthTokenizer, self).__init__(
|
70 |
+
vocab_file,
|
71 |
+
merges_file,
|
72 |
+
errors,
|
73 |
+
unk_token,
|
74 |
+
bos_token,
|
75 |
+
eos_token,
|
76 |
+
pad_token, # hack to enable padding
|
77 |
+
**kwargs,)
|
models/multimodal_encoder/languagebind/image/configuration_image.py
ADDED
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
from transformers import PretrainedConfig
|
6 |
+
from transformers.utils import logging
|
7 |
+
|
8 |
+
logger = logging.get_logger(__name__)
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
class CLIPTextConfig(PretrainedConfig):
|
17 |
+
r"""
|
18 |
+
This is the configuration class to store the configuration of a [`CLIPTextModel`]. It is used to instantiate a CLIP
|
19 |
+
text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
|
20 |
+
with the defaults will yield a similar configuration to that of the text encoder of the CLIP
|
21 |
+
[openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
|
22 |
+
|
23 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
24 |
+
documentation from [`PretrainedConfig`] for more information.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
vocab_size (`int`, *optional*, defaults to 49408):
|
28 |
+
Vocabulary size of the CLIP text model. Defines the number of different tokens that can be represented by
|
29 |
+
the `inputs_ids` passed when calling [`CLIPModel`].
|
30 |
+
hidden_size (`int`, *optional*, defaults to 512):
|
31 |
+
Dimensionality of the encoder layers and the pooler layer.
|
32 |
+
intermediate_size (`int`, *optional*, defaults to 2048):
|
33 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
34 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
35 |
+
Number of hidden layers in the Transformer encoder.
|
36 |
+
num_attention_heads (`int`, *optional*, defaults to 8):
|
37 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
38 |
+
max_position_embeddings (`int`, *optional*, defaults to 77):
|
39 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
40 |
+
just in case (e.g., 512 or 1024 or 2048).
|
41 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
|
42 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
43 |
+
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
44 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
45 |
+
The epsilon used by the layer normalization layers.
|
46 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
47 |
+
The dropout ratio for the attention probabilities.
|
48 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
49 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
50 |
+
initializer_factor (`float`, *optional*, defaults to 1):
|
51 |
+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
52 |
+
testing).
|
53 |
+
|
54 |
+
Example:
|
55 |
+
|
56 |
+
```python
|
57 |
+
>>> from transformers import CLIPTextConfig, CLIPTextModel
|
58 |
+
|
59 |
+
>>> # Initializing a CLIPTextConfig with openai/clip-vit-base-patch32 style configuration
|
60 |
+
>>> configuration = CLIPTextConfig()
|
61 |
+
|
62 |
+
>>> # Initializing a CLIPTextModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
|
63 |
+
>>> model = CLIPTextModel(configuration)
|
64 |
+
|
65 |
+
>>> # Accessing the model configuration
|
66 |
+
>>> configuration = model.config
|
67 |
+
```"""
|
68 |
+
model_type = "clip_text_model"
|
69 |
+
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
vocab_size=49408,
|
73 |
+
hidden_size=512,
|
74 |
+
intermediate_size=2048,
|
75 |
+
projection_dim=512,
|
76 |
+
num_hidden_layers=12,
|
77 |
+
num_attention_heads=8,
|
78 |
+
max_position_embeddings=77,
|
79 |
+
hidden_act="quick_gelu",
|
80 |
+
layer_norm_eps=1e-5,
|
81 |
+
attention_dropout=0.0,
|
82 |
+
initializer_range=0.02,
|
83 |
+
initializer_factor=1.0,
|
84 |
+
# This differs from `CLIPTokenizer`'s default and from openai/clip
|
85 |
+
# See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
|
86 |
+
pad_token_id=1,
|
87 |
+
bos_token_id=49406,
|
88 |
+
eos_token_id=49407,
|
89 |
+
**kwargs,
|
90 |
+
):
|
91 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
92 |
+
|
93 |
+
self.vocab_size = vocab_size
|
94 |
+
self.hidden_size = hidden_size
|
95 |
+
self.intermediate_size = intermediate_size
|
96 |
+
self.projection_dim = projection_dim
|
97 |
+
self.num_hidden_layers = num_hidden_layers
|
98 |
+
self.num_attention_heads = num_attention_heads
|
99 |
+
self.max_position_embeddings = max_position_embeddings
|
100 |
+
self.layer_norm_eps = layer_norm_eps
|
101 |
+
self.hidden_act = hidden_act
|
102 |
+
self.initializer_range = initializer_range
|
103 |
+
self.initializer_factor = initializer_factor
|
104 |
+
self.attention_dropout = attention_dropout
|
105 |
+
self.add_time_attn = False ######################################
|
106 |
+
|
107 |
+
@classmethod
|
108 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
109 |
+
cls._set_token_in_kwargs(kwargs)
|
110 |
+
|
111 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
112 |
+
|
113 |
+
# get the text config dict if we are loading from CLIPConfig
|
114 |
+
if config_dict.get("model_type") == "clip":
|
115 |
+
config_dict = config_dict["text_config"]
|
116 |
+
|
117 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
118 |
+
logger.warning(
|
119 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
120 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
121 |
+
)
|
122 |
+
|
123 |
+
return cls.from_dict(config_dict, **kwargs)
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
class CLIPVisionConfig(PretrainedConfig):
|
129 |
+
r"""
|
130 |
+
This is the configuration class to store the configuration of a [`CLIPVisionModel`]. It is used to instantiate a
|
131 |
+
CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
132 |
+
configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP
|
133 |
+
[openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
|
134 |
+
|
135 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
136 |
+
documentation from [`PretrainedConfig`] for more information.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
140 |
+
Dimensionality of the encoder layers and the pooler layer.
|
141 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
142 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
143 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
144 |
+
Number of hidden layers in the Transformer encoder.
|
145 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
146 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
147 |
+
image_size (`int`, *optional*, defaults to 224):
|
148 |
+
The size (resolution) of each image.
|
149 |
+
patch_size (`int`, *optional*, defaults to 32):
|
150 |
+
The size (resolution) of each patch.
|
151 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
|
152 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
153 |
+
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
|
154 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
155 |
+
The epsilon used by the layer normalization layers.
|
156 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
157 |
+
The dropout ratio for the attention probabilities.
|
158 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
159 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
160 |
+
initializer_factor (`float`, *optional*, defaults to 1):
|
161 |
+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
162 |
+
testing).
|
163 |
+
|
164 |
+
Example:
|
165 |
+
|
166 |
+
```python
|
167 |
+
>>> from transformers import CLIPVisionConfig, CLIPVisionModel
|
168 |
+
|
169 |
+
>>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration
|
170 |
+
>>> configuration = CLIPVisionConfig()
|
171 |
+
|
172 |
+
>>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
|
173 |
+
>>> model = CLIPVisionModel(configuration)
|
174 |
+
|
175 |
+
>>> # Accessing the model configuration
|
176 |
+
>>> configuration = model.config
|
177 |
+
```"""
|
178 |
+
|
179 |
+
model_type = "clip_vision_model"
|
180 |
+
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
hidden_size=768,
|
184 |
+
intermediate_size=3072,
|
185 |
+
projection_dim=512,
|
186 |
+
num_hidden_layers=12,
|
187 |
+
num_attention_heads=12,
|
188 |
+
num_channels=3,
|
189 |
+
image_size=224,
|
190 |
+
patch_size=32,
|
191 |
+
hidden_act="quick_gelu",
|
192 |
+
layer_norm_eps=1e-5,
|
193 |
+
attention_dropout=0.0,
|
194 |
+
initializer_range=0.02,
|
195 |
+
initializer_factor=1.0,
|
196 |
+
|
197 |
+
add_time_attn=False, ################################
|
198 |
+
num_frames=1, ################################
|
199 |
+
force_patch_dropout=0.0, ################################
|
200 |
+
lora_r=2, ################################
|
201 |
+
lora_alpha=16, ################################
|
202 |
+
lora_dropout=0.0, ################################
|
203 |
+
num_mel_bins=0.0, ################################
|
204 |
+
target_length=0.0, ################################
|
205 |
+
video_decode_backend='decord', #########################
|
206 |
+
**kwargs,
|
207 |
+
):
|
208 |
+
super().__init__(**kwargs)
|
209 |
+
|
210 |
+
self.hidden_size = hidden_size
|
211 |
+
self.intermediate_size = intermediate_size
|
212 |
+
self.projection_dim = projection_dim
|
213 |
+
self.num_hidden_layers = num_hidden_layers
|
214 |
+
self.num_attention_heads = num_attention_heads
|
215 |
+
self.num_channels = num_channels
|
216 |
+
self.patch_size = patch_size
|
217 |
+
self.image_size = image_size
|
218 |
+
self.initializer_range = initializer_range
|
219 |
+
self.initializer_factor = initializer_factor
|
220 |
+
self.attention_dropout = attention_dropout
|
221 |
+
self.layer_norm_eps = layer_norm_eps
|
222 |
+
self.hidden_act = hidden_act
|
223 |
+
|
224 |
+
self.add_time_attn = add_time_attn ################
|
225 |
+
self.num_frames = num_frames ################
|
226 |
+
self.force_patch_dropout = force_patch_dropout ################
|
227 |
+
self.lora_r = lora_r ################
|
228 |
+
self.lora_alpha = lora_alpha ################
|
229 |
+
self.lora_dropout = lora_dropout ################
|
230 |
+
self.num_mel_bins = num_mel_bins ################
|
231 |
+
self.target_length = target_length ################
|
232 |
+
self.video_decode_backend = video_decode_backend ################
|
233 |
+
|
234 |
+
@classmethod
|
235 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
236 |
+
cls._set_token_in_kwargs(kwargs)
|
237 |
+
|
238 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
239 |
+
|
240 |
+
# get the vision config dict if we are loading from CLIPConfig
|
241 |
+
if config_dict.get("model_type") == "clip":
|
242 |
+
config_dict = config_dict["vision_config"]
|
243 |
+
|
244 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
245 |
+
logger.warning(
|
246 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
247 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
248 |
+
)
|
249 |
+
|
250 |
+
return cls.from_dict(config_dict, **kwargs)
|
251 |
+
|
252 |
+
|
253 |
+
class LanguageBindImageConfig(PretrainedConfig):
|
254 |
+
r"""
|
255 |
+
[`CLIPConfig`] is the configuration class to store the configuration of a [`CLIPModel`]. It is used to instantiate
|
256 |
+
a CLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating
|
257 |
+
a configuration with the defaults will yield a similar configuration to that of the CLIP
|
258 |
+
[openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
|
259 |
+
|
260 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
261 |
+
documentation from [`PretrainedConfig`] for more information.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
text_config (`dict`, *optional*):
|
265 |
+
Dictionary of configuration options used to initialize [`CLIPTextConfig`].
|
266 |
+
vision_config (`dict`, *optional*):
|
267 |
+
Dictionary of configuration options used to initialize [`CLIPVisionConfig`].
|
268 |
+
projection_dim (`int`, *optional*, defaults to 512):
|
269 |
+
Dimentionality of text and vision projection layers.
|
270 |
+
logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
|
271 |
+
The inital value of the *logit_scale* paramter. Default is used as per the original CLIP implementation.
|
272 |
+
kwargs (*optional*):
|
273 |
+
Dictionary of keyword arguments.
|
274 |
+
|
275 |
+
Example:
|
276 |
+
|
277 |
+
```python
|
278 |
+
>>> from transformers import CLIPConfig, CLIPModel
|
279 |
+
|
280 |
+
>>> # Initializing a CLIPConfig with openai/clip-vit-base-patch32 style configuration
|
281 |
+
>>> configuration = CLIPConfig()
|
282 |
+
|
283 |
+
>>> # Initializing a CLIPModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
|
284 |
+
>>> model = CLIPModel(configuration)
|
285 |
+
|
286 |
+
>>> # Accessing the model configuration
|
287 |
+
>>> configuration = model.config
|
288 |
+
|
289 |
+
>>> # We can also initialize a CLIPConfig from a CLIPTextConfig and a CLIPVisionConfig
|
290 |
+
>>> from transformers import CLIPTextConfig, CLIPVisionConfig
|
291 |
+
|
292 |
+
>>> # Initializing a CLIPText and CLIPVision configuration
|
293 |
+
>>> config_text = CLIPTextConfig()
|
294 |
+
>>> config_vision = CLIPVisionConfig()
|
295 |
+
|
296 |
+
>>> config = CLIPConfig.from_text_vision_configs(config_text, config_vision)
|
297 |
+
```"""
|
298 |
+
|
299 |
+
model_type = "LanguageBindImage"
|
300 |
+
is_composition = True
|
301 |
+
|
302 |
+
def __init__(
|
303 |
+
self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
|
304 |
+
):
|
305 |
+
# If `_config_dict` exist, we use them for the backward compatibility.
|
306 |
+
# We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
|
307 |
+
# of confusion!).
|
308 |
+
text_config_dict = kwargs.pop("text_config_dict", None)
|
309 |
+
vision_config_dict = kwargs.pop("vision_config_dict", None)
|
310 |
+
|
311 |
+
super().__init__(**kwargs)
|
312 |
+
|
313 |
+
# Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
|
314 |
+
# `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
|
315 |
+
# cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
|
316 |
+
if text_config_dict is not None:
|
317 |
+
if text_config is None:
|
318 |
+
text_config = {}
|
319 |
+
|
320 |
+
# This is the complete result when using `text_config_dict`.
|
321 |
+
_text_config_dict = CLIPTextConfig(**text_config_dict).to_dict()
|
322 |
+
|
323 |
+
# Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
|
324 |
+
for key, value in _text_config_dict.items():
|
325 |
+
if key in text_config and value != text_config[key] and key not in ["transformers_version"]:
|
326 |
+
# If specified in `text_config_dict`
|
327 |
+
if key in text_config_dict:
|
328 |
+
message = (
|
329 |
+
f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. "
|
330 |
+
f'The value `text_config_dict["{key}"]` will be used instead.'
|
331 |
+
)
|
332 |
+
# If inferred from default argument values (just to be super careful)
|
333 |
+
else:
|
334 |
+
message = (
|
335 |
+
f"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The "
|
336 |
+
f'value `text_config["{key}"]` will be overriden.'
|
337 |
+
)
|
338 |
+
logger.warning(message)
|
339 |
+
|
340 |
+
# Update all values in `text_config` with the ones in `_text_config_dict`.
|
341 |
+
text_config.update(_text_config_dict)
|
342 |
+
|
343 |
+
if vision_config_dict is not None:
|
344 |
+
if vision_config is None:
|
345 |
+
vision_config = {}
|
346 |
+
|
347 |
+
# This is the complete result when using `vision_config_dict`.
|
348 |
+
_vision_config_dict = CLIPVisionConfig(**vision_config_dict).to_dict()
|
349 |
+
# convert keys to string instead of integer
|
350 |
+
if "id2label" in _vision_config_dict:
|
351 |
+
_vision_config_dict["id2label"] = {
|
352 |
+
str(key): value for key, value in _vision_config_dict["id2label"].items()
|
353 |
+
}
|
354 |
+
|
355 |
+
# Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different.
|
356 |
+
for key, value in _vision_config_dict.items():
|
357 |
+
if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]:
|
358 |
+
# If specified in `vision_config_dict`
|
359 |
+
if key in vision_config_dict:
|
360 |
+
message = (
|
361 |
+
f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different "
|
362 |
+
f'values. The value `vision_config_dict["{key}"]` will be used instead.'
|
363 |
+
)
|
364 |
+
# If inferred from default argument values (just to be super careful)
|
365 |
+
else:
|
366 |
+
message = (
|
367 |
+
f"`vision_config_dict` is provided which will be used to initialize `CLIPVisionConfig`. "
|
368 |
+
f'The value `vision_config["{key}"]` will be overriden.'
|
369 |
+
)
|
370 |
+
logger.warning(message)
|
371 |
+
|
372 |
+
# Update all values in `vision_config` with the ones in `_vision_config_dict`.
|
373 |
+
vision_config.update(_vision_config_dict)
|
374 |
+
|
375 |
+
if text_config is None:
|
376 |
+
text_config = {}
|
377 |
+
logger.info("`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.")
|
378 |
+
|
379 |
+
if vision_config is None:
|
380 |
+
vision_config = {}
|
381 |
+
logger.info("`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.")
|
382 |
+
|
383 |
+
self.text_config = CLIPTextConfig(**text_config)
|
384 |
+
self.vision_config = CLIPVisionConfig(**vision_config)
|
385 |
+
|
386 |
+
self.projection_dim = projection_dim
|
387 |
+
self.logit_scale_init_value = logit_scale_init_value
|
388 |
+
self.initializer_factor = 1.0
|
389 |
+
|
390 |
+
@classmethod
|
391 |
+
def from_text_vision_configs(cls, text_config: CLIPTextConfig, vision_config: CLIPVisionConfig, **kwargs):
|
392 |
+
r"""
|
393 |
+
Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model
|
394 |
+
configuration.
|
395 |
+
|
396 |
+
Returns:
|
397 |
+
[`CLIPConfig`]: An instance of a configuration object
|
398 |
+
"""
|
399 |
+
|
400 |
+
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
|
401 |
+
|
402 |
+
def to_dict(self):
|
403 |
+
"""
|
404 |
+
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
405 |
+
|
406 |
+
Returns:
|
407 |
+
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
408 |
+
"""
|
409 |
+
output = copy.deepcopy(self.__dict__)
|
410 |
+
output["text_config"] = self.text_config.to_dict()
|
411 |
+
output["vision_config"] = self.vision_config.to_dict()
|
412 |
+
output["model_type"] = self.__class__.model_type
|
413 |
+
return output
|
414 |
+
|
415 |
+
|
416 |
+
|
417 |
+
|
418 |
+
|
419 |
+
|
420 |
+
|
421 |
+
|
422 |
+
|
423 |
+
|
models/multimodal_encoder/languagebind/image/modeling_image.py
ADDED
@@ -0,0 +1,1030 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from einops import rearrange
|
6 |
+
from peft import LoraConfig, get_peft_model
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from transformers import PreTrainedModel, add_start_docstrings
|
10 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
11 |
+
from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \
|
12 |
+
CLIPVisionModelWithProjection, CLIPTextModelWithProjection, _expand_mask, CLIPOutput, clip_loss
|
13 |
+
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
|
14 |
+
|
15 |
+
from .configuration_image import LanguageBindImageConfig, CLIPVisionConfig, CLIPTextConfig
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
class PatchDropout(nn.Module):
|
20 |
+
"""
|
21 |
+
https://arxiv.org/abs/2212.00794
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, prob, exclude_first_token=True):
|
25 |
+
super().__init__()
|
26 |
+
assert 0 <= prob < 1.
|
27 |
+
self.prob = prob
|
28 |
+
self.exclude_first_token = exclude_first_token # exclude CLS token
|
29 |
+
|
30 |
+
def forward(self, x, B, T):
|
31 |
+
if not self.training or self.prob == 0.:
|
32 |
+
return x
|
33 |
+
|
34 |
+
if self.exclude_first_token:
|
35 |
+
cls_tokens, x = x[:, :1], x[:, 1:]
|
36 |
+
else:
|
37 |
+
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
38 |
+
|
39 |
+
batch = x.size()[0]
|
40 |
+
num_tokens = x.size()[1]
|
41 |
+
|
42 |
+
batch_indices = torch.arange(batch)
|
43 |
+
batch_indices = batch_indices[..., None]
|
44 |
+
|
45 |
+
keep_prob = 1 - self.prob
|
46 |
+
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
47 |
+
|
48 |
+
if T == 1:
|
49 |
+
rand = torch.randn(batch, num_tokens)
|
50 |
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
51 |
+
else:
|
52 |
+
rand = torch.randn(B, num_tokens)
|
53 |
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
54 |
+
patch_indices_keep = patch_indices_keep.unsqueeze(1).repeat(1, T, 1)
|
55 |
+
patch_indices_keep = rearrange(patch_indices_keep, 'b t n -> (b t) n')
|
56 |
+
|
57 |
+
|
58 |
+
x = x[batch_indices, patch_indices_keep]
|
59 |
+
|
60 |
+
if self.exclude_first_token:
|
61 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
62 |
+
|
63 |
+
return x
|
64 |
+
|
65 |
+
class CLIPEncoderLayer(nn.Module):
|
66 |
+
def __init__(self, config: LanguageBindImageConfig):
|
67 |
+
super().__init__()
|
68 |
+
self.embed_dim = config.hidden_size
|
69 |
+
self.self_attn = CLIPAttention(config)
|
70 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
71 |
+
self.mlp = CLIPMLP(config)
|
72 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
73 |
+
|
74 |
+
self.add_time_attn = config.add_time_attn
|
75 |
+
if self.add_time_attn:
|
76 |
+
self.t = config.num_frames
|
77 |
+
self.temporal_embedding = nn.Parameter(torch.zeros(1, config.num_frames, config.hidden_size))
|
78 |
+
nn.init.normal_(self.temporal_embedding, std=config.hidden_size ** -0.5)
|
79 |
+
|
80 |
+
self.embed_dim = config.hidden_size
|
81 |
+
self.temporal_attn = CLIPAttention(config)
|
82 |
+
self.temporal_layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
83 |
+
self.temporal_mlp = CLIPMLP(config)
|
84 |
+
self.temporal_layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
85 |
+
|
86 |
+
def forward(
|
87 |
+
self,
|
88 |
+
hidden_states: torch.Tensor,
|
89 |
+
attention_mask: torch.Tensor,
|
90 |
+
causal_attention_mask: torch.Tensor,
|
91 |
+
output_attentions: Optional[bool] = False,
|
92 |
+
) -> Tuple[torch.FloatTensor]:
|
93 |
+
"""
|
94 |
+
Args:
|
95 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
96 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
97 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
98 |
+
`(config.encoder_attention_heads,)`.
|
99 |
+
output_attentions (`bool`, *optional*):
|
100 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
101 |
+
returned tensors for more detail.
|
102 |
+
"""
|
103 |
+
|
104 |
+
|
105 |
+
if self.add_time_attn:
|
106 |
+
bt, n, d = hidden_states.shape
|
107 |
+
t = self.t
|
108 |
+
|
109 |
+
# time embed
|
110 |
+
if t != 1:
|
111 |
+
n = hidden_states.shape[1]
|
112 |
+
hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
|
113 |
+
hidden_states = hidden_states + self.temporal_embedding[:, :t, :]
|
114 |
+
hidden_states = rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
|
115 |
+
|
116 |
+
# time attn
|
117 |
+
residual = hidden_states
|
118 |
+
hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
|
119 |
+
# hidden_states = self.layer_norm1(hidden_states) # share layernorm
|
120 |
+
hidden_states = self.temporal_layer_norm1(hidden_states)
|
121 |
+
hidden_states, attn_weights = self.temporal_attn(
|
122 |
+
hidden_states=hidden_states,
|
123 |
+
attention_mask=attention_mask,
|
124 |
+
causal_attention_mask=causal_attention_mask,
|
125 |
+
output_attentions=output_attentions,
|
126 |
+
)
|
127 |
+
hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
|
128 |
+
|
129 |
+
residual = hidden_states
|
130 |
+
hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
|
131 |
+
# hidden_states = self.layer_norm2(hidden_states) # share layernorm
|
132 |
+
hidden_states = self.temporal_layer_norm2(hidden_states)
|
133 |
+
hidden_states = self.temporal_mlp(hidden_states)
|
134 |
+
hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
|
135 |
+
|
136 |
+
# spatial attn
|
137 |
+
residual = hidden_states
|
138 |
+
|
139 |
+
hidden_states = self.layer_norm1(hidden_states)
|
140 |
+
hidden_states, attn_weights = self.self_attn(
|
141 |
+
hidden_states=hidden_states,
|
142 |
+
attention_mask=attention_mask,
|
143 |
+
causal_attention_mask=causal_attention_mask,
|
144 |
+
output_attentions=output_attentions,
|
145 |
+
)
|
146 |
+
hidden_states = residual + hidden_states
|
147 |
+
|
148 |
+
residual = hidden_states
|
149 |
+
hidden_states = self.layer_norm2(hidden_states)
|
150 |
+
hidden_states = self.mlp(hidden_states)
|
151 |
+
hidden_states = residual + hidden_states
|
152 |
+
|
153 |
+
outputs = (hidden_states,)
|
154 |
+
|
155 |
+
if output_attentions:
|
156 |
+
outputs += (attn_weights,)
|
157 |
+
|
158 |
+
return outputs
|
159 |
+
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
class CLIPPreTrainedModel(PreTrainedModel):
|
169 |
+
"""
|
170 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
171 |
+
models.
|
172 |
+
"""
|
173 |
+
|
174 |
+
config_class = LanguageBindImageConfig
|
175 |
+
base_model_prefix = "clip"
|
176 |
+
supports_gradient_checkpointing = True
|
177 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
178 |
+
|
179 |
+
def _init_weights(self, module):
|
180 |
+
"""Initialize the weights"""
|
181 |
+
factor = self.config.initializer_factor
|
182 |
+
if isinstance(module, CLIPTextEmbeddings):
|
183 |
+
module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
184 |
+
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
185 |
+
elif isinstance(module, CLIPVisionEmbeddings):
|
186 |
+
factor = self.config.initializer_factor
|
187 |
+
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
188 |
+
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
189 |
+
nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
|
190 |
+
elif isinstance(module, CLIPAttention):
|
191 |
+
factor = self.config.initializer_factor
|
192 |
+
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
193 |
+
out_proj_std = (module.embed_dim**-0.5) * factor
|
194 |
+
nn.init.normal_(module.q_proj.weight, std=in_proj_std)
|
195 |
+
nn.init.normal_(module.k_proj.weight, std=in_proj_std)
|
196 |
+
nn.init.normal_(module.v_proj.weight, std=in_proj_std)
|
197 |
+
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
|
198 |
+
elif isinstance(module, CLIPMLP):
|
199 |
+
factor = self.config.initializer_factor
|
200 |
+
in_proj_std = (
|
201 |
+
(module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
202 |
+
)
|
203 |
+
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
|
204 |
+
nn.init.normal_(module.fc1.weight, std=fc_std)
|
205 |
+
nn.init.normal_(module.fc2.weight, std=in_proj_std)
|
206 |
+
elif isinstance(module, LanguageBindImage):
|
207 |
+
nn.init.normal_(
|
208 |
+
module.text_projection.weight,
|
209 |
+
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
|
210 |
+
)
|
211 |
+
nn.init.normal_(
|
212 |
+
module.visual_projection.weight,
|
213 |
+
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
|
214 |
+
)
|
215 |
+
elif isinstance(module, CLIPVisionModelWithProjection):
|
216 |
+
nn.init.normal_(
|
217 |
+
module.visual_projection.weight,
|
218 |
+
std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
|
219 |
+
)
|
220 |
+
elif isinstance(module, CLIPTextModelWithProjection):
|
221 |
+
nn.init.normal_(
|
222 |
+
module.text_projection.weight,
|
223 |
+
std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
|
224 |
+
)
|
225 |
+
|
226 |
+
if isinstance(module, nn.LayerNorm):
|
227 |
+
module.bias.data.zero_()
|
228 |
+
module.weight.data.fill_(1.0)
|
229 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
230 |
+
module.bias.data.zero_()
|
231 |
+
|
232 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
233 |
+
if isinstance(module, CLIPEncoder):
|
234 |
+
module.gradient_checkpointing = value
|
235 |
+
|
236 |
+
|
237 |
+
CLIP_START_DOCSTRING = r"""
|
238 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
239 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
240 |
+
etc.)
|
241 |
+
|
242 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
243 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
244 |
+
and behavior.
|
245 |
+
|
246 |
+
Parameters:
|
247 |
+
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
|
248 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
249 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
250 |
+
"""
|
251 |
+
|
252 |
+
CLIP_TEXT_INPUTS_DOCSTRING = r"""
|
253 |
+
Args:
|
254 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
255 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
256 |
+
it.
|
257 |
+
|
258 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
259 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
260 |
+
|
261 |
+
[What are input IDs?](../glossary#input-ids)
|
262 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
263 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
264 |
+
|
265 |
+
- 1 for tokens that are **not masked**,
|
266 |
+
- 0 for tokens that are **masked**.
|
267 |
+
|
268 |
+
[What are attention masks?](../glossary#attention-mask)
|
269 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
270 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
271 |
+
config.max_position_embeddings - 1]`.
|
272 |
+
|
273 |
+
[What are position IDs?](../glossary#position-ids)
|
274 |
+
output_attentions (`bool`, *optional*):
|
275 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
276 |
+
tensors for more detail.
|
277 |
+
output_hidden_states (`bool`, *optional*):
|
278 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
279 |
+
more detail.
|
280 |
+
return_dict (`bool`, *optional*):
|
281 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
282 |
+
"""
|
283 |
+
|
284 |
+
CLIP_VISION_INPUTS_DOCSTRING = r"""
|
285 |
+
Args:
|
286 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
287 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
288 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
289 |
+
output_attentions (`bool`, *optional*):
|
290 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
291 |
+
tensors for more detail.
|
292 |
+
output_hidden_states (`bool`, *optional*):
|
293 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
294 |
+
more detail.
|
295 |
+
return_dict (`bool`, *optional*):
|
296 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
297 |
+
"""
|
298 |
+
|
299 |
+
CLIP_INPUTS_DOCSTRING = r"""
|
300 |
+
Args:
|
301 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
302 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
303 |
+
it.
|
304 |
+
|
305 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
306 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
307 |
+
|
308 |
+
[What are input IDs?](../glossary#input-ids)
|
309 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
310 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
311 |
+
|
312 |
+
- 1 for tokens that are **not masked**,
|
313 |
+
- 0 for tokens that are **masked**.
|
314 |
+
|
315 |
+
[What are attention masks?](../glossary#attention-mask)
|
316 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
317 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
318 |
+
config.max_position_embeddings - 1]`.
|
319 |
+
|
320 |
+
[What are position IDs?](../glossary#position-ids)
|
321 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
322 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
323 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
324 |
+
return_loss (`bool`, *optional*):
|
325 |
+
Whether or not to return the contrastive loss.
|
326 |
+
output_attentions (`bool`, *optional*):
|
327 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
328 |
+
tensors for more detail.
|
329 |
+
output_hidden_states (`bool`, *optional*):
|
330 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
331 |
+
more detail.
|
332 |
+
return_dict (`bool`, *optional*):
|
333 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
334 |
+
"""
|
335 |
+
|
336 |
+
|
337 |
+
class CLIPEncoder(nn.Module):
|
338 |
+
"""
|
339 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
340 |
+
[`CLIPEncoderLayer`].
|
341 |
+
|
342 |
+
Args:
|
343 |
+
config: CLIPConfig
|
344 |
+
"""
|
345 |
+
|
346 |
+
def __init__(self, config: LanguageBindImageConfig):
|
347 |
+
super().__init__()
|
348 |
+
self.config = config
|
349 |
+
self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
350 |
+
self.gradient_checkpointing = False
|
351 |
+
|
352 |
+
def forward(
|
353 |
+
self,
|
354 |
+
inputs_embeds,
|
355 |
+
attention_mask: Optional[torch.Tensor] = None,
|
356 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
357 |
+
output_attentions: Optional[bool] = None,
|
358 |
+
output_hidden_states: Optional[bool] = None,
|
359 |
+
return_dict: Optional[bool] = None,
|
360 |
+
) -> Union[Tuple, BaseModelOutput]:
|
361 |
+
r"""
|
362 |
+
Args:
|
363 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
364 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
365 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
366 |
+
than the model's internal embedding lookup matrix.
|
367 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
368 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
369 |
+
|
370 |
+
- 1 for tokens that are **not masked**,
|
371 |
+
- 0 for tokens that are **masked**.
|
372 |
+
|
373 |
+
[What are attention masks?](../glossary#attention-mask)
|
374 |
+
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
375 |
+
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
376 |
+
|
377 |
+
- 1 for tokens that are **not masked**,
|
378 |
+
- 0 for tokens that are **masked**.
|
379 |
+
|
380 |
+
[What are attention masks?](../glossary#attention-mask)
|
381 |
+
output_attentions (`bool`, *optional*):
|
382 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
383 |
+
returned tensors for more detail.
|
384 |
+
output_hidden_states (`bool`, *optional*):
|
385 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
386 |
+
for more detail.
|
387 |
+
return_dict (`bool`, *optional*):
|
388 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
389 |
+
"""
|
390 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
391 |
+
output_hidden_states = (
|
392 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
393 |
+
)
|
394 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
395 |
+
|
396 |
+
encoder_states = () if output_hidden_states else None
|
397 |
+
all_attentions = () if output_attentions else None
|
398 |
+
|
399 |
+
hidden_states = inputs_embeds
|
400 |
+
for idx, encoder_layer in enumerate(self.layers):
|
401 |
+
if output_hidden_states:
|
402 |
+
encoder_states = encoder_states + (hidden_states,)
|
403 |
+
if self.gradient_checkpointing and self.training:
|
404 |
+
|
405 |
+
def create_custom_forward(module):
|
406 |
+
def custom_forward(*inputs):
|
407 |
+
return module(*inputs, output_attentions)
|
408 |
+
|
409 |
+
return custom_forward
|
410 |
+
|
411 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
412 |
+
create_custom_forward(encoder_layer),
|
413 |
+
hidden_states,
|
414 |
+
attention_mask,
|
415 |
+
causal_attention_mask,
|
416 |
+
)
|
417 |
+
else:
|
418 |
+
layer_outputs = encoder_layer(
|
419 |
+
hidden_states,
|
420 |
+
attention_mask,
|
421 |
+
causal_attention_mask,
|
422 |
+
output_attentions=output_attentions,
|
423 |
+
)
|
424 |
+
|
425 |
+
hidden_states = layer_outputs[0]
|
426 |
+
|
427 |
+
if output_attentions:
|
428 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
429 |
+
|
430 |
+
if output_hidden_states:
|
431 |
+
encoder_states = encoder_states + (hidden_states,)
|
432 |
+
|
433 |
+
if not return_dict:
|
434 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
435 |
+
return BaseModelOutput(
|
436 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
437 |
+
)
|
438 |
+
|
439 |
+
|
440 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
441 |
+
def _make_causal_mask(
|
442 |
+
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
443 |
+
):
|
444 |
+
"""
|
445 |
+
Make causal mask used for bi-directional self-attention.
|
446 |
+
"""
|
447 |
+
bsz, tgt_len = input_ids_shape
|
448 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
449 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
450 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
451 |
+
mask = mask.to(dtype)
|
452 |
+
|
453 |
+
if past_key_values_length > 0:
|
454 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
455 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
456 |
+
|
457 |
+
|
458 |
+
class CLIPTextTransformer(nn.Module):
|
459 |
+
def __init__(self, config: CLIPTextConfig):
|
460 |
+
super().__init__()
|
461 |
+
self.config = config
|
462 |
+
embed_dim = config.hidden_size
|
463 |
+
self.embeddings = CLIPTextEmbeddings(config)
|
464 |
+
self.encoder = CLIPEncoder(config)
|
465 |
+
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
466 |
+
|
467 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
468 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
469 |
+
def forward(
|
470 |
+
self,
|
471 |
+
input_ids: Optional[torch.Tensor] = None,
|
472 |
+
attention_mask: Optional[torch.Tensor] = None,
|
473 |
+
position_ids: Optional[torch.Tensor] = None,
|
474 |
+
output_attentions: Optional[bool] = None,
|
475 |
+
output_hidden_states: Optional[bool] = None,
|
476 |
+
return_dict: Optional[bool] = None,
|
477 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
478 |
+
r"""
|
479 |
+
Returns:
|
480 |
+
|
481 |
+
"""
|
482 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
483 |
+
output_hidden_states = (
|
484 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
485 |
+
)
|
486 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
487 |
+
|
488 |
+
if input_ids is None:
|
489 |
+
raise ValueError("You have to specify input_ids")
|
490 |
+
|
491 |
+
input_shape = input_ids.size()
|
492 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
493 |
+
|
494 |
+
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
495 |
+
|
496 |
+
# CLIP's text model uses causal mask, prepare it here.
|
497 |
+
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
498 |
+
causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
|
499 |
+
# expand attention_mask
|
500 |
+
if attention_mask is not None:
|
501 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
502 |
+
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
503 |
+
|
504 |
+
encoder_outputs = self.encoder(
|
505 |
+
inputs_embeds=hidden_states,
|
506 |
+
attention_mask=attention_mask,
|
507 |
+
causal_attention_mask=causal_attention_mask,
|
508 |
+
output_attentions=output_attentions,
|
509 |
+
output_hidden_states=output_hidden_states,
|
510 |
+
return_dict=return_dict,
|
511 |
+
)
|
512 |
+
|
513 |
+
last_hidden_state = encoder_outputs[0]
|
514 |
+
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
515 |
+
|
516 |
+
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
517 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
518 |
+
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
519 |
+
pooled_output = last_hidden_state[
|
520 |
+
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
521 |
+
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
522 |
+
]
|
523 |
+
|
524 |
+
if not return_dict:
|
525 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
526 |
+
|
527 |
+
return BaseModelOutputWithPooling(
|
528 |
+
last_hidden_state=last_hidden_state,
|
529 |
+
pooler_output=pooled_output,
|
530 |
+
hidden_states=encoder_outputs.hidden_states,
|
531 |
+
attentions=encoder_outputs.attentions,
|
532 |
+
)
|
533 |
+
|
534 |
+
|
535 |
+
@add_start_docstrings(
|
536 |
+
"""The text model from CLIP without any head or projection on top.""",
|
537 |
+
CLIP_START_DOCSTRING,
|
538 |
+
)
|
539 |
+
class CLIPTextModel(CLIPPreTrainedModel):
|
540 |
+
config_class = CLIPTextConfig
|
541 |
+
|
542 |
+
_no_split_modules = ["CLIPEncoderLayer"]
|
543 |
+
|
544 |
+
def __init__(self, config: CLIPTextConfig):
|
545 |
+
super().__init__(config)
|
546 |
+
self.text_model = CLIPTextTransformer(config)
|
547 |
+
# Initialize weights and apply final processing
|
548 |
+
self.post_init()
|
549 |
+
|
550 |
+
def get_input_embeddings(self) -> nn.Module:
|
551 |
+
return self.text_model.embeddings.token_embedding
|
552 |
+
|
553 |
+
def set_input_embeddings(self, value):
|
554 |
+
self.text_model.embeddings.token_embedding = value
|
555 |
+
|
556 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
557 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
558 |
+
def forward(
|
559 |
+
self,
|
560 |
+
input_ids: Optional[torch.Tensor] = None,
|
561 |
+
attention_mask: Optional[torch.Tensor] = None,
|
562 |
+
position_ids: Optional[torch.Tensor] = None,
|
563 |
+
output_attentions: Optional[bool] = None,
|
564 |
+
output_hidden_states: Optional[bool] = None,
|
565 |
+
return_dict: Optional[bool] = None,
|
566 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
567 |
+
r"""
|
568 |
+
Returns:
|
569 |
+
|
570 |
+
Examples:
|
571 |
+
|
572 |
+
```python
|
573 |
+
>>> from transformers import AutoTokenizer, CLIPTextModel
|
574 |
+
|
575 |
+
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
576 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
577 |
+
|
578 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
579 |
+
|
580 |
+
>>> outputs = model(**inputs)
|
581 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
582 |
+
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
583 |
+
```"""
|
584 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
585 |
+
|
586 |
+
return self.text_model(
|
587 |
+
input_ids=input_ids,
|
588 |
+
attention_mask=attention_mask,
|
589 |
+
position_ids=position_ids,
|
590 |
+
output_attentions=output_attentions,
|
591 |
+
output_hidden_states=output_hidden_states,
|
592 |
+
return_dict=return_dict,
|
593 |
+
)
|
594 |
+
|
595 |
+
|
596 |
+
class CLIPVisionTransformer(nn.Module):
|
597 |
+
def __init__(self, config: CLIPVisionConfig):
|
598 |
+
super().__init__()
|
599 |
+
self.config = config
|
600 |
+
embed_dim = config.hidden_size
|
601 |
+
|
602 |
+
self.embeddings = CLIPVisionEmbeddings(config)
|
603 |
+
self.patch_dropout = PatchDropout(config.force_patch_dropout)
|
604 |
+
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
605 |
+
self.encoder = CLIPEncoder(config)
|
606 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
607 |
+
|
608 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
609 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
|
610 |
+
def forward(
|
611 |
+
self,
|
612 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
613 |
+
output_attentions: Optional[bool] = None,
|
614 |
+
output_hidden_states: Optional[bool] = None,
|
615 |
+
return_dict: Optional[bool] = None,
|
616 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
617 |
+
r"""
|
618 |
+
Returns:
|
619 |
+
|
620 |
+
"""
|
621 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
622 |
+
output_hidden_states = (
|
623 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
624 |
+
)
|
625 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
626 |
+
|
627 |
+
if pixel_values is None:
|
628 |
+
raise ValueError("You have to specify pixel_values")
|
629 |
+
######################################
|
630 |
+
if len(pixel_values.shape) == 7:
|
631 |
+
b_new, pair_new, T, bs_new, channel_new, h_new, w_new = pixel_values.shape
|
632 |
+
# print(pixel_values.shape)
|
633 |
+
B = b_new * pair_new * bs_new
|
634 |
+
pixel_values = pixel_values.reshape(B*T, channel_new, h_new, w_new)
|
635 |
+
|
636 |
+
elif len(pixel_values.shape) == 5:
|
637 |
+
B, _, T, _, _ = pixel_values.shape
|
638 |
+
# print(pixel_values.shape)
|
639 |
+
pixel_values = rearrange(pixel_values, 'b c t h w -> (b t) c h w')
|
640 |
+
else:
|
641 |
+
# print(pixel_values.shape)
|
642 |
+
B, _, _, _ = pixel_values.shape
|
643 |
+
T = 1
|
644 |
+
###########################
|
645 |
+
hidden_states = self.embeddings(pixel_values)
|
646 |
+
|
647 |
+
hidden_states = self.patch_dropout(hidden_states, B, T) ##############################################
|
648 |
+
|
649 |
+
hidden_states = self.pre_layrnorm(hidden_states)
|
650 |
+
|
651 |
+
encoder_outputs = self.encoder(
|
652 |
+
inputs_embeds=hidden_states,
|
653 |
+
output_attentions=output_attentions,
|
654 |
+
output_hidden_states=output_hidden_states,
|
655 |
+
return_dict=return_dict,
|
656 |
+
)
|
657 |
+
|
658 |
+
last_hidden_state = encoder_outputs[0]
|
659 |
+
pooled_output = last_hidden_state[:, 0, :]
|
660 |
+
pooled_output = self.post_layernorm(pooled_output)
|
661 |
+
|
662 |
+
pooled_output = pooled_output.reshape(B, T, -1).mean(1) ################################
|
663 |
+
|
664 |
+
if not return_dict:
|
665 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
666 |
+
|
667 |
+
return BaseModelOutputWithPooling(
|
668 |
+
last_hidden_state=last_hidden_state,
|
669 |
+
pooler_output=pooled_output,
|
670 |
+
hidden_states=encoder_outputs.hidden_states,
|
671 |
+
attentions=encoder_outputs.attentions,
|
672 |
+
)
|
673 |
+
|
674 |
+
|
675 |
+
@add_start_docstrings(
|
676 |
+
"""The vision model from CLIP without any head or projection on top.""",
|
677 |
+
CLIP_START_DOCSTRING,
|
678 |
+
)
|
679 |
+
class CLIPVisionModel(CLIPPreTrainedModel):
|
680 |
+
config_class = CLIPVisionConfig
|
681 |
+
main_input_name = "pixel_values"
|
682 |
+
|
683 |
+
def __init__(self, config: CLIPVisionConfig):
|
684 |
+
super().__init__(config)
|
685 |
+
self.vision_model = CLIPVisionTransformer(config)
|
686 |
+
# Initialize weights and apply final processing
|
687 |
+
self.post_init()
|
688 |
+
|
689 |
+
def get_input_embeddings(self) -> nn.Module:
|
690 |
+
return self.vision_model.embeddings.patch_embedding
|
691 |
+
|
692 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
693 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
|
694 |
+
def forward(
|
695 |
+
self,
|
696 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
697 |
+
output_attentions: Optional[bool] = None,
|
698 |
+
output_hidden_states: Optional[bool] = None,
|
699 |
+
return_dict: Optional[bool] = None,
|
700 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
701 |
+
r"""
|
702 |
+
Returns:
|
703 |
+
|
704 |
+
Examples:
|
705 |
+
|
706 |
+
```python
|
707 |
+
>>> from PIL import Image
|
708 |
+
>>> import requests
|
709 |
+
>>> from transformers import AutoProcessor, CLIPVisionModel
|
710 |
+
|
711 |
+
>>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
|
712 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
713 |
+
|
714 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
715 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
716 |
+
|
717 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
718 |
+
|
719 |
+
>>> outputs = model(**inputs)
|
720 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
721 |
+
>>> pooled_output = outputs.pooler_output # pooled CLS states
|
722 |
+
```"""
|
723 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
724 |
+
|
725 |
+
return self.vision_model(
|
726 |
+
pixel_values=pixel_values,
|
727 |
+
output_attentions=output_attentions,
|
728 |
+
output_hidden_states=output_hidden_states,
|
729 |
+
return_dict=return_dict,
|
730 |
+
)
|
731 |
+
|
732 |
+
|
733 |
+
@add_start_docstrings(CLIP_START_DOCSTRING)
|
734 |
+
class LanguageBindImage(CLIPPreTrainedModel):
|
735 |
+
config_class = LanguageBindImageConfig
|
736 |
+
|
737 |
+
def __init__(self, config: LanguageBindImageConfig):
|
738 |
+
super().__init__(config)
|
739 |
+
|
740 |
+
if not isinstance(config.text_config, CLIPTextConfig):
|
741 |
+
raise ValueError(
|
742 |
+
"config.text_config is expected to be of type CLIPTextConfig but is of type"
|
743 |
+
f" {type(config.text_config)}."
|
744 |
+
)
|
745 |
+
|
746 |
+
if not isinstance(config.vision_config, CLIPVisionConfig):
|
747 |
+
raise ValueError(
|
748 |
+
"config.vision_config is expected to be of type CLIPVisionConfig but is of type"
|
749 |
+
f" {type(config.vision_config)}."
|
750 |
+
)
|
751 |
+
|
752 |
+
text_config = config.text_config
|
753 |
+
vision_config = config.vision_config
|
754 |
+
self.add_time_attn = vision_config.add_time_attn
|
755 |
+
self.lora_r = vision_config.lora_r
|
756 |
+
self.lora_alpha = vision_config.lora_alpha
|
757 |
+
self.lora_dropout = vision_config.lora_dropout
|
758 |
+
|
759 |
+
self.projection_dim = config.projection_dim
|
760 |
+
self.text_embed_dim = text_config.hidden_size
|
761 |
+
self.vision_embed_dim = vision_config.hidden_size
|
762 |
+
|
763 |
+
self.text_model = CLIPTextTransformer(text_config)
|
764 |
+
self.vision_model = CLIPVisionTransformer(vision_config)
|
765 |
+
|
766 |
+
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
|
767 |
+
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
|
768 |
+
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
|
769 |
+
|
770 |
+
# Initialize weights and apply final processing
|
771 |
+
self.post_init()
|
772 |
+
self.convert_to_lora()
|
773 |
+
# self.resize_pos(self.vision_model.embeddings, vision_config)
|
774 |
+
|
775 |
+
def convert_to_lora(self):
|
776 |
+
if self.lora_r == 0:
|
777 |
+
return
|
778 |
+
if self.add_time_attn:
|
779 |
+
target_modules = ["temporal_attn.k_proj", "temporal_attn.v_proj",
|
780 |
+
"temporal_attn.q_proj", "temporal_attn.out_proj",
|
781 |
+
"temporal_mlp.fc1", "temporal_mlp.fc2"]
|
782 |
+
else:
|
783 |
+
target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"]
|
784 |
+
config = LoraConfig(
|
785 |
+
r=self.lora_r, # 16
|
786 |
+
lora_alpha=self.lora_alpha, # 16
|
787 |
+
target_modules=target_modules, # self_attn.out_proj
|
788 |
+
lora_dropout=self.lora_dropout, # 0.1
|
789 |
+
bias="none",
|
790 |
+
modules_to_save=[],
|
791 |
+
)
|
792 |
+
self.vision_model.encoder.is_gradient_checkpointing = False
|
793 |
+
self.vision_model.encoder = get_peft_model(self.vision_model.encoder, config)
|
794 |
+
|
795 |
+
def resize_pos(self, m, vision_config):
|
796 |
+
# convert embedding
|
797 |
+
if vision_config.num_mel_bins!=0 and vision_config.target_length!=0:
|
798 |
+
m.image_size = [vision_config.num_mel_bins, vision_config.target_length]
|
799 |
+
m.config.image_size = [m.image_size, m.image_size] if isinstance(m.image_size, int) else m.image_size
|
800 |
+
# pos resize
|
801 |
+
old_pos_embed_state_dict = m.position_embedding.state_dict()
|
802 |
+
old_pos_embed = old_pos_embed_state_dict['weight']
|
803 |
+
dtype = old_pos_embed.dtype
|
804 |
+
grid_size = [m.config.image_size[0] // m.patch_size, m.config.image_size[1] // m.patch_size]
|
805 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
806 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
807 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
808 |
+
# m.to(args.device)
|
809 |
+
return
|
810 |
+
|
811 |
+
m.num_patches = grid_size[0] * grid_size[1]
|
812 |
+
m.num_positions = m.num_patches + 1
|
813 |
+
m.register_buffer("position_ids", torch.arange(m.num_positions).expand((1, -1)))
|
814 |
+
new_position_embedding = nn.Embedding(m.num_positions, m.embed_dim)
|
815 |
+
|
816 |
+
if extra_tokens:
|
817 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
818 |
+
else:
|
819 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
820 |
+
old_grid_size = [int(math.sqrt(len(pos_emb_img)))] * 2
|
821 |
+
|
822 |
+
# if is_master(args):
|
823 |
+
# logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
824 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
825 |
+
pos_emb_img = F.interpolate(
|
826 |
+
pos_emb_img,
|
827 |
+
size=grid_size,
|
828 |
+
mode='bicubic',
|
829 |
+
antialias=True,
|
830 |
+
align_corners=False,
|
831 |
+
)
|
832 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
833 |
+
if pos_emb_tok is not None:
|
834 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
835 |
+
else:
|
836 |
+
new_pos_embed = pos_emb_img
|
837 |
+
old_pos_embed_state_dict['weight'] = new_pos_embed.to(dtype)
|
838 |
+
m.position_embedding = new_position_embedding
|
839 |
+
m.position_embedding.load_state_dict(old_pos_embed_state_dict)
|
840 |
+
|
841 |
+
# m.to(args.device)
|
842 |
+
|
843 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
844 |
+
def get_text_features(
|
845 |
+
self,
|
846 |
+
input_ids: Optional[torch.Tensor] = None,
|
847 |
+
attention_mask: Optional[torch.Tensor] = None,
|
848 |
+
position_ids: Optional[torch.Tensor] = None,
|
849 |
+
output_attentions: Optional[bool] = None,
|
850 |
+
output_hidden_states: Optional[bool] = None,
|
851 |
+
return_dict: Optional[bool] = None,
|
852 |
+
) -> torch.FloatTensor:
|
853 |
+
r"""
|
854 |
+
Returns:
|
855 |
+
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
|
856 |
+
applying the projection layer to the pooled output of [`CLIPTextModel`].
|
857 |
+
|
858 |
+
Examples:
|
859 |
+
|
860 |
+
```python
|
861 |
+
>>> from transformers import AutoTokenizer, CLIPModel
|
862 |
+
|
863 |
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
864 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
865 |
+
|
866 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
867 |
+
>>> text_features = model.get_text_features(**inputs)
|
868 |
+
```"""
|
869 |
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
870 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
871 |
+
output_hidden_states = (
|
872 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
873 |
+
)
|
874 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
875 |
+
|
876 |
+
text_outputs = self.text_model(
|
877 |
+
input_ids=input_ids,
|
878 |
+
attention_mask=attention_mask,
|
879 |
+
position_ids=position_ids,
|
880 |
+
output_attentions=output_attentions,
|
881 |
+
output_hidden_states=output_hidden_states,
|
882 |
+
return_dict=return_dict,
|
883 |
+
)
|
884 |
+
|
885 |
+
pooled_output = text_outputs[1]
|
886 |
+
text_features = self.text_projection(pooled_output)
|
887 |
+
|
888 |
+
return text_features
|
889 |
+
|
890 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
891 |
+
def get_image_features(
|
892 |
+
self,
|
893 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
894 |
+
output_attentions: Optional[bool] = None,
|
895 |
+
output_hidden_states: Optional[bool] = None,
|
896 |
+
return_dict: Optional[bool] = None,
|
897 |
+
) -> torch.FloatTensor:
|
898 |
+
r"""
|
899 |
+
Returns:
|
900 |
+
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
901 |
+
applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
902 |
+
|
903 |
+
Examples:
|
904 |
+
|
905 |
+
```python
|
906 |
+
>>> from PIL import Image
|
907 |
+
>>> import requests
|
908 |
+
>>> from transformers import AutoProcessor, CLIPModel
|
909 |
+
|
910 |
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
911 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
912 |
+
|
913 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
914 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
915 |
+
|
916 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
917 |
+
|
918 |
+
>>> image_features = model.get_image_features(**inputs)
|
919 |
+
```"""
|
920 |
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
921 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
922 |
+
output_hidden_states = (
|
923 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
924 |
+
)
|
925 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
926 |
+
|
927 |
+
vision_outputs = self.vision_model(
|
928 |
+
pixel_values=pixel_values,
|
929 |
+
output_attentions=output_attentions,
|
930 |
+
output_hidden_states=output_hidden_states,
|
931 |
+
return_dict=return_dict,
|
932 |
+
)
|
933 |
+
|
934 |
+
pooled_output = vision_outputs[1] # pooled_output
|
935 |
+
image_features = self.visual_projection(pooled_output)
|
936 |
+
|
937 |
+
return image_features
|
938 |
+
|
939 |
+
@add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
|
940 |
+
@replace_return_docstrings(output_type=CLIPOutput, config_class=LanguageBindImageConfig)
|
941 |
+
def forward(
|
942 |
+
self,
|
943 |
+
input_ids: Optional[torch.LongTensor] = None,
|
944 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
945 |
+
attention_mask: Optional[torch.Tensor] = None,
|
946 |
+
position_ids: Optional[torch.LongTensor] = None,
|
947 |
+
return_loss: Optional[bool] = None,
|
948 |
+
output_attentions: Optional[bool] = None,
|
949 |
+
output_hidden_states: Optional[bool] = None,
|
950 |
+
return_dict: Optional[bool] = None,
|
951 |
+
) -> Union[Tuple, CLIPOutput]:
|
952 |
+
r"""
|
953 |
+
Returns:
|
954 |
+
|
955 |
+
Examples:
|
956 |
+
|
957 |
+
```python
|
958 |
+
>>> from PIL import Image
|
959 |
+
>>> import requests
|
960 |
+
>>> from transformers import AutoProcessor, CLIPModel
|
961 |
+
|
962 |
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
963 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
964 |
+
|
965 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
966 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
967 |
+
|
968 |
+
>>> inputs = processor(
|
969 |
+
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
|
970 |
+
... )
|
971 |
+
|
972 |
+
>>> outputs = model(**inputs)
|
973 |
+
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
974 |
+
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
975 |
+
```"""
|
976 |
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
977 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
978 |
+
output_hidden_states = (
|
979 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
980 |
+
)
|
981 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
982 |
+
|
983 |
+
vision_outputs = self.vision_model(
|
984 |
+
pixel_values=pixel_values,
|
985 |
+
output_attentions=output_attentions,
|
986 |
+
output_hidden_states=output_hidden_states,
|
987 |
+
return_dict=return_dict,
|
988 |
+
)
|
989 |
+
|
990 |
+
text_outputs = self.text_model(
|
991 |
+
input_ids=input_ids,
|
992 |
+
attention_mask=attention_mask,
|
993 |
+
position_ids=position_ids,
|
994 |
+
output_attentions=output_attentions,
|
995 |
+
output_hidden_states=output_hidden_states,
|
996 |
+
return_dict=return_dict,
|
997 |
+
)
|
998 |
+
|
999 |
+
image_embeds = vision_outputs[1]
|
1000 |
+
image_embeds = self.visual_projection(image_embeds)
|
1001 |
+
|
1002 |
+
text_embeds = text_outputs[1]
|
1003 |
+
text_embeds = self.text_projection(text_embeds)
|
1004 |
+
|
1005 |
+
# normalized features
|
1006 |
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
1007 |
+
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
1008 |
+
|
1009 |
+
# cosine similarity as logits
|
1010 |
+
logit_scale = self.logit_scale.exp()
|
1011 |
+
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
1012 |
+
logits_per_image = logits_per_text.t()
|
1013 |
+
|
1014 |
+
loss = None
|
1015 |
+
if return_loss:
|
1016 |
+
loss = clip_loss(logits_per_text)
|
1017 |
+
|
1018 |
+
if not return_dict:
|
1019 |
+
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
1020 |
+
return ((loss,) + output) if loss is not None else output
|
1021 |
+
|
1022 |
+
return CLIPOutput(
|
1023 |
+
loss=loss,
|
1024 |
+
logits_per_image=logits_per_image,
|
1025 |
+
logits_per_text=logits_per_text,
|
1026 |
+
text_embeds=text_embeds,
|
1027 |
+
image_embeds=image_embeds,
|
1028 |
+
text_model_output=text_outputs,
|
1029 |
+
vision_model_output=vision_outputs,
|
1030 |
+
)
|
models/multimodal_encoder/languagebind/image/processing_image.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
from torchvision import transforms
|
4 |
+
from transformers import ProcessorMixin, BatchEncoding
|
5 |
+
from transformers.image_processing_utils import BatchFeature
|
6 |
+
|
7 |
+
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
8 |
+
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
9 |
+
|
10 |
+
def make_list_of_images(x):
|
11 |
+
if not isinstance(x, list):
|
12 |
+
return [x]
|
13 |
+
return x
|
14 |
+
|
15 |
+
def get_image_transform(config):
|
16 |
+
config = config.vision_config
|
17 |
+
transform = transforms.Compose(
|
18 |
+
[
|
19 |
+
transforms.ToTensor(),
|
20 |
+
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
|
21 |
+
transforms.CenterCrop(224),
|
22 |
+
transforms.Normalize(OPENAI_DATASET_MEAN, OPENAI_DATASET_STD) # assume image
|
23 |
+
]
|
24 |
+
)
|
25 |
+
return transform
|
26 |
+
|
27 |
+
|
28 |
+
def load_and_transform_image(image_path, transform):
|
29 |
+
image = Image.open(image_path).convert('RGB') if isinstance(image_path, str) else image_path
|
30 |
+
image_outputs = transform(image)
|
31 |
+
return image_outputs
|
32 |
+
|
33 |
+
class LanguageBindImageProcessor(ProcessorMixin):
|
34 |
+
attributes = []
|
35 |
+
tokenizer_class = ("LanguageBindImageTokenizer")
|
36 |
+
|
37 |
+
def __init__(self, config, tokenizer=None, **kwargs):
|
38 |
+
super().__init__(**kwargs)
|
39 |
+
self.config = config
|
40 |
+
self.transform = get_image_transform(config)
|
41 |
+
self.image_processor = load_and_transform_image
|
42 |
+
self.tokenizer = tokenizer
|
43 |
+
self.image_mean = OPENAI_DATASET_MEAN
|
44 |
+
self.crop_size = {'height': 224, 'width': 224}
|
45 |
+
|
46 |
+
def __call__(self, images=None, text=None, context_length=77, return_tensors=None, **kwargs):
|
47 |
+
if text is None and images is None:
|
48 |
+
raise ValueError("You have to specify either text or images. Both cannot be none.")
|
49 |
+
|
50 |
+
if text is not None:
|
51 |
+
encoding = self.tokenizer(text, max_length=context_length, padding='max_length',
|
52 |
+
truncation=True, return_tensors=return_tensors, **kwargs)
|
53 |
+
|
54 |
+
if images is not None:
|
55 |
+
images = make_list_of_images(images)
|
56 |
+
image_features = [self.image_processor(image, self.transform) for image in images]
|
57 |
+
image_features = torch.stack(image_features)
|
58 |
+
|
59 |
+
if text is not None and images is not None:
|
60 |
+
encoding["pixel_values"] = image_features
|
61 |
+
return encoding
|
62 |
+
elif text is not None:
|
63 |
+
return encoding
|
64 |
+
else:
|
65 |
+
return {"pixel_values": image_features}
|
66 |
+
|
67 |
+
def preprocess(self, images, return_tensors):
|
68 |
+
return self.__call__(images=images, return_tensors=return_tensors)
|
69 |
+
|
70 |
+
def batch_decode(self, skip_special_tokens=True, *args, **kwargs):
|
71 |
+
"""
|
72 |
+
This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
73 |
+
refer to the docstring of this method for more information.
|
74 |
+
"""
|
75 |
+
return self.tokenizer.batch_decode(*args, skip_special_tokens=skip_special_tokens, **kwargs)
|
76 |
+
|
77 |
+
def decode(self, skip_special_tokens=True, *args, **kwargs):
|
78 |
+
"""
|
79 |
+
This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
80 |
+
the docstring of this method for more information.
|
81 |
+
"""
|
82 |
+
return self.tokenizer.decode(*args, skip_special_tokens=skip_special_tokens, **kwargs)
|
models/multimodal_encoder/languagebind/image/tokenization_image.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import CLIPTokenizer
|
2 |
+
from transformers.utils import logging
|
3 |
+
|
4 |
+
logger = logging.get_logger(__name__)
|
5 |
+
|
6 |
+
VOCAB_FILES_NAMES = {
|
7 |
+
"vocab_file": "vocab.json",
|
8 |
+
"merges_file": "merges.txt",
|
9 |
+
}
|
10 |
+
|
11 |
+
PRETRAINED_VOCAB_FILES_MAP = {
|
12 |
+
"vocab_file": {
|
13 |
+
"lb203/LanguageBind-Image": "https://huggingface.co/lb203/LanguageBind-Image/resolve/main/vocab.json",
|
14 |
+
},
|
15 |
+
"merges_file": {
|
16 |
+
"lb203/LanguageBind-Image": "https://huggingface.co/lb203/LanguageBind-Image/resolve/main/merges.txt",
|
17 |
+
},
|
18 |
+
}
|
19 |
+
|
20 |
+
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
21 |
+
"lb203/LanguageBind-Image": 77,
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
PRETRAINED_INIT_CONFIGURATION = {
|
26 |
+
"lb203/LanguageBind-Image": {},
|
27 |
+
}
|
28 |
+
|
29 |
+
class LanguageBindImageTokenizer(CLIPTokenizer):
|
30 |
+
"""
|
31 |
+
Construct a CLIP tokenizer. Based on byte-level Byte-Pair-Encoding.
|
32 |
+
|
33 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
34 |
+
this superclass for more information regarding those methods.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
vocab_file (`str`):
|
38 |
+
Path to the vocabulary file.
|
39 |
+
merges_file (`str`):
|
40 |
+
Path to the merges file.
|
41 |
+
errors (`str`, *optional*, defaults to `"replace"`):
|
42 |
+
Paradigm to follow when decoding bytes to UTF-8. See
|
43 |
+
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
|
44 |
+
unk_token (`str`, *optional*, defaults to `<|endoftext|>`):
|
45 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
46 |
+
token instead.
|
47 |
+
bos_token (`str`, *optional*, defaults to `<|startoftext|>`):
|
48 |
+
The beginning of sequence token.
|
49 |
+
eos_token (`str`, *optional*, defaults to `<|endoftext|>`):
|
50 |
+
The end of sequence token.
|
51 |
+
"""
|
52 |
+
|
53 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
54 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
55 |
+
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
56 |
+
model_input_names = ["input_ids", "attention_mask"]
|
57 |
+
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
vocab_file,
|
61 |
+
merges_file,
|
62 |
+
errors="replace",
|
63 |
+
unk_token="<|endoftext|>",
|
64 |
+
bos_token="<|startoftext|>",
|
65 |
+
eos_token="<|endoftext|>",
|
66 |
+
pad_token="<|endoftext|>", # hack to enable padding
|
67 |
+
**kwargs,
|
68 |
+
):
|
69 |
+
super(LanguageBindImageTokenizer, self).__init__(
|
70 |
+
vocab_file,
|
71 |
+
merges_file,
|
72 |
+
errors,
|
73 |
+
unk_token,
|
74 |
+
bos_token,
|
75 |
+
eos_token,
|
76 |
+
pad_token, # hack to enable padding
|
77 |
+
**kwargs,)
|
models/multimodal_encoder/languagebind/thermal/configuration_thermal.py
ADDED
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
from transformers import PretrainedConfig
|
6 |
+
from transformers.utils import logging
|
7 |
+
|
8 |
+
logger = logging.get_logger(__name__)
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
class CLIPTextConfig(PretrainedConfig):
|
17 |
+
r"""
|
18 |
+
This is the configuration class to store the configuration of a [`CLIPTextModel`]. It is used to instantiate a CLIP
|
19 |
+
text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
|
20 |
+
with the defaults will yield a similar configuration to that of the text encoder of the CLIP
|
21 |
+
[openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
|
22 |
+
|
23 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
24 |
+
documentation from [`PretrainedConfig`] for more information.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
vocab_size (`int`, *optional*, defaults to 49408):
|
28 |
+
Vocabulary size of the CLIP text model. Defines the number of different tokens that can be represented by
|
29 |
+
the `inputs_ids` passed when calling [`CLIPModel`].
|
30 |
+
hidden_size (`int`, *optional*, defaults to 512):
|
31 |
+
Dimensionality of the encoder layers and the pooler layer.
|
32 |
+
intermediate_size (`int`, *optional*, defaults to 2048):
|
33 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
34 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
35 |
+
Number of hidden layers in the Transformer encoder.
|
36 |
+
num_attention_heads (`int`, *optional*, defaults to 8):
|
37 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
38 |
+
max_position_embeddings (`int`, *optional*, defaults to 77):
|
39 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
40 |
+
just in case (e.g., 512 or 1024 or 2048).
|
41 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
|
42 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
43 |
+
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
44 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
45 |
+
The epsilon used by the layer normalization layers.
|
46 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
47 |
+
The dropout ratio for the attention probabilities.
|
48 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
49 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
50 |
+
initializer_factor (`float`, *optional*, defaults to 1):
|
51 |
+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
52 |
+
testing).
|
53 |
+
|
54 |
+
Example:
|
55 |
+
|
56 |
+
```python
|
57 |
+
>>> from transformers import CLIPTextConfig, CLIPTextModel
|
58 |
+
|
59 |
+
>>> # Initializing a CLIPTextConfig with openai/clip-vit-base-patch32 style configuration
|
60 |
+
>>> configuration = CLIPTextConfig()
|
61 |
+
|
62 |
+
>>> # Initializing a CLIPTextModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
|
63 |
+
>>> model = CLIPTextModel(configuration)
|
64 |
+
|
65 |
+
>>> # Accessing the model configuration
|
66 |
+
>>> configuration = model.config
|
67 |
+
```"""
|
68 |
+
model_type = "clip_text_model"
|
69 |
+
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
vocab_size=49408,
|
73 |
+
hidden_size=512,
|
74 |
+
intermediate_size=2048,
|
75 |
+
projection_dim=512,
|
76 |
+
num_hidden_layers=12,
|
77 |
+
num_attention_heads=8,
|
78 |
+
max_position_embeddings=77,
|
79 |
+
hidden_act="quick_gelu",
|
80 |
+
layer_norm_eps=1e-5,
|
81 |
+
attention_dropout=0.0,
|
82 |
+
initializer_range=0.02,
|
83 |
+
initializer_factor=1.0,
|
84 |
+
# This differs from `CLIPTokenizer`'s default and from openai/clip
|
85 |
+
# See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
|
86 |
+
pad_token_id=1,
|
87 |
+
bos_token_id=49406,
|
88 |
+
eos_token_id=49407,
|
89 |
+
**kwargs,
|
90 |
+
):
|
91 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
92 |
+
|
93 |
+
self.vocab_size = vocab_size
|
94 |
+
self.hidden_size = hidden_size
|
95 |
+
self.intermediate_size = intermediate_size
|
96 |
+
self.projection_dim = projection_dim
|
97 |
+
self.num_hidden_layers = num_hidden_layers
|
98 |
+
self.num_attention_heads = num_attention_heads
|
99 |
+
self.max_position_embeddings = max_position_embeddings
|
100 |
+
self.layer_norm_eps = layer_norm_eps
|
101 |
+
self.hidden_act = hidden_act
|
102 |
+
self.initializer_range = initializer_range
|
103 |
+
self.initializer_factor = initializer_factor
|
104 |
+
self.attention_dropout = attention_dropout
|
105 |
+
self.add_time_attn = False ######################################
|
106 |
+
|
107 |
+
@classmethod
|
108 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
109 |
+
cls._set_token_in_kwargs(kwargs)
|
110 |
+
|
111 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
112 |
+
|
113 |
+
# get the text config dict if we are loading from CLIPConfig
|
114 |
+
if config_dict.get("model_type") == "clip":
|
115 |
+
config_dict = config_dict["text_config"]
|
116 |
+
|
117 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
118 |
+
logger.warning(
|
119 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
120 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
121 |
+
)
|
122 |
+
|
123 |
+
return cls.from_dict(config_dict, **kwargs)
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
class CLIPVisionConfig(PretrainedConfig):
|
129 |
+
r"""
|
130 |
+
This is the configuration class to store the configuration of a [`CLIPVisionModel`]. It is used to instantiate a
|
131 |
+
CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
132 |
+
configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP
|
133 |
+
[openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
|
134 |
+
|
135 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
136 |
+
documentation from [`PretrainedConfig`] for more information.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
140 |
+
Dimensionality of the encoder layers and the pooler layer.
|
141 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
142 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
143 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
144 |
+
Number of hidden layers in the Transformer encoder.
|
145 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
146 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
147 |
+
image_size (`int`, *optional*, defaults to 224):
|
148 |
+
The size (resolution) of each image.
|
149 |
+
patch_size (`int`, *optional*, defaults to 32):
|
150 |
+
The size (resolution) of each patch.
|
151 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
|
152 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
153 |
+
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
|
154 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
155 |
+
The epsilon used by the layer normalization layers.
|
156 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
157 |
+
The dropout ratio for the attention probabilities.
|
158 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
159 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
160 |
+
initializer_factor (`float`, *optional*, defaults to 1):
|
161 |
+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
162 |
+
testing).
|
163 |
+
|
164 |
+
Example:
|
165 |
+
|
166 |
+
```python
|
167 |
+
>>> from transformers import CLIPVisionConfig, CLIPVisionModel
|
168 |
+
|
169 |
+
>>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration
|
170 |
+
>>> configuration = CLIPVisionConfig()
|
171 |
+
|
172 |
+
>>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
|
173 |
+
>>> model = CLIPVisionModel(configuration)
|
174 |
+
|
175 |
+
>>> # Accessing the model configuration
|
176 |
+
>>> configuration = model.config
|
177 |
+
```"""
|
178 |
+
|
179 |
+
model_type = "clip_vision_model"
|
180 |
+
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
hidden_size=768,
|
184 |
+
intermediate_size=3072,
|
185 |
+
projection_dim=512,
|
186 |
+
num_hidden_layers=12,
|
187 |
+
num_attention_heads=12,
|
188 |
+
num_channels=3,
|
189 |
+
image_size=224,
|
190 |
+
patch_size=32,
|
191 |
+
hidden_act="quick_gelu",
|
192 |
+
layer_norm_eps=1e-5,
|
193 |
+
attention_dropout=0.0,
|
194 |
+
initializer_range=0.02,
|
195 |
+
initializer_factor=1.0,
|
196 |
+
|
197 |
+
add_time_attn=False, ################################
|
198 |
+
num_frames=1, ################################
|
199 |
+
force_patch_dropout=0.0, ################################
|
200 |
+
lora_r=2, ################################
|
201 |
+
lora_alpha=16, ################################
|
202 |
+
lora_dropout=0.0, ################################
|
203 |
+
num_mel_bins=0.0, ################################
|
204 |
+
target_length=0.0, ################################
|
205 |
+
video_decode_backend='decord', #########################
|
206 |
+
**kwargs,
|
207 |
+
):
|
208 |
+
super().__init__(**kwargs)
|
209 |
+
|
210 |
+
self.hidden_size = hidden_size
|
211 |
+
self.intermediate_size = intermediate_size
|
212 |
+
self.projection_dim = projection_dim
|
213 |
+
self.num_hidden_layers = num_hidden_layers
|
214 |
+
self.num_attention_heads = num_attention_heads
|
215 |
+
self.num_channels = num_channels
|
216 |
+
self.patch_size = patch_size
|
217 |
+
self.image_size = image_size
|
218 |
+
self.initializer_range = initializer_range
|
219 |
+
self.initializer_factor = initializer_factor
|
220 |
+
self.attention_dropout = attention_dropout
|
221 |
+
self.layer_norm_eps = layer_norm_eps
|
222 |
+
self.hidden_act = hidden_act
|
223 |
+
|
224 |
+
self.add_time_attn = add_time_attn ################
|
225 |
+
self.num_frames = num_frames ################
|
226 |
+
self.force_patch_dropout = force_patch_dropout ################
|
227 |
+
self.lora_r = lora_r ################
|
228 |
+
self.lora_alpha = lora_alpha ################
|
229 |
+
self.lora_dropout = lora_dropout ################
|
230 |
+
self.num_mel_bins = num_mel_bins ################
|
231 |
+
self.target_length = target_length ################
|
232 |
+
self.video_decode_backend = video_decode_backend ################
|
233 |
+
|
234 |
+
@classmethod
|
235 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
236 |
+
cls._set_token_in_kwargs(kwargs)
|
237 |
+
|
238 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
239 |
+
|
240 |
+
# get the vision config dict if we are loading from CLIPConfig
|
241 |
+
if config_dict.get("model_type") == "clip":
|
242 |
+
config_dict = config_dict["vision_config"]
|
243 |
+
|
244 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
245 |
+
logger.warning(
|
246 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
247 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
248 |
+
)
|
249 |
+
|
250 |
+
return cls.from_dict(config_dict, **kwargs)
|
251 |
+
|
252 |
+
|
253 |
+
class LanguageBindThermalConfig(PretrainedConfig):
|
254 |
+
r"""
|
255 |
+
[`CLIPConfig`] is the configuration class to store the configuration of a [`CLIPModel`]. It is used to instantiate
|
256 |
+
a CLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating
|
257 |
+
a configuration with the defaults will yield a similar configuration to that of the CLIP
|
258 |
+
[openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
|
259 |
+
|
260 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
261 |
+
documentation from [`PretrainedConfig`] for more information.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
text_config (`dict`, *optional*):
|
265 |
+
Dictionary of configuration options used to initialize [`CLIPTextConfig`].
|
266 |
+
vision_config (`dict`, *optional*):
|
267 |
+
Dictionary of configuration options used to initialize [`CLIPVisionConfig`].
|
268 |
+
projection_dim (`int`, *optional*, defaults to 512):
|
269 |
+
Dimentionality of text and vision projection layers.
|
270 |
+
logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
|
271 |
+
The inital value of the *logit_scale* paramter. Default is used as per the original CLIP implementation.
|
272 |
+
kwargs (*optional*):
|
273 |
+
Dictionary of keyword arguments.
|
274 |
+
|
275 |
+
Example:
|
276 |
+
|
277 |
+
```python
|
278 |
+
>>> from transformers import CLIPConfig, CLIPModel
|
279 |
+
|
280 |
+
>>> # Initializing a CLIPConfig with openai/clip-vit-base-patch32 style configuration
|
281 |
+
>>> configuration = CLIPConfig()
|
282 |
+
|
283 |
+
>>> # Initializing a CLIPModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
|
284 |
+
>>> model = CLIPModel(configuration)
|
285 |
+
|
286 |
+
>>> # Accessing the model configuration
|
287 |
+
>>> configuration = model.config
|
288 |
+
|
289 |
+
>>> # We can also initialize a CLIPConfig from a CLIPTextConfig and a CLIPVisionConfig
|
290 |
+
>>> from transformers import CLIPTextConfig, CLIPVisionConfig
|
291 |
+
|
292 |
+
>>> # Initializing a CLIPText and CLIPVision configuration
|
293 |
+
>>> config_text = CLIPTextConfig()
|
294 |
+
>>> config_vision = CLIPVisionConfig()
|
295 |
+
|
296 |
+
>>> config = CLIPConfig.from_text_vision_configs(config_text, config_vision)
|
297 |
+
```"""
|
298 |
+
|
299 |
+
model_type = "LanguageBindThermal"
|
300 |
+
is_composition = True
|
301 |
+
|
302 |
+
def __init__(
|
303 |
+
self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
|
304 |
+
):
|
305 |
+
# If `_config_dict` exist, we use them for the backward compatibility.
|
306 |
+
# We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
|
307 |
+
# of confusion!).
|
308 |
+
text_config_dict = kwargs.pop("text_config_dict", None)
|
309 |
+
vision_config_dict = kwargs.pop("vision_config_dict", None)
|
310 |
+
|
311 |
+
super().__init__(**kwargs)
|
312 |
+
|
313 |
+
# Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
|
314 |
+
# `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
|
315 |
+
# cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
|
316 |
+
if text_config_dict is not None:
|
317 |
+
if text_config is None:
|
318 |
+
text_config = {}
|
319 |
+
|
320 |
+
# This is the complete result when using `text_config_dict`.
|
321 |
+
_text_config_dict = CLIPTextConfig(**text_config_dict).to_dict()
|
322 |
+
|
323 |
+
# Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
|
324 |
+
for key, value in _text_config_dict.items():
|
325 |
+
if key in text_config and value != text_config[key] and key not in ["transformers_version"]:
|
326 |
+
# If specified in `text_config_dict`
|
327 |
+
if key in text_config_dict:
|
328 |
+
message = (
|
329 |
+
f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. "
|
330 |
+
f'The value `text_config_dict["{key}"]` will be used instead.'
|
331 |
+
)
|
332 |
+
# If inferred from default argument values (just to be super careful)
|
333 |
+
else:
|
334 |
+
message = (
|
335 |
+
f"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The "
|
336 |
+
f'value `text_config["{key}"]` will be overriden.'
|
337 |
+
)
|
338 |
+
logger.warning(message)
|
339 |
+
|
340 |
+
# Update all values in `text_config` with the ones in `_text_config_dict`.
|
341 |
+
text_config.update(_text_config_dict)
|
342 |
+
|
343 |
+
if vision_config_dict is not None:
|
344 |
+
if vision_config is None:
|
345 |
+
vision_config = {}
|
346 |
+
|
347 |
+
# This is the complete result when using `vision_config_dict`.
|
348 |
+
_vision_config_dict = CLIPVisionConfig(**vision_config_dict).to_dict()
|
349 |
+
# convert keys to string instead of integer
|
350 |
+
if "id2label" in _vision_config_dict:
|
351 |
+
_vision_config_dict["id2label"] = {
|
352 |
+
str(key): value for key, value in _vision_config_dict["id2label"].items()
|
353 |
+
}
|
354 |
+
|
355 |
+
# Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different.
|
356 |
+
for key, value in _vision_config_dict.items():
|
357 |
+
if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]:
|
358 |
+
# If specified in `vision_config_dict`
|
359 |
+
if key in vision_config_dict:
|
360 |
+
message = (
|
361 |
+
f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different "
|
362 |
+
f'values. The value `vision_config_dict["{key}"]` will be used instead.'
|
363 |
+
)
|
364 |
+
# If inferred from default argument values (just to be super careful)
|
365 |
+
else:
|
366 |
+
message = (
|
367 |
+
f"`vision_config_dict` is provided which will be used to initialize `CLIPVisionConfig`. "
|
368 |
+
f'The value `vision_config["{key}"]` will be overriden.'
|
369 |
+
)
|
370 |
+
logger.warning(message)
|
371 |
+
|
372 |
+
# Update all values in `vision_config` with the ones in `_vision_config_dict`.
|
373 |
+
vision_config.update(_vision_config_dict)
|
374 |
+
|
375 |
+
if text_config is None:
|
376 |
+
text_config = {}
|
377 |
+
logger.info("`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.")
|
378 |
+
|
379 |
+
if vision_config is None:
|
380 |
+
vision_config = {}
|
381 |
+
logger.info("`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.")
|
382 |
+
|
383 |
+
self.text_config = CLIPTextConfig(**text_config)
|
384 |
+
self.vision_config = CLIPVisionConfig(**vision_config)
|
385 |
+
|
386 |
+
self.projection_dim = projection_dim
|
387 |
+
self.logit_scale_init_value = logit_scale_init_value
|
388 |
+
self.initializer_factor = 1.0
|
389 |
+
|
390 |
+
@classmethod
|
391 |
+
def from_text_vision_configs(cls, text_config: CLIPTextConfig, vision_config: CLIPVisionConfig, **kwargs):
|
392 |
+
r"""
|
393 |
+
Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model
|
394 |
+
configuration.
|
395 |
+
|
396 |
+
Returns:
|
397 |
+
[`CLIPConfig`]: An instance of a configuration object
|
398 |
+
"""
|
399 |
+
|
400 |
+
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
|
401 |
+
|
402 |
+
def to_dict(self):
|
403 |
+
"""
|
404 |
+
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
405 |
+
|
406 |
+
Returns:
|
407 |
+
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
408 |
+
"""
|
409 |
+
output = copy.deepcopy(self.__dict__)
|
410 |
+
output["text_config"] = self.text_config.to_dict()
|
411 |
+
output["vision_config"] = self.vision_config.to_dict()
|
412 |
+
output["model_type"] = self.__class__.model_type
|
413 |
+
return output
|
414 |
+
|
415 |
+
|
416 |
+
|
417 |
+
|
418 |
+
|
419 |
+
|
420 |
+
|
421 |
+
|
422 |
+
|
423 |
+
|