EvanTHU commited on
Commit
445d3d1
1 Parent(s): bc6c851
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +14 -0
  2. LICENSE +9 -0
  3. README copy.md +133 -0
  4. app copy.py +661 -0
  5. assets/application.png +0 -0
  6. assets/compare.png +0 -0
  7. assets/highlight.png +0 -0
  8. assets/logo.png +0 -0
  9. assets/system.png +0 -0
  10. generate.py +199 -0
  11. lit_gpt/__init__.py +15 -0
  12. lit_gpt/adapter.py +165 -0
  13. lit_gpt/adapter_v2.py +197 -0
  14. lit_gpt/config.py +1040 -0
  15. lit_gpt/lora.py +671 -0
  16. lit_gpt/model.py +355 -0
  17. lit_gpt/packed_dataset.py +235 -0
  18. lit_gpt/rmsnorm.py +26 -0
  19. lit_gpt/speed_monitor.py +425 -0
  20. lit_gpt/tokenizer.py +103 -0
  21. lit_gpt/utils.py +311 -0
  22. lit_llama/__init__.py +2 -0
  23. lit_llama/adapter.py +151 -0
  24. lit_llama/indexed_dataset.py +588 -0
  25. lit_llama/lora.py +232 -0
  26. lit_llama/model.py +246 -0
  27. lit_llama/quantization.py +281 -0
  28. lit_llama/tokenizer.py +49 -0
  29. lit_llama/utils.py +244 -0
  30. models/__init__.py +0 -0
  31. models/constants.py +18 -0
  32. models/encdec.py +67 -0
  33. models/evaluator_wrapper.py +92 -0
  34. models/modules.py +109 -0
  35. models/multimodal_encoder/builder.py +49 -0
  36. models/multimodal_encoder/clip_encoder.py +78 -0
  37. models/multimodal_encoder/languagebind/__init__.py +285 -0
  38. models/multimodal_encoder/languagebind/audio/configuration_audio.py +430 -0
  39. models/multimodal_encoder/languagebind/audio/modeling_audio.py +1030 -0
  40. models/multimodal_encoder/languagebind/audio/processing_audio.py +190 -0
  41. models/multimodal_encoder/languagebind/audio/tokenization_audio.py +77 -0
  42. models/multimodal_encoder/languagebind/depth/configuration_depth.py +425 -0
  43. models/multimodal_encoder/languagebind/depth/modeling_depth.py +1030 -0
  44. models/multimodal_encoder/languagebind/depth/processing_depth.py +108 -0
  45. models/multimodal_encoder/languagebind/depth/tokenization_depth.py +77 -0
  46. models/multimodal_encoder/languagebind/image/configuration_image.py +423 -0
  47. models/multimodal_encoder/languagebind/image/modeling_image.py +1030 -0
  48. models/multimodal_encoder/languagebind/image/processing_image.py +82 -0
  49. models/multimodal_encoder/languagebind/image/tokenization_image.py +77 -0
  50. models/multimodal_encoder/languagebind/thermal/configuration_thermal.py +423 -0
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **/*.pyc
2
+ **/__pycache__
3
+ __pycache__/
4
+ cache_dir/
5
+ checkpoints/
6
+ feedback/
7
+ temp/
8
+ models--LanguageBind--Video-LLaVA-7B/
9
+ *.jsonl
10
+ *.json
11
+ linghao
12
+ run.sh
13
+ examples/
14
+ assets/task.gif
LICENSE ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ License for Non-commercial Scientific Research Purposes
2
+
3
+ IDEA grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under IDEA’s copyright interests to reproduce, distribute, and create derivative works of the text, videos, codes solely for your non-commercial research purposes.
4
+
5
+ Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited.
6
+
7
+ Text and visualization results are owned by International Digital Economy Academy (IDEA).
8
+
9
+ You also need to obey the original license of the dependency models/data used in this service.
README copy.md ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MotionLLM: Understanding Human Behaviors from Human Motions and Videos
2
+
3
+ ![task](./assets/task.gif)
4
+
5
+ [Ling-Hao Chen](https://lhchen.top)<sup>😎 1, 3</sup>,
6
+ [Shunlin Lu](https://shunlinlu.github.io)<sup>😎 2, 3</sup>,
7
+ [Ailing Zeng](https://ailingzeng.sit)<sup>3</sup>,
8
+ [Hao Zhang](https://haozhang534.github.io/)<sup>3, 4</sup>,
9
+ [Benyou Wang](https://wabyking.github.io/old.html)<sup>2</sup>,
10
+ [Ruimao Zhang](http://zhangruimao.site)<sup>2</sup>,
11
+ [Lei Zhang](https://leizhang.org)<sup>🤗 3</sup>
12
+
13
+ <sup>😎</sup>Co-first author. Listing order is random.
14
+ <sup>🤗</sup>Corresponding author.
15
+
16
+ <sup>1</sup>Tsinghua University,
17
+ <sup>2</sup>School of Data Science, The Chinese University of Hong Kong, Shenzhen (CUHK-SZ),
18
+ <sup>3</sup>International Digital Economy Academy (IDEA),
19
+ <sup>4</sup>The Hong Kong University of Science and Technology
20
+
21
+ <p align="center">
22
+ <a href='https://arxiv.org/abs/2304'>
23
+ <img src='https://img.shields.io/badge/Arxiv-2304.tomorrow-A42C25?style=flat&logo=arXiv&logoColor=A42C25'>
24
+ </a>
25
+ <a href='https://arxiv.org/pdf/2304.pdf'>
26
+ <img src='https://img.shields.io/badge/Paper-PDF-yellow?style=flat&logo=arXiv&logoColor=yellow'>
27
+ </a>
28
+ <a href='https://lhchen.top/MotionLLM'>
29
+ <img src='https://img.shields.io/badge/Project-Page-%23df5b46?style=flat&logo=Google%20chrome&logoColor=%23df5b46'></a>
30
+ <a href='https://research.lhchen.top/blogpost/motionllm'>
31
+ <img src='https://img.shields.io/badge/Blog-post-4EABE6?style=flat&logoColor=4EABE6'></a>
32
+ <a href='https://github.com/IDEA-Research/MotionLLM'>
33
+ <img src='https://img.shields.io/badge/GitHub-Code-black?style=flat&logo=github&logoColor=white'></a>
34
+ <a href='LICENSE'>
35
+ <img src='https://img.shields.io/badge/License-IDEA-blue.svg'>
36
+ </a>
37
+ <a href="" target='_blank'>
38
+ <img src="https://visitor-badge.laobi.icu/badge?page_id=IDEA-Research.MotionLLM&left_color=gray&right_color=%2342b983">
39
+ </a>
40
+ </p>
41
+
42
+ # 🤩 Abstract
43
+
44
+ This study delves into the realm of multi-modality (i.e., video and motion modalities) human behavior understanding by leveraging the powerful capabilities of Large Language Models (LLMs). Diverging from recent LLMs designed for video-only or motion-only understanding, we argue that understanding human behavior necessitates joint modeling from both videos and motion sequences (e.g., SMPL sequences) to capture nuanced body part dynamics and semantics effectively. In light of this, we present MotionLLM, a straightforward yet effective framework for human motion understanding, captioning, and reasoning. Specifically, MotionLLM adopts a unified video-motion training strategy that leverages the complementary advantages of existing coarse video-text data and fine-grained motion-text data to glean rich spatial-temporal insights. Furthermore, we collect a substantial dataset, MoVid, comprising diverse videos, motions, captions, and instructions. Additionally, we propose the MoVid-Bench, with carefully manual annotations, for better evaluation of human behavior understanding on video and motion. Extensive experiments show the superiority of MotionLLM in the caption, spatial-temporal comprehension, and reasoning ability.
45
+
46
+ ## 🤩 Highlight Applications
47
+
48
+ ![application](./assets/application.png)
49
+
50
+ ## 🔧 Technical Solution
51
+
52
+ ![system](./assets/system.png)
53
+
54
+ ## 💻 Try it
55
+
56
+ We provide a simple online [demo](https://demo.humotionx.com/) for you to try MotionLLM. Below is the guidance to deploy the demo on your local machine.
57
+
58
+ ### Step 1: Set up the environment
59
+
60
+ ```bash
61
+ pip install -r requirements.txt
62
+ ```
63
+
64
+ ### Step 2: Download the pre-trained model
65
+
66
+
67
+ <details>
68
+ <summary><b> 2.1 Download the LLM </b></summary>
69
+
70
+ Please follow the instruction of [Lit-GPT](https://github.com/Lightning-AI/litgpt) to prepare the LLM model (vicuna 1.5-7B). These files will be:
71
+ ```bah
72
+ ./checkpoints/vicuna-7b-v1.5
73
+ ├── generation_config.json
74
+ ├── lit_config.json
75
+ ├── lit_model.pth
76
+ ├── pytorch_model-00001-of-00002.bin
77
+ ├── pytorch_model-00002-of-00002.bin
78
+ ├── pytorch_model.bin.index.json
79
+ ├── tokenizer_config.json
80
+ └── tokenizer.model
81
+ ```
82
+
83
+ If you have any confusion, we will update a more detailed instruction in couple of days.
84
+
85
+ </details>
86
+
87
+ <details>
88
+ <summary><b> 2.2 Dowload the LoRA and the projection layer of the MotionLLM </b></summary>
89
+
90
+ We now release one versions of the MotionLLM checkpoints, namely `v1.0` (download [here](https://drive.google.com/drive/folders/1d_5vaL34Hs2z9ACcMXyPEfZNyMs36xKx?usp=sharing)). Opening for the suggestions to Ling-Hao Chen and Shunlin Lu.
91
+
92
+ ```bash
93
+ wget xxx
94
+ ```
95
+ Keep them in a folder named and remember the path (`LINEAR_V` and `LORA`).
96
+
97
+ </details>
98
+
99
+ ### 2.3 Run the demo
100
+
101
+ ```bash
102
+ GRADIO_TEMP_DIR=temp python app.py --lora_path $LORA --mlp_path $LINEAR_V
103
+ ```
104
+ If you have some error in downloading the huggingface model, you can try the following command with the mirror of huggingface.
105
+ ```bash
106
+ HF_ENDPOINT=https://hf-mirror.com GRADIO_TEMP_DIR=temp python app.py --lora_path $LORA --mlp_path $LINEAR_V
107
+ ```
108
+ The `GRADIO_TEMP_DIR=temp` defines a temporary directory as `./temp` for the Gradio to store the data. You can change it to your own path.
109
+
110
+ After thiess, you can open the browser and visit the local host via the command line output reminder. If it is not loaded, please change the IP address as your local IP address (via command `ifconfig`).
111
+
112
+
113
+ ## 💼 To-Do
114
+
115
+ - [x] Release the video demo of MotionLLM.
116
+ - [ ] Release the motion demo of MotionLLM.
117
+ - [ ] Release the MoVid dataset and MoVid-Bench.
118
+ - [ ] Release the tuning instruction of MotionLLM.
119
+
120
+
121
+ ## 💋 Acknowledgement
122
+
123
+
124
+ The author team would like to deliver many thanks to many people. Qing Jiang helps a lot with some parts of manual annotation on MoVid Bench and resolves some ethics issues of MotionLLM. Jingcheng Hu provided some technical suggestions for efficient training. Shilong Liu and Bojia Zi provided some significant technical suggestions on LLM tuning. Jiale Liu, Wenhao Yang, and Chenlai Qian provided some significant suggestions for us to polish the paper. Hongyang Li helped us a lot with the figure design. Yiren Pang provided GPT API keys when our keys were temporarily out of quota. The code is on the basis of [Video-LLaVA](https://github.com/PKU-YuanGroup/Video-LLaVA), [HumanTOMATO](https://lhchen.top/HumanTOMATO/), [MotionGPT](https://github.com/qiqiApink/MotionGPT). [lit-gpt](https://github.com/Lightning-AI/litgpt), and [HumanML3D](https://github.com/EricGuo5513/HumanML3D). Thanks to all contributors!
125
+
126
+
127
+ ## 📚 License
128
+
129
+ This code is distributed under an [IDEA LICENSE](LICENSE). Note that our code depends on other libraries and datasets which each have their own respective licenses that must also be followed.
130
+
131
+
132
+ If you have any question, please contact at: thu [DOT] lhchen [AT] gmail [DOT] com AND shunlinlu0803 [AT] gmail [DOT] com.
133
+
app copy.py ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import subprocess
3
+
4
+ import torch
5
+ import gradio as gr
6
+ from fastapi import FastAPI
7
+ import os
8
+ from PIL import Image
9
+ import tempfile
10
+ from decord import VideoReader, cpu
11
+ import uvicorn
12
+ from transformers import TextStreamer
13
+
14
+ import hashlib
15
+ import os
16
+ import sys
17
+ import time
18
+ import warnings
19
+ from pathlib import Path
20
+ from typing import Optional
21
+ from typing import Dict, List, Literal, Optional, Tuple
22
+ from lit_gpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable
23
+
24
+ import lightning as L
25
+ import numpy as np
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+
29
+ from generate import generate as generate_
30
+ from lit_llama import Tokenizer, LLaMA, LLaMAConfig
31
+ from lit_llama.lora import lora
32
+ from lit_llama.utils import EmptyInitOnDevice
33
+ from lit_gpt.utils import lazy_load
34
+ from scripts.video_dataset.prepare_video_dataset_video_llava import generate_prompt_mlp
35
+ from options import option
36
+ import imageio
37
+ from tqdm import tqdm
38
+
39
+ from models.multimodal_encoder.builder import build_image_tower, build_video_tower
40
+ from models.multimodal_projector.builder import build_vision_projector
41
+
42
+
43
+ title_markdown = ("""<div class="embed_hidden" style="text-align: center;">
44
+ <h1>MotionLLM: Understanding Human Behaviors from Human Motions and Videos</h1>
45
+ <h3>
46
+ <a href="https://lhchen.top" target="_blank" rel="noopener noreferrer">Ling-Hao Chen</a><sup>😎 1, 3</sup>,
47
+ <a href="https://shunlinlu.github.io" target="_blank" rel="noopener noreferrer">Shunlin Lu</a><sup>😎 2, 3</sup>,
48
+ <br>
49
+ <a href="https://ailingzeng.sit" target="_blank" rel="noopener noreferrer">Ailing Zeng</a><sup>3</sup>,
50
+ <a href="https://haozhang534.github.io/" target="_blank" rel="noopener noreferrer">Hao Zhang</a><sup>3, 4</sup>,
51
+ <a href="https://wabyking.github.io/old.html" target="_blank" rel="noopener noreferrer">Benyou Wang</a><sup>2</sup>,
52
+ <a href="http://zhangruimao.site" target="_blank" rel="noopener noreferrer">Ruimao Zhang</a><sup>2</sup>,
53
+ <a href="https://leizhang.org" target="_blank" rel="noopener noreferrer">Lei Zhang</a><sup>🤗 3</sup>
54
+ </h3>
55
+ <h3><sup>😎</sup><i>Co-first author. Listing order is random.</i> &emsp; <sup>🤗</sup><i>Corresponding author.</i></h3>
56
+ <h3>
57
+ <sup>1</sup>THU &emsp;
58
+ <sup>2</sup>CUHK (SZ) &emsp;
59
+ <sup>3</sup>IDEA Research &emsp;
60
+ <sup>4</sup>HKUST
61
+ </h3>
62
+ </div>
63
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
64
+ <img src="https://lhchen.top/MotionLLM/assets/img/highlight.png" alt="MotionLLM" style="width:60%; height: auto; align-items: center;">
65
+ </div>
66
+
67
+ """)
68
+
69
+ block_css = """
70
+ #buttons button {
71
+ min-width: min(120px,100%);
72
+ }
73
+ """
74
+
75
+
76
+ tos_markdown = ("""
77
+ *We are now working to support the motion branch of the MotionLLM model.
78
+
79
+ ### Terms of use
80
+ By using this service, users are required to agree to the following terms:
81
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content.
82
+ It is forbidden to use the service to generate content that is illegal, harmful, violent, racist, or sexual
83
+ The usage of this service is subject to the IDEA License.
84
+ """)
85
+
86
+
87
+ learn_more_markdown = ("""
88
+ ### License
89
+ License for Non-commercial Scientific Research Purposes
90
+
91
+ IDEA grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under IDEA’s copyright interests to reproduce, distribute, and create derivative works of the text, videos, codes solely for your non-commercial research purposes.
92
+
93
+ Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited.
94
+
95
+ Text and visualization results are owned by International Digital Economy Academy (IDEA).
96
+
97
+ You also need to obey the original license of the dependency models/data used in this service.
98
+ """)
99
+
100
+
101
+
102
+ class LlavaMetaModel:
103
+
104
+ def __init__(self, config, pretrained_checkpoint):
105
+ super(LlavaMetaModel, self).__init__()
106
+ # import pdb; pdb.set_trace()
107
+ if hasattr(config, "mm_image_tower") or hasattr(config, "image_tower"):
108
+ self.image_tower = build_image_tower(config, delay_load=True)
109
+ self.mm_projector = build_vision_projector(config)
110
+ if hasattr(config, "mm_video_tower") or hasattr(config, "video_tower"):
111
+ self.video_tower = build_video_tower(config, delay_load=True)
112
+ self.mm_projector = build_vision_projector(config)
113
+ self.load_video_tower_pretrained(pretrained_checkpoint)
114
+
115
+ def get_image_tower(self):
116
+ image_tower = getattr(self, 'image_tower', None)
117
+ if type(image_tower) is list:
118
+ image_tower = image_tower[0]
119
+ return image_tower
120
+
121
+ def get_video_tower(self):
122
+ video_tower = getattr(self, 'video_tower', None)
123
+
124
+ if type(video_tower) is list:
125
+ video_tower = video_tower[0]
126
+ return video_tower
127
+
128
+
129
+ def get_all_tower(self, keys):
130
+ tower = {key: getattr(self, f'get_{key}_tower') for key in keys}
131
+ return tower
132
+
133
+
134
+ def load_video_tower_pretrained(self, pretrained_checkpoint):
135
+ self.mm_projector.load_state_dict(pretrained_checkpoint, strict=True)
136
+
137
+
138
+ def initialize_image_modules(self, model_args, fsdp=None):
139
+ image_tower = model_args.image_tower
140
+ mm_vision_select_layer = model_args.mm_vision_select_layer
141
+ mm_vision_select_feature = model_args.mm_vision_select_feature
142
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
143
+
144
+ self.config.mm_image_tower = image_tower
145
+
146
+ image_tower = build_image_tower(model_args)
147
+
148
+ if fsdp is not None and len(fsdp) > 0:
149
+ self.image_tower = [image_tower]
150
+ else:
151
+ self.image_tower = image_tower
152
+
153
+ self.config.use_mm_proj = True
154
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
155
+ self.config.mm_hidden_size = image_tower.hidden_size
156
+ self.config.mm_vision_select_layer = mm_vision_select_layer
157
+ self.config.mm_vision_select_feature = mm_vision_select_feature
158
+
159
+ self.mm_projector = build_vision_projector(self.config)
160
+
161
+ if pretrain_mm_mlp_adapter is not None:
162
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
163
+ def get_w(weights, keyword):
164
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
165
+
166
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
167
+
168
+ def initialize_video_modules(self, model_args, fsdp=None):
169
+ video_tower = model_args.video_tower
170
+ mm_vision_select_layer = model_args.mm_vision_select_layer
171
+ mm_vision_select_feature = model_args.mm_vision_select_feature
172
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
173
+
174
+ self.config.mm_video_tower = video_tower
175
+
176
+ video_tower = build_video_tower(model_args)
177
+
178
+ if fsdp is not None and len(fsdp) > 0:
179
+ self.video_tower = [video_tower]
180
+ else:
181
+ self.video_tower = video_tower
182
+
183
+ self.config.use_mm_proj = True
184
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
185
+ self.config.mm_hidden_size = video_tower.hidden_size
186
+ self.config.mm_vision_select_layer = mm_vision_select_layer
187
+ self.config.mm_vision_select_feature = mm_vision_select_feature
188
+
189
+ self.mm_projector = build_vision_projector(self.config)
190
+
191
+ if pretrain_mm_mlp_adapter is not None:
192
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
193
+ def get_w(weights, keyword):
194
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
195
+
196
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
197
+
198
+ def encode_images(self, images):
199
+ image_features = self.get_image_tower()(images)
200
+ image_features = self.mm_projector(image_features)
201
+ return image_features
202
+
203
+ def encode_videos(self, videos):
204
+ # import pdb; pdb.set_trace()
205
+ # videos: torch.Size([1, 3, 8, 224, 224])
206
+ video_features = self.get_video_tower()(videos) # torch.Size([1, 2048, 1024])
207
+ video_features = self.mm_projector(video_features.float()) # torch.Size([1, 2048, 4096])
208
+ return video_features
209
+
210
+ def get_multimodal_embeddings(self, X_modalities):
211
+ Xs, keys= X_modalities
212
+
213
+ X_features = getattr(self, f'encode_{keys[0]}s')(Xs) # expand to get batchsize
214
+
215
+ return X_features
216
+
217
+
218
+ class Projection(nn.Module):
219
+ def __init__(self, ):
220
+ super().__init__()
221
+ self.linear_proj = nn.Linear(512, 4096)
222
+ def forward(self, x):
223
+ return self.linear_proj(x)
224
+
225
+
226
+ class ProjectionNN(nn.Module):
227
+ def __init__(self, ):
228
+ super().__init__()
229
+ self.proj = nn.Sequential(
230
+ nn.Linear(512, 4096),
231
+ nn.GELU(),
232
+ nn.Linear(4096, 4096)
233
+ )
234
+ def forward(self, x):
235
+ return self.proj(x)
236
+
237
+
238
+ class Conversation():
239
+ def __init__(self, output=None, input_prompt=None, prompt=None):
240
+ if output is None:
241
+ self.messages = []
242
+ else:
243
+ self.messages = []
244
+ self.append_message(prompt, input_prompt, output)
245
+
246
+ def append_message(self, output, input_prompt, prompt, show_images):
247
+ # print(output)
248
+ # print(input_prompt)
249
+ # print(prompt)
250
+ # print(show_images)
251
+ self.messages.append((output, input_prompt, prompt, show_images))
252
+
253
+ def to_gradio_chatbot(self, show_images=None, output_text=None):
254
+ # return a list
255
+ if show_images is None:
256
+ show_images = self.messages[-1][3]
257
+ output_text = self.messages[-1][0]
258
+ return [
259
+ [show_images, output_text]
260
+ ]
261
+
262
+ def get_info(self):
263
+ return self.messages[-1][0], self.messages[-1][1]
264
+
265
+
266
+ class ConversationBuffer():
267
+ def __init__(self, input_text):
268
+ self.buffer_ = []
269
+ self.buffer.append(input_text)
270
+
271
+
272
+ def init_conv():
273
+ conv = Conversation()
274
+ return conv
275
+
276
+
277
+ def get_processor(X, config, device, pretrained_checkpoint_tower, model_path = 'LanguageBind/MotionLLM-7B'):
278
+ mm_backbone_mlp_model = LlavaMetaModel(config, pretrained_checkpoint_tower)
279
+
280
+ processor = {}
281
+ if 'Image' in X:
282
+ image_tower = mm_backbone_mlp_model.get_image_tower() # LanguageBindImageTower()
283
+ if not image_tower.is_loaded:
284
+ image_tower.load_model()
285
+ image_tower.to(device=device, dtype=torch.float16)
286
+ image_processor = image_tower.image_processor
287
+ processor['image'] = image_processor
288
+ if 'Video' in X:
289
+ video_tower = mm_backbone_mlp_model.get_video_tower()
290
+ if not video_tower.is_loaded:
291
+ video_tower.load_model()
292
+ video_tower.to(device=device, dtype=torch.float16)
293
+ video_processor = video_tower.video_processor
294
+ processor['video'] = video_processor
295
+
296
+ return mm_backbone_mlp_model, processor
297
+
298
+
299
+ def motionllm(
300
+ args,
301
+ input_video_path: str,
302
+ text_en_in: str,
303
+ quantize: Optional[str] = None,
304
+ dtype: str = "float32",
305
+ max_new_tokens: int = 200,
306
+ top_k: int = 200,
307
+ temperature: float = 0.8,
308
+ accelerator: str = "auto",):
309
+
310
+ video_tensor = video_processor(input_video_path, return_tensors='pt')['pixel_values']
311
+
312
+ if type(video_tensor) is list:
313
+ tensor = [video.to('cuda', dtype=torch.float16) for video in video_tensor]
314
+ else:
315
+ tensor = video_tensor.to('cuda', dtype=torch.float16) # (1,3,8,224,224)
316
+
317
+ X_modalities = [tensor,['video']]
318
+ video_feature = mm_backbone_mlp_model.get_multimodal_embeddings(X_modalities)
319
+ prompt = text_en_in
320
+ input_prompt = prompt
321
+
322
+ sample = {"instruction": prompt, "input": input_video_path}
323
+
324
+ prefix = generate_prompt_mlp(sample)
325
+ pre = torch.cat((tokenizer.encode(prefix.split('INPUT_VIDEO: ')[0] + "\n", bos=True, eos=False, device=model.device).view(1, -1), tokenizer.encode("INPUT_VIDEO: ", bos=False, eos=False, device=model.device).view(1, -1)), dim=1)
326
+
327
+ prompt = (pre, ". ASSISTANT: ")
328
+ encoded = (prompt[0], video_feature[0], tokenizer.encode(prompt[1], bos=False, eos=False, device=model.device).view(1, -1))
329
+
330
+ t0 = time.perf_counter()
331
+
332
+ output_seq = generate_(
333
+ model,
334
+ idx=encoded,
335
+ max_seq_length=4096,
336
+ max_new_tokens=max_new_tokens,
337
+ temperature=temperature,
338
+ top_k=top_k,
339
+ eos_id=tokenizer.eos_id,
340
+ tokenizer = tokenizer,
341
+ )
342
+ outputfull = tokenizer.decode(output_seq)
343
+ output = outputfull.split("ASSISTANT:")[-1].strip()
344
+ print("================================")
345
+ print(output)
346
+ print("================================")
347
+
348
+ return output, input_prompt, prompt
349
+
350
+
351
+ def save_image_to_local(image):
352
+ filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg')
353
+ image = Image.open(image)
354
+ image.save(filename)
355
+ # print(filename)
356
+ return filename
357
+
358
+
359
+ def save_video_to_local(video_path):
360
+ filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4')
361
+ shutil.copyfile(video_path, filename)
362
+ return filename
363
+
364
+
365
+ def generate(image1, video, textbox_in, first_run, state, images_tensor):
366
+ flag = 1
367
+
368
+ image1 = image1 if image1 else "none"
369
+ video = video if video else "none"
370
+
371
+ if type(state) is not Conversation:
372
+ state = init_conv()
373
+ images_tensor = [[], []]
374
+
375
+ first_run = False if len(state.messages) > 0 else True
376
+ text_en_in = textbox_in.replace("picture", "image")
377
+ output, input_prompt, prompt = motionllm(args, video, text_en_in)
378
+
379
+ text_en_out = output
380
+ textbox_out = text_en_out
381
+
382
+ show_images = ""
383
+ if os.path.exists(image1):
384
+ filename = save_image_to_local(image1)
385
+ show_images += f'<img src="./file={filename}" style="display: inline-block;width: 250px;max-height: 400px;">'
386
+
387
+ if os.path.exists(video):
388
+ filename = save_video_to_local(video)
389
+ show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={filename}"></video>'
390
+
391
+ show_images = textbox_in + "\n" + show_images
392
+ state.append_message(output, input_prompt, prompt, show_images)
393
+
394
+ torch.cuda.empty_cache()
395
+
396
+ return (state, state.to_gradio_chatbot(show_images, output), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(image1) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True))
397
+
398
+ def regenerate(state):
399
+ if len(state.messages) > 0:
400
+ tobot = state.to_gradio_chatbot()
401
+ tobot[-1][1] = None
402
+ textbox = state.messages[-1][1]
403
+ state.messages.pop(-1)
404
+ return state, tobot, False, textbox
405
+ return (state, [], True)
406
+
407
+
408
+ def clear_history(state):
409
+ state = init_conv()
410
+ try:
411
+ tgt = state.to_gradio_chatbot()
412
+ except:
413
+ tgt = [None, None]
414
+ return (gr.update(value=None, interactive=True),
415
+ gr.update(value=None, interactive=True),\
416
+ gr.update(value=None, interactive=True),\
417
+ True, state, tgt, [[], []])
418
+
419
+
420
+ def get_md5(file_path):
421
+ hash_md5 = hashlib.md5()
422
+ with open(file_path, "rb") as f:
423
+ for chunk in iter(lambda: f.read(4096), b""):
424
+ hash_md5.update(chunk)
425
+ return hash_md5.hexdigest()
426
+
427
+
428
+ def logging_up(video, state):
429
+ try:
430
+ state.get_info()
431
+ except:
432
+ return False
433
+ action = "upvote"
434
+ # Get the current time
435
+ current_time = str(time.time())
436
+
437
+ # Create an md5 object
438
+ hash_object = hashlib.md5(current_time.encode())
439
+
440
+ # Get the hexadecimal representation of the hash
441
+ md5_hash = get_md5(video) + "-" + hash_object.hexdigest()
442
+
443
+ command = f"cp {video} ./feedback/{action}/mp4/{md5_hash}.mp4"
444
+ os.system(command)
445
+ with open (f"./feedback/{action}/txt/{md5_hash}.txt", "w") as f:
446
+ out, prp = state.get_info()
447
+ f.write(f"==========\nPrompt: {prp}\n==========\nOutput: {out}==========\n")
448
+ return True
449
+
450
+
451
+ def logging_down(video, state):
452
+ try:
453
+ state.get_info()
454
+ except:
455
+ return False
456
+ action = "downvote"
457
+ # Get the current time
458
+ current_time = str(time.time())
459
+
460
+ # Create an md5 object
461
+ hash_object = hashlib.md5(current_time.encode())
462
+
463
+ # Get the hexadecimal representation of the hash
464
+ md5_hash = get_md5(video) + "-" + hash_object.hexdigest()
465
+
466
+ command = f"cp {video} ./feedback/{action}/mp4/{md5_hash}.mp4"
467
+ os.system(command)
468
+ with open (f"./feedback/{action}/txt/{md5_hash}.txt", "w") as f:
469
+ out, prp = state.get_info()
470
+ f.write(f"==========\nPrompt: {prp}\n==========\nOutput: {out}==========\n")
471
+ return True
472
+
473
+
474
+ torch.set_float32_matmul_precision("high")
475
+ warnings.filterwarnings('ignore')
476
+ args = option.get_args_parser()
477
+
478
+ conv_mode = "llava_v1"
479
+ model_path = 'LanguageBind/Video-LLaVA-7B'
480
+ device = 'cuda'
481
+ load_8bit = False
482
+ load_4bit = True
483
+ dtype = torch.float16
484
+
485
+ if not os.path.exists("temp"):
486
+ os.makedirs("temp")
487
+
488
+ lora_path = Path(args.lora_path)
489
+ pretrained_llm_path = Path(f"./checkpoints/vicuna-7b-v1.5/lit_model.pth")
490
+ tokenizer_llm_path = Path("./checkpoints/vicuna-7b-v1.5/tokenizer.model")
491
+
492
+ # assert lora_path.is_file()
493
+ assert pretrained_llm_path.is_file()
494
+ assert tokenizer_llm_path.is_file()
495
+
496
+ accelerator = "auto"
497
+ fabric = L.Fabric(accelerator=accelerator, devices=1)
498
+
499
+ dtype = "float32"
500
+ dt = getattr(torch, dtype, None)
501
+ if not isinstance(dt, torch.dtype):
502
+ raise ValueError(f"{dtype} is not a valid dtype.")
503
+ dtype = dt
504
+
505
+ quantize = None
506
+ t0 = time.time()
507
+
508
+ with EmptyInitOnDevice(
509
+ device=fabric.device, dtype=dtype, quantization_mode=quantize
510
+ ), lora(r=args.lora_r, alpha=args.lora_alpha, dropout=args.lora_dropout, enabled=True):
511
+ checkpoint_dir = Path("checkpoints/vicuna-7b-v1.5")
512
+ lora_query = True
513
+ lora_key = False
514
+ lora_value = True
515
+ lora_projection = False
516
+ lora_mlp = False
517
+ lora_head = False
518
+ config = Config.from_name(
519
+ name=checkpoint_dir.name,
520
+ r=args.lora_r,
521
+ alpha=args.lora_alpha,
522
+ dropout=args.lora_dropout,
523
+ to_query=lora_query,
524
+ to_key=lora_key,
525
+ to_value=lora_value,
526
+ to_projection=lora_projection,
527
+ to_mlp=lora_mlp,
528
+ to_head=lora_head,
529
+ )
530
+ model = GPT(config).bfloat16()
531
+
532
+ mlp_path = args.mlp_path
533
+ pretrained_checkpoint_mlp = torch.load(mlp_path)
534
+
535
+ X = ['Video']
536
+
537
+ mm_backbone_mlp_model, processor = get_processor(X, args, 'cuda', pretrained_checkpoint_mlp, model_path = 'LanguageBind/Video-LLaVA-7B')
538
+ video_processor = processor['video']
539
+
540
+ linear_proj = mm_backbone_mlp_model.mm_projector
541
+
542
+ # 1. Load the pretrained weights
543
+ pretrained_llm_checkpoint = lazy_load(pretrained_llm_path)
544
+ # 2. Load the fine-tuned LoRA weights
545
+ lora_checkpoint = lazy_load(lora_path)
546
+ # 3. merge the two checkpoints
547
+ model_state_dict = {**pretrained_llm_checkpoint, **lora_checkpoint}
548
+ model.load_state_dict(model_state_dict, strict=True)
549
+ print('Load llm base model from', pretrained_llm_path)
550
+ print('Load lora model from', lora_path)
551
+
552
+ # load mlp again, to en sure, not neccessary actually
553
+ linear_proj.load_state_dict(pretrained_checkpoint_mlp)
554
+ linear_proj = linear_proj.cuda()
555
+ print('Load mlp model again from', mlp_path)
556
+ print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
557
+
558
+ model.eval()
559
+ model = fabric.setup_module(model)
560
+ linear_proj.eval()
561
+
562
+ tokenizer = Tokenizer(tokenizer_llm_path)
563
+ print('Load tokenizer from', tokenizer_llm_path)
564
+
565
+ print(torch.cuda.memory_allocated())
566
+ print(torch.cuda.max_memory_allocated())
567
+
568
+
569
+ app = FastAPI()
570
+
571
+ textbox = gr.Textbox(
572
+ show_label=False, placeholder="Enter text and press ENTER", container=False
573
+ )
574
+
575
+ with gr.Blocks(title='MotionLLM', theme=gr.themes.Default(), css=block_css) as demo:
576
+ gr.Markdown(title_markdown)
577
+ state = gr.State()
578
+ buffer_ = gr.State()
579
+ first_run = gr.State()
580
+ images_tensor = gr.State()
581
+
582
+ with gr.Row():
583
+ with gr.Column(scale=3):
584
+ image1 = gr.State()
585
+ video = gr.Video(label="Input Video")
586
+
587
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
588
+ gr.Examples(
589
+ examples=[
590
+ [
591
+ f"{cur_dir}/examples/Play_Electric_guitar_16_clip1.mp4",
592
+ "why is the girl so happy",
593
+ ],
594
+ [
595
+ f"{cur_dir}/examples/guoyoucai.mov",
596
+ "what is the feeling of him",
597
+ ],
598
+ [
599
+ f"{cur_dir}/examples/sprint_run_18_clip1.mp4",
600
+ "Why is the man running so fast?",
601
+ ],
602
+ [
603
+ f"{cur_dir}/examples/lift_weight.mp4",
604
+ "Assume you are a fitness coach, refer to the video of the professional athlete, please analyze specific action essentials in steps and give detailed instruction.",
605
+ ],
606
+ [
607
+ f"{cur_dir}/examples/Shaolin_Kung_Fu_Wushu_Selfdefense_Sword_Form_Session_22_clip3.mp4",
608
+ "wow, can you teach me the motion, step by step in detail",
609
+ ],
610
+ [
611
+ f"{cur_dir}/examples/mabaoguo.mp4",
612
+ "why is the video funny?",
613
+ ],
614
+ [
615
+ f"{cur_dir}/examples/COBRA_PUSH_UPS_clip2.mp4",
616
+ "describe the body movement of the woman",
617
+ ],
618
+ [
619
+ f"{cur_dir}/examples/sample_demo_1.mp4",
620
+ "Why is this video interesting?",
621
+ ],
622
+ ],
623
+ inputs=[video, textbox],
624
+ )
625
+
626
+ with gr.Column(scale=7):
627
+ chatbot = gr.Chatbot(label="MotionLLM", bubble_full_width=True).style(height=875)
628
+ with gr.Row():
629
+ with gr.Column(scale=8):
630
+ textbox.render()
631
+ with gr.Column(scale=1, min_width=50):
632
+ submit_btn = gr.Button(
633
+ value="Send", variant="primary", interactive=True
634
+ )
635
+ with gr.Row(elem_id="buttons") as button_row:
636
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
637
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
638
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
639
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
640
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
641
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
642
+
643
+ gr.Markdown(tos_markdown)
644
+ gr.Markdown(learn_more_markdown)
645
+
646
+ tmp = gr.State()
647
+ upvote_btn.click(logging_up, [video, state], [tmp])
648
+
649
+ downvote_btn.click(logging_down, [video, state], [tmp])
650
+
651
+ submit_btn.click(generate, [image1, video, textbox, first_run, state, images_tensor],
652
+ [state, chatbot, first_run, textbox, images_tensor, image1, video])
653
+
654
+ regenerate_btn.click(regenerate, [state], [state, chatbot, first_run, textbox]).then(
655
+ generate, [image1, video, textbox, first_run, state, images_tensor], [state, chatbot, first_run, textbox, images_tensor, image1, video])
656
+
657
+ clear_btn.click(clear_history, [state],
658
+ [image1, video, textbox, first_run, state, chatbot, images_tensor])
659
+
660
+ app = gr.mount_gradio_app(app, demo, path="/")
661
+ uvicorn.run(app, host="0.0.0.0", port=6657)
assets/application.png ADDED
assets/compare.png ADDED
assets/highlight.png ADDED
assets/logo.png ADDED
assets/system.png ADDED
generate.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ import warnings
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import lightning as L
8
+ import torch
9
+
10
+ from lit_llama import LLaMA, Tokenizer
11
+ from lit_llama.utils import EmptyInitOnDevice, lazy_load
12
+
13
+
14
+ @torch.no_grad()
15
+ def generate(
16
+ model: torch.nn.Module,
17
+ idx: torch.Tensor,
18
+ max_new_tokens: int,
19
+ max_seq_length: int,
20
+ temperature: float = 1.0,
21
+ top_k: Optional[int] = None,
22
+ eos_id: Optional[int] = None,
23
+ tokenizer = None,
24
+ ) -> torch.Tensor:
25
+ """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
26
+
27
+ The implementation of this function is modified from A. Karpathy's nanoGPT.
28
+
29
+ Args:
30
+ model: The model to use.
31
+ idx: Tensor of shape (T) with indices of the prompt sequence.
32
+ max_new_tokens: The number of new tokens to generate.
33
+ max_seq_length: The maximum sequence length allowed.
34
+ temperature: Scales the predicted logits by 1 / temperature
35
+ top_k: If specified, only sample among the tokens with the k highest probabilities
36
+ eos_id: If specified, stop generating any more token once the <eos> token is triggered
37
+ """
38
+ # create an empty tensor of the expected final shape and fill in the current tokens
39
+ # import pdb; pdb.set_trace()
40
+ if type(idx) == tuple:
41
+ # import pdb; pdb.set_trace()
42
+ T = idx[0].shape[-1] + idx[2].shape[-1] + len(idx[1])
43
+ before_len = idx[0].shape[-1]
44
+ catted = torch.cat((idx[0], torch.zeros((1, len(idx[1]))).cuda(), idx[2]), dim=1).long()
45
+ idx = (catted, idx[1], before_len)
46
+ T_new = T + max_new_tokens
47
+ # import pdb; pdb.set_trace()
48
+ empty = torch.empty(T_new, dtype=idx[0].dtype, device=idx[0].device)
49
+ empty = torch.empty(T_new, dtype=idx[0].dtype, device=idx[0].device)
50
+ empty[:T] = idx[0]
51
+ idx = (empty, idx[1], [before_len])
52
+ # import pdb; pdb.set_trace()
53
+ else:
54
+ # import pdb; pdb.set_trace()
55
+ T = idx.size(0)
56
+ T_new = T + max_new_tokens
57
+ empty = torch.empty(T_new, dtype=idx.dtype, device=idx.device)
58
+ empty[:T] = idx
59
+ idx = empty
60
+
61
+ # generate max_new_tokens tokens
62
+ # import pdb; pdb.set_trace()
63
+ for t in range(T, T_new):
64
+ if type(idx) == tuple:
65
+ idx_cond = idx[0][:t]
66
+ tmp = idx_cond if T <= max_seq_length else idx_cond[-max_seq_length:]
67
+ # import pdb; pdb.set_trace()
68
+ idx_cond = (tmp.view(1, -1), idx[1].unsqueeze(0), idx[2])
69
+ else:
70
+ # ignore the not-filled-yet tokens
71
+ idx_cond = idx[:t]
72
+ # if the sequence context is growing too long we must crop it at max_seq_length
73
+ idx_cond = idx_cond if T <= max_seq_length else idx_cond[-max_seq_length:]
74
+
75
+ # forward
76
+ if type(idx) == tuple:
77
+ logits = model(idx_cond, maxlen=idx_cond[0].size(1))
78
+ else:
79
+ logits = model(idx_cond.view(1, -1))
80
+ logits = logits[0, -1] / temperature
81
+
82
+ # import pdb; pdb.set_trace()
83
+ # optionally crop the logits to only the top k options
84
+ if top_k is not None:
85
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
86
+ logits[logits < v[[-1]]] = -float("Inf")
87
+
88
+ probs = torch.nn.functional.softmax(logits, dim=-1)
89
+ idx_next = torch.multinomial(probs, num_samples=1)
90
+
91
+ # concatenate the new generation
92
+ if type(idx) == tuple:
93
+ seq = idx[0]
94
+ seq[t] = idx_next
95
+ idx = (seq, idx[1], idx[2])
96
+ else:
97
+ idx[t] = idx_next
98
+
99
+ # if <eos> token is triggered, return the output (stop generation)
100
+ if idx_next == eos_id:
101
+ if type(idx) == tuple:
102
+ return idx[0][:t+1]
103
+ else:
104
+ return idx[:t + 1] # include the EOS token
105
+ if type(idx) == tuple:
106
+ return idx[0]
107
+ else:
108
+ return idx
109
+
110
+
111
+ def main(
112
+ prompt: str = "Hello, my name is",
113
+ *,
114
+ num_samples: int = 1,
115
+ max_new_tokens: int = 50,
116
+ top_k: int = 200,
117
+ temperature: float = 0.8,
118
+ checkpoint_path: Optional[Path] = None,
119
+ tokenizer_path: Optional[Path] = None,
120
+ model_size: str = "7B",
121
+ quantize: Optional[str] = None,
122
+ ) -> None:
123
+ """Generates text samples based on a pre-trained LLaMA model and tokenizer.
124
+
125
+ Args:
126
+ prompt: The prompt string to use for generating the samples.
127
+ num_samples: The number of text samples to generate.
128
+ max_new_tokens: The number of generation steps to take.
129
+ top_k: The number of top most probable tokens to consider in the sampling process.
130
+ temperature: A value controlling the randomness of the sampling process. Higher values result in more random
131
+ samples.
132
+ checkpoint_path: The checkpoint path to load.
133
+ tokenizer_path: The tokenizer path to load.
134
+ model_size: The model size to load.
135
+ quantize: Whether to quantize the model and using which method:
136
+ ``"llm.int8"``: LLM.int8() mode,
137
+ ``"gptq.int4"``: GPTQ 4-bit mode.
138
+ """
139
+ if not checkpoint_path:
140
+ checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/lit-llama.pth")
141
+ if not tokenizer_path:
142
+ tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
143
+ assert checkpoint_path.is_file(), checkpoint_path
144
+ assert tokenizer_path.is_file(), tokenizer_path
145
+
146
+ fabric = L.Fabric(accelerator="cuda", devices=1)
147
+ dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
148
+
149
+ print("Loading model ...", file=sys.stderr)
150
+ t0 = time.time()
151
+ with EmptyInitOnDevice(
152
+ device=fabric.device, dtype=dtype, quantization_mode=quantize
153
+ ):
154
+ model = LLaMA.from_name(model_size)
155
+
156
+ checkpoint = lazy_load(checkpoint_path)
157
+ model.load_state_dict(checkpoint)
158
+ print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
159
+
160
+ model.eval()
161
+ model = fabric.setup_module(model)
162
+
163
+ tokenizer = Tokenizer(tokenizer_path)
164
+ encoded_prompt = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
165
+
166
+ L.seed_everything(1234)
167
+ t0 = time.perf_counter()
168
+
169
+ for _ in range(num_samples):
170
+ y = generate(
171
+ model,
172
+ encoded_prompt,
173
+ max_new_tokens,
174
+ model.config.block_size, # type: ignore[union-attr,arg-type]
175
+ temperature=temperature,
176
+ top_k=top_k,
177
+ )
178
+ print(tokenizer.decode(y))
179
+
180
+ t = time.perf_counter() - t0
181
+ print(f"\n\nTime for inference: {t:.02f} sec total, {num_samples * max_new_tokens / t:.02f} tokens/sec", file=sys.stderr)
182
+ print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
183
+
184
+
185
+ if __name__ == "__main__":
186
+ from jsonargparse import CLI
187
+
188
+ torch.set_float32_matmul_precision("high")
189
+ warnings.filterwarnings(
190
+ # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
191
+ "ignore",
192
+ message="ComplexHalf support is experimental and many operators don't support it yet"
193
+ )
194
+ warnings.filterwarnings(
195
+ # Triggered in bitsandbytes/autograd/_functions.py:298
196
+ "ignore",
197
+ message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization",
198
+ )
199
+ CLI(main)
lit_gpt/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lit_gpt.model import GPT
2
+ from lit_gpt.config import Config
3
+ from lit_gpt.tokenizer import Tokenizer
4
+
5
+ from lightning_utilities.core.imports import RequirementCache
6
+
7
+ _LIGHTNING_AVAILABLE = RequirementCache("lightning>=2.1.0.dev0")
8
+ # if not bool(_LIGHTNING_AVAILABLE):
9
+ # raise ImportError(
10
+ # "Lit-GPT requires lightning==2.1. Please run:\n"
11
+ # f" pip uninstall -y lightning; pip install -r requirements.txt\n{str(_LIGHTNING_AVAILABLE)}"
12
+ # )
13
+
14
+
15
+ __all__ = ["GPT", "Config", "Tokenizer"]
lit_gpt/adapter.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of the paper:
2
+
3
+ LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
4
+ https://arxiv.org/abs/2303.16199
5
+
6
+ Port for Lit-GPT
7
+ """
8
+ from dataclasses import dataclass
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from typing_extensions import Self
14
+
15
+ from lit_gpt.config import Config as BaseConfig
16
+ from lit_gpt.model import GPT as BaseModel
17
+ from lit_gpt.model import Block as BaseBlock
18
+ from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention
19
+
20
+
21
+ @dataclass
22
+ class Config(BaseConfig):
23
+ adapter_prompt_length: int = 10
24
+ adapter_start_layer: int = 2
25
+
26
+
27
+ class GPT(BaseModel):
28
+ """The implementation is identical to `lit_gpt.model.GPT` with the exception that
29
+ the `Block` saves the layer index and passes it down to the attention layer."""
30
+
31
+ def __init__(self, config: Config) -> None:
32
+ nn.Module.__init__(self)
33
+ assert config.padded_vocab_size is not None
34
+ self.config = config
35
+
36
+ self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
37
+ self.transformer = nn.ModuleDict(
38
+ dict(
39
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
40
+ h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
41
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
42
+ )
43
+ )
44
+ self.max_seq_length = self.config.block_size
45
+ self.mask_cache: Optional[torch.Tensor] = None
46
+
47
+ def forward(
48
+ self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0
49
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
50
+ T = idx.size(1)
51
+ if self.max_seq_length < T:
52
+ raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")
53
+
54
+ if input_pos is not None: # use the kv cache
55
+ cos = self.cos.index_select(0, input_pos)
56
+ sin = self.sin.index_select(0, input_pos)
57
+ if self.mask_cache is None:
58
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
59
+ mask = self.mask_cache.index_select(2, input_pos)
60
+ else:
61
+ cos = self.cos[:T]
62
+ sin = self.sin[:T]
63
+ mask = None
64
+
65
+ x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
66
+ for block in self.transformer.h:
67
+ x = block(x, cos, sin, mask, input_pos)
68
+ x = self.transformer.ln_f(x)
69
+ if lm_head_chunk_size > 0:
70
+ # chunk the lm head logits to reduce the peak memory used by autograd
71
+ return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)]
72
+ return self.lm_head(x) # (b, t, vocab_size)
73
+
74
+ @classmethod
75
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
76
+ return cls(Config.from_name(name, **kwargs))
77
+
78
+ def _init_weights(self, module: nn.Module) -> None:
79
+ """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
80
+ super()._init_weights(module)
81
+ if isinstance(module, CausalSelfAttention):
82
+ module.reset_parameters()
83
+
84
+
85
+ class Block(BaseBlock):
86
+ """The implementation is identical to `lit_gpt.model.Block` with the exception that
87
+ we replace the attention layer where adaption is implemented."""
88
+
89
+ def __init__(self, config: Config, block_idx: int) -> None:
90
+ # Skip the parent class __init__ altogether and replace it to avoid useless allocations
91
+ nn.Module.__init__(self)
92
+ self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
93
+ self.attn = CausalSelfAttention(config, block_idx)
94
+ if not config.shared_attention_norm:
95
+ self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
96
+ self.mlp = config.mlp_class(config)
97
+
98
+ self.config = config
99
+
100
+
101
+ class CausalSelfAttention(BaseCausalSelfAttention):
102
+ """A modification of `lit_gpt.model.CausalSelfAttention` that adds the attention
103
+ over the adaption prompt."""
104
+
105
+ def __init__(self, config: Config, block_idx: int) -> None:
106
+ super().__init__(config)
107
+ if block_idx >= config.adapter_start_layer:
108
+ # adapter embedding layer
109
+ self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
110
+ # gate for adaption
111
+ self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
112
+ # kv cache for inference
113
+ self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
114
+ self.block_idx = block_idx
115
+
116
+ def scaled_dot_product_attention(
117
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
118
+ ) -> torch.Tensor:
119
+ y = super().scaled_dot_product_attention(q, k, v, mask)
120
+ if self.block_idx < self.config.adapter_start_layer:
121
+ return y
122
+
123
+ aT = self.config.adapter_prompt_length
124
+ if self.adapter_kv_cache is not None:
125
+ # since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av
126
+ # are the same every call
127
+ ak, av = self.adapter_kv_cache
128
+ else:
129
+ prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd)
130
+ aqkv = self.attn(prefix)
131
+ q_per_kv = self.config.n_head // self.config.n_query_groups
132
+ aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)
133
+ aqkv = aqkv.permute(0, 2, 3, 1, 4)
134
+ _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2)
135
+ if self.config.n_query_groups != 1:
136
+ # for MHA this is a no-op
137
+ ak = ak.repeat_interleave(q_per_kv, dim=2)
138
+ av = av.repeat_interleave(q_per_kv, dim=2)
139
+ ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs)
140
+ av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs)
141
+ self.adapter_kv_cache = (ak, av)
142
+
143
+ T = q.size(2)
144
+ amask = torch.ones(T, aT, dtype=torch.bool, device=q.device)
145
+ ay = super().scaled_dot_product_attention(q, ak, av, amask)
146
+ return y + self.gating_factor * ay
147
+
148
+ def reset_parameters(self) -> None:
149
+ torch.nn.init.zeros_(self.gating_factor)
150
+
151
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
152
+ """For compatibility with older checkpoints."""
153
+ if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
154
+ state_dict[key] = state_dict[key].permute(0, 2, 1, 3)
155
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
156
+
157
+
158
+ def mark_only_adapter_as_trainable(model: GPT) -> None:
159
+ """Sets `requires_grad=False` for all non-adapter weights."""
160
+ for name, param in model.named_parameters():
161
+ param.requires_grad = adapter_filter(name, param)
162
+
163
+
164
+ def adapter_filter(key: str, value: Any) -> bool:
165
+ return "adapter_wte" in key or "gating_factor" in key
lit_gpt/adapter_v2.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of the paper:
2
+
3
+ LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model
4
+ https://arxiv.org/abs/2304.15010
5
+
6
+ Port for Lit-GPT
7
+ """
8
+ from dataclasses import dataclass
9
+ from typing import Any, Dict, Optional, Tuple, Type
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from typing_extensions import Self
14
+
15
+ import lit_gpt
16
+ from lit_gpt.adapter import GPT as BaseModel
17
+ from lit_gpt.adapter import Block as BaseBlock
18
+ from lit_gpt.adapter import CausalSelfAttention as BaseCausalSelfAttention
19
+ from lit_gpt.adapter import Config as BaseConfig
20
+ from lit_gpt.model import KVCache
21
+ from lit_gpt.utils import map_old_state_dict_weights
22
+
23
+
24
+ @dataclass
25
+ class Config(BaseConfig):
26
+ @property
27
+ def mlp_class(self) -> Type:
28
+ return getattr(lit_gpt.adapter_v2, self._mlp_class)
29
+
30
+
31
+ def adapter_filter(key: str, value: Any) -> bool:
32
+ adapter_substrings = (
33
+ # regular adapter v1 parameters
34
+ "adapter_wte",
35
+ "gating_factor",
36
+ # adapter v2: new bias and scale used in Linear
37
+ "adapter_scale",
38
+ "adapter_bias",
39
+ # adapter v2: Norm parameters are now trainable
40
+ "norm_1",
41
+ "norm_2",
42
+ "ln_f",
43
+ )
44
+ return any(s in key for s in adapter_substrings)
45
+
46
+
47
+ class AdapterV2Linear(torch.nn.Module):
48
+ def __init__(self, in_features: int, out_features: int, **kwargs) -> None:
49
+ super().__init__()
50
+ self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
51
+ self.adapter_bias = torch.nn.Parameter(torch.zeros(out_features), requires_grad=False)
52
+ self.adapter_scale = torch.nn.Parameter(torch.ones(out_features), requires_grad=False)
53
+
54
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
55
+ return self.adapter_scale * (self.linear(x) + self.adapter_bias)
56
+
57
+ def reset_parameters(self) -> None:
58
+ nn.init.zeros_(self.adapter_bias)
59
+ nn.init.ones_(self.adapter_scale)
60
+
61
+
62
+ class GPT(BaseModel):
63
+ def __init__(self, config: Config) -> None:
64
+ # Skip the parent class __init__ altogether and replace it to avoid useless allocations
65
+ nn.Module.__init__(self)
66
+ assert config.padded_vocab_size is not None
67
+ self.config = config
68
+
69
+ self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
70
+ self.transformer = nn.ModuleDict(
71
+ dict(
72
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
73
+ h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
74
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
75
+ )
76
+ )
77
+ self.max_seq_length = self.config.block_size
78
+ self.mask_cache: Optional[torch.Tensor] = None
79
+
80
+ @classmethod
81
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
82
+ return cls(Config.from_name(name, **kwargs))
83
+
84
+ def _init_weights(self, module: nn.Module) -> None:
85
+ """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
86
+ super()._init_weights(module)
87
+ if isinstance(module, AdapterV2Linear):
88
+ module.reset_parameters()
89
+
90
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
91
+ """For compatibility with base checkpoints."""
92
+ mapping = {"lm_head.weight": "lm_head.linear.weight"}
93
+ state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
94
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
95
+
96
+
97
+ class Block(BaseBlock):
98
+ """The implementation is identical to `lit_gpt.model.Block` with the exception that
99
+ we replace the attention layer where adaption is implemented."""
100
+
101
+ def __init__(self, config: Config, block_idx: int) -> None:
102
+ # Skip the parent class __init__ altogether and replace it to avoid useless allocations
103
+ nn.Module.__init__(self)
104
+ self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
105
+ self.attn = CausalSelfAttention(config, block_idx)
106
+ if not config.shared_attention_norm:
107
+ self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
108
+ self.mlp = config.mlp_class(config)
109
+
110
+ self.config = config
111
+
112
+
113
+ class CausalSelfAttention(BaseCausalSelfAttention):
114
+ """A modification of `lit_gpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class"""
115
+
116
+ def __init__(self, config: Config, block_idx: int) -> None:
117
+ # Skip the parent class __init__ altogether and replace it to avoid useless allocations
118
+ nn.Module.__init__(self)
119
+ shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
120
+ # key, query, value projections for all heads, but in a batch
121
+ self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias)
122
+ # output projection
123
+ self.proj = AdapterV2Linear(config.n_embd, config.n_embd, bias=config.bias)
124
+ # disabled by default
125
+ self.kv_cache: Optional[KVCache] = None
126
+
127
+ if block_idx >= config.adapter_start_layer:
128
+ # adapter embedding layer
129
+ self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
130
+ # gate for adaption
131
+ self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
132
+ # kv cache for inference
133
+ self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
134
+ self.block_idx = block_idx
135
+
136
+ self.config = config
137
+
138
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
139
+ """For compatibility with base checkpoints."""
140
+ mapping = {
141
+ "attn.weight": "attn.linear.weight",
142
+ "attn.bias": "attn.linear.bias",
143
+ "proj.weight": "proj.linear.weight",
144
+ "proj.bias": "proj.linear.bias",
145
+ }
146
+ state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
147
+ # For compatibility with older checkpoints
148
+ if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
149
+ state_dict[key] = state_dict[key].permute(0, 2, 1, 3)
150
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
151
+
152
+
153
+ class GptNeoxMLP(lit_gpt.model.GptNeoxMLP):
154
+ def __init__(self, config: Config) -> None:
155
+ nn.Module.__init__(self)
156
+ self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
157
+ self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)
158
+
159
+ self.config = config
160
+
161
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
162
+ """For compatibility with base checkpoints."""
163
+ mapping = {
164
+ "fc.weight": "fc.linear.weight",
165
+ "fc.bias": "fc.linear.bias",
166
+ "proj.weight": "proj.linear.weight",
167
+ "proj.bias": "proj.linear.bias",
168
+ }
169
+ state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
170
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
171
+
172
+
173
+ class LLaMAMLP(lit_gpt.model.LLaMAMLP):
174
+ def __init__(self, config: Config) -> None:
175
+ nn.Module.__init__(self)
176
+ self.fc_1 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
177
+ self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
178
+ self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)
179
+
180
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
181
+ """For compatibility with base checkpoints."""
182
+ mapping = {
183
+ "fc_1.weight": "fc_1.linear.weight",
184
+ "fc_1.bias": "fc_1.linear.bias",
185
+ "fc_2.weight": "fc_2.linear.weight",
186
+ "fc_2.bias": "fc_2.linear.bias",
187
+ "proj.weight": "proj.linear.weight",
188
+ "proj.bias": "proj.linear.bias",
189
+ }
190
+ state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
191
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
192
+
193
+
194
+ def mark_only_adapter_v2_as_trainable(model: GPT) -> None:
195
+ """Sets requires_grad=False for all non-adapter weights"""
196
+ for name, param in model.named_parameters():
197
+ param.requires_grad = adapter_filter(name, param)
lit_gpt/config.py ADDED
@@ -0,0 +1,1040 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Any, Literal, Optional, Type, Union
5
+
6
+ import torch
7
+ from typing_extensions import Self
8
+
9
+ import lit_gpt.model
10
+ from lit_gpt.utils import find_multiple
11
+
12
+
13
+ @dataclass
14
+ class Config:
15
+ org: str = "Lightning-AI"
16
+ name: str = "lit-GPT"
17
+ block_size: int = 4096
18
+ vocab_size: int = 50254
19
+ padding_multiple: int = 512
20
+ padded_vocab_size: Optional[int] = None
21
+ n_layer: int = 16
22
+ n_head: int = 32
23
+ n_embd: int = 4096
24
+ rotary_percentage: float = 0.25
25
+ parallel_residual: bool = True
26
+ bias: bool = True
27
+ lm_head_bias: bool = False
28
+ # to use multi-head attention (MHA), set this to `n_head` (default)
29
+ # to use multi-query attention (MQA), set this to 1
30
+ # to use grouped-query attention (GQA), set this to a value in between
31
+ # Example with `n_head=4`
32
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
33
+ # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
34
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
35
+ # │ │ │ │ │ │ │
36
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
37
+ # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
38
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
39
+ # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
40
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
41
+ # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
42
+ # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
43
+ # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
44
+ # MHA GQA MQA
45
+ # n_query_groups=4 n_query_groups=2 n_query_groups=1
46
+ #
47
+ # credit https://arxiv.org/pdf/2305.13245.pdf
48
+ n_query_groups: Optional[int] = None
49
+ shared_attention_norm: bool = False
50
+ _norm_class: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
51
+ norm_eps: float = 1e-5
52
+ _mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP"
53
+ gelu_approximate: str = "none"
54
+ intermediate_size: Optional[int] = None
55
+ rope_condense_ratio: int = 1
56
+ rope_base: int = 10000
57
+
58
+ def __post_init__(self):
59
+ assert self.n_embd % self.n_head == 0
60
+ self.head_size = self.n_embd // self.n_head
61
+
62
+ # vocab size should be a power of 2 to be optimal on hardware. compute the closest value
63
+ if self.padded_vocab_size is None:
64
+ self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple)
65
+ else:
66
+ # vocab size shouldn't be larger than padded vocab size
67
+ self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
68
+
69
+ # compute the number of query groups
70
+ if self.n_query_groups is not None:
71
+ assert self.n_head % self.n_query_groups == 0
72
+ else:
73
+ self.n_query_groups = self.n_head
74
+
75
+ # compute the intermediate size for MLP if not set
76
+ if self.intermediate_size is None:
77
+ if self._mlp_class == "LLaMAMLP":
78
+ raise ValueError("The config needs to set the `intermediate_size`")
79
+ self.intermediate_size = 4 * self.n_embd
80
+
81
+ self.rope_n_elem = int(self.rotary_percentage * self.head_size)
82
+
83
+ @classmethod
84
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
85
+ conf_dict = name_to_config[name].copy()
86
+ if "condense_ratio" in kwargs: # legacy name
87
+ kwargs["rope_condense_ratio"] = kwargs.pop("condense_ratio")
88
+ conf_dict.update(kwargs)
89
+ return cls(**conf_dict)
90
+
91
+ @classmethod
92
+ def from_json(cls, path: Union[str, Path], **kwargs: Any) -> Self:
93
+ with open(path, encoding="utf-8") as fp:
94
+ json_kwargs = json.load(fp)
95
+ if "condense_ratio" in json_kwargs: # legacy name
96
+ json_kwargs["rope_condense_ratio"] = json_kwargs.pop("condense_ratio")
97
+ if "condense_ratio" in kwargs: # legacy name
98
+ kwargs["rope_condense_ratio"] = kwargs.pop("condense_ratio")
99
+ json_kwargs.update(kwargs)
100
+ return cls(**json_kwargs)
101
+
102
+ @property
103
+ def mlp_class(self) -> Type:
104
+ # `self._mlp_class` cannot be the type to keep the config json serializable
105
+ return getattr(lit_gpt.model, self._mlp_class)
106
+
107
+ @property
108
+ def norm_class(self) -> Type:
109
+ # `self._norm_class` cannot be the type to keep the config json serializable
110
+ if self._norm_class == "RMSNorm":
111
+ from lit_gpt.rmsnorm import RMSNorm
112
+
113
+ return RMSNorm
114
+ return getattr(torch.nn, self._norm_class)
115
+
116
+
117
+ ########################
118
+ # Stability AI StableLM
119
+ ########################
120
+ configs = [
121
+ # https://huggingface.co/stabilityai/stablelm-base-alpha-3b/blob/main/config.json
122
+ dict(org="stabilityai", name="stablelm-base-alpha-3b"),
123
+ # https://huggingface.co/stabilityai/stablelm-base-alpha-7b/blob/main/config.json
124
+ dict(org="stabilityai", name="stablelm-base-alpha-7b", n_head=48, n_embd=6144, padding_multiple=256),
125
+ # https://huggingface.co/stabilityai/stablelm-tuned-alpha-3b/blob/main/config.json
126
+ dict(org="stabilityai", name="stablelm-tuned-alpha-3b", n_head=32),
127
+ # https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b/blob/main/config.json
128
+ dict(org="stabilityai", name="stablelm-tuned-alpha-7b", n_head=48, n_embd=6144, padding_multiple=256),
129
+ ]
130
+
131
+ ####################
132
+ # EleutherAI Pythia
133
+ ####################
134
+ pythia = [
135
+ # https://huggingface.co/EleutherAI/pythia-70m/blob/main/config.json
136
+ dict(org="EleutherAI", name="pythia-70m", block_size=2048, n_layer=6, n_embd=512, n_head=8, padding_multiple=128),
137
+ # https://huggingface.co/EleutherAI/pythia-160m/blob/main/config.json
138
+ dict(
139
+ org="EleutherAI", name="pythia-160m", block_size=2048, n_layer=12, n_embd=768, n_head=12, padding_multiple=128
140
+ ),
141
+ # https://huggingface.co/EleutherAI/pythia-410m/blob/main/config.json
142
+ dict(
143
+ org="EleutherAI", name="pythia-410m", block_size=2048, n_layer=24, n_embd=1024, n_head=16, padding_multiple=128
144
+ ),
145
+ # https://huggingface.co/EleutherAI/pythia-1b/blob/main/config.json
146
+ dict(org="EleutherAI", name="pythia-1b", block_size=2048, n_embd=2048, n_head=8, padding_multiple=128),
147
+ # https://huggingface.co/EleutherAI/pythia-1.4b/blob/main/config.json
148
+ dict(
149
+ org="EleutherAI", name="pythia-1.4b", block_size=2048, n_layer=24, n_embd=2048, n_head=16, padding_multiple=128
150
+ ),
151
+ # https://huggingface.co/EleutherAI/pythia-2.8b/blob/main/config.json
152
+ dict(org="EleutherAI", name="pythia-2.8b", block_size=2048, n_layer=32, n_embd=2560, padding_multiple=128),
153
+ # https://huggingface.co/EleutherAI/pythia-6.9b/blob/main/config.json
154
+ dict(org="EleutherAI", name="pythia-6.9b", block_size=2048, n_layer=32, padding_multiple=256),
155
+ # https://huggingface.co/EleutherAI/pythia-12b/blob/main/config.json
156
+ dict(org="EleutherAI", name="pythia-12b", block_size=2048, n_layer=36, n_embd=5120, n_head=40),
157
+ ]
158
+ configs.extend(pythia)
159
+ for c in pythia:
160
+ copy = c.copy()
161
+ copy["name"] = f"{c['name']}-deduped"
162
+ configs.append(copy)
163
+
164
+
165
+ ####################################
166
+ # togethercomputer RedPajama INCITE
167
+ ####################################
168
+ redpajama_incite = [
169
+ # https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1/blob/main/config.json
170
+ dict(
171
+ org="togethercomputer",
172
+ name="RedPajama-INCITE-{}-3B-v1",
173
+ block_size=2048,
174
+ n_layer=32,
175
+ n_embd=2560,
176
+ padding_multiple=256,
177
+ rotary_percentage=1.0,
178
+ parallel_residual=False,
179
+ ),
180
+ # https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Base/blob/main/config.json
181
+ dict(
182
+ org="togethercomputer",
183
+ name="RedPajama-INCITE-7B-{}",
184
+ block_size=2048,
185
+ n_layer=32,
186
+ padding_multiple=256,
187
+ rotary_percentage=1.0,
188
+ parallel_residual=False,
189
+ ),
190
+ # this redirects to the checkpoint above. kept for those who had the old weights already downloaded
191
+ dict(
192
+ org="togethercomputer",
193
+ name="RedPajama-INCITE-{}-7B-v0.1",
194
+ block_size=2048,
195
+ n_layer=32,
196
+ padding_multiple=256,
197
+ rotary_percentage=1.0,
198
+ parallel_residual=False,
199
+ ),
200
+ ]
201
+ for c in redpajama_incite:
202
+ for kind in ("Base", "Chat", "Instruct"):
203
+ copy = c.copy()
204
+ copy["name"] = c["name"].format(kind)
205
+ configs.append(copy)
206
+
207
+
208
+ #################
209
+ # TII UAE Falcon
210
+ #################
211
+ falcon = [
212
+ # https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json
213
+ dict(
214
+ org="tiiuae",
215
+ name="falcon-7b{}",
216
+ block_size=2048,
217
+ vocab_size=65024,
218
+ padded_vocab_size=65024,
219
+ n_layer=32,
220
+ n_head=71,
221
+ n_embd=4544,
222
+ rotary_percentage=1.0,
223
+ n_query_groups=1,
224
+ bias=False,
225
+ # this is not in the config, but in the original model implementation, only for this config
226
+ shared_attention_norm=True,
227
+ ),
228
+ # https://huggingface.co/tiiuae/falcon-40b/blob/main/config.json
229
+ dict(
230
+ org="tiiuae",
231
+ name="falcon-40b{}",
232
+ block_size=2048,
233
+ vocab_size=65024,
234
+ padded_vocab_size=65024,
235
+ n_layer=60,
236
+ n_head=128,
237
+ n_embd=8192,
238
+ rotary_percentage=1.0,
239
+ n_query_groups=8,
240
+ bias=False,
241
+ ),
242
+ ]
243
+ for c in falcon:
244
+ for kind in ("", "-instruct"):
245
+ copy = c.copy()
246
+ copy["name"] = c["name"].format(kind)
247
+ configs.append(copy)
248
+
249
+ # https://huggingface.co/tiiuae/falcon-180b/blob/main/config.json
250
+ falcon180b = dict(
251
+ org="tiiuae",
252
+ name="falcon-180B{}",
253
+ block_size=2048,
254
+ vocab_size=65024,
255
+ padded_vocab_size=65024,
256
+ n_layer=80,
257
+ n_head=232,
258
+ n_embd=14848,
259
+ rotary_percentage=1.0,
260
+ n_query_groups=8,
261
+ bias=False,
262
+ )
263
+
264
+ for kind in ("", "-chat"):
265
+ copy = falcon180b.copy()
266
+ copy["name"] = falcon180b["name"].format(kind)
267
+ configs.append(copy)
268
+
269
+
270
+ #############################
271
+ # OpenLM Research Open LLaMA
272
+ #############################
273
+ open_LLaMA = [
274
+ # https://huggingface.co/openlm-research/open_llama_3b/blob/main/config.json
275
+ dict(
276
+ org="openlm-research",
277
+ name="open_llama_3b",
278
+ block_size=2048,
279
+ vocab_size=32000,
280
+ padding_multiple=64,
281
+ n_layer=26,
282
+ n_embd=3200,
283
+ rotary_percentage=1.0,
284
+ parallel_residual=False,
285
+ bias=False,
286
+ _norm_class="RMSNorm",
287
+ norm_eps=1e-6,
288
+ _mlp_class="LLaMAMLP",
289
+ intermediate_size=8640,
290
+ ),
291
+ # https://huggingface.co/openlm-research/open_llama_7b/blob/main/config.json
292
+ dict(
293
+ org="openlm-research",
294
+ name="open_llama_7b",
295
+ block_size=2048,
296
+ vocab_size=32000,
297
+ padding_multiple=64,
298
+ n_layer=32,
299
+ rotary_percentage=1.0,
300
+ parallel_residual=False,
301
+ bias=False,
302
+ _norm_class="RMSNorm",
303
+ norm_eps=1e-6,
304
+ _mlp_class="LLaMAMLP",
305
+ intermediate_size=11008,
306
+ ),
307
+ # https://huggingface.co/openlm-research/open_llama_13b/blob/main/config.json
308
+ dict(
309
+ org="openlm-research",
310
+ name="open_llama_13b",
311
+ block_size=2048,
312
+ vocab_size=32000,
313
+ padding_multiple=64,
314
+ n_layer=40,
315
+ n_head=40,
316
+ n_embd=5120,
317
+ rotary_percentage=1.0,
318
+ parallel_residual=False,
319
+ bias=False,
320
+ _norm_class="RMSNorm",
321
+ norm_eps=1e-6,
322
+ _mlp_class="LLaMAMLP",
323
+ intermediate_size=13824,
324
+ ),
325
+ ]
326
+ configs.extend(open_LLaMA)
327
+
328
+
329
+ ###############
330
+ # LMSYS Vicuna
331
+ ###############
332
+ vicuna = [
333
+ # https://huggingface.co/lmsys/vicuna-7b-v1.3/blob/main/config.json
334
+ dict(
335
+ org="lmsys",
336
+ name="vicuna-7b-v1.3",
337
+ block_size=2048,
338
+ vocab_size=32000,
339
+ padding_multiple=64,
340
+ n_layer=32,
341
+ rotary_percentage=1.0,
342
+ parallel_residual=False,
343
+ bias=False,
344
+ _norm_class="RMSNorm",
345
+ norm_eps=1e-6,
346
+ _mlp_class="LLaMAMLP",
347
+ intermediate_size=11008,
348
+ ),
349
+ # https://huggingface.co/lmsys/vicuna-13b-v1.3/blob/main/config.json
350
+ dict(
351
+ org="lmsys",
352
+ name="vicuna-13b-v1.3",
353
+ block_size=2048,
354
+ vocab_size=32000,
355
+ padding_multiple=64,
356
+ n_layer=40,
357
+ n_head=40,
358
+ n_embd=5120,
359
+ rotary_percentage=1.0,
360
+ parallel_residual=False,
361
+ bias=False,
362
+ _norm_class="RMSNorm",
363
+ norm_eps=1e-6,
364
+ _mlp_class="LLaMAMLP",
365
+ intermediate_size=13824,
366
+ ),
367
+ # https://huggingface.co/lmsys/vicuna-33b-v1.3/blob/main/config.json
368
+ dict(
369
+ org="lmsys",
370
+ name="vicuna-33b-v1.3",
371
+ block_size=2048,
372
+ vocab_size=32000,
373
+ padding_multiple=64,
374
+ n_layer=60,
375
+ n_head=52,
376
+ n_embd=6656,
377
+ rotary_percentage=1.0,
378
+ parallel_residual=False,
379
+ bias=False,
380
+ _norm_class="RMSNorm",
381
+ norm_eps=1e-6,
382
+ _mlp_class="LLaMAMLP",
383
+ intermediate_size=17920,
384
+ ),
385
+ # https://huggingface.co/lmsys/vicuna-7b-v1.5/blob/main/config.json
386
+ dict(
387
+ org="lmsys",
388
+ name="vicuna-7b-v1.5",
389
+ vocab_size=32000,
390
+ padding_multiple=64,
391
+ n_layer=32,
392
+ rotary_percentage=1.0,
393
+ parallel_residual=False,
394
+ bias=False,
395
+ _norm_class="RMSNorm",
396
+ _mlp_class="LLaMAMLP",
397
+ intermediate_size=11008,
398
+ ),
399
+ # https://huggingface.co/lmsys/vicuna-7b-v1.5-16k/blob/main/config.json
400
+ dict(
401
+ org="lmsys",
402
+ name="vicuna-7b-v1.5-16k",
403
+ block_size=16384,
404
+ vocab_size=32000,
405
+ padding_multiple=64,
406
+ n_layer=32,
407
+ rotary_percentage=1.0,
408
+ parallel_residual=False,
409
+ bias=False,
410
+ _norm_class="RMSNorm",
411
+ _mlp_class="LLaMAMLP",
412
+ intermediate_size=11008,
413
+ rope_condense_ratio=4,
414
+ ),
415
+ # https://huggingface.co/lmsys/vicuna-13b-v1.5/blob/main/config.json
416
+ dict(
417
+ org="lmsys",
418
+ name="vicuna-13b-v1.5",
419
+ vocab_size=32000,
420
+ padding_multiple=64,
421
+ n_layer=40,
422
+ n_head=40,
423
+ n_embd=5120,
424
+ rotary_percentage=1.0,
425
+ parallel_residual=False,
426
+ bias=False,
427
+ _norm_class="RMSNorm",
428
+ _mlp_class="LLaMAMLP",
429
+ intermediate_size=13824,
430
+ ),
431
+ # https://huggingface.co/lmsys/vicuna-13b-v1.5-16k/blob/main/config.json
432
+ dict(
433
+ org="lmsys",
434
+ name="vicuna-13b-v1.5-16k",
435
+ block_size=16384,
436
+ vocab_size=32000,
437
+ padding_multiple=64,
438
+ n_layer=40,
439
+ n_head=40,
440
+ n_embd=5120,
441
+ rotary_percentage=1.0,
442
+ parallel_residual=False,
443
+ bias=False,
444
+ _norm_class="RMSNorm",
445
+ _mlp_class="LLaMAMLP",
446
+ intermediate_size=13824,
447
+ rope_condense_ratio=4,
448
+ ),
449
+ ]
450
+ configs.extend(vicuna)
451
+
452
+
453
+ #################
454
+ # LMSYS LongChat
455
+ #################
456
+ long_chat = [
457
+ # https://huggingface.co/lmsys/longchat-7b-16k/blob/main/config.json
458
+ dict(
459
+ org="lmsys",
460
+ name="longchat-7b-16k",
461
+ block_size=16384,
462
+ vocab_size=32000,
463
+ padding_multiple=64,
464
+ n_layer=32,
465
+ rotary_percentage=1.0,
466
+ parallel_residual=False,
467
+ bias=False,
468
+ _norm_class="RMSNorm",
469
+ norm_eps=1e-6,
470
+ _mlp_class="LLaMAMLP",
471
+ intermediate_size=11008,
472
+ rope_condense_ratio=8,
473
+ ),
474
+ # https://huggingface.co/lmsys/longchat-13b-16k/blob/main/config.json
475
+ dict(
476
+ org="lmsys",
477
+ name="longchat-13b-16k",
478
+ block_size=16384,
479
+ vocab_size=32000,
480
+ padding_multiple=64,
481
+ n_layer=40,
482
+ n_head=40,
483
+ n_embd=5120,
484
+ rotary_percentage=1.0,
485
+ parallel_residual=False,
486
+ bias=False,
487
+ _norm_class="RMSNorm",
488
+ norm_eps=1e-6,
489
+ _mlp_class="LLaMAMLP",
490
+ intermediate_size=13824,
491
+ rope_condense_ratio=8,
492
+ ),
493
+ ]
494
+ configs.extend(long_chat)
495
+
496
+
497
+ ######################
498
+ # NousResearch Hermes
499
+ ######################
500
+ nous_research = [
501
+ # https://huggingface.co/NousResearch/Nous-Hermes-llama-2-7b/blob/main/config.json
502
+ dict(
503
+ org="NousResearch",
504
+ name="Nous-Hermes-llama-2-7b",
505
+ padded_vocab_size=32000,
506
+ n_layer=32,
507
+ rotary_percentage=1.0,
508
+ parallel_residual=False,
509
+ bias=False,
510
+ _norm_class="RMSNorm",
511
+ norm_eps=1e-05,
512
+ _mlp_class="LLaMAMLP",
513
+ intermediate_size=11008,
514
+ ),
515
+ # https://huggingface.co/NousResearch/Nous-Hermes-13B/blob/main/config.json
516
+ dict(
517
+ org="NousResearch",
518
+ name="Nous-Hermes-13b",
519
+ block_size=2048,
520
+ vocab_size=32000,
521
+ padded_vocab_size=32001,
522
+ n_layer=40,
523
+ n_head=40,
524
+ n_embd=5120,
525
+ rotary_percentage=1.0,
526
+ parallel_residual=False,
527
+ bias=False,
528
+ _norm_class="RMSNorm",
529
+ norm_eps=1e-6,
530
+ _mlp_class="LLaMAMLP",
531
+ intermediate_size=13824,
532
+ ),
533
+ # https://huggingface.co/NousResearch/Nous-Hermes-Llama2-13b
534
+ dict(
535
+ org="NousResearch",
536
+ name="Nous-Hermes-Llama2-13b",
537
+ vocab_size=32000,
538
+ padded_vocab_size=32032,
539
+ n_layer=40,
540
+ n_head=40,
541
+ n_embd=5120,
542
+ rotary_percentage=1.0,
543
+ parallel_residual=False,
544
+ bias=False,
545
+ _norm_class="RMSNorm",
546
+ norm_eps=1e-05,
547
+ _mlp_class="LLaMAMLP",
548
+ intermediate_size=13824,
549
+ ),
550
+ ]
551
+ configs.extend(nous_research)
552
+
553
+
554
+ ###############
555
+ # Meta LLaMA 2
556
+ ###############
557
+ llama_2 = [
558
+ # https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json
559
+ dict(
560
+ org="meta-llama",
561
+ name="Llama-2-7b{}-hf",
562
+ vocab_size=32000,
563
+ padding_multiple=64,
564
+ n_layer=32,
565
+ rotary_percentage=1.0,
566
+ parallel_residual=False,
567
+ bias=False,
568
+ _norm_class="RMSNorm",
569
+ _mlp_class="LLaMAMLP",
570
+ intermediate_size=11008,
571
+ ),
572
+ # https://huggingface.co/meta-llama/Llama-2-13b-hf/blob/main/config.json
573
+ dict(
574
+ org="meta-llama",
575
+ name="Llama-2-13b{}-hf",
576
+ vocab_size=32000,
577
+ padding_multiple=64,
578
+ n_layer=40,
579
+ n_head=40,
580
+ n_embd=5120,
581
+ rotary_percentage=1.0,
582
+ parallel_residual=False,
583
+ bias=False,
584
+ _norm_class="RMSNorm",
585
+ _mlp_class="LLaMAMLP",
586
+ intermediate_size=13824,
587
+ ),
588
+ # https://huggingface.co/meta-llama/Llama-2-70b-hf/blob/main/config.json
589
+ dict(
590
+ org="meta-llama",
591
+ name="Llama-2-70b{}-hf",
592
+ vocab_size=32000,
593
+ padding_multiple=64,
594
+ n_layer=80,
595
+ n_head=64,
596
+ n_embd=8192,
597
+ n_query_groups=8,
598
+ rotary_percentage=1.0,
599
+ parallel_residual=False,
600
+ bias=False,
601
+ _norm_class="RMSNorm",
602
+ _mlp_class="LLaMAMLP",
603
+ intermediate_size=28672,
604
+ ),
605
+ ]
606
+ for c in llama_2:
607
+ for kind in ("", "-chat"):
608
+ copy = c.copy()
609
+ copy["name"] = c["name"].format(kind)
610
+ configs.append(copy)
611
+
612
+
613
+ ##########################
614
+ # Stability AI FreeWilly2
615
+ ##########################
616
+ freewilly_2 = [
617
+ # https://huggingface.co/stabilityai/FreeWilly2/blob/main/config.json
618
+ dict(
619
+ org="stabilityai",
620
+ name="FreeWilly2",
621
+ vocab_size=32000,
622
+ padding_multiple=64,
623
+ n_layer=80,
624
+ n_head=64,
625
+ n_embd=8192,
626
+ n_query_groups=8,
627
+ rotary_percentage=1.0,
628
+ parallel_residual=False,
629
+ bias=False,
630
+ _norm_class="RMSNorm",
631
+ _mlp_class="LLaMAMLP",
632
+ intermediate_size=28672,
633
+ )
634
+ ]
635
+ configs.extend(freewilly_2)
636
+
637
+
638
+ ##################
639
+ # Meta Code Llama
640
+ ##################
641
+ code_llama = [
642
+ # https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json
643
+ dict(
644
+ org="codellama",
645
+ name="CodeLlama-7b-hf",
646
+ block_size=16384,
647
+ vocab_size=32016,
648
+ padding_multiple=16,
649
+ n_layer=32,
650
+ rotary_percentage=1.0,
651
+ parallel_residual=False,
652
+ bias=False,
653
+ _norm_class="RMSNorm",
654
+ norm_eps=1e-05,
655
+ _mlp_class="LLaMAMLP",
656
+ intermediate_size=11008,
657
+ rope_base=1000000,
658
+ ),
659
+ # https://huggingface.co/codellama/CodeLlama-13b-hf/blob/main/config.json
660
+ dict(
661
+ org="codellama",
662
+ name="CodeLlama-13b-hf",
663
+ block_size=16384,
664
+ vocab_size=32016,
665
+ padding_multiple=16,
666
+ n_layer=40,
667
+ n_head=40,
668
+ n_embd=5120,
669
+ rotary_percentage=1.0,
670
+ parallel_residual=False,
671
+ bias=False,
672
+ _norm_class="RMSNorm",
673
+ norm_eps=1e-05,
674
+ _mlp_class="LLaMAMLP",
675
+ intermediate_size=13824,
676
+ rope_base=1000000,
677
+ ),
678
+ # https://huggingface.co/codellama/CodeLlama-34b-hf/blob/main/config.json
679
+ dict(
680
+ org="codellama",
681
+ name="CodeLlama-34b-hf",
682
+ block_size=16384,
683
+ vocab_size=32000,
684
+ padding_multiple=64,
685
+ n_layer=48,
686
+ n_head=64,
687
+ n_embd=8192,
688
+ n_query_groups=8,
689
+ rotary_percentage=1.0,
690
+ parallel_residual=False,
691
+ bias=False,
692
+ _norm_class="RMSNorm",
693
+ norm_eps=1e-05,
694
+ _mlp_class="LLaMAMLP",
695
+ intermediate_size=22016,
696
+ rope_base=1000000,
697
+ ),
698
+ # https://huggingface.co/codellama/CodeLlama-7b-Python-hf/blob/main/config.json
699
+ dict(
700
+ org="codellama",
701
+ name="CodeLlama-7b-Python-hf",
702
+ block_size=16384,
703
+ vocab_size=32000,
704
+ padding_multiple=64,
705
+ n_layer=32,
706
+ rotary_percentage=1.0,
707
+ parallel_residual=False,
708
+ bias=False,
709
+ _norm_class="RMSNorm",
710
+ norm_eps=1e-05,
711
+ _mlp_class="LLaMAMLP",
712
+ intermediate_size=11008,
713
+ rope_base=1000000,
714
+ ),
715
+ # https://huggingface.co/codellama/CodeLlama-13b-Python-hf/blob/main/config.json
716
+ dict(
717
+ org="codellama",
718
+ name="CodeLlama-13b-Python-hf",
719
+ block_size=16384,
720
+ vocab_size=32000,
721
+ padding_multiple=64,
722
+ n_layer=40,
723
+ n_head=40,
724
+ n_embd=5120,
725
+ rotary_percentage=1.0,
726
+ parallel_residual=False,
727
+ bias=False,
728
+ _norm_class="RMSNorm",
729
+ norm_eps=1e-05,
730
+ _mlp_class="LLaMAMLP",
731
+ intermediate_size=13824,
732
+ rope_base=1000000,
733
+ ),
734
+ # https://huggingface.co/codellama/CodeLlama-34b-Python-hf/blob/main/config.json
735
+ dict(
736
+ org="codellama",
737
+ name="CodeLlama-34b-Python-hf",
738
+ block_size=16384,
739
+ vocab_size=32000,
740
+ padding_multiple=64,
741
+ n_layer=48,
742
+ n_head=64,
743
+ n_embd=8192,
744
+ n_query_groups=8,
745
+ rotary_percentage=1.0,
746
+ parallel_residual=False,
747
+ bias=False,
748
+ _norm_class="RMSNorm",
749
+ norm_eps=1e-05,
750
+ _mlp_class="LLaMAMLP",
751
+ intermediate_size=22016,
752
+ rope_base=1000000,
753
+ ),
754
+ # https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/tree/main/config.json
755
+ dict(
756
+ org="codellama",
757
+ name="CodeLlama-7b-Instruct-hf",
758
+ block_size=16384,
759
+ vocab_size=32016,
760
+ padding_multiple=16,
761
+ n_layer=32,
762
+ rotary_percentage=1.0,
763
+ parallel_residual=False,
764
+ bias=False,
765
+ _norm_class="RMSNorm",
766
+ norm_eps=1e-05,
767
+ _mlp_class="LLaMAMLP",
768
+ intermediate_size=11008,
769
+ rope_base=1000000,
770
+ ),
771
+ # https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf/blob/main/config.json
772
+ dict(
773
+ org="codellama",
774
+ name="CodeLlama-13b-Instruct-hf",
775
+ block_size=2048,
776
+ vocab_size=32016,
777
+ padding_multiple=16,
778
+ n_layer=40,
779
+ n_head=40,
780
+ n_embd=5120,
781
+ rotary_percentage=1.0,
782
+ parallel_residual=False,
783
+ bias=False,
784
+ _norm_class="RMSNorm",
785
+ norm_eps=1e-05,
786
+ _mlp_class="LLaMAMLP",
787
+ intermediate_size=13824,
788
+ rope_base=1000000,
789
+ ),
790
+ # https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf/blob/main/config.json
791
+ dict(
792
+ org="codellama",
793
+ name="CodeLlama-34b-Instruct-hf",
794
+ block_size=16384,
795
+ vocab_size=32000,
796
+ padding_multiple=64,
797
+ n_layer=48,
798
+ n_head=64,
799
+ n_embd=8192,
800
+ n_query_groups=8,
801
+ rotary_percentage=1.0,
802
+ parallel_residual=False,
803
+ bias=False,
804
+ _norm_class="RMSNorm",
805
+ norm_eps=1e-05,
806
+ _mlp_class="LLaMAMLP",
807
+ intermediate_size=22016,
808
+ rope_base=1000000,
809
+ ),
810
+ ]
811
+ configs.extend(code_llama)
812
+
813
+
814
+ ########################
815
+ # garage-bAInd Platypus
816
+ ########################
817
+ platypus = [
818
+ # https://huggingface.co/garage-bAInd/Platypus-30B/blob/main/config.json
819
+ dict(
820
+ org="garage-bAInd",
821
+ name="Platypus-30B",
822
+ block_size=2048,
823
+ padded_vocab_size=32000,
824
+ n_layer=60,
825
+ n_head=52,
826
+ n_embd=6656,
827
+ rotary_percentage=1.0,
828
+ parallel_residual=False,
829
+ bias=False,
830
+ _norm_class="RMSNorm",
831
+ norm_eps=1e-06,
832
+ _mlp_class="LLaMAMLP",
833
+ intermediate_size=17920,
834
+ ),
835
+ # https://huggingface.co/garage-bAInd/Platypus2-7B/blob/main/config.json
836
+ dict(
837
+ org="garage-bAInd",
838
+ name="Platypus2-7B",
839
+ padded_vocab_size=32000,
840
+ n_layer=32,
841
+ rotary_percentage=1.0,
842
+ parallel_residual=False,
843
+ bias=False,
844
+ _norm_class="RMSNorm",
845
+ norm_eps=1e-05,
846
+ _mlp_class="LLaMAMLP",
847
+ intermediate_size=11008,
848
+ ),
849
+ # https://huggingface.co/garage-bAInd/Platypus2-13B/blob/main/config.json
850
+ dict(
851
+ org="garage-bAInd",
852
+ name="Platypus2-13B",
853
+ padded_vocab_size=32000,
854
+ n_layer=40,
855
+ n_head=40,
856
+ n_embd=5120,
857
+ rotary_percentage=1.0,
858
+ parallel_residual=False,
859
+ bias=False,
860
+ _norm_class="RMSNorm",
861
+ norm_eps=1e-05,
862
+ _mlp_class="LLaMAMLP",
863
+ intermediate_size=13824,
864
+ ),
865
+ # https://huggingface.co/garage-bAInd/Platypus2-70B/blob/main/config.json
866
+ dict(
867
+ org="garage-bAInd",
868
+ name="Platypus2-70B",
869
+ padded_vocab_size=32000,
870
+ n_layer=80,
871
+ n_head=64,
872
+ n_embd=8192,
873
+ rotary_percentage=1.0,
874
+ parallel_residual=False,
875
+ bias=False,
876
+ _norm_class="RMSNorm",
877
+ _mlp_class="LLaMAMLP",
878
+ intermediate_size=28672,
879
+ ),
880
+ # https://huggingface.co/garage-bAInd/Camel-Platypus2-13B/blob/main/config.json
881
+ dict(
882
+ org="garage-bAInd",
883
+ name="Camel-Platypus2-13B",
884
+ padded_vocab_size=32000,
885
+ n_layer=40,
886
+ n_head=40,
887
+ n_embd=5120,
888
+ rotary_percentage=1.0,
889
+ parallel_residual=False,
890
+ bias=False,
891
+ _norm_class="RMSNorm",
892
+ _mlp_class="LLaMAMLP",
893
+ intermediate_size=13824,
894
+ ),
895
+ # https://huggingface.co/garage-bAInd/Camel-Platypus2-70B/blob/main/config.json
896
+ dict(
897
+ org="garage-bAInd",
898
+ name="Camel-Platypus2-70B",
899
+ padded_vocab_size=32000,
900
+ n_layer=80,
901
+ n_head=64,
902
+ n_embd=8192,
903
+ n_query_groups=8,
904
+ rotary_percentage=1.0,
905
+ parallel_residual=False,
906
+ bias=False,
907
+ _norm_class="RMSNorm",
908
+ _mlp_class="LLaMAMLP",
909
+ intermediate_size=28672,
910
+ ),
911
+ # https://huggingface.co/garage-bAInd/Stable-Platypus2-13B/blob/main/config.json
912
+ dict(
913
+ org="garage-bAInd",
914
+ name="Stable-Platypus2-13B",
915
+ padded_vocab_size=32000,
916
+ n_layer=40,
917
+ n_head=40,
918
+ n_embd=5120,
919
+ rotary_percentage=1.0,
920
+ parallel_residual=False,
921
+ bias=False,
922
+ _norm_class="RMSNorm",
923
+ _mlp_class="LLaMAMLP",
924
+ intermediate_size=13824,
925
+ ),
926
+ # https://huggingface.co/garage-bAInd/Platypus2-70B-instruct/blob/main/config.json
927
+ dict(
928
+ org="garage-bAInd",
929
+ name="Platypus2-70B-instruct",
930
+ padded_vocab_size=32000,
931
+ n_layer=80,
932
+ n_head=64,
933
+ n_embd=8192,
934
+ n_query_groups=8,
935
+ rotary_percentage=1.0,
936
+ parallel_residual=False,
937
+ bias=False,
938
+ _norm_class="RMSNorm",
939
+ _mlp_class="LLaMAMLP",
940
+ intermediate_size=28672,
941
+ ),
942
+ ]
943
+ configs.extend(platypus)
944
+
945
+
946
+ ##########################
947
+ # Stability AI StableCode
948
+ ##########################
949
+ stablecode = [
950
+ # https://huggingface.co/stabilityai/stablecode-completion-alpha-3b/blob/main/config.json
951
+ dict(
952
+ org="stabilityai",
953
+ name="stablecode-completion-alpha-3b",
954
+ block_size=16384,
955
+ vocab_size=49152,
956
+ n_layer=32,
957
+ n_embd=2560,
958
+ ),
959
+ # https://huggingface.co/stabilityai/stablecode-completion-alpha-3b-4k/blob/main/config.json
960
+ dict(org="stabilityai", name="stablecode-completion-alpha-3b-4k", vocab_size=49152, n_layer=32, n_embd=2560),
961
+ # https://huggingface.co/stabilityai/stablecode-instruct-alpha-3b/blob/main/config.json
962
+ dict(org="stabilityai", name="stablecode-instruct-alpha-3b", vocab_size=49152, n_layer=32, n_embd=2560),
963
+ ]
964
+ configs.extend(stablecode)
965
+
966
+
967
+ ##################################
968
+ # togethercomputer LLaMA-2-7B-32K
969
+ ##################################
970
+ together_llama2_32k = [
971
+ # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/config.json
972
+ dict(
973
+ org="togethercomputer",
974
+ name="LLaMA-2-7B-32K",
975
+ vocab_size=32000,
976
+ padding_multiple=64,
977
+ n_layer=32,
978
+ rotary_percentage=1.0,
979
+ parallel_residual=False,
980
+ bias=False,
981
+ _norm_class="RMSNorm",
982
+ _mlp_class="LLaMAMLP",
983
+ intermediate_size=11008,
984
+ rope_condense_ratio=8,
985
+ )
986
+ ]
987
+ configs.extend(together_llama2_32k)
988
+
989
+
990
+ ################
991
+ # Microsoft Phi
992
+ ################
993
+ phi = [
994
+ # https://huggingface.co/microsoft/phi-1_5/blob/main/config.json
995
+ dict(
996
+ org="microsoft",
997
+ name="phi-1_5",
998
+ vocab_size=50257,
999
+ padded_vocab_size=51200,
1000
+ block_size=2048,
1001
+ n_embd=2048,
1002
+ n_layer=24,
1003
+ rotary_percentage=0.5, # 32 / (n_embd / n_head) = 32 / 64
1004
+ shared_attention_norm=True,
1005
+ lm_head_bias=True,
1006
+ gelu_approximate="tanh",
1007
+ )
1008
+ ]
1009
+ configs.extend(phi)
1010
+
1011
+
1012
+ #############
1013
+ # Mistral AI
1014
+ #############
1015
+ mistral = [
1016
+ # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
1017
+ dict(
1018
+ org="mistralai",
1019
+ name="Mistral-7B-{}v0.1",
1020
+ padded_vocab_size=32000,
1021
+ block_size=4096, # should be 32768 but sliding window attention is not implemented
1022
+ n_layer=32,
1023
+ n_query_groups=8,
1024
+ rotary_percentage=1.0,
1025
+ parallel_residual=False,
1026
+ bias=False,
1027
+ _norm_class="RMSNorm",
1028
+ norm_eps=1e-05,
1029
+ _mlp_class="LLaMAMLP",
1030
+ intermediate_size=14336,
1031
+ )
1032
+ ]
1033
+ for c in mistral:
1034
+ for kind in ("", "Instruct-"):
1035
+ copy = c.copy()
1036
+ copy["name"] = c["name"].format(kind)
1037
+ configs.append(copy)
1038
+
1039
+
1040
+ name_to_config = {config["name"]: config for config in configs}
lit_gpt/lora.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Derived from https://github.com/microsoft/LoRA
2
+ # ------------------------------------------------------------------------------------------
3
+ # Copyright (c) Microsoft Corporation. All rights reserved.
4
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
5
+ # ------------------------------------------------------------------------------------------
6
+
7
+ r"""
8
+ Low Ranking Adaptation for LLMs scheme.
9
+
10
+ ┌───────────────────┐
11
+ ┆ h ┆
12
+ └───────────────────┘
13
+
14
+ |
15
+ +
16
+ / \
17
+ ┌─────────────────┐ ╭───────────────╮ Matrix initialization:
18
+ ┆ ┆ \ B / B = 0
19
+ ┆ pretrained ┆ \ r*d / A = N(0, sigma^2)
20
+ ┆ weights ┆ ╰─────────╯
21
+ ┆ ┆ | r | r - rank
22
+ ┆ W e R^(d*d) ┆ | ◀─────▶ |
23
+ ┆ ┆ ╭─────────╮
24
+ └─────────────────┘ / A \
25
+ ▲ / d*r \
26
+ \ ╰───────────────╯
27
+ \ ▲
28
+ \ /
29
+ \ /
30
+ ┌───────────────────┐
31
+ ┆ x ┆
32
+ └───────────────────┘
33
+
34
+ With LoRA (Low Ranking Adaptation: https://arxiv.org/abs/2106.09685) instead of learning weights of size d*d,
35
+ we can freeze the pretrained weights and instead learn two matrices of size d*r and r*d (they will store weight updates
36
+ for the pretrained weights): the number of parameters in this case will be reduced drastically (depending on the rank of
37
+ course) yet after multiplication of matrices d*r and r*d we will get a matrix d*d which we can sum with frozen
38
+ pretrained weights and thus fine-tune the model.
39
+
40
+ The goal of this approach is to move weight updates into a separate matrix which is decomposed with
41
+ two matrices of a lower rank.
42
+ """
43
+
44
+ import math
45
+ from dataclasses import dataclass
46
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union
47
+
48
+ import torch
49
+ import torch.nn as nn
50
+ from torch.nn import functional as F
51
+ from typing_extensions import Self
52
+
53
+ import lit_gpt
54
+ from lit_gpt.config import Config as BaseConfig
55
+ from lit_gpt.model import GPT as BaseModel
56
+ from lit_gpt.model import Block as BaseBlock
57
+ from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention
58
+ from lit_gpt.model import KVCache
59
+ from lit_gpt.utils import map_old_state_dict_weights
60
+
61
+
62
+ class LoRALayer(nn.Module):
63
+ def __init__(self, r: int, lora_alpha: int, lora_dropout: float):
64
+ """Store LoRA specific attributes in a class.
65
+
66
+ Args:
67
+ r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
68
+ the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
69
+ lora_alpha: alpha is needed for scaling updates as alpha/r
70
+ "This scaling helps to reduce the need to retune hyperparameters when we vary r"
71
+ https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
72
+ lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
73
+ """
74
+ super().__init__()
75
+ assert r >= 0
76
+ self.r = r
77
+ self.lora_alpha = lora_alpha
78
+ # Optional dropout
79
+ if lora_dropout > 0.0:
80
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
81
+ else:
82
+ self.lora_dropout = lambda x: x
83
+ # Mark the weight as unmerged
84
+ self.merged = False
85
+
86
+
87
+ class LoRALinear(LoRALayer):
88
+ # LoRA implemented in a dense layer
89
+ def __init__(
90
+ self,
91
+ # ↓ this part is for pretrained weights
92
+ in_features: int,
93
+ out_features: int,
94
+ # ↓ the remaining part is for LoRA
95
+ r: int = 0,
96
+ lora_alpha: int = 1,
97
+ lora_dropout: float = 0.0,
98
+ **kwargs,
99
+ ):
100
+ """LoRA wrapper around linear class.
101
+
102
+ This class has three weight matrices:
103
+ 1. Pretrained weights are stored as `self.linear.weight`
104
+ 2. LoRA A matrix as `self.lora_A`
105
+ 3. LoRA B matrix as `self.lora_B`
106
+ Only LoRA's A and B matrices are updated, pretrained weights stay frozen.
107
+
108
+ Args:
109
+ in_features: number of input features of the pretrained weights
110
+ out_features: number of output features of the pretrained weights
111
+ r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
112
+ the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
113
+ lora_alpha: alpha is needed for scaling updates as alpha/r
114
+ "This scaling helps to reduce the need to retune hyperparameters when we vary r"
115
+ https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
116
+ lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
117
+ """
118
+ super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
119
+ self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
120
+
121
+ # Actual trainable parameters
122
+ if r > 0:
123
+ self.lora_A = nn.Parameter(torch.zeros((r, in_features)))
124
+ self.lora_B = nn.Parameter(torch.zeros((out_features, r)))
125
+ self.scaling = self.lora_alpha / self.r
126
+ self.reset_parameters()
127
+
128
+ def reset_parameters(self) -> None:
129
+ """Reset all the weights, even including pretrained ones."""
130
+ if hasattr(self, "lora_A"):
131
+ # initialize A the same way as the default for nn.Linear and B to zero
132
+ # Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314
133
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
134
+ nn.init.zeros_(self.lora_B)
135
+
136
+ def merge(self) -> None:
137
+ """Merges the LoRA weights into the full-rank weights (W = W + delta_W)."""
138
+ if self.r > 0 and not self.merged:
139
+ # Merge the weights and mark it
140
+ self.linear.weight.data += (self.lora_B @ self.lora_A) * self.scaling
141
+ self.merged = True
142
+
143
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
144
+ # if weights are merged or rank is less or equal to zero (LoRA is disabled) - it's only a regular nn.Linear forward pass;
145
+ # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights
146
+ pretrained = self.linear(x)
147
+ if self.r == 0 or self.merged:
148
+ return pretrained
149
+ lora = (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
150
+ return pretrained + lora
151
+
152
+
153
+ class LoRAQKVLinear(LoRALinear):
154
+ # LoRA implemented in a dense layer
155
+ def __init__(
156
+ self,
157
+ # ↓ this part is for pretrained weights
158
+ in_features: int,
159
+ out_features: int,
160
+ # ↓ the remaining part is for LoRA
161
+ n_head: int,
162
+ n_query_groups: int,
163
+ r: int = 0,
164
+ lora_alpha: int = 1,
165
+ lora_dropout: float = 0.0,
166
+ enable_lora: Union[bool, Tuple[bool, bool, bool]] = False,
167
+ **kwargs,
168
+ ):
169
+ """LoRA wrapper around linear class that is used for calculation of q, k and v matrices.
170
+
171
+ This class has three weight matrices:
172
+ 1. Pretrained weights are stored as `self.linear.weight`
173
+ 2. LoRA A matrix as `self.lora_A`
174
+ 3. LoRA B matrix as `self.lora_B`
175
+ Only LoRA's A and B matrices are updated, pretrained weights stay frozen.
176
+
177
+ Args:
178
+ in_features: number of input features of the pretrained weights
179
+ out_features: number of output features of the pretrained weights
180
+ n_head: number of attention heads
181
+ n_query_groups: number of query groups (see diagram in `lit_gpt/config.py`)
182
+ r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
183
+ the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
184
+ lora_alpha: alpha is needed for scaling updates as alpha/r
185
+ "This scaling helps to reduce the need to retune hyperparameters when we vary r"
186
+ https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
187
+ lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
188
+ enable_lora: MergeLinear class is for attention mechanism where qkv are calculated with a single weight matrix. If we
189
+ don't want to apply LoRA we can set it as False. For example if we want to apply LoRA only to `query`
190
+ and `value` but keep `key` without weight updates we should pass `[True, False, True]`
191
+ """
192
+ super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
193
+ self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
194
+ self.n_head = n_head
195
+ self.n_query_groups = n_query_groups
196
+ if isinstance(enable_lora, bool):
197
+ enable_lora = [enable_lora] * 3
198
+ assert len(enable_lora) == 3
199
+ self.enable_lora = enable_lora
200
+
201
+ # Actual trainable parameters
202
+ # To better understand initialization let's imagine that we have such parameters:
203
+ # ⚬ in_features: 128 (embeddings_size)
204
+ # ⚬ out_features: 384 (3 * embedding_size)
205
+ # ⚬ r: 2
206
+ # ⚬ enable_lora: [True, False, True]
207
+ if r > 0 and any(enable_lora):
208
+ self.lora_A = nn.Parameter(torch.zeros((r * sum(enable_lora), in_features))) # (4, 128)
209
+ enable_q, enable_k, enable_v = enable_lora
210
+ self.kv_embd_size = self.linear.in_features // (n_head // n_query_groups)
211
+ # qkv_shapes will be used to split a tensor with weights correctly
212
+ qkv_shapes = (
213
+ self.linear.in_features * enable_q,
214
+ self.kv_embd_size * enable_k,
215
+ self.kv_embd_size * enable_v,
216
+ )
217
+ self.qkv_shapes = [s for s in qkv_shapes if s]
218
+ self.lora_B = nn.Parameter(torch.zeros(sum(self.qkv_shapes), r)) # (256, 2))
219
+ # Notes about shapes above
220
+ # - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices;
221
+ # 128 is the input size of the x (embedding size). (4, 128) and not (128, 4) because later on in
222
+ # F.linear function weights are automatically transposed. In addition conv1d requires channels to
223
+ # be before seq length
224
+ # - self.lora_B has shape (256, 2): 256 because LoRA is applied only to two matrices, so the output is
225
+ # 128*2; 2 tells to have two channels per group for group convolution
226
+
227
+ # Scaling:
228
+ # This balances the pretrained model`s knowledge and the new task-specific adaptation
229
+ # https://lightning.ai/pages/community/tutorial/lora-llm/
230
+ # So, set alpha to 1.0 to fully add LoRA. If the LoRA seems to have too much effect (i.e., overfitted), set
231
+ # alpha to lower value. If the LoRA seems to have too little effect, set alpha to higher than 1.0. You can
232
+ # tune these values to your needs. This value can be even slightly greater than 1.0!
233
+ # https://github.com/cloneofsimo/lora
234
+ self.scaling = self.lora_alpha / self.r
235
+
236
+ # Compute the indices
237
+ # Indices are needed to properly pad weight updates with zeros. If we want to fine-tune queries and values,
238
+ # but not keys, then the weights update should be:
239
+ #
240
+ # [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
241
+ # [....................................],
242
+ # [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
243
+ # ↑ ↑ ↑
244
+ # ________________________________________
245
+ # | query | key | value |
246
+ # ----------------------------------------
247
+ self.lora_ind = []
248
+ if enable_q:
249
+ self.lora_ind.extend(range(0, self.linear.in_features))
250
+ if enable_k:
251
+ self.lora_ind.extend(range(self.linear.in_features, self.linear.in_features + self.kv_embd_size))
252
+ if enable_v:
253
+ self.lora_ind.extend(range(self.linear.in_features + self.kv_embd_size, self.linear.out_features))
254
+ self.reset_parameters()
255
+
256
+ def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
257
+ """Properly pad weight updates with zeros.
258
+
259
+ If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys,
260
+ then the weights update should be:
261
+
262
+ [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
263
+ [....................................],
264
+ [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
265
+ ↑ ↑ ↑
266
+ ________________________________________
267
+ | query | key | value |
268
+ ----------------------------------------
269
+
270
+ Args:
271
+ x: tensor with weights update that will be padded with zeros if necessary
272
+
273
+ Returns:
274
+ A tensor with weight updates and zeros for deselected q, k or v
275
+ """
276
+ # we need to do zero padding only if LoRA is disabled for one of QKV matrices
277
+ if all(self.enable_lora):
278
+ return x
279
+
280
+ # Let's image that:
281
+ # ⚬ input x has shape (64, 64, 256): (batch_size, sequence_length, embeddings_size)
282
+ # ⚬ embeddings_size: 128
283
+ # ⚬ self.linear.out_features: 384 (3 * embeddings_size)
284
+ # ⚬ enable_lora: [True, False, True]
285
+ # Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected
286
+ # embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but
287
+ # only for key updates (this is where self.lora_ind comes in handy)
288
+ # Note: double transpose (in the beginning and in the end) is basically a guard for two-dimensional tensors
289
+ # for example when we want to merge/unmerge LoRA weights and pretrained weights
290
+ x = x.transpose(0, 1)
291
+ result = x.new_zeros((*x.shape[:-1], self.linear.out_features)) # (64, 64, 384)
292
+ result = result.view(-1, self.linear.out_features) # (4096, 384)
293
+ result = result.index_copy(
294
+ 1, torch.tensor(self.lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes))
295
+ ) # (4096, 256)
296
+ return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1) # (64, 64, 384)
297
+
298
+ def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
299
+ """An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries.
300
+
301
+ If the number of heads is equal to the number of query groups - grouped queries are disabled
302
+ (see scheme in `lit_gpt/config.py:Config`). In this case the combined QKV matrix consists of equally sized
303
+ query, key and value parts, which means we can utilize `groups` argument from `conv1d`: with this argument the
304
+ input and weight matrices will be splitted in equally sized parts and applied separately (like having multiple
305
+ conv layers side by side).
306
+
307
+ Otherwise QKV matrix consists of unequally sized parts and thus we have to split input and weight matrices manually,
308
+ apply each part of the weight matrix to the corresponding input's part and concatenate the result.
309
+
310
+ Args:
311
+ input: input matrix of shape (B, C, T)
312
+ weight: weight matrix of shape (C_output, rank, 1).
313
+ "C_output" is defined as a sum of embedding sizes for each enabled LoRA layer (see init method of the class).
314
+
315
+ Returns:
316
+ A tensor with a shape (B, C_output, T)
317
+
318
+ """
319
+ if self.n_head == self.n_query_groups:
320
+ return F.conv1d(input, weight, groups=sum(self.enable_lora)) # (B, C_output, T)
321
+
322
+ # Notation:
323
+ # ⚬ N: number of enabled LoRA layers (self.enable_lora)
324
+ # ⚬ C_output': embeddings size for each LoRA layer (not equal in size)
325
+ # ⚬ r: rank of all LoRA layers (equal in size)
326
+
327
+ input_splitted = input.chunk(sum(self.enable_lora), dim=1) # N * (B, C // N, T)
328
+ weight_splitted = weight.split(self.qkv_shapes) # N * (C_output', r, 1)
329
+ return torch.cat(
330
+ [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1 # (B, C_output', T)
331
+ ) # (B, C_output, T)
332
+
333
+ def merge(self) -> None:
334
+ """Merges the LoRA weights into the full-rank weights (W = W + delta_W)."""
335
+
336
+ # Let's assume that:
337
+ # ⚬ self.linear.weight.data: (384, 128) or (3 * embedding_size, embedding_size)
338
+ # ⚬ self.lora_A.data: (4, 128)
339
+ # ⚬ self.lora_B.data: (256, 2)
340
+ if self.r > 0 and any(self.enable_lora) and not self.merged:
341
+ delta_w = self.conv1d(
342
+ self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128)
343
+ self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
344
+ ).squeeze(
345
+ 0
346
+ ) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)
347
+ # W = W + delta_W (merge)
348
+ self.linear.weight.data += self.zero_pad(delta_w * self.scaling) # (256, 128) after zero_pad (384, 128)
349
+ self.merged = True
350
+
351
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
352
+ """Do the forward pass.
353
+
354
+ If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication.
355
+ If not, then multiply pretrained weights with input, apply LoRA on input and do summation.
356
+
357
+ Args:
358
+ x: input tensor of shape (batch_size, context_length, embedding_size)
359
+
360
+ Returns:
361
+ Output tensor of shape (batch_size, context_length, 3 * embedding_size)
362
+ """
363
+
364
+ # Let's assume that:
365
+ # ⚬ x: (64, 64, 128) or (batch_size, context_length, embedding_size)
366
+ # ⚬ self.linear.weight: (384, 128) or (3 * embedding_size, embedding_size)
367
+ # ⚬ self.lora_A.data: (4, 128)
368
+ # ⚬ self.lora_B.data: (256, 2)
369
+
370
+ # if weights are merged or LoRA is disabled (r <= 0 or all `enable_lora` are False) - it's only a regular nn.Linear forward pass;
371
+ # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights
372
+ pretrained = self.linear(x)
373
+ if self.r == 0 or not any(self.enable_lora) or self.merged:
374
+ return pretrained
375
+ after_A = F.linear(self.lora_dropout(x), self.lora_A) # (64, 64, 128) @ (4, 128) -> (64, 64, 4)
376
+ # For F.conv1d:
377
+ # ⚬ input: input tensor of shape (mini-batch, in_channels, iW)
378
+ # ⚬ weight: filters of shape (out_channels, in_channels/groups, kW)
379
+ after_B = self.conv1d(
380
+ after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64)
381
+ self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
382
+ ).transpose(
383
+ -2, -1
384
+ ) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256)
385
+ lora = self.zero_pad(after_B) * self.scaling # (64, 64, 256) after zero_pad (64, 64, 384)
386
+ return pretrained + lora
387
+
388
+
389
+ def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
390
+ """Freeze all modules except LoRA's and depending on 'bias' value unfreezes bias weights.
391
+
392
+ Args:
393
+ model: model with LoRA layers
394
+ bias:
395
+ ``"none"``: all bias weights will be frozen,
396
+ ``"lora_only"``: only bias weight for LoRA layers will be unfrozen,
397
+ ``"all"``: all bias weights will be unfrozen.
398
+
399
+ Raises:
400
+ NotImplementedError: if `bias` not in ["none", "lora_only", "all"]
401
+ """
402
+ # freeze all layers except LoRA's
403
+ for n, p in model.named_parameters():
404
+ if "lora_" not in n:
405
+ p.requires_grad = False
406
+
407
+ # depending on the `bias` value unfreeze bias weights
408
+ if bias == "none":
409
+ return
410
+ if bias == "all":
411
+ for n, p in model.named_parameters():
412
+ if "bias" in n:
413
+ p.requires_grad = True
414
+ elif bias == "lora_only":
415
+ for m in model.modules():
416
+ if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None:
417
+ m.bias.requires_grad = True
418
+ else:
419
+ raise NotImplementedError
420
+
421
+
422
+ def lora_filter(key: str, value: Any) -> bool:
423
+ return "lora_" in key
424
+
425
+
426
+ @dataclass
427
+ class Config(BaseConfig):
428
+ """
429
+ Args:
430
+ r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
431
+ the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
432
+ alpha: alpha is needed for scaling updates as alpha/r
433
+ "This scaling helps to reduce the need to retune hyperparameters when we vary r"
434
+ https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
435
+ dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
436
+ to_*: either apply LoRA to the specified weights or not
437
+ """
438
+
439
+ r: int = 0
440
+ alpha: int = 1
441
+ dropout: float = 0.0
442
+ to_query: bool = False
443
+ to_key: bool = False
444
+ to_value: bool = False
445
+ to_projection: bool = False
446
+ to_mlp: bool = False
447
+ to_head: bool = False
448
+
449
+ @property
450
+ def mlp_class(self) -> Type:
451
+ return getattr(lit_gpt.lora, self._mlp_class)
452
+
453
+
454
+ class GPT(BaseModel):
455
+ def __init__(self, config: Config) -> None:
456
+ nn.Module.__init__(self)
457
+ assert config.padded_vocab_size is not None
458
+ self.config = config
459
+
460
+ self.lm_head = LoRALinear(
461
+ config.n_embd,
462
+ config.padded_vocab_size,
463
+ bias=config.lm_head_bias,
464
+ r=(config.r if config.to_head else 0),
465
+ lora_alpha=config.alpha,
466
+ lora_dropout=config.dropout,
467
+ )
468
+ self.transformer = nn.ModuleDict(
469
+ dict(
470
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
471
+ h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
472
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
473
+ )
474
+ )
475
+ self.max_seq_length = self.config.block_size
476
+ self.mask_cache: Optional[torch.Tensor] = None
477
+
478
+ def forward(
479
+ self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0, maxlen: int = None
480
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
481
+ T = idx.size(1) if maxlen is None else maxlen
482
+ if self.max_seq_length < T:
483
+ raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")
484
+ # import pdb; pdb.set_trace()
485
+ if input_pos is not None: # use the kv cache
486
+ cos = self.cos.index_select(0, input_pos)
487
+ sin = self.sin.index_select(0, input_pos)
488
+ if self.mask_cache is None:
489
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
490
+ mask = self.mask_cache.index_select(2, input_pos)
491
+ else:
492
+ cos = self.cos[:T]
493
+ sin = self.sin[:T]
494
+ mask = None
495
+
496
+ if type(idx) is tuple:
497
+ # import pdb; pdb.set_trace()
498
+ stack_before_tokens_x, motion_tokens, before_len = idx
499
+ # stack_before_tokens_x = stack_before_tokens_x.unsqueeze(0)
500
+ # motion_tokens = motion_tokens.unsqueeze(0)
501
+ # stack_before_tokens_x[0][before_len[0]: before_len[0] + len(motion_tokens[0])] = 1
502
+ # import pdb; pdb.set_trace()
503
+ x = self.transformer.wte(stack_before_tokens_x)
504
+ # import pdb; pdb.set_trace()
505
+ for i in range(len(x)):
506
+ x[i][before_len[i]: before_len[i] + len(motion_tokens[i])] = motion_tokens[i]
507
+ else:
508
+ x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
509
+ for block in self.transformer.h:
510
+ x = block(x, cos, sin, mask, input_pos)
511
+ x = self.transformer.ln_f(x)
512
+ if lm_head_chunk_size > 0:
513
+ # chunk the lm head logits to reduce the peak memory used by autograd
514
+ return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)]
515
+ return self.lm_head(x) # (B, T, vocab_size)
516
+
517
+ @classmethod
518
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
519
+ return cls(Config.from_name(name, **kwargs))
520
+
521
+ def _init_weights(self, module: nn.Module) -> None:
522
+ """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
523
+ super()._init_weights(module)
524
+ if isinstance(module, LoRALinear):
525
+ module.reset_parameters()
526
+
527
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
528
+ """For compatibility with base checkpoints."""
529
+ mapping = {"lm_head.weight": "lm_head.linear.weight"}
530
+ state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
531
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
532
+
533
+
534
+ class Block(BaseBlock):
535
+ def __init__(self, config: Config) -> None:
536
+ nn.Module.__init__(self)
537
+ self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
538
+ self.attn = CausalSelfAttention(config)
539
+ if not config.shared_attention_norm:
540
+ self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
541
+ self.mlp = config.mlp_class(config)
542
+
543
+ self.config = config
544
+
545
+
546
+ class CausalSelfAttention(BaseCausalSelfAttention):
547
+ def __init__(self, config: Config) -> None:
548
+ # Skip the parent class __init__ altogether and replace it to avoid
549
+ # useless allocations
550
+ nn.Module.__init__(self)
551
+ shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
552
+ # key, query, value projections for all heads, but in a batch
553
+ self.attn = LoRAQKVLinear(
554
+ in_features=config.n_embd,
555
+ out_features=shape,
556
+ r=config.r,
557
+ lora_alpha=config.alpha,
558
+ lora_dropout=config.dropout,
559
+ enable_lora=(config.to_query, config.to_key, config.to_value),
560
+ bias=config.bias,
561
+ # for MQA/GQA support
562
+ n_head=config.n_head,
563
+ n_query_groups=config.n_query_groups,
564
+ )
565
+ # output projection
566
+ self.proj = LoRALinear(
567
+ config.n_embd,
568
+ config.n_embd,
569
+ bias=config.bias,
570
+ r=(config.r if config.to_projection else 0),
571
+ lora_alpha=config.alpha,
572
+ lora_dropout=config.dropout,
573
+ )
574
+ # disabled by default
575
+ self.kv_cache: Optional[KVCache] = None
576
+
577
+ self.config = config
578
+
579
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
580
+ """For compatibility with base checkpoints."""
581
+ mapping = {
582
+ "attn.weight": "attn.linear.weight",
583
+ "attn.bias": "attn.linear.bias",
584
+ "proj.weight": "proj.linear.weight",
585
+ "proj.bias": "proj.linear.bias",
586
+ }
587
+ state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
588
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
589
+
590
+
591
+ class GptNeoxMLP(lit_gpt.model.GptNeoxMLP):
592
+ def __init__(self, config: Config) -> None:
593
+ nn.Module.__init__(self)
594
+ self.fc = LoRALinear(
595
+ config.n_embd,
596
+ config.intermediate_size,
597
+ bias=config.bias,
598
+ r=(config.r if config.to_mlp else 0),
599
+ lora_alpha=config.alpha,
600
+ lora_dropout=config.dropout,
601
+ )
602
+ self.proj = LoRALinear(
603
+ config.intermediate_size,
604
+ config.n_embd,
605
+ bias=config.bias,
606
+ r=(config.r if config.to_mlp else 0),
607
+ lora_alpha=config.alpha,
608
+ lora_dropout=config.dropout,
609
+ )
610
+
611
+ self.config = config
612
+
613
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
614
+ """For compatibility with base checkpoints."""
615
+ mapping = {
616
+ "fc.weight": "fc.linear.weight",
617
+ "fc.bias": "fc.linear.bias",
618
+ "proj.weight": "proj.linear.weight",
619
+ "proj.bias": "proj.linear.bias",
620
+ }
621
+ state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
622
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
623
+
624
+
625
+ class LLaMAMLP(lit_gpt.model.LLaMAMLP):
626
+ def __init__(self, config: Config) -> None:
627
+ nn.Module.__init__(self)
628
+ self.fc_1 = LoRALinear(
629
+ config.n_embd,
630
+ config.intermediate_size,
631
+ bias=config.bias,
632
+ r=(config.r if config.to_mlp else 0),
633
+ lora_alpha=config.alpha,
634
+ lora_dropout=config.dropout,
635
+ )
636
+ self.fc_2 = LoRALinear(
637
+ config.n_embd,
638
+ config.intermediate_size,
639
+ bias=config.bias,
640
+ r=(config.r if config.to_mlp else 0),
641
+ lora_alpha=config.alpha,
642
+ lora_dropout=config.dropout,
643
+ )
644
+ self.proj = LoRALinear(
645
+ config.intermediate_size,
646
+ config.n_embd,
647
+ bias=config.bias,
648
+ r=(config.r if config.to_mlp else 0),
649
+ lora_alpha=config.alpha,
650
+ lora_dropout=config.dropout,
651
+ )
652
+
653
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
654
+ """For compatibility with base checkpoints."""
655
+ mapping = {
656
+ "fc_1.weight": "fc_1.linear.weight",
657
+ "fc_1.bias": "fc_1.linear.bias",
658
+ "fc_2.weight": "fc_2.linear.weight",
659
+ "fc_2.bias": "fc_2.linear.bias",
660
+ "proj.weight": "proj.linear.weight",
661
+ "proj.bias": "proj.linear.bias",
662
+ }
663
+ state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
664
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
665
+
666
+
667
+ def merge_lora_weights(model: GPT) -> None:
668
+ """Merge LoRA weights into the full-rank weights to speed up inference."""
669
+ for module in model.modules():
670
+ if isinstance(module, LoRALinear):
671
+ module.merge()
lit_gpt/model.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Full definition of a GPT NeoX Language Model, all of it in this single file.
2
+
3
+ Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
4
+ https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
5
+ """
6
+ import math
7
+ from typing import Any, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from typing_extensions import Self
12
+
13
+ from lit_gpt.config import Config
14
+
15
+
16
+ class GPT(nn.Module):
17
+ def __init__(self, config: Config) -> None:
18
+ super().__init__()
19
+ assert config.padded_vocab_size is not None
20
+ self.config = config
21
+
22
+ self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
23
+ self.transformer = nn.ModuleDict(
24
+ dict(
25
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
26
+ h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
27
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
28
+ )
29
+ )
30
+ self.max_seq_length = self.config.block_size
31
+ self.mask_cache: Optional[torch.Tensor] = None
32
+
33
+ @property
34
+ def max_seq_length(self) -> int:
35
+ return self._max_seq_length
36
+
37
+ @max_seq_length.setter
38
+ def max_seq_length(self, value: int) -> None:
39
+ """
40
+ When doing inference, the sequences used might be shorter than the model's context length.
41
+ This allows setting a smaller number to avoid allocating unused memory
42
+ """
43
+ if value > self.config.block_size:
44
+ raise ValueError(f"Cannot attend to {value}, block size is only {self.config.block_size}")
45
+ self._max_seq_length = value
46
+ if not hasattr(self, "cos"):
47
+ # first call
48
+ cos, sin = self.rope_cache()
49
+ self.register_buffer("cos", cos, persistent=False)
50
+ self.register_buffer("sin", sin, persistent=False)
51
+ elif value != self.cos.size(0):
52
+ # override
53
+ self.cos, self.sin = self.rope_cache(device=self.cos.device)
54
+ # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know
55
+ # if the kv cache is expected
56
+
57
+ def reset_parameters(self) -> None:
58
+ # Trigger resetting the rope-cache
59
+ self.max_seq_length = self.config.block_size
60
+
61
+ def _init_weights(self, module: nn.Module) -> None:
62
+ """Meant to be used with `gpt.apply(gpt._init_weights)`."""
63
+ if isinstance(module, nn.Linear):
64
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
65
+ if module.bias is not None:
66
+ torch.nn.init.zeros_(module.bias)
67
+ elif isinstance(module, nn.Embedding):
68
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
69
+
70
+ def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, maxlen: int = None) -> torch.Tensor:
71
+ T = idx.size(1) if maxlen is None else maxlen
72
+ # print(T, end=', ')
73
+ if self.max_seq_length < T:
74
+ raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")
75
+
76
+ # import pdb; pdb.set_trace()
77
+ if input_pos is not None: # use the kv cache
78
+ cos = self.cos.index_select(0, input_pos)
79
+ sin = self.sin.index_select(0, input_pos)
80
+ if self.mask_cache is None:
81
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
82
+ mask = self.mask_cache.index_select(2, input_pos)
83
+ else:
84
+ cos = self.cos[:T]
85
+ sin = self.sin[:T]
86
+ mask = None
87
+
88
+ if type(idx) is tuple:
89
+ stack_before_tokens_x, motion_tokens, before_len = idx
90
+ # stack_before_tokens_x = stack_before_tokens_x.unsqueeze(0)
91
+ # motion_tokens = motion_tokens.unsqueeze(0)
92
+ # stack_before_tokens_x[0][before_len[0]: before_len[0] + len(motion_tokens[0])] = 1
93
+ x = self.transformer.wte(stack_before_tokens_x.cuda())
94
+ # import pdb; pdb.set_trace()
95
+ for i in range(len(x)):
96
+ x[i][before_len[i]: before_len[i] + len(motion_tokens[i])] = motion_tokens[i].cuda()
97
+ else:
98
+ x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
99
+ for block in self.transformer.h:
100
+ x = block(x, cos, sin, mask, input_pos)
101
+ x = self.transformer.ln_f(x)
102
+ return self.lm_head(x) # (b, t, vocab_size)
103
+
104
+ @classmethod
105
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
106
+ return cls(Config.from_name(name, **kwargs))
107
+
108
+ def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tensor, torch.Tensor]:
109
+ return build_rope_cache(
110
+ seq_len=self.max_seq_length,
111
+ n_elem=self.config.rope_n_elem,
112
+ device=device,
113
+ condense_ratio=self.config.rope_condense_ratio,
114
+ base=self.config.rope_base,
115
+ )
116
+
117
+ def set_kv_cache(
118
+ self,
119
+ batch_size: int,
120
+ rope_cache_length: Optional[int] = None,
121
+ device: Optional[torch.device] = None,
122
+ dtype: Optional[torch.dtype] = None,
123
+ ) -> None:
124
+ if rope_cache_length is None:
125
+ rope_cache_length = self.cos.size(-1)
126
+ max_seq_length = self.max_seq_length
127
+
128
+ # initialize the kv cache for all blocks
129
+ for block in self.transformer.h:
130
+ block.attn.kv_cache = block.attn.build_kv_cache(
131
+ batch_size, max_seq_length, rope_cache_length, device, dtype
132
+ )
133
+
134
+ if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length:
135
+ # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask
136
+ # for the kv-cache support (only during inference), we only create it in that situation
137
+ # this will be resolved by https://github.com/pytorch/pytorch/issues/96099
138
+ ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)
139
+ self.mask_cache = torch.tril(ones).unsqueeze(0).unsqueeze(0)
140
+
141
+ def clear_kv_cache(self) -> None:
142
+ self.mask_cache = None
143
+ for block in self.transformer.h:
144
+ block.attn.kv_cache = None
145
+
146
+
147
+ class Block(nn.Module):
148
+ def __init__(self, config: Config) -> None:
149
+ super().__init__()
150
+ self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
151
+ self.attn = CausalSelfAttention(config)
152
+ self.norm_2 = None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps)
153
+ self.mlp = config.mlp_class(config)
154
+
155
+ self.config = config
156
+
157
+ def forward(
158
+ self,
159
+ x: torch.Tensor,
160
+ cos: torch.Tensor,
161
+ sin: torch.Tensor,
162
+ mask: Optional[torch.Tensor] = None,
163
+ input_pos: Optional[torch.Tensor] = None,
164
+ ) -> torch.Tensor:
165
+ n_1 = self.norm_1(x)
166
+ h = self.attn(n_1, cos, sin, mask, input_pos)
167
+ if self.config.parallel_residual:
168
+ n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
169
+ x = self.mlp(n_2) + h + x
170
+ else:
171
+ if self.config.shared_attention_norm:
172
+ raise NotImplementedError(
173
+ "No checkpoint amongst the ones we support uses this configuration"
174
+ " (non-parallel residual and shared attention norm)."
175
+ )
176
+ x = h + x
177
+ x = self.mlp(self.norm_2(x)) + x
178
+ return x
179
+
180
+
181
+ class CausalSelfAttention(nn.Module):
182
+ def __init__(self, config: Config) -> None:
183
+ super().__init__()
184
+ shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
185
+ # key, query, value projections for all heads, but in a batch
186
+ self.attn = nn.Linear(config.n_embd, shape, bias=config.bias)
187
+ # output projection
188
+ self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
189
+ # disabled by default
190
+ self.kv_cache: Optional[KVCache] = None
191
+
192
+ self.config = config
193
+
194
+ def forward(
195
+ self,
196
+ x: torch.Tensor,
197
+ cos: torch.Tensor,
198
+ sin: torch.Tensor,
199
+ mask: Optional[torch.Tensor] = None,
200
+ input_pos: Optional[torch.Tensor] = None,
201
+ ) -> torch.Tensor:
202
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
203
+
204
+ qkv = self.attn(x)
205
+
206
+ # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
207
+ q_per_kv = self.config.n_head // self.config.n_query_groups
208
+ total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
209
+ qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size)
210
+ qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)
211
+
212
+ # split batched computation into three
213
+ q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
214
+
215
+ # maybe repeat k and v if for the non multi-head attention cases
216
+ # training: flash attention requires it
217
+ # inference: multi-query would require a full kv cache so avoid it to limit its memory usage
218
+ if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1):
219
+ k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
220
+ v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
221
+
222
+ q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs)
223
+ k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs)
224
+ v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs)
225
+
226
+ q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin)
227
+ k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin)
228
+ q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1)
229
+ k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1)
230
+
231
+ if input_pos is not None:
232
+ if not isinstance(self.kv_cache, KVCache):
233
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
234
+ k, v = self.kv_cache(input_pos, k, v)
235
+
236
+ y = self.scaled_dot_product_attention(q, k, v, mask)
237
+
238
+ y = y.reshape(B, T, C) # re-assemble all head outputs side by side
239
+
240
+ # output projection
241
+ return self.proj(y)
242
+
243
+ def scaled_dot_product_attention(
244
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
245
+ ) -> torch.Tensor:
246
+ scale = 1.0 / math.sqrt(self.config.head_size)
247
+ y = torch.nn.functional.scaled_dot_product_attention(
248
+ q, k, v, attn_mask=mask, dropout_p=0.0,
249
+ # scale=scale,
250
+ is_causal=mask is None
251
+ )
252
+ return y.transpose(1, 2)
253
+
254
+ def build_kv_cache(
255
+ self,
256
+ batch_size: int,
257
+ max_seq_length: int,
258
+ rope_cache_length: Optional[int] = None,
259
+ device: Optional[torch.device] = None,
260
+ dtype: Optional[torch.dtype] = None,
261
+ ) -> "KVCache":
262
+ heads = 1 if self.config.n_query_groups == 1 else self.config.n_head
263
+ v_shape = (batch_size, heads, max_seq_length, self.config.head_size)
264
+ if rope_cache_length is None:
265
+ if self.config.rotary_percentage != 1.0:
266
+ raise TypeError("Please pass the `rope_cache_length=gpt.cos.size(-1)` value")
267
+ k_shape = v_shape
268
+ else:
269
+ k_shape = (
270
+ batch_size,
271
+ heads,
272
+ max_seq_length,
273
+ rope_cache_length + self.config.head_size - self.config.rope_n_elem,
274
+ )
275
+ return KVCache(k_shape, v_shape, device=device, dtype=dtype)
276
+
277
+
278
+ class GptNeoxMLP(nn.Module):
279
+ def __init__(self, config: Config) -> None:
280
+ super().__init__()
281
+ self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
282
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
283
+
284
+ self.config = config
285
+
286
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
287
+ x = self.fc(x)
288
+ x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate)
289
+ return self.proj(x)
290
+
291
+
292
+ class LLaMAMLP(nn.Module):
293
+ def __init__(self, config: Config) -> None:
294
+ super().__init__()
295
+ self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
296
+ self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
297
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
298
+
299
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
300
+ x_fc_1 = self.fc_1(x)
301
+ x_fc_2 = self.fc_2(x)
302
+ x = torch.nn.functional.silu(x_fc_1) * x_fc_2
303
+ return self.proj(x)
304
+
305
+
306
+ def build_rope_cache(
307
+ seq_len: int, n_elem: int, device: Optional[torch.device] = None, base: int = 10000, condense_ratio: int = 1
308
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
309
+ """Enhanced Transformer with Rotary Position Embedding.
310
+
311
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
312
+ transformers/rope/__init__.py. MIT License:
313
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
314
+ """
315
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
316
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
317
+
318
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
319
+ seq_idx = torch.arange(seq_len, device=device) / condense_ratio
320
+
321
+ # Calculate the product of position index and $\theta_i$
322
+ idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
323
+
324
+ return torch.cos(idx_theta), torch.sin(idx_theta)
325
+
326
+
327
+ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
328
+ head_size = x.size(-1)
329
+ x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
330
+ x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
331
+ rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
332
+ roped = (x * cos) + (rotated * sin)
333
+ return roped.type_as(x)
334
+
335
+
336
+ class KVCache(nn.Module):
337
+ def __init__(
338
+ self,
339
+ k_shape: Tuple[int, int, int, int],
340
+ v_shape: Tuple[int, int, int, int],
341
+ device: Optional[torch.device] = None,
342
+ dtype: Optional[torch.dtype] = None,
343
+ ) -> None:
344
+ super().__init__()
345
+ self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False)
346
+ self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False)
347
+
348
+ def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
349
+ # move the buffer to the activation dtype for when AMP is used
350
+ self.k = self.k.to(k.dtype)
351
+ self.v = self.v.to(v.dtype)
352
+ # update the cache
353
+ k = self.k.index_copy_(2, input_pos, k)
354
+ v = self.v.index_copy_(2, input_pos, v)
355
+ return k, v
lit_gpt/packed_dataset.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Very loosely inspired by indexed_dataset in Fairseq, Megatron
2
+ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py
3
+
4
+
5
+ import os
6
+ import random
7
+ import struct
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch.utils.data import IterableDataset, get_worker_info
12
+
13
+ dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float32, 7: np.float64, 8: np.uint16}
14
+
15
+
16
+ def code(dtype):
17
+ for k in dtypes:
18
+ if dtypes[k] == dtype:
19
+ return k
20
+ raise ValueError(dtype)
21
+
22
+
23
+ HDR_MAGIC = b"LITPKDS"
24
+ HDR_SIZE = 24 # bytes
25
+
26
+
27
+ class PackedDataset(IterableDataset):
28
+ def __init__(
29
+ self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0
30
+ ):
31
+ self._filenames = filenames
32
+ self._n_chunks = n_chunks
33
+ self._block_size = block_size
34
+ self._seed = seed
35
+ self._shuffle = shuffle
36
+ self._wrap = wrap
37
+ self._num_processes = num_processes
38
+ self._process_rank = process_rank
39
+
40
+ def __iter__(self):
41
+ worker_info = get_worker_info()
42
+ num_workers = worker_info.num_workers if worker_info is not None else 1
43
+ worker_id = worker_info.id if worker_info is not None else 0
44
+ num_shards = num_workers * self._num_processes
45
+ shard_id = self._process_rank * num_workers + worker_id
46
+
47
+ max_num_files = len(self._filenames) // num_shards * num_shards
48
+ filenames = self._filenames[shard_id:max_num_files:num_shards]
49
+
50
+ return PackedDatasetIterator(
51
+ filenames=filenames,
52
+ n_chunks=self._n_chunks,
53
+ block_size=self._block_size,
54
+ seed=self._seed,
55
+ shuffle=self._shuffle,
56
+ wrap=self._wrap,
57
+ )
58
+
59
+
60
+ class PackedDatasetBuilder(object):
61
+ def __init__(self, outdir, prefix, chunk_size, sep_token, dtype="auto", vocab_size=None):
62
+ if dtype == "auto":
63
+ if vocab_size is None:
64
+ raise ValueError("vocab_size cannot be None when dtype='auto'")
65
+ if vocab_size is not None and vocab_size < 65500:
66
+ self._dtype = np.uint16
67
+ else:
68
+ self._dtype = np.int32
69
+ else:
70
+ self._dtype = dtype
71
+ self._counter = 0
72
+ self._chunk_size = chunk_size
73
+ self._outdir = outdir
74
+ self._prefix = prefix
75
+ self._sep_token = sep_token
76
+ self._arr = np.zeros(self._chunk_size, dtype=self._dtype)
77
+ self._arr.fill(self._sep_token)
78
+ self._idx = 0
79
+ self._version = 1
80
+ self._filenames = []
81
+
82
+ def _write_chunk(self):
83
+ filename = f"{self._prefix}_{self._counter:010d}.bin"
84
+ filename = os.path.join(self._outdir, filename)
85
+
86
+ with open(filename, "wb") as f:
87
+ f.write(HDR_MAGIC)
88
+ f.write(struct.pack("<Q", self._version))
89
+ f.write(struct.pack("<B", code(self._dtype)))
90
+ f.write(struct.pack("<Q", self._chunk_size))
91
+ f.write(self._arr.tobytes(order="C"))
92
+
93
+ self._filenames.append(filename)
94
+ self._counter += 1
95
+ self._arr.fill(self._sep_token)
96
+ self._idx = 0
97
+
98
+ @property
99
+ def dtype(self):
100
+ return self._dtype
101
+
102
+ @property
103
+ def filenames(self):
104
+ return self._filenames.copy()
105
+
106
+ def add_array(self, arr):
107
+ while self._idx + arr.shape[0] > self._chunk_size:
108
+ part_len = self._chunk_size - self._idx
109
+ self._arr[self._idx : self._idx + part_len] = arr[:part_len]
110
+ self._write_chunk()
111
+ arr = arr[part_len:]
112
+
113
+ arr_len = arr.shape[0]
114
+ self._arr[self._idx : self._idx + arr_len] = arr
115
+ self._idx += arr_len
116
+
117
+ def write_reminder(self):
118
+ self._write_chunk()
119
+
120
+
121
+ class PackedDatasetIterator:
122
+ def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap):
123
+ self._seed = seed
124
+ self._shuffle = shuffle
125
+ self._rng = np.random.default_rng(seed) if shuffle else None
126
+ self._block_idxs = None
127
+
128
+ self._wrap = wrap
129
+
130
+ # TODO: instead of filenames, we could have a single text stream
131
+ # (or text file) with the sequence of all files to be
132
+ # fetched/loaded.
133
+ self._filenames = filenames
134
+ self._file_idx = 0
135
+
136
+ self._n_chunks = n_chunks
137
+
138
+ self._dtype = None
139
+ self._block_size = block_size
140
+ self._n_blocks = None
141
+
142
+ self._mmaps = []
143
+ self._buffers = []
144
+
145
+ self._block_idxs = []
146
+ self._curr_idx = 0
147
+
148
+ self._load_n_chunks()
149
+
150
+ def _read_header(self, path):
151
+ with open(path, "rb") as f:
152
+ magic = f.read(len(HDR_MAGIC))
153
+ assert magic == HDR_MAGIC, "File doesn't match expected format."
154
+ version = struct.unpack("<Q", f.read(8))
155
+ assert version == (1,)
156
+ (dtype_code,) = struct.unpack("<B", f.read(1))
157
+ dtype = dtypes[dtype_code]
158
+ (chunk_size,) = struct.unpack("<Q", f.read(8))
159
+ return dtype, chunk_size
160
+
161
+ def _close_mmaps(self):
162
+ for mmap in self._mmaps:
163
+ mmap._mmap.close()
164
+
165
+ def _load_n_chunks(self):
166
+ self._close_mmaps()
167
+ self._mmaps = []
168
+ self._buffers = []
169
+
170
+ if self._n_chunks > len(self._filenames[self._file_idx :]):
171
+ if not self._wrap:
172
+ raise StopIteration
173
+ self._file_idx = 0
174
+
175
+ for i in range(self._n_chunks):
176
+ filename = self._filenames[self._file_idx + i]
177
+ if self._dtype is None:
178
+ self._dtype, self._chunk_size = self._read_header(filename)
179
+ self._n_blocks = self._chunk_size // self._block_size
180
+ # TODO: check header matches with previous files
181
+ mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE)
182
+ self._mmaps.append(mmap)
183
+ self._buffers.append(memoryview(mmap))
184
+
185
+ self._file_idx += self._n_chunks
186
+ n_all_blocks = self._n_chunks * self._n_blocks
187
+
188
+ self._block_idxs = self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks)
189
+
190
+ self._curr_idx = 0
191
+
192
+ def __del__(self):
193
+ self._close_mmaps()
194
+ del self._mmaps
195
+ del self._buffers
196
+
197
+ def __iter__(self):
198
+ return self
199
+
200
+ def __next__(self):
201
+ if self._curr_idx >= len(self._block_idxs):
202
+ self._load_n_chunks()
203
+ # TODO: trigger fetching next next n_chunks if remote
204
+ block_idx = self._block_idxs[self._curr_idx]
205
+ chunk_id = block_idx // self._n_blocks
206
+ buffer = self._buffers[chunk_id]
207
+ elem_id = (block_idx % self._n_blocks) * self._block_size
208
+ offset = np.dtype(self._dtype).itemsize * elem_id
209
+ arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)
210
+ self._curr_idx += 1
211
+ return torch.from_numpy(arr.astype(np.int64))
212
+
213
+
214
+ class CombinedDataset(IterableDataset):
215
+ def __init__(self, datasets, seed, weights=None):
216
+ self._seed = seed
217
+ self._datasets = datasets
218
+ self._weights = weights
219
+ n_datasets = len(datasets)
220
+ if weights is None:
221
+ self._weights = [1 / n_datasets] * n_datasets
222
+
223
+ def __iter__(self):
224
+ return CombinedDatasetIterator(self._datasets, self._seed, self._weights)
225
+
226
+
227
+ class CombinedDatasetIterator:
228
+ def __init__(self, datasets, seed, weights):
229
+ self._datasets = [iter(el) for el in datasets]
230
+ self._weights = weights
231
+ self._rng = random.Random(seed)
232
+
233
+ def __next__(self):
234
+ (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1)
235
+ return next(dataset)
lit_gpt/rmsnorm.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class RMSNorm(torch.nn.Module):
5
+ """Root Mean Square Layer Normalization.
6
+
7
+ Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
8
+ https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
9
+ """
10
+
11
+ def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
12
+ super().__init__()
13
+ self.weight = torch.nn.Parameter(torch.ones(size))
14
+ self.eps = eps
15
+ self.dim = dim
16
+
17
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
18
+ dtype = x.dtype
19
+ x = x.float()
20
+ # NOTE: the original RMSNorm paper implementation is not equivalent
21
+ norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
22
+ x_normed = x * torch.rsqrt(norm_x + self.eps)
23
+ return (self.weight * x_normed).to(dtype=dtype)
24
+
25
+ def reset_parameters(self) -> None:
26
+ torch.nn.init.ones_(self.weight)
lit_gpt/speed_monitor.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from collections import deque
3
+ from contextlib import nullcontext
4
+ from typing import Any, Callable, Deque, Dict, Optional
5
+
6
+ import torch
7
+ from lightning import Callback, Fabric, LightningModule, Trainer
8
+ from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1
9
+ from lightning.fabric.plugins import (
10
+ BitsandbytesPrecision,
11
+ DoublePrecision,
12
+ FSDPPrecision,
13
+ HalfPrecision,
14
+ MixedPrecision,
15
+ Precision,
16
+ TransformerEnginePrecision,
17
+ XLAPrecision,
18
+ )
19
+ from lightning.fabric.utilities.rank_zero import rank_zero_only as fabric_rank_zero_only
20
+ from lightning.pytorch.plugins import (
21
+ DoublePrecisionPlugin,
22
+ FSDPPrecisionPlugin,
23
+ HalfPrecisionPlugin,
24
+ MixedPrecisionPlugin,
25
+ XLAPrecisionPlugin,
26
+ )
27
+ from lightning.pytorch.utilities.rank_zero import rank_zero_only as trainer_rank_zero_only
28
+ from torch.utils.flop_counter import FlopCounterMode
29
+
30
+ from lit_gpt import GPT
31
+ from lit_gpt.utils import num_parameters
32
+
33
+ GPU_AVAILABLE_FLOPS = {
34
+ # source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
35
+ # nvidia publishes spec sheet with a 2x sparsity factor
36
+ "h100-sxm": {
37
+ torch.float64: 67e12,
38
+ torch.float32: 67e12,
39
+ torch.bfloat16: 1.979e15 / 2,
40
+ torch.float16: 1.979e15 / 2,
41
+ torch.int8: 3.958e15 / 2,
42
+ },
43
+ "h100-pcie": {
44
+ torch.float64: 51e12,
45
+ torch.float32: 51e12,
46
+ torch.bfloat16: 1.513e15 / 2,
47
+ torch.float16: 1.513e15 / 2,
48
+ torch.int8: 3.026e15 / 2,
49
+ },
50
+ # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
51
+ # sxm and pcie have same flop counts
52
+ "a100": {torch.float64: 19.5e12, torch.float32: 19.5e12, torch.bfloat16: 312e12, torch.float16: 312e12},
53
+ # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/a10-datasheet.pdf
54
+ "a10g": {torch.float32: 31.2e12, torch.bfloat16: 125e12, torch.float16: 125e12},
55
+ # source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf
56
+ "v100-sxm": {torch.float64: 7.8e12, torch.float32: 15.7e12, torch.float16: 125e12},
57
+ "v100-pcie": {torch.float64: 7e12, torch.float32: 14e12, torch.float16: 112e12},
58
+ "v100s-pcie": {torch.float64: 8.2e12, torch.float32: 16.4e12, torch.float16: 130e12},
59
+ # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf
60
+ # sxm and pcie have same flop counts
61
+ "t4": {torch.float32: 8.1e12, torch.float16: 65e12, torch.int8: 130e12},
62
+ # https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/quadro-rtx-5000-data-sheet-us-nvidia-704120-r4-web.pdf
63
+ "quadro rtx 5000": {torch.float32: 11.2e12, torch.float16: 89.2e12},
64
+ }
65
+
66
+ TPU_AVAILABLE_FLOPS = {
67
+ # flop count for each TPU generation is the same for all precisions
68
+ # since bfloat16 precision is always used for performing matrix operations
69
+ # for more info: https://cloud.google.com/tpu/docs/bfloat16#choosing_bfloat16
70
+ # source: https://arxiv.org/pdf/1907.10701.pdf
71
+ "v2": 45e12,
72
+ # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v3
73
+ "v3": 123e12,
74
+ # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v4
75
+ "v4": 275e12,
76
+ # source: https://cloud.google.com/tpu/docs/v5e-training
77
+ "v5litepod": 197e12,
78
+ }
79
+
80
+
81
+ def get_flops_available(device: torch.device, dtype: torch.dtype) -> Optional[float]:
82
+ if device.type == "cuda":
83
+ device_name = torch.cuda.get_device_name(device).lower()
84
+ if "h100" in device_name and "hbm3" in device_name:
85
+ device_name = "h100-sxm"
86
+ elif "h100" in device_name and ("pcie" in device_name or "hbm2e" in device_name):
87
+ device_name = "h100-pcie"
88
+ elif "a100" in device_name:
89
+ device_name = "a100"
90
+ elif "a10g" in device_name:
91
+ device_name = "a10g"
92
+ elif "v100-sxm" in device_name:
93
+ device_name = "v100-sxm"
94
+ elif "v100-pcie" in device_name:
95
+ device_name = "v100-pcie"
96
+ elif "t4" in device_name:
97
+ device_name = "t4"
98
+ elif "quadro rtx 5000" in device_name:
99
+ device_name = "quadro rtx 5000"
100
+ else:
101
+ device_name = None
102
+
103
+ if device_name is not None:
104
+ try:
105
+ return int(GPU_AVAILABLE_FLOPS[device_name][dtype])
106
+ except KeyError:
107
+ raise KeyError(
108
+ f"flop count not found for {device_name} with dtype: {dtype}; "
109
+ "MFU cannot be calculated and reported."
110
+ )
111
+ elif device.type == "xla":
112
+ if _XLA_GREATER_EQUAL_2_1:
113
+ from torch_xla._internal import tpu
114
+ else:
115
+ from torch_xla.experimental import tpu
116
+
117
+ device_name = tpu.get_tpu_env()["TYPE"].lower()
118
+ try:
119
+ return int(TPU_AVAILABLE_FLOPS[device_name])
120
+ except KeyError:
121
+ raise KeyError(
122
+ f"flop count not found for {device_name} with dtype: {dtype}; MFU cannot be calculated and reported."
123
+ )
124
+
125
+ return None
126
+
127
+
128
+ # Adapted from https://github.com/mosaicml/composer/blob/f2a2dc820cb75023b9eb7c46fdfd25273712abd0/composer/callbacks/speed_monitor.py
129
+
130
+
131
+ class SpeedMonitorBase:
132
+ """Logs the training throughput and utilization.
133
+
134
+ +-------------------------------------+-----------------------------------------------------------+
135
+ | Key | Logged data |
136
+ +=====================================+===========================================================+
137
+ | | Rolling average (over `window_size` most recent |
138
+ | `throughput/batches_per_sec` | batches) of the number of batches processed per second |
139
+ | | |
140
+ +-------------------------------------+-----------------------------------------------------------+
141
+ | | Rolling average (over `window_size` most recent |
142
+ | `throughput/samples_per_sec` | batches) of the number of samples processed per second |
143
+ | | |
144
+ +-------------------------------------+-----------------------------------------------------------+
145
+ | | Rolling average (over `window_size` most recent |
146
+ | `throughput/tokens_per_sec` | batches) of the number of tokens processed per second. |
147
+ | | This may include padding depending on dataset |
148
+ +-------------------------------------+-----------------------------------------------------------+
149
+ | | Estimates flops by `flops_per_batch * batches_per_sec` |
150
+ | `throughput/flops_per_sec` | |
151
+ | | |
152
+ +-------------------------------------+-----------------------------------------------------------+
153
+ | `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size |
154
+ +-------------------------------------+-----------------------------------------------------------+
155
+ | `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size |
156
+ +-------------------------------------+-----------------------------------------------------------+
157
+ | | `throughput/tokens_per_sec` divided by world size. This |
158
+ | `throughput/device/tokens_per_sec` | may include pad tokens depending on dataset |
159
+ | | |
160
+ +-------------------------------------+-----------------------------------------------------------+
161
+ | | `throughput/flops_per_sec` divided by world size. Only |
162
+ | `throughput/device/flops_per_sec` | logged when model has attribute `flops_per_batch` |
163
+ | | |
164
+ +-------------------------------------+-----------------------------------------------------------+
165
+ | | `throughput/device/flops_per_sec` divided by world size. |
166
+ | `throughput/device/mfu` | |
167
+ | | |
168
+ +-------------------------------------+-----------------------------------------------------------+
169
+ | `time/train` | Total elapsed training time |
170
+ +-------------------------------------+-----------------------------------------------------------+
171
+ | `time/val` | Total elapsed validation time |
172
+ +-------------------------------------+-----------------------------------------------------------+
173
+ | `time/total` | Total elapsed time (time/train + time/val) |
174
+ +-------------------------------------+-----------------------------------------------------------+
175
+
176
+ Notes:
177
+ - The implementation assumes that devices are homogeneous as it normalizes by the world size.
178
+ - Tokens/sec, flops/sec and MFU do not account for padding tokens if present. We suggest using samples/sec or
179
+ batches/sec to measure throughput under this circumstance.
180
+ - Be careful when comparing MFU numbers across projects, as this will highly depend on the ``flops_per_batch``.
181
+ There is no widespread, realistic, and reliable implementation to compute them.
182
+ We suggest using our ``measure_flops`` function, but many other works will use ``estimated_flops`` which
183
+ will almost always be an overestimate when compared to the true value.
184
+
185
+ Args:
186
+ window_size (int, optional): Number of batches to use for a rolling average of throughput.
187
+ Defaults to 100.
188
+ time_unit (str, optional): Time unit to use for `time` logging. Can be one of
189
+ 'seconds', 'minutes', 'hours', or 'days'. Defaults to 'hours'.
190
+ """
191
+
192
+ def __init__(
193
+ self,
194
+ flops_available: float,
195
+ log_dict: Callable[[Dict, int], None],
196
+ window_size: int = 100,
197
+ time_unit: str = "hours",
198
+ ):
199
+ self.flops_available = flops_available
200
+ self.log_dict = log_dict
201
+
202
+ # Track the batch num samples and wct to compute throughput over a window of batches
203
+ self.history_samples: Deque[int] = deque(maxlen=window_size + 1)
204
+ self.history_wct: Deque[float] = deque(maxlen=window_size + 1)
205
+ self.history_lengths: Deque[int] = deque(maxlen=window_size + 1)
206
+ self.history_flops: Deque[int] = deque(maxlen=window_size + 1)
207
+
208
+ self.divider = 1
209
+ if time_unit == "seconds":
210
+ self.divider = 1
211
+ elif time_unit == "minutes":
212
+ self.divider = 60
213
+ elif time_unit == "hours":
214
+ self.divider = 60 * 60
215
+ elif time_unit == "days":
216
+ self.divider = 60 * 60 * 24
217
+ else:
218
+ raise ValueError(
219
+ f'Invalid time_unit: {time_unit}. Must be one of "seconds", "minutes", "hours", or "days".'
220
+ )
221
+
222
+ # Keep track of time spent evaluating
223
+ self.total_eval_wct = 0.0
224
+ self.step = -1
225
+
226
+ def on_train_batch_end(
227
+ self,
228
+ samples: int, # total samples seen (per device)
229
+ train_elapsed: float, # total training time (seconds)
230
+ world_size: int,
231
+ flops_per_batch: Optional[int] = None, # (per device)
232
+ lengths: Optional[int] = None, # total length of the samples seen (per device)
233
+ ) -> None:
234
+ self.step += 1
235
+ step = self.step
236
+ metrics = {}
237
+
238
+ self.history_samples.append(samples)
239
+ if lengths is not None:
240
+ self.history_lengths.append(lengths)
241
+ # if lengths are passed, there should be as many values as samples
242
+ assert len(self.history_samples) == len(self.history_lengths)
243
+ self.history_wct.append(train_elapsed)
244
+ if len(self.history_wct) == self.history_wct.maxlen:
245
+ elapsed_batches = len(self.history_samples) - 1
246
+ elapsed_samples = self.history_samples[-1] - self.history_samples[0]
247
+ elapsed_wct = self.history_wct[-1] - self.history_wct[0]
248
+ samples_per_sec = elapsed_samples * world_size / elapsed_wct
249
+ dev_samples_per_sec = elapsed_samples / elapsed_wct
250
+ metrics.update(
251
+ {
252
+ "throughput/batches_per_sec": elapsed_batches * world_size / elapsed_wct,
253
+ "throughput/samples_per_sec": samples_per_sec,
254
+ "throughput/device/batches_per_sec": elapsed_batches / elapsed_wct,
255
+ "throughput/device/samples_per_sec": dev_samples_per_sec,
256
+ }
257
+ )
258
+ if lengths is not None:
259
+ elapsed_lengths = int(self.history_lengths[-1]) - int(self.history_lengths[0])
260
+ avg_length = elapsed_lengths / elapsed_batches
261
+ metrics.update(
262
+ {
263
+ "throughput/tokens_per_sec": samples_per_sec * avg_length,
264
+ "throughput/device/tokens_per_sec": dev_samples_per_sec * avg_length,
265
+ }
266
+ )
267
+
268
+ if flops_per_batch is not None:
269
+ # sum of flops per batch across ranks
270
+ self.history_flops.append(flops_per_batch * world_size)
271
+ if len(self.history_flops) == self.history_flops.maxlen:
272
+ elapsed_flops = sum(self.history_flops) - self.history_flops[0]
273
+ elapsed_wct = self.history_wct[-1] - self.history_wct[0]
274
+ flops_per_sec = elapsed_flops / elapsed_wct
275
+ device_flops_per_sec = flops_per_sec / world_size
276
+ metrics.update(
277
+ {"throughput/flops_per_sec": flops_per_sec, "throughput/device/flops_per_sec": device_flops_per_sec}
278
+ )
279
+ if self.flops_available:
280
+ metrics["throughput/device/mfu"] = device_flops_per_sec / self.flops_available
281
+
282
+ metrics.update(
283
+ {
284
+ "time/train": train_elapsed / self.divider,
285
+ "time/val": self.total_eval_wct / self.divider,
286
+ "time/total": (train_elapsed + self.total_eval_wct) / self.divider,
287
+ "samples": samples,
288
+ }
289
+ )
290
+
291
+ self.log_dict(metrics, step)
292
+
293
+ def eval_end(self, eval_elapsed: float) -> None:
294
+ self.total_eval_wct += eval_elapsed # seconds
295
+
296
+
297
+ def plugin_to_compute_dtype(plugin: Precision) -> torch.dtype:
298
+ if isinstance(plugin, BitsandbytesPrecision):
299
+ return plugin.dtype
300
+ if isinstance(plugin, (HalfPrecision, MixedPrecision, HalfPrecisionPlugin)):
301
+ return plugin._desired_input_dtype
302
+ if isinstance(plugin, MixedPrecisionPlugin):
303
+ return torch.bfloat16 if plugin.precision == "bf16-mixed" else torch.half
304
+ if isinstance(plugin, (DoublePrecision, DoublePrecisionPlugin)):
305
+ return torch.double
306
+ if isinstance(plugin, (XLAPrecision, XLAPrecisionPlugin)):
307
+ return plugin._desired_dtype
308
+ if isinstance(plugin, TransformerEnginePrecision):
309
+ return torch.int8
310
+ if isinstance(plugin, (FSDPPrecision, FSDPPrecisionPlugin)):
311
+ return plugin.mixed_precision_config.reduce_dtype
312
+ if isinstance(plugin, Precision):
313
+ return torch.float32
314
+ raise NotImplementedError(plugin)
315
+
316
+
317
+ class SpeedMonitorFabric(SpeedMonitorBase):
318
+ def __init__(self, fabric: Fabric, *args: Any, **kwargs: Any) -> None:
319
+ dtype = plugin_to_compute_dtype(fabric.strategy.precision)
320
+ flops_available = get_flops_available(fabric.device, dtype)
321
+ super().__init__(flops_available, fabric.log_dict, *args, **kwargs)
322
+
323
+ @fabric_rank_zero_only
324
+ def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None:
325
+ super().on_train_batch_end(*args, **kwargs)
326
+
327
+
328
+ class SpeedMonitorCallback(Callback):
329
+ def __init__(self, length_fn: Callable[[Any], int], batch_size: int, **kwargs: Any) -> None:
330
+ super().__init__()
331
+ self.speed_monitor: Optional[SpeedMonitorBase] = None
332
+ self.speed_monitor_kwargs = kwargs
333
+ self.length_fn = length_fn
334
+ self.batch_size = batch_size
335
+ self.eval_t0: int = 0
336
+ self.train_t0: int = 0
337
+ self.total_lengths: int = 0
338
+
339
+ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
340
+ if self.speed_monitor is not None:
341
+ return # already setup
342
+ dtype = plugin_to_compute_dtype(trainer.precision_plugin)
343
+ flops_available = get_flops_available(trainer.strategy.root_device, dtype)
344
+ self.speed_monitor = SpeedMonitorBase(flops_available, trainer.logger.log_metrics, **self.speed_monitor_kwargs)
345
+
346
+ @trainer_rank_zero_only
347
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
348
+ if trainer.fit_loop._should_accumulate():
349
+ return
350
+
351
+ self.train_t0 = time.perf_counter()
352
+
353
+ @trainer_rank_zero_only
354
+ def on_train_batch_end(
355
+ self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int
356
+ ) -> None:
357
+ self.total_lengths += self.length_fn(batch)
358
+ if trainer.fit_loop._should_accumulate():
359
+ return
360
+ train_elapsed = time.perf_counter() - self.train_t0
361
+ assert self.speed_monitor is not None
362
+ iter_num = trainer.fit_loop.total_batch_idx
363
+ assert (measured_flops := pl_module.measured_flops) is not None
364
+ self.speed_monitor.on_train_batch_end(
365
+ (iter_num + 1) * self.batch_size,
366
+ train_elapsed,
367
+ # this assumes that device FLOPs are the same and that all devices have the same batch size
368
+ trainer.world_size,
369
+ flops_per_batch=measured_flops,
370
+ lengths=self.total_lengths,
371
+ )
372
+
373
+ @trainer_rank_zero_only
374
+ def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
375
+ self.eval_t0 = time.perf_counter()
376
+
377
+ @trainer_rank_zero_only
378
+ def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
379
+ eval_elapsed = time.perf_counter() - self.eval_t0
380
+ assert self.speed_monitor is not None
381
+ self.speed_monitor.eval_end(eval_elapsed)
382
+
383
+
384
+ def flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int:
385
+ flops_per_token = 2 * n_params # each parameter is used for a MAC (2 FLOPS) per network operation
386
+ # this assumes that all samples have a fixed length equal to the block size
387
+ # which is most likely false during finetuning
388
+ flops_per_seq = flops_per_token * max_seq_length
389
+ attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))
390
+ return flops_per_seq + attn_flops_per_seq
391
+
392
+
393
+ def estimate_flops(model: GPT) -> int:
394
+ """Measures estimated FLOPs for MFU.
395
+
396
+ Refs:
397
+ * https://ar5iv.labs.arxiv.org/html/2205.05198#A1
398
+ * https://ar5iv.labs.arxiv.org/html/2204.02311#A2
399
+ """
400
+ # using all parameters for this is a naive over estimation because not all model parameters actually contribute to
401
+ # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage
402
+ # (~10%) compared to the measured FLOPs, making those lower but more realistic.
403
+ # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.
404
+ n_trainable_params = num_parameters(model, requires_grad=True)
405
+ trainable_flops = flops_per_param(
406
+ model.max_seq_length, model.config.n_layer, model.config.n_embd, n_trainable_params
407
+ )
408
+ # forward + backward + gradients (assumes no gradient accumulation)
409
+ ops_per_step = 3 if model.training else 1
410
+ n_frozen_params = num_parameters(model, requires_grad=False)
411
+ frozen_flops = flops_per_param(model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params)
412
+ # forward + backward
413
+ frozen_ops_per_step = 2 if model.training else 1
414
+ return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops
415
+
416
+
417
+ def measure_flops(model: GPT, x: torch.Tensor) -> int:
418
+ """Measures real FLOPs for HFU"""
419
+ flop_counter = FlopCounterMode(model, display=False)
420
+ ctx = nullcontext() if model.training else torch.no_grad()
421
+ with ctx, flop_counter:
422
+ y = model(x)
423
+ if model.training:
424
+ y.sum().backward()
425
+ return flop_counter.get_total_flops()
lit_gpt/tokenizer.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+
8
+ class Tokenizer:
9
+ def __init__(self, checkpoint_dir: Path) -> None:
10
+ self.use_bos = self.check_if_bos_token_used(checkpoint_dir)
11
+ self.bos_id = None
12
+ self.eos_id = None
13
+
14
+ # some checkpoints have both files, `.model` takes precedence
15
+ if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file():
16
+ from sentencepiece import SentencePieceProcessor
17
+
18
+ self.processor = SentencePieceProcessor(model_file=str(vocabulary_path))
19
+ self.backend = "sentencepiece"
20
+ self.bos_id = self.processor.bos_id()
21
+ self.eos_id = self.processor.eos_id()
22
+
23
+ elif (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file():
24
+ from tokenizers import Tokenizer as HFTokenizer
25
+
26
+ self.processor = HFTokenizer.from_file(str(vocabulary_path))
27
+ self.backend = "huggingface"
28
+
29
+ if (special_tokens_path := checkpoint_dir / "tokenizer_config.json").is_file():
30
+ with open(special_tokens_path) as fp:
31
+ config = json.load(fp)
32
+ bos_token = config.get("bos_token")
33
+ self.bos_id = self.token_to_id(bos_token) if bos_token is not None else None
34
+ eos_token = config.get("eos_token")
35
+ self.eos_id = self.token_to_id(eos_token) if eos_token is not None else None
36
+ if (special_tokens_path := checkpoint_dir / "generation_config.json").is_file():
37
+ with open(special_tokens_path) as fp:
38
+ config = json.load(fp)
39
+ if self.bos_id is None:
40
+ self.bos_id = config.get("bos_token_id")
41
+ if self.eos_id is None:
42
+ self.eos_id = config.get("eos_token_id")
43
+ else:
44
+ raise NotImplementedError
45
+
46
+ @property
47
+ def vocab_size(self) -> int:
48
+ if self.backend == "huggingface":
49
+ return self.processor.get_vocab_size(with_added_tokens=False)
50
+ if self.backend == "sentencepiece":
51
+ return self.processor.vocab_size()
52
+ raise RuntimeError
53
+
54
+ def token_to_id(self, token: str) -> int:
55
+ if self.backend == "huggingface":
56
+ id_ = self.processor.token_to_id(token)
57
+ elif self.backend == "sentencepiece":
58
+ id_ = self.processor.piece_to_id(token)
59
+ else:
60
+ raise RuntimeError
61
+ if id_ is None:
62
+ raise ValueError(f"token {token!r} not found in the collection.")
63
+ return id_
64
+
65
+ def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
66
+ if not (tokenizer_config_path := checkpoint_dir / "tokenizer_config.json").is_file():
67
+ return False
68
+ with open(tokenizer_config_path) as fp:
69
+ config = json.load(fp)
70
+ if any(config.get(check, False) for check in ("add_bos_token", "add_prefix_space")):
71
+ return True
72
+ # for examples that also use the Llama tokenizer, but do not have or set add_bos_token to True.
73
+ # ex: https://huggingface.co/stabilityai/StableBeluga2/blob/main/tokenizer_config.json#L2
74
+ return config.get("add_bos_token") is None and config.get("tokenizer_class") == "LlamaTokenizer"
75
+
76
+ def encode(
77
+ self,
78
+ string: str,
79
+ device: Optional[torch.device] = None,
80
+ bos: Optional[bool] = None,
81
+ eos: bool = False,
82
+ max_length: int = -1,
83
+ ) -> torch.Tensor:
84
+ if self.backend == "huggingface":
85
+ tokens = self.processor.encode(string).ids
86
+ elif self.backend == "sentencepiece":
87
+ tokens = self.processor.encode(string)
88
+ else:
89
+ raise RuntimeError
90
+ if bos or (bos is None and self.use_bos):
91
+ bos_id = self.bos_id
92
+ if bos_id is None:
93
+ raise NotImplementedError("This tokenizer does not have a defined a bos token")
94
+ tokens = [bos_id] + tokens
95
+ if eos:
96
+ tokens = tokens + [self.eos_id]
97
+ if max_length > 0:
98
+ tokens = tokens[:max_length]
99
+ return torch.tensor(tokens, dtype=torch.int, device=device)
100
+
101
+ def decode(self, tensor: torch.Tensor) -> str:
102
+ tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()
103
+ return self.processor.decode(tokens)
lit_gpt/utils.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for training and inference."""
2
+ import math
3
+ import pickle
4
+ import sys
5
+ from contextlib import nullcontext
6
+ from io import BytesIO
7
+ from pathlib import Path
8
+ from typing import ContextManager, Dict, List, Mapping, Optional, TypeVar, Union
9
+
10
+ import lightning as L
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.utils._device
14
+ from lightning.fabric.strategies import FSDPStrategy
15
+ from lightning.fabric.utilities.load import _lazy_load as lazy_load
16
+ from torch.serialization import normalize_storage_type
17
+
18
+
19
+ def find_multiple(n: int, k: int) -> int:
20
+ assert k > 0
21
+ if n % k == 0:
22
+ return n
23
+ return n + k - (n % k)
24
+
25
+
26
+ def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:
27
+ total = 0
28
+ for p in module.parameters():
29
+ if requires_grad is None or p.requires_grad == requires_grad:
30
+ if hasattr(p, "quant_state"):
31
+ # bitsandbytes 4bit layer support
32
+ total += math.prod(p.quant_state[1])
33
+ else:
34
+ total += p.numel()
35
+ return total
36
+
37
+
38
+ def gptq_quantization(enabled: bool = False) -> ContextManager:
39
+ if not enabled:
40
+ return nullcontext()
41
+
42
+ from lightning.fabric.plugins.precision.utils import _ClassReplacementContextManager
43
+
44
+ from quantize.gptq import ColBlockQuantizedLinear
45
+
46
+ class QuantizedLinear(ColBlockQuantizedLinear):
47
+ def __init__(self, *args, **kwargs):
48
+ super().__init__(*args, bits=4, tile_cols=-1, **kwargs)
49
+
50
+ return _ClassReplacementContextManager({"torch.nn.Linear": QuantizedLinear})
51
+
52
+
53
+ def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None:
54
+ files = {
55
+ "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(),
56
+ "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(),
57
+ "tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or (
58
+ checkpoint_dir / "tokenizer.model"
59
+ ).is_file(),
60
+ "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
61
+ }
62
+ if checkpoint_dir.is_dir():
63
+ if all(files.values()):
64
+ # we're good
65
+ return
66
+ problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}"
67
+ else:
68
+ problem = " is not a checkpoint directory"
69
+
70
+ # list locally available checkpoints
71
+ available = list(Path("checkpoints").glob("*/*"))
72
+ if available:
73
+ options = "\n --checkpoint_dir ".join([""] + [repr(str(p.resolve())) for p in available])
74
+ extra = f"\nYou have downloaded locally:{options}\n"
75
+ else:
76
+ extra = ""
77
+
78
+ error_message = (
79
+ f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
80
+ "\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n"
81
+ f"{extra}\nSee all download options by running:\n python scripts/download.py"
82
+ )
83
+ print(error_message, file=sys.stderr)
84
+ raise SystemExit(1)
85
+
86
+
87
+ class SavingProxyForStorage:
88
+ def __init__(self, obj, saver, protocol_version=5):
89
+ self.protocol_version = protocol_version
90
+ self.saver = saver
91
+ if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
92
+ raise TypeError(f"expected storage, not {type(obj)}")
93
+
94
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
95
+ if isinstance(obj, torch.storage.TypedStorage):
96
+ # PT upstream wants to deprecate this eventually...
97
+ storage = obj._untyped_storage
98
+ storage_type_str = obj._pickle_storage_type()
99
+ storage_type = getattr(torch, storage_type_str)
100
+ storage_numel = obj._size()
101
+ else:
102
+ storage = obj
103
+ storage_type = normalize_storage_type(type(obj))
104
+ storage_numel = storage.nbytes()
105
+
106
+ storage_key = saver._write_storage_and_return_key(storage)
107
+ location = torch.serialization.location_tag(storage)
108
+
109
+ self.storage_info = ("storage", storage_type, storage_key, location, storage_numel)
110
+
111
+ def __reduce_ex__(self, protocol_version):
112
+ assert False, "this should be handled with out of band"
113
+
114
+
115
+ class SavingProxyForTensor:
116
+ def __init__(self, tensor, saver, protocol_version=5):
117
+ self.protocol_version = protocol_version
118
+ self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version)
119
+ if reduce_args[0] == torch._utils._rebuild_tensor_v2:
120
+ # for Tensors with Python attributes
121
+ (a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args
122
+ assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates"
123
+ storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version)
124
+ self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args)
125
+ else:
126
+ (storage, *other_reduce_args) = reduce_args
127
+ assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates"
128
+ storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version)
129
+ self.reduce_args = (storage_proxy, *other_reduce_args)
130
+
131
+ def __reduce_ex__(self, protocol_version):
132
+ if protocol_version != self.protocol_version:
133
+ raise RuntimeError(f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}")
134
+ return self.reduce_ret_fn, self.reduce_args
135
+
136
+
137
+ class IncrementalPyTorchPickler(pickle.Pickler):
138
+ def __init__(self, saver, *args, **kwargs):
139
+ super().__init__(*args, **kwargs)
140
+ self.storage_dtypes = {}
141
+ self.saver = saver
142
+ self.id_map = {}
143
+
144
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
145
+ def persistent_id(self, obj):
146
+ # FIXME: the docs say that persistent_id should only return a string
147
+ # but torch store returns tuples. This works only in the binary protocol
148
+ # see
149
+ # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
150
+ # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
151
+ if isinstance(obj, SavingProxyForStorage):
152
+ return obj.storage_info
153
+
154
+ if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
155
+ if isinstance(obj, torch.storage.TypedStorage):
156
+ # TODO: Once we decide to break serialization FC, this case
157
+ # can be deleted
158
+ storage = obj._untyped_storage
159
+ storage_dtype = obj.dtype
160
+ storage_type_str = obj._pickle_storage_type()
161
+ storage_type = getattr(torch, storage_type_str)
162
+ storage_numel = obj._size()
163
+
164
+ else:
165
+ storage = obj
166
+ storage_dtype = torch.uint8
167
+ storage_type = normalize_storage_type(type(obj))
168
+ storage_numel = storage.nbytes()
169
+
170
+ # If storage is allocated, ensure that any other saved storages
171
+ # pointing to the same data all have the same dtype. If storage is
172
+ # not allocated, don't perform this check
173
+ if storage.data_ptr() != 0:
174
+ if storage.data_ptr() in self.storage_dtypes:
175
+ if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
176
+ raise RuntimeError(
177
+ "Cannot save multiple tensors or storages that view the same data as different types"
178
+ )
179
+ else:
180
+ self.storage_dtypes[storage.data_ptr()] = storage_dtype
181
+
182
+ storage_key = self.id_map.get(storage._cdata)
183
+ if storage_key is None:
184
+ storage_key = self.saver._write_storage_and_return_key(storage)
185
+ self.id_map[storage._cdata] = storage_key
186
+ location = torch.serialization.location_tag(storage)
187
+
188
+ return ("storage", storage_type, storage_key, location, storage_numel)
189
+
190
+ return None
191
+
192
+
193
+ class incremental_save:
194
+ def __init__(self, name):
195
+ self.name = name
196
+ self.zipfile = torch._C.PyTorchFileWriter(str(name))
197
+ self.has_saved = False
198
+ self.next_key = 0
199
+
200
+ def __enter__(self):
201
+ return self
202
+
203
+ def store_early(self, tensor):
204
+ if isinstance(tensor, torch.Tensor):
205
+ return SavingProxyForTensor(tensor, self)
206
+ raise TypeError(f"can only store tensors early, not {type(tensor)}")
207
+
208
+ def save(self, obj):
209
+ if self.has_saved:
210
+ raise RuntimeError("have already saved")
211
+ # Write the pickle data for `obj`
212
+ data_buf = BytesIO()
213
+ pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
214
+ pickler.dump(obj)
215
+ data_value = data_buf.getvalue()
216
+ self.zipfile.write_record("data.pkl", data_value, len(data_value))
217
+ self.has_saved = True
218
+
219
+ def _write_storage_and_return_key(self, storage):
220
+ if self.has_saved:
221
+ raise RuntimeError("have already saved")
222
+ key = self.next_key
223
+ self.next_key += 1
224
+ name = f"data/{key}"
225
+ if storage.device.type != "cpu":
226
+ storage = storage.cpu()
227
+ num_bytes = storage.nbytes()
228
+ self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
229
+ return key
230
+
231
+ def __exit__(self, type, value, traceback):
232
+ self.zipfile.write_end_of_file()
233
+
234
+
235
+ T = TypeVar("T")
236
+
237
+
238
+ def chunked_cross_entropy(
239
+ logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128
240
+ ) -> torch.Tensor:
241
+ # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate
242
+ # the memory usage in fine-tuning settings with low number of parameters.
243
+ # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing
244
+ # the memory spike's magnitude
245
+
246
+ # lm_head was chunked (we are fine-tuning)
247
+ if isinstance(logits, list):
248
+ # don't want to chunk cross entropy
249
+ if chunk_size == 0:
250
+ logits = torch.cat(logits, dim=1)
251
+ logits = logits.reshape(-1, logits.size(-1))
252
+ targets = targets.reshape(-1)
253
+ return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)
254
+
255
+ # chunk cross entropy
256
+ logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits]
257
+ target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)]
258
+ loss_chunks = [
259
+ torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none")
260
+ for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
261
+ ]
262
+ return torch.cat(loss_chunks).mean()
263
+
264
+ # no chunking at all
265
+ logits = logits.reshape(-1, logits.size(-1))
266
+ targets = targets.reshape(-1)
267
+ if chunk_size == 0:
268
+ return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)
269
+
270
+ # lm_head wasn't chunked, chunk cross entropy
271
+ logit_chunks = logits.split(chunk_size)
272
+ target_chunks = targets.split(chunk_size)
273
+ loss_chunks = [
274
+ torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none")
275
+ for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
276
+ ]
277
+ return torch.cat(loss_chunks).mean()
278
+
279
+
280
+ def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:
281
+ for checkpoint_name, attribute_name in mapping.items():
282
+ full_checkpoint_name = prefix + checkpoint_name
283
+ if full_checkpoint_name in state_dict:
284
+ full_attribute_name = prefix + attribute_name
285
+ state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name)
286
+ return state_dict
287
+
288
+
289
+ def get_default_supported_precision(training: bool) -> str:
290
+ """Return default precision that is supported by the hardware: either `bf16` or `16`.
291
+
292
+ Args:
293
+ training: `-mixed` or `-true` version of the precision to use
294
+
295
+ Returns:
296
+ default precision that is suitable for the task and is supported by the hardware
297
+ """
298
+ from lightning.fabric.accelerators import MPSAccelerator
299
+
300
+ if MPSAccelerator.is_available() or (torch.cuda.is_available() and not torch.cuda.is_bf16_supported()):
301
+ return "16-mixed" if training else "16-true"
302
+ return "bf16-mixed" if training else "bf16-true"
303
+
304
+
305
+ def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True) -> None:
306
+ if isinstance(fabric.strategy, FSDPStrategy):
307
+ fabric.load_raw(checkpoint_path, model, strict=strict)
308
+ else:
309
+ state_dict = lazy_load(checkpoint_path)
310
+ state_dict = state_dict.get("model", state_dict)
311
+ model.load_state_dict(state_dict, strict=strict)
lit_llama/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from lit_llama.model import LLaMAConfig, LLaMA, RMSNorm, build_rope_cache, apply_rope
2
+ from lit_llama.tokenizer import Tokenizer
lit_llama/adapter.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of the paper:
2
+
3
+ LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
4
+ https://arxiv.org/abs/2303.16199
5
+ """
6
+ # mypy: ignore-errors
7
+ import math
8
+ from dataclasses import dataclass
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.nn import functional as F
13
+ import lit_llama.model as llama
14
+ from lit_llama.model import build_rope_cache, apply_rope, RMSNorm, MLP
15
+
16
+
17
+ @dataclass
18
+ class LLaMAConfig(llama.LLaMAConfig):
19
+ adapter_prompt_length: int = 10
20
+ adapter_start_layer: int = 2
21
+
22
+
23
+ class CausalSelfAttention(nn.Module):
24
+ """A modification of `lit_llama.model.CausalSelfAttention` that adds the attention
25
+ over the adaption prompt."""
26
+
27
+ def __init__(self, config: LLaMAConfig, block_idx: int) -> None:
28
+ super().__init__()
29
+ assert config.n_embd % config.n_head == 0
30
+
31
+ # key, query, value projections for all heads, but in a batch
32
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
33
+ # output projection
34
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
35
+
36
+ if block_idx >= config.adapter_start_layer:
37
+ # adapter embedding layer
38
+ self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
39
+ # gate for adaption
40
+ self.gating_factor = torch.nn.Parameter(torch.zeros(1))
41
+
42
+ self.n_head = config.n_head
43
+ self.n_embd = config.n_embd
44
+ self.block_size = config.block_size
45
+ self.block_idx = block_idx
46
+ self.adapter_prompt_length = config.adapter_prompt_length
47
+ self.adapter_start_layer = config.adapter_start_layer
48
+ self.rope_cache = None
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
52
+
53
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
54
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
55
+
56
+ head_size = C // self.n_head
57
+ k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
58
+ q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
59
+ v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
60
+
61
+ if self.rope_cache is None:
62
+ # cache for future forward calls
63
+ self.rope_cache = build_rope_cache(
64
+ seq_len=self.block_size,
65
+ n_elem=self.n_embd // self.n_head,
66
+ dtype=x.dtype,
67
+ device=x.device,
68
+ )
69
+
70
+ q = apply_rope(q, self.rope_cache)
71
+ k = apply_rope(k, self.rope_cache)
72
+
73
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
74
+ # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
75
+ # att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
76
+ # att = F.softmax(att, dim=-1)
77
+ # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
78
+
79
+ # efficient attention using Flash Attention CUDA kernels
80
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
81
+
82
+ if self.block_idx >= self.adapter_start_layer:
83
+ prefix = self.adapter_wte.weight.reshape(1, self.adapter_prompt_length, self.n_embd)
84
+
85
+ aT = prefix.size(1)
86
+ _, ak, av = self.c_attn(prefix).split(self.n_embd, dim=2)
87
+ ak = ak.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2)
88
+ av = av.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2)
89
+
90
+ amask = torch.ones(q.shape[-2], ak.shape[-2], dtype=torch.bool, device=x.device)
91
+ ay = F.scaled_dot_product_attention(q, ak, av, attn_mask=amask, dropout_p=0.0, is_causal=False)
92
+ y = y + self.gating_factor * ay
93
+
94
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
95
+
96
+ # output projection
97
+ y = self.c_proj(y)
98
+
99
+ return y
100
+
101
+
102
+ class Block(nn.Module):
103
+ """The implementation is identical to `lit_llama.model.Block` with the exception that
104
+ we replace the attention layer where adaption is implemented."""
105
+
106
+ def __init__(self, config: LLaMAConfig, block_idx: int) -> None:
107
+ super().__init__()
108
+ self.rms_1 = RMSNorm(config.n_embd)
109
+ self.attn = CausalSelfAttention(config, block_idx)
110
+ self.rms_2 = RMSNorm(config.n_embd)
111
+ self.mlp = MLP(config)
112
+
113
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
114
+ x = x + self.attn(self.rms_1(x))
115
+ x = x + self.mlp(self.rms_2(x))
116
+ return x
117
+
118
+
119
+ class LLaMA(llama.LLaMA):
120
+ """The implementation is identical to `lit_llama.model.LLaMA` with the exception that
121
+ the `Block` saves the layer index and passes it down to the attention layer."""
122
+
123
+ def __init__(self, config: LLaMAConfig) -> None:
124
+ nn.Module.__init__(self)
125
+ assert config.vocab_size is not None
126
+ assert config.block_size is not None
127
+ self.config = config
128
+
129
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
130
+ self.transformer = nn.ModuleDict(
131
+ dict(
132
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
133
+ h=nn.ModuleList([Block(config, i) for i in range(config.n_layer)]),
134
+ ln_f=RMSNorm(config.n_embd),
135
+ )
136
+ )
137
+
138
+ @classmethod
139
+ def from_name(cls, name: str):
140
+ return cls(LLaMAConfig.from_name(name))
141
+
142
+
143
+ def mark_only_adapter_as_trainable(model: LLaMA) -> None:
144
+ """Sets `requires_grad=False` for all non-adapter weights."""
145
+ for name, param in model.named_parameters():
146
+ param.requires_grad = "adapter_wte" in name or "gating_factor" in name
147
+
148
+
149
+ def adapter_state_from_state_dict(state_dict: dict) -> dict:
150
+ """Returns the model state dict with only the adapter weights for saving."""
151
+ return {name: param for name, param in state_dict.items() if "adapter_wte" in name or "gating_factor" in name}
lit_llama/indexed_dataset.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py
2
+
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of the FairSeq source tree.
7
+
8
+ # copied from fairseq/fairseq/data/indexed_dataset.py
9
+ # Removed IndexedRawTextDataset since it relied on Fairseq dictionary
10
+ # other slight modifications to remove fairseq dependencies
11
+ # Added document index to index file and made it accessible.
12
+ # An empty sentence no longer separates documents.
13
+
14
+ from functools import lru_cache
15
+ import os
16
+ import shutil
17
+ import struct
18
+ from itertools import accumulate
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+
24
+ def __best_fitting_dtype(vocab_size=None):
25
+ if vocab_size is not None and vocab_size < 65500:
26
+ return np.uint16
27
+ else:
28
+ return np.int32
29
+
30
+
31
+ def get_available_dataset_impl():
32
+ return ['lazy', 'cached', 'mmap']
33
+
34
+
35
+ def infer_dataset_impl(path):
36
+ if IndexedDataset.exists(path):
37
+ with open(index_file_path(path), 'rb') as f:
38
+ magic = f.read(8)
39
+ if magic == IndexedDataset._HDR_MAGIC:
40
+ return 'cached'
41
+ elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
42
+ return 'mmap'
43
+ else:
44
+ return None
45
+ else:
46
+ print(f"Dataset does not exist: {path}")
47
+ print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
48
+ return None
49
+
50
+
51
+ def make_builder(out_file, impl, vocab_size=None):
52
+ if impl == 'mmap':
53
+ return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size))
54
+ else:
55
+ return IndexedDatasetBuilder(out_file)
56
+
57
+
58
+ def make_dataset(path, impl, skip_warmup=False):
59
+ if not IndexedDataset.exists(path):
60
+ print(f"Dataset does not exist: {path}")
61
+ print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
62
+ return None
63
+ if impl == 'infer':
64
+ impl = infer_dataset_impl(path)
65
+ if impl == 'lazy' and IndexedDataset.exists(path):
66
+ return IndexedDataset(path)
67
+ elif impl == 'cached' and IndexedDataset.exists(path):
68
+ return IndexedCachedDataset(path)
69
+ elif impl == 'mmap' and MMapIndexedDataset.exists(path):
70
+ return MMapIndexedDataset(path, skip_warmup)
71
+ print(f"Unknown dataset implementation: {impl}")
72
+ return None
73
+
74
+
75
+ def dataset_exists(path, impl):
76
+ if impl == 'mmap':
77
+ return MMapIndexedDataset.exists(path)
78
+ else:
79
+ return IndexedDataset.exists(path)
80
+
81
+
82
+ def read_longs(f, n):
83
+ a = np.empty(n, dtype=np.int64)
84
+ f.readinto(a)
85
+ return a
86
+
87
+
88
+ def write_longs(f, a):
89
+ f.write(np.array(a, dtype=np.int64))
90
+
91
+
92
+ dtypes = {
93
+ 1: np.uint8,
94
+ 2: np.int8,
95
+ 3: np.int16,
96
+ 4: np.int32,
97
+ 5: np.int64,
98
+ 6: np.float32,
99
+ 7: np.float64,
100
+ 8: np.uint16
101
+ }
102
+
103
+
104
+ def code(dtype):
105
+ for k in dtypes.keys():
106
+ if dtypes[k] == dtype:
107
+ return k
108
+ raise ValueError(dtype)
109
+
110
+
111
+ def index_file_path(prefix_path):
112
+ return prefix_path + '.idx'
113
+
114
+
115
+ def data_file_path(prefix_path):
116
+ return prefix_path + '.bin'
117
+
118
+
119
+ def create_doc_idx(sizes):
120
+ doc_idx = [0]
121
+ for i, s in enumerate(sizes):
122
+ if s == 0:
123
+ doc_idx.append(i + 1)
124
+ return doc_idx
125
+
126
+
127
+ class IndexedDataset(torch.utils.data.Dataset):
128
+ """Loader for IndexedDataset"""
129
+ _HDR_MAGIC = b'TNTIDX\x00\x00'
130
+
131
+ def __init__(self, path):
132
+ super().__init__()
133
+ self.path = path
134
+ self.data_file = None
135
+ self.read_index(path)
136
+
137
+ def read_index(self, path):
138
+ with open(index_file_path(path), 'rb') as f:
139
+ magic = f.read(8)
140
+ assert magic == self._HDR_MAGIC, (
141
+ 'Index file doesn\'t match expected format. '
142
+ 'Make sure that --dataset-impl is configured properly.'
143
+ )
144
+ version = f.read(8)
145
+ assert struct.unpack('<Q', version) == (1,)
146
+ code, self.element_size = struct.unpack('<QQ', f.read(16))
147
+ self.dtype = dtypes[code]
148
+ self._len, self.s = struct.unpack('<QQ', f.read(16))
149
+ self.doc_count = struct.unpack('<Q', f.read(8))
150
+ self.dim_offsets = read_longs(f, self._len + 1)
151
+ self.data_offsets = read_longs(f, self._len + 1)
152
+ self.sizes = read_longs(f, self.s)
153
+ self.doc_idx = read_longs(f, self.doc_count)
154
+
155
+ def read_data(self, path):
156
+ self.data_file = open(data_file_path(path), 'rb', buffering=0)
157
+
158
+ def check_index(self, i):
159
+ if i < 0 or i >= self._len:
160
+ raise IndexError('index out of range')
161
+
162
+ def __del__(self):
163
+ if self.data_file:
164
+ self.data_file.close()
165
+
166
+ # @lru_cache(maxsize=8)
167
+ def __getitem__(self, idx):
168
+ if not self.data_file:
169
+ self.read_data(self.path)
170
+ if isinstance(idx, int):
171
+ i = idx
172
+ self.check_index(i)
173
+ tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
174
+ a = np.empty(tensor_size, dtype=self.dtype)
175
+ self.data_file.seek(self.data_offsets[i] * self.element_size)
176
+ self.data_file.readinto(a)
177
+ return a
178
+ elif isinstance(idx, slice):
179
+ start, stop, step = idx.indices(len(self))
180
+ if step != 1:
181
+ raise ValueError("Slices into indexed_dataset must be contiguous")
182
+ sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]]
183
+ size = sum(sizes)
184
+ a = np.empty(size, dtype=self.dtype)
185
+ self.data_file.seek(self.data_offsets[start] * self.element_size)
186
+ self.data_file.readinto(a)
187
+ offsets = list(accumulate(sizes))
188
+ sents = np.split(a, offsets[:-1])
189
+ return sents
190
+
191
+ def __len__(self):
192
+ return self._len
193
+
194
+ def num_tokens(self, index):
195
+ return self.sizes[index]
196
+
197
+ def size(self, index):
198
+ return self.sizes[index]
199
+
200
+ @staticmethod
201
+ def exists(path):
202
+ return (
203
+ os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
204
+ )
205
+
206
+ @property
207
+ def supports_prefetch(self):
208
+ return False # avoid prefetching to save memory
209
+
210
+
211
+ class IndexedCachedDataset(IndexedDataset):
212
+
213
+ def __init__(self, path):
214
+ super().__init__(path)
215
+ self.cache = None
216
+ self.cache_index = {}
217
+
218
+ @property
219
+ def supports_prefetch(self):
220
+ return True
221
+
222
+ def prefetch(self, indices):
223
+ if all(i in self.cache_index for i in indices):
224
+ return
225
+ if not self.data_file:
226
+ self.read_data(self.path)
227
+ indices = sorted(set(indices))
228
+ total_size = 0
229
+ for i in indices:
230
+ total_size += self.data_offsets[i + 1] - self.data_offsets[i]
231
+ self.cache = np.empty(total_size, dtype=self.dtype)
232
+ ptx = 0
233
+ self.cache_index.clear()
234
+ for i in indices:
235
+ self.cache_index[i] = ptx
236
+ size = self.data_offsets[i + 1] - self.data_offsets[i]
237
+ a = self.cache[ptx: ptx + size]
238
+ self.data_file.seek(self.data_offsets[i] * self.element_size)
239
+ self.data_file.readinto(a)
240
+ ptx += size
241
+ if self.data_file:
242
+ # close and delete data file after prefetch so we can pickle
243
+ self.data_file.close()
244
+ self.data_file = None
245
+
246
+ # @lru_cache(maxsize=8)
247
+ def __getitem__(self, idx):
248
+ if isinstance(idx, int):
249
+ i = idx
250
+ self.check_index(i)
251
+ tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
252
+ a = np.empty(tensor_size, dtype=self.dtype)
253
+ ptx = self.cache_index[i]
254
+ np.copyto(a, self.cache[ptx: ptx + a.size])
255
+ return a
256
+ elif isinstance(idx, slice):
257
+ # Hack just to make this work, can optimizer later if necessary
258
+ sents = []
259
+ for i in range(*idx.indices(len(self))):
260
+ sents.append(self[i])
261
+ return sents
262
+
263
+
264
+ class IndexedDatasetBuilder(object):
265
+ element_sizes = {
266
+ np.uint8: 1,
267
+ np.int8: 1,
268
+ np.int16: 2,
269
+ np.int32: 4,
270
+ np.int64: 8,
271
+ np.float32: 4,
272
+ np.float64: 8
273
+ }
274
+
275
+ def __init__(self, out_file, dtype=np.int32):
276
+ self.out_file = open(out_file, 'wb')
277
+ self.dtype = dtype
278
+ self.data_offsets = [0]
279
+ self.dim_offsets = [0]
280
+ self.sizes = []
281
+ self.element_size = self.element_sizes[self.dtype]
282
+ self.doc_idx = [0]
283
+
284
+ def add_item(self, tensor):
285
+ bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype))
286
+ self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
287
+ for s in tensor.size():
288
+ self.sizes.append(s)
289
+ self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
290
+
291
+ def end_document(self):
292
+ self.doc_idx.append(len(self.sizes))
293
+
294
+ def merge_file_(self, another_file):
295
+ index = IndexedDataset(another_file)
296
+ assert index.dtype == self.dtype
297
+
298
+ doc_offset = len(self.sizes)
299
+
300
+ begin = self.data_offsets[-1]
301
+ for data_offset in index.data_offsets[1:]:
302
+ self.data_offsets.append(begin + data_offset)
303
+ self.sizes.extend(index.sizes)
304
+
305
+ begin = self.dim_offsets[-1]
306
+ for dim_offset in index.dim_offsets[1:]:
307
+ self.dim_offsets.append(begin + dim_offset)
308
+
309
+ self.doc_idx.extend((doc_offset + index.doc_idx)[1:])
310
+
311
+ with open(data_file_path(another_file), 'rb') as f:
312
+ while True:
313
+ data = f.read(1024)
314
+ if data:
315
+ self.out_file.write(data)
316
+ else:
317
+ break
318
+
319
+ def finalize(self, index_file):
320
+ self.out_file.close()
321
+ index = open(index_file, 'wb')
322
+ index.write(b'TNTIDX\x00\x00')
323
+ index.write(struct.pack('<Q', 1))
324
+ index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
325
+ index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
326
+ index.write(struct.pack('<Q', len(self.doc_idx)))
327
+ write_longs(index, self.dim_offsets)
328
+ write_longs(index, self.data_offsets)
329
+ write_longs(index, self.sizes)
330
+ write_longs(index, self.doc_idx)
331
+ index.close()
332
+
333
+
334
+ def _warmup_mmap_file(path):
335
+ with open(path, 'rb') as stream:
336
+ while stream.read(100 * 1024 * 1024):
337
+ pass
338
+
339
+
340
+ class MMapIndexedDataset(torch.utils.data.Dataset):
341
+ class Index(object):
342
+ _HDR_MAGIC = b'MMIDIDX\x00\x00'
343
+
344
+ @classmethod
345
+ def writer(cls, path, dtype):
346
+ class _Writer(object):
347
+ def __enter__(self):
348
+ self._file = open(path, 'wb')
349
+
350
+ self._file.write(cls._HDR_MAGIC)
351
+ self._file.write(struct.pack('<Q', 1))
352
+ self._file.write(struct.pack('<B', code(dtype)))
353
+
354
+ return self
355
+
356
+ @staticmethod
357
+ def _get_pointers(sizes):
358
+ dtype_size = dtype().itemsize
359
+ address = 0
360
+ pointers = []
361
+
362
+ for size in sizes:
363
+ pointers.append(address)
364
+ address += size * dtype_size
365
+
366
+ return pointers
367
+
368
+ def write(self, sizes, doc_idx):
369
+ pointers = self._get_pointers(sizes)
370
+
371
+ self._file.write(struct.pack('<Q', len(sizes)))
372
+ self._file.write(struct.pack('<Q', len(doc_idx)))
373
+
374
+ sizes = np.array(sizes, dtype=np.int32)
375
+ self._file.write(sizes.tobytes(order='C'))
376
+ del sizes
377
+
378
+ pointers = np.array(pointers, dtype=np.int64)
379
+ self._file.write(pointers.tobytes(order='C'))
380
+ del pointers
381
+
382
+ doc_idx = np.array(doc_idx, dtype=np.int64)
383
+ self._file.write(doc_idx.tobytes(order='C'))
384
+
385
+ def __exit__(self, exc_type, exc_val, exc_tb):
386
+ self._file.close()
387
+
388
+ return _Writer()
389
+
390
+ def __init__(self, path, skip_warmup=False):
391
+ with open(path, 'rb') as stream:
392
+ magic_test = stream.read(9)
393
+ assert self._HDR_MAGIC == magic_test, (
394
+ 'Index file doesn\'t match expected format. '
395
+ 'Make sure that --dataset-impl is configured properly.'
396
+ )
397
+ version = struct.unpack('<Q', stream.read(8))
398
+ assert (1,) == version
399
+
400
+ dtype_code, = struct.unpack('<B', stream.read(1))
401
+ self._dtype = dtypes[dtype_code]
402
+ self._dtype_size = self._dtype().itemsize
403
+
404
+ self._len = struct.unpack('<Q', stream.read(8))[0]
405
+ self._doc_count = struct.unpack('<Q', stream.read(8))[0]
406
+ offset = stream.tell()
407
+
408
+ if not skip_warmup:
409
+ print(" warming up index mmap file...")
410
+ _warmup_mmap_file(path)
411
+
412
+ self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
413
+ self._bin_buffer = memoryview(self._bin_buffer_mmap)
414
+ print(" reading sizes...")
415
+ self._sizes = np.frombuffer(
416
+ self._bin_buffer,
417
+ dtype=np.int32,
418
+ count=self._len,
419
+ offset=offset)
420
+ print(" reading pointers...")
421
+ self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
422
+ offset=offset + self._sizes.nbytes)
423
+ print(" reading document index...")
424
+ self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
425
+ offset=offset + self._sizes.nbytes + self._pointers.nbytes)
426
+
427
+ def __del__(self):
428
+ self._bin_buffer_mmap._mmap.close()
429
+ del self._bin_buffer_mmap
430
+
431
+ @property
432
+ def dtype(self):
433
+ return self._dtype
434
+
435
+ @property
436
+ def sizes(self):
437
+ return self._sizes
438
+
439
+ @property
440
+ def doc_idx(self):
441
+ return self._doc_idx
442
+
443
+ @lru_cache(maxsize=8)
444
+ def __getitem__(self, i):
445
+ return self._pointers[i], self._sizes[i]
446
+
447
+ def __len__(self):
448
+ return self._len
449
+
450
+ def __init__(self, path, skip_warmup=False):
451
+ super().__init__()
452
+
453
+ self._path = None
454
+ self._index = None
455
+ self._bin_buffer = None
456
+
457
+ self._do_init(path, skip_warmup)
458
+
459
+ def __getstate__(self):
460
+ return self._path
461
+
462
+ def __setstate__(self, state):
463
+ self._do_init(state, skip_warmup=True)
464
+
465
+ def _do_init(self, path, skip_warmup):
466
+ self._path = path
467
+ self._index = self.Index(index_file_path(self._path), skip_warmup)
468
+
469
+ if not skip_warmup:
470
+ print(" warming up data mmap file...")
471
+ _warmup_mmap_file(data_file_path(self._path))
472
+ print(" creating numpy buffer of mmap...")
473
+ self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
474
+ print(" creating memory view of numpy buffer...")
475
+ self._bin_buffer = memoryview(self._bin_buffer_mmap)
476
+
477
+ def __del__(self):
478
+ self._bin_buffer_mmap._mmap.close()
479
+ del self._bin_buffer_mmap
480
+ del self._index
481
+
482
+ def __len__(self):
483
+ return len(self._index)
484
+
485
+ # @lru_cache(maxsize=8)
486
+ def __getitem__(self, idx):
487
+ if isinstance(idx, (int, np.integer)):
488
+ ptr, size = self._index[idx]
489
+ np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
490
+ count=size, offset=ptr)
491
+ return np_array
492
+ elif isinstance(idx, slice):
493
+ start, stop, step = idx.indices(len(self))
494
+ if step != 1:
495
+ raise ValueError("Slices into indexed_dataset must be contiguous")
496
+ ptr = self._index._pointers[start]
497
+ sizes = self._index._sizes[idx]
498
+ offsets = list(accumulate(sizes))
499
+ total_size = sum(sizes)
500
+ np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
501
+ count=total_size, offset=ptr)
502
+ sents = np.split(np_array, offsets[:-1])
503
+ return sents
504
+ else:
505
+ raise TypeError("Unexpected type received for idx: {}".format(type(idx)))
506
+
507
+ def get(self, idx, offset=0, length=None):
508
+ """ Retrieves a single item from the dataset with the option to only
509
+ return a portion of the item.
510
+
511
+ get(idx) is the same as [idx] but get() does not support slicing.
512
+ """
513
+ ptr, size = self._index[idx]
514
+ if length is None:
515
+ length = size - offset
516
+ ptr += offset * np.dtype(self._index.dtype).itemsize
517
+ np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
518
+ count=length, offset=ptr)
519
+ return np_array
520
+
521
+ @property
522
+ def sizes(self):
523
+ return self._index.sizes
524
+
525
+ @property
526
+ def doc_idx(self):
527
+ return self._index.doc_idx
528
+
529
+ def get_doc_idx(self):
530
+ return self._index._doc_idx
531
+
532
+ def set_doc_idx(self, doc_idx_):
533
+ self._index._doc_idx = doc_idx_
534
+
535
+ @property
536
+ def supports_prefetch(self):
537
+ return False
538
+
539
+ @staticmethod
540
+ def exists(path):
541
+ return (
542
+ os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
543
+ )
544
+
545
+
546
+ class MMapIndexedDatasetBuilder(object):
547
+ def __init__(self, out_file, dtype=np.int64):
548
+ self._data_file = open(out_file, 'wb')
549
+ self._dtype = dtype
550
+ self._sizes = []
551
+ self._doc_idx = [0]
552
+
553
+ @property
554
+ def dtype(self):
555
+ return self._dtype
556
+
557
+ def add_item(self, np_array):
558
+ # np_array = np.array(tensor.numpy(), dtype=self._dtype)
559
+ self._data_file.write(np_array.tobytes(order='C'))
560
+ self._sizes.append(np_array.size)
561
+
562
+ def add_doc(self, np_array, sizes):
563
+ # np_array = np.array(tensor, dtype=self._dtype)
564
+ self._data_file.write(np_array.tobytes(order='C'))
565
+ self._sizes.extend(sizes)
566
+ self._doc_idx.append(len(self._sizes))
567
+
568
+ def end_document(self):
569
+ self._doc_idx.append(len(self._sizes))
570
+
571
+ def merge_file_(self, another_file):
572
+ # Concatenate index
573
+ index = MMapIndexedDataset.Index(index_file_path(another_file))
574
+ assert index.dtype == self._dtype
575
+
576
+ offset = len(self._sizes)
577
+ self._sizes.extend(index.sizes)
578
+ self._doc_idx.extend((offset + index.doc_idx)[1:])
579
+
580
+ # Concatenate data
581
+ with open(data_file_path(another_file), 'rb') as f:
582
+ shutil.copyfileobj(f, self._data_file)
583
+
584
+ def finalize(self, index_file):
585
+ self._data_file.close()
586
+
587
+ with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
588
+ index.write(self._sizes, self._doc_idx)
lit_llama/lora.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Derived from https://github.com/microsoft/LoRA
2
+ # ------------------------------------------------------------------------------------------
3
+ # Copyright (c) Microsoft Corporation. All rights reserved.
4
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
5
+ # ------------------------------------------------------------------------------------------
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ import math
11
+ from typing import Dict, List
12
+
13
+ import lit_llama.model as llama
14
+
15
+ from contextlib import contextmanager
16
+ from dataclasses import dataclass
17
+
18
+ class LoRALayer():
19
+ def __init__(
20
+ self,
21
+ r: int,
22
+ lora_alpha: int,
23
+ lora_dropout: float,
24
+ merge_weights: bool,
25
+ ):
26
+ self.r = r
27
+ self.lora_alpha = lora_alpha
28
+ # Optional dropout
29
+ if lora_dropout > 0.:
30
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
31
+ else:
32
+ self.lora_dropout = lambda x: x
33
+ # Mark the weight as unmerged
34
+ self.merged = False
35
+ self.merge_weights = merge_weights
36
+
37
+
38
+ class MergedLinear(nn.Linear, LoRALayer):
39
+ # LoRA implemented in a dense layer
40
+ def __init__(
41
+ self,
42
+ in_features: int,
43
+ out_features: int,
44
+ r: int = 0,
45
+ lora_alpha: int = 1,
46
+ lora_dropout: float = 0.,
47
+ enable_lora: List[bool] = [False],
48
+ fan_in_fan_out: bool = False,
49
+ merge_weights: bool = True,
50
+ **kwargs
51
+ ):
52
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
53
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
54
+ merge_weights=merge_weights)
55
+ assert out_features % len(enable_lora) == 0, \
56
+ 'The length of enable_lora must divide out_features'
57
+ self.enable_lora = enable_lora
58
+ self.fan_in_fan_out = fan_in_fan_out
59
+ # Actual trainable parameters
60
+ if r > 0 and any(enable_lora):
61
+ self.lora_A = nn.Parameter(
62
+ self.weight.new_zeros((r * sum(enable_lora), in_features)))
63
+ self.lora_B = nn.Parameter(
64
+ self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))
65
+ ) # weights for Conv1D with groups=sum(enable_lora)
66
+ self.scaling = self.lora_alpha / self.r
67
+ # Freezing the pre-trained weight matrix
68
+ self.weight.requires_grad = False
69
+ # Compute the indices
70
+ self.lora_ind = self.weight.new_zeros(
71
+ (out_features, ), dtype=torch.bool
72
+ ).view(len(enable_lora), -1)
73
+ self.lora_ind[enable_lora, :] = True
74
+ self.lora_ind = self.lora_ind.view(-1)
75
+ self.reset_parameters()
76
+ if fan_in_fan_out:
77
+ self.weight.data = self.weight.data.T
78
+
79
+ def reset_parameters(self):
80
+ nn.Linear.reset_parameters(self)
81
+ if hasattr(self, 'lora_A'):
82
+ # initialize A the same way as the default for nn.Linear and B to zero
83
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
84
+ nn.init.zeros_(self.lora_B)
85
+
86
+ def zero_pad(self, x):
87
+ result = x.new_zeros((*x.shape[:-1], self.out_features))
88
+ result = result.view(-1, self.out_features)
89
+ result[:, self.lora_ind] = x.reshape(
90
+ -1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
91
+ )
92
+ return result.view((*x.shape[:-1], self.out_features))
93
+
94
+ def train(self, mode: bool = True):
95
+ def T(w):
96
+ return w.T if self.fan_in_fan_out else w
97
+ nn.Linear.train(self, mode)
98
+ if self.merge_weights and self.merged:
99
+ # Make sure that the weights are not merged
100
+ if self.r > 0 and any(self.enable_lora):
101
+ delta_w = F.conv1d(
102
+ self.lora_A.data.unsqueeze(0),
103
+ self.lora_B.data.unsqueeze(-1),
104
+ groups=sum(self.enable_lora)
105
+ ).squeeze(0)
106
+ self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
107
+ self.merged = False
108
+
109
+ def eval(self):
110
+ def T(w):
111
+ return w.T if self.fan_in_fan_out else w
112
+ nn.Linear.eval(self)
113
+ if self.merge_weights and not self.merged:
114
+ # Merge the weights and mark it
115
+ if self.r > 0 and any(self.enable_lora):
116
+ delta_w = F.conv1d(
117
+ self.lora_A.data.unsqueeze(0),
118
+ self.lora_B.data.unsqueeze(-1),
119
+ groups=sum(self.enable_lora)
120
+ ).squeeze(0)
121
+ self.weight.data += self.zero_pad(T(delta_w * self.scaling))
122
+ self.merged = True
123
+
124
+ def forward(self, x: torch.Tensor):
125
+ def T(w):
126
+ return w.T if self.fan_in_fan_out else w
127
+ if self.merged:
128
+ return F.linear(x, T(self.weight), bias=self.bias)
129
+ else:
130
+ result = F.linear(x, T(self.weight), bias=self.bias)
131
+ if self.r > 0:
132
+ after_A = F.linear(self.lora_dropout(x), self.lora_A)
133
+ after_B = F.conv1d(
134
+ after_A.transpose(-2, -1),
135
+ self.lora_B.unsqueeze(-1),
136
+ groups=sum(self.enable_lora)
137
+ ).transpose(-2, -1)
138
+ result += self.zero_pad(after_B) * self.scaling
139
+ return result
140
+
141
+
142
+ def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
143
+ # import pdb; pdb.set_trace()
144
+ for n, p in model.named_parameters():
145
+ if 'lora_' not in n and 'motion_proj' not in n and 'llama_proj' not in n:
146
+ p.requires_grad = False
147
+ if bias == 'none':
148
+ return
149
+ elif bias == 'all':
150
+ for n, p in model.named_parameters():
151
+ if 'bias' in n:
152
+ p.requires_grad = True
153
+ elif bias == 'lora_only':
154
+ for m in model.modules():
155
+ if isinstance(m, LoRALayer) and \
156
+ hasattr(m, 'bias') and \
157
+ m.bias is not None:
158
+ m.bias.requires_grad = True
159
+ else:
160
+ raise NotImplementedError
161
+
162
+
163
+ def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
164
+ my_state_dict = model.state_dict()
165
+ if bias == 'none':
166
+ return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'llama_proj' in k or 'motion_proj' in k}
167
+ elif bias == 'all':
168
+ return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k or 'llama_proj' in k or 'motion_proj' in k}
169
+ elif bias == 'lora_only':
170
+ to_return = {}
171
+ for k in my_state_dict:
172
+ if 'lora_' in k:
173
+ to_return[k] = my_state_dict[k]
174
+ bias_name = k.split('lora_')[0]+'bias'
175
+ if bias_name in my_state_dict:
176
+ to_return[bias_name] = my_state_dict[bias_name]
177
+ return to_return
178
+ else:
179
+ raise NotImplementedError
180
+
181
+
182
+ @dataclass
183
+ class LoRAConfig:
184
+ r: float = 0.0
185
+ alpha: float = 1.0
186
+ dropout: float = 0.0
187
+
188
+
189
+ class CausalSelfAttention(llama.CausalSelfAttention):
190
+ lora_config = None
191
+
192
+ def __init__(self, config: llama.LLaMAConfig) -> None:
193
+ # Skip the parent class __init__ altogether and replace it to avoid
194
+ # useless allocations
195
+ nn.Module.__init__(self)
196
+ assert config.n_embd % config.n_head == 0
197
+
198
+ # key, query, value projections for all heads, but in a batch
199
+ self.c_attn = MergedLinear(
200
+ in_features=config.n_embd,
201
+ out_features=3 * config.n_embd,
202
+ r=self.lora_config.r,
203
+ lora_alpha=self.lora_config.alpha,
204
+ lora_dropout=self.lora_config.dropout,
205
+ enable_lora=[True, False, True],
206
+ fan_in_fan_out = False,
207
+ merge_weights=True,
208
+ bias=False)
209
+ # output projection
210
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
211
+ # regularization
212
+ self.n_head = config.n_head
213
+ self.n_embd = config.n_embd
214
+ self.block_size = config.block_size
215
+ self.rope_cache = None
216
+
217
+
218
+ @contextmanager
219
+ def lora(r, alpha, dropout, enabled: bool = True):
220
+ """A context manager under which you can instantiate the model with LoRA."""
221
+ if not enabled:
222
+ yield
223
+ return
224
+
225
+ CausalSelfAttention.lora_config = LoRAConfig(r=r, alpha=alpha, dropout=dropout)
226
+
227
+ causal_self_attention = llama.CausalSelfAttention
228
+ llama.CausalSelfAttention = CausalSelfAttention
229
+ yield
230
+ llama.CausalSelfAttention = causal_self_attention
231
+
232
+ CausalSelfAttention.lora_config = None
lit_llama/model.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Full definition of a LLaMA Language Model, all of it in this single file.
2
+
3
+ Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
4
+ """
5
+ # mypy: ignore-errors
6
+ import math
7
+ from dataclasses import dataclass
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from typing_extensions import Self
13
+
14
+
15
+ @dataclass
16
+ class LLaMAConfig:
17
+ block_size: int = 4096
18
+ vocab_size: int = 32000
19
+ n_layer: int = 32
20
+ n_head: int = 32
21
+ n_embd: int = 4096
22
+
23
+ @classmethod
24
+ def from_name(cls, name: str) -> Self:
25
+ return cls(**llama_configs[name])
26
+
27
+
28
+ llama_configs = {
29
+ "7B": dict(n_layer=32, n_head=32, n_embd=4096),
30
+ "13B": dict(n_layer=40, n_head=40, n_embd=5120),
31
+ "30B": dict(n_layer=60, n_head=52, n_embd=6656),
32
+ "65B": dict(n_layer=80, n_head=64, n_embd=8192),
33
+ }
34
+
35
+
36
+ class LLaMA(nn.Module):
37
+ def __init__(self, config: LLaMAConfig) -> None:
38
+ super().__init__()
39
+ assert config.vocab_size is not None
40
+ assert config.block_size is not None
41
+ self.config = config
42
+
43
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
44
+ self.transformer = nn.ModuleDict(
45
+ dict(
46
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
47
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
48
+ ln_f=RMSNorm(config.n_embd),
49
+ )
50
+ )
51
+ # self.llama_proj = nn.Sequential(
52
+ # nn.Linear(256, 1024),
53
+ # nn.ReLU(),
54
+ # nn.Linear(1024, config.n_embd)
55
+ # )
56
+ self.llama_proj = nn.Linear(512, config.n_embd)
57
+ # self.motion_proj = nn.Sequential(
58
+ # nn.Linear(config.n_embd, 1024),
59
+ # nn.ReLU(),
60
+ # nn.Linear(1024, 256)
61
+ # )
62
+ self.motion_proj = nn.Linear(config.n_embd, 512)
63
+
64
+ def _init_weights(self, module: nn.Module) -> None:
65
+ if isinstance(module, nn.Linear):
66
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
67
+ elif isinstance(module, nn.Embedding):
68
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
69
+
70
+ def forward(self, idx: torch.Tensor) -> torch.Tensor:
71
+ # import pdb; pdb.set_trace()
72
+ _, t = idx.size()
73
+ assert (
74
+ t <= self.config.block_size
75
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
76
+
77
+ # forward the LLaMA model itself
78
+ x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
79
+
80
+ for block in self.transformer.h:
81
+ x = block(x)
82
+ x = self.transformer.ln_f(x)
83
+
84
+ logits = self.lm_head(x) # (b, t, vocab_size)
85
+
86
+ return logits
87
+
88
+ @classmethod
89
+ def from_name(cls, name: str) -> Self:
90
+ return cls(LLaMAConfig.from_name(name))
91
+
92
+
93
+ class Block(nn.Module):
94
+ def __init__(self, config: LLaMAConfig) -> None:
95
+ super().__init__()
96
+ self.rms_1 = RMSNorm(config.n_embd)
97
+ self.attn = CausalSelfAttention(config)
98
+ self.rms_2 = RMSNorm(config.n_embd)
99
+ self.mlp = MLP(config)
100
+
101
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
102
+ x = x + self.attn(self.rms_1(x))
103
+ x = x + self.mlp(self.rms_2(x))
104
+ return x
105
+
106
+
107
+ class CausalSelfAttention(nn.Module):
108
+ def __init__(self, config: LLaMAConfig) -> None:
109
+ super().__init__()
110
+ assert config.n_embd % config.n_head == 0
111
+
112
+ # key, query, value projections for all heads, but in a batch
113
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
114
+ # output projection
115
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
116
+
117
+ self.n_head = config.n_head
118
+ self.n_embd = config.n_embd
119
+ self.block_size = config.block_size
120
+ self.rope_cache = None
121
+
122
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
123
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
124
+
125
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
126
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
127
+
128
+ head_size = C // self.n_head
129
+ k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
130
+ q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
131
+ v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
132
+
133
+ if self.rope_cache is None:
134
+ # cache for future forward calls
135
+ self.rope_cache = build_rope_cache(
136
+ seq_len=self.block_size,
137
+ n_elem=self.n_embd // self.n_head,
138
+ dtype=x.dtype,
139
+ device=x.device,
140
+ )
141
+
142
+ q = apply_rope(q, self.rope_cache)
143
+ k = apply_rope(k, self.rope_cache)
144
+
145
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
146
+ # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
147
+ # att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
148
+ # att = F.softmax(att, dim=-1)
149
+ # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
150
+
151
+ # efficient attention using Flash Attention CUDA kernels
152
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
153
+
154
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
155
+
156
+ # output projection
157
+ y = self.c_proj(y)
158
+
159
+ return y
160
+
161
+
162
+ class MLP(nn.Module):
163
+ def __init__(self, config: LLaMAConfig) -> None:
164
+ super().__init__()
165
+ hidden_dim = 4 * config.n_embd
166
+ n_hidden = int(2 * hidden_dim / 3)
167
+ N = 256
168
+ # ensure n_hidden is multiple of N
169
+ n_hidden = ((n_hidden - 1) // N) * N + N
170
+
171
+ self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False)
172
+ self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False)
173
+ self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False)
174
+
175
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
176
+ x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
177
+ x = self.c_proj(x)
178
+ return x
179
+
180
+
181
+ class RMSNorm(nn.Module):
182
+ """Root Mean Square Layer Normalization.
183
+
184
+ Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
185
+ https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
186
+ """
187
+
188
+ def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
189
+ super().__init__()
190
+ self.scale = nn.Parameter(torch.ones(size))
191
+ self.eps = eps
192
+ self.dim = dim
193
+
194
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
195
+ # NOTE: the original RMSNorm paper implementation is not equivalent
196
+ # norm_x = x.norm(2, dim=self.dim, keepdim=True)
197
+ # rms_x = norm_x * d_x ** (-1. / 2)
198
+ # x_normed = x / (rms_x + self.eps)
199
+ norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
200
+ x_normed = x * torch.rsqrt(norm_x + self.eps)
201
+ return self.scale * x_normed
202
+
203
+
204
+ def build_rope_cache(seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000) -> torch.Tensor:
205
+ """Enhanced Transformer with Rotary Position Embedding.
206
+
207
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
208
+ transformers/rope/__init__.py. MIT License:
209
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
210
+ """
211
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
212
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
213
+
214
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
215
+ seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
216
+
217
+ # Calculate the product of position index and $\theta_i$
218
+ idx_theta = torch.outer(seq_idx, theta)
219
+
220
+ # Compute cache. Because polar only takes float32 or float64, we need to cast
221
+ # when working with 16 bit floats (float16 or bfloat16)
222
+ dtypes_requiring_casting = [torch.float16, torch.bfloat16, torch.int8]
223
+ working_dtype = (
224
+ torch.float32 if dtype in dtypes_requiring_casting else dtype
225
+ )
226
+ complex_dtype = (
227
+ torch.complex32 if dtype in dtypes_requiring_casting else torch.complex64
228
+ )
229
+ cache = torch.polar(
230
+ torch.ones_like(idx_theta).to(working_dtype), idx_theta.to(working_dtype)
231
+ ).to(complex_dtype)
232
+ return cache
233
+
234
+
235
+ def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
236
+ x = x.transpose(1, 2)
237
+
238
+ # truncate to support variable sizes
239
+ T = x.size(1)
240
+ rope_cache = rope_cache[:T]
241
+
242
+ # cast because `view_as_complex` does not support 16 bit tensors
243
+ xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
244
+ rope_cache = rope_cache.view(1, xc.size(1), 1, xc.size(3))
245
+ x_out = torch.view_as_real(xc * rope_cache).flatten(3)
246
+ return x_out.transpose(1, 2).type_as(x)
lit_llama/quantization.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from contextlib import contextmanager
3
+ import warnings
4
+ import math
5
+
6
+ import torch
7
+
8
+ # configuration for bitsandbytes before import
9
+ os.environ["BITSANDBYTES_NOWELCOME"] = "1"
10
+ warnings.filterwarnings(
11
+ "ignore",
12
+ message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization"
13
+ )
14
+ warnings.filterwarnings(
15
+ "ignore",
16
+ message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization"
17
+ )
18
+ warnings.filterwarnings(
19
+ "ignore",
20
+ message="The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable."
21
+ )
22
+
23
+ try:
24
+ import bitsandbytes as bnb # noqa: E402
25
+ except:
26
+ bnb = None
27
+
28
+ if bnb is not None:
29
+ class Linear8bitLt(bnb.nn.Linear8bitLt):
30
+ """Wraps `bnb.nn.Linear8bitLt` and enables instantiation directly on the device and
31
+ re-quantizaton when loading the state dict.
32
+
33
+
34
+ This should only be used for inference. For training, use `bnb.nn.Linear8bitLt` directly.
35
+ """
36
+ def __init__(self, *args, **kwargs):
37
+ super().__init__(*args, **kwargs, has_fp16_weights=False, threshold=6.0)
38
+ # We quantize the initial weight here so we don't end up filling the device
39
+ # memory with float32 weights which could lead to OOM.
40
+ self._quantize_weight(self.weight.data)
41
+
42
+ def _load_from_state_dict(self, local_state_dict, *args, **kwargs):
43
+ # There is only one key that ends with `*.weight`, the other one is the bias
44
+ weight_key = next((name for name in local_state_dict.keys() if name.endswith("weight")), None)
45
+ if weight_key is None:
46
+ return
47
+
48
+ # Load the weight from the state dict and re-quantize it
49
+ weight = local_state_dict.pop(weight_key)
50
+ self._quantize_weight(weight)
51
+
52
+ # If there is a bias, let nn.Module load it
53
+ if local_state_dict:
54
+ super()._load_from_state_dict(local_state_dict, *args, **kwargs)
55
+
56
+ def _quantize_weight(self, weight: torch.Tensor) -> None:
57
+ # This code is taken and adapted from `bnb.nn.Int8Params.cuda()`
58
+ B = weight.contiguous().half().cuda()
59
+ CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
60
+ del CBt
61
+ del SCBt
62
+ self.weight.data = CB
63
+ setattr(self.weight, "CB", CB)
64
+ setattr(self.weight, "SCB", SCB)
65
+
66
+
67
+ # for correctness but with terrible perf
68
+ class ColBlockQuantizedLinear(torch.nn.Module):
69
+ def __init__(self, in_features, out_features, bias: bool, *, bits, tile_cols):
70
+ super().__init__()
71
+ self.in_features = in_features
72
+ self.out_features = out_features
73
+ self.tile_cols = tile_cols if tile_cols != -1 else self.in_features
74
+ self.bits = bits
75
+ self.entries_per_byte = 8 // bits
76
+ assert self.entries_per_byte > 0 and self.entries_per_byte * self.bits == 8
77
+ assert in_features % self.entries_per_byte == 0
78
+ self.register_buffer("quant_weight", torch.empty((self.out_features, self.in_features // self.entries_per_byte), dtype=torch.uint8))
79
+ self.register_buffer("scales", torch.empty((self.out_features, (self.in_features + self.tile_cols - 1) // self.tile_cols)))
80
+ self.register_buffer("zeros", torch.empty_like(self.scales))
81
+ assert isinstance(bias, bool)
82
+ if bias:
83
+ self.register_buffer("bias", torch.empty((self.out_features,)))
84
+ else:
85
+ self.register_buffer("bias", None)
86
+
87
+ def pack_weight(self, weight):
88
+ weight = weight.to(device=self.quant_weight.device, copy=True)
89
+ for j in range(self.scales.size(1)):
90
+ weight[:, j * self.tile_cols: (j + 1) * self.tile_cols] /= self.scales[: , j: j+1]
91
+ weight[:, j * self.tile_cols: (j + 1) * self.tile_cols] += self.zeros[: , j: j+1]
92
+ weight = weight.clamp_(min=0, max=2 ** self.bits - 1).to(dtype=torch.uint8)
93
+ self.quant_weight.zero_()
94
+ for nr in range(self.entries_per_byte):
95
+ self.quant_weight += weight[:, nr::self.entries_per_byte] << (nr * self.bits)
96
+
97
+ def get_weight(self, dtype=torch.float):
98
+ weight = torch.empty((self.out_features, self.in_features), device=self.quant_weight.device, dtype=dtype)
99
+ mask = (1<<self.bits) - 1
100
+ for nr in range(self.entries_per_byte):
101
+ weight[:, nr::self.entries_per_byte] = ((self.quant_weight >> (nr * self.bits)) & mask).float()
102
+ self.quant_weight.to(dtype)
103
+ for j in range(self.scales.size(1)):
104
+ weight[:, j * self.tile_cols: (j + 1) * self.tile_cols] -= self.zeros[: , j: j+1]
105
+ weight[:, j * self.tile_cols: (j + 1) * self.tile_cols] *= self.scales[: , j: j+1]
106
+ return weight
107
+
108
+ def forward(self, inp):
109
+ weight = self.get_weight(dtype=inp.dtype)
110
+ return torch.nn.functional.linear(inp, weight, self.bias)
111
+
112
+
113
+
114
+
115
+ class GPTQQuantizer:
116
+ # The algorithm and code has been taken from https://github.com/IST-DASLab/gptq/
117
+ # E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
118
+ # portions copyright by the authors licensed under the Apache License 2.0
119
+ # All errors are our own.
120
+
121
+ def __init__(self, linear_module, *, bits, perchannel=True, sym=False, blocksize=128, percdamp=.01, groupsize=-1, actorder=False):
122
+ assert isinstance(linear_module, torch.nn.Linear)
123
+
124
+ self.linear_module = linear_module
125
+ self.dev = self.linear_module.weight.device
126
+ self.rows = linear_module.weight.shape[0]
127
+ self.columns = linear_module.weight.shape[1]
128
+ self.H = torch.zeros((self.columns, self.columns), device=self.dev)
129
+ self.nsamples = 0
130
+ self.bits = bits
131
+ self.maxq = 2 ** bits - 1
132
+ self.perchannel = perchannel
133
+ self.sym = sym
134
+ self.blocksize = blocksize
135
+ self.percdamp = percdamp
136
+ self.groupsize = groupsize
137
+ self.actorder = actorder
138
+ self.tile_cols = self.columns if groupsize == -1 else groupsize
139
+ self.scales = torch.zeros((self.rows, (self.columns + self.tile_cols - 1) // self.tile_cols), dtype=self.linear_module.weight.dtype, device = self.dev)
140
+ self.zeros = torch.zeros_like(self.scales)
141
+ assert not (self.actorder and self.groupsize != -1), "The permutation trick does not work for grouped quantization"
142
+
143
+ @staticmethod
144
+ def quantize_weight(x, scale, zero, maxq):
145
+ q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
146
+ x_rec = scale * (q - zero)
147
+ return x_rec
148
+
149
+ def find_params_weight(self, x):
150
+ dev = x.device
151
+
152
+ shape = x.shape
153
+ if self.perchannel:
154
+ x = x.flatten(1)
155
+ else:
156
+ x = x.flatten().unsqueeze(0)
157
+
158
+ tmp = torch.zeros(x.shape[0], device=dev)
159
+ xmin = torch.minimum(x.min(1)[0], tmp)
160
+ xmax = torch.maximum(x.max(1)[0], tmp)
161
+
162
+ if self.sym:
163
+ xmax = torch.maximum(torch.abs(xmin), xmax)
164
+ tmp = xmin < 0
165
+ if torch.any(tmp):
166
+ xmin[tmp] = -xmax[tmp]
167
+ tmp = (xmin == 0) & (xmax == 0)
168
+ xmin[tmp] = -1
169
+ xmax[tmp] = +1
170
+
171
+ scale = (xmax - xmin) / self.maxq
172
+ if self.sym:
173
+ zero = torch.full_like(scale, (self.maxq + 1) / 2)
174
+ else:
175
+ zero = torch.round(-xmin / scale)
176
+
177
+ if not self.perchannel:
178
+ tmp = shape[0]
179
+ scale = scale.repeat(tmp)
180
+ zero = zero.repeat(tmp)
181
+
182
+ shape = [-1] + [1] * (len(shape) - 1)
183
+ scale = scale.reshape(shape)
184
+ zero = zero.reshape(shape)
185
+ return scale, zero
186
+
187
+ def collect_input_stats(self, _1, inp, _2):
188
+ inp = inp[0].detach()
189
+ self.last_inp = inp
190
+ if len(inp.shape) == 2:
191
+ inp = inp.unsqueeze(0)
192
+ tmp = inp.shape[0]
193
+ if len(inp.shape) == 3:
194
+ inp = inp.reshape((-1, inp.shape[-1]))
195
+ inp = inp.t()
196
+ self.H *= self.nsamples / (self.nsamples + tmp)
197
+ self.nsamples += tmp
198
+ # inp = inp.float()
199
+ inp = math.sqrt(2 / self.nsamples) * inp.float()
200
+ # self.H += 2 / self.nsamples * inp.matmul(inp.t())
201
+ self.H += inp.matmul(inp.t())
202
+
203
+ def quantize(self):
204
+ W = self.linear_module.weight.detach().to(dtype=torch.float, copy=True)
205
+
206
+ scale, zero = self.find_params_weight(W)
207
+ self.scales[:] = scale
208
+ self.zeros[:] = zero
209
+
210
+ H = self.H
211
+ del self.H
212
+ dead = torch.diag(H) == 0
213
+ H[dead, dead] = 1
214
+ W[:, dead] = 0
215
+ if self.actorder:
216
+ perm = torch.argsort(torch.diag(H), descending=True)
217
+ W = W[:, perm]
218
+ H = H[perm][:, perm]
219
+
220
+ Losses = torch.zeros_like(W)
221
+ Q = torch.zeros_like(W)
222
+
223
+ damp = self.percdamp * torch.mean(torch.diag(H))
224
+ diag = torch.arange(self.columns, device=self.dev)
225
+ H[diag, diag] += damp
226
+ H = torch.linalg.cholesky(H)
227
+ H = torch.cholesky_inverse(H)
228
+ H = torch.linalg.cholesky(H, upper=True)
229
+ Hinv = H
230
+
231
+ for i1 in range(0, self.columns, self.blocksize):
232
+ i2 = min(i1 + self.blocksize, self.columns)
233
+ count = i2 - i1
234
+
235
+ W1 = W[:, i1:i2].clone()
236
+ Q1 = torch.zeros_like(W1)
237
+ Err1 = torch.zeros_like(W1)
238
+ Losses1 = torch.zeros_like(W1)
239
+ Hinv1 = Hinv[i1:i2, i1:i2]
240
+
241
+ for i in range(count):
242
+ w = W1[:, i]
243
+ d = Hinv1[i, i]
244
+
245
+ if self.groupsize != -1:
246
+ if (i1 + i) % self.groupsize == 0:
247
+ scale, zero = self.find_params_weight(W[:, (i1 + i):(i1 + i + self.groupsize)])
248
+ self.scales[:, (i1 + i) // self.groupsize] = scale
249
+ self.zeros[:, (i1 + i) // self.groupsize] = zeros
250
+
251
+ q = self.quantize_weight(
252
+ w.unsqueeze(1), scale, zero, self.maxq
253
+ )
254
+ q = q.squeeze(1)
255
+ assert q.dim() == 1
256
+ Q1[:, i] = q
257
+ Losses1[:, i] = (w - q) ** 2 / d ** 2
258
+
259
+ err1 = (w - q) / d
260
+ W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
261
+ Err1[:, i] = err1
262
+
263
+ Q[:, i1:i2] = Q1
264
+ Losses[:, i1:i2] = Losses1 / 2
265
+
266
+ W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
267
+
268
+ if self.actorder:
269
+ invperm = torch.argsort(perm)
270
+ Q = Q[:, invperm]
271
+
272
+ weight = Q.reshape(self.linear_module.weight.shape).to(self.linear_module.weight.data.dtype)
273
+ error = torch.sum(Losses).item()
274
+
275
+ q_module = ColBlockQuantizedLinear(self.linear_module.in_features, self.linear_module.out_features, self.linear_module.bias is not None,
276
+ bits=self.bits, tile_cols=self.groupsize).to(self.dev)
277
+ q_module.scales = self.scales
278
+ q_module.zeros = self.zeros
279
+ q_module.pack_weight(weight)
280
+ q_module.bias = self.linear_module.bias
281
+ return q_module, error
lit_llama/tokenizer.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from sentencepiece import SentencePieceProcessor, SentencePieceTrainer
7
+
8
+
9
+ class Tokenizer:
10
+ """Tokenizer for LLaMA."""
11
+
12
+ def __init__(self, model_path: Path) -> None:
13
+ self.processor = SentencePieceProcessor(model_file=str(model_path))
14
+ self.bos_id = self.processor.bos_id()
15
+ self.eos_id = self.processor.eos_id()
16
+ self.pad_id = self.processor.pad_id()
17
+
18
+ @property
19
+ def vocab_size(self) -> int:
20
+ return self.processor.vocab_size()
21
+
22
+ def encode(
23
+ self,
24
+ string: str,
25
+ bos: bool = True,
26
+ eos: bool = False,
27
+ max_length: int = -1,
28
+ pad: bool = False,
29
+ device: Optional[torch.device] = None
30
+ ) -> torch.Tensor:
31
+ tokens = self.processor.encode(string)
32
+ if bos:
33
+ tokens = [self.bos_id] + tokens
34
+ if eos:
35
+ tokens = tokens + [self.eos_id]
36
+ if max_length > 0:
37
+ tokens = tokens[:max_length]
38
+ if pad and len(tokens) < max_length:
39
+ tokens += [self.pad_id] * (max_length - len(tokens))
40
+
41
+ return torch.tensor(tokens, dtype=torch.int, device=device)
42
+
43
+ def decode(self, tokens: torch.Tensor) -> str:
44
+ return self.processor.decode(tokens.tolist())
45
+
46
+ @staticmethod
47
+ def train(input: str, destination: str, vocab_size=32000) -> None:
48
+ model_prefix = os.path.join(destination, "tokenizer")
49
+ SentencePieceTrainer.Train(input=input, model_prefix=model_prefix, vocab_size=vocab_size)
lit_llama/utils.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for training and inference."""
2
+
3
+ import functools
4
+ from pathlib import Path
5
+ import pickle
6
+ import warnings
7
+ from io import BytesIO
8
+
9
+ import torch
10
+ import torch.utils._device
11
+ from lightning.fabric.strategies import DeepSpeedStrategy, FSDPStrategy
12
+ from torch.distributed.fsdp import FullStateDictConfig
13
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
14
+ from torch.distributed.fsdp import StateDictType
15
+
16
+
17
+ def save_model_checkpoint(fabric, model, file_path):
18
+ """Handles boilerplate logic for retrieving and saving the state_dict.
19
+
20
+ This will be upstreamed to Fabric soon.
21
+ """
22
+ file_path = Path(file_path)
23
+
24
+ if isinstance(fabric.strategy, DeepSpeedStrategy):
25
+ from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
26
+
27
+ fabric.save(file_path, {"model": model})
28
+ fabric.barrier()
29
+ if fabric.global_rank == 0:
30
+ # Create a consolidated checkpoint with the same name next to the deepspeed checkpoint
31
+ convert_zero_checkpoint_to_fp32_state_dict(file_path, file_path.with_suffix(".pth"))
32
+ return
33
+
34
+ if isinstance(fabric.strategy, FSDPStrategy):
35
+ save_policy = FullStateDictConfig(offload_to_cpu=(fabric.world_size > 1), rank0_only=True)
36
+ with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
37
+ state_dict = model._forward_module.state_dict()
38
+ else:
39
+ state_dict = model.state_dict()
40
+
41
+ if fabric.global_rank == 0:
42
+ torch.save(state_dict, file_path)
43
+ fabric.barrier()
44
+
45
+
46
+ class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
47
+ def __init__(self, device=None, dtype=None, quantization_mode=None):
48
+ """
49
+ Create tensors with given device and dtype and don't run initialization
50
+ (but instead use "empty tensors", i.e. uninitialized memory).
51
+
52
+ device: `torch.device` to work with
53
+ dtype: `torch.dtype` to work with
54
+ quantization_mode: optional string, quantization mode to work with, default `None`.
55
+ Available modes: `llm.int8` bitsnbytes LLM.int8 quantization (only on GPU)
56
+ `qptq.int4`, `gptq.int8`: GPTQ pre-quantized models
57
+
58
+ Example::
59
+ with EmptyInitOnDevice("cuda", dtype=torch.bfloat16):
60
+ model = LLaMA.from_name('7B')
61
+ model.load_state_dict(torch.load('llama-lit/7B/lit-llama.pth'))"""
62
+
63
+ self.quantization_mode = quantization_mode
64
+ self.quantized_linear_cls = None
65
+ if self.quantization_mode == 'llm.int8':
66
+ if device.type != "cuda":
67
+ raise ValueError("Quantization is only supported on the GPU.")
68
+ from .quantization import Linear8bitLt
69
+ self.quantized_linear_cls = Linear8bitLt
70
+ elif self.quantization_mode == 'gptq.int4':
71
+ from .quantization import ColBlockQuantizedLinear
72
+ self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1)
73
+ elif self.quantization_mode == 'gptq.int8':
74
+ from .quantization import ColBlockQuantizedLinear
75
+ self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1)
76
+ elif self.quantization_mode is not None:
77
+ raise RuntimeError(f"unknown quantization mode {self.quantization_mode}")
78
+ self.device = device
79
+ self.dtype = dtype
80
+
81
+ def __enter__(self):
82
+ if self.quantized_linear_cls != None:
83
+ self.torch_linear_cls = torch.nn.Linear
84
+ torch.nn.Linear = self.quantized_linear_cls
85
+ return super().__enter__()
86
+
87
+ def __exit__(self, exc_type, exc_val, exc_tb):
88
+ if self.quantized_linear_cls != None:
89
+ torch.nn.Linear = self.torch_linear_cls
90
+ return super().__exit__(exc_type, exc_val, exc_tb)
91
+
92
+ def __torch_function__(self, func, types, args=(), kwargs=None):
93
+ kwargs = kwargs or {}
94
+ if getattr(func, "__module__", None) == "torch.nn.init":
95
+ if "tensor" in kwargs:
96
+ return kwargs["tensor"]
97
+ else:
98
+ return args[0]
99
+ if (
100
+ self.device is not None
101
+ and func in torch.utils._device._device_constructors()
102
+ and kwargs.get("device") is None
103
+ ):
104
+ kwargs["device"] = self.device
105
+ if (
106
+ self.dtype is not None
107
+ and func in torch.utils._device._device_constructors()
108
+ and kwargs.get("dtype") is None
109
+ ):
110
+ kwargs["dtype"] = self.dtype
111
+ return func(*args, **kwargs)
112
+
113
+
114
+ # this is taken from torchhacks https://github.com/lernapparat/torchhacks
115
+
116
+
117
+ class NotYetLoadedTensor:
118
+ def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args):
119
+ self.metatensor = metatensor
120
+ self.archiveinfo = archiveinfo
121
+ self.storageinfo = storageinfo
122
+ self.rebuild_args = rebuild_args
123
+
124
+ @classmethod
125
+ def rebuild(
126
+ cls,
127
+ storage,
128
+ storage_offset,
129
+ size,
130
+ stride,
131
+ requires_grad,
132
+ backward_hooks,
133
+ metadata=None,
134
+ archiveinfo=None,
135
+ ):
136
+ rebuild_args = (
137
+ storage_offset,
138
+ size,
139
+ stride,
140
+ requires_grad,
141
+ backward_hooks,
142
+ metadata,
143
+ )
144
+ metatensor = torch._utils._rebuild_tensor_v2(
145
+ storage,
146
+ storage_offset,
147
+ size,
148
+ stride,
149
+ requires_grad,
150
+ backward_hooks,
151
+ metadata,
152
+ )
153
+ storageinfo = storage.archiveinfo
154
+ return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args)
155
+
156
+ def _load_tensor(self):
157
+ name, storage_cls, fn, device, size = self.storageinfo
158
+ dtype = self.metatensor.dtype
159
+
160
+ uts = (
161
+ self.archiveinfo.zipfile.get_storage_from_record(
162
+ f"data/{fn}",
163
+ size * torch._utils._element_size(dtype),
164
+ torch.UntypedStorage,
165
+ )
166
+ ._typed_storage()
167
+ ._untyped_storage
168
+ )
169
+ with warnings.catch_warnings():
170
+ warnings.simplefilter("ignore")
171
+ storage = torch.storage.TypedStorage(
172
+ wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True
173
+ )
174
+ tensor = torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args)
175
+ return tensor
176
+
177
+ @classmethod
178
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
179
+ if kwargs is None:
180
+ kwargs = {}
181
+ loaded_args = [
182
+ (a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args
183
+ ]
184
+ res = func(*loaded_args, **kwargs)
185
+ # gc.collect would be costly here, maybe do it optionally
186
+ return res
187
+
188
+ def __getattr__(self, name):
189
+ # properties
190
+ ## TODO: device, is_...??
191
+ ## TODO: mH, mT, H, T, data, imag, real
192
+ ## name ???
193
+ if name in {
194
+ "dtype",
195
+ "grad",
196
+ "grad_fn",
197
+ "layout",
198
+ "names",
199
+ "ndim",
200
+ "output_nr",
201
+ "requires_grad",
202
+ "retains_grad",
203
+ "shape",
204
+ "volatile",
205
+ }:
206
+ return getattr(self.metatensor, name)
207
+ if name in {"size"}:
208
+ return getattr(self.metatensor, name)
209
+ # materializing with contiguous is needed for quantization
210
+ if name in {"contiguous"}:
211
+ return getattr(self._load_tensor(), name)
212
+
213
+ raise AttributeError(f"{type(self)} does not have {name}")
214
+
215
+ def __repr__(self):
216
+ return f"NotYetLoadedTensor({repr(self.metatensor)})"
217
+
218
+
219
+ class LazyLoadingUnpickler(pickle.Unpickler):
220
+ def __init__(self, file, zipfile):
221
+ super().__init__(file)
222
+ self.zipfile = zipfile
223
+
224
+ def find_class(self, module, name):
225
+ if module == "torch._utils" and name == "_rebuild_tensor_v2":
226
+ res = super().find_class(module, name)
227
+ return functools.partial(NotYetLoadedTensor.rebuild, archiveinfo=self)
228
+ return super().find_class(module, name)
229
+
230
+ def persistent_load(self, pid):
231
+ name, cls, fn, device, size = pid
232
+ with warnings.catch_warnings():
233
+ warnings.simplefilter("ignore")
234
+ s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta")
235
+ s.archiveinfo = pid
236
+ return s
237
+
238
+
239
+ def lazy_load(fn):
240
+ zf = torch._C.PyTorchFileReader(str(fn))
241
+ with BytesIO(zf.get_record("data.pkl")) as pkl:
242
+ mup = LazyLoadingUnpickler(pkl, zf)
243
+ sd = mup.load()
244
+ return sd
models/__init__.py ADDED
File without changes
models/constants.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ X_TOKEN_INDEX = {'IMAGE': -200, 'VIDEO': -201, 'AUDIO': -202, 'THERMAL': -203, 'DEPTH': -204}
9
+ X_INDEX_TOKEN = {v: k for k, v in X_TOKEN_INDEX.items()}
10
+ # IMAGE_TOKEN_INDEX = -200
11
+ DEFAULT_X_TOKEN = {'IMAGE': "<image>", 'VIDEO': "<video>", 'AUDIO': "<audio>", 'THERMAL': "<thermal>", 'DEPTH': "<depth>"}
12
+ # DEFAULT_IMAGE_TOKEN = "<image>"
13
+ DEFAULT_X_PATCH_TOKEN = {'IMAGE': "<im_patch>", 'VIDEO': "<vi_patch>", 'AUDIO': "<au_patch>", 'THERMAL': "<th_patch>", 'DEPTH': "<de_patch>"}
14
+ # DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
15
+ DEFAULT_X_START_TOKEN = {'IMAGE': "<im_start>", 'VIDEO': "<vi_start>", 'AUDIO': "<au_start>", 'THERMAL': "<th_start>", 'DEPTH': "<de_start>"}
16
+ # DEFAULT_IM_START_TOKEN = "<im_start>"
17
+ DEFAULT_X_END_TOKEN = {'IMAGE': "<im_end>", 'VIDEO': "<vi_end>", 'AUDIO': "<au_end>", 'THERMAL': "<th_end>", 'DEPTH': "<de_end>"}
18
+ # DEFAULT_IM_END_TOKEN = "<im_end>"
models/encdec.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from models.resnet import Resnet1D
3
+
4
+ class Encoder(nn.Module):
5
+ def __init__(self,
6
+ input_emb_width = 3,
7
+ output_emb_width = 512,
8
+ down_t = 3,
9
+ stride_t = 2,
10
+ width = 512,
11
+ depth = 3,
12
+ dilation_growth_rate = 3,
13
+ activation='relu',
14
+ norm=None):
15
+ super().__init__()
16
+
17
+ blocks = []
18
+ filter_t, pad_t = stride_t * 2, stride_t // 2
19
+ blocks.append(nn.Conv1d(input_emb_width, width, 3, 1, 1))
20
+ blocks.append(nn.ReLU())
21
+
22
+ for i in range(down_t):
23
+ input_dim = width
24
+ block = nn.Sequential(
25
+ nn.Conv1d(input_dim, width, filter_t, stride_t, pad_t),
26
+ Resnet1D(width, depth, dilation_growth_rate, activation=activation, norm=norm),
27
+ )
28
+ blocks.append(block)
29
+ blocks.append(nn.Conv1d(width, output_emb_width, 3, 1, 1))
30
+ self.model = nn.Sequential(*blocks)
31
+
32
+ def forward(self, x):
33
+ return self.model(x)
34
+
35
+ class Decoder(nn.Module):
36
+ def __init__(self,
37
+ input_emb_width = 3,
38
+ output_emb_width = 512,
39
+ down_t = 3,
40
+ stride_t = 2,
41
+ width = 512,
42
+ depth = 3,
43
+ dilation_growth_rate = 3,
44
+ activation='relu',
45
+ norm=None):
46
+ super().__init__()
47
+ blocks = []
48
+
49
+ filter_t, pad_t = stride_t * 2, stride_t // 2
50
+ blocks.append(nn.Conv1d(output_emb_width, width, 3, 1, 1))
51
+ blocks.append(nn.ReLU())
52
+ for i in range(down_t):
53
+ out_dim = width
54
+ block = nn.Sequential(
55
+ Resnet1D(width, depth, dilation_growth_rate, reverse_dilation=True, activation=activation, norm=norm),
56
+ nn.Upsample(scale_factor=2, mode='nearest'),
57
+ nn.Conv1d(width, out_dim, 3, 1, 1)
58
+ )
59
+ blocks.append(block)
60
+ blocks.append(nn.Conv1d(width, width, 3, 1, 1))
61
+ blocks.append(nn.ReLU())
62
+ blocks.append(nn.Conv1d(width, input_emb_width, 3, 1, 1))
63
+ self.model = nn.Sequential(*blocks)
64
+
65
+ def forward(self, x):
66
+ return self.model(x)
67
+
models/evaluator_wrapper.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from os.path import join as pjoin
4
+ import numpy as np
5
+ from models.modules import MovementConvEncoder, TextEncoderBiGRUCo, MotionEncoderBiGRUCo
6
+ from utils.word_vectorizer import POS_enumerator
7
+
8
+ def build_models(opt):
9
+ movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent)
10
+ text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word,
11
+ pos_size=opt.dim_pos_ohot,
12
+ hidden_size=opt.dim_text_hidden,
13
+ output_size=opt.dim_coemb_hidden,
14
+ device=opt.device)
15
+
16
+ motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent,
17
+ hidden_size=opt.dim_motion_hidden,
18
+ output_size=opt.dim_coemb_hidden,
19
+ device=opt.device)
20
+
21
+ checkpoint = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'),
22
+ map_location=opt.device)
23
+ movement_enc.load_state_dict(checkpoint['movement_encoder'])
24
+ text_enc.load_state_dict(checkpoint['text_encoder'])
25
+ motion_enc.load_state_dict(checkpoint['motion_encoder'])
26
+ print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch']))
27
+ return text_enc, motion_enc, movement_enc
28
+
29
+
30
+ class EvaluatorModelWrapper(object):
31
+
32
+ def __init__(self, opt):
33
+
34
+ if opt.dataset_name == 't2m':
35
+ opt.dim_pose = 263
36
+ elif opt.dataset_name == 'kit':
37
+ opt.dim_pose = 251
38
+ else:
39
+ raise KeyError('Dataset not Recognized!!!')
40
+
41
+ opt.dim_word = 300
42
+ opt.max_motion_length = 196
43
+ opt.dim_pos_ohot = len(POS_enumerator)
44
+ opt.dim_motion_hidden = 1024
45
+ opt.max_text_len = 20
46
+ opt.dim_text_hidden = 512
47
+ opt.dim_coemb_hidden = 512
48
+
49
+ # print(opt)
50
+
51
+ self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt)
52
+ self.opt = opt
53
+ self.device = opt.device
54
+
55
+ self.text_encoder.to(opt.device)
56
+ self.motion_encoder.to(opt.device)
57
+ self.movement_encoder.to(opt.device)
58
+
59
+ self.text_encoder.eval()
60
+ self.motion_encoder.eval()
61
+ self.movement_encoder.eval()
62
+
63
+ # Please note that the results does not following the order of inputs
64
+ def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens):
65
+ with torch.no_grad():
66
+ word_embs = word_embs.detach().to(self.device).float()
67
+ pos_ohot = pos_ohot.detach().to(self.device).float()
68
+ motions = motions.detach().to(self.device).float()
69
+
70
+ '''Movement Encoding'''
71
+ movements = self.movement_encoder(motions[..., :-4]).detach()
72
+ m_lens = m_lens // self.opt.unit_length
73
+ motion_embedding = self.motion_encoder(movements, m_lens)
74
+
75
+ '''Text Encoding'''
76
+ text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens)
77
+ return text_embedding, motion_embedding
78
+
79
+ # Please note that the results does not following the order of inputs
80
+ def get_motion_embeddings(self, motions, m_lens):
81
+ with torch.no_grad():
82
+ motions = motions.detach().to(self.device).float()
83
+
84
+ align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
85
+ motions = motions[align_idx]
86
+ m_lens = m_lens[align_idx]
87
+
88
+ '''Movement Encoding'''
89
+ movements = self.movement_encoder(motions[..., :-4]).detach()
90
+ m_lens = m_lens // self.opt.unit_length
91
+ motion_embedding = self.motion_encoder(movements, m_lens)
92
+ return motion_embedding
models/modules.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.utils.rnn import pack_padded_sequence
4
+
5
+ def init_weight(m):
6
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d):
7
+ nn.init.xavier_normal_(m.weight)
8
+ # m.bias.data.fill_(0.01)
9
+ if m.bias is not None:
10
+ nn.init.constant_(m.bias, 0)
11
+
12
+
13
+ class MovementConvEncoder(nn.Module):
14
+ def __init__(self, input_size, hidden_size, output_size):
15
+ super(MovementConvEncoder, self).__init__()
16
+ self.main = nn.Sequential(
17
+ nn.Conv1d(input_size, hidden_size, 4, 2, 1),
18
+ nn.Dropout(0.2, inplace=True),
19
+ nn.LeakyReLU(0.2, inplace=True),
20
+ nn.Conv1d(hidden_size, output_size, 4, 2, 1),
21
+ nn.Dropout(0.2, inplace=True),
22
+ nn.LeakyReLU(0.2, inplace=True),
23
+ )
24
+ self.out_net = nn.Linear(output_size, output_size)
25
+ self.main.apply(init_weight)
26
+ self.out_net.apply(init_weight)
27
+
28
+ def forward(self, inputs):
29
+ inputs = inputs.permute(0, 2, 1)
30
+ outputs = self.main(inputs).permute(0, 2, 1)
31
+ # print(outputs.shape)
32
+ return self.out_net(outputs)
33
+
34
+
35
+
36
+ class TextEncoderBiGRUCo(nn.Module):
37
+ def __init__(self, word_size, pos_size, hidden_size, output_size, device):
38
+ super(TextEncoderBiGRUCo, self).__init__()
39
+ self.device = device
40
+
41
+ self.pos_emb = nn.Linear(pos_size, word_size)
42
+ self.input_emb = nn.Linear(word_size, hidden_size)
43
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
44
+ self.output_net = nn.Sequential(
45
+ nn.Linear(hidden_size * 2, hidden_size),
46
+ nn.LayerNorm(hidden_size),
47
+ nn.LeakyReLU(0.2, inplace=True),
48
+ nn.Linear(hidden_size, output_size)
49
+ )
50
+
51
+ self.input_emb.apply(init_weight)
52
+ self.pos_emb.apply(init_weight)
53
+ self.output_net.apply(init_weight)
54
+ self.hidden_size = hidden_size
55
+ self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
56
+
57
+ # input(batch_size, seq_len, dim)
58
+ def forward(self, word_embs, pos_onehot, cap_lens):
59
+ num_samples = word_embs.shape[0]
60
+
61
+ pos_embs = self.pos_emb(pos_onehot)
62
+ inputs = word_embs + pos_embs
63
+ input_embs = self.input_emb(inputs)
64
+ hidden = self.hidden.repeat(1, num_samples, 1)
65
+
66
+ cap_lens = cap_lens.data.tolist()
67
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
68
+
69
+ gru_seq, gru_last = self.gru(emb, hidden)
70
+
71
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
72
+
73
+ return self.output_net(gru_last)
74
+
75
+
76
+ class MotionEncoderBiGRUCo(nn.Module):
77
+ def __init__(self, input_size, hidden_size, output_size, device):
78
+ super(MotionEncoderBiGRUCo, self).__init__()
79
+ self.device = device
80
+
81
+ self.input_emb = nn.Linear(input_size, hidden_size)
82
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
83
+ self.output_net = nn.Sequential(
84
+ nn.Linear(hidden_size*2, hidden_size),
85
+ nn.LayerNorm(hidden_size),
86
+ nn.LeakyReLU(0.2, inplace=True),
87
+ nn.Linear(hidden_size, output_size)
88
+ )
89
+
90
+ self.input_emb.apply(init_weight)
91
+ self.output_net.apply(init_weight)
92
+ self.hidden_size = hidden_size
93
+ self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
94
+
95
+ # input(batch_size, seq_len, dim)
96
+ def forward(self, inputs, m_lens):
97
+ num_samples = inputs.shape[0]
98
+
99
+ input_embs = self.input_emb(inputs)
100
+ hidden = self.hidden.repeat(1, num_samples, 1)
101
+
102
+ cap_lens = m_lens.data.tolist()
103
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True, enforce_sorted=False)
104
+
105
+ gru_seq, gru_last = self.gru(emb, hidden)
106
+
107
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
108
+
109
+ return self.output_net(gru_last)
models/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .clip_encoder import CLIPVisionTower
3
+ from .languagebind import LanguageBindImageTower, LanguageBindVideoTower
4
+ from .mae_encoder import MAEVisionTower
5
+ from transformers import CLIPModel
6
+
7
+ def build_image_tower(image_tower_cfg, **kwargs):
8
+ image_tower = getattr(image_tower_cfg, 'mm_image_tower', getattr(image_tower_cfg, 'image_tower', None))
9
+ is_absolute_path_exists = os.path.exists(image_tower)
10
+ if is_absolute_path_exists or image_tower.startswith("openai") or image_tower.startswith("laion"):
11
+ return CLIPVisionTower(image_tower, args=image_tower_cfg, **kwargs)
12
+ if image_tower.endswith('LanguageBind_Image'):
13
+ return LanguageBindImageTower(image_tower, args=image_tower_cfg, cache_dir='./cache_dir', **kwargs)
14
+ if 'mae' in image_tower:
15
+ print('maemaemaemaemaemaemaemae')
16
+ print('maemaemaemaemaemaemaemae')
17
+ print('maemaemaemaemaemaemaemae')
18
+ print('maemaemaemaemaemaemaemae')
19
+ print('maemaemaemaemaemaemaemae')
20
+ return MAEVisionTower(image_tower, args=image_tower_cfg, cache_dir='./cache_dir', **kwargs)
21
+ raise ValueError(f'Unknown image tower: {image_tower}')
22
+
23
+ def build_video_tower(video_tower_cfg, **kwargs):
24
+ video_tower = getattr(video_tower_cfg, 'mm_video_tower', getattr(video_tower_cfg, 'video_tower', None))
25
+ if video_tower.endswith('LanguageBind_Video_merge'):
26
+ return LanguageBindVideoTower(video_tower, args=video_tower_cfg, cache_dir='./cache_dir', **kwargs)
27
+ raise ValueError(f'Unknown video tower: {video_tower}')
28
+
29
+
30
+
31
+ # import os
32
+ # from .clip_encoder import CLIPVisionTower
33
+ # from .languagebind import LanguageBindImageTower, LanguageBindVideoTower
34
+ # from transformers import CLIPModel
35
+
36
+ # def build_image_tower(image_tower_cfg, **kwargs):
37
+ # image_tower = getattr(image_tower_cfg, 'mm_image_tower', getattr(image_tower_cfg, 'image_tower', None))
38
+ # is_absolute_path_exists = os.path.exists(image_tower)
39
+ # if is_absolute_path_exists or image_tower.startswith("openai") or image_tower.startswith("laion"):
40
+ # return CLIPVisionTower(image_tower, args=image_tower_cfg, **kwargs)
41
+ # if image_tower.endswith('LanguageBind_Image'):
42
+ # return LanguageBindImageTower(image_tower, args=image_tower_cfg, cache_dir='./cache_dir', **kwargs)
43
+ # raise ValueError(f'Unknown image tower: {image_tower}')
44
+
45
+ # def build_video_tower(video_tower_cfg, **kwargs):
46
+ # video_tower = getattr(video_tower_cfg, 'mm_video_tower', getattr(video_tower_cfg, 'video_tower', None))
47
+ # if video_tower.endswith('LanguageBind_Video'):
48
+ # return LanguageBindVideoTower(video_tower, args=video_tower_cfg, cache_dir='./cache_dir', **kwargs)
49
+ # raise ValueError(f'Unknown video tower: {video_tower}')
models/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+
6
+
7
+ class CLIPVisionTower(nn.Module):
8
+ def __init__(self, vision_tower, args, delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_tower_name = vision_tower
14
+ self.select_layer = args.mm_vision_select_layer
15
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16
+
17
+ if not delay_load:
18
+ self.load_model()
19
+ else:
20
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
21
+
22
+ def load_model(self):
23
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
24
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
25
+ self.vision_tower.requires_grad_(False)
26
+
27
+ self.is_loaded = True
28
+
29
+ def feature_select(self, image_forward_outs):
30
+ image_features = image_forward_outs.hidden_states[self.select_layer]
31
+ if self.select_feature == 'patch':
32
+ image_features = image_features[:, 1:]
33
+ elif self.select_feature == 'cls_patch':
34
+ image_features = image_features
35
+ else:
36
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
37
+ return image_features
38
+
39
+ @torch.no_grad()
40
+ def forward(self, images):
41
+ if type(images) is list:
42
+ image_features = []
43
+ for image in images:
44
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
45
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
46
+ image_features.append(image_feature)
47
+ else:
48
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
49
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
50
+
51
+ return image_features
52
+
53
+ @property
54
+ def dummy_feature(self):
55
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
56
+
57
+ @property
58
+ def dtype(self):
59
+ return self.vision_tower.dtype
60
+
61
+ @property
62
+ def device(self):
63
+ return self.vision_tower.device
64
+
65
+ @property
66
+ def config(self):
67
+ if self.is_loaded:
68
+ return self.vision_tower.config
69
+ else:
70
+ return self.cfg_only
71
+
72
+ @property
73
+ def hidden_size(self):
74
+ return self.config.hidden_size
75
+
76
+ @property
77
+ def num_patches(self):
78
+ return (self.config.image_size // self.config.patch_size) ** 2
models/multimodal_encoder/languagebind/__init__.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import AutoConfig
4
+
5
+ from .image.configuration_image import LanguageBindImageConfig
6
+ from .image.modeling_image import LanguageBindImage
7
+ from .image.tokenization_image import LanguageBindImageTokenizer
8
+ from .image.processing_image import LanguageBindImageProcessor
9
+
10
+ from .video.configuration_video import LanguageBindVideoConfig
11
+ from .video.modeling_video import LanguageBindVideo
12
+ from .video.tokenization_video import LanguageBindVideoTokenizer
13
+ from .video.processing_video import LanguageBindVideoProcessor
14
+
15
+ from .depth.configuration_depth import LanguageBindDepthConfig
16
+ from .depth.modeling_depth import LanguageBindDepth
17
+ from .depth.tokenization_depth import LanguageBindDepthTokenizer
18
+ from .depth.processing_depth import LanguageBindDepthProcessor
19
+
20
+ from .audio.configuration_audio import LanguageBindAudioConfig
21
+ from .audio.modeling_audio import LanguageBindAudio
22
+ from .audio.tokenization_audio import LanguageBindAudioTokenizer
23
+ from .audio.processing_audio import LanguageBindAudioProcessor
24
+
25
+ from .thermal.configuration_thermal import LanguageBindThermalConfig
26
+ from .thermal.modeling_thermal import LanguageBindThermal
27
+ from .thermal.tokenization_thermal import LanguageBindThermalTokenizer
28
+ from .thermal.processing_thermal import LanguageBindThermalProcessor
29
+
30
+
31
+
32
+ config_dict = {
33
+ 'thermal': LanguageBindThermalConfig,
34
+ 'image': LanguageBindImageConfig,
35
+ 'video': LanguageBindVideoConfig,
36
+ 'depth': LanguageBindDepthConfig,
37
+ 'audio': LanguageBindAudioConfig
38
+ }
39
+ model_dict = {
40
+ 'thermal': LanguageBindThermal,
41
+ 'image': LanguageBindImage,
42
+ 'video': LanguageBindVideo,
43
+ 'depth': LanguageBindDepth,
44
+ 'audio': LanguageBindAudio
45
+ }
46
+ transform_dict = {
47
+ 'video': LanguageBindVideoProcessor,
48
+ 'audio': LanguageBindAudioProcessor,
49
+ 'depth': LanguageBindDepthProcessor,
50
+ 'thermal': LanguageBindThermalProcessor,
51
+ 'image': LanguageBindImageProcessor,
52
+ }
53
+
54
+ class LanguageBind(nn.Module):
55
+ def __init__(self, clip_type=('thermal', 'image', 'video', 'depth', 'audio'), use_temp=True, cache_dir='./cache_dir'):
56
+ super(LanguageBind, self).__init__()
57
+ self.use_temp = use_temp
58
+ self.modality_encoder = {}
59
+ self.modality_proj = {}
60
+ self.modality_scale = {}
61
+ self.modality_config = {}
62
+ for c in clip_type:
63
+ pretrained_ckpt = f'LanguageBind/LanguageBind_{c.capitalize()}'
64
+ model = model_dict[c].from_pretrained(pretrained_ckpt, cache_dir=cache_dir)
65
+ self.modality_encoder[c] = model.vision_model
66
+ self.modality_proj[c] = model.visual_projection
67
+ self.modality_scale[c] = model.logit_scale
68
+ self.modality_config[c] = model.config
69
+ self.modality_encoder['language'] = model.text_model
70
+ self.modality_proj['language'] = model.text_projection
71
+
72
+ self.modality_encoder = nn.ModuleDict(self.modality_encoder)
73
+ self.modality_proj = nn.ModuleDict(self.modality_proj)
74
+
75
+ def forward(self, inputs):
76
+ outputs = {}
77
+ for key, value in inputs.items():
78
+ value = self.modality_encoder[key](**value)[1]
79
+ value = self.modality_proj[key](value)
80
+ value = value / value.norm(p=2, dim=-1, keepdim=True)
81
+ if self.use_temp:
82
+ if key != 'language':
83
+ value = value * self.modality_scale[key].exp()
84
+ outputs[key] = value
85
+ return outputs
86
+
87
+ def to_device(x, device):
88
+ out_dict = {k: v.to(device) for k, v in x.items()}
89
+ return out_dict
90
+
91
+
92
+
93
+
94
+ class LanguageBindImageTower(nn.Module):
95
+ def __init__(self, image_tower, args, delay_load=False, cache_dir='./cache_dir'):
96
+ super().__init__()
97
+ # import pdb; pdb.set_trace()
98
+ self.is_loaded = False
99
+
100
+ self.image_tower_name = image_tower
101
+ self.select_layer = args.mm_vision_select_layer
102
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
103
+
104
+ self.cache_dir = cache_dir
105
+
106
+ if not delay_load:
107
+ self.load_model()
108
+ else:
109
+ # import pdb; pdb.set_trace()
110
+ self.cfg_only = LanguageBindImageConfig.from_pretrained(self.image_tower_name, cache_dir=self.cache_dir)
111
+
112
+ ############################################################
113
+ def load_model(self):
114
+ model = LanguageBindImage.from_pretrained(self.image_tower_name, cache_dir=self.cache_dir)
115
+ self.image_tower = model.vision_model
116
+ self.image_tower.requires_grad_(False)
117
+
118
+ self.image_processor = LanguageBindImageProcessor(model.config)
119
+
120
+ self.is_loaded = True
121
+
122
+ def feature_select(self, image_forward_outs):
123
+ image_features = image_forward_outs.hidden_states[self.select_layer]
124
+ if self.select_feature == 'patch':
125
+ image_features = image_features[:, 1:]
126
+ elif self.select_feature == 'cls_patch':
127
+ image_features = image_features
128
+ else:
129
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
130
+ return image_features
131
+
132
+ @torch.no_grad()
133
+ def forward(self, images):
134
+ if type(images) is list:
135
+ image_features = []
136
+ for image in images:
137
+ image_forward_out = self.image_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
138
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
139
+ image_features.append(image_feature)
140
+ else:
141
+ # print('images', images.shape)
142
+ image_forward_outs = self.image_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
143
+ # print('image_forward_outs', len(image_forward_outs), image_forward_outs[0].shape)
144
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
145
+ # print('image_features', image_features.shape)
146
+
147
+ return image_features
148
+
149
+ @property
150
+ def dummy_feature(self):
151
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
152
+
153
+ @property
154
+ def dtype(self):
155
+ return self.image_tower.embeddings.class_embedding.dtype #############
156
+
157
+ @property
158
+ def device(self):
159
+ return self.image_tower.embeddings.class_embedding.device ##############
160
+
161
+ @property
162
+ def config(self):
163
+ if self.is_loaded:
164
+ return self.image_tower.config
165
+ else:
166
+ return self.cfg_only
167
+
168
+ @property
169
+ def hidden_size(self):
170
+ return self.config.hidden_size
171
+
172
+ @property
173
+ def num_patches(self):
174
+ return (self.config.image_size // self.config.patch_size) ** 2
175
+
176
+ class temp_model(nn.Module):
177
+ def __init__(self):
178
+ super(temp_model, self).__init__()
179
+ def forward(self, **kwargs):
180
+ return torch.randn(25, 1, 256, 1024)
181
+
182
+
183
+ class LanguageBindVideoTower(nn.Module):
184
+ def __init__(self, video_tower, args, delay_load=False, cache_dir='./cache_dir'):
185
+ super().__init__()
186
+
187
+ self.is_loaded = False
188
+
189
+ self.video_tower_name = video_tower
190
+ self.select_layer = args.mm_vision_select_layer
191
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
192
+
193
+ self.cache_dir = cache_dir
194
+
195
+ if not delay_load:
196
+ self.load_model()
197
+ else:
198
+ self.cfg_only = LanguageBindVideoConfig.from_pretrained(self.video_tower_name, cache_dir=self.cache_dir)
199
+
200
+ ## 使用deley load, from_pretrained 之后,self.is_loaded 仍然是false
201
+ # import pdb; pdb.set_trace()
202
+
203
+ ############################################################
204
+ def load_model(self):
205
+ model = LanguageBindVideo.from_pretrained(self.video_tower_name, cache_dir=self.cache_dir)
206
+ self.video_processor = LanguageBindVideoProcessor(model.config)
207
+
208
+
209
+ # model = LanguageBindImage.from_pretrained('LanguageBind/LanguageBind_Image', cache_dir=self.cache_dir)
210
+ self.video_tower = model.vision_model
211
+ self.video_tower.requires_grad_(False)
212
+
213
+
214
+ self.is_loaded = True
215
+
216
+ # def feature_select(self, image_forward_outs):
217
+ # image_features = image_forward_outs.hidden_states[self.select_layer]
218
+ # if self.select_feature == 'patch':
219
+ # image_features = image_features[:, 1:]
220
+ # elif self.select_feature == 'cls_patch':
221
+ # image_features = image_features
222
+ # else:
223
+ # raise ValueError(f'Unexpected select feature: {self.select_feature}')
224
+ # return image_features
225
+
226
+ def feature_select(self, video_forward_outs):
227
+ # print('len(video_forward_outs.hidden_states)', len(video_forward_outs.hidden_states))
228
+ video_features = video_forward_outs.hidden_states[self.select_layer] # b t n c
229
+ b, t, n, c = video_features.shape
230
+ # print('video_features', video_features.shape)
231
+ if self.select_feature == 'patch':
232
+ # video_features = video_features[:, 1:]
233
+ video_features = video_features[:, :, 1:]
234
+ video_features = video_features.reshape(b, -1, c)
235
+ elif self.select_feature == 'cls_patch':
236
+ # video_features = video_features
237
+ video_features = video_features.reshape(b, -1, c)
238
+ else:
239
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
240
+ return video_features
241
+
242
+ @torch.no_grad()
243
+ def forward(self, videos):
244
+ # import pdb; pdb.set_trace()
245
+ if type(videos) is list:
246
+ video_features = []
247
+ for video in videos:
248
+ video_forward_out = self.video_tower(video.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
249
+ video_feature = self.feature_select(video_forward_out).to(video.dtype)
250
+ video_features.append(video_feature)
251
+ else:
252
+ # print(11111111111, videos.shape)
253
+ video_forward_outs = self.video_tower(videos.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
254
+ video_features = self.feature_select(video_forward_outs).to(videos.dtype)
255
+
256
+ return video_features
257
+
258
+ @property
259
+ def dummy_feature(self):
260
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
261
+
262
+ @property
263
+ def dtype(self):
264
+ return self.video_tower.embeddings.class_embedding.dtype #############
265
+ # return torch.randn(1).cuda().dtype
266
+
267
+ @property
268
+ def device(self):
269
+ return self.video_tower.embeddings.class_embedding.device ##############
270
+ # return torch.randn(1).cuda().device
271
+
272
+ @property
273
+ def config(self):
274
+ if self.is_loaded:
275
+ return self.video_tower.config
276
+ else:
277
+ return self.cfg_only
278
+
279
+ @property
280
+ def hidden_size(self):
281
+ return self.config.hidden_size
282
+
283
+ @property
284
+ def num_patches(self):
285
+ return (self.config.image_size // self.config.patch_size) ** 2
models/multimodal_encoder/languagebind/audio/configuration_audio.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ from typing import Union
4
+
5
+ from transformers import PretrainedConfig
6
+ from transformers.utils import logging
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+
11
+
12
+
13
+
14
+
15
+
16
+ class CLIPTextConfig(PretrainedConfig):
17
+ r"""
18
+ This is the configuration class to store the configuration of a [`CLIPTextModel`]. It is used to instantiate a CLIP
19
+ text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
20
+ with the defaults will yield a similar configuration to that of the text encoder of the CLIP
21
+ [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
22
+
23
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
24
+ documentation from [`PretrainedConfig`] for more information.
25
+
26
+ Args:
27
+ vocab_size (`int`, *optional*, defaults to 49408):
28
+ Vocabulary size of the CLIP text model. Defines the number of different tokens that can be represented by
29
+ the `inputs_ids` passed when calling [`CLIPModel`].
30
+ hidden_size (`int`, *optional*, defaults to 512):
31
+ Dimensionality of the encoder layers and the pooler layer.
32
+ intermediate_size (`int`, *optional*, defaults to 2048):
33
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
34
+ num_hidden_layers (`int`, *optional*, defaults to 12):
35
+ Number of hidden layers in the Transformer encoder.
36
+ num_attention_heads (`int`, *optional*, defaults to 8):
37
+ Number of attention heads for each attention layer in the Transformer encoder.
38
+ max_position_embeddings (`int`, *optional*, defaults to 77):
39
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
40
+ just in case (e.g., 512 or 1024 or 2048).
41
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
42
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
43
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
44
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
45
+ The epsilon used by the layer normalization layers.
46
+ attention_dropout (`float`, *optional*, defaults to 0.0):
47
+ The dropout ratio for the attention probabilities.
48
+ initializer_range (`float`, *optional*, defaults to 0.02):
49
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
50
+ initializer_factor (`float`, *optional*, defaults to 1):
51
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
52
+ testing).
53
+
54
+ Example:
55
+
56
+ ```python
57
+ >>> from transformers import CLIPTextConfig, CLIPTextModel
58
+
59
+ >>> # Initializing a CLIPTextConfig with openai/clip-vit-base-patch32 style configuration
60
+ >>> configuration = CLIPTextConfig()
61
+
62
+ >>> # Initializing a CLIPTextModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
63
+ >>> model = CLIPTextModel(configuration)
64
+
65
+ >>> # Accessing the model configuration
66
+ >>> configuration = model.config
67
+ ```"""
68
+ model_type = "clip_text_model"
69
+
70
+ def __init__(
71
+ self,
72
+ vocab_size=49408,
73
+ hidden_size=512,
74
+ intermediate_size=2048,
75
+ projection_dim=512,
76
+ num_hidden_layers=12,
77
+ num_attention_heads=8,
78
+ max_position_embeddings=77,
79
+ hidden_act="quick_gelu",
80
+ layer_norm_eps=1e-5,
81
+ attention_dropout=0.0,
82
+ initializer_range=0.02,
83
+ initializer_factor=1.0,
84
+ # This differs from `CLIPTokenizer`'s default and from openai/clip
85
+ # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
86
+ pad_token_id=1,
87
+ bos_token_id=49406,
88
+ eos_token_id=49407,
89
+ **kwargs,
90
+ ):
91
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
92
+
93
+ self.vocab_size = vocab_size
94
+ self.hidden_size = hidden_size
95
+ self.intermediate_size = intermediate_size
96
+ self.projection_dim = projection_dim
97
+ self.num_hidden_layers = num_hidden_layers
98
+ self.num_attention_heads = num_attention_heads
99
+ self.max_position_embeddings = max_position_embeddings
100
+ self.layer_norm_eps = layer_norm_eps
101
+ self.hidden_act = hidden_act
102
+ self.initializer_range = initializer_range
103
+ self.initializer_factor = initializer_factor
104
+ self.attention_dropout = attention_dropout
105
+ self.add_time_attn = False ######################################
106
+
107
+ @classmethod
108
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
109
+ cls._set_token_in_kwargs(kwargs)
110
+
111
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
112
+
113
+ # get the text config dict if we are loading from CLIPConfig
114
+ if config_dict.get("model_type") == "clip":
115
+ config_dict = config_dict["text_config"]
116
+
117
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
118
+ logger.warning(
119
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
120
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
121
+ )
122
+
123
+ return cls.from_dict(config_dict, **kwargs)
124
+
125
+
126
+
127
+
128
+ class CLIPVisionConfig(PretrainedConfig):
129
+ r"""
130
+ This is the configuration class to store the configuration of a [`CLIPVisionModel`]. It is used to instantiate a
131
+ CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a
132
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP
133
+ [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
134
+
135
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
136
+ documentation from [`PretrainedConfig`] for more information.
137
+
138
+ Args:
139
+ hidden_size (`int`, *optional*, defaults to 768):
140
+ Dimensionality of the encoder layers and the pooler layer.
141
+ intermediate_size (`int`, *optional*, defaults to 3072):
142
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
143
+ num_hidden_layers (`int`, *optional*, defaults to 12):
144
+ Number of hidden layers in the Transformer encoder.
145
+ num_attention_heads (`int`, *optional*, defaults to 12):
146
+ Number of attention heads for each attention layer in the Transformer encoder.
147
+ image_size (`int`, *optional*, defaults to 224):
148
+ The size (resolution) of each image.
149
+ patch_size (`int`, *optional*, defaults to 32):
150
+ The size (resolution) of each patch.
151
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
152
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
153
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
154
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
155
+ The epsilon used by the layer normalization layers.
156
+ attention_dropout (`float`, *optional*, defaults to 0.0):
157
+ The dropout ratio for the attention probabilities.
158
+ initializer_range (`float`, *optional*, defaults to 0.02):
159
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
160
+ initializer_factor (`float`, *optional*, defaults to 1):
161
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
162
+ testing).
163
+
164
+ Example:
165
+
166
+ ```python
167
+ >>> from transformers import CLIPVisionConfig, CLIPVisionModel
168
+
169
+ >>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration
170
+ >>> configuration = CLIPVisionConfig()
171
+
172
+ >>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
173
+ >>> model = CLIPVisionModel(configuration)
174
+
175
+ >>> # Accessing the model configuration
176
+ >>> configuration = model.config
177
+ ```"""
178
+
179
+ model_type = "clip_vision_model"
180
+
181
+ def __init__(
182
+ self,
183
+ hidden_size=768,
184
+ intermediate_size=3072,
185
+ projection_dim=512,
186
+ num_hidden_layers=12,
187
+ num_attention_heads=12,
188
+ num_channels=3,
189
+ image_size=224,
190
+ patch_size=32,
191
+ hidden_act="quick_gelu",
192
+ layer_norm_eps=1e-5,
193
+ attention_dropout=0.0,
194
+ initializer_range=0.02,
195
+ initializer_factor=1.0,
196
+
197
+ add_time_attn=False, ################################
198
+ num_frames=1, ################################
199
+ force_patch_dropout=0.0, ################################
200
+ lora_r=2, ################################
201
+ lora_alpha=16, ################################
202
+ lora_dropout=0.0, ################################
203
+ num_mel_bins=0.0, ################################
204
+ target_length=0.0, ################################
205
+ video_decode_backend='decord', #########################
206
+ audio_sample_rate=16000,
207
+ audio_mean=0.5,
208
+ audio_std=0.5,
209
+ **kwargs,
210
+ ):
211
+ super().__init__(**kwargs)
212
+
213
+ self.hidden_size = hidden_size
214
+ self.intermediate_size = intermediate_size
215
+ self.projection_dim = projection_dim
216
+ self.num_hidden_layers = num_hidden_layers
217
+ self.num_attention_heads = num_attention_heads
218
+ self.num_channels = num_channels
219
+ self.patch_size = patch_size
220
+ self.image_size = image_size
221
+ self.initializer_range = initializer_range
222
+ self.initializer_factor = initializer_factor
223
+ self.attention_dropout = attention_dropout
224
+ self.layer_norm_eps = layer_norm_eps
225
+ self.hidden_act = hidden_act
226
+
227
+ self.add_time_attn = add_time_attn ################
228
+ self.num_frames = num_frames ################
229
+ self.force_patch_dropout = force_patch_dropout ################
230
+ self.lora_r = lora_r ################
231
+ self.lora_alpha = lora_alpha ################
232
+ self.lora_dropout = lora_dropout ################
233
+ self.num_mel_bins = num_mel_bins ################
234
+ self.target_length = target_length ################
235
+ self.video_decode_backend = video_decode_backend ################
236
+
237
+ self.audio_sample_rate = audio_sample_rate
238
+ self.audio_mean = audio_mean
239
+ self.audio_std = audio_std
240
+
241
+ @classmethod
242
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
243
+ cls._set_token_in_kwargs(kwargs)
244
+
245
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
246
+
247
+ # get the vision config dict if we are loading from CLIPConfig
248
+ if config_dict.get("model_type") == "clip":
249
+ config_dict = config_dict["vision_config"]
250
+
251
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
252
+ logger.warning(
253
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
254
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
255
+ )
256
+
257
+ return cls.from_dict(config_dict, **kwargs)
258
+
259
+
260
+ class LanguageBindAudioConfig(PretrainedConfig):
261
+ r"""
262
+ [`CLIPConfig`] is the configuration class to store the configuration of a [`CLIPModel`]. It is used to instantiate
263
+ a CLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating
264
+ a configuration with the defaults will yield a similar configuration to that of the CLIP
265
+ [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
266
+
267
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
268
+ documentation from [`PretrainedConfig`] for more information.
269
+
270
+ Args:
271
+ text_config (`dict`, *optional*):
272
+ Dictionary of configuration options used to initialize [`CLIPTextConfig`].
273
+ vision_config (`dict`, *optional*):
274
+ Dictionary of configuration options used to initialize [`CLIPVisionConfig`].
275
+ projection_dim (`int`, *optional*, defaults to 512):
276
+ Dimentionality of text and vision projection layers.
277
+ logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
278
+ The inital value of the *logit_scale* paramter. Default is used as per the original CLIP implementation.
279
+ kwargs (*optional*):
280
+ Dictionary of keyword arguments.
281
+
282
+ Example:
283
+
284
+ ```python
285
+ >>> from transformers import CLIPConfig, CLIPModel
286
+
287
+ >>> # Initializing a CLIPConfig with openai/clip-vit-base-patch32 style configuration
288
+ >>> configuration = CLIPConfig()
289
+
290
+ >>> # Initializing a CLIPModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
291
+ >>> model = CLIPModel(configuration)
292
+
293
+ >>> # Accessing the model configuration
294
+ >>> configuration = model.config
295
+
296
+ >>> # We can also initialize a CLIPConfig from a CLIPTextConfig and a CLIPVisionConfig
297
+ >>> from transformers import CLIPTextConfig, CLIPVisionConfig
298
+
299
+ >>> # Initializing a CLIPText and CLIPVision configuration
300
+ >>> config_text = CLIPTextConfig()
301
+ >>> config_vision = CLIPVisionConfig()
302
+
303
+ >>> config = CLIPConfig.from_text_vision_configs(config_text, config_vision)
304
+ ```"""
305
+
306
+ model_type = "LanguageBindAudio"
307
+ is_composition = True
308
+
309
+ def __init__(
310
+ self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
311
+ ):
312
+ # If `_config_dict` exist, we use them for the backward compatibility.
313
+ # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
314
+ # of confusion!).
315
+ text_config_dict = kwargs.pop("text_config_dict", None)
316
+ vision_config_dict = kwargs.pop("vision_config_dict", None)
317
+
318
+ super().__init__(**kwargs)
319
+
320
+ # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
321
+ # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
322
+ # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
323
+ if text_config_dict is not None:
324
+ if text_config is None:
325
+ text_config = {}
326
+
327
+ # This is the complete result when using `text_config_dict`.
328
+ _text_config_dict = CLIPTextConfig(**text_config_dict).to_dict()
329
+
330
+ # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
331
+ for key, value in _text_config_dict.items():
332
+ if key in text_config and value != text_config[key] and key not in ["transformers_version"]:
333
+ # If specified in `text_config_dict`
334
+ if key in text_config_dict:
335
+ message = (
336
+ f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. "
337
+ f'The value `text_config_dict["{key}"]` will be used instead.'
338
+ )
339
+ # If inferred from default argument values (just to be super careful)
340
+ else:
341
+ message = (
342
+ f"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The "
343
+ f'value `text_config["{key}"]` will be overriden.'
344
+ )
345
+ logger.warning(message)
346
+
347
+ # Update all values in `text_config` with the ones in `_text_config_dict`.
348
+ text_config.update(_text_config_dict)
349
+
350
+ if vision_config_dict is not None:
351
+ if vision_config is None:
352
+ vision_config = {}
353
+
354
+ # This is the complete result when using `vision_config_dict`.
355
+ _vision_config_dict = CLIPVisionConfig(**vision_config_dict).to_dict()
356
+ # convert keys to string instead of integer
357
+ if "id2label" in _vision_config_dict:
358
+ _vision_config_dict["id2label"] = {
359
+ str(key): value for key, value in _vision_config_dict["id2label"].items()
360
+ }
361
+
362
+ # Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different.
363
+ for key, value in _vision_config_dict.items():
364
+ if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]:
365
+ # If specified in `vision_config_dict`
366
+ if key in vision_config_dict:
367
+ message = (
368
+ f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different "
369
+ f'values. The value `vision_config_dict["{key}"]` will be used instead.'
370
+ )
371
+ # If inferred from default argument values (just to be super careful)
372
+ else:
373
+ message = (
374
+ f"`vision_config_dict` is provided which will be used to initialize `CLIPVisionConfig`. "
375
+ f'The value `vision_config["{key}"]` will be overriden.'
376
+ )
377
+ logger.warning(message)
378
+
379
+ # Update all values in `vision_config` with the ones in `_vision_config_dict`.
380
+ vision_config.update(_vision_config_dict)
381
+
382
+ if text_config is None:
383
+ text_config = {}
384
+ logger.info("`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.")
385
+
386
+ if vision_config is None:
387
+ vision_config = {}
388
+ logger.info("`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.")
389
+
390
+ self.text_config = CLIPTextConfig(**text_config)
391
+ self.vision_config = CLIPVisionConfig(**vision_config)
392
+
393
+ self.projection_dim = projection_dim
394
+ self.logit_scale_init_value = logit_scale_init_value
395
+ self.initializer_factor = 1.0
396
+
397
+ @classmethod
398
+ def from_text_vision_configs(cls, text_config: CLIPTextConfig, vision_config: CLIPVisionConfig, **kwargs):
399
+ r"""
400
+ Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model
401
+ configuration.
402
+
403
+ Returns:
404
+ [`CLIPConfig`]: An instance of a configuration object
405
+ """
406
+
407
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
408
+
409
+ def to_dict(self):
410
+ """
411
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
412
+
413
+ Returns:
414
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
415
+ """
416
+ output = copy.deepcopy(self.__dict__)
417
+ output["text_config"] = self.text_config.to_dict()
418
+ output["vision_config"] = self.vision_config.to_dict()
419
+ output["model_type"] = self.__class__.model_type
420
+ return output
421
+
422
+
423
+
424
+
425
+
426
+
427
+
428
+
429
+
430
+
models/multimodal_encoder/languagebind/audio/modeling_audio.py ADDED
@@ -0,0 +1,1030 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from peft import LoraConfig, get_peft_model
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from transformers import PreTrainedModel, add_start_docstrings
10
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
11
+ from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \
12
+ CLIPVisionModelWithProjection, CLIPTextModelWithProjection, _expand_mask, CLIPOutput, clip_loss
13
+ from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
14
+
15
+ from .configuration_audio import LanguageBindAudioConfig, CLIPVisionConfig, CLIPTextConfig
16
+
17
+
18
+
19
+ class PatchDropout(nn.Module):
20
+ """
21
+ https://arxiv.org/abs/2212.00794
22
+ """
23
+
24
+ def __init__(self, prob, exclude_first_token=True):
25
+ super().__init__()
26
+ assert 0 <= prob < 1.
27
+ self.prob = prob
28
+ self.exclude_first_token = exclude_first_token # exclude CLS token
29
+
30
+ def forward(self, x, B, T):
31
+ if not self.training or self.prob == 0.:
32
+ return x
33
+
34
+ if self.exclude_first_token:
35
+ cls_tokens, x = x[:, :1], x[:, 1:]
36
+ else:
37
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
38
+
39
+ batch = x.size()[0]
40
+ num_tokens = x.size()[1]
41
+
42
+ batch_indices = torch.arange(batch)
43
+ batch_indices = batch_indices[..., None]
44
+
45
+ keep_prob = 1 - self.prob
46
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
47
+
48
+ if T == 1:
49
+ rand = torch.randn(batch, num_tokens)
50
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
51
+ else:
52
+ rand = torch.randn(B, num_tokens)
53
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
54
+ patch_indices_keep = patch_indices_keep.unsqueeze(1).repeat(1, T, 1)
55
+ patch_indices_keep = rearrange(patch_indices_keep, 'b t n -> (b t) n')
56
+
57
+
58
+ x = x[batch_indices, patch_indices_keep]
59
+
60
+ if self.exclude_first_token:
61
+ x = torch.cat((cls_tokens, x), dim=1)
62
+
63
+ return x
64
+
65
+ class CLIPEncoderLayer(nn.Module):
66
+ def __init__(self, config: LanguageBindAudioConfig):
67
+ super().__init__()
68
+ self.embed_dim = config.hidden_size
69
+ self.self_attn = CLIPAttention(config)
70
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
71
+ self.mlp = CLIPMLP(config)
72
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
73
+
74
+ self.add_time_attn = config.add_time_attn
75
+ if self.add_time_attn:
76
+ self.t = config.num_frames
77
+ self.temporal_embedding = nn.Parameter(torch.zeros(1, config.num_frames, config.hidden_size))
78
+ nn.init.normal_(self.temporal_embedding, std=config.hidden_size ** -0.5)
79
+
80
+ self.embed_dim = config.hidden_size
81
+ self.temporal_attn = CLIPAttention(config)
82
+ self.temporal_layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
83
+ self.temporal_mlp = CLIPMLP(config)
84
+ self.temporal_layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
85
+
86
+ def forward(
87
+ self,
88
+ hidden_states: torch.Tensor,
89
+ attention_mask: torch.Tensor,
90
+ causal_attention_mask: torch.Tensor,
91
+ output_attentions: Optional[bool] = False,
92
+ ) -> Tuple[torch.FloatTensor]:
93
+ """
94
+ Args:
95
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
96
+ attention_mask (`torch.FloatTensor`): attention mask of size
97
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
98
+ `(config.encoder_attention_heads,)`.
99
+ output_attentions (`bool`, *optional*):
100
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
101
+ returned tensors for more detail.
102
+ """
103
+
104
+
105
+ if self.add_time_attn:
106
+ bt, n, d = hidden_states.shape
107
+ t = self.t
108
+
109
+ # time embed
110
+ if t != 1:
111
+ n = hidden_states.shape[1]
112
+ hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
113
+ hidden_states = hidden_states + self.temporal_embedding[:, :t, :]
114
+ hidden_states = rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
115
+
116
+ # time attn
117
+ residual = hidden_states
118
+ hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
119
+ # hidden_states = self.layer_norm1(hidden_states) # share layernorm
120
+ hidden_states = self.temporal_layer_norm1(hidden_states)
121
+ hidden_states, attn_weights = self.temporal_attn(
122
+ hidden_states=hidden_states,
123
+ attention_mask=attention_mask,
124
+ causal_attention_mask=causal_attention_mask,
125
+ output_attentions=output_attentions,
126
+ )
127
+ hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
128
+
129
+ residual = hidden_states
130
+ hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
131
+ # hidden_states = self.layer_norm2(hidden_states) # share layernorm
132
+ hidden_states = self.temporal_layer_norm2(hidden_states)
133
+ hidden_states = self.temporal_mlp(hidden_states)
134
+ hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
135
+
136
+ # spatial attn
137
+ residual = hidden_states
138
+
139
+ hidden_states = self.layer_norm1(hidden_states)
140
+ hidden_states, attn_weights = self.self_attn(
141
+ hidden_states=hidden_states,
142
+ attention_mask=attention_mask,
143
+ causal_attention_mask=causal_attention_mask,
144
+ output_attentions=output_attentions,
145
+ )
146
+ hidden_states = residual + hidden_states
147
+
148
+ residual = hidden_states
149
+ hidden_states = self.layer_norm2(hidden_states)
150
+ hidden_states = self.mlp(hidden_states)
151
+ hidden_states = residual + hidden_states
152
+
153
+ outputs = (hidden_states,)
154
+
155
+ if output_attentions:
156
+ outputs += (attn_weights,)
157
+
158
+ return outputs
159
+
160
+
161
+
162
+
163
+
164
+
165
+
166
+
167
+
168
+ class CLIPPreTrainedModel(PreTrainedModel):
169
+ """
170
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
171
+ models.
172
+ """
173
+
174
+ config_class = LanguageBindAudioConfig
175
+ base_model_prefix = "clip"
176
+ supports_gradient_checkpointing = True
177
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
178
+
179
+ def _init_weights(self, module):
180
+ """Initialize the weights"""
181
+ factor = self.config.initializer_factor
182
+ if isinstance(module, CLIPTextEmbeddings):
183
+ module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
184
+ module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
185
+ elif isinstance(module, CLIPVisionEmbeddings):
186
+ factor = self.config.initializer_factor
187
+ nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
188
+ nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
189
+ nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
190
+ elif isinstance(module, CLIPAttention):
191
+ factor = self.config.initializer_factor
192
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
193
+ out_proj_std = (module.embed_dim**-0.5) * factor
194
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
195
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
196
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
197
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
198
+ elif isinstance(module, CLIPMLP):
199
+ factor = self.config.initializer_factor
200
+ in_proj_std = (
201
+ (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
202
+ )
203
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
204
+ nn.init.normal_(module.fc1.weight, std=fc_std)
205
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
206
+ elif isinstance(module, LanguageBindAudio):
207
+ nn.init.normal_(
208
+ module.text_projection.weight,
209
+ std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
210
+ )
211
+ nn.init.normal_(
212
+ module.visual_projection.weight,
213
+ std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
214
+ )
215
+ elif isinstance(module, CLIPVisionModelWithProjection):
216
+ nn.init.normal_(
217
+ module.visual_projection.weight,
218
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
219
+ )
220
+ elif isinstance(module, CLIPTextModelWithProjection):
221
+ nn.init.normal_(
222
+ module.text_projection.weight,
223
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
224
+ )
225
+
226
+ if isinstance(module, nn.LayerNorm):
227
+ module.bias.data.zero_()
228
+ module.weight.data.fill_(1.0)
229
+ if isinstance(module, nn.Linear) and module.bias is not None:
230
+ module.bias.data.zero_()
231
+
232
+ def _set_gradient_checkpointing(self, module, value=False):
233
+ if isinstance(module, CLIPEncoder):
234
+ module.gradient_checkpointing = value
235
+
236
+
237
+ CLIP_START_DOCSTRING = r"""
238
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
239
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
240
+ etc.)
241
+
242
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
243
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
244
+ and behavior.
245
+
246
+ Parameters:
247
+ config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
248
+ Initializing with a config file does not load the weights associated with the model, only the
249
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
250
+ """
251
+
252
+ CLIP_TEXT_INPUTS_DOCSTRING = r"""
253
+ Args:
254
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
255
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
256
+ it.
257
+
258
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
259
+ [`PreTrainedTokenizer.__call__`] for details.
260
+
261
+ [What are input IDs?](../glossary#input-ids)
262
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
263
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
264
+
265
+ - 1 for tokens that are **not masked**,
266
+ - 0 for tokens that are **masked**.
267
+
268
+ [What are attention masks?](../glossary#attention-mask)
269
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
270
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
271
+ config.max_position_embeddings - 1]`.
272
+
273
+ [What are position IDs?](../glossary#position-ids)
274
+ output_attentions (`bool`, *optional*):
275
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
276
+ tensors for more detail.
277
+ output_hidden_states (`bool`, *optional*):
278
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
279
+ more detail.
280
+ return_dict (`bool`, *optional*):
281
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
282
+ """
283
+
284
+ CLIP_VISION_INPUTS_DOCSTRING = r"""
285
+ Args:
286
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
287
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
288
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
289
+ output_attentions (`bool`, *optional*):
290
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
291
+ tensors for more detail.
292
+ output_hidden_states (`bool`, *optional*):
293
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
294
+ more detail.
295
+ return_dict (`bool`, *optional*):
296
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
297
+ """
298
+
299
+ CLIP_INPUTS_DOCSTRING = r"""
300
+ Args:
301
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
302
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
303
+ it.
304
+
305
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
306
+ [`PreTrainedTokenizer.__call__`] for details.
307
+
308
+ [What are input IDs?](../glossary#input-ids)
309
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
310
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
311
+
312
+ - 1 for tokens that are **not masked**,
313
+ - 0 for tokens that are **masked**.
314
+
315
+ [What are attention masks?](../glossary#attention-mask)
316
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
317
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
318
+ config.max_position_embeddings - 1]`.
319
+
320
+ [What are position IDs?](../glossary#position-ids)
321
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
322
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
323
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
324
+ return_loss (`bool`, *optional*):
325
+ Whether or not to return the contrastive loss.
326
+ output_attentions (`bool`, *optional*):
327
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
328
+ tensors for more detail.
329
+ output_hidden_states (`bool`, *optional*):
330
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
331
+ more detail.
332
+ return_dict (`bool`, *optional*):
333
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
334
+ """
335
+
336
+
337
+ class CLIPEncoder(nn.Module):
338
+ """
339
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
340
+ [`CLIPEncoderLayer`].
341
+
342
+ Args:
343
+ config: CLIPConfig
344
+ """
345
+
346
+ def __init__(self, config: LanguageBindAudioConfig):
347
+ super().__init__()
348
+ self.config = config
349
+ self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
350
+ self.gradient_checkpointing = False
351
+
352
+ def forward(
353
+ self,
354
+ inputs_embeds,
355
+ attention_mask: Optional[torch.Tensor] = None,
356
+ causal_attention_mask: Optional[torch.Tensor] = None,
357
+ output_attentions: Optional[bool] = None,
358
+ output_hidden_states: Optional[bool] = None,
359
+ return_dict: Optional[bool] = None,
360
+ ) -> Union[Tuple, BaseModelOutput]:
361
+ r"""
362
+ Args:
363
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
364
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
365
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
366
+ than the model's internal embedding lookup matrix.
367
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
368
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
369
+
370
+ - 1 for tokens that are **not masked**,
371
+ - 0 for tokens that are **masked**.
372
+
373
+ [What are attention masks?](../glossary#attention-mask)
374
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
375
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
376
+
377
+ - 1 for tokens that are **not masked**,
378
+ - 0 for tokens that are **masked**.
379
+
380
+ [What are attention masks?](../glossary#attention-mask)
381
+ output_attentions (`bool`, *optional*):
382
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
383
+ returned tensors for more detail.
384
+ output_hidden_states (`bool`, *optional*):
385
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
386
+ for more detail.
387
+ return_dict (`bool`, *optional*):
388
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
389
+ """
390
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
391
+ output_hidden_states = (
392
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
393
+ )
394
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
395
+
396
+ encoder_states = () if output_hidden_states else None
397
+ all_attentions = () if output_attentions else None
398
+
399
+ hidden_states = inputs_embeds
400
+ for idx, encoder_layer in enumerate(self.layers):
401
+ if output_hidden_states:
402
+ encoder_states = encoder_states + (hidden_states,)
403
+ if self.gradient_checkpointing and self.training:
404
+
405
+ def create_custom_forward(module):
406
+ def custom_forward(*inputs):
407
+ return module(*inputs, output_attentions)
408
+
409
+ return custom_forward
410
+
411
+ layer_outputs = torch.utils.checkpoint.checkpoint(
412
+ create_custom_forward(encoder_layer),
413
+ hidden_states,
414
+ attention_mask,
415
+ causal_attention_mask,
416
+ )
417
+ else:
418
+ layer_outputs = encoder_layer(
419
+ hidden_states,
420
+ attention_mask,
421
+ causal_attention_mask,
422
+ output_attentions=output_attentions,
423
+ )
424
+
425
+ hidden_states = layer_outputs[0]
426
+
427
+ if output_attentions:
428
+ all_attentions = all_attentions + (layer_outputs[1],)
429
+
430
+ if output_hidden_states:
431
+ encoder_states = encoder_states + (hidden_states,)
432
+
433
+ if not return_dict:
434
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
435
+ return BaseModelOutput(
436
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
437
+ )
438
+
439
+
440
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
441
+ def _make_causal_mask(
442
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
443
+ ):
444
+ """
445
+ Make causal mask used for bi-directional self-attention.
446
+ """
447
+ bsz, tgt_len = input_ids_shape
448
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
449
+ mask_cond = torch.arange(mask.size(-1), device=device)
450
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
451
+ mask = mask.to(dtype)
452
+
453
+ if past_key_values_length > 0:
454
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
455
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
456
+
457
+
458
+ class CLIPTextTransformer(nn.Module):
459
+ def __init__(self, config: CLIPTextConfig):
460
+ super().__init__()
461
+ self.config = config
462
+ embed_dim = config.hidden_size
463
+ self.embeddings = CLIPTextEmbeddings(config)
464
+ self.encoder = CLIPEncoder(config)
465
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
466
+
467
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
468
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
469
+ def forward(
470
+ self,
471
+ input_ids: Optional[torch.Tensor] = None,
472
+ attention_mask: Optional[torch.Tensor] = None,
473
+ position_ids: Optional[torch.Tensor] = None,
474
+ output_attentions: Optional[bool] = None,
475
+ output_hidden_states: Optional[bool] = None,
476
+ return_dict: Optional[bool] = None,
477
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
478
+ r"""
479
+ Returns:
480
+
481
+ """
482
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
483
+ output_hidden_states = (
484
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
485
+ )
486
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
487
+
488
+ if input_ids is None:
489
+ raise ValueError("You have to specify input_ids")
490
+
491
+ input_shape = input_ids.size()
492
+ input_ids = input_ids.view(-1, input_shape[-1])
493
+
494
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
495
+
496
+ # CLIP's text model uses causal mask, prepare it here.
497
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
498
+ causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
499
+ # expand attention_mask
500
+ if attention_mask is not None:
501
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
502
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
503
+
504
+ encoder_outputs = self.encoder(
505
+ inputs_embeds=hidden_states,
506
+ attention_mask=attention_mask,
507
+ causal_attention_mask=causal_attention_mask,
508
+ output_attentions=output_attentions,
509
+ output_hidden_states=output_hidden_states,
510
+ return_dict=return_dict,
511
+ )
512
+
513
+ last_hidden_state = encoder_outputs[0]
514
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
515
+
516
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
517
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
518
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
519
+ pooled_output = last_hidden_state[
520
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
521
+ input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
522
+ ]
523
+
524
+ if not return_dict:
525
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
526
+
527
+ return BaseModelOutputWithPooling(
528
+ last_hidden_state=last_hidden_state,
529
+ pooler_output=pooled_output,
530
+ hidden_states=encoder_outputs.hidden_states,
531
+ attentions=encoder_outputs.attentions,
532
+ )
533
+
534
+
535
+ @add_start_docstrings(
536
+ """The text model from CLIP without any head or projection on top.""",
537
+ CLIP_START_DOCSTRING,
538
+ )
539
+ class CLIPTextModel(CLIPPreTrainedModel):
540
+ config_class = CLIPTextConfig
541
+
542
+ _no_split_modules = ["CLIPEncoderLayer"]
543
+
544
+ def __init__(self, config: CLIPTextConfig):
545
+ super().__init__(config)
546
+ self.text_model = CLIPTextTransformer(config)
547
+ # Initialize weights and apply final processing
548
+ self.post_init()
549
+
550
+ def get_input_embeddings(self) -> nn.Module:
551
+ return self.text_model.embeddings.token_embedding
552
+
553
+ def set_input_embeddings(self, value):
554
+ self.text_model.embeddings.token_embedding = value
555
+
556
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
557
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
558
+ def forward(
559
+ self,
560
+ input_ids: Optional[torch.Tensor] = None,
561
+ attention_mask: Optional[torch.Tensor] = None,
562
+ position_ids: Optional[torch.Tensor] = None,
563
+ output_attentions: Optional[bool] = None,
564
+ output_hidden_states: Optional[bool] = None,
565
+ return_dict: Optional[bool] = None,
566
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
567
+ r"""
568
+ Returns:
569
+
570
+ Examples:
571
+
572
+ ```python
573
+ >>> from transformers import AutoTokenizer, CLIPTextModel
574
+
575
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
576
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
577
+
578
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
579
+
580
+ >>> outputs = model(**inputs)
581
+ >>> last_hidden_state = outputs.last_hidden_state
582
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
583
+ ```"""
584
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
585
+
586
+ return self.text_model(
587
+ input_ids=input_ids,
588
+ attention_mask=attention_mask,
589
+ position_ids=position_ids,
590
+ output_attentions=output_attentions,
591
+ output_hidden_states=output_hidden_states,
592
+ return_dict=return_dict,
593
+ )
594
+
595
+
596
+ class CLIPVisionTransformer(nn.Module):
597
+ def __init__(self, config: CLIPVisionConfig):
598
+ super().__init__()
599
+ self.config = config
600
+ embed_dim = config.hidden_size
601
+
602
+ self.embeddings = CLIPVisionEmbeddings(config)
603
+ self.patch_dropout = PatchDropout(config.force_patch_dropout)
604
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
605
+ self.encoder = CLIPEncoder(config)
606
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
607
+
608
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
609
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
610
+ def forward(
611
+ self,
612
+ pixel_values: Optional[torch.FloatTensor] = None,
613
+ output_attentions: Optional[bool] = None,
614
+ output_hidden_states: Optional[bool] = None,
615
+ return_dict: Optional[bool] = None,
616
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
617
+ r"""
618
+ Returns:
619
+
620
+ """
621
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
622
+ output_hidden_states = (
623
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
624
+ )
625
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
626
+
627
+ if pixel_values is None:
628
+ raise ValueError("You have to specify pixel_values")
629
+ ######################################
630
+ if len(pixel_values.shape) == 7:
631
+ b_new, pair_new, T, bs_new, channel_new, h_new, w_new = pixel_values.shape
632
+ # print(pixel_values.shape)
633
+ B = b_new * pair_new * bs_new
634
+ pixel_values = pixel_values.reshape(B*T, channel_new, h_new, w_new)
635
+
636
+ elif len(pixel_values.shape) == 5:
637
+ B, _, T, _, _ = pixel_values.shape
638
+ # print(pixel_values.shape)
639
+ pixel_values = rearrange(pixel_values, 'b c t h w -> (b t) c h w')
640
+ else:
641
+ # print(pixel_values.shape)
642
+ B, _, _, _ = pixel_values.shape
643
+ T = 1
644
+ ###########################
645
+ hidden_states = self.embeddings(pixel_values)
646
+
647
+ hidden_states = self.patch_dropout(hidden_states, B, T) ##############################################
648
+
649
+ hidden_states = self.pre_layrnorm(hidden_states)
650
+
651
+ encoder_outputs = self.encoder(
652
+ inputs_embeds=hidden_states,
653
+ output_attentions=output_attentions,
654
+ output_hidden_states=output_hidden_states,
655
+ return_dict=return_dict,
656
+ )
657
+
658
+ last_hidden_state = encoder_outputs[0]
659
+ pooled_output = last_hidden_state[:, 0, :]
660
+ pooled_output = self.post_layernorm(pooled_output)
661
+
662
+ pooled_output = pooled_output.reshape(B, T, -1).mean(1) ################################
663
+
664
+ if not return_dict:
665
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
666
+
667
+ return BaseModelOutputWithPooling(
668
+ last_hidden_state=last_hidden_state,
669
+ pooler_output=pooled_output,
670
+ hidden_states=encoder_outputs.hidden_states,
671
+ attentions=encoder_outputs.attentions,
672
+ )
673
+
674
+
675
+ @add_start_docstrings(
676
+ """The vision model from CLIP without any head or projection on top.""",
677
+ CLIP_START_DOCSTRING,
678
+ )
679
+ class CLIPVisionModel(CLIPPreTrainedModel):
680
+ config_class = CLIPVisionConfig
681
+ main_input_name = "pixel_values"
682
+
683
+ def __init__(self, config: CLIPVisionConfig):
684
+ super().__init__(config)
685
+ self.vision_model = CLIPVisionTransformer(config)
686
+ # Initialize weights and apply final processing
687
+ self.post_init()
688
+
689
+ def get_input_embeddings(self) -> nn.Module:
690
+ return self.vision_model.embeddings.patch_embedding
691
+
692
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
693
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
694
+ def forward(
695
+ self,
696
+ pixel_values: Optional[torch.FloatTensor] = None,
697
+ output_attentions: Optional[bool] = None,
698
+ output_hidden_states: Optional[bool] = None,
699
+ return_dict: Optional[bool] = None,
700
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
701
+ r"""
702
+ Returns:
703
+
704
+ Examples:
705
+
706
+ ```python
707
+ >>> from PIL import Image
708
+ >>> import requests
709
+ >>> from transformers import AutoProcessor, CLIPVisionModel
710
+
711
+ >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
712
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
713
+
714
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
715
+ >>> image = Image.open(requests.get(url, stream=True).raw)
716
+
717
+ >>> inputs = processor(images=image, return_tensors="pt")
718
+
719
+ >>> outputs = model(**inputs)
720
+ >>> last_hidden_state = outputs.last_hidden_state
721
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
722
+ ```"""
723
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
724
+
725
+ return self.vision_model(
726
+ pixel_values=pixel_values,
727
+ output_attentions=output_attentions,
728
+ output_hidden_states=output_hidden_states,
729
+ return_dict=return_dict,
730
+ )
731
+
732
+
733
+ @add_start_docstrings(CLIP_START_DOCSTRING)
734
+ class LanguageBindAudio(CLIPPreTrainedModel):
735
+ config_class = LanguageBindAudioConfig
736
+
737
+ def __init__(self, config: LanguageBindAudioConfig):
738
+ super().__init__(config)
739
+
740
+ if not isinstance(config.text_config, CLIPTextConfig):
741
+ raise ValueError(
742
+ "config.text_config is expected to be of type CLIPTextConfig but is of type"
743
+ f" {type(config.text_config)}."
744
+ )
745
+
746
+ if not isinstance(config.vision_config, CLIPVisionConfig):
747
+ raise ValueError(
748
+ "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
749
+ f" {type(config.vision_config)}."
750
+ )
751
+
752
+ text_config = config.text_config
753
+ vision_config = config.vision_config
754
+ self.add_time_attn = vision_config.add_time_attn
755
+ self.lora_r = vision_config.lora_r
756
+ self.lora_alpha = vision_config.lora_alpha
757
+ self.lora_dropout = vision_config.lora_dropout
758
+
759
+ self.projection_dim = config.projection_dim
760
+ self.text_embed_dim = text_config.hidden_size
761
+ self.vision_embed_dim = vision_config.hidden_size
762
+
763
+ self.text_model = CLIPTextTransformer(text_config)
764
+ self.vision_model = CLIPVisionTransformer(vision_config)
765
+
766
+ self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
767
+ self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
768
+ self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
769
+
770
+ # Initialize weights and apply final processing
771
+ self.post_init()
772
+ self.convert_to_lora()
773
+ self.resize_pos(self.vision_model.embeddings, vision_config)
774
+
775
+ def convert_to_lora(self):
776
+ if self.lora_r == 0:
777
+ return
778
+ if self.add_time_attn:
779
+ target_modules = ["temporal_attn.k_proj", "temporal_attn.v_proj",
780
+ "temporal_attn.q_proj", "temporal_attn.out_proj",
781
+ "temporal_mlp.fc1", "temporal_mlp.fc2"]
782
+ else:
783
+ target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"]
784
+ config = LoraConfig(
785
+ r=self.lora_r, # 16
786
+ lora_alpha=self.lora_alpha, # 16
787
+ target_modules=target_modules, # self_attn.out_proj
788
+ lora_dropout=self.lora_dropout, # 0.1
789
+ bias="none",
790
+ modules_to_save=[],
791
+ )
792
+ self.vision_model.encoder.is_gradient_checkpointing = False
793
+ self.vision_model.encoder = get_peft_model(self.vision_model.encoder, config)
794
+
795
+ def resize_pos(self, m, vision_config):
796
+ # convert embedding
797
+ if vision_config.num_mel_bins!=0 and vision_config.target_length!=0:
798
+ m.image_size = [vision_config.num_mel_bins, vision_config.target_length]
799
+ m.config.image_size = [m.image_size, m.image_size] if isinstance(m.image_size, int) else m.image_size
800
+ # pos resize
801
+ old_pos_embed_state_dict = m.position_embedding.state_dict()
802
+ old_pos_embed = old_pos_embed_state_dict['weight']
803
+ dtype = old_pos_embed.dtype
804
+ grid_size = [m.config.image_size[0] // m.patch_size, m.config.image_size[1] // m.patch_size]
805
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
806
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
807
+ if new_seq_len == old_pos_embed.shape[0]:
808
+ # m.to(args.device)
809
+ return
810
+
811
+ m.num_patches = grid_size[0] * grid_size[1]
812
+ m.num_positions = m.num_patches + 1
813
+ m.register_buffer("position_ids", torch.arange(m.num_positions).expand((1, -1)))
814
+ new_position_embedding = nn.Embedding(m.num_positions, m.embed_dim)
815
+
816
+ if extra_tokens:
817
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
818
+ else:
819
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
820
+ old_grid_size = [int(math.sqrt(len(pos_emb_img)))] * 2
821
+
822
+ # if is_master(args):
823
+ # logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
824
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
825
+ pos_emb_img = F.interpolate(
826
+ pos_emb_img,
827
+ size=grid_size,
828
+ mode='bicubic',
829
+ antialias=True,
830
+ align_corners=False,
831
+ )
832
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
833
+ if pos_emb_tok is not None:
834
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
835
+ else:
836
+ new_pos_embed = pos_emb_img
837
+ old_pos_embed_state_dict['weight'] = new_pos_embed.to(dtype)
838
+ m.position_embedding = new_position_embedding
839
+ m.position_embedding.load_state_dict(old_pos_embed_state_dict)
840
+
841
+ # m.to(args.device)
842
+
843
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
844
+ def get_text_features(
845
+ self,
846
+ input_ids: Optional[torch.Tensor] = None,
847
+ attention_mask: Optional[torch.Tensor] = None,
848
+ position_ids: Optional[torch.Tensor] = None,
849
+ output_attentions: Optional[bool] = None,
850
+ output_hidden_states: Optional[bool] = None,
851
+ return_dict: Optional[bool] = None,
852
+ ) -> torch.FloatTensor:
853
+ r"""
854
+ Returns:
855
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
856
+ applying the projection layer to the pooled output of [`CLIPTextModel`].
857
+
858
+ Examples:
859
+
860
+ ```python
861
+ >>> from transformers import AutoTokenizer, CLIPModel
862
+
863
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
864
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
865
+
866
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
867
+ >>> text_features = model.get_text_features(**inputs)
868
+ ```"""
869
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
870
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
871
+ output_hidden_states = (
872
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
873
+ )
874
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
875
+
876
+ text_outputs = self.text_model(
877
+ input_ids=input_ids,
878
+ attention_mask=attention_mask,
879
+ position_ids=position_ids,
880
+ output_attentions=output_attentions,
881
+ output_hidden_states=output_hidden_states,
882
+ return_dict=return_dict,
883
+ )
884
+
885
+ pooled_output = text_outputs[1]
886
+ text_features = self.text_projection(pooled_output)
887
+
888
+ return text_features
889
+
890
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
891
+ def get_image_features(
892
+ self,
893
+ pixel_values: Optional[torch.FloatTensor] = None,
894
+ output_attentions: Optional[bool] = None,
895
+ output_hidden_states: Optional[bool] = None,
896
+ return_dict: Optional[bool] = None,
897
+ ) -> torch.FloatTensor:
898
+ r"""
899
+ Returns:
900
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
901
+ applying the projection layer to the pooled output of [`CLIPVisionModel`].
902
+
903
+ Examples:
904
+
905
+ ```python
906
+ >>> from PIL import Image
907
+ >>> import requests
908
+ >>> from transformers import AutoProcessor, CLIPModel
909
+
910
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
911
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
912
+
913
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
914
+ >>> image = Image.open(requests.get(url, stream=True).raw)
915
+
916
+ >>> inputs = processor(images=image, return_tensors="pt")
917
+
918
+ >>> image_features = model.get_image_features(**inputs)
919
+ ```"""
920
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
921
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
922
+ output_hidden_states = (
923
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
924
+ )
925
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
926
+
927
+ vision_outputs = self.vision_model(
928
+ pixel_values=pixel_values,
929
+ output_attentions=output_attentions,
930
+ output_hidden_states=output_hidden_states,
931
+ return_dict=return_dict,
932
+ )
933
+
934
+ pooled_output = vision_outputs[1] # pooled_output
935
+ image_features = self.visual_projection(pooled_output)
936
+
937
+ return image_features
938
+
939
+ @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
940
+ @replace_return_docstrings(output_type=CLIPOutput, config_class=LanguageBindAudioConfig)
941
+ def forward(
942
+ self,
943
+ input_ids: Optional[torch.LongTensor] = None,
944
+ pixel_values: Optional[torch.FloatTensor] = None,
945
+ attention_mask: Optional[torch.Tensor] = None,
946
+ position_ids: Optional[torch.LongTensor] = None,
947
+ return_loss: Optional[bool] = None,
948
+ output_attentions: Optional[bool] = None,
949
+ output_hidden_states: Optional[bool] = None,
950
+ return_dict: Optional[bool] = None,
951
+ ) -> Union[Tuple, CLIPOutput]:
952
+ r"""
953
+ Returns:
954
+
955
+ Examples:
956
+
957
+ ```python
958
+ >>> from PIL import Image
959
+ >>> import requests
960
+ >>> from transformers import AutoProcessor, CLIPModel
961
+
962
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
963
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
964
+
965
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
966
+ >>> image = Image.open(requests.get(url, stream=True).raw)
967
+
968
+ >>> inputs = processor(
969
+ ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
970
+ ... )
971
+
972
+ >>> outputs = model(**inputs)
973
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
974
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
975
+ ```"""
976
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
977
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
978
+ output_hidden_states = (
979
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
980
+ )
981
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
982
+
983
+ vision_outputs = self.vision_model(
984
+ pixel_values=pixel_values,
985
+ output_attentions=output_attentions,
986
+ output_hidden_states=output_hidden_states,
987
+ return_dict=return_dict,
988
+ )
989
+
990
+ text_outputs = self.text_model(
991
+ input_ids=input_ids,
992
+ attention_mask=attention_mask,
993
+ position_ids=position_ids,
994
+ output_attentions=output_attentions,
995
+ output_hidden_states=output_hidden_states,
996
+ return_dict=return_dict,
997
+ )
998
+
999
+ image_embeds = vision_outputs[1]
1000
+ image_embeds = self.visual_projection(image_embeds)
1001
+
1002
+ text_embeds = text_outputs[1]
1003
+ text_embeds = self.text_projection(text_embeds)
1004
+
1005
+ # normalized features
1006
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1007
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1008
+
1009
+ # cosine similarity as logits
1010
+ logit_scale = self.logit_scale.exp()
1011
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
1012
+ logits_per_image = logits_per_text.t()
1013
+
1014
+ loss = None
1015
+ if return_loss:
1016
+ loss = clip_loss(logits_per_text)
1017
+
1018
+ if not return_dict:
1019
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1020
+ return ((loss,) + output) if loss is not None else output
1021
+
1022
+ return CLIPOutput(
1023
+ loss=loss,
1024
+ logits_per_image=logits_per_image,
1025
+ logits_per_text=logits_per_text,
1026
+ text_embeds=text_embeds,
1027
+ image_embeds=image_embeds,
1028
+ text_model_output=text_outputs,
1029
+ vision_model_output=vision_outputs,
1030
+ )
models/multimodal_encoder/languagebind/audio/processing_audio.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ # import torchaudio
5
+ from torchvision import transforms
6
+ from transformers import ProcessorMixin, BatchEncoding
7
+ from transformers.image_processing_utils import BatchFeature
8
+ from torch.nn import functional as F
9
+
10
+
11
+ def make_list_of_images(x):
12
+ if not isinstance(x, list):
13
+ return [x]
14
+ return x
15
+
16
+
17
+ #torchaudio.set_audio_backend("soundfile")
18
+
19
+ def torchaudio_loader(path):
20
+ return torchaudio.load(path)
21
+
22
+ def int16_to_float32_torch(x):
23
+ return (x / 32767.0).type(torch.float32)
24
+
25
+ def float32_to_int16_torch(x):
26
+ x = torch.clamp(x, min=-1., max=1.)
27
+ return (x * 32767.).type(torch.int16)
28
+
29
+ DEFAULT_AUDIO_FRAME_SHIFT_MS = 10
30
+
31
+ class AudioTransform:
32
+ def __init__(self, config):
33
+ self.sample_rate = config.audio_sample_rate
34
+ self.num_mel_bins = config.num_mel_bins
35
+ self.target_length = config.target_length
36
+ self.audio_mean = config.audio_mean
37
+ self.audio_std = config.audio_std
38
+ # mean=-4.2677393
39
+ # std=4.5689974
40
+ self.norm = transforms.Normalize(mean=self.audio_mean, std=self.audio_std)
41
+
42
+ def __call__(self, audio_data_and_origin_sr):
43
+ audio_data, origin_sr = audio_data_and_origin_sr
44
+ if self.sample_rate != origin_sr:
45
+ # print(audio_data.shape, origin_sr)
46
+ audio_data = torchaudio.functional.resample(audio_data, orig_freq=origin_sr, new_freq=self.sample_rate)
47
+ waveform_melspec = self.waveform2melspec(audio_data[0])
48
+ return self.norm(waveform_melspec)
49
+
50
+ def waveform2melspec(self, audio_data):
51
+ max_len = self.target_length * self.sample_rate // 100
52
+ if audio_data.shape[-1] > max_len:
53
+ mel = self.get_mel(audio_data)
54
+ # split to three parts
55
+ chunk_frames = self.target_length
56
+ total_frames = mel.shape[0]
57
+ ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3)
58
+ # print('total_frames-chunk_frames:', total_frames-chunk_frames,
59
+ # 'len(audio_data):', len(audio_data),
60
+ # 'chunk_frames:', chunk_frames,
61
+ # 'total_frames:', total_frames)
62
+ if len(ranges[1]) == 0: # if the audio is too short, we just use the first chunk
63
+ ranges[1] = [0]
64
+ if len(ranges[2]) == 0: # if the audio is too short, we just use the first chunk
65
+ ranges[2] = [0]
66
+ # randomly choose index for each part
67
+ # idx_front = np.random.choice(ranges[0])
68
+ # idx_middle = np.random.choice(ranges[1])
69
+ # idx_back = np.random.choice(ranges[2])
70
+ idx_front = ranges[0][0] # fixed
71
+ idx_middle = ranges[1][0]
72
+ idx_back = ranges[2][0]
73
+ # select mel
74
+ mel_chunk_front = mel[idx_front:idx_front + chunk_frames, :]
75
+ mel_chunk_middle = mel[idx_middle:idx_middle + chunk_frames, :]
76
+ mel_chunk_back = mel[idx_back:idx_back + chunk_frames, :]
77
+ # stack
78
+ mel_fusion = torch.stack([mel_chunk_front, mel_chunk_middle, mel_chunk_back], dim=0)
79
+ elif audio_data.shape[-1] < max_len: # padding if too short
80
+ n_repeat = int(max_len / len(audio_data))
81
+ audio_data = audio_data.repeat(n_repeat)
82
+ audio_data = F.pad(
83
+ audio_data,
84
+ (0, max_len - len(audio_data)),
85
+ mode="constant",
86
+ value=0,
87
+ )
88
+ mel = self.get_mel(audio_data)
89
+ mel_fusion = torch.stack([mel, mel, mel], dim=0)
90
+ else: # if equal
91
+ mel = self.get_mel(audio_data)
92
+ mel_fusion = torch.stack([mel, mel, mel], dim=0)
93
+
94
+ # twice check
95
+ p = self.target_length - mel_fusion.shape[1]
96
+
97
+ # if abs(p) / self.target_length > 0.2:
98
+ # logging.warning(
99
+ # "Large gap between audio n_frames(%d) and "
100
+ # "target_length (%d). Is the audio_target_length "
101
+ # "setting correct?",
102
+ # mel_fusion.shape[1],
103
+ # self.target_length,
104
+ # )
105
+
106
+ # cut and pad
107
+ if p > 0:
108
+ m = torch.nn.ZeroPad2d((0, 0, 0, p))
109
+ mel_fusion = m(mel_fusion)
110
+ elif p < 0:
111
+ mel_fusion = mel_fusion[:, 0: self.target_length, :]
112
+
113
+ mel_fusion = mel_fusion.transpose(1, 2) # [3, target_length, mel_bins] -> [3, mel_bins, target_length]
114
+ return mel_fusion
115
+
116
+ def get_mel(self, audio_data):
117
+ # mel shape: (n_mels, T)
118
+ audio_data -= audio_data.mean()
119
+ mel = torchaudio.compliance.kaldi.fbank(
120
+ audio_data.unsqueeze(0),
121
+ htk_compat=True,
122
+ sample_frequency=self.sample_rate,
123
+ use_energy=False,
124
+ window_type="hanning",
125
+ num_mel_bins=self.num_mel_bins,
126
+ dither=0.0,
127
+ frame_length=25,
128
+ frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS,
129
+ )
130
+ return mel # (T, n_mels)
131
+
132
+ def get_audio_transform(config):
133
+ config = config.vision_config
134
+ return AudioTransform(config)
135
+
136
+
137
+ def load_and_transform_audio(
138
+ audio_path,
139
+ transform,
140
+ ):
141
+ waveform_and_sr = torchaudio_loader(audio_path)
142
+ audio_outputs = transform(waveform_and_sr)
143
+
144
+ return audio_outputs
145
+
146
+ class LanguageBindAudioProcessor(ProcessorMixin):
147
+ attributes = []
148
+ tokenizer_class = ("LanguageBindAudioTokenizer")
149
+
150
+ def __init__(self, config, tokenizer=None, **kwargs):
151
+ super().__init__(**kwargs)
152
+ self.config = config
153
+ self.transform = get_audio_transform(config)
154
+ self.image_processor = load_and_transform_audio
155
+ self.tokenizer = tokenizer
156
+
157
+ def __call__(self, images=None, text=None, context_length=77, return_tensors=None, **kwargs):
158
+ if text is None and images is None:
159
+ raise ValueError("You have to specify either text or images. Both cannot be none.")
160
+
161
+ if text is not None:
162
+ encoding = self.tokenizer(text, max_length=context_length, padding='max_length',
163
+ truncation=True, return_tensors=return_tensors, **kwargs)
164
+
165
+ if images is not None:
166
+ images = make_list_of_images(images)
167
+ image_features = [self.image_processor(image, self.transform) for image in images]
168
+ image_features = torch.stack(image_features)
169
+
170
+ if text is not None and images is not None:
171
+ encoding["pixel_values"] = image_features
172
+ return encoding
173
+ elif text is not None:
174
+ return encoding
175
+ else:
176
+ return {"pixel_values": image_features}
177
+
178
+ def batch_decode(self, skip_special_tokens=True, *args, **kwargs):
179
+ """
180
+ This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
181
+ refer to the docstring of this method for more information.
182
+ """
183
+ return self.tokenizer.batch_decode(*args, skip_special_tokens=skip_special_tokens, **kwargs)
184
+
185
+ def decode(self, skip_special_tokens=True, *args, **kwargs):
186
+ """
187
+ This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
188
+ the docstring of this method for more information.
189
+ """
190
+ return self.tokenizer.decode(*args, skip_special_tokens=skip_special_tokens, **kwargs)
models/multimodal_encoder/languagebind/audio/tokenization_audio.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPTokenizer
2
+ from transformers.utils import logging
3
+
4
+ logger = logging.get_logger(__name__)
5
+
6
+ VOCAB_FILES_NAMES = {
7
+ "vocab_file": "vocab.json",
8
+ "merges_file": "merges.txt",
9
+ }
10
+
11
+ PRETRAINED_VOCAB_FILES_MAP = {
12
+ "vocab_file": {
13
+ "lb203/LanguageBind-Audio": "https://huggingface.co/lb203/LanguageBind-Audio/resolve/main/vocab.json",
14
+ },
15
+ "merges_file": {
16
+ "lb203/LanguageBind-Audio": "https://huggingface.co/lb203/LanguageBind-Audio/resolve/main/merges.txt",
17
+ },
18
+ }
19
+
20
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
21
+ "lb203/LanguageBind-Audio": 77,
22
+ }
23
+
24
+
25
+ PRETRAINED_INIT_CONFIGURATION = {
26
+ "lb203/LanguageBind-Audio": {},
27
+ }
28
+
29
+ class LanguageBindAudioTokenizer(CLIPTokenizer):
30
+ """
31
+ Construct a CLIP tokenizer. Based on byte-level Byte-Pair-Encoding.
32
+
33
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
34
+ this superclass for more information regarding those methods.
35
+
36
+ Args:
37
+ vocab_file (`str`):
38
+ Path to the vocabulary file.
39
+ merges_file (`str`):
40
+ Path to the merges file.
41
+ errors (`str`, *optional*, defaults to `"replace"`):
42
+ Paradigm to follow when decoding bytes to UTF-8. See
43
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
44
+ unk_token (`str`, *optional*, defaults to `<|endoftext|>`):
45
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
46
+ token instead.
47
+ bos_token (`str`, *optional*, defaults to `<|startoftext|>`):
48
+ The beginning of sequence token.
49
+ eos_token (`str`, *optional*, defaults to `<|endoftext|>`):
50
+ The end of sequence token.
51
+ """
52
+
53
+ vocab_files_names = VOCAB_FILES_NAMES
54
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
55
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
56
+ model_input_names = ["input_ids", "attention_mask"]
57
+
58
+ def __init__(
59
+ self,
60
+ vocab_file,
61
+ merges_file,
62
+ errors="replace",
63
+ unk_token="<|endoftext|>",
64
+ bos_token="<|startoftext|>",
65
+ eos_token="<|endoftext|>",
66
+ pad_token="<|endoftext|>", # hack to enable padding
67
+ **kwargs,
68
+ ):
69
+ super(LanguageBindAudioTokenizer, self).__init__(
70
+ vocab_file,
71
+ merges_file,
72
+ errors,
73
+ unk_token,
74
+ bos_token,
75
+ eos_token,
76
+ pad_token, # hack to enable padding
77
+ **kwargs,)
models/multimodal_encoder/languagebind/depth/configuration_depth.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ from typing import Union
4
+
5
+ from transformers import PretrainedConfig
6
+ from transformers.utils import logging
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+
11
+
12
+
13
+
14
+
15
+
16
+ class CLIPTextConfig(PretrainedConfig):
17
+ r"""
18
+ This is the configuration class to store the configuration of a [`CLIPTextModel`]. It is used to instantiate a CLIP
19
+ text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
20
+ with the defaults will yield a similar configuration to that of the text encoder of the CLIP
21
+ [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
22
+
23
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
24
+ documentation from [`PretrainedConfig`] for more information.
25
+
26
+ Args:
27
+ vocab_size (`int`, *optional*, defaults to 49408):
28
+ Vocabulary size of the CLIP text model. Defines the number of different tokens that can be represented by
29
+ the `inputs_ids` passed when calling [`CLIPModel`].
30
+ hidden_size (`int`, *optional*, defaults to 512):
31
+ Dimensionality of the encoder layers and the pooler layer.
32
+ intermediate_size (`int`, *optional*, defaults to 2048):
33
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
34
+ num_hidden_layers (`int`, *optional*, defaults to 12):
35
+ Number of hidden layers in the Transformer encoder.
36
+ num_attention_heads (`int`, *optional*, defaults to 8):
37
+ Number of attention heads for each attention layer in the Transformer encoder.
38
+ max_position_embeddings (`int`, *optional*, defaults to 77):
39
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
40
+ just in case (e.g., 512 or 1024 or 2048).
41
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
42
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
43
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
44
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
45
+ The epsilon used by the layer normalization layers.
46
+ attention_dropout (`float`, *optional*, defaults to 0.0):
47
+ The dropout ratio for the attention probabilities.
48
+ initializer_range (`float`, *optional*, defaults to 0.02):
49
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
50
+ initializer_factor (`float`, *optional*, defaults to 1):
51
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
52
+ testing).
53
+
54
+ Example:
55
+
56
+ ```python
57
+ >>> from transformers import CLIPTextConfig, CLIPTextModel
58
+
59
+ >>> # Initializing a CLIPTextConfig with openai/clip-vit-base-patch32 style configuration
60
+ >>> configuration = CLIPTextConfig()
61
+
62
+ >>> # Initializing a CLIPTextModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
63
+ >>> model = CLIPTextModel(configuration)
64
+
65
+ >>> # Accessing the model configuration
66
+ >>> configuration = model.config
67
+ ```"""
68
+ model_type = "clip_text_model"
69
+
70
+ def __init__(
71
+ self,
72
+ vocab_size=49408,
73
+ hidden_size=512,
74
+ intermediate_size=2048,
75
+ projection_dim=512,
76
+ num_hidden_layers=12,
77
+ num_attention_heads=8,
78
+ max_position_embeddings=77,
79
+ hidden_act="quick_gelu",
80
+ layer_norm_eps=1e-5,
81
+ attention_dropout=0.0,
82
+ initializer_range=0.02,
83
+ initializer_factor=1.0,
84
+ # This differs from `CLIPTokenizer`'s default and from openai/clip
85
+ # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
86
+ pad_token_id=1,
87
+ bos_token_id=49406,
88
+ eos_token_id=49407,
89
+ **kwargs,
90
+ ):
91
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
92
+
93
+ self.vocab_size = vocab_size
94
+ self.hidden_size = hidden_size
95
+ self.intermediate_size = intermediate_size
96
+ self.projection_dim = projection_dim
97
+ self.num_hidden_layers = num_hidden_layers
98
+ self.num_attention_heads = num_attention_heads
99
+ self.max_position_embeddings = max_position_embeddings
100
+ self.layer_norm_eps = layer_norm_eps
101
+ self.hidden_act = hidden_act
102
+ self.initializer_range = initializer_range
103
+ self.initializer_factor = initializer_factor
104
+ self.attention_dropout = attention_dropout
105
+ self.add_time_attn = False ######################################
106
+
107
+ @classmethod
108
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
109
+ cls._set_token_in_kwargs(kwargs)
110
+
111
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
112
+
113
+ # get the text config dict if we are loading from CLIPConfig
114
+ if config_dict.get("model_type") == "clip":
115
+ config_dict = config_dict["text_config"]
116
+
117
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
118
+ logger.warning(
119
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
120
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
121
+ )
122
+
123
+ return cls.from_dict(config_dict, **kwargs)
124
+
125
+
126
+
127
+
128
+ class CLIPVisionConfig(PretrainedConfig):
129
+ r"""
130
+ This is the configuration class to store the configuration of a [`CLIPVisionModel`]. It is used to instantiate a
131
+ CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a
132
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP
133
+ [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
134
+
135
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
136
+ documentation from [`PretrainedConfig`] for more information.
137
+
138
+ Args:
139
+ hidden_size (`int`, *optional*, defaults to 768):
140
+ Dimensionality of the encoder layers and the pooler layer.
141
+ intermediate_size (`int`, *optional*, defaults to 3072):
142
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
143
+ num_hidden_layers (`int`, *optional*, defaults to 12):
144
+ Number of hidden layers in the Transformer encoder.
145
+ num_attention_heads (`int`, *optional*, defaults to 12):
146
+ Number of attention heads for each attention layer in the Transformer encoder.
147
+ image_size (`int`, *optional*, defaults to 224):
148
+ The size (resolution) of each image.
149
+ patch_size (`int`, *optional*, defaults to 32):
150
+ The size (resolution) of each patch.
151
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
152
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
153
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
154
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
155
+ The epsilon used by the layer normalization layers.
156
+ attention_dropout (`float`, *optional*, defaults to 0.0):
157
+ The dropout ratio for the attention probabilities.
158
+ initializer_range (`float`, *optional*, defaults to 0.02):
159
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
160
+ initializer_factor (`float`, *optional*, defaults to 1):
161
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
162
+ testing).
163
+
164
+ Example:
165
+
166
+ ```python
167
+ >>> from transformers import CLIPVisionConfig, CLIPVisionModel
168
+
169
+ >>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration
170
+ >>> configuration = CLIPVisionConfig()
171
+
172
+ >>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
173
+ >>> model = CLIPVisionModel(configuration)
174
+
175
+ >>> # Accessing the model configuration
176
+ >>> configuration = model.config
177
+ ```"""
178
+
179
+ model_type = "clip_vision_model"
180
+
181
+ def __init__(
182
+ self,
183
+ hidden_size=768,
184
+ intermediate_size=3072,
185
+ projection_dim=512,
186
+ num_hidden_layers=12,
187
+ num_attention_heads=12,
188
+ num_channels=3,
189
+ image_size=224,
190
+ patch_size=32,
191
+ hidden_act="quick_gelu",
192
+ layer_norm_eps=1e-5,
193
+ attention_dropout=0.0,
194
+ initializer_range=0.02,
195
+ initializer_factor=1.0,
196
+
197
+ add_time_attn=False, ################################
198
+ num_frames=1, ################################
199
+ force_patch_dropout=0.0, ################################
200
+ lora_r=2, ################################
201
+ lora_alpha=16, ################################
202
+ lora_dropout=0.0, ################################
203
+ num_mel_bins=0.0, ################################
204
+ target_length=0.0, ################################
205
+ max_depth=10,
206
+ video_decode_backend='decord', #########################
207
+ **kwargs,
208
+ ):
209
+ super().__init__(**kwargs)
210
+
211
+ self.hidden_size = hidden_size
212
+ self.intermediate_size = intermediate_size
213
+ self.projection_dim = projection_dim
214
+ self.num_hidden_layers = num_hidden_layers
215
+ self.num_attention_heads = num_attention_heads
216
+ self.num_channels = num_channels
217
+ self.patch_size = patch_size
218
+ self.image_size = image_size
219
+ self.initializer_range = initializer_range
220
+ self.initializer_factor = initializer_factor
221
+ self.attention_dropout = attention_dropout
222
+ self.layer_norm_eps = layer_norm_eps
223
+ self.hidden_act = hidden_act
224
+
225
+ self.add_time_attn = add_time_attn ################
226
+ self.num_frames = num_frames ################
227
+ self.force_patch_dropout = force_patch_dropout ################
228
+ self.lora_r = lora_r ################
229
+ self.lora_alpha = lora_alpha ################
230
+ self.lora_dropout = lora_dropout ################
231
+ self.num_mel_bins = num_mel_bins ################
232
+ self.target_length = target_length ################
233
+ self.max_depth = max_depth ################
234
+ self.video_decode_backend = video_decode_backend ################
235
+
236
+ @classmethod
237
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
238
+ cls._set_token_in_kwargs(kwargs)
239
+
240
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
241
+
242
+ # get the vision config dict if we are loading from CLIPConfig
243
+ if config_dict.get("model_type") == "clip":
244
+ config_dict = config_dict["vision_config"]
245
+
246
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
247
+ logger.warning(
248
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
249
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
250
+ )
251
+
252
+ return cls.from_dict(config_dict, **kwargs)
253
+
254
+
255
+ class LanguageBindDepthConfig(PretrainedConfig):
256
+ r"""
257
+ [`CLIPConfig`] is the configuration class to store the configuration of a [`CLIPModel`]. It is used to instantiate
258
+ a CLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating
259
+ a configuration with the defaults will yield a similar configuration to that of the CLIP
260
+ [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
261
+
262
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
263
+ documentation from [`PretrainedConfig`] for more information.
264
+
265
+ Args:
266
+ text_config (`dict`, *optional*):
267
+ Dictionary of configuration options used to initialize [`CLIPTextConfig`].
268
+ vision_config (`dict`, *optional*):
269
+ Dictionary of configuration options used to initialize [`CLIPVisionConfig`].
270
+ projection_dim (`int`, *optional*, defaults to 512):
271
+ Dimentionality of text and vision projection layers.
272
+ logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
273
+ The inital value of the *logit_scale* paramter. Default is used as per the original CLIP implementation.
274
+ kwargs (*optional*):
275
+ Dictionary of keyword arguments.
276
+
277
+ Example:
278
+
279
+ ```python
280
+ >>> from transformers import CLIPConfig, CLIPModel
281
+
282
+ >>> # Initializing a CLIPConfig with openai/clip-vit-base-patch32 style configuration
283
+ >>> configuration = CLIPConfig()
284
+
285
+ >>> # Initializing a CLIPModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
286
+ >>> model = CLIPModel(configuration)
287
+
288
+ >>> # Accessing the model configuration
289
+ >>> configuration = model.config
290
+
291
+ >>> # We can also initialize a CLIPConfig from a CLIPTextConfig and a CLIPVisionConfig
292
+ >>> from transformers import CLIPTextConfig, CLIPVisionConfig
293
+
294
+ >>> # Initializing a CLIPText and CLIPVision configuration
295
+ >>> config_text = CLIPTextConfig()
296
+ >>> config_vision = CLIPVisionConfig()
297
+
298
+ >>> config = CLIPConfig.from_text_vision_configs(config_text, config_vision)
299
+ ```"""
300
+
301
+ model_type = "LanguageBindDepth"
302
+ is_composition = True
303
+
304
+ def __init__(
305
+ self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
306
+ ):
307
+ # If `_config_dict` exist, we use them for the backward compatibility.
308
+ # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
309
+ # of confusion!).
310
+ text_config_dict = kwargs.pop("text_config_dict", None)
311
+ vision_config_dict = kwargs.pop("vision_config_dict", None)
312
+
313
+ super().__init__(**kwargs)
314
+
315
+ # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
316
+ # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
317
+ # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
318
+ if text_config_dict is not None:
319
+ if text_config is None:
320
+ text_config = {}
321
+
322
+ # This is the complete result when using `text_config_dict`.
323
+ _text_config_dict = CLIPTextConfig(**text_config_dict).to_dict()
324
+
325
+ # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
326
+ for key, value in _text_config_dict.items():
327
+ if key in text_config and value != text_config[key] and key not in ["transformers_version"]:
328
+ # If specified in `text_config_dict`
329
+ if key in text_config_dict:
330
+ message = (
331
+ f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. "
332
+ f'The value `text_config_dict["{key}"]` will be used instead.'
333
+ )
334
+ # If inferred from default argument values (just to be super careful)
335
+ else:
336
+ message = (
337
+ f"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The "
338
+ f'value `text_config["{key}"]` will be overriden.'
339
+ )
340
+ logger.warning(message)
341
+
342
+ # Update all values in `text_config` with the ones in `_text_config_dict`.
343
+ text_config.update(_text_config_dict)
344
+
345
+ if vision_config_dict is not None:
346
+ if vision_config is None:
347
+ vision_config = {}
348
+
349
+ # This is the complete result when using `vision_config_dict`.
350
+ _vision_config_dict = CLIPVisionConfig(**vision_config_dict).to_dict()
351
+ # convert keys to string instead of integer
352
+ if "id2label" in _vision_config_dict:
353
+ _vision_config_dict["id2label"] = {
354
+ str(key): value for key, value in _vision_config_dict["id2label"].items()
355
+ }
356
+
357
+ # Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different.
358
+ for key, value in _vision_config_dict.items():
359
+ if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]:
360
+ # If specified in `vision_config_dict`
361
+ if key in vision_config_dict:
362
+ message = (
363
+ f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different "
364
+ f'values. The value `vision_config_dict["{key}"]` will be used instead.'
365
+ )
366
+ # If inferred from default argument values (just to be super careful)
367
+ else:
368
+ message = (
369
+ f"`vision_config_dict` is provided which will be used to initialize `CLIPVisionConfig`. "
370
+ f'The value `vision_config["{key}"]` will be overriden.'
371
+ )
372
+ logger.warning(message)
373
+
374
+ # Update all values in `vision_config` with the ones in `_vision_config_dict`.
375
+ vision_config.update(_vision_config_dict)
376
+
377
+ if text_config is None:
378
+ text_config = {}
379
+ logger.info("`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.")
380
+
381
+ if vision_config is None:
382
+ vision_config = {}
383
+ logger.info("`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.")
384
+
385
+ self.text_config = CLIPTextConfig(**text_config)
386
+ self.vision_config = CLIPVisionConfig(**vision_config)
387
+
388
+ self.projection_dim = projection_dim
389
+ self.logit_scale_init_value = logit_scale_init_value
390
+ self.initializer_factor = 1.0
391
+
392
+ @classmethod
393
+ def from_text_vision_configs(cls, text_config: CLIPTextConfig, vision_config: CLIPVisionConfig, **kwargs):
394
+ r"""
395
+ Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model
396
+ configuration.
397
+
398
+ Returns:
399
+ [`CLIPConfig`]: An instance of a configuration object
400
+ """
401
+
402
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
403
+
404
+ def to_dict(self):
405
+ """
406
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
407
+
408
+ Returns:
409
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
410
+ """
411
+ output = copy.deepcopy(self.__dict__)
412
+ output["text_config"] = self.text_config.to_dict()
413
+ output["vision_config"] = self.vision_config.to_dict()
414
+ output["model_type"] = self.__class__.model_type
415
+ return output
416
+
417
+
418
+
419
+
420
+
421
+
422
+
423
+
424
+
425
+
models/multimodal_encoder/languagebind/depth/modeling_depth.py ADDED
@@ -0,0 +1,1030 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from peft import LoraConfig, get_peft_model
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from transformers import PreTrainedModel, add_start_docstrings
10
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
11
+ from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \
12
+ CLIPVisionModelWithProjection, CLIPTextModelWithProjection, _expand_mask, CLIPOutput, clip_loss
13
+ from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
14
+
15
+ from .configuration_depth import LanguageBindDepthConfig, CLIPVisionConfig, CLIPTextConfig
16
+
17
+
18
+
19
+ class PatchDropout(nn.Module):
20
+ """
21
+ https://arxiv.org/abs/2212.00794
22
+ """
23
+
24
+ def __init__(self, prob, exclude_first_token=True):
25
+ super().__init__()
26
+ assert 0 <= prob < 1.
27
+ self.prob = prob
28
+ self.exclude_first_token = exclude_first_token # exclude CLS token
29
+
30
+ def forward(self, x, B, T):
31
+ if not self.training or self.prob == 0.:
32
+ return x
33
+
34
+ if self.exclude_first_token:
35
+ cls_tokens, x = x[:, :1], x[:, 1:]
36
+ else:
37
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
38
+
39
+ batch = x.size()[0]
40
+ num_tokens = x.size()[1]
41
+
42
+ batch_indices = torch.arange(batch)
43
+ batch_indices = batch_indices[..., None]
44
+
45
+ keep_prob = 1 - self.prob
46
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
47
+
48
+ if T == 1:
49
+ rand = torch.randn(batch, num_tokens)
50
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
51
+ else:
52
+ rand = torch.randn(B, num_tokens)
53
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
54
+ patch_indices_keep = patch_indices_keep.unsqueeze(1).repeat(1, T, 1)
55
+ patch_indices_keep = rearrange(patch_indices_keep, 'b t n -> (b t) n')
56
+
57
+
58
+ x = x[batch_indices, patch_indices_keep]
59
+
60
+ if self.exclude_first_token:
61
+ x = torch.cat((cls_tokens, x), dim=1)
62
+
63
+ return x
64
+
65
+ class CLIPEncoderLayer(nn.Module):
66
+ def __init__(self, config: LanguageBindDepthConfig):
67
+ super().__init__()
68
+ self.embed_dim = config.hidden_size
69
+ self.self_attn = CLIPAttention(config)
70
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
71
+ self.mlp = CLIPMLP(config)
72
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
73
+
74
+ self.add_time_attn = config.add_time_attn
75
+ if self.add_time_attn:
76
+ self.t = config.num_frames
77
+ self.temporal_embedding = nn.Parameter(torch.zeros(1, config.num_frames, config.hidden_size))
78
+ nn.init.normal_(self.temporal_embedding, std=config.hidden_size ** -0.5)
79
+
80
+ self.embed_dim = config.hidden_size
81
+ self.temporal_attn = CLIPAttention(config)
82
+ self.temporal_layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
83
+ self.temporal_mlp = CLIPMLP(config)
84
+ self.temporal_layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
85
+
86
+ def forward(
87
+ self,
88
+ hidden_states: torch.Tensor,
89
+ attention_mask: torch.Tensor,
90
+ causal_attention_mask: torch.Tensor,
91
+ output_attentions: Optional[bool] = False,
92
+ ) -> Tuple[torch.FloatTensor]:
93
+ """
94
+ Args:
95
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
96
+ attention_mask (`torch.FloatTensor`): attention mask of size
97
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
98
+ `(config.encoder_attention_heads,)`.
99
+ output_attentions (`bool`, *optional*):
100
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
101
+ returned tensors for more detail.
102
+ """
103
+
104
+
105
+ if self.add_time_attn:
106
+ bt, n, d = hidden_states.shape
107
+ t = self.t
108
+
109
+ # time embed
110
+ if t != 1:
111
+ n = hidden_states.shape[1]
112
+ hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
113
+ hidden_states = hidden_states + self.temporal_embedding[:, :t, :]
114
+ hidden_states = rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
115
+
116
+ # time attn
117
+ residual = hidden_states
118
+ hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
119
+ # hidden_states = self.layer_norm1(hidden_states) # share layernorm
120
+ hidden_states = self.temporal_layer_norm1(hidden_states)
121
+ hidden_states, attn_weights = self.temporal_attn(
122
+ hidden_states=hidden_states,
123
+ attention_mask=attention_mask,
124
+ causal_attention_mask=causal_attention_mask,
125
+ output_attentions=output_attentions,
126
+ )
127
+ hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
128
+
129
+ residual = hidden_states
130
+ hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
131
+ # hidden_states = self.layer_norm2(hidden_states) # share layernorm
132
+ hidden_states = self.temporal_layer_norm2(hidden_states)
133
+ hidden_states = self.temporal_mlp(hidden_states)
134
+ hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
135
+
136
+ # spatial attn
137
+ residual = hidden_states
138
+
139
+ hidden_states = self.layer_norm1(hidden_states)
140
+ hidden_states, attn_weights = self.self_attn(
141
+ hidden_states=hidden_states,
142
+ attention_mask=attention_mask,
143
+ causal_attention_mask=causal_attention_mask,
144
+ output_attentions=output_attentions,
145
+ )
146
+ hidden_states = residual + hidden_states
147
+
148
+ residual = hidden_states
149
+ hidden_states = self.layer_norm2(hidden_states)
150
+ hidden_states = self.mlp(hidden_states)
151
+ hidden_states = residual + hidden_states
152
+
153
+ outputs = (hidden_states,)
154
+
155
+ if output_attentions:
156
+ outputs += (attn_weights,)
157
+
158
+ return outputs
159
+
160
+
161
+
162
+
163
+
164
+
165
+
166
+
167
+
168
+ class CLIPPreTrainedModel(PreTrainedModel):
169
+ """
170
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
171
+ models.
172
+ """
173
+
174
+ config_class = LanguageBindDepthConfig
175
+ base_model_prefix = "clip"
176
+ supports_gradient_checkpointing = True
177
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
178
+
179
+ def _init_weights(self, module):
180
+ """Initialize the weights"""
181
+ factor = self.config.initializer_factor
182
+ if isinstance(module, CLIPTextEmbeddings):
183
+ module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
184
+ module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
185
+ elif isinstance(module, CLIPVisionEmbeddings):
186
+ factor = self.config.initializer_factor
187
+ nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
188
+ nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
189
+ nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
190
+ elif isinstance(module, CLIPAttention):
191
+ factor = self.config.initializer_factor
192
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
193
+ out_proj_std = (module.embed_dim**-0.5) * factor
194
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
195
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
196
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
197
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
198
+ elif isinstance(module, CLIPMLP):
199
+ factor = self.config.initializer_factor
200
+ in_proj_std = (
201
+ (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
202
+ )
203
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
204
+ nn.init.normal_(module.fc1.weight, std=fc_std)
205
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
206
+ elif isinstance(module, LanguageBindDepth):
207
+ nn.init.normal_(
208
+ module.text_projection.weight,
209
+ std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
210
+ )
211
+ nn.init.normal_(
212
+ module.visual_projection.weight,
213
+ std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
214
+ )
215
+ elif isinstance(module, CLIPVisionModelWithProjection):
216
+ nn.init.normal_(
217
+ module.visual_projection.weight,
218
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
219
+ )
220
+ elif isinstance(module, CLIPTextModelWithProjection):
221
+ nn.init.normal_(
222
+ module.text_projection.weight,
223
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
224
+ )
225
+
226
+ if isinstance(module, nn.LayerNorm):
227
+ module.bias.data.zero_()
228
+ module.weight.data.fill_(1.0)
229
+ if isinstance(module, nn.Linear) and module.bias is not None:
230
+ module.bias.data.zero_()
231
+
232
+ def _set_gradient_checkpointing(self, module, value=False):
233
+ if isinstance(module, CLIPEncoder):
234
+ module.gradient_checkpointing = value
235
+
236
+
237
+ CLIP_START_DOCSTRING = r"""
238
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
239
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
240
+ etc.)
241
+
242
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
243
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
244
+ and behavior.
245
+
246
+ Parameters:
247
+ config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
248
+ Initializing with a config file does not load the weights associated with the model, only the
249
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
250
+ """
251
+
252
+ CLIP_TEXT_INPUTS_DOCSTRING = r"""
253
+ Args:
254
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
255
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
256
+ it.
257
+
258
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
259
+ [`PreTrainedTokenizer.__call__`] for details.
260
+
261
+ [What are input IDs?](../glossary#input-ids)
262
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
263
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
264
+
265
+ - 1 for tokens that are **not masked**,
266
+ - 0 for tokens that are **masked**.
267
+
268
+ [What are attention masks?](../glossary#attention-mask)
269
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
270
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
271
+ config.max_position_embeddings - 1]`.
272
+
273
+ [What are position IDs?](../glossary#position-ids)
274
+ output_attentions (`bool`, *optional*):
275
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
276
+ tensors for more detail.
277
+ output_hidden_states (`bool`, *optional*):
278
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
279
+ more detail.
280
+ return_dict (`bool`, *optional*):
281
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
282
+ """
283
+
284
+ CLIP_VISION_INPUTS_DOCSTRING = r"""
285
+ Args:
286
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
287
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
288
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
289
+ output_attentions (`bool`, *optional*):
290
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
291
+ tensors for more detail.
292
+ output_hidden_states (`bool`, *optional*):
293
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
294
+ more detail.
295
+ return_dict (`bool`, *optional*):
296
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
297
+ """
298
+
299
+ CLIP_INPUTS_DOCSTRING = r"""
300
+ Args:
301
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
302
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
303
+ it.
304
+
305
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
306
+ [`PreTrainedTokenizer.__call__`] for details.
307
+
308
+ [What are input IDs?](../glossary#input-ids)
309
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
310
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
311
+
312
+ - 1 for tokens that are **not masked**,
313
+ - 0 for tokens that are **masked**.
314
+
315
+ [What are attention masks?](../glossary#attention-mask)
316
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
317
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
318
+ config.max_position_embeddings - 1]`.
319
+
320
+ [What are position IDs?](../glossary#position-ids)
321
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
322
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
323
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
324
+ return_loss (`bool`, *optional*):
325
+ Whether or not to return the contrastive loss.
326
+ output_attentions (`bool`, *optional*):
327
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
328
+ tensors for more detail.
329
+ output_hidden_states (`bool`, *optional*):
330
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
331
+ more detail.
332
+ return_dict (`bool`, *optional*):
333
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
334
+ """
335
+
336
+
337
+ class CLIPEncoder(nn.Module):
338
+ """
339
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
340
+ [`CLIPEncoderLayer`].
341
+
342
+ Args:
343
+ config: CLIPConfig
344
+ """
345
+
346
+ def __init__(self, config: LanguageBindDepthConfig):
347
+ super().__init__()
348
+ self.config = config
349
+ self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
350
+ self.gradient_checkpointing = False
351
+
352
+ def forward(
353
+ self,
354
+ inputs_embeds,
355
+ attention_mask: Optional[torch.Tensor] = None,
356
+ causal_attention_mask: Optional[torch.Tensor] = None,
357
+ output_attentions: Optional[bool] = None,
358
+ output_hidden_states: Optional[bool] = None,
359
+ return_dict: Optional[bool] = None,
360
+ ) -> Union[Tuple, BaseModelOutput]:
361
+ r"""
362
+ Args:
363
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
364
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
365
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
366
+ than the model's internal embedding lookup matrix.
367
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
368
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
369
+
370
+ - 1 for tokens that are **not masked**,
371
+ - 0 for tokens that are **masked**.
372
+
373
+ [What are attention masks?](../glossary#attention-mask)
374
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
375
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
376
+
377
+ - 1 for tokens that are **not masked**,
378
+ - 0 for tokens that are **masked**.
379
+
380
+ [What are attention masks?](../glossary#attention-mask)
381
+ output_attentions (`bool`, *optional*):
382
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
383
+ returned tensors for more detail.
384
+ output_hidden_states (`bool`, *optional*):
385
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
386
+ for more detail.
387
+ return_dict (`bool`, *optional*):
388
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
389
+ """
390
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
391
+ output_hidden_states = (
392
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
393
+ )
394
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
395
+
396
+ encoder_states = () if output_hidden_states else None
397
+ all_attentions = () if output_attentions else None
398
+
399
+ hidden_states = inputs_embeds
400
+ for idx, encoder_layer in enumerate(self.layers):
401
+ if output_hidden_states:
402
+ encoder_states = encoder_states + (hidden_states,)
403
+ if self.gradient_checkpointing and self.training:
404
+
405
+ def create_custom_forward(module):
406
+ def custom_forward(*inputs):
407
+ return module(*inputs, output_attentions)
408
+
409
+ return custom_forward
410
+
411
+ layer_outputs = torch.utils.checkpoint.checkpoint(
412
+ create_custom_forward(encoder_layer),
413
+ hidden_states,
414
+ attention_mask,
415
+ causal_attention_mask,
416
+ )
417
+ else:
418
+ layer_outputs = encoder_layer(
419
+ hidden_states,
420
+ attention_mask,
421
+ causal_attention_mask,
422
+ output_attentions=output_attentions,
423
+ )
424
+
425
+ hidden_states = layer_outputs[0]
426
+
427
+ if output_attentions:
428
+ all_attentions = all_attentions + (layer_outputs[1],)
429
+
430
+ if output_hidden_states:
431
+ encoder_states = encoder_states + (hidden_states,)
432
+
433
+ if not return_dict:
434
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
435
+ return BaseModelOutput(
436
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
437
+ )
438
+
439
+
440
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
441
+ def _make_causal_mask(
442
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
443
+ ):
444
+ """
445
+ Make causal mask used for bi-directional self-attention.
446
+ """
447
+ bsz, tgt_len = input_ids_shape
448
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
449
+ mask_cond = torch.arange(mask.size(-1), device=device)
450
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
451
+ mask = mask.to(dtype)
452
+
453
+ if past_key_values_length > 0:
454
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
455
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
456
+
457
+
458
+ class CLIPTextTransformer(nn.Module):
459
+ def __init__(self, config: CLIPTextConfig):
460
+ super().__init__()
461
+ self.config = config
462
+ embed_dim = config.hidden_size
463
+ self.embeddings = CLIPTextEmbeddings(config)
464
+ self.encoder = CLIPEncoder(config)
465
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
466
+
467
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
468
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
469
+ def forward(
470
+ self,
471
+ input_ids: Optional[torch.Tensor] = None,
472
+ attention_mask: Optional[torch.Tensor] = None,
473
+ position_ids: Optional[torch.Tensor] = None,
474
+ output_attentions: Optional[bool] = None,
475
+ output_hidden_states: Optional[bool] = None,
476
+ return_dict: Optional[bool] = None,
477
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
478
+ r"""
479
+ Returns:
480
+
481
+ """
482
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
483
+ output_hidden_states = (
484
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
485
+ )
486
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
487
+
488
+ if input_ids is None:
489
+ raise ValueError("You have to specify input_ids")
490
+
491
+ input_shape = input_ids.size()
492
+ input_ids = input_ids.view(-1, input_shape[-1])
493
+
494
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
495
+
496
+ # CLIP's text model uses causal mask, prepare it here.
497
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
498
+ causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
499
+ # expand attention_mask
500
+ if attention_mask is not None:
501
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
502
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
503
+
504
+ encoder_outputs = self.encoder(
505
+ inputs_embeds=hidden_states,
506
+ attention_mask=attention_mask,
507
+ causal_attention_mask=causal_attention_mask,
508
+ output_attentions=output_attentions,
509
+ output_hidden_states=output_hidden_states,
510
+ return_dict=return_dict,
511
+ )
512
+
513
+ last_hidden_state = encoder_outputs[0]
514
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
515
+
516
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
517
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
518
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
519
+ pooled_output = last_hidden_state[
520
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
521
+ input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
522
+ ]
523
+
524
+ if not return_dict:
525
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
526
+
527
+ return BaseModelOutputWithPooling(
528
+ last_hidden_state=last_hidden_state,
529
+ pooler_output=pooled_output,
530
+ hidden_states=encoder_outputs.hidden_states,
531
+ attentions=encoder_outputs.attentions,
532
+ )
533
+
534
+
535
+ @add_start_docstrings(
536
+ """The text model from CLIP without any head or projection on top.""",
537
+ CLIP_START_DOCSTRING,
538
+ )
539
+ class CLIPTextModel(CLIPPreTrainedModel):
540
+ config_class = CLIPTextConfig
541
+
542
+ _no_split_modules = ["CLIPEncoderLayer"]
543
+
544
+ def __init__(self, config: CLIPTextConfig):
545
+ super().__init__(config)
546
+ self.text_model = CLIPTextTransformer(config)
547
+ # Initialize weights and apply final processing
548
+ self.post_init()
549
+
550
+ def get_input_embeddings(self) -> nn.Module:
551
+ return self.text_model.embeddings.token_embedding
552
+
553
+ def set_input_embeddings(self, value):
554
+ self.text_model.embeddings.token_embedding = value
555
+
556
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
557
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
558
+ def forward(
559
+ self,
560
+ input_ids: Optional[torch.Tensor] = None,
561
+ attention_mask: Optional[torch.Tensor] = None,
562
+ position_ids: Optional[torch.Tensor] = None,
563
+ output_attentions: Optional[bool] = None,
564
+ output_hidden_states: Optional[bool] = None,
565
+ return_dict: Optional[bool] = None,
566
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
567
+ r"""
568
+ Returns:
569
+
570
+ Examples:
571
+
572
+ ```python
573
+ >>> from transformers import AutoTokenizer, CLIPTextModel
574
+
575
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
576
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
577
+
578
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
579
+
580
+ >>> outputs = model(**inputs)
581
+ >>> last_hidden_state = outputs.last_hidden_state
582
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
583
+ ```"""
584
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
585
+
586
+ return self.text_model(
587
+ input_ids=input_ids,
588
+ attention_mask=attention_mask,
589
+ position_ids=position_ids,
590
+ output_attentions=output_attentions,
591
+ output_hidden_states=output_hidden_states,
592
+ return_dict=return_dict,
593
+ )
594
+
595
+
596
+ class CLIPVisionTransformer(nn.Module):
597
+ def __init__(self, config: CLIPVisionConfig):
598
+ super().__init__()
599
+ self.config = config
600
+ embed_dim = config.hidden_size
601
+
602
+ self.embeddings = CLIPVisionEmbeddings(config)
603
+ self.patch_dropout = PatchDropout(config.force_patch_dropout)
604
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
605
+ self.encoder = CLIPEncoder(config)
606
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
607
+
608
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
609
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
610
+ def forward(
611
+ self,
612
+ pixel_values: Optional[torch.FloatTensor] = None,
613
+ output_attentions: Optional[bool] = None,
614
+ output_hidden_states: Optional[bool] = None,
615
+ return_dict: Optional[bool] = None,
616
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
617
+ r"""
618
+ Returns:
619
+
620
+ """
621
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
622
+ output_hidden_states = (
623
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
624
+ )
625
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
626
+
627
+ if pixel_values is None:
628
+ raise ValueError("You have to specify pixel_values")
629
+ ######################################
630
+ if len(pixel_values.shape) == 7:
631
+ b_new, pair_new, T, bs_new, channel_new, h_new, w_new = pixel_values.shape
632
+ # print(pixel_values.shape)
633
+ B = b_new * pair_new * bs_new
634
+ pixel_values = pixel_values.reshape(B*T, channel_new, h_new, w_new)
635
+
636
+ elif len(pixel_values.shape) == 5:
637
+ B, _, T, _, _ = pixel_values.shape
638
+ # print(pixel_values.shape)
639
+ pixel_values = rearrange(pixel_values, 'b c t h w -> (b t) c h w')
640
+ else:
641
+ # print(pixel_values.shape)
642
+ B, _, _, _ = pixel_values.shape
643
+ T = 1
644
+ ###########################
645
+ hidden_states = self.embeddings(pixel_values)
646
+
647
+ hidden_states = self.patch_dropout(hidden_states, B, T) ##############################################
648
+
649
+ hidden_states = self.pre_layrnorm(hidden_states)
650
+
651
+ encoder_outputs = self.encoder(
652
+ inputs_embeds=hidden_states,
653
+ output_attentions=output_attentions,
654
+ output_hidden_states=output_hidden_states,
655
+ return_dict=return_dict,
656
+ )
657
+
658
+ last_hidden_state = encoder_outputs[0]
659
+ pooled_output = last_hidden_state[:, 0, :]
660
+ pooled_output = self.post_layernorm(pooled_output)
661
+
662
+ pooled_output = pooled_output.reshape(B, T, -1).mean(1) ################################
663
+
664
+ if not return_dict:
665
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
666
+
667
+ return BaseModelOutputWithPooling(
668
+ last_hidden_state=last_hidden_state,
669
+ pooler_output=pooled_output,
670
+ hidden_states=encoder_outputs.hidden_states,
671
+ attentions=encoder_outputs.attentions,
672
+ )
673
+
674
+
675
+ @add_start_docstrings(
676
+ """The vision model from CLIP without any head or projection on top.""",
677
+ CLIP_START_DOCSTRING,
678
+ )
679
+ class CLIPVisionModel(CLIPPreTrainedModel):
680
+ config_class = CLIPVisionConfig
681
+ main_input_name = "pixel_values"
682
+
683
+ def __init__(self, config: CLIPVisionConfig):
684
+ super().__init__(config)
685
+ self.vision_model = CLIPVisionTransformer(config)
686
+ # Initialize weights and apply final processing
687
+ self.post_init()
688
+
689
+ def get_input_embeddings(self) -> nn.Module:
690
+ return self.vision_model.embeddings.patch_embedding
691
+
692
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
693
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
694
+ def forward(
695
+ self,
696
+ pixel_values: Optional[torch.FloatTensor] = None,
697
+ output_attentions: Optional[bool] = None,
698
+ output_hidden_states: Optional[bool] = None,
699
+ return_dict: Optional[bool] = None,
700
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
701
+ r"""
702
+ Returns:
703
+
704
+ Examples:
705
+
706
+ ```python
707
+ >>> from PIL import Image
708
+ >>> import requests
709
+ >>> from transformers import AutoProcessor, CLIPVisionModel
710
+
711
+ >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
712
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
713
+
714
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
715
+ >>> image = Image.open(requests.get(url, stream=True).raw)
716
+
717
+ >>> inputs = processor(images=image, return_tensors="pt")
718
+
719
+ >>> outputs = model(**inputs)
720
+ >>> last_hidden_state = outputs.last_hidden_state
721
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
722
+ ```"""
723
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
724
+
725
+ return self.vision_model(
726
+ pixel_values=pixel_values,
727
+ output_attentions=output_attentions,
728
+ output_hidden_states=output_hidden_states,
729
+ return_dict=return_dict,
730
+ )
731
+
732
+
733
+ @add_start_docstrings(CLIP_START_DOCSTRING)
734
+ class LanguageBindDepth(CLIPPreTrainedModel):
735
+ config_class = LanguageBindDepthConfig
736
+
737
+ def __init__(self, config: LanguageBindDepthConfig):
738
+ super().__init__(config)
739
+
740
+ if not isinstance(config.text_config, CLIPTextConfig):
741
+ raise ValueError(
742
+ "config.text_config is expected to be of type CLIPTextConfig but is of type"
743
+ f" {type(config.text_config)}."
744
+ )
745
+
746
+ if not isinstance(config.vision_config, CLIPVisionConfig):
747
+ raise ValueError(
748
+ "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
749
+ f" {type(config.vision_config)}."
750
+ )
751
+
752
+ text_config = config.text_config
753
+ vision_config = config.vision_config
754
+ self.add_time_attn = vision_config.add_time_attn
755
+ self.lora_r = vision_config.lora_r
756
+ self.lora_alpha = vision_config.lora_alpha
757
+ self.lora_dropout = vision_config.lora_dropout
758
+
759
+ self.projection_dim = config.projection_dim
760
+ self.text_embed_dim = text_config.hidden_size
761
+ self.vision_embed_dim = vision_config.hidden_size
762
+
763
+ self.text_model = CLIPTextTransformer(text_config)
764
+ self.vision_model = CLIPVisionTransformer(vision_config)
765
+
766
+ self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
767
+ self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
768
+ self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
769
+
770
+ # Initialize weights and apply final processing
771
+ self.post_init()
772
+ self.convert_to_lora()
773
+ self.resize_pos(self.vision_model.embeddings, vision_config)
774
+
775
+ def convert_to_lora(self):
776
+ if self.lora_r == 0:
777
+ return
778
+ if self.add_time_attn:
779
+ target_modules = ["temporal_attn.k_proj", "temporal_attn.v_proj",
780
+ "temporal_attn.q_proj", "temporal_attn.out_proj",
781
+ "temporal_mlp.fc1", "temporal_mlp.fc2"]
782
+ else:
783
+ target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"]
784
+ config = LoraConfig(
785
+ r=self.lora_r, # 16
786
+ lora_alpha=self.lora_alpha, # 16
787
+ target_modules=target_modules, # self_attn.out_proj
788
+ lora_dropout=self.lora_dropout, # 0.1
789
+ bias="none",
790
+ modules_to_save=[],
791
+ )
792
+ self.vision_model.encoder.is_gradient_checkpointing = False
793
+ self.vision_model.encoder = get_peft_model(self.vision_model.encoder, config)
794
+
795
+ def resize_pos(self, m, vision_config):
796
+ # convert embedding
797
+ if vision_config.num_mel_bins!=0 and vision_config.target_length!=0:
798
+ m.image_size = [vision_config.num_mel_bins, vision_config.target_length]
799
+ m.config.image_size = [m.image_size, m.image_size] if isinstance(m.image_size, int) else m.image_size
800
+ # pos resize
801
+ old_pos_embed_state_dict = m.position_embedding.state_dict()
802
+ old_pos_embed = old_pos_embed_state_dict['weight']
803
+ dtype = old_pos_embed.dtype
804
+ grid_size = [m.config.image_size[0] // m.patch_size, m.config.image_size[1] // m.patch_size]
805
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
806
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
807
+ if new_seq_len == old_pos_embed.shape[0]:
808
+ # m.to(args.device)
809
+ return
810
+
811
+ m.num_patches = grid_size[0] * grid_size[1]
812
+ m.num_positions = m.num_patches + 1
813
+ m.register_buffer("position_ids", torch.arange(m.num_positions).expand((1, -1)))
814
+ new_position_embedding = nn.Embedding(m.num_positions, m.embed_dim)
815
+
816
+ if extra_tokens:
817
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
818
+ else:
819
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
820
+ old_grid_size = [int(math.sqrt(len(pos_emb_img)))] * 2
821
+
822
+ # if is_master(args):
823
+ # logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
824
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
825
+ pos_emb_img = F.interpolate(
826
+ pos_emb_img,
827
+ size=grid_size,
828
+ mode='bicubic',
829
+ antialias=True,
830
+ align_corners=False,
831
+ )
832
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
833
+ if pos_emb_tok is not None:
834
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
835
+ else:
836
+ new_pos_embed = pos_emb_img
837
+ old_pos_embed_state_dict['weight'] = new_pos_embed.to(dtype)
838
+ m.position_embedding = new_position_embedding
839
+ m.position_embedding.load_state_dict(old_pos_embed_state_dict)
840
+
841
+ # m.to(args.device)
842
+
843
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
844
+ def get_text_features(
845
+ self,
846
+ input_ids: Optional[torch.Tensor] = None,
847
+ attention_mask: Optional[torch.Tensor] = None,
848
+ position_ids: Optional[torch.Tensor] = None,
849
+ output_attentions: Optional[bool] = None,
850
+ output_hidden_states: Optional[bool] = None,
851
+ return_dict: Optional[bool] = None,
852
+ ) -> torch.FloatTensor:
853
+ r"""
854
+ Returns:
855
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
856
+ applying the projection layer to the pooled output of [`CLIPTextModel`].
857
+
858
+ Examples:
859
+
860
+ ```python
861
+ >>> from transformers import AutoTokenizer, CLIPModel
862
+
863
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
864
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
865
+
866
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
867
+ >>> text_features = model.get_text_features(**inputs)
868
+ ```"""
869
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
870
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
871
+ output_hidden_states = (
872
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
873
+ )
874
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
875
+
876
+ text_outputs = self.text_model(
877
+ input_ids=input_ids,
878
+ attention_mask=attention_mask,
879
+ position_ids=position_ids,
880
+ output_attentions=output_attentions,
881
+ output_hidden_states=output_hidden_states,
882
+ return_dict=return_dict,
883
+ )
884
+
885
+ pooled_output = text_outputs[1]
886
+ text_features = self.text_projection(pooled_output)
887
+
888
+ return text_features
889
+
890
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
891
+ def get_image_features(
892
+ self,
893
+ pixel_values: Optional[torch.FloatTensor] = None,
894
+ output_attentions: Optional[bool] = None,
895
+ output_hidden_states: Optional[bool] = None,
896
+ return_dict: Optional[bool] = None,
897
+ ) -> torch.FloatTensor:
898
+ r"""
899
+ Returns:
900
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
901
+ applying the projection layer to the pooled output of [`CLIPVisionModel`].
902
+
903
+ Examples:
904
+
905
+ ```python
906
+ >>> from PIL import Image
907
+ >>> import requests
908
+ >>> from transformers import AutoProcessor, CLIPModel
909
+
910
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
911
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
912
+
913
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
914
+ >>> image = Image.open(requests.get(url, stream=True).raw)
915
+
916
+ >>> inputs = processor(images=image, return_tensors="pt")
917
+
918
+ >>> image_features = model.get_image_features(**inputs)
919
+ ```"""
920
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
921
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
922
+ output_hidden_states = (
923
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
924
+ )
925
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
926
+
927
+ vision_outputs = self.vision_model(
928
+ pixel_values=pixel_values,
929
+ output_attentions=output_attentions,
930
+ output_hidden_states=output_hidden_states,
931
+ return_dict=return_dict,
932
+ )
933
+
934
+ pooled_output = vision_outputs[1] # pooled_output
935
+ image_features = self.visual_projection(pooled_output)
936
+
937
+ return image_features
938
+
939
+ @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
940
+ @replace_return_docstrings(output_type=CLIPOutput, config_class=LanguageBindDepthConfig)
941
+ def forward(
942
+ self,
943
+ input_ids: Optional[torch.LongTensor] = None,
944
+ pixel_values: Optional[torch.FloatTensor] = None,
945
+ attention_mask: Optional[torch.Tensor] = None,
946
+ position_ids: Optional[torch.LongTensor] = None,
947
+ return_loss: Optional[bool] = None,
948
+ output_attentions: Optional[bool] = None,
949
+ output_hidden_states: Optional[bool] = None,
950
+ return_dict: Optional[bool] = None,
951
+ ) -> Union[Tuple, CLIPOutput]:
952
+ r"""
953
+ Returns:
954
+
955
+ Examples:
956
+
957
+ ```python
958
+ >>> from PIL import Image
959
+ >>> import requests
960
+ >>> from transformers import AutoProcessor, CLIPModel
961
+
962
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
963
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
964
+
965
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
966
+ >>> image = Image.open(requests.get(url, stream=True).raw)
967
+
968
+ >>> inputs = processor(
969
+ ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
970
+ ... )
971
+
972
+ >>> outputs = model(**inputs)
973
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
974
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
975
+ ```"""
976
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
977
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
978
+ output_hidden_states = (
979
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
980
+ )
981
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
982
+
983
+ vision_outputs = self.vision_model(
984
+ pixel_values=pixel_values,
985
+ output_attentions=output_attentions,
986
+ output_hidden_states=output_hidden_states,
987
+ return_dict=return_dict,
988
+ )
989
+
990
+ text_outputs = self.text_model(
991
+ input_ids=input_ids,
992
+ attention_mask=attention_mask,
993
+ position_ids=position_ids,
994
+ output_attentions=output_attentions,
995
+ output_hidden_states=output_hidden_states,
996
+ return_dict=return_dict,
997
+ )
998
+
999
+ image_embeds = vision_outputs[1]
1000
+ image_embeds = self.visual_projection(image_embeds)
1001
+
1002
+ text_embeds = text_outputs[1]
1003
+ text_embeds = self.text_projection(text_embeds)
1004
+
1005
+ # normalized features
1006
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1007
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1008
+
1009
+ # cosine similarity as logits
1010
+ logit_scale = self.logit_scale.exp()
1011
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
1012
+ logits_per_image = logits_per_text.t()
1013
+
1014
+ loss = None
1015
+ if return_loss:
1016
+ loss = clip_loss(logits_per_text)
1017
+
1018
+ if not return_dict:
1019
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1020
+ return ((loss,) + output) if loss is not None else output
1021
+
1022
+ return CLIPOutput(
1023
+ loss=loss,
1024
+ logits_per_image=logits_per_image,
1025
+ logits_per_text=logits_per_text,
1026
+ text_embeds=text_embeds,
1027
+ image_embeds=image_embeds,
1028
+ text_model_output=text_outputs,
1029
+ vision_model_output=vision_outputs,
1030
+ )
models/multimodal_encoder/languagebind/depth/processing_depth.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ from PIL import Image
4
+ from torch import nn
5
+ from torchvision import transforms
6
+ from transformers import ProcessorMixin, BatchEncoding
7
+ from transformers.image_processing_utils import BatchFeature
8
+
9
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
10
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
11
+
12
+ def make_list_of_images(x):
13
+ if not isinstance(x, list):
14
+ return [x]
15
+ return x
16
+
17
+ def opencv_loader(path):
18
+ return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype('float32')
19
+
20
+
21
+ class DepthNorm(nn.Module):
22
+ def __init__(
23
+ self,
24
+ max_depth=0,
25
+ min_depth=0.01,
26
+ ):
27
+ super().__init__()
28
+ self.max_depth = max_depth
29
+ self.min_depth = min_depth
30
+ self.scale = 1000.0 # nyuv2 abs.depth
31
+
32
+ def forward(self, image):
33
+ # image = np.array(image)
34
+ depth_img = image / self.scale # (H, W) in meters
35
+ depth_img = depth_img.clip(min=self.min_depth)
36
+ if self.max_depth != 0:
37
+ depth_img = depth_img.clip(max=self.max_depth)
38
+ depth_img /= self.max_depth # 0-1
39
+ else:
40
+ depth_img /= depth_img.max()
41
+ depth_img = torch.from_numpy(depth_img).unsqueeze(0).repeat(3, 1, 1) # assume image
42
+ return depth_img.to(torch.get_default_dtype())
43
+
44
+ def get_depth_transform(config):
45
+ config = config.vision_config
46
+ transform = transforms.Compose(
47
+ [
48
+ DepthNorm(max_depth=config.max_depth),
49
+ transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
50
+ transforms.CenterCrop(224),
51
+ transforms.Normalize(OPENAI_DATASET_MEAN, OPENAI_DATASET_STD), # assume image
52
+ # transforms.Normalize((0.5, ), (0.5, )) # 0-1 to norm distribution
53
+ # transforms.Normalize((0.0418, ), (0.0295, )) # sun rgb-d imagebind
54
+ # transforms.Normalize((0.02, ), (0.00295, )) # nyuv2
55
+ ]
56
+ )
57
+ return transform
58
+
59
+ def load_and_transform_depth(depth_path, transform):
60
+ depth = opencv_loader(depth_path)
61
+ depth_outputs = transform(depth)
62
+ return depth_outputs
63
+
64
+ class LanguageBindDepthProcessor(ProcessorMixin):
65
+ attributes = []
66
+ tokenizer_class = ("LanguageBindDepthTokenizer")
67
+
68
+ def __init__(self, config, tokenizer=None, **kwargs):
69
+ super().__init__(**kwargs)
70
+ self.config = config
71
+ self.transform = get_depth_transform(config)
72
+ self.image_processor = load_and_transform_depth
73
+ self.tokenizer = tokenizer
74
+
75
+ def __call__(self, images=None, text=None, context_length=77, return_tensors=None, **kwargs):
76
+ if text is None and images is None:
77
+ raise ValueError("You have to specify either text or images. Both cannot be none.")
78
+
79
+ if text is not None:
80
+ encoding = self.tokenizer(text, max_length=context_length, padding='max_length',
81
+ truncation=True, return_tensors=return_tensors, **kwargs)
82
+
83
+ if images is not None:
84
+ images = make_list_of_images(images)
85
+ image_features = [self.image_processor(image, self.transform) for image in images]
86
+ image_features = torch.stack(image_features)
87
+
88
+ if text is not None and images is not None:
89
+ encoding["pixel_values"] = image_features
90
+ return encoding
91
+ elif text is not None:
92
+ return encoding
93
+ else:
94
+ return {"pixel_values": image_features}
95
+
96
+ def batch_decode(self, skip_special_tokens=True, *args, **kwargs):
97
+ """
98
+ This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
99
+ refer to the docstring of this method for more information.
100
+ """
101
+ return self.tokenizer.batch_decode(*args, skip_special_tokens=skip_special_tokens, **kwargs)
102
+
103
+ def decode(self, skip_special_tokens=True, *args, **kwargs):
104
+ """
105
+ This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
106
+ the docstring of this method for more information.
107
+ """
108
+ return self.tokenizer.decode(*args, skip_special_tokens=skip_special_tokens, **kwargs)
models/multimodal_encoder/languagebind/depth/tokenization_depth.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPTokenizer
2
+ from transformers.utils import logging
3
+
4
+ logger = logging.get_logger(__name__)
5
+
6
+ VOCAB_FILES_NAMES = {
7
+ "vocab_file": "vocab.json",
8
+ "merges_file": "merges.txt",
9
+ }
10
+
11
+ PRETRAINED_VOCAB_FILES_MAP = {
12
+ "vocab_file": {
13
+ "lb203/LanguageBind-Depth": "https://huggingface.co/lb203/LanguageBind-Depth/resolve/main/vocab.json",
14
+ },
15
+ "merges_file": {
16
+ "lb203/LanguageBind-Depth": "https://huggingface.co/lb203/LanguageBind-Depth/resolve/main/merges.txt",
17
+ },
18
+ }
19
+
20
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
21
+ "lb203/LanguageBind-Depth": 77,
22
+ }
23
+
24
+
25
+ PRETRAINED_INIT_CONFIGURATION = {
26
+ "lb203/LanguageBind-Thermal": {},
27
+ }
28
+
29
+ class LanguageBindDepthTokenizer(CLIPTokenizer):
30
+ """
31
+ Construct a CLIP tokenizer. Based on byte-level Byte-Pair-Encoding.
32
+
33
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
34
+ this superclass for more information regarding those methods.
35
+
36
+ Args:
37
+ vocab_file (`str`):
38
+ Path to the vocabulary file.
39
+ merges_file (`str`):
40
+ Path to the merges file.
41
+ errors (`str`, *optional*, defaults to `"replace"`):
42
+ Paradigm to follow when decoding bytes to UTF-8. See
43
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
44
+ unk_token (`str`, *optional*, defaults to `<|endoftext|>`):
45
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
46
+ token instead.
47
+ bos_token (`str`, *optional*, defaults to `<|startoftext|>`):
48
+ The beginning of sequence token.
49
+ eos_token (`str`, *optional*, defaults to `<|endoftext|>`):
50
+ The end of sequence token.
51
+ """
52
+
53
+ vocab_files_names = VOCAB_FILES_NAMES
54
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
55
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
56
+ model_input_names = ["input_ids", "attention_mask"]
57
+
58
+ def __init__(
59
+ self,
60
+ vocab_file,
61
+ merges_file,
62
+ errors="replace",
63
+ unk_token="<|endoftext|>",
64
+ bos_token="<|startoftext|>",
65
+ eos_token="<|endoftext|>",
66
+ pad_token="<|endoftext|>", # hack to enable padding
67
+ **kwargs,
68
+ ):
69
+ super(LanguageBindDepthTokenizer, self).__init__(
70
+ vocab_file,
71
+ merges_file,
72
+ errors,
73
+ unk_token,
74
+ bos_token,
75
+ eos_token,
76
+ pad_token, # hack to enable padding
77
+ **kwargs,)
models/multimodal_encoder/languagebind/image/configuration_image.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ from typing import Union
4
+
5
+ from transformers import PretrainedConfig
6
+ from transformers.utils import logging
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+
11
+
12
+
13
+
14
+
15
+
16
+ class CLIPTextConfig(PretrainedConfig):
17
+ r"""
18
+ This is the configuration class to store the configuration of a [`CLIPTextModel`]. It is used to instantiate a CLIP
19
+ text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
20
+ with the defaults will yield a similar configuration to that of the text encoder of the CLIP
21
+ [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
22
+
23
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
24
+ documentation from [`PretrainedConfig`] for more information.
25
+
26
+ Args:
27
+ vocab_size (`int`, *optional*, defaults to 49408):
28
+ Vocabulary size of the CLIP text model. Defines the number of different tokens that can be represented by
29
+ the `inputs_ids` passed when calling [`CLIPModel`].
30
+ hidden_size (`int`, *optional*, defaults to 512):
31
+ Dimensionality of the encoder layers and the pooler layer.
32
+ intermediate_size (`int`, *optional*, defaults to 2048):
33
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
34
+ num_hidden_layers (`int`, *optional*, defaults to 12):
35
+ Number of hidden layers in the Transformer encoder.
36
+ num_attention_heads (`int`, *optional*, defaults to 8):
37
+ Number of attention heads for each attention layer in the Transformer encoder.
38
+ max_position_embeddings (`int`, *optional*, defaults to 77):
39
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
40
+ just in case (e.g., 512 or 1024 or 2048).
41
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
42
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
43
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
44
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
45
+ The epsilon used by the layer normalization layers.
46
+ attention_dropout (`float`, *optional*, defaults to 0.0):
47
+ The dropout ratio for the attention probabilities.
48
+ initializer_range (`float`, *optional*, defaults to 0.02):
49
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
50
+ initializer_factor (`float`, *optional*, defaults to 1):
51
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
52
+ testing).
53
+
54
+ Example:
55
+
56
+ ```python
57
+ >>> from transformers import CLIPTextConfig, CLIPTextModel
58
+
59
+ >>> # Initializing a CLIPTextConfig with openai/clip-vit-base-patch32 style configuration
60
+ >>> configuration = CLIPTextConfig()
61
+
62
+ >>> # Initializing a CLIPTextModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
63
+ >>> model = CLIPTextModel(configuration)
64
+
65
+ >>> # Accessing the model configuration
66
+ >>> configuration = model.config
67
+ ```"""
68
+ model_type = "clip_text_model"
69
+
70
+ def __init__(
71
+ self,
72
+ vocab_size=49408,
73
+ hidden_size=512,
74
+ intermediate_size=2048,
75
+ projection_dim=512,
76
+ num_hidden_layers=12,
77
+ num_attention_heads=8,
78
+ max_position_embeddings=77,
79
+ hidden_act="quick_gelu",
80
+ layer_norm_eps=1e-5,
81
+ attention_dropout=0.0,
82
+ initializer_range=0.02,
83
+ initializer_factor=1.0,
84
+ # This differs from `CLIPTokenizer`'s default and from openai/clip
85
+ # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
86
+ pad_token_id=1,
87
+ bos_token_id=49406,
88
+ eos_token_id=49407,
89
+ **kwargs,
90
+ ):
91
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
92
+
93
+ self.vocab_size = vocab_size
94
+ self.hidden_size = hidden_size
95
+ self.intermediate_size = intermediate_size
96
+ self.projection_dim = projection_dim
97
+ self.num_hidden_layers = num_hidden_layers
98
+ self.num_attention_heads = num_attention_heads
99
+ self.max_position_embeddings = max_position_embeddings
100
+ self.layer_norm_eps = layer_norm_eps
101
+ self.hidden_act = hidden_act
102
+ self.initializer_range = initializer_range
103
+ self.initializer_factor = initializer_factor
104
+ self.attention_dropout = attention_dropout
105
+ self.add_time_attn = False ######################################
106
+
107
+ @classmethod
108
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
109
+ cls._set_token_in_kwargs(kwargs)
110
+
111
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
112
+
113
+ # get the text config dict if we are loading from CLIPConfig
114
+ if config_dict.get("model_type") == "clip":
115
+ config_dict = config_dict["text_config"]
116
+
117
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
118
+ logger.warning(
119
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
120
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
121
+ )
122
+
123
+ return cls.from_dict(config_dict, **kwargs)
124
+
125
+
126
+
127
+
128
+ class CLIPVisionConfig(PretrainedConfig):
129
+ r"""
130
+ This is the configuration class to store the configuration of a [`CLIPVisionModel`]. It is used to instantiate a
131
+ CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a
132
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP
133
+ [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
134
+
135
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
136
+ documentation from [`PretrainedConfig`] for more information.
137
+
138
+ Args:
139
+ hidden_size (`int`, *optional*, defaults to 768):
140
+ Dimensionality of the encoder layers and the pooler layer.
141
+ intermediate_size (`int`, *optional*, defaults to 3072):
142
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
143
+ num_hidden_layers (`int`, *optional*, defaults to 12):
144
+ Number of hidden layers in the Transformer encoder.
145
+ num_attention_heads (`int`, *optional*, defaults to 12):
146
+ Number of attention heads for each attention layer in the Transformer encoder.
147
+ image_size (`int`, *optional*, defaults to 224):
148
+ The size (resolution) of each image.
149
+ patch_size (`int`, *optional*, defaults to 32):
150
+ The size (resolution) of each patch.
151
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
152
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
153
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
154
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
155
+ The epsilon used by the layer normalization layers.
156
+ attention_dropout (`float`, *optional*, defaults to 0.0):
157
+ The dropout ratio for the attention probabilities.
158
+ initializer_range (`float`, *optional*, defaults to 0.02):
159
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
160
+ initializer_factor (`float`, *optional*, defaults to 1):
161
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
162
+ testing).
163
+
164
+ Example:
165
+
166
+ ```python
167
+ >>> from transformers import CLIPVisionConfig, CLIPVisionModel
168
+
169
+ >>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration
170
+ >>> configuration = CLIPVisionConfig()
171
+
172
+ >>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
173
+ >>> model = CLIPVisionModel(configuration)
174
+
175
+ >>> # Accessing the model configuration
176
+ >>> configuration = model.config
177
+ ```"""
178
+
179
+ model_type = "clip_vision_model"
180
+
181
+ def __init__(
182
+ self,
183
+ hidden_size=768,
184
+ intermediate_size=3072,
185
+ projection_dim=512,
186
+ num_hidden_layers=12,
187
+ num_attention_heads=12,
188
+ num_channels=3,
189
+ image_size=224,
190
+ patch_size=32,
191
+ hidden_act="quick_gelu",
192
+ layer_norm_eps=1e-5,
193
+ attention_dropout=0.0,
194
+ initializer_range=0.02,
195
+ initializer_factor=1.0,
196
+
197
+ add_time_attn=False, ################################
198
+ num_frames=1, ################################
199
+ force_patch_dropout=0.0, ################################
200
+ lora_r=2, ################################
201
+ lora_alpha=16, ################################
202
+ lora_dropout=0.0, ################################
203
+ num_mel_bins=0.0, ################################
204
+ target_length=0.0, ################################
205
+ video_decode_backend='decord', #########################
206
+ **kwargs,
207
+ ):
208
+ super().__init__(**kwargs)
209
+
210
+ self.hidden_size = hidden_size
211
+ self.intermediate_size = intermediate_size
212
+ self.projection_dim = projection_dim
213
+ self.num_hidden_layers = num_hidden_layers
214
+ self.num_attention_heads = num_attention_heads
215
+ self.num_channels = num_channels
216
+ self.patch_size = patch_size
217
+ self.image_size = image_size
218
+ self.initializer_range = initializer_range
219
+ self.initializer_factor = initializer_factor
220
+ self.attention_dropout = attention_dropout
221
+ self.layer_norm_eps = layer_norm_eps
222
+ self.hidden_act = hidden_act
223
+
224
+ self.add_time_attn = add_time_attn ################
225
+ self.num_frames = num_frames ################
226
+ self.force_patch_dropout = force_patch_dropout ################
227
+ self.lora_r = lora_r ################
228
+ self.lora_alpha = lora_alpha ################
229
+ self.lora_dropout = lora_dropout ################
230
+ self.num_mel_bins = num_mel_bins ################
231
+ self.target_length = target_length ################
232
+ self.video_decode_backend = video_decode_backend ################
233
+
234
+ @classmethod
235
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
236
+ cls._set_token_in_kwargs(kwargs)
237
+
238
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
239
+
240
+ # get the vision config dict if we are loading from CLIPConfig
241
+ if config_dict.get("model_type") == "clip":
242
+ config_dict = config_dict["vision_config"]
243
+
244
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
245
+ logger.warning(
246
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
247
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
248
+ )
249
+
250
+ return cls.from_dict(config_dict, **kwargs)
251
+
252
+
253
+ class LanguageBindImageConfig(PretrainedConfig):
254
+ r"""
255
+ [`CLIPConfig`] is the configuration class to store the configuration of a [`CLIPModel`]. It is used to instantiate
256
+ a CLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating
257
+ a configuration with the defaults will yield a similar configuration to that of the CLIP
258
+ [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
259
+
260
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
261
+ documentation from [`PretrainedConfig`] for more information.
262
+
263
+ Args:
264
+ text_config (`dict`, *optional*):
265
+ Dictionary of configuration options used to initialize [`CLIPTextConfig`].
266
+ vision_config (`dict`, *optional*):
267
+ Dictionary of configuration options used to initialize [`CLIPVisionConfig`].
268
+ projection_dim (`int`, *optional*, defaults to 512):
269
+ Dimentionality of text and vision projection layers.
270
+ logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
271
+ The inital value of the *logit_scale* paramter. Default is used as per the original CLIP implementation.
272
+ kwargs (*optional*):
273
+ Dictionary of keyword arguments.
274
+
275
+ Example:
276
+
277
+ ```python
278
+ >>> from transformers import CLIPConfig, CLIPModel
279
+
280
+ >>> # Initializing a CLIPConfig with openai/clip-vit-base-patch32 style configuration
281
+ >>> configuration = CLIPConfig()
282
+
283
+ >>> # Initializing a CLIPModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
284
+ >>> model = CLIPModel(configuration)
285
+
286
+ >>> # Accessing the model configuration
287
+ >>> configuration = model.config
288
+
289
+ >>> # We can also initialize a CLIPConfig from a CLIPTextConfig and a CLIPVisionConfig
290
+ >>> from transformers import CLIPTextConfig, CLIPVisionConfig
291
+
292
+ >>> # Initializing a CLIPText and CLIPVision configuration
293
+ >>> config_text = CLIPTextConfig()
294
+ >>> config_vision = CLIPVisionConfig()
295
+
296
+ >>> config = CLIPConfig.from_text_vision_configs(config_text, config_vision)
297
+ ```"""
298
+
299
+ model_type = "LanguageBindImage"
300
+ is_composition = True
301
+
302
+ def __init__(
303
+ self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
304
+ ):
305
+ # If `_config_dict` exist, we use them for the backward compatibility.
306
+ # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
307
+ # of confusion!).
308
+ text_config_dict = kwargs.pop("text_config_dict", None)
309
+ vision_config_dict = kwargs.pop("vision_config_dict", None)
310
+
311
+ super().__init__(**kwargs)
312
+
313
+ # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
314
+ # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
315
+ # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
316
+ if text_config_dict is not None:
317
+ if text_config is None:
318
+ text_config = {}
319
+
320
+ # This is the complete result when using `text_config_dict`.
321
+ _text_config_dict = CLIPTextConfig(**text_config_dict).to_dict()
322
+
323
+ # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
324
+ for key, value in _text_config_dict.items():
325
+ if key in text_config and value != text_config[key] and key not in ["transformers_version"]:
326
+ # If specified in `text_config_dict`
327
+ if key in text_config_dict:
328
+ message = (
329
+ f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. "
330
+ f'The value `text_config_dict["{key}"]` will be used instead.'
331
+ )
332
+ # If inferred from default argument values (just to be super careful)
333
+ else:
334
+ message = (
335
+ f"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The "
336
+ f'value `text_config["{key}"]` will be overriden.'
337
+ )
338
+ logger.warning(message)
339
+
340
+ # Update all values in `text_config` with the ones in `_text_config_dict`.
341
+ text_config.update(_text_config_dict)
342
+
343
+ if vision_config_dict is not None:
344
+ if vision_config is None:
345
+ vision_config = {}
346
+
347
+ # This is the complete result when using `vision_config_dict`.
348
+ _vision_config_dict = CLIPVisionConfig(**vision_config_dict).to_dict()
349
+ # convert keys to string instead of integer
350
+ if "id2label" in _vision_config_dict:
351
+ _vision_config_dict["id2label"] = {
352
+ str(key): value for key, value in _vision_config_dict["id2label"].items()
353
+ }
354
+
355
+ # Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different.
356
+ for key, value in _vision_config_dict.items():
357
+ if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]:
358
+ # If specified in `vision_config_dict`
359
+ if key in vision_config_dict:
360
+ message = (
361
+ f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different "
362
+ f'values. The value `vision_config_dict["{key}"]` will be used instead.'
363
+ )
364
+ # If inferred from default argument values (just to be super careful)
365
+ else:
366
+ message = (
367
+ f"`vision_config_dict` is provided which will be used to initialize `CLIPVisionConfig`. "
368
+ f'The value `vision_config["{key}"]` will be overriden.'
369
+ )
370
+ logger.warning(message)
371
+
372
+ # Update all values in `vision_config` with the ones in `_vision_config_dict`.
373
+ vision_config.update(_vision_config_dict)
374
+
375
+ if text_config is None:
376
+ text_config = {}
377
+ logger.info("`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.")
378
+
379
+ if vision_config is None:
380
+ vision_config = {}
381
+ logger.info("`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.")
382
+
383
+ self.text_config = CLIPTextConfig(**text_config)
384
+ self.vision_config = CLIPVisionConfig(**vision_config)
385
+
386
+ self.projection_dim = projection_dim
387
+ self.logit_scale_init_value = logit_scale_init_value
388
+ self.initializer_factor = 1.0
389
+
390
+ @classmethod
391
+ def from_text_vision_configs(cls, text_config: CLIPTextConfig, vision_config: CLIPVisionConfig, **kwargs):
392
+ r"""
393
+ Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model
394
+ configuration.
395
+
396
+ Returns:
397
+ [`CLIPConfig`]: An instance of a configuration object
398
+ """
399
+
400
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
401
+
402
+ def to_dict(self):
403
+ """
404
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
405
+
406
+ Returns:
407
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
408
+ """
409
+ output = copy.deepcopy(self.__dict__)
410
+ output["text_config"] = self.text_config.to_dict()
411
+ output["vision_config"] = self.vision_config.to_dict()
412
+ output["model_type"] = self.__class__.model_type
413
+ return output
414
+
415
+
416
+
417
+
418
+
419
+
420
+
421
+
422
+
423
+
models/multimodal_encoder/languagebind/image/modeling_image.py ADDED
@@ -0,0 +1,1030 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from peft import LoraConfig, get_peft_model
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from transformers import PreTrainedModel, add_start_docstrings
10
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
11
+ from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \
12
+ CLIPVisionModelWithProjection, CLIPTextModelWithProjection, _expand_mask, CLIPOutput, clip_loss
13
+ from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
14
+
15
+ from .configuration_image import LanguageBindImageConfig, CLIPVisionConfig, CLIPTextConfig
16
+
17
+
18
+
19
+ class PatchDropout(nn.Module):
20
+ """
21
+ https://arxiv.org/abs/2212.00794
22
+ """
23
+
24
+ def __init__(self, prob, exclude_first_token=True):
25
+ super().__init__()
26
+ assert 0 <= prob < 1.
27
+ self.prob = prob
28
+ self.exclude_first_token = exclude_first_token # exclude CLS token
29
+
30
+ def forward(self, x, B, T):
31
+ if not self.training or self.prob == 0.:
32
+ return x
33
+
34
+ if self.exclude_first_token:
35
+ cls_tokens, x = x[:, :1], x[:, 1:]
36
+ else:
37
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
38
+
39
+ batch = x.size()[0]
40
+ num_tokens = x.size()[1]
41
+
42
+ batch_indices = torch.arange(batch)
43
+ batch_indices = batch_indices[..., None]
44
+
45
+ keep_prob = 1 - self.prob
46
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
47
+
48
+ if T == 1:
49
+ rand = torch.randn(batch, num_tokens)
50
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
51
+ else:
52
+ rand = torch.randn(B, num_tokens)
53
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
54
+ patch_indices_keep = patch_indices_keep.unsqueeze(1).repeat(1, T, 1)
55
+ patch_indices_keep = rearrange(patch_indices_keep, 'b t n -> (b t) n')
56
+
57
+
58
+ x = x[batch_indices, patch_indices_keep]
59
+
60
+ if self.exclude_first_token:
61
+ x = torch.cat((cls_tokens, x), dim=1)
62
+
63
+ return x
64
+
65
+ class CLIPEncoderLayer(nn.Module):
66
+ def __init__(self, config: LanguageBindImageConfig):
67
+ super().__init__()
68
+ self.embed_dim = config.hidden_size
69
+ self.self_attn = CLIPAttention(config)
70
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
71
+ self.mlp = CLIPMLP(config)
72
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
73
+
74
+ self.add_time_attn = config.add_time_attn
75
+ if self.add_time_attn:
76
+ self.t = config.num_frames
77
+ self.temporal_embedding = nn.Parameter(torch.zeros(1, config.num_frames, config.hidden_size))
78
+ nn.init.normal_(self.temporal_embedding, std=config.hidden_size ** -0.5)
79
+
80
+ self.embed_dim = config.hidden_size
81
+ self.temporal_attn = CLIPAttention(config)
82
+ self.temporal_layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
83
+ self.temporal_mlp = CLIPMLP(config)
84
+ self.temporal_layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
85
+
86
+ def forward(
87
+ self,
88
+ hidden_states: torch.Tensor,
89
+ attention_mask: torch.Tensor,
90
+ causal_attention_mask: torch.Tensor,
91
+ output_attentions: Optional[bool] = False,
92
+ ) -> Tuple[torch.FloatTensor]:
93
+ """
94
+ Args:
95
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
96
+ attention_mask (`torch.FloatTensor`): attention mask of size
97
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
98
+ `(config.encoder_attention_heads,)`.
99
+ output_attentions (`bool`, *optional*):
100
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
101
+ returned tensors for more detail.
102
+ """
103
+
104
+
105
+ if self.add_time_attn:
106
+ bt, n, d = hidden_states.shape
107
+ t = self.t
108
+
109
+ # time embed
110
+ if t != 1:
111
+ n = hidden_states.shape[1]
112
+ hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
113
+ hidden_states = hidden_states + self.temporal_embedding[:, :t, :]
114
+ hidden_states = rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
115
+
116
+ # time attn
117
+ residual = hidden_states
118
+ hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
119
+ # hidden_states = self.layer_norm1(hidden_states) # share layernorm
120
+ hidden_states = self.temporal_layer_norm1(hidden_states)
121
+ hidden_states, attn_weights = self.temporal_attn(
122
+ hidden_states=hidden_states,
123
+ attention_mask=attention_mask,
124
+ causal_attention_mask=causal_attention_mask,
125
+ output_attentions=output_attentions,
126
+ )
127
+ hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
128
+
129
+ residual = hidden_states
130
+ hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
131
+ # hidden_states = self.layer_norm2(hidden_states) # share layernorm
132
+ hidden_states = self.temporal_layer_norm2(hidden_states)
133
+ hidden_states = self.temporal_mlp(hidden_states)
134
+ hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
135
+
136
+ # spatial attn
137
+ residual = hidden_states
138
+
139
+ hidden_states = self.layer_norm1(hidden_states)
140
+ hidden_states, attn_weights = self.self_attn(
141
+ hidden_states=hidden_states,
142
+ attention_mask=attention_mask,
143
+ causal_attention_mask=causal_attention_mask,
144
+ output_attentions=output_attentions,
145
+ )
146
+ hidden_states = residual + hidden_states
147
+
148
+ residual = hidden_states
149
+ hidden_states = self.layer_norm2(hidden_states)
150
+ hidden_states = self.mlp(hidden_states)
151
+ hidden_states = residual + hidden_states
152
+
153
+ outputs = (hidden_states,)
154
+
155
+ if output_attentions:
156
+ outputs += (attn_weights,)
157
+
158
+ return outputs
159
+
160
+
161
+
162
+
163
+
164
+
165
+
166
+
167
+
168
+ class CLIPPreTrainedModel(PreTrainedModel):
169
+ """
170
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
171
+ models.
172
+ """
173
+
174
+ config_class = LanguageBindImageConfig
175
+ base_model_prefix = "clip"
176
+ supports_gradient_checkpointing = True
177
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
178
+
179
+ def _init_weights(self, module):
180
+ """Initialize the weights"""
181
+ factor = self.config.initializer_factor
182
+ if isinstance(module, CLIPTextEmbeddings):
183
+ module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
184
+ module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
185
+ elif isinstance(module, CLIPVisionEmbeddings):
186
+ factor = self.config.initializer_factor
187
+ nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
188
+ nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
189
+ nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
190
+ elif isinstance(module, CLIPAttention):
191
+ factor = self.config.initializer_factor
192
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
193
+ out_proj_std = (module.embed_dim**-0.5) * factor
194
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
195
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
196
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
197
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
198
+ elif isinstance(module, CLIPMLP):
199
+ factor = self.config.initializer_factor
200
+ in_proj_std = (
201
+ (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
202
+ )
203
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
204
+ nn.init.normal_(module.fc1.weight, std=fc_std)
205
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
206
+ elif isinstance(module, LanguageBindImage):
207
+ nn.init.normal_(
208
+ module.text_projection.weight,
209
+ std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
210
+ )
211
+ nn.init.normal_(
212
+ module.visual_projection.weight,
213
+ std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
214
+ )
215
+ elif isinstance(module, CLIPVisionModelWithProjection):
216
+ nn.init.normal_(
217
+ module.visual_projection.weight,
218
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
219
+ )
220
+ elif isinstance(module, CLIPTextModelWithProjection):
221
+ nn.init.normal_(
222
+ module.text_projection.weight,
223
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
224
+ )
225
+
226
+ if isinstance(module, nn.LayerNorm):
227
+ module.bias.data.zero_()
228
+ module.weight.data.fill_(1.0)
229
+ if isinstance(module, nn.Linear) and module.bias is not None:
230
+ module.bias.data.zero_()
231
+
232
+ def _set_gradient_checkpointing(self, module, value=False):
233
+ if isinstance(module, CLIPEncoder):
234
+ module.gradient_checkpointing = value
235
+
236
+
237
+ CLIP_START_DOCSTRING = r"""
238
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
239
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
240
+ etc.)
241
+
242
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
243
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
244
+ and behavior.
245
+
246
+ Parameters:
247
+ config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
248
+ Initializing with a config file does not load the weights associated with the model, only the
249
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
250
+ """
251
+
252
+ CLIP_TEXT_INPUTS_DOCSTRING = r"""
253
+ Args:
254
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
255
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
256
+ it.
257
+
258
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
259
+ [`PreTrainedTokenizer.__call__`] for details.
260
+
261
+ [What are input IDs?](../glossary#input-ids)
262
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
263
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
264
+
265
+ - 1 for tokens that are **not masked**,
266
+ - 0 for tokens that are **masked**.
267
+
268
+ [What are attention masks?](../glossary#attention-mask)
269
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
270
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
271
+ config.max_position_embeddings - 1]`.
272
+
273
+ [What are position IDs?](../glossary#position-ids)
274
+ output_attentions (`bool`, *optional*):
275
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
276
+ tensors for more detail.
277
+ output_hidden_states (`bool`, *optional*):
278
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
279
+ more detail.
280
+ return_dict (`bool`, *optional*):
281
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
282
+ """
283
+
284
+ CLIP_VISION_INPUTS_DOCSTRING = r"""
285
+ Args:
286
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
287
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
288
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
289
+ output_attentions (`bool`, *optional*):
290
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
291
+ tensors for more detail.
292
+ output_hidden_states (`bool`, *optional*):
293
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
294
+ more detail.
295
+ return_dict (`bool`, *optional*):
296
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
297
+ """
298
+
299
+ CLIP_INPUTS_DOCSTRING = r"""
300
+ Args:
301
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
302
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
303
+ it.
304
+
305
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
306
+ [`PreTrainedTokenizer.__call__`] for details.
307
+
308
+ [What are input IDs?](../glossary#input-ids)
309
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
310
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
311
+
312
+ - 1 for tokens that are **not masked**,
313
+ - 0 for tokens that are **masked**.
314
+
315
+ [What are attention masks?](../glossary#attention-mask)
316
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
317
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
318
+ config.max_position_embeddings - 1]`.
319
+
320
+ [What are position IDs?](../glossary#position-ids)
321
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
322
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
323
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
324
+ return_loss (`bool`, *optional*):
325
+ Whether or not to return the contrastive loss.
326
+ output_attentions (`bool`, *optional*):
327
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
328
+ tensors for more detail.
329
+ output_hidden_states (`bool`, *optional*):
330
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
331
+ more detail.
332
+ return_dict (`bool`, *optional*):
333
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
334
+ """
335
+
336
+
337
+ class CLIPEncoder(nn.Module):
338
+ """
339
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
340
+ [`CLIPEncoderLayer`].
341
+
342
+ Args:
343
+ config: CLIPConfig
344
+ """
345
+
346
+ def __init__(self, config: LanguageBindImageConfig):
347
+ super().__init__()
348
+ self.config = config
349
+ self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
350
+ self.gradient_checkpointing = False
351
+
352
+ def forward(
353
+ self,
354
+ inputs_embeds,
355
+ attention_mask: Optional[torch.Tensor] = None,
356
+ causal_attention_mask: Optional[torch.Tensor] = None,
357
+ output_attentions: Optional[bool] = None,
358
+ output_hidden_states: Optional[bool] = None,
359
+ return_dict: Optional[bool] = None,
360
+ ) -> Union[Tuple, BaseModelOutput]:
361
+ r"""
362
+ Args:
363
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
364
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
365
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
366
+ than the model's internal embedding lookup matrix.
367
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
368
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
369
+
370
+ - 1 for tokens that are **not masked**,
371
+ - 0 for tokens that are **masked**.
372
+
373
+ [What are attention masks?](../glossary#attention-mask)
374
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
375
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
376
+
377
+ - 1 for tokens that are **not masked**,
378
+ - 0 for tokens that are **masked**.
379
+
380
+ [What are attention masks?](../glossary#attention-mask)
381
+ output_attentions (`bool`, *optional*):
382
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
383
+ returned tensors for more detail.
384
+ output_hidden_states (`bool`, *optional*):
385
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
386
+ for more detail.
387
+ return_dict (`bool`, *optional*):
388
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
389
+ """
390
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
391
+ output_hidden_states = (
392
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
393
+ )
394
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
395
+
396
+ encoder_states = () if output_hidden_states else None
397
+ all_attentions = () if output_attentions else None
398
+
399
+ hidden_states = inputs_embeds
400
+ for idx, encoder_layer in enumerate(self.layers):
401
+ if output_hidden_states:
402
+ encoder_states = encoder_states + (hidden_states,)
403
+ if self.gradient_checkpointing and self.training:
404
+
405
+ def create_custom_forward(module):
406
+ def custom_forward(*inputs):
407
+ return module(*inputs, output_attentions)
408
+
409
+ return custom_forward
410
+
411
+ layer_outputs = torch.utils.checkpoint.checkpoint(
412
+ create_custom_forward(encoder_layer),
413
+ hidden_states,
414
+ attention_mask,
415
+ causal_attention_mask,
416
+ )
417
+ else:
418
+ layer_outputs = encoder_layer(
419
+ hidden_states,
420
+ attention_mask,
421
+ causal_attention_mask,
422
+ output_attentions=output_attentions,
423
+ )
424
+
425
+ hidden_states = layer_outputs[0]
426
+
427
+ if output_attentions:
428
+ all_attentions = all_attentions + (layer_outputs[1],)
429
+
430
+ if output_hidden_states:
431
+ encoder_states = encoder_states + (hidden_states,)
432
+
433
+ if not return_dict:
434
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
435
+ return BaseModelOutput(
436
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
437
+ )
438
+
439
+
440
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
441
+ def _make_causal_mask(
442
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
443
+ ):
444
+ """
445
+ Make causal mask used for bi-directional self-attention.
446
+ """
447
+ bsz, tgt_len = input_ids_shape
448
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
449
+ mask_cond = torch.arange(mask.size(-1), device=device)
450
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
451
+ mask = mask.to(dtype)
452
+
453
+ if past_key_values_length > 0:
454
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
455
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
456
+
457
+
458
+ class CLIPTextTransformer(nn.Module):
459
+ def __init__(self, config: CLIPTextConfig):
460
+ super().__init__()
461
+ self.config = config
462
+ embed_dim = config.hidden_size
463
+ self.embeddings = CLIPTextEmbeddings(config)
464
+ self.encoder = CLIPEncoder(config)
465
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
466
+
467
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
468
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
469
+ def forward(
470
+ self,
471
+ input_ids: Optional[torch.Tensor] = None,
472
+ attention_mask: Optional[torch.Tensor] = None,
473
+ position_ids: Optional[torch.Tensor] = None,
474
+ output_attentions: Optional[bool] = None,
475
+ output_hidden_states: Optional[bool] = None,
476
+ return_dict: Optional[bool] = None,
477
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
478
+ r"""
479
+ Returns:
480
+
481
+ """
482
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
483
+ output_hidden_states = (
484
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
485
+ )
486
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
487
+
488
+ if input_ids is None:
489
+ raise ValueError("You have to specify input_ids")
490
+
491
+ input_shape = input_ids.size()
492
+ input_ids = input_ids.view(-1, input_shape[-1])
493
+
494
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
495
+
496
+ # CLIP's text model uses causal mask, prepare it here.
497
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
498
+ causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
499
+ # expand attention_mask
500
+ if attention_mask is not None:
501
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
502
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
503
+
504
+ encoder_outputs = self.encoder(
505
+ inputs_embeds=hidden_states,
506
+ attention_mask=attention_mask,
507
+ causal_attention_mask=causal_attention_mask,
508
+ output_attentions=output_attentions,
509
+ output_hidden_states=output_hidden_states,
510
+ return_dict=return_dict,
511
+ )
512
+
513
+ last_hidden_state = encoder_outputs[0]
514
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
515
+
516
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
517
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
518
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
519
+ pooled_output = last_hidden_state[
520
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
521
+ input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
522
+ ]
523
+
524
+ if not return_dict:
525
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
526
+
527
+ return BaseModelOutputWithPooling(
528
+ last_hidden_state=last_hidden_state,
529
+ pooler_output=pooled_output,
530
+ hidden_states=encoder_outputs.hidden_states,
531
+ attentions=encoder_outputs.attentions,
532
+ )
533
+
534
+
535
+ @add_start_docstrings(
536
+ """The text model from CLIP without any head or projection on top.""",
537
+ CLIP_START_DOCSTRING,
538
+ )
539
+ class CLIPTextModel(CLIPPreTrainedModel):
540
+ config_class = CLIPTextConfig
541
+
542
+ _no_split_modules = ["CLIPEncoderLayer"]
543
+
544
+ def __init__(self, config: CLIPTextConfig):
545
+ super().__init__(config)
546
+ self.text_model = CLIPTextTransformer(config)
547
+ # Initialize weights and apply final processing
548
+ self.post_init()
549
+
550
+ def get_input_embeddings(self) -> nn.Module:
551
+ return self.text_model.embeddings.token_embedding
552
+
553
+ def set_input_embeddings(self, value):
554
+ self.text_model.embeddings.token_embedding = value
555
+
556
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
557
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
558
+ def forward(
559
+ self,
560
+ input_ids: Optional[torch.Tensor] = None,
561
+ attention_mask: Optional[torch.Tensor] = None,
562
+ position_ids: Optional[torch.Tensor] = None,
563
+ output_attentions: Optional[bool] = None,
564
+ output_hidden_states: Optional[bool] = None,
565
+ return_dict: Optional[bool] = None,
566
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
567
+ r"""
568
+ Returns:
569
+
570
+ Examples:
571
+
572
+ ```python
573
+ >>> from transformers import AutoTokenizer, CLIPTextModel
574
+
575
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
576
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
577
+
578
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
579
+
580
+ >>> outputs = model(**inputs)
581
+ >>> last_hidden_state = outputs.last_hidden_state
582
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
583
+ ```"""
584
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
585
+
586
+ return self.text_model(
587
+ input_ids=input_ids,
588
+ attention_mask=attention_mask,
589
+ position_ids=position_ids,
590
+ output_attentions=output_attentions,
591
+ output_hidden_states=output_hidden_states,
592
+ return_dict=return_dict,
593
+ )
594
+
595
+
596
+ class CLIPVisionTransformer(nn.Module):
597
+ def __init__(self, config: CLIPVisionConfig):
598
+ super().__init__()
599
+ self.config = config
600
+ embed_dim = config.hidden_size
601
+
602
+ self.embeddings = CLIPVisionEmbeddings(config)
603
+ self.patch_dropout = PatchDropout(config.force_patch_dropout)
604
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
605
+ self.encoder = CLIPEncoder(config)
606
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
607
+
608
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
609
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
610
+ def forward(
611
+ self,
612
+ pixel_values: Optional[torch.FloatTensor] = None,
613
+ output_attentions: Optional[bool] = None,
614
+ output_hidden_states: Optional[bool] = None,
615
+ return_dict: Optional[bool] = None,
616
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
617
+ r"""
618
+ Returns:
619
+
620
+ """
621
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
622
+ output_hidden_states = (
623
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
624
+ )
625
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
626
+
627
+ if pixel_values is None:
628
+ raise ValueError("You have to specify pixel_values")
629
+ ######################################
630
+ if len(pixel_values.shape) == 7:
631
+ b_new, pair_new, T, bs_new, channel_new, h_new, w_new = pixel_values.shape
632
+ # print(pixel_values.shape)
633
+ B = b_new * pair_new * bs_new
634
+ pixel_values = pixel_values.reshape(B*T, channel_new, h_new, w_new)
635
+
636
+ elif len(pixel_values.shape) == 5:
637
+ B, _, T, _, _ = pixel_values.shape
638
+ # print(pixel_values.shape)
639
+ pixel_values = rearrange(pixel_values, 'b c t h w -> (b t) c h w')
640
+ else:
641
+ # print(pixel_values.shape)
642
+ B, _, _, _ = pixel_values.shape
643
+ T = 1
644
+ ###########################
645
+ hidden_states = self.embeddings(pixel_values)
646
+
647
+ hidden_states = self.patch_dropout(hidden_states, B, T) ##############################################
648
+
649
+ hidden_states = self.pre_layrnorm(hidden_states)
650
+
651
+ encoder_outputs = self.encoder(
652
+ inputs_embeds=hidden_states,
653
+ output_attentions=output_attentions,
654
+ output_hidden_states=output_hidden_states,
655
+ return_dict=return_dict,
656
+ )
657
+
658
+ last_hidden_state = encoder_outputs[0]
659
+ pooled_output = last_hidden_state[:, 0, :]
660
+ pooled_output = self.post_layernorm(pooled_output)
661
+
662
+ pooled_output = pooled_output.reshape(B, T, -1).mean(1) ################################
663
+
664
+ if not return_dict:
665
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
666
+
667
+ return BaseModelOutputWithPooling(
668
+ last_hidden_state=last_hidden_state,
669
+ pooler_output=pooled_output,
670
+ hidden_states=encoder_outputs.hidden_states,
671
+ attentions=encoder_outputs.attentions,
672
+ )
673
+
674
+
675
+ @add_start_docstrings(
676
+ """The vision model from CLIP without any head or projection on top.""",
677
+ CLIP_START_DOCSTRING,
678
+ )
679
+ class CLIPVisionModel(CLIPPreTrainedModel):
680
+ config_class = CLIPVisionConfig
681
+ main_input_name = "pixel_values"
682
+
683
+ def __init__(self, config: CLIPVisionConfig):
684
+ super().__init__(config)
685
+ self.vision_model = CLIPVisionTransformer(config)
686
+ # Initialize weights and apply final processing
687
+ self.post_init()
688
+
689
+ def get_input_embeddings(self) -> nn.Module:
690
+ return self.vision_model.embeddings.patch_embedding
691
+
692
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
693
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
694
+ def forward(
695
+ self,
696
+ pixel_values: Optional[torch.FloatTensor] = None,
697
+ output_attentions: Optional[bool] = None,
698
+ output_hidden_states: Optional[bool] = None,
699
+ return_dict: Optional[bool] = None,
700
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
701
+ r"""
702
+ Returns:
703
+
704
+ Examples:
705
+
706
+ ```python
707
+ >>> from PIL import Image
708
+ >>> import requests
709
+ >>> from transformers import AutoProcessor, CLIPVisionModel
710
+
711
+ >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
712
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
713
+
714
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
715
+ >>> image = Image.open(requests.get(url, stream=True).raw)
716
+
717
+ >>> inputs = processor(images=image, return_tensors="pt")
718
+
719
+ >>> outputs = model(**inputs)
720
+ >>> last_hidden_state = outputs.last_hidden_state
721
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
722
+ ```"""
723
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
724
+
725
+ return self.vision_model(
726
+ pixel_values=pixel_values,
727
+ output_attentions=output_attentions,
728
+ output_hidden_states=output_hidden_states,
729
+ return_dict=return_dict,
730
+ )
731
+
732
+
733
+ @add_start_docstrings(CLIP_START_DOCSTRING)
734
+ class LanguageBindImage(CLIPPreTrainedModel):
735
+ config_class = LanguageBindImageConfig
736
+
737
+ def __init__(self, config: LanguageBindImageConfig):
738
+ super().__init__(config)
739
+
740
+ if not isinstance(config.text_config, CLIPTextConfig):
741
+ raise ValueError(
742
+ "config.text_config is expected to be of type CLIPTextConfig but is of type"
743
+ f" {type(config.text_config)}."
744
+ )
745
+
746
+ if not isinstance(config.vision_config, CLIPVisionConfig):
747
+ raise ValueError(
748
+ "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
749
+ f" {type(config.vision_config)}."
750
+ )
751
+
752
+ text_config = config.text_config
753
+ vision_config = config.vision_config
754
+ self.add_time_attn = vision_config.add_time_attn
755
+ self.lora_r = vision_config.lora_r
756
+ self.lora_alpha = vision_config.lora_alpha
757
+ self.lora_dropout = vision_config.lora_dropout
758
+
759
+ self.projection_dim = config.projection_dim
760
+ self.text_embed_dim = text_config.hidden_size
761
+ self.vision_embed_dim = vision_config.hidden_size
762
+
763
+ self.text_model = CLIPTextTransformer(text_config)
764
+ self.vision_model = CLIPVisionTransformer(vision_config)
765
+
766
+ self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
767
+ self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
768
+ self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
769
+
770
+ # Initialize weights and apply final processing
771
+ self.post_init()
772
+ self.convert_to_lora()
773
+ # self.resize_pos(self.vision_model.embeddings, vision_config)
774
+
775
+ def convert_to_lora(self):
776
+ if self.lora_r == 0:
777
+ return
778
+ if self.add_time_attn:
779
+ target_modules = ["temporal_attn.k_proj", "temporal_attn.v_proj",
780
+ "temporal_attn.q_proj", "temporal_attn.out_proj",
781
+ "temporal_mlp.fc1", "temporal_mlp.fc2"]
782
+ else:
783
+ target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"]
784
+ config = LoraConfig(
785
+ r=self.lora_r, # 16
786
+ lora_alpha=self.lora_alpha, # 16
787
+ target_modules=target_modules, # self_attn.out_proj
788
+ lora_dropout=self.lora_dropout, # 0.1
789
+ bias="none",
790
+ modules_to_save=[],
791
+ )
792
+ self.vision_model.encoder.is_gradient_checkpointing = False
793
+ self.vision_model.encoder = get_peft_model(self.vision_model.encoder, config)
794
+
795
+ def resize_pos(self, m, vision_config):
796
+ # convert embedding
797
+ if vision_config.num_mel_bins!=0 and vision_config.target_length!=0:
798
+ m.image_size = [vision_config.num_mel_bins, vision_config.target_length]
799
+ m.config.image_size = [m.image_size, m.image_size] if isinstance(m.image_size, int) else m.image_size
800
+ # pos resize
801
+ old_pos_embed_state_dict = m.position_embedding.state_dict()
802
+ old_pos_embed = old_pos_embed_state_dict['weight']
803
+ dtype = old_pos_embed.dtype
804
+ grid_size = [m.config.image_size[0] // m.patch_size, m.config.image_size[1] // m.patch_size]
805
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
806
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
807
+ if new_seq_len == old_pos_embed.shape[0]:
808
+ # m.to(args.device)
809
+ return
810
+
811
+ m.num_patches = grid_size[0] * grid_size[1]
812
+ m.num_positions = m.num_patches + 1
813
+ m.register_buffer("position_ids", torch.arange(m.num_positions).expand((1, -1)))
814
+ new_position_embedding = nn.Embedding(m.num_positions, m.embed_dim)
815
+
816
+ if extra_tokens:
817
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
818
+ else:
819
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
820
+ old_grid_size = [int(math.sqrt(len(pos_emb_img)))] * 2
821
+
822
+ # if is_master(args):
823
+ # logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
824
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
825
+ pos_emb_img = F.interpolate(
826
+ pos_emb_img,
827
+ size=grid_size,
828
+ mode='bicubic',
829
+ antialias=True,
830
+ align_corners=False,
831
+ )
832
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
833
+ if pos_emb_tok is not None:
834
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
835
+ else:
836
+ new_pos_embed = pos_emb_img
837
+ old_pos_embed_state_dict['weight'] = new_pos_embed.to(dtype)
838
+ m.position_embedding = new_position_embedding
839
+ m.position_embedding.load_state_dict(old_pos_embed_state_dict)
840
+
841
+ # m.to(args.device)
842
+
843
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
844
+ def get_text_features(
845
+ self,
846
+ input_ids: Optional[torch.Tensor] = None,
847
+ attention_mask: Optional[torch.Tensor] = None,
848
+ position_ids: Optional[torch.Tensor] = None,
849
+ output_attentions: Optional[bool] = None,
850
+ output_hidden_states: Optional[bool] = None,
851
+ return_dict: Optional[bool] = None,
852
+ ) -> torch.FloatTensor:
853
+ r"""
854
+ Returns:
855
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
856
+ applying the projection layer to the pooled output of [`CLIPTextModel`].
857
+
858
+ Examples:
859
+
860
+ ```python
861
+ >>> from transformers import AutoTokenizer, CLIPModel
862
+
863
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
864
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
865
+
866
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
867
+ >>> text_features = model.get_text_features(**inputs)
868
+ ```"""
869
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
870
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
871
+ output_hidden_states = (
872
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
873
+ )
874
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
875
+
876
+ text_outputs = self.text_model(
877
+ input_ids=input_ids,
878
+ attention_mask=attention_mask,
879
+ position_ids=position_ids,
880
+ output_attentions=output_attentions,
881
+ output_hidden_states=output_hidden_states,
882
+ return_dict=return_dict,
883
+ )
884
+
885
+ pooled_output = text_outputs[1]
886
+ text_features = self.text_projection(pooled_output)
887
+
888
+ return text_features
889
+
890
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
891
+ def get_image_features(
892
+ self,
893
+ pixel_values: Optional[torch.FloatTensor] = None,
894
+ output_attentions: Optional[bool] = None,
895
+ output_hidden_states: Optional[bool] = None,
896
+ return_dict: Optional[bool] = None,
897
+ ) -> torch.FloatTensor:
898
+ r"""
899
+ Returns:
900
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
901
+ applying the projection layer to the pooled output of [`CLIPVisionModel`].
902
+
903
+ Examples:
904
+
905
+ ```python
906
+ >>> from PIL import Image
907
+ >>> import requests
908
+ >>> from transformers import AutoProcessor, CLIPModel
909
+
910
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
911
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
912
+
913
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
914
+ >>> image = Image.open(requests.get(url, stream=True).raw)
915
+
916
+ >>> inputs = processor(images=image, return_tensors="pt")
917
+
918
+ >>> image_features = model.get_image_features(**inputs)
919
+ ```"""
920
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
921
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
922
+ output_hidden_states = (
923
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
924
+ )
925
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
926
+
927
+ vision_outputs = self.vision_model(
928
+ pixel_values=pixel_values,
929
+ output_attentions=output_attentions,
930
+ output_hidden_states=output_hidden_states,
931
+ return_dict=return_dict,
932
+ )
933
+
934
+ pooled_output = vision_outputs[1] # pooled_output
935
+ image_features = self.visual_projection(pooled_output)
936
+
937
+ return image_features
938
+
939
+ @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
940
+ @replace_return_docstrings(output_type=CLIPOutput, config_class=LanguageBindImageConfig)
941
+ def forward(
942
+ self,
943
+ input_ids: Optional[torch.LongTensor] = None,
944
+ pixel_values: Optional[torch.FloatTensor] = None,
945
+ attention_mask: Optional[torch.Tensor] = None,
946
+ position_ids: Optional[torch.LongTensor] = None,
947
+ return_loss: Optional[bool] = None,
948
+ output_attentions: Optional[bool] = None,
949
+ output_hidden_states: Optional[bool] = None,
950
+ return_dict: Optional[bool] = None,
951
+ ) -> Union[Tuple, CLIPOutput]:
952
+ r"""
953
+ Returns:
954
+
955
+ Examples:
956
+
957
+ ```python
958
+ >>> from PIL import Image
959
+ >>> import requests
960
+ >>> from transformers import AutoProcessor, CLIPModel
961
+
962
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
963
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
964
+
965
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
966
+ >>> image = Image.open(requests.get(url, stream=True).raw)
967
+
968
+ >>> inputs = processor(
969
+ ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
970
+ ... )
971
+
972
+ >>> outputs = model(**inputs)
973
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
974
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
975
+ ```"""
976
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
977
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
978
+ output_hidden_states = (
979
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
980
+ )
981
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
982
+
983
+ vision_outputs = self.vision_model(
984
+ pixel_values=pixel_values,
985
+ output_attentions=output_attentions,
986
+ output_hidden_states=output_hidden_states,
987
+ return_dict=return_dict,
988
+ )
989
+
990
+ text_outputs = self.text_model(
991
+ input_ids=input_ids,
992
+ attention_mask=attention_mask,
993
+ position_ids=position_ids,
994
+ output_attentions=output_attentions,
995
+ output_hidden_states=output_hidden_states,
996
+ return_dict=return_dict,
997
+ )
998
+
999
+ image_embeds = vision_outputs[1]
1000
+ image_embeds = self.visual_projection(image_embeds)
1001
+
1002
+ text_embeds = text_outputs[1]
1003
+ text_embeds = self.text_projection(text_embeds)
1004
+
1005
+ # normalized features
1006
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1007
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1008
+
1009
+ # cosine similarity as logits
1010
+ logit_scale = self.logit_scale.exp()
1011
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
1012
+ logits_per_image = logits_per_text.t()
1013
+
1014
+ loss = None
1015
+ if return_loss:
1016
+ loss = clip_loss(logits_per_text)
1017
+
1018
+ if not return_dict:
1019
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1020
+ return ((loss,) + output) if loss is not None else output
1021
+
1022
+ return CLIPOutput(
1023
+ loss=loss,
1024
+ logits_per_image=logits_per_image,
1025
+ logits_per_text=logits_per_text,
1026
+ text_embeds=text_embeds,
1027
+ image_embeds=image_embeds,
1028
+ text_model_output=text_outputs,
1029
+ vision_model_output=vision_outputs,
1030
+ )
models/multimodal_encoder/languagebind/image/processing_image.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from torchvision import transforms
4
+ from transformers import ProcessorMixin, BatchEncoding
5
+ from transformers.image_processing_utils import BatchFeature
6
+
7
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
8
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
9
+
10
+ def make_list_of_images(x):
11
+ if not isinstance(x, list):
12
+ return [x]
13
+ return x
14
+
15
+ def get_image_transform(config):
16
+ config = config.vision_config
17
+ transform = transforms.Compose(
18
+ [
19
+ transforms.ToTensor(),
20
+ transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
21
+ transforms.CenterCrop(224),
22
+ transforms.Normalize(OPENAI_DATASET_MEAN, OPENAI_DATASET_STD) # assume image
23
+ ]
24
+ )
25
+ return transform
26
+
27
+
28
+ def load_and_transform_image(image_path, transform):
29
+ image = Image.open(image_path).convert('RGB') if isinstance(image_path, str) else image_path
30
+ image_outputs = transform(image)
31
+ return image_outputs
32
+
33
+ class LanguageBindImageProcessor(ProcessorMixin):
34
+ attributes = []
35
+ tokenizer_class = ("LanguageBindImageTokenizer")
36
+
37
+ def __init__(self, config, tokenizer=None, **kwargs):
38
+ super().__init__(**kwargs)
39
+ self.config = config
40
+ self.transform = get_image_transform(config)
41
+ self.image_processor = load_and_transform_image
42
+ self.tokenizer = tokenizer
43
+ self.image_mean = OPENAI_DATASET_MEAN
44
+ self.crop_size = {'height': 224, 'width': 224}
45
+
46
+ def __call__(self, images=None, text=None, context_length=77, return_tensors=None, **kwargs):
47
+ if text is None and images is None:
48
+ raise ValueError("You have to specify either text or images. Both cannot be none.")
49
+
50
+ if text is not None:
51
+ encoding = self.tokenizer(text, max_length=context_length, padding='max_length',
52
+ truncation=True, return_tensors=return_tensors, **kwargs)
53
+
54
+ if images is not None:
55
+ images = make_list_of_images(images)
56
+ image_features = [self.image_processor(image, self.transform) for image in images]
57
+ image_features = torch.stack(image_features)
58
+
59
+ if text is not None and images is not None:
60
+ encoding["pixel_values"] = image_features
61
+ return encoding
62
+ elif text is not None:
63
+ return encoding
64
+ else:
65
+ return {"pixel_values": image_features}
66
+
67
+ def preprocess(self, images, return_tensors):
68
+ return self.__call__(images=images, return_tensors=return_tensors)
69
+
70
+ def batch_decode(self, skip_special_tokens=True, *args, **kwargs):
71
+ """
72
+ This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
73
+ refer to the docstring of this method for more information.
74
+ """
75
+ return self.tokenizer.batch_decode(*args, skip_special_tokens=skip_special_tokens, **kwargs)
76
+
77
+ def decode(self, skip_special_tokens=True, *args, **kwargs):
78
+ """
79
+ This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
80
+ the docstring of this method for more information.
81
+ """
82
+ return self.tokenizer.decode(*args, skip_special_tokens=skip_special_tokens, **kwargs)
models/multimodal_encoder/languagebind/image/tokenization_image.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPTokenizer
2
+ from transformers.utils import logging
3
+
4
+ logger = logging.get_logger(__name__)
5
+
6
+ VOCAB_FILES_NAMES = {
7
+ "vocab_file": "vocab.json",
8
+ "merges_file": "merges.txt",
9
+ }
10
+
11
+ PRETRAINED_VOCAB_FILES_MAP = {
12
+ "vocab_file": {
13
+ "lb203/LanguageBind-Image": "https://huggingface.co/lb203/LanguageBind-Image/resolve/main/vocab.json",
14
+ },
15
+ "merges_file": {
16
+ "lb203/LanguageBind-Image": "https://huggingface.co/lb203/LanguageBind-Image/resolve/main/merges.txt",
17
+ },
18
+ }
19
+
20
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
21
+ "lb203/LanguageBind-Image": 77,
22
+ }
23
+
24
+
25
+ PRETRAINED_INIT_CONFIGURATION = {
26
+ "lb203/LanguageBind-Image": {},
27
+ }
28
+
29
+ class LanguageBindImageTokenizer(CLIPTokenizer):
30
+ """
31
+ Construct a CLIP tokenizer. Based on byte-level Byte-Pair-Encoding.
32
+
33
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
34
+ this superclass for more information regarding those methods.
35
+
36
+ Args:
37
+ vocab_file (`str`):
38
+ Path to the vocabulary file.
39
+ merges_file (`str`):
40
+ Path to the merges file.
41
+ errors (`str`, *optional*, defaults to `"replace"`):
42
+ Paradigm to follow when decoding bytes to UTF-8. See
43
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
44
+ unk_token (`str`, *optional*, defaults to `<|endoftext|>`):
45
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
46
+ token instead.
47
+ bos_token (`str`, *optional*, defaults to `<|startoftext|>`):
48
+ The beginning of sequence token.
49
+ eos_token (`str`, *optional*, defaults to `<|endoftext|>`):
50
+ The end of sequence token.
51
+ """
52
+
53
+ vocab_files_names = VOCAB_FILES_NAMES
54
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
55
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
56
+ model_input_names = ["input_ids", "attention_mask"]
57
+
58
+ def __init__(
59
+ self,
60
+ vocab_file,
61
+ merges_file,
62
+ errors="replace",
63
+ unk_token="<|endoftext|>",
64
+ bos_token="<|startoftext|>",
65
+ eos_token="<|endoftext|>",
66
+ pad_token="<|endoftext|>", # hack to enable padding
67
+ **kwargs,
68
+ ):
69
+ super(LanguageBindImageTokenizer, self).__init__(
70
+ vocab_file,
71
+ merges_file,
72
+ errors,
73
+ unk_token,
74
+ bos_token,
75
+ eos_token,
76
+ pad_token, # hack to enable padding
77
+ **kwargs,)
models/multimodal_encoder/languagebind/thermal/configuration_thermal.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ from typing import Union
4
+
5
+ from transformers import PretrainedConfig
6
+ from transformers.utils import logging
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+
11
+
12
+
13
+
14
+
15
+
16
+ class CLIPTextConfig(PretrainedConfig):
17
+ r"""
18
+ This is the configuration class to store the configuration of a [`CLIPTextModel`]. It is used to instantiate a CLIP
19
+ text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
20
+ with the defaults will yield a similar configuration to that of the text encoder of the CLIP
21
+ [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
22
+
23
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
24
+ documentation from [`PretrainedConfig`] for more information.
25
+
26
+ Args:
27
+ vocab_size (`int`, *optional*, defaults to 49408):
28
+ Vocabulary size of the CLIP text model. Defines the number of different tokens that can be represented by
29
+ the `inputs_ids` passed when calling [`CLIPModel`].
30
+ hidden_size (`int`, *optional*, defaults to 512):
31
+ Dimensionality of the encoder layers and the pooler layer.
32
+ intermediate_size (`int`, *optional*, defaults to 2048):
33
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
34
+ num_hidden_layers (`int`, *optional*, defaults to 12):
35
+ Number of hidden layers in the Transformer encoder.
36
+ num_attention_heads (`int`, *optional*, defaults to 8):
37
+ Number of attention heads for each attention layer in the Transformer encoder.
38
+ max_position_embeddings (`int`, *optional*, defaults to 77):
39
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
40
+ just in case (e.g., 512 or 1024 or 2048).
41
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
42
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
43
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
44
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
45
+ The epsilon used by the layer normalization layers.
46
+ attention_dropout (`float`, *optional*, defaults to 0.0):
47
+ The dropout ratio for the attention probabilities.
48
+ initializer_range (`float`, *optional*, defaults to 0.02):
49
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
50
+ initializer_factor (`float`, *optional*, defaults to 1):
51
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
52
+ testing).
53
+
54
+ Example:
55
+
56
+ ```python
57
+ >>> from transformers import CLIPTextConfig, CLIPTextModel
58
+
59
+ >>> # Initializing a CLIPTextConfig with openai/clip-vit-base-patch32 style configuration
60
+ >>> configuration = CLIPTextConfig()
61
+
62
+ >>> # Initializing a CLIPTextModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
63
+ >>> model = CLIPTextModel(configuration)
64
+
65
+ >>> # Accessing the model configuration
66
+ >>> configuration = model.config
67
+ ```"""
68
+ model_type = "clip_text_model"
69
+
70
+ def __init__(
71
+ self,
72
+ vocab_size=49408,
73
+ hidden_size=512,
74
+ intermediate_size=2048,
75
+ projection_dim=512,
76
+ num_hidden_layers=12,
77
+ num_attention_heads=8,
78
+ max_position_embeddings=77,
79
+ hidden_act="quick_gelu",
80
+ layer_norm_eps=1e-5,
81
+ attention_dropout=0.0,
82
+ initializer_range=0.02,
83
+ initializer_factor=1.0,
84
+ # This differs from `CLIPTokenizer`'s default and from openai/clip
85
+ # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
86
+ pad_token_id=1,
87
+ bos_token_id=49406,
88
+ eos_token_id=49407,
89
+ **kwargs,
90
+ ):
91
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
92
+
93
+ self.vocab_size = vocab_size
94
+ self.hidden_size = hidden_size
95
+ self.intermediate_size = intermediate_size
96
+ self.projection_dim = projection_dim
97
+ self.num_hidden_layers = num_hidden_layers
98
+ self.num_attention_heads = num_attention_heads
99
+ self.max_position_embeddings = max_position_embeddings
100
+ self.layer_norm_eps = layer_norm_eps
101
+ self.hidden_act = hidden_act
102
+ self.initializer_range = initializer_range
103
+ self.initializer_factor = initializer_factor
104
+ self.attention_dropout = attention_dropout
105
+ self.add_time_attn = False ######################################
106
+
107
+ @classmethod
108
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
109
+ cls._set_token_in_kwargs(kwargs)
110
+
111
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
112
+
113
+ # get the text config dict if we are loading from CLIPConfig
114
+ if config_dict.get("model_type") == "clip":
115
+ config_dict = config_dict["text_config"]
116
+
117
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
118
+ logger.warning(
119
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
120
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
121
+ )
122
+
123
+ return cls.from_dict(config_dict, **kwargs)
124
+
125
+
126
+
127
+
128
+ class CLIPVisionConfig(PretrainedConfig):
129
+ r"""
130
+ This is the configuration class to store the configuration of a [`CLIPVisionModel`]. It is used to instantiate a
131
+ CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a
132
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP
133
+ [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
134
+
135
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
136
+ documentation from [`PretrainedConfig`] for more information.
137
+
138
+ Args:
139
+ hidden_size (`int`, *optional*, defaults to 768):
140
+ Dimensionality of the encoder layers and the pooler layer.
141
+ intermediate_size (`int`, *optional*, defaults to 3072):
142
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
143
+ num_hidden_layers (`int`, *optional*, defaults to 12):
144
+ Number of hidden layers in the Transformer encoder.
145
+ num_attention_heads (`int`, *optional*, defaults to 12):
146
+ Number of attention heads for each attention layer in the Transformer encoder.
147
+ image_size (`int`, *optional*, defaults to 224):
148
+ The size (resolution) of each image.
149
+ patch_size (`int`, *optional*, defaults to 32):
150
+ The size (resolution) of each patch.
151
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
152
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
153
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
154
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
155
+ The epsilon used by the layer normalization layers.
156
+ attention_dropout (`float`, *optional*, defaults to 0.0):
157
+ The dropout ratio for the attention probabilities.
158
+ initializer_range (`float`, *optional*, defaults to 0.02):
159
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
160
+ initializer_factor (`float`, *optional*, defaults to 1):
161
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
162
+ testing).
163
+
164
+ Example:
165
+
166
+ ```python
167
+ >>> from transformers import CLIPVisionConfig, CLIPVisionModel
168
+
169
+ >>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration
170
+ >>> configuration = CLIPVisionConfig()
171
+
172
+ >>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
173
+ >>> model = CLIPVisionModel(configuration)
174
+
175
+ >>> # Accessing the model configuration
176
+ >>> configuration = model.config
177
+ ```"""
178
+
179
+ model_type = "clip_vision_model"
180
+
181
+ def __init__(
182
+ self,
183
+ hidden_size=768,
184
+ intermediate_size=3072,
185
+ projection_dim=512,
186
+ num_hidden_layers=12,
187
+ num_attention_heads=12,
188
+ num_channels=3,
189
+ image_size=224,
190
+ patch_size=32,
191
+ hidden_act="quick_gelu",
192
+ layer_norm_eps=1e-5,
193
+ attention_dropout=0.0,
194
+ initializer_range=0.02,
195
+ initializer_factor=1.0,
196
+
197
+ add_time_attn=False, ################################
198
+ num_frames=1, ################################
199
+ force_patch_dropout=0.0, ################################
200
+ lora_r=2, ################################
201
+ lora_alpha=16, ################################
202
+ lora_dropout=0.0, ################################
203
+ num_mel_bins=0.0, ################################
204
+ target_length=0.0, ################################
205
+ video_decode_backend='decord', #########################
206
+ **kwargs,
207
+ ):
208
+ super().__init__(**kwargs)
209
+
210
+ self.hidden_size = hidden_size
211
+ self.intermediate_size = intermediate_size
212
+ self.projection_dim = projection_dim
213
+ self.num_hidden_layers = num_hidden_layers
214
+ self.num_attention_heads = num_attention_heads
215
+ self.num_channels = num_channels
216
+ self.patch_size = patch_size
217
+ self.image_size = image_size
218
+ self.initializer_range = initializer_range
219
+ self.initializer_factor = initializer_factor
220
+ self.attention_dropout = attention_dropout
221
+ self.layer_norm_eps = layer_norm_eps
222
+ self.hidden_act = hidden_act
223
+
224
+ self.add_time_attn = add_time_attn ################
225
+ self.num_frames = num_frames ################
226
+ self.force_patch_dropout = force_patch_dropout ################
227
+ self.lora_r = lora_r ################
228
+ self.lora_alpha = lora_alpha ################
229
+ self.lora_dropout = lora_dropout ################
230
+ self.num_mel_bins = num_mel_bins ################
231
+ self.target_length = target_length ################
232
+ self.video_decode_backend = video_decode_backend ################
233
+
234
+ @classmethod
235
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
236
+ cls._set_token_in_kwargs(kwargs)
237
+
238
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
239
+
240
+ # get the vision config dict if we are loading from CLIPConfig
241
+ if config_dict.get("model_type") == "clip":
242
+ config_dict = config_dict["vision_config"]
243
+
244
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
245
+ logger.warning(
246
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
247
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
248
+ )
249
+
250
+ return cls.from_dict(config_dict, **kwargs)
251
+
252
+
253
+ class LanguageBindThermalConfig(PretrainedConfig):
254
+ r"""
255
+ [`CLIPConfig`] is the configuration class to store the configuration of a [`CLIPModel`]. It is used to instantiate
256
+ a CLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating
257
+ a configuration with the defaults will yield a similar configuration to that of the CLIP
258
+ [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
259
+
260
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
261
+ documentation from [`PretrainedConfig`] for more information.
262
+
263
+ Args:
264
+ text_config (`dict`, *optional*):
265
+ Dictionary of configuration options used to initialize [`CLIPTextConfig`].
266
+ vision_config (`dict`, *optional*):
267
+ Dictionary of configuration options used to initialize [`CLIPVisionConfig`].
268
+ projection_dim (`int`, *optional*, defaults to 512):
269
+ Dimentionality of text and vision projection layers.
270
+ logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
271
+ The inital value of the *logit_scale* paramter. Default is used as per the original CLIP implementation.
272
+ kwargs (*optional*):
273
+ Dictionary of keyword arguments.
274
+
275
+ Example:
276
+
277
+ ```python
278
+ >>> from transformers import CLIPConfig, CLIPModel
279
+
280
+ >>> # Initializing a CLIPConfig with openai/clip-vit-base-patch32 style configuration
281
+ >>> configuration = CLIPConfig()
282
+
283
+ >>> # Initializing a CLIPModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
284
+ >>> model = CLIPModel(configuration)
285
+
286
+ >>> # Accessing the model configuration
287
+ >>> configuration = model.config
288
+
289
+ >>> # We can also initialize a CLIPConfig from a CLIPTextConfig and a CLIPVisionConfig
290
+ >>> from transformers import CLIPTextConfig, CLIPVisionConfig
291
+
292
+ >>> # Initializing a CLIPText and CLIPVision configuration
293
+ >>> config_text = CLIPTextConfig()
294
+ >>> config_vision = CLIPVisionConfig()
295
+
296
+ >>> config = CLIPConfig.from_text_vision_configs(config_text, config_vision)
297
+ ```"""
298
+
299
+ model_type = "LanguageBindThermal"
300
+ is_composition = True
301
+
302
+ def __init__(
303
+ self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
304
+ ):
305
+ # If `_config_dict` exist, we use them for the backward compatibility.
306
+ # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
307
+ # of confusion!).
308
+ text_config_dict = kwargs.pop("text_config_dict", None)
309
+ vision_config_dict = kwargs.pop("vision_config_dict", None)
310
+
311
+ super().__init__(**kwargs)
312
+
313
+ # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
314
+ # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
315
+ # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
316
+ if text_config_dict is not None:
317
+ if text_config is None:
318
+ text_config = {}
319
+
320
+ # This is the complete result when using `text_config_dict`.
321
+ _text_config_dict = CLIPTextConfig(**text_config_dict).to_dict()
322
+
323
+ # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
324
+ for key, value in _text_config_dict.items():
325
+ if key in text_config and value != text_config[key] and key not in ["transformers_version"]:
326
+ # If specified in `text_config_dict`
327
+ if key in text_config_dict:
328
+ message = (
329
+ f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. "
330
+ f'The value `text_config_dict["{key}"]` will be used instead.'
331
+ )
332
+ # If inferred from default argument values (just to be super careful)
333
+ else:
334
+ message = (
335
+ f"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The "
336
+ f'value `text_config["{key}"]` will be overriden.'
337
+ )
338
+ logger.warning(message)
339
+
340
+ # Update all values in `text_config` with the ones in `_text_config_dict`.
341
+ text_config.update(_text_config_dict)
342
+
343
+ if vision_config_dict is not None:
344
+ if vision_config is None:
345
+ vision_config = {}
346
+
347
+ # This is the complete result when using `vision_config_dict`.
348
+ _vision_config_dict = CLIPVisionConfig(**vision_config_dict).to_dict()
349
+ # convert keys to string instead of integer
350
+ if "id2label" in _vision_config_dict:
351
+ _vision_config_dict["id2label"] = {
352
+ str(key): value for key, value in _vision_config_dict["id2label"].items()
353
+ }
354
+
355
+ # Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different.
356
+ for key, value in _vision_config_dict.items():
357
+ if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]:
358
+ # If specified in `vision_config_dict`
359
+ if key in vision_config_dict:
360
+ message = (
361
+ f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different "
362
+ f'values. The value `vision_config_dict["{key}"]` will be used instead.'
363
+ )
364
+ # If inferred from default argument values (just to be super careful)
365
+ else:
366
+ message = (
367
+ f"`vision_config_dict` is provided which will be used to initialize `CLIPVisionConfig`. "
368
+ f'The value `vision_config["{key}"]` will be overriden.'
369
+ )
370
+ logger.warning(message)
371
+
372
+ # Update all values in `vision_config` with the ones in `_vision_config_dict`.
373
+ vision_config.update(_vision_config_dict)
374
+
375
+ if text_config is None:
376
+ text_config = {}
377
+ logger.info("`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.")
378
+
379
+ if vision_config is None:
380
+ vision_config = {}
381
+ logger.info("`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.")
382
+
383
+ self.text_config = CLIPTextConfig(**text_config)
384
+ self.vision_config = CLIPVisionConfig(**vision_config)
385
+
386
+ self.projection_dim = projection_dim
387
+ self.logit_scale_init_value = logit_scale_init_value
388
+ self.initializer_factor = 1.0
389
+
390
+ @classmethod
391
+ def from_text_vision_configs(cls, text_config: CLIPTextConfig, vision_config: CLIPVisionConfig, **kwargs):
392
+ r"""
393
+ Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model
394
+ configuration.
395
+
396
+ Returns:
397
+ [`CLIPConfig`]: An instance of a configuration object
398
+ """
399
+
400
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
401
+
402
+ def to_dict(self):
403
+ """
404
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
405
+
406
+ Returns:
407
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
408
+ """
409
+ output = copy.deepcopy(self.__dict__)
410
+ output["text_config"] = self.text_config.to_dict()
411
+ output["vision_config"] = self.vision_config.to_dict()
412
+ output["model_type"] = self.__class__.model_type
413
+ return output
414
+
415
+
416
+
417
+
418
+
419
+
420
+
421
+
422
+
423
+