jwcho commited on
Commit
208ab35
·
1 Parent(s): cbeaae9

first commit

Browse files
app.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import gradio as gr
4
+ import lightning as L
5
+ import torch
6
+
7
+ from lit_llama import LLaMA, Tokenizer
8
+ from lit_llama.utils import EmptyInitOnDevice
9
+
10
+
11
+ class ChatDoctor:
12
+ def __init__(self, model, tokenizer, fabric):
13
+ self.model = model
14
+ self.tokenizer = tokenizer
15
+ self.fabric = fabric
16
+
17
+ def generate_prompt(self, example):
18
+ if example["input"]:
19
+ return (
20
+ "아래는 작업을 설명하는 명령어와 추가적 맥락을 제공하는 입력이 짝을 이루는 예제입니다.\n\n"
21
+ "요청을 적절히 완료하는 응답을 작성하세요.\n\n"
22
+ f"### 명령어:\n{example['instruction']}\n\n### 입력:\n{example['input']}\n\n### 응답:"
23
+ )
24
+ return (
25
+ "환자가 의사에게 아픈 곳에 대해 문의합니다.\n\n"
26
+ "환자의 문의 내용에 대해 답변하세요. 환자의 질병을 진단하고, 가능하면 처방을 하세요. \n\n"
27
+ f"### 문의:\n{example['instruction']}\n\n### 응답:"
28
+ )
29
+
30
+ # This method generates the chatbot's responses.
31
+ @torch.no_grad()
32
+ def generate(
33
+ self,
34
+ idx,
35
+ max_new_tokens,
36
+ max_seq_length=None,
37
+ temperature=0.8,
38
+ top_k=None,
39
+ eos_id=None
40
+ ):
41
+ T = idx.size(0)
42
+ T_new = T + max_new_tokens
43
+ if max_seq_length is None:
44
+ max_seq_length = min(T_new, self.model.config.block_size)
45
+
46
+ device, dtype = idx.device, idx.dtype
47
+ # create an empty tensor of the expected final shape and fill in the current tokens
48
+ empty = torch.empty(T_new, dtype=dtype, device=device)
49
+ empty[:T] = idx
50
+ idx = empty
51
+ input_pos = torch.arange(0, T, device=device)
52
+
53
+ if idx.device.type == "xla":
54
+ import torch_xla.core.xla_model as xm
55
+
56
+ xm.mark_step()
57
+
58
+ # generate max_new_tokens tokens
59
+ for _ in range(max_new_tokens):
60
+ x = idx.index_select(0, input_pos).view(1, -1)
61
+
62
+ # forward
63
+ logits = self.model(x, max_seq_length, input_pos)
64
+ logits = logits[0, -1] / temperature
65
+
66
+ # optionally crop the logits to only the top k options
67
+ if top_k is not None:
68
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
69
+ logits = torch.where(logits < v[[-1]], -float("Inf"), logits)
70
+
71
+ probs = torch.nn.functional.softmax(logits, dim=-1)
72
+ idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)
73
+
74
+ # advance
75
+ input_pos = input_pos[-1:] + 1
76
+
77
+ if idx.device.type == "xla":
78
+ xm.mark_step()
79
+
80
+ # concatenate the new generation
81
+ idx = idx.index_copy(0, input_pos, idx_next)
82
+
83
+ # if <eos> token is triggered, return the output (stop generation)
84
+ if idx_next == eos_id:
85
+ return idx[:input_pos] # include the EOS token
86
+
87
+ return idx
88
+
89
+
90
+ # This method handles user's messages and updates the conversation history.
91
+ def user(self, user_message, history):
92
+ # The user's message is added to the history with None as the bot's response.
93
+ return "", history + [[user_message, None]]
94
+
95
+ # This method generates and handles bot's responses.
96
+ def bot(self, history, max_new_tokens, top_k, temperature):
97
+ instruction = history[-1][0].strip()
98
+ sample = { "instruction" : instruction, "input" : None }
99
+ prompt = self.generate_prompt(sample)
100
+ encoded_prompt = self.tokenizer.encode(prompt, bos=True, eos=False, device=self.fabric.device)
101
+
102
+ y = self.generate(
103
+ idx=encoded_prompt,
104
+ max_new_tokens=max_new_tokens,
105
+ temperature=temperature,
106
+ top_k=top_k,
107
+ eos_id=self.tokenizer.eos_id
108
+ )
109
+
110
+ self.model.reset_cache()
111
+
112
+ response = self.tokenizer.decode(y)
113
+ response = response.split('응답:')[1].strip()
114
+
115
+ # The history is updated with the bot's response.
116
+ history[-1][1] = response
117
+
118
+ return history
119
+
120
+
121
+ def load_model():
122
+ # Settings for inference
123
+ # Precision setting for float32 matmul operations. It's important for some CUDA devices.
124
+ torch.set_float32_matmul_precision("high")
125
+
126
+ checkpoint_path = Path("checkpoints/lit-llama/7B/lit-llama.pth")
127
+ tokenizer_path = Path("checkpoints/lit-llama/tokenizer.model")
128
+ quantize = None # "gptq.int4" or "llm.int8"
129
+
130
+ fabric = L.Fabric(devices=1)
131
+ dtype = torch.bfloat16 if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
132
+
133
+ with EmptyInitOnDevice(device=fabric.device, dtype=dtype, quantization_mode=quantize):
134
+ model = LLaMA.from_name("7B")
135
+
136
+ checkpoint = torch.load(checkpoint_path)
137
+ model.load_state_dict(checkpoint)
138
+
139
+ model.eval()
140
+ model = fabric.setup_module(model)
141
+
142
+ tokenizer = Tokenizer(tokenizer_path)
143
+
144
+ return model, tokenizer, fabric
145
+
146
+
147
+ def setup_gradio_ui(chat_doctor):
148
+ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
149
+ gr.Markdown(
150
+ """
151
+ # ChatDoctor-KR Demo
152
+
153
+ last modified : 23.05.18
154
+ """)
155
+
156
+ chatbot = gr.Chatbot(label="ChatDoctor-KR")
157
+ msg = gr.Textbox(lines=1, placeholder="질문 입력 후 엔터를 누르세요.", label="질문")
158
+ clear = gr.Button("클리어")
159
+
160
+ gr.Markdown(
161
+ """
162
+ ## Parameters
163
+ """)
164
+
165
+ max_new_tokens = gr.Slider(
166
+ minimum=1,
167
+ maximum=512,
168
+ step=1,
169
+ value=512,
170
+ label="max_new_tokens",
171
+ info="The number of new tokens to generate",
172
+ interactive=True
173
+ )
174
+
175
+ top_k = gr.Slider(
176
+ minimum=1,
177
+ maximum=300,
178
+ step=1,
179
+ value=200,
180
+ label="top_k",
181
+ info="If specified, only sample among the tokens with the k highest probabilities",
182
+ interactive=True
183
+ )
184
+
185
+ temperature = gr.Slider(
186
+ minimum=0.1,
187
+ maximum=1.0,
188
+ step=0.1,
189
+ value=0.8,
190
+ label="temperature",
191
+ info="Scales the predicted logits by 1 / temperature",
192
+ interactive=True
193
+ )
194
+
195
+ with gr.Accordion(label="Open for More!", open=False):
196
+ gr.Markdown("Blah Blah ...")
197
+
198
+ submit_result = msg.submit(
199
+ chat_doctor.user, [msg, chatbot], [msg, chatbot], queue=False
200
+ )
201
+ submit_result.then(
202
+ chat_doctor.bot, [chatbot, max_new_tokens, top_k, temperature], chatbot
203
+ )
204
+
205
+ # This part clears the chatbot history when the clear button is clicked.
206
+ clear.click(lambda: None, None, chatbot, queue=False)
207
+
208
+ demo.queue()
209
+
210
+ demo.launch(share=True, server_name="0.0.0.0")
211
+
212
+
213
+ def main():
214
+ # Load model and tokenizer
215
+ model, tokenizer, fabric = load_model()
216
+
217
+ # ChatDoctor instance
218
+ chat_doctor = ChatDoctor(model, tokenizer, fabric)
219
+
220
+ # Gradio UI setup and launch
221
+ setup_gradio_ui(chat_doctor)
222
+
223
+ if __name__ == "__main__":
224
+ main()
checkpoints/lit-llama/7B/lit-llama.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ea5abe49d33b50c000c1107907db19ef293dd61fceab8b451fe883f5fd8a919
3
+ size 13476954436
checkpoints/lit-llama/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
lit_llama/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from lit_llama.model import LLaMAConfig, LLaMA, RMSNorm, build_rope_cache, apply_rope
2
+ from lit_llama.tokenizer import Tokenizer
lit_llama/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (415 Bytes). View file
 
lit_llama/__pycache__/model.cpython-311.pyc ADDED
Binary file (20 kB). View file
 
lit_llama/__pycache__/tokenizer.cpython-311.pyc ADDED
Binary file (3.37 kB). View file
 
lit_llama/__pycache__/utils.cpython-311.pyc ADDED
Binary file (25.3 kB). View file
 
lit_llama/adapter.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of the paper:
2
+
3
+ LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
4
+ https://arxiv.org/abs/2303.16199
5
+
6
+ | Prefix cross-attention
7
+ |
8
+ ┌─────────────────┐ | ┌──────────────────┐
9
+ ┆ x ┆ | ┆ prefix ┆
10
+ └─────────────────┘ | └──────────────────┘
11
+ | | |
12
+ ▼ | ▼
13
+ ┌──────────────────┐ | ┌─────────────────────┐
14
+ ┆ self-attention ┆ --------------------------------------------------------------┐ ┆ linear projection ┆
15
+ └──────────────────┘ | ┆ └─────────────────────┘
16
+ | | ┆ | \
17
+ ▼ | ▼ ▼ ▼
18
+ ╭───╮ ┌────────────────┐ ╭───╮ ┌──────────────────────────┐ | ┌─────────┐ ┌──────────────┐ ┌────────────────┐
19
+ ┆ + ┆ ◀── ┆ gating factor ┆-┆ x ┆-┆ prefix cross-attention ┆ | ┆ query ┆ ┆ prefix key ┆ ┆ prefix value ┆
20
+ ╰───╯ └────────────────┘ ╰───╯ └──────────────────────────┘ | └─────────┘ └──────────────┘ └────────────────┘
21
+ | | \ | /
22
+ ▼ | ▼ ▼ ▼
23
+ | ┌────────────────────────────────┐
24
+ | ┆ scaled dot-product attention ┆
25
+ | └────────────────────────────────┘
26
+
27
+
28
+ In order to inject learnable information from the prefix to pretrained weights we need to sum outputs from
29
+ self-attention and prefix cross-attention (times gating factor). For prefix cross-attention we need `query` (from
30
+ self-attention as a result of linear projection), `prefix key` and `prefix value` (from cross-attention as a result of
31
+ linear projection).
32
+ The output of prefix cross-attention is multiplied by gating factor, which is a learnable parameter that is needed to
33
+ avoid potential disruption of pretrained weights caused by incorporating randomly initialized tensors. This factor is
34
+ initialized with zeros to avoid noise from the adaption prompts at the early training stage.
35
+ More about it: https://lightning.ai/pages/community/article/understanding-llama-adapters/
36
+
37
+ Notes about implementation: as per paper adapter's prefix is concatenated with the input, while here outputs of
38
+ self-attention and prefix cross-attention are summed. Both variants are mathematically equivalent:
39
+ https://github.com/ZrrSkywalker/LLaMA-Adapter/issues/47
40
+ """
41
+ # mypy: ignore-errors
42
+ from dataclasses import dataclass
43
+ from typing import Optional, Tuple, List, Union
44
+
45
+ import torch
46
+ import torch.nn as nn
47
+ from torch.nn import functional as F
48
+
49
+ import lit_llama.model as llama
50
+ from lit_llama.model import build_rope_cache, apply_rope, RMSNorm, MLP, KVCache, RoPECache
51
+
52
+
53
+ @dataclass
54
+ class LLaMAConfig(llama.LLaMAConfig):
55
+ adapter_prompt_length: int = 10
56
+ adapter_start_layer: int = 2
57
+
58
+
59
+ class CausalSelfAttention(nn.Module):
60
+ """A modification of `lit_llama.model.CausalSelfAttention` that adds the attention
61
+ over the adaption prompt."""
62
+
63
+ def __init__(self, config: LLaMAConfig, block_idx: int) -> None:
64
+ super().__init__()
65
+ assert config.n_embd % config.n_head == 0
66
+
67
+ # key, query, value projections for all heads, but in a batch
68
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
69
+ # output projection
70
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
71
+
72
+ if block_idx >= config.adapter_start_layer:
73
+ # adapter embedding layer
74
+ self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
75
+ # a learnable gating factor (to avoid potential disruption of pretrained weights) initialized with zeros (to
76
+ # avoid noise from adaption prompts at the early training stage)
77
+ self.gating_factor = torch.nn.Parameter(torch.zeros(1, config.n_head, 1, 1))
78
+
79
+ self.n_head = config.n_head
80
+ self.n_embd = config.n_embd
81
+ self.block_size = config.block_size
82
+ self.block_idx = block_idx
83
+ self.adapter_prompt_length = config.adapter_prompt_length
84
+ self.adapter_start_layer = config.adapter_start_layer
85
+
86
+ def forward(
87
+ self,
88
+ x: torch.Tensor,
89
+ rope: RoPECache,
90
+ mask: torch.Tensor,
91
+ max_seq_length: int,
92
+ input_pos: Optional[torch.Tensor] = None,
93
+ kv_cache: Optional[KVCache] = None,
94
+ adapter_kv_cache: Optional[KVCache] = None,
95
+ ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]:
96
+ # notation:
97
+ # - B | batch
98
+ # - T | time-step (sequence length)
99
+ # - C | embeddings size (n_embd) = head size * num heads
100
+ # - hs | head size
101
+ # - nh | number of heads
102
+
103
+ B, T, C = x.size()
104
+
105
+ # instead of calculating `query`, `key` and `value` by separately multiplying input `x` with corresponding
106
+ # weight matrices do it (for all heads) in a single multiplication with a matrix of 3x size (concatenated
107
+ # weights for q, k, v) and then split the result along `embedding size` dimension
108
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2) # (B, T, 3 * C) --> 3 * (B, T, C)
109
+
110
+ # in order to move head_size (hs) dimension right after batch (B) dimension, we need to first split
111
+ # embedding size (C) dimension into num_heads (nh) and head_size (hs)
112
+ head_size = C // self.n_head
113
+ k = k.view(B, T, self.n_head, head_size)
114
+ q = q.view(B, T, self.n_head, head_size)
115
+ v = v.view(B, T, self.n_head, head_size)
116
+
117
+ # "Unlike standard positional embeddings rotary embeddings must be applied at every layer"
118
+ q = apply_rope(q, rope) # (B, T, nh, hs)
119
+ k = apply_rope(k, rope) # (B, T, nh, hs)
120
+
121
+ # now `key`, 'query` and `value` tensors are correctly represented: for each element in a batch (B)
122
+ # there is a number of heads (nh) and for each head there is a sequence of elements (T), each of them is
123
+ # represented by a vector of size `hs`
124
+ k = k.transpose(1, 2) # (B, nh, T, hs)
125
+ q = q.transpose(1, 2) # (B, nh, T, hs)
126
+ v = v.transpose(1, 2) # (B, nh, T, hs)
127
+
128
+ if kv_cache is not None:
129
+ cache_k, cache_v = kv_cache # 2 * (B, nh, max_seq_length, hs)
130
+ # check if reached token limit
131
+ if input_pos[-1] >= max_seq_length:
132
+ # if we reached token limit and thus there is no space to put newly calculated `key` and `value`
133
+ # right next to cached ones, we need to rotate cache tensor along `max_seq_length` dimension by one
134
+ # element to the left: this will free up space for new `key` and `value`
135
+ input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
136
+ # shift 1 position to the left
137
+ cache_k = torch.roll(cache_k, -1, dims=2)
138
+ cache_v = torch.roll(cache_v, -1, dims=2)
139
+ k = cache_k.index_copy(2, input_pos, k) # (B, nh, max_seq_length, hs)
140
+ v = cache_v.index_copy(2, input_pos, v) # (B, nh, max_seq_length, hs)
141
+ kv_cache = k, v
142
+
143
+ # efficient attention using Flash Attention CUDA kernels
144
+ # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
145
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) # (B, nh, T, hs)
146
+
147
+ # "Adapters are applied to the topmost layers to better tune the language
148
+ # representations with higher-level semantics".
149
+ if self.block_idx >= self.adapter_start_layer:
150
+ if adapter_kv_cache is not None:
151
+ ak, av = adapter_kv_cache # 2 * (B, nh, aT, hs)
152
+ else:
153
+ prefix = self.adapter_wte.weight.reshape(1, self.adapter_prompt_length, self.n_embd)
154
+ aT = prefix.size(1)
155
+ _, ak, av = self.c_attn(prefix).split(self.n_embd, dim=2) # (1, aT, 3 * C) --> 3 * (1, aT, C)
156
+ ak = ak.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) # (B, nh, aT, hs)
157
+ av = av.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) # (B, nh, aT, hs)
158
+ adapter_kv_cache = (ak, av)
159
+
160
+ # Apply cross-attention with `query`, `adapter_key`, `adapter_value` and sum the output with the output
161
+ # obtained from self-attention step. This is mathematically equivalent to concatenation of prefix and input as per paper.
162
+ amask = torch.ones(q.shape[-2], ak.shape[-2], dtype=torch.bool, device=x.device) # (T, aT)
163
+ # ↓ (B, nh, T, hs) @ (B, nh, aT, hs).mT --> (B, nh, T, aT) @ (B, nh, aT, hs) --> (B, nh, T, hs)
164
+ ay = F.scaled_dot_product_attention(q, ak, av, attn_mask=amask, dropout_p=0.0, is_causal=False) # (B, nh, T, hs)
165
+ y = y + self.gating_factor * ay
166
+
167
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
168
+
169
+ # output projection
170
+ y = self.c_proj(y) # (B, T, C)
171
+
172
+ return y, kv_cache, adapter_kv_cache
173
+
174
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
175
+ """For backward compatibility with old checkpoints that have a single gating value for all heads."""
176
+ name = prefix + "gating_factor"
177
+ if name in state_dict:
178
+ tensor = state_dict[name]
179
+ # in case we are loading with `utils.lazy_load()`
180
+ tensor = tensor._load_tensor() if hasattr(tensor, "_load_tensor") else tensor
181
+
182
+ if len(tensor.shape) < 4:
183
+ # For old checkpoints with unified gating value
184
+ state_dict[name] = tensor.reshape(1, 1, 1, 1).repeat(1, self.n_head, 1, 1)
185
+ else:
186
+ state_dict[name] = tensor
187
+
188
+ return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
189
+
190
+
191
+ class Block(nn.Module):
192
+ """The implementation is identical to `lit_llama.model.Block` with the exception that
193
+ we replace the attention layer where adaption is implemented."""
194
+
195
+ def __init__(self, config: LLaMAConfig, block_idx: int) -> None:
196
+ super().__init__()
197
+ self.rms_1 = RMSNorm(config.n_embd)
198
+ self.attn = CausalSelfAttention(config, block_idx)
199
+ self.rms_2 = RMSNorm(config.n_embd)
200
+ self.mlp = MLP(config)
201
+
202
+ def forward(
203
+ self,
204
+ x: torch.Tensor,
205
+ rope: RoPECache,
206
+ mask: torch.Tensor,
207
+ max_seq_length: int,
208
+ input_pos: Optional[torch.Tensor] = None,
209
+ kv_cache: Optional[KVCache] = None,
210
+ adapter_kv_cache: Optional[KVCache] = None,
211
+ ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]:
212
+ h, new_kv_cache, new_adapter_kv_cache = self.attn(
213
+ self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache, adapter_kv_cache
214
+ )
215
+ x = x + h
216
+ x = x + self.mlp(self.rms_2(x))
217
+ return x, new_kv_cache, new_adapter_kv_cache
218
+
219
+
220
+ class LLaMA(llama.LLaMA):
221
+ """The implementation is identical to `lit_llama.model.LLaMA` with the exception that
222
+ the `Block` saves the layer index and passes it down to the attention layer."""
223
+
224
+ def __init__(self, config: LLaMAConfig) -> None:
225
+ nn.Module.__init__(self)
226
+ assert config.vocab_size is not None
227
+ assert config.block_size is not None
228
+ self.config = config
229
+
230
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
231
+ self.transformer = nn.ModuleDict(
232
+ dict(
233
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
234
+ h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
235
+ ln_f=RMSNorm(config.n_embd),
236
+ )
237
+ )
238
+
239
+ self.rope_cache: Optional[RoPECache] = None
240
+ self.mask_cache: Optional[torch.Tensor] = None
241
+ self.kv_caches: List[KVCache] = []
242
+ self.adapter_kv_caches: List[KVCache] = []
243
+
244
+ @classmethod
245
+ def from_name(cls, name: str):
246
+ return cls(LLaMAConfig.from_name(name))
247
+
248
+ def reset_cache(self) -> None:
249
+ super().reset_cache()
250
+ self.adapter_kv_caches.clear()
251
+
252
+ def forward(
253
+ self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None
254
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[KVCache]]]:
255
+ B, T = idx.size()
256
+
257
+ block_size = self.config.block_size
258
+ if max_seq_length is None:
259
+ max_seq_length = block_size
260
+ assert T <= max_seq_length, f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
261
+ assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
262
+ assert T <= block_size, f"Cannot forward sequence of length {T}, block size is only {block_size}"
263
+
264
+ if self.rope_cache is None:
265
+ self.rope_cache = self.build_rope_cache(idx) # (block_size, head_size / 2, 2)
266
+ if self.mask_cache is None:
267
+ self.mask_cache = self.build_mask_cache(idx) # (1, 1, block_size, block_size)
268
+
269
+ if input_pos is not None:
270
+ rope = self.rope_cache.index_select(0, input_pos)
271
+ mask = self.mask_cache.index_select(2, input_pos)
272
+ mask = mask[:, :, :, :max_seq_length]
273
+ else:
274
+ rope = self.rope_cache[:T]
275
+ mask = self.mask_cache[:, :, :T, :T]
276
+
277
+ # forward the model itself
278
+ x = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
279
+
280
+ if input_pos is None: # proxy for use_cache=False
281
+ for block in self.transformer.h:
282
+ x, *_ = block(x, rope, mask, max_seq_length)
283
+ else:
284
+ if not self.kv_caches:
285
+ head_size = self.config.n_embd // self.config.n_head
286
+ cache_shape = (B, self.config.n_head, max_seq_length, head_size)
287
+ self.kv_caches = [
288
+ (torch.zeros(cache_shape, device=x.device, dtype=x.dtype), torch.zeros(cache_shape, device=x.device, dtype=x.dtype))
289
+ for _ in range(self.config.n_layer)
290
+ ]
291
+ if not self.adapter_kv_caches:
292
+ self.adapter_kv_caches = [None for _ in range(self.config.n_layer)]
293
+ for i, block in enumerate(self.transformer.h):
294
+ x, self.kv_caches[i], self.adapter_kv_caches[i] = block(
295
+ x, rope, mask, max_seq_length, input_pos, self.kv_caches[i], self.adapter_kv_caches[i]
296
+ )
297
+
298
+ x = self.transformer.ln_f(x) # (B, T, n_embd)
299
+
300
+ logits = self.lm_head(x) # (B, T, vocab_size)
301
+
302
+ return logits
303
+
304
+
305
+ def mark_only_adapter_as_trainable(model: LLaMA) -> None:
306
+ """Sets `requires_grad=False` for all non-adapter weights."""
307
+ for name, param in model.named_parameters():
308
+ param.requires_grad = "adapter_wte" in name or "gating_factor" in name
309
+
310
+
311
+ def adapter_state_from_state_dict(state_dict: dict) -> dict:
312
+ """Returns the model state dict with only the adapter weights for saving."""
313
+ return {name: param for name, param in state_dict.items() if "adapter_wte" in name or "gating_factor" in name}
lit_llama/adapter_v2.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+
6
+ from lit_llama.adapter import LLaMA
7
+
8
+
9
+ def get_adapter_substrings():
10
+ substrings = ["adapter_wte", "gating_factor"] # regular adapter v1 parameters
11
+ substrings.extend(["adapter_scale", "adapter_bias"]) # adapter v2: new bias and scale used in Linear
12
+ substrings.extend(["rms_1", "rms_2", "ln_f"]) # adapter v2: RMSNorm parameters are now trainable
13
+ return substrings
14
+
15
+
16
+ def mark_only_adapter_v2_as_trainable(model: LLaMA) -> None:
17
+ """Sets `requires_grad=False` for all non-adapter weights."""
18
+ for name, param in model.named_parameters():
19
+ param.requires_grad = any(s in name for s in get_adapter_substrings())
20
+
21
+
22
+ def adapter_v2_state_from_state_dict(state_dict: dict) -> dict:
23
+ """Returns the model state dict with only the adapter weights for saving."""
24
+ return {name: param for name, param in state_dict.items()
25
+ if any(s in name for s in get_adapter_substrings())}
26
+
27
+
28
+ def adapter_v2_new_forward(self, input: Tensor) -> Tensor:
29
+ return self.adapter_scale * (
30
+ F.linear(input, self.weight, self.bias) + self.adapter_bias
31
+ )
32
+
33
+
34
+ def adapter_v2_linear_with_bias_and_scale(layer):
35
+ layer.adapter_bias = torch.nn.Parameter(torch.zeros(layer.weight.shape[0]), requires_grad=True)
36
+ layer.adapter_scale = torch.nn.Parameter(torch.ones(layer.weight.shape[0]), requires_grad=True)
37
+ bound_method = adapter_v2_new_forward.__get__(layer, layer.__class__)
38
+ setattr(layer, 'forward', bound_method)
39
+ return layer
40
+
41
+
42
+ def add_adapter_v2_parameters_to_linear_layers(model):
43
+ for module in model.modules():
44
+ if isinstance(module, nn.Linear):
45
+ adapter_v2_linear_with_bias_and_scale(module)
lit_llama/lora.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Derived from https://github.com/microsoft/LoRA
2
+ # ------------------------------------------------------------------------------------------
3
+ # Copyright (c) Microsoft Corporation. All rights reserved.
4
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
5
+ # ------------------------------------------------------------------------------------------
6
+
7
+ r"""
8
+ Low Ranking Adaptation for LLMs scheme.
9
+
10
+ ┌───────────────────┐
11
+ ┆ h ┆
12
+ └───────────────────┘
13
+
14
+ |
15
+ +
16
+ / \
17
+ ┌─────────────────┐ ╭───────────────╮ Matrix initialization:
18
+ ┆ ┆ \ B / B = 0
19
+ ┆ pretrained ┆ \ r*d / A = N(0, sigma^2)
20
+ ┆ weights ┆ ╰─────────╯
21
+ ┆ ┆ | r | r - rank
22
+ ┆ W e R^(d*d) ┆ | ◀─────▶ |
23
+ ┆ ┆ ╭─────────╮
24
+ └─────────────────┘ / A \
25
+ ▲ / d*r \
26
+ \ ╰───────────────╯
27
+ \ ▲
28
+ \ /
29
+ \ /
30
+ ┌───────────────────┐
31
+ ┆ x ┆
32
+ └───────────────────┘
33
+
34
+ With LoRA (Low Ranking Adaptation: https://arxiv.org/abs/2106.09685) instead of learning weights of size d*d,
35
+ we can freeze the pretrained weights and instead learn two matrices of size d*r and r*d (they will store weight updates
36
+ for the pretrained weights): the number of parameters in this case will be reduced drastically (depending on the rank of
37
+ course) yet after multiplication of matrices d*r and r*d we will get a matrix d*d which we can sum with frozen
38
+ pretrained weights and thus fine-tune the model.
39
+
40
+ The goal of this approach is to move weight updates into a separate matrix which is decomposed with
41
+ two matrices of a lower rank.
42
+ """
43
+
44
+ import torch
45
+ import torch.nn as nn
46
+ import torch.nn.functional as F
47
+
48
+ import math
49
+ from typing import Dict, List
50
+
51
+ import lit_llama.model as llama
52
+
53
+ from contextlib import contextmanager
54
+ from dataclasses import dataclass
55
+
56
+
57
+ class LoRALayer():
58
+ def __init__(
59
+ self,
60
+ r: int,
61
+ lora_alpha: int,
62
+ lora_dropout: float,
63
+ merge_weights: bool,
64
+ ):
65
+ """Store LoRA specific attributes in a class.
66
+
67
+ Args:
68
+ r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
69
+ the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
70
+ lora_alpha: alpha is needed for scaling updates as alpha/r
71
+ "This scaling helps to reduce the need to retune hyperparameters when we vary r"
72
+ https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
73
+ lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
74
+ merge_weights: whether we want to merge pretrained weights and LoRA weight updates. This is useful if one wants to use
75
+ fine-tuned model as a standalone one (without storing LoRA weights separately) plus it helps to reduce
76
+ overhead during inference.
77
+ """
78
+ self.r = r
79
+ self.lora_alpha = lora_alpha
80
+ # Optional dropout
81
+ if lora_dropout > 0.:
82
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
83
+ else:
84
+ self.lora_dropout = lambda x: x
85
+ # Mark the weight as unmerged
86
+ self.merged = False
87
+ self.merge_weights = merge_weights
88
+
89
+
90
+ class MergedLinear(nn.Linear, LoRALayer):
91
+ # LoRA implemented in a dense layer
92
+ def __init__(
93
+ self,
94
+ # ↓ this part is for pretrained weights
95
+ in_features: int,
96
+ out_features: int,
97
+ # ↓ the remaining part is for LoRA
98
+ r: int = 0,
99
+ lora_alpha: int = 1,
100
+ lora_dropout: float = 0.,
101
+ enable_lora: List[bool] = [False],
102
+ fan_in_fan_out: bool = False,
103
+ merge_weights: bool = True,
104
+ **kwargs
105
+ ):
106
+ """LoRA wrapper around linear class that is used for calculation of q, k and v matrices.
107
+
108
+ This class has three weight matrices:
109
+ 1. Pretrained weights are stored as `self.weight` (because of the nn.Linear inheritance)
110
+ 2. LoRA A matrix as `self.lora_A`
111
+ 3. LoRA B matrix as `self.lora_B`
112
+ Only LoRA's A and B matrices are updated, pretrained weights stay frozen.
113
+
114
+ Args:
115
+ in_features: number of input features of the pretrained weights
116
+ out_features: number of output features of the pretrained weights
117
+ r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
118
+ the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
119
+ lora_alpha: alpha is needed for scaling updates as alpha/r
120
+ "This scaling helps to reduce the need to retune hyperparameters when we vary r"
121
+ https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
122
+ lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
123
+ enable_lora: MergeLinear class is for attention mechanism where qkv are calculated with a single weight matrix. If we
124
+ don't want to apply LoRA for all three (query, key and value) we can set it as False. For example if we want
125
+ to apply LoRA only to `query` and `value` but keep `key` without weight updates we should pass `[True,
126
+ False, True]`
127
+ fan_in_fan_out: set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses
128
+ `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`
129
+ https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora.py#LL53C9-L53C112
130
+ merge_weights: whether we want to merge pretrained weights and LoRA weight updates. This is useful if one wants to use
131
+ fine-tuned model as a standalone one (without storing LoRA weight separately) plus it helps to reduce
132
+ overhead during inference.
133
+ """
134
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
135
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
136
+ merge_weights=merge_weights)
137
+ assert out_features % len(enable_lora) == 0, \
138
+ 'The length of enable_lora must divide out_features'
139
+ self.enable_lora = enable_lora
140
+ self.fan_in_fan_out = fan_in_fan_out
141
+
142
+ # Actual trainable parameters
143
+ # To better understand initialization let's imagine that we have such parameters:
144
+ # ⚬ in_features: 128 (embeddings_size)
145
+ # ⚬ out_features: 384 (3 * embedding_size)
146
+ # ⚬ r: 2
147
+ # ⚬ enable_lora: [True, False, True]
148
+ if r > 0 and any(enable_lora):
149
+ self.lora_A = nn.Parameter(
150
+ self.weight.new_zeros((r * sum(enable_lora), in_features))) # (4, 128)
151
+ self.lora_B = nn.Parameter(
152
+ self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r)) # (256, 2)
153
+ ) # weights for Conv1D with groups=sum(enable_lora)
154
+ # Notes about shapes above
155
+ # - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices;
156
+ # 128 is the input size of the x (embedding size). (4, 128) and not (128, 4) because later on in
157
+ # F.linear function weights are automatically transposed. In addition conv1d requires channels to
158
+ # be before seq length
159
+ # - self.lora_B has shape (256, 2): 256 because LoRA is applied only to two matrices, so the output is
160
+ # 128*2; 2 tells to have two channels per group for group convolution
161
+
162
+ # Scaling:
163
+ # This balances the pretrained model`s knowledge and the new task-specific adaptation
164
+ # https://lightning.ai/pages/community/tutorial/lora-llm/
165
+ # So, set alpha to 1.0 to fully add LoRA. If the LoRA seems to have too much effect (i.e., overfitted), set
166
+ # alpha to lower value. If the LoRA seems to have too little effect, set alpha to higher than 1.0. You can
167
+ # tune these values to your needs. This value can be even slightly greater than 1.0!
168
+ # https://github.com/cloneofsimo/lora
169
+ self.scaling = self.lora_alpha / self.r
170
+
171
+ # Freezing the pre-trained weight matrix
172
+ self.weight.requires_grad = False # (384, 128)
173
+
174
+ # Compute the indices
175
+ # Indices are needed to properly pad weight updates with zeros. If we want to fine-tune queries and values,
176
+ # but not keys, then the weights update should be:
177
+ #
178
+ # [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
179
+ # [....................................],
180
+ # [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
181
+ # ↑ ↑ ↑
182
+ # ________________________________________
183
+ # | query | key | value |
184
+ # ----------------------------------------
185
+ self.lora_ind = self.weight.new_zeros(
186
+ (out_features, ), dtype=torch.bool
187
+ ).view(len(enable_lora), -1) # (3, 128)
188
+ self.lora_ind[enable_lora, :] = True # (3, 128)
189
+ self.lora_ind = self.lora_ind.view(-1) # (384,)
190
+ self.reset_parameters()
191
+ if fan_in_fan_out:
192
+ self.weight.data = self.weight.data.T
193
+
194
+ def reset_parameters(self):
195
+ """Reset all the weights, even including pretrained ones."""
196
+ nn.Linear.reset_parameters(self)
197
+ if hasattr(self, 'lora_A'):
198
+ # initialize A the same way as the default for nn.Linear and B to zero
199
+ # Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314
200
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
201
+ nn.init.zeros_(self.lora_B)
202
+
203
+ def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
204
+ """Properly pad weight updates with zeros.
205
+
206
+ If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys,
207
+ then the weights update should be:
208
+
209
+ [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
210
+ [....................................],
211
+ [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
212
+ ↑ ↑ ↑
213
+ ________________________________________
214
+ | query | key | value |
215
+ ----------------------------------------
216
+
217
+ Args:
218
+ x: tensor with weights update that will be padded with zeros if necessary
219
+
220
+ Returns:
221
+ A tensor with weight updates and zeros for deselected q, k or v
222
+ """
223
+ # Let's image that:
224
+ # ⚬ input x has shape (64, 64, 256): (batch_size, sequence_length, embeddings_size)
225
+ # ⚬ embeddings_size: 128
226
+ # ⚬ self.out_features: 384 (3 * embeddings_size)
227
+ # ⚬ enable_lora: [True, False, True]
228
+ # Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected
229
+ # embeddings_size is 384 (self.out_features), so that means that we need to pad from 256 to 384 with zeros, but
230
+ # only for key updates (this is where self.lora_ind comes in handy)
231
+ # Note: double transpose (in the beginning and in the end) is basically a guard for two-dimensional tensors
232
+ # for example when we want to merge/unmerge LoRA weights and pretrained weights
233
+ x = x.transpose(0, 1)
234
+ result = x.new_zeros((*x.shape[:-1], self.out_features)) # (64, 64, 384)
235
+ result = result.view(-1, self.out_features) # (4096, 384)
236
+ result[:, self.lora_ind] = x.reshape(
237
+ -1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
238
+ ) # (4096, 256)
239
+ return result.view((*x.shape[:-1], self.out_features)).transpose(0, 1) # (64, 64, 384)
240
+
241
+ def train(self, mode: bool = True):
242
+ """Set the module into train or eval mode if `mode` is True of False respectively.
243
+
244
+ For train mode (train(True)) if weights are merged we need to subtract weights updates (LoRA_A @ LoRA_B) from
245
+ pretrained weights so we can continue training LoRA's matrices A and B and keep pretrained weights frozen.
246
+
247
+ For eval mode (train(False)) if weights are not merged we need to add weight updates to pretrained weights in
248
+ order to reduce computational overhead during inference.
249
+
250
+ Args:
251
+ mode: if True the module will be set into train mode (affects Dropout and BatchNorm), if False - eval mode.
252
+
253
+ """
254
+ def T(w):
255
+ return w.T if self.fan_in_fan_out else w
256
+ # despite being called from nn.Linear this method will put all layers into train mode, including nn.Dropout
257
+ # of course except parameters (such as self.lora_A, self.lora_B)
258
+ nn.Linear.train(self, mode)
259
+
260
+ # if train(True) -> unmerge unless we already have them unmerged
261
+ # if train(False) -> merge unless we already have them merged
262
+ should = self.merged if mode else not self.merged
263
+
264
+ # Let's assume that:
265
+ # ⚬ self.weight.data: (384, 128) or (3 * embedding_size, embedding_size)
266
+ # ⚬ self.lora_A.data: (4, 128)
267
+ # ⚬ self.lora_B.data: (256, 2)
268
+ if self.merge_weights and should:
269
+ if self.r > 0 and any(self.enable_lora):
270
+ delta_w = F.conv1d(
271
+ self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128)
272
+ self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
273
+ groups=sum(self.enable_lora)
274
+ ).squeeze(0) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)
275
+ # -1: W = W - delta_W (unmerge), +1: W = W + delta_W (merge)
276
+ sign = -1 if mode else 1
277
+ self.weight.data += sign * self.zero_pad(T(delta_w * self.scaling)) # (256, 128) after zero_pad (384, 128)
278
+ self.merged = not mode
279
+
280
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
281
+ """Do the forward pass.
282
+
283
+ If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication.
284
+ If not, then multiply pretrained weights with input, apply LoRA on input and do summation.
285
+
286
+ Args:
287
+ x: input tensor of shape (batch_size, context_length, embedding_size)
288
+
289
+ Returns:
290
+ Output tensor of shape (batch_size, context_length, 3 * embedding_size)
291
+ """
292
+ def T(w):
293
+ return w.T if self.fan_in_fan_out else w
294
+
295
+ # Let's assume that:
296
+ # ⚬ x: (64, 64, 128) or (batch_size, context_length, embedding_size)
297
+ # ⚬ self.weight: (384, 128) or (3 * embedding_size, embedding_size)
298
+ # ⚬ self.lora_A.data: (4, 128)
299
+ # ⚬ self.lora_B.data: (256, 2)
300
+
301
+ # the logic here is that the weights are merged only during inference
302
+ # so if they are merged we don't need to do anything with LoRA's A and B matrices
303
+ # but if the weights are not merged that means that the forward method is called during
304
+ # training and we need to forward pass input through pretrained weights, LoRA A and B matrices
305
+ # and do the summation (as per scheme at the top of the file)
306
+ if self.merged:
307
+ return F.linear(x, T(self.weight), bias=self.bias)
308
+ else:
309
+ # `F.linear` automatically transposes the second argument (T(self.weight) in our case)
310
+ result = F.linear(x, T(self.weight), bias=self.bias) # (64, 64, 128) @ (384, 128) -> (64, 64, 384)
311
+ if self.r > 0:
312
+ after_A = F.linear(self.lora_dropout(x), self.lora_A) # (64, 64, 128) @ (4, 128) -> (64, 64, 4)
313
+ # For F.conv1d:
314
+ # ⚬ input: input tensor of shape (mini-batch, in_channels, iW)
315
+ # ⚬ weight: filters of shape (out_channels, in_channels/groups, kW)
316
+ # ⚬ groups: split input into groups, in_channels should be divisible by the number of groups. Default: 1
317
+ # presumably iW - sequence width/length, kW - kernel width
318
+ after_B = F.conv1d(
319
+ after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64)
320
+ self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
321
+ groups=sum(self.enable_lora)
322
+ ).transpose(-2, -1) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256)
323
+ result += self.zero_pad(after_B) * self.scaling # (64, 64, 256) after zero_pad (64, 64, 384)
324
+ return result
325
+
326
+
327
+ def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
328
+ """Freeze all modules except LoRA's and depending on 'bias' value unfreezes bias weights.
329
+
330
+ Args:
331
+ model: model with LoRA layers
332
+ bias:
333
+ ``"none"``: all bias weights will be frozen,
334
+ ``"lora_only"``: only bias weight for LoRA layers will be unfrozen,
335
+ ``"all"``: all bias weights will be unfrozen.
336
+
337
+ Raises:
338
+ NotImplementedError: if `bias` not in ["none", "lora_only", "all"]
339
+ """
340
+ # freeze all layers except LoRA's
341
+ for n, p in model.named_parameters():
342
+ if 'lora_' not in n:
343
+ p.requires_grad = False
344
+
345
+ # depending on the `bias` value unfreeze bias weights
346
+ if bias == 'none':
347
+ return
348
+ elif bias == 'all':
349
+ for n, p in model.named_parameters():
350
+ if 'bias' in n:
351
+ p.requires_grad = True
352
+ elif bias == 'lora_only':
353
+ for m in model.modules():
354
+ if isinstance(m, LoRALayer) and \
355
+ hasattr(m, 'bias') and \
356
+ m.bias is not None:
357
+ m.bias.requires_grad = True
358
+ else:
359
+ raise NotImplementedError
360
+
361
+
362
+ def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
363
+ """Return state_dict with weights of LoRA's A and B matrices and with biases depending on the `bias` value.
364
+
365
+ Args:
366
+ model: model with LoRA layers
367
+ bias:
368
+ ``"none"``: state dict will not store bias weights,
369
+ ``"lora_only"``: state dict will store bias weights only from LoRA layers,
370
+ ``"all"``: state dict will store all bias weights.
371
+
372
+ Returns:
373
+ Weights and biases of LoRA layers
374
+
375
+ Raises:
376
+ NotImplementedError: if `bias` not in ["none", "lora_only", "all"]
377
+ """
378
+ my_state_dict = model.state_dict()
379
+ if bias == 'none':
380
+ return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
381
+ elif bias == 'all':
382
+ return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k}
383
+ elif bias == 'lora_only':
384
+ to_return = {}
385
+ for k in my_state_dict:
386
+ if 'lora_' in k:
387
+ to_return[k] = my_state_dict[k]
388
+ bias_name = k.split('lora_')[0]+'bias'
389
+ if bias_name in my_state_dict:
390
+ to_return[bias_name] = my_state_dict[bias_name]
391
+ return to_return
392
+ else:
393
+ raise NotImplementedError
394
+
395
+
396
+ @dataclass
397
+ class LoRAConfig:
398
+ r: float = 0.0
399
+ alpha: float = 1.0
400
+ dropout: float = 0.0
401
+
402
+
403
+ class CausalSelfAttention(llama.CausalSelfAttention):
404
+ lora_config = None
405
+
406
+ def __init__(self, config: llama.LLaMAConfig) -> None:
407
+ """Causal self-attention with calculating qkv matrices with a single matrix* and Low Ranking Adaptation for
408
+ parameter-efficient fine-tuning.
409
+
410
+ *Instead of creating multiple heads and concatenating the result (in addition to creating separate matrices for
411
+ query, key and value for each head) we can do this in a single pass with a single weight matrix.
412
+
413
+ Args:
414
+ config:
415
+ ``"block_size"``: size of the context of the model,
416
+ ``"vocab_size"``: number of unique tokens,
417
+ ``"padded_vocab_size"``: padded size of the vocabulary to the nearest multiple of 64 (leads to a greater performance),
418
+ ``"n_layer"``: number of transformer blocks (self-attention + MLP),
419
+ ``"n_head"``: number of heads in multi-head attention mechanism,
420
+ ``"n_embd"``: size of the embedding: vector representation of each token.
421
+ """
422
+ # Skip the parent class __init__ altogether and replace it to avoid
423
+ # useless allocations
424
+ nn.Module.__init__(self)
425
+ assert config.n_embd % config.n_head == 0
426
+
427
+ # key, query, value projections for all heads, but in a batch
428
+ self.c_attn = MergedLinear(
429
+ in_features=config.n_embd,
430
+ out_features=3 * config.n_embd,
431
+ r=self.lora_config.r,
432
+ lora_alpha=self.lora_config.alpha,
433
+ lora_dropout=self.lora_config.dropout,
434
+ enable_lora=[True, False, True],
435
+ fan_in_fan_out = False,
436
+ merge_weights=True,
437
+ bias=False)
438
+ # output projection
439
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
440
+ # regularization
441
+ self.n_head = config.n_head
442
+ self.n_embd = config.n_embd
443
+ self.block_size = config.block_size
444
+ self.rope_cache = None
445
+
446
+
447
+ @contextmanager
448
+ def lora(r, alpha, dropout, enabled: bool = True):
449
+ """Apply context manager under which you can instantiate the model with LoRA.
450
+
451
+ In a nutshell the code inside this function forces to use LoRA variant of causal self-attention
452
+ instead of the original one (without LoRA).
453
+
454
+ Args:
455
+ r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
456
+ the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
457
+ alpha: alpha is needed for scaling updates as alpha/r
458
+ "This scaling helps to reduce the need to retune hyperparameters when we vary r"
459
+ https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
460
+ dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
461
+ enabled: enables/disables LoRA
462
+ """
463
+ if not enabled:
464
+ yield
465
+ return
466
+
467
+ CausalSelfAttention.lora_config = LoRAConfig(r=r, alpha=alpha, dropout=dropout)
468
+ # when entering context manager replace link to causal self-attention class from original
469
+ # to a variant with LoRA
470
+ causal_self_attention = llama.CausalSelfAttention
471
+ llama.CausalSelfAttention = CausalSelfAttention
472
+ yield
473
+ # when exiting context manager - restore link to original causal self-attention class
474
+ llama.CausalSelfAttention = causal_self_attention
475
+
476
+ CausalSelfAttention.lora_config = None
lit_llama/model.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Full definition of a LLaMA Language Model, all of it in this single file.
2
+
3
+ Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
4
+ """
5
+ # mypy: ignore-errors
6
+ import math
7
+ from dataclasses import dataclass
8
+ from typing import List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.nn import functional as F
13
+ from typing_extensions import Self
14
+
15
+ from lit_llama.utils import find_multiple
16
+
17
+
18
+ MaskCache = torch.Tensor
19
+ RoPECache = torch.Tensor
20
+ KVCache = Tuple[torch.Tensor, torch.Tensor]
21
+
22
+
23
+ @dataclass
24
+ class LLaMAConfig:
25
+ block_size: int = 2048
26
+ vocab_size: int = 32000
27
+ padded_vocab_size: Optional[int] = None
28
+ n_layer: int = 32
29
+ n_head: int = 32
30
+ n_embd: int = 4096
31
+
32
+ def __post_init__(self):
33
+ if self.padded_vocab_size is None:
34
+ self.padded_vocab_size = find_multiple(self.vocab_size, 64)
35
+
36
+ @classmethod
37
+ def from_name(cls, name: str) -> Self:
38
+ return cls(**llama_configs[name])
39
+
40
+
41
+ llama_configs = {
42
+ "7B": dict(n_layer=32, n_head=32, n_embd=4096),
43
+ "13B": dict(n_layer=40, n_head=40, n_embd=5120),
44
+ "30B": dict(n_layer=60, n_head=52, n_embd=6656),
45
+ "65B": dict(n_layer=80, n_head=64, n_embd=8192),
46
+ }
47
+
48
+
49
+ class LLaMA(nn.Module):
50
+ def __init__(self, config: LLaMAConfig) -> None:
51
+ super().__init__()
52
+ assert config.padded_vocab_size is not None
53
+ self.config = config
54
+
55
+ self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
56
+ self.transformer = nn.ModuleDict(
57
+ dict(
58
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
59
+ h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
60
+ ln_f=RMSNorm(config.n_embd),
61
+ )
62
+ )
63
+
64
+ self.rope_cache: Optional[RoPECache] = None
65
+ self.mask_cache: Optional[MaskCache] = None
66
+ self.kv_caches: List[KVCache] = []
67
+
68
+ def _init_weights(self, module: nn.Module) -> None:
69
+ if isinstance(module, nn.Linear):
70
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
71
+ elif isinstance(module, nn.Embedding):
72
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
73
+
74
+ def forward(
75
+ self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None
76
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[KVCache]]]:
77
+ B, T = idx.size()
78
+
79
+ block_size = self.config.block_size
80
+ if max_seq_length is None:
81
+ max_seq_length = block_size
82
+ assert T <= max_seq_length, f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
83
+ assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
84
+ assert T <= block_size, f"Cannot forward sequence of length {T}, block size is only {block_size}"
85
+
86
+ if self.rope_cache is None:
87
+ self.rope_cache = self.build_rope_cache(idx)
88
+ if self.mask_cache is None:
89
+ self.mask_cache = self.build_mask_cache(idx)
90
+
91
+ if input_pos is not None:
92
+ rope = self.rope_cache.index_select(0, input_pos)
93
+ mask = self.mask_cache.index_select(2, input_pos)
94
+ mask = mask[:, :, :, :max_seq_length]
95
+ else:
96
+ rope = self.rope_cache[:T]
97
+ mask = self.mask_cache[:, :, :T, :T]
98
+
99
+ # forward the model itself
100
+ x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
101
+
102
+ if input_pos is None: # proxy for use_cache=False
103
+ for block in self.transformer.h:
104
+ x, _ = block(x, rope, mask, max_seq_length)
105
+ else:
106
+ if not self.kv_caches:
107
+ head_size = self.config.n_embd // self.config.n_head
108
+ cache_shape = (B, self.config.n_head, max_seq_length, head_size)
109
+ self.kv_caches = [
110
+ (torch.zeros(cache_shape, device=x.device, dtype=x.dtype), torch.zeros(cache_shape, device=x.device, dtype=x.dtype))
111
+ for _ in range(self.config.n_layer)
112
+ ]
113
+ for i, block in enumerate(self.transformer.h):
114
+ x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])
115
+
116
+ x = self.transformer.ln_f(x)
117
+
118
+ logits = self.lm_head(x) # (b, t, vocab_size)
119
+
120
+ return logits
121
+
122
+ @classmethod
123
+ def from_name(cls, name: str) -> Self:
124
+ return cls(LLaMAConfig.from_name(name))
125
+
126
+ def build_rope_cache(self, idx: torch.Tensor) -> RoPECache:
127
+ return build_rope_cache(
128
+ seq_len=self.config.block_size,
129
+ n_elem=self.config.n_embd // self.config.n_head,
130
+ dtype=idx.dtype,
131
+ device=idx.device,
132
+ )
133
+
134
+ def build_mask_cache(self, idx: torch.Tensor) -> MaskCache:
135
+ ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool)
136
+ return torch.tril(ones).unsqueeze(0).unsqueeze(0)
137
+
138
+ def reset_cache(self) -> None:
139
+ self.kv_caches.clear()
140
+ if self.mask_cache.device.type == "xla":
141
+ # https://github.com/Lightning-AI/lit-parrot/pull/83#issuecomment-1558150179
142
+ self.rope_cache = None
143
+ self.mask_cache = None
144
+
145
+
146
+ class Block(nn.Module):
147
+ def __init__(self, config: LLaMAConfig) -> None:
148
+ super().__init__()
149
+ self.rms_1 = RMSNorm(config.n_embd)
150
+ self.attn = CausalSelfAttention(config)
151
+ self.rms_2 = RMSNorm(config.n_embd)
152
+ self.mlp = MLP(config)
153
+
154
+ def forward(
155
+ self,
156
+ x: torch.Tensor,
157
+ rope: RoPECache,
158
+ mask: MaskCache,
159
+ max_seq_length: int,
160
+ input_pos: Optional[torch.Tensor] = None,
161
+ kv_cache: Optional[KVCache] = None,
162
+ ) -> Tuple[torch.Tensor, Optional[KVCache]]:
163
+ h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache)
164
+ x = x + h
165
+ x = x + self.mlp(self.rms_2(x))
166
+ return x, new_kv_cache
167
+
168
+
169
+ class CausalSelfAttention(nn.Module):
170
+ def __init__(self, config: LLaMAConfig) -> None:
171
+ super().__init__()
172
+ assert config.n_embd % config.n_head == 0
173
+
174
+ # key, query, value projections for all heads, but in a batch
175
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
176
+ # output projection
177
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
178
+
179
+ self.n_head = config.n_head
180
+ self.n_embd = config.n_embd
181
+ self.block_size = config.block_size
182
+
183
+ def forward(
184
+ self,
185
+ x: torch.Tensor,
186
+ rope: RoPECache,
187
+ mask: MaskCache,
188
+ max_seq_length: int,
189
+ input_pos: Optional[torch.Tensor] = None,
190
+ kv_cache: Optional[KVCache] = None,
191
+ ) -> Tuple[torch.Tensor, Optional[KVCache]]:
192
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
193
+
194
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
195
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
196
+
197
+ head_size = C // self.n_head
198
+ k = k.view(B, T, self.n_head, head_size)
199
+ q = q.view(B, T, self.n_head, head_size)
200
+ v = v.view(B, T, self.n_head, head_size)
201
+
202
+ q = apply_rope(q, rope)
203
+ k = apply_rope(k, rope)
204
+
205
+ k = k.transpose(1, 2) # (B, nh, T, hs)
206
+ q = q.transpose(1, 2) # (B, nh, T, hs)
207
+ v = v.transpose(1, 2) # (B, nh, T, hs)
208
+
209
+ if kv_cache is not None:
210
+ cache_k, cache_v = kv_cache
211
+ # check if reached token limit
212
+ if input_pos[-1] >= max_seq_length:
213
+ input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
214
+ # shift 1 position to the left
215
+ cache_k = torch.roll(cache_k, -1, dims=2)
216
+ cache_v = torch.roll(cache_v, -1, dims=2)
217
+ k = cache_k.index_copy(2, input_pos, k)
218
+ v = cache_v.index_copy(2, input_pos, v)
219
+ kv_cache = k, v
220
+
221
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
222
+ # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
223
+ # att = att.masked_fill(mask[:,:,:T,:T] == 0, float('-inf'))
224
+ # att = F.softmax(att, dim=-1)
225
+ # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
226
+
227
+ # efficient attention using Flash Attention CUDA kernels
228
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
229
+
230
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
231
+
232
+ # output projection
233
+ y = self.c_proj(y)
234
+
235
+ return y, kv_cache
236
+
237
+
238
+ class MLP(nn.Module):
239
+ def __init__(self, config: LLaMAConfig) -> None:
240
+ super().__init__()
241
+ hidden_dim = 4 * config.n_embd
242
+ n_hidden = int(2 * hidden_dim / 3)
243
+ n_hidden = find_multiple(n_hidden, 256)
244
+
245
+ self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False)
246
+ self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False)
247
+ self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False)
248
+
249
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
250
+ x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
251
+ x = self.c_proj(x)
252
+ return x
253
+
254
+
255
+ class RMSNorm(nn.Module):
256
+ """Root Mean Square Layer Normalization.
257
+
258
+ Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
259
+ https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
260
+ """
261
+
262
+ def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
263
+ super().__init__()
264
+ self.scale = nn.Parameter(torch.ones(size))
265
+ self.eps = eps
266
+ self.dim = dim
267
+
268
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
269
+ # NOTE: the original RMSNorm paper implementation is not equivalent
270
+ # norm_x = x.norm(2, dim=self.dim, keepdim=True)
271
+ # rms_x = norm_x * d_x ** (-1. / 2)
272
+ # x_normed = x / (rms_x + self.eps)
273
+ norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
274
+ x_normed = x * torch.rsqrt(norm_x + self.eps)
275
+ return self.scale * x_normed
276
+
277
+
278
+ def build_rope_cache(
279
+ seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
280
+ ) -> RoPECache:
281
+ """Enhanced Transformer with Rotary Position Embedding.
282
+
283
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
284
+ transformers/rope/__init__.py. MIT License:
285
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
286
+ """
287
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
288
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
289
+
290
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
291
+ seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
292
+
293
+ # Calculate the product of position index and $\theta_i$
294
+ idx_theta = torch.outer(seq_idx, theta).float()
295
+
296
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
297
+
298
+ # this is to mimic the behaviour of complex32, else we will get different results
299
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
300
+ cache = cache.half()
301
+ return cache
302
+
303
+
304
+ def apply_rope(x: torch.Tensor, rope_cache: RoPECache) -> torch.Tensor:
305
+ # truncate to support variable sizes
306
+ T = x.size(1)
307
+ rope_cache = rope_cache[:T]
308
+
309
+ # cast because the reference does
310
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
311
+ rope_cache = rope_cache.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
312
+ x_out2 = torch.stack(
313
+ [
314
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
315
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
316
+ ],
317
+ -1,
318
+ )
319
+
320
+ x_out2 = x_out2.flatten(3)
321
+ return x_out2.type_as(x)
lit_llama/packed_dataset.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Very loosely inspired by indexed_dataset in Fairseq, Megatron
2
+ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py
3
+
4
+
5
+ import os
6
+ import struct
7
+ import random
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch.utils.data import IterableDataset, get_worker_info
12
+
13
+
14
+ dtypes = {
15
+ 1: np.uint8,
16
+ 2: np.int8,
17
+ 3: np.int16,
18
+ 4: np.int32,
19
+ 5: np.int64,
20
+ 6: np.float32,
21
+ 7: np.float64,
22
+ 8: np.uint16,
23
+ }
24
+
25
+
26
+ def code(dtype):
27
+ for k in dtypes.keys():
28
+ if dtypes[k] == dtype:
29
+ return k
30
+ raise ValueError(dtype)
31
+
32
+
33
+ HDR_MAGIC = b"LITPKDS"
34
+ HDR_SIZE = 24 # bytes
35
+
36
+
37
+ class PackedDataset(IterableDataset):
38
+ def __init__(self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0):
39
+ self._filenames = filenames
40
+ self._n_chunks = n_chunks
41
+ self._block_size = block_size
42
+ self._seed = seed
43
+ self._shuffle = shuffle
44
+ self._wrap = wrap
45
+ self._num_processes = num_processes
46
+ self._process_rank = process_rank
47
+
48
+ def __iter__(self):
49
+ worker_info = get_worker_info()
50
+ num_workers = worker_info.num_workers if worker_info is not None else 1
51
+ worker_id = worker_info.id if worker_info is not None else 0
52
+ num_shards = num_workers * self._num_processes
53
+ shard_id = self._process_rank * num_workers + worker_id
54
+
55
+ max_num_files = len(self._filenames) // num_shards * num_shards
56
+ filenames = self._filenames[shard_id : max_num_files : num_shards]
57
+
58
+ return PackedDatasetIterator(
59
+ filenames=filenames,
60
+ n_chunks=self._n_chunks,
61
+ block_size=self._block_size,
62
+ seed=self._seed,
63
+ shuffle=self._shuffle,
64
+ wrap=self._wrap,
65
+ )
66
+
67
+
68
+ class PackedDatasetBuilder(object):
69
+ def __init__(
70
+ self,
71
+ outdir,
72
+ prefix,
73
+ chunk_size,
74
+ sep_token,
75
+ dtype="auto",
76
+ vocab_size=None,
77
+ ):
78
+ if dtype == "auto":
79
+ if vocab_size is None:
80
+ raise ValueError("vocab_size cannot be None when dtype='auto'")
81
+ if vocab_size is not None and vocab_size < 65500:
82
+ self._dtype = np.uint16
83
+ else:
84
+ self._dtype = np.int32
85
+ else:
86
+ self._dtype = dtype
87
+ self._counter = 0
88
+ self._chunk_size = chunk_size
89
+ self._outdir = outdir
90
+ self._prefix = prefix
91
+ self._sep_token = sep_token
92
+ self._arr = np.zeros(self._chunk_size, dtype=self._dtype)
93
+ self._arr.fill(self._sep_token)
94
+ self._idx = 0
95
+ self._version = 1
96
+ self._filenames = []
97
+
98
+ def _write_chunk(self):
99
+ filename = f"{self._prefix}_{self._counter:010d}.bin"
100
+ filename = os.path.join(self._outdir, filename)
101
+
102
+ with open(filename, "wb") as f:
103
+ f.write(HDR_MAGIC)
104
+ f.write(struct.pack("<Q", self._version))
105
+ f.write(struct.pack("<B", code(self._dtype)))
106
+ f.write(struct.pack("<Q", self._chunk_size))
107
+ f.write(self._arr.tobytes(order="C"))
108
+
109
+ self._filenames.append(filename)
110
+ self._counter += 1
111
+ self._arr.fill(self._sep_token)
112
+ self._idx = 0
113
+
114
+ @property
115
+ def dtype(self):
116
+ return self._dtype
117
+
118
+ @property
119
+ def filenames(self):
120
+ return self._filenames.copy()
121
+
122
+ def add_array(self, arr):
123
+ while self._idx + arr.shape[0] > self._chunk_size:
124
+ part_len = self._chunk_size - self._idx
125
+ self._arr[self._idx : self._idx + part_len] = arr[:part_len]
126
+ self._write_chunk()
127
+ arr = arr[part_len:]
128
+
129
+ arr_len = arr.shape[0]
130
+ self._arr[self._idx : self._idx + arr_len] = arr
131
+ self._idx += arr_len
132
+
133
+ def write_reminder(self):
134
+ self._write_chunk()
135
+
136
+
137
+ class PackedDatasetIterator:
138
+ def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap):
139
+ self._seed = seed
140
+ self._shuffle = shuffle
141
+ self._rng = np.random.default_rng(seed) if shuffle else None
142
+ self._block_idxs = None
143
+
144
+ self._wrap = wrap
145
+
146
+ # TODO: instead of filenames, we could have a single text stream
147
+ # (or text file) with the sequence of all files to be
148
+ # fetched/loaded.
149
+ self._filenames = filenames
150
+ self._file_idx = 0
151
+
152
+ self._n_chunks = n_chunks
153
+
154
+ self._dtype = None
155
+ self._block_size = block_size
156
+ self._n_blocks = None
157
+
158
+ self._mmaps = []
159
+ self._buffers = []
160
+
161
+ self._block_idxs = []
162
+ self._curr_idx = 0
163
+
164
+ self._load_n_chunks()
165
+
166
+ def _read_header(self, path):
167
+ with open(path, "rb") as f:
168
+ magic = f.read(len(HDR_MAGIC))
169
+ assert magic == HDR_MAGIC, "File doesn't match expected format."
170
+ version = struct.unpack("<Q", f.read(8))
171
+ assert (1,) == version
172
+ (dtype_code,) = struct.unpack("<B", f.read(1))
173
+ dtype = dtypes[dtype_code]
174
+ (chunk_size,) = struct.unpack("<Q", f.read(8))
175
+ return dtype, chunk_size
176
+
177
+ def _close_mmaps(self):
178
+ for mmap in self._mmaps:
179
+ mmap._mmap.close()
180
+
181
+ def _load_n_chunks(self):
182
+ self._close_mmaps()
183
+ self._mmaps = []
184
+ self._buffers = []
185
+
186
+ if self._n_chunks > len(self._filenames[self._file_idx:]):
187
+ if not self._wrap:
188
+ raise StopIteration
189
+ else:
190
+ self._file_idx = 0
191
+
192
+ for i in range(self._n_chunks):
193
+ filename = self._filenames[self._file_idx + i]
194
+ if self._dtype is None:
195
+ self._dtype, self._chunk_size = self._read_header(
196
+ filename
197
+ )
198
+ self._n_blocks = self._chunk_size // self._block_size
199
+ # TODO: check header matches with previous files
200
+ mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE)
201
+ self._mmaps.append(mmap)
202
+ self._buffers.append(memoryview(mmap))
203
+
204
+ self._file_idx += self._n_chunks
205
+ n_all_blocks = self._n_chunks * self._n_blocks
206
+
207
+ self._block_idxs = (
208
+ self._rng.permutation(n_all_blocks)
209
+ if self._shuffle
210
+ else range(n_all_blocks)
211
+ )
212
+
213
+ self._curr_idx = 0
214
+
215
+ def __del__(self):
216
+ self._close_mmaps()
217
+ del self._mmaps
218
+ del self._buffers
219
+
220
+ def __iter__(self):
221
+ return self
222
+
223
+ def __next__(self):
224
+ if self._curr_idx >= len(self._block_idxs):
225
+ self._load_n_chunks()
226
+ # TODO: trigger fetching next next n_chunks if remote
227
+ block_idx = self._block_idxs[self._curr_idx]
228
+ chunk_id = block_idx // self._n_blocks
229
+ buffer = self._buffers[chunk_id]
230
+ elem_id = (block_idx % self._n_blocks) * self._block_size
231
+ offset = np.dtype(self._dtype).itemsize * elem_id
232
+ arr = np.frombuffer(
233
+ buffer, dtype=self._dtype, count=self._block_size, offset=offset
234
+ )
235
+ self._curr_idx += 1
236
+ return torch.from_numpy(arr.astype(np.int64))
237
+
238
+
239
+ class CombinedDataset(IterableDataset):
240
+ def __init__(self, datasets, seed, weights=None):
241
+ self._seed = seed
242
+ self._datasets = datasets
243
+ self._weights = weights
244
+ n_datasets = len(datasets)
245
+ if weights is None:
246
+ self._weights = [1 / n_datasets] * n_datasets
247
+
248
+ def __iter__(self):
249
+ return CombinedDatasetIterator(self._datasets, self._seed, self._weights)
250
+
251
+
252
+ class CombinedDatasetIterator:
253
+ def __init__(self, datasets, seed, weights):
254
+ self._datasets = [iter(el) for el in datasets]
255
+ self._weights = weights
256
+ self._rng = random.Random(seed)
257
+
258
+ def __next__(self):
259
+ dataset, = self._rng.choices(self._datasets, weights=self._weights, k=1)
260
+ return next(dataset)
lit_llama/quantization.py ADDED
@@ -0,0 +1,614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from contextlib import contextmanager
3
+ import warnings
4
+ import math
5
+
6
+ import torch
7
+
8
+ # configuration for bitsandbytes before import
9
+ os.environ["BITSANDBYTES_NOWELCOME"] = "1"
10
+ warnings.filterwarnings(
11
+ "ignore",
12
+ message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization",
13
+ )
14
+ warnings.filterwarnings(
15
+ "ignore",
16
+ message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization",
17
+ )
18
+ warnings.filterwarnings(
19
+ "ignore",
20
+ message="The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable.",
21
+ )
22
+
23
+ try:
24
+ import bitsandbytes as bnb # noqa: E402
25
+ except:
26
+ bnb = None
27
+
28
+ try:
29
+ import triton # noqa: E402
30
+ import triton.language as tl # noqa: E402
31
+ except:
32
+ triton = None
33
+
34
+ if bnb is not None:
35
+
36
+ class Linear8bitLt(bnb.nn.Linear8bitLt):
37
+ """Wraps `bnb.nn.Linear8bitLt` and enables instantiation directly on the device and
38
+ re-quantizaton when loading the state dict.
39
+
40
+
41
+ This should only be used for inference. For training, use `bnb.nn.Linear8bitLt` directly.
42
+ """
43
+
44
+ def __init__(self, *args, **kwargs):
45
+ super().__init__(*args, **kwargs, has_fp16_weights=False, threshold=6.0)
46
+ # We quantize the initial weight here so we don't end up filling the device
47
+ # memory with float32 weights which could lead to OOM.
48
+ self._quantize_weight(self.weight.data)
49
+
50
+ def _load_from_state_dict(self, local_state_dict, *args, **kwargs):
51
+ # There is only one key that ends with `*.weight`, the other one is the bias
52
+ weight_key = next(
53
+ (name for name in local_state_dict.keys() if name.endswith("weight")),
54
+ None,
55
+ )
56
+ if weight_key is None:
57
+ return
58
+
59
+ # Load the weight from the state dict and re-quantize it
60
+ weight = local_state_dict.pop(weight_key)
61
+ self._quantize_weight(weight)
62
+
63
+ # If there is a bias, let nn.Module load it
64
+ if local_state_dict:
65
+ super()._load_from_state_dict(local_state_dict, *args, **kwargs)
66
+
67
+ def _quantize_weight(self, weight: torch.Tensor) -> None:
68
+ # This code is taken and adapted from `bnb.nn.Int8Params.cuda()`
69
+ B = weight.contiguous().half().cuda()
70
+ CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
71
+ del CBt
72
+ del SCBt
73
+ self.weight.data = CB
74
+ setattr(self.weight, "CB", CB)
75
+ setattr(self.weight, "SCB", SCB)
76
+
77
+
78
+ if triton is not None:
79
+ # This is adapted from the OpenAI Triton matmul example.
80
+ @triton.autotune(
81
+ configs=[
82
+ triton.Config(
83
+ {
84
+ "BLOCK_SIZE_M": 128,
85
+ "BLOCK_SIZE_N": 256,
86
+ "BLOCK_SIZE_K": 32,
87
+ "GROUP_SIZE_M": 8,
88
+ },
89
+ num_stages=3,
90
+ num_warps=8,
91
+ ),
92
+ triton.Config(
93
+ {
94
+ "BLOCK_SIZE_M": 256,
95
+ "BLOCK_SIZE_N": 128,
96
+ "BLOCK_SIZE_K": 32,
97
+ "GROUP_SIZE_M": 8,
98
+ },
99
+ num_stages=3,
100
+ num_warps=8,
101
+ ),
102
+ triton.Config(
103
+ {
104
+ "BLOCK_SIZE_M": 256,
105
+ "BLOCK_SIZE_N": 64,
106
+ "BLOCK_SIZE_K": 32,
107
+ "GROUP_SIZE_M": 8,
108
+ },
109
+ num_stages=4,
110
+ num_warps=4,
111
+ ),
112
+ triton.Config(
113
+ {
114
+ "BLOCK_SIZE_M": 64,
115
+ "BLOCK_SIZE_N": 256,
116
+ "BLOCK_SIZE_K": 32,
117
+ "GROUP_SIZE_M": 8,
118
+ },
119
+ num_stages=4,
120
+ num_warps=4,
121
+ ),
122
+ triton.Config(
123
+ {
124
+ "BLOCK_SIZE_M": 128,
125
+ "BLOCK_SIZE_N": 128,
126
+ "BLOCK_SIZE_K": 32,
127
+ "GROUP_SIZE_M": 8,
128
+ },
129
+ num_stages=4,
130
+ num_warps=4,
131
+ ),
132
+ triton.Config(
133
+ {
134
+ "BLOCK_SIZE_M": 128,
135
+ "BLOCK_SIZE_N": 64,
136
+ "BLOCK_SIZE_K": 32,
137
+ "GROUP_SIZE_M": 8,
138
+ },
139
+ num_stages=4,
140
+ num_warps=4,
141
+ ),
142
+ triton.Config(
143
+ {
144
+ "BLOCK_SIZE_M": 64,
145
+ "BLOCK_SIZE_N": 128,
146
+ "BLOCK_SIZE_K": 32,
147
+ "GROUP_SIZE_M": 8,
148
+ },
149
+ num_stages=4,
150
+ num_warps=4,
151
+ ),
152
+ triton.Config(
153
+ {
154
+ "BLOCK_SIZE_M": 128,
155
+ "BLOCK_SIZE_N": 32,
156
+ "BLOCK_SIZE_K": 32,
157
+ "GROUP_SIZE_M": 8,
158
+ },
159
+ num_stages=4,
160
+ num_warps=4,
161
+ ),
162
+ triton.Config(
163
+ {
164
+ "BLOCK_SIZE_M": 64,
165
+ "BLOCK_SIZE_N": 32,
166
+ "BLOCK_SIZE_K": 32,
167
+ "GROUP_SIZE_M": 8,
168
+ },
169
+ num_stages=5,
170
+ num_warps=2,
171
+ ),
172
+ triton.Config(
173
+ {
174
+ "BLOCK_SIZE_M": 32,
175
+ "BLOCK_SIZE_N": 64,
176
+ "BLOCK_SIZE_K": 32,
177
+ "GROUP_SIZE_M": 8,
178
+ },
179
+ num_stages=5,
180
+ num_warps=2,
181
+ ),
182
+ ],
183
+ key=["M", "N", "K"],
184
+ )
185
+ @triton.jit
186
+ def linear_kernel_4bit_weight(
187
+ # Pointers to matrices
188
+ a_ptr,
189
+ b_ptr,
190
+ c_ptr,
191
+ bscales_ptr,
192
+ bzeros_ptr,
193
+ # bdequant,
194
+ # Matrix dimensions
195
+ M,
196
+ N,
197
+ K,
198
+ # The stride variables represent how much to increase the ptr by when moving by 1
199
+ # element in a particular dimension. E.g. stride_am is how much to increase a_ptr
200
+ # by to get the element one row down (A has M rows)
201
+ stride_am,
202
+ stride_ak,
203
+ stride_bk,
204
+ stride_bn,
205
+ stride_cm,
206
+ stride_cn,
207
+ # Meta-parameters
208
+ BLOCK_SIZE_M: tl.constexpr,
209
+ BLOCK_SIZE_N: tl.constexpr,
210
+ BLOCK_SIZE_K: tl.constexpr,
211
+ GROUP_SIZE_M: tl.constexpr,
212
+ ):
213
+ """Kernel for computing the matmul C = A x B.T.
214
+ A has shape (M, K), B has shape (N, K) and C has shape (M, N)
215
+ """
216
+ # -----------------------------------------------------------
217
+ # Map program ids `pid` to the block of C it should compute.
218
+ # This is done in a grouped ordering to promote L2 data reuse
219
+ # See above `L2 Cache Optimizations` section for details
220
+ pid = tl.program_id(axis=0)
221
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
222
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
223
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
224
+ group_id = pid // num_pid_in_group
225
+ first_pid_m = group_id * GROUP_SIZE_M
226
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
227
+ pid_m = first_pid_m + (pid % group_size_m)
228
+ pid_n = (pid % num_pid_in_group) // group_size_m
229
+
230
+ # ----------------------------------------------------------
231
+ # Create pointers for the first blocks of A and B.
232
+ # We will advance this pointer as we move in the K direction
233
+ # and accumulate
234
+ # a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
235
+ # b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
236
+ # see above `Pointer Arithmetics` section for details
237
+ offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
238
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
239
+ a_mask = offs_am[:, None] < M
240
+ b_mask = offs_bn[None, :] < N
241
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
242
+ a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
243
+ b_ptrs = b_ptr + (
244
+ (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
245
+ )
246
+
247
+ bscales_ptrs = bscales_ptr + offs_bn[None, :]
248
+ bzeros_ptrs = bzeros_ptr + offs_bn[None, :]
249
+
250
+ scale = tl.load(bscales_ptrs)
251
+ zero = tl.load(bzeros_ptrs)
252
+ # -----------------------------------------------------------
253
+ # Iterate to compute a block of the C matrix
254
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
255
+ # of fp32 values for higher accuracy.
256
+ # `accumulator` will be converted back to fp16 after the loop
257
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
258
+ for k in range(0, K, BLOCK_SIZE_K):
259
+ # wasteful as it is to load everything twice, my attempts at avoiding it lead to slower code
260
+ b12 = tl.load(b_ptrs, mask=b_mask)
261
+ # Note that for simplicity, we don't apply a mask in K here.
262
+ a = tl.load(a_ptrs, mask=a_mask).to(tl.float32)
263
+ b = (
264
+ ((b12.to(tl.uint8) >> ((offs_k[:, None] % 2) * 4)) & 0xF).to(tl.float32)
265
+ - zero
266
+ ) * scale
267
+ accumulator += tl.dot(a, b)
268
+
269
+ # Advance the ptrs to the next K block
270
+ a_ptrs += BLOCK_SIZE_K * stride_ak
271
+ b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
272
+ c = accumulator
273
+
274
+ # -----------------------------------------------------------
275
+ # Write back the block of the output matrix C
276
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
277
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
278
+ c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
279
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
280
+ tl.store(c_ptrs, c, mask=c_mask)
281
+
282
+ def qlinear_4bit_weight(inp, weight, scales, zeros):
283
+ weight = weight.t().contiguous()
284
+ c_shape = inp.shape[:-1] + weight.shape[-1:]
285
+ inp = inp.reshape(-1, inp.shape[-1]).contiguous()
286
+ # we pad the input to amortize triton compilation cost better
287
+ PAD_TO = 256
288
+ if inp.shape[0] % PAD_TO != 0:
289
+ c_crop = inp.shape[0]
290
+ new_inp_shape0 = inp.shape[0] + PAD_TO - inp.shape[0] % PAD_TO
291
+ inp2 = inp.new_empty((new_inp_shape0, inp.shape[1]))
292
+ inp2[: inp.shape[0]] = inp
293
+ inp2[inp.shape[0] :].zero_()
294
+ inp = inp2
295
+ else:
296
+ c_crop = None
297
+
298
+ assert inp.shape[1] == weight.shape[0] * 2, "incompatible dimensions"
299
+
300
+ assert scales.shape == (weight.shape[1], 1)
301
+ assert zeros.shape == (weight.shape[1], 1)
302
+ scales = scales.contiguous()
303
+ zeros = zeros.contiguous()
304
+ K, N = weight.shape
305
+ M, K = inp.shape
306
+ assert (
307
+ K % 32 == 0
308
+ ), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"
309
+ # allocates output
310
+ c = torch.empty((M, N), device=inp.device, dtype=inp.dtype)
311
+ # 1D launch kernel where each block gets its own program.
312
+ grid = lambda META: (
313
+ triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
314
+ )
315
+ linear_kernel_4bit_weight[grid](
316
+ inp,
317
+ weight,
318
+ c,
319
+ scales,
320
+ zeros,
321
+ M,
322
+ N,
323
+ K,
324
+ inp.stride(0),
325
+ inp.stride(1),
326
+ weight.stride(0),
327
+ weight.stride(1),
328
+ c.stride(0),
329
+ c.stride(1),
330
+ )
331
+ return c[:c_crop].reshape(c_shape)
332
+
333
+ else:
334
+ qlinear_4bit_weight = None
335
+
336
+
337
+ # for correctness but with terrible perf
338
+ class ColBlockQuantizedLinear(torch.nn.Module):
339
+ def __init__(self, in_features, out_features, bias: bool, *, bits, tile_cols):
340
+ super().__init__()
341
+ self.in_features = in_features
342
+ self.out_features = out_features
343
+ self.tile_cols = tile_cols if tile_cols != -1 else self.in_features
344
+ self.bits = bits
345
+ self.entries_per_byte = 8 // bits
346
+ assert self.entries_per_byte > 0 and self.entries_per_byte * self.bits == 8
347
+ assert in_features % self.entries_per_byte == 0
348
+ self.register_buffer(
349
+ "quant_weight",
350
+ torch.empty(
351
+ (self.out_features, self.in_features // self.entries_per_byte),
352
+ dtype=torch.uint8,
353
+ )
354
+ .t()
355
+ .contiguous()
356
+ .t(),
357
+ )
358
+ self.register_buffer(
359
+ "scales",
360
+ torch.empty(
361
+ (
362
+ self.out_features,
363
+ (self.in_features + self.tile_cols - 1) // self.tile_cols,
364
+ )
365
+ ),
366
+ )
367
+ self.register_buffer("zeros", torch.empty_like(self.scales))
368
+ assert isinstance(bias, bool)
369
+ if bias:
370
+ self.register_buffer("bias", torch.empty((self.out_features,)))
371
+ else:
372
+ self.register_buffer("bias", None)
373
+
374
+ def pack_weight(self, weight):
375
+ weight = weight.to(device=self.quant_weight.device, copy=True)
376
+ for j in range(self.scales.size(1)):
377
+ weight[:, j * self.tile_cols : (j + 1) * self.tile_cols] /= self.scales[
378
+ :, j : j + 1
379
+ ]
380
+ weight[:, j * self.tile_cols : (j + 1) * self.tile_cols] += self.zeros[
381
+ :, j : j + 1
382
+ ]
383
+ weight = weight.clamp_(min=0, max=2**self.bits - 1).to(dtype=torch.uint8)
384
+ self.quant_weight.zero_()
385
+ for nr in range(self.entries_per_byte):
386
+ self.quant_weight += weight[:, nr :: self.entries_per_byte] << (
387
+ nr * self.bits
388
+ )
389
+
390
+ def get_weight(self, dtype=torch.float):
391
+ weight = torch.empty(
392
+ (self.out_features, self.in_features),
393
+ device=self.quant_weight.device,
394
+ dtype=dtype,
395
+ )
396
+ mask = (1 << self.bits) - 1
397
+ for nr in range(self.entries_per_byte):
398
+ weight[:, nr :: self.entries_per_byte] = (
399
+ (self.quant_weight >> (nr * self.bits)) & mask
400
+ ).float()
401
+ self.quant_weight.to(dtype)
402
+ for j in range(self.scales.size(1)):
403
+ weight[:, j * self.tile_cols : (j + 1) * self.tile_cols] -= self.zeros[
404
+ :, j : j + 1
405
+ ]
406
+ weight[:, j * self.tile_cols : (j + 1) * self.tile_cols] *= self.scales[
407
+ :, j : j + 1
408
+ ]
409
+ return weight
410
+
411
+ def forward(self, inp):
412
+ if (
413
+ triton is not None
414
+ and self.bits == 4
415
+ and self.quant_weight.device.type == "cuda"
416
+ and self.zeros.shape[1] == 1
417
+ and self.quant_weight.shape[1] % 32 == 0
418
+ ):
419
+ return qlinear_4bit_weight(inp, self.quant_weight, self.scales, self.zeros)
420
+ weight = self.get_weight(dtype=inp.dtype)
421
+ return torch.nn.functional.linear(inp, weight, self.bias)
422
+
423
+
424
+ class GPTQQuantizer:
425
+ # The algorithm and code has been taken from https://github.com/IST-DASLab/gptq/
426
+ # E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
427
+ # portions copyright by the authors licensed under the Apache License 2.0
428
+ # All errors are our own.
429
+
430
+ def __init__(
431
+ self,
432
+ linear_module,
433
+ *,
434
+ bits,
435
+ perchannel=True,
436
+ sym=False,
437
+ blocksize=128,
438
+ percdamp=0.01,
439
+ groupsize=-1,
440
+ actorder=False
441
+ ):
442
+ assert isinstance(linear_module, torch.nn.Linear)
443
+
444
+ self.linear_module = linear_module
445
+ self.dev = self.linear_module.weight.device
446
+ self.rows = linear_module.weight.shape[0]
447
+ self.columns = linear_module.weight.shape[1]
448
+ self.H = torch.zeros((self.columns, self.columns), device=self.dev)
449
+ self.nsamples = 0
450
+ self.bits = bits
451
+ self.maxq = 2**bits - 1
452
+ self.perchannel = perchannel
453
+ self.sym = sym
454
+ self.blocksize = blocksize
455
+ self.percdamp = percdamp
456
+ self.groupsize = groupsize
457
+ self.actorder = actorder
458
+ self.tile_cols = self.columns if groupsize == -1 else groupsize
459
+ self.scales = torch.zeros(
460
+ (self.rows, (self.columns + self.tile_cols - 1) // self.tile_cols),
461
+ dtype=self.linear_module.weight.dtype,
462
+ device=self.dev,
463
+ )
464
+ self.zeros = torch.zeros_like(self.scales)
465
+ assert not (
466
+ self.actorder and self.groupsize != -1
467
+ ), "The permutation trick does not work for grouped quantization"
468
+
469
+ @staticmethod
470
+ def quantize_weight(x, scale, zero, maxq):
471
+ q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
472
+ x_rec = scale * (q - zero)
473
+ return x_rec
474
+
475
+ def find_params_weight(self, x):
476
+ dev = x.device
477
+
478
+ shape = x.shape
479
+ if self.perchannel:
480
+ x = x.flatten(1)
481
+ else:
482
+ x = x.flatten().unsqueeze(0)
483
+
484
+ tmp = torch.zeros(x.shape[0], device=dev)
485
+ xmin = torch.minimum(x.min(1)[0], tmp)
486
+ xmax = torch.maximum(x.max(1)[0], tmp)
487
+
488
+ if self.sym:
489
+ xmax = torch.maximum(torch.abs(xmin), xmax)
490
+ tmp = xmin < 0
491
+ if torch.any(tmp):
492
+ xmin[tmp] = -xmax[tmp]
493
+ tmp = (xmin == 0) & (xmax == 0)
494
+ xmin[tmp] = -1
495
+ xmax[tmp] = +1
496
+
497
+ scale = (xmax - xmin) / self.maxq
498
+ if self.sym:
499
+ zero = torch.full_like(scale, (self.maxq + 1) / 2)
500
+ else:
501
+ zero = torch.round(-xmin / scale)
502
+
503
+ if not self.perchannel:
504
+ tmp = shape[0]
505
+ scale = scale.repeat(tmp)
506
+ zero = zero.repeat(tmp)
507
+
508
+ shape = [-1] + [1] * (len(shape) - 1)
509
+ scale = scale.reshape(shape)
510
+ zero = zero.reshape(shape)
511
+ return scale, zero
512
+
513
+ def collect_input_stats(self, _1, inp, _2):
514
+ inp = inp[0].detach()
515
+ self.last_inp = inp
516
+ if len(inp.shape) == 2:
517
+ inp = inp.unsqueeze(0)
518
+ tmp = inp.shape[0]
519
+ if len(inp.shape) == 3:
520
+ inp = inp.reshape((-1, inp.shape[-1]))
521
+ inp = inp.t()
522
+ self.H *= self.nsamples / (self.nsamples + tmp)
523
+ self.nsamples += tmp
524
+ # inp = inp.float()
525
+ inp = math.sqrt(2 / self.nsamples) * inp.float()
526
+ # self.H += 2 / self.nsamples * inp.matmul(inp.t())
527
+ self.H += inp.matmul(inp.t())
528
+
529
+ def quantize(self):
530
+ W = self.linear_module.weight.detach().to(dtype=torch.float, copy=True)
531
+
532
+ scale, zero = self.find_params_weight(W)
533
+ self.scales[:] = scale
534
+ self.zeros[:] = zero
535
+
536
+ H = self.H
537
+ del self.H
538
+ dead = torch.diag(H) == 0
539
+ H[dead, dead] = 1
540
+ W[:, dead] = 0
541
+ if self.actorder:
542
+ perm = torch.argsort(torch.diag(H), descending=True)
543
+ W = W[:, perm]
544
+ H = H[perm][:, perm]
545
+
546
+ Losses = torch.zeros_like(W)
547
+ Q = torch.zeros_like(W)
548
+
549
+ damp = self.percdamp * torch.mean(torch.diag(H))
550
+ diag = torch.arange(self.columns, device=self.dev)
551
+ H[diag, diag] += damp
552
+ H = torch.linalg.cholesky(H)
553
+ H = torch.cholesky_inverse(H)
554
+ H = torch.linalg.cholesky(H, upper=True)
555
+ Hinv = H
556
+
557
+ for i1 in range(0, self.columns, self.blocksize):
558
+ i2 = min(i1 + self.blocksize, self.columns)
559
+ count = i2 - i1
560
+
561
+ W1 = W[:, i1:i2].clone()
562
+ Q1 = torch.zeros_like(W1)
563
+ Err1 = torch.zeros_like(W1)
564
+ Losses1 = torch.zeros_like(W1)
565
+ Hinv1 = Hinv[i1:i2, i1:i2]
566
+
567
+ for i in range(count):
568
+ w = W1[:, i]
569
+ d = Hinv1[i, i]
570
+
571
+ if self.groupsize != -1:
572
+ if (i1 + i) % self.groupsize == 0:
573
+ scale, zero = self.find_params_weight(
574
+ W[:, (i1 + i) : (i1 + i + self.groupsize)]
575
+ )
576
+ self.scales[:, (i1 + i) // self.groupsize] = scale
577
+ self.zeros[:, (i1 + i) // self.groupsize] = zero
578
+
579
+ q = self.quantize_weight(w.unsqueeze(1), scale, zero, self.maxq)
580
+ q = q.squeeze(1)
581
+ assert q.dim() == 1
582
+ Q1[:, i] = q
583
+ Losses1[:, i] = (w - q) ** 2 / d**2
584
+
585
+ err1 = (w - q) / d
586
+ W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
587
+ Err1[:, i] = err1
588
+
589
+ Q[:, i1:i2] = Q1
590
+ Losses[:, i1:i2] = Losses1 / 2
591
+
592
+ W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
593
+
594
+ if self.actorder:
595
+ invperm = torch.argsort(perm)
596
+ Q = Q[:, invperm]
597
+
598
+ weight = Q.reshape(self.linear_module.weight.shape).to(
599
+ self.linear_module.weight.data.dtype
600
+ )
601
+ error = torch.sum(Losses).item()
602
+
603
+ q_module = ColBlockQuantizedLinear(
604
+ self.linear_module.in_features,
605
+ self.linear_module.out_features,
606
+ self.linear_module.bias is not None,
607
+ bits=self.bits,
608
+ tile_cols=self.groupsize,
609
+ ).to(self.dev)
610
+ q_module.scales = self.scales
611
+ q_module.zeros = self.zeros
612
+ q_module.pack_weight(weight)
613
+ q_module.bias = self.linear_module.bias
614
+ return q_module, error
lit_llama/tokenizer.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from sentencepiece import SentencePieceProcessor, SentencePieceTrainer
7
+
8
+
9
+ class Tokenizer:
10
+ """Tokenizer for LLaMA."""
11
+
12
+ def __init__(self, model_path: Path) -> None:
13
+ self.processor = SentencePieceProcessor(model_file=str(model_path))
14
+ self.bos_id = self.processor.bos_id()
15
+ self.eos_id = self.processor.eos_id()
16
+ self.pad_id = self.processor.pad_id()
17
+
18
+ @property
19
+ def vocab_size(self) -> int:
20
+ return self.processor.vocab_size()
21
+
22
+ def encode(
23
+ self,
24
+ string: str,
25
+ bos: bool = True,
26
+ eos: bool = False,
27
+ max_length: int = -1,
28
+ pad: bool = False,
29
+ device: Optional[torch.device] = None
30
+ ) -> torch.Tensor:
31
+ tokens = self.processor.encode(string)
32
+ if bos:
33
+ tokens = [self.bos_id] + tokens
34
+ if eos:
35
+ tokens = tokens + [self.eos_id]
36
+ if max_length > 0:
37
+ tokens = tokens[:max_length]
38
+ if pad and len(tokens) < max_length:
39
+ tokens += [self.pad_id] * (max_length - len(tokens))
40
+
41
+ return torch.tensor(tokens, dtype=torch.int, device=device)
42
+
43
+ def decode(self, tokens: torch.Tensor) -> str:
44
+ return self.processor.decode(tokens.tolist())
45
+
46
+ @staticmethod
47
+ def train(input: str, destination: str, vocab_size=32000) -> None:
48
+ model_prefix = os.path.join(destination, "tokenizer")
49
+ SentencePieceTrainer.Train(input=input, model_prefix=model_prefix, vocab_size=vocab_size)
lit_llama/utils.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for training and inference."""
2
+
3
+ import functools
4
+ import pickle
5
+ import warnings
6
+ from io import BytesIO
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ import torch.utils._device
11
+ from lightning.fabric.strategies import DeepSpeedStrategy, FSDPStrategy
12
+ from torch.distributed.fsdp import FullStateDictConfig
13
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
14
+ from torch.distributed.fsdp import StateDictType
15
+ from torch.serialization import normalize_storage_type
16
+
17
+ llama_model_sizes = {
18
+ 4096: "7B", # 7B n_embd=4096
19
+ 5120: "13B", # 13B n_embd=5120
20
+ 6656: "30B", # 30B n_embd=6656
21
+ 8192: "65B", # 65B n_embd=8192
22
+ }
23
+
24
+
25
+ def llama_model_lookup(checkpoint: dict) -> str:
26
+ """Returns the LLaMA model name from the checkpoint.
27
+
28
+ Checks the width of the lm_head.weight matrix, as these uniquely identify the model.
29
+ """
30
+ embedding_size = checkpoint['transformer.wte.weight'].shape[1]
31
+ return llama_model_sizes[embedding_size]
32
+
33
+
34
+ def find_multiple(n: int, k: int) -> int:
35
+ if n % k == 0:
36
+ return n
37
+ return n + k - (n % k)
38
+
39
+
40
+ def save_model_checkpoint(fabric, model, file_path):
41
+ """Handles boilerplate logic for retrieving and saving the state_dict.
42
+
43
+ This will be upstreamed to Fabric soon.
44
+ """
45
+ file_path = Path(file_path)
46
+
47
+ if isinstance(fabric.strategy, DeepSpeedStrategy):
48
+ from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
49
+
50
+ fabric.save(file_path, {"model": model})
51
+ fabric.barrier()
52
+ if fabric.global_rank == 0:
53
+ # Create a consolidated checkpoint with the same name next to the deepspeed checkpoint
54
+ convert_zero_checkpoint_to_fp32_state_dict(file_path, file_path.with_suffix(".pth"))
55
+ return
56
+
57
+ if isinstance(fabric.strategy, FSDPStrategy):
58
+ save_policy = FullStateDictConfig(offload_to_cpu=(fabric.world_size > 1), rank0_only=True)
59
+ with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
60
+ state_dict = model._forward_module.state_dict()
61
+ else:
62
+ state_dict = model.state_dict()
63
+
64
+ if fabric.global_rank == 0:
65
+ torch.save(state_dict, file_path)
66
+ fabric.barrier()
67
+
68
+
69
+ class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
70
+ def __init__(self, device=None, dtype=None, quantization_mode=None):
71
+ """
72
+ Create tensors with given device and dtype and don't run initialization
73
+ (but instead use "empty tensors", i.e. uninitialized memory).
74
+
75
+ device: `torch.device` to work with
76
+ dtype: `torch.dtype` to work with
77
+ quantization_mode: optional string, quantization mode to work with, default `None`.
78
+ Available modes: `llm.int8` bitsnbytes LLM.int8 quantization (only on GPU)
79
+ `gptq.int4`, `gptq.int8`: GPTQ pre-quantized models
80
+
81
+ Example::
82
+ with EmptyInitOnDevice("cuda", dtype=torch.bfloat16):
83
+ model = LLaMA.from_name('7B')
84
+ model.load_state_dict(torch.load('llama-lit/7B/lit-llama.pth'))"""
85
+
86
+ self.quantization_mode = quantization_mode
87
+ self.quantized_linear_cls = None
88
+ if self.quantization_mode == 'llm.int8':
89
+ if device.type != "cuda":
90
+ raise ValueError("Quantization is only supported on the GPU.")
91
+ from .quantization import Linear8bitLt
92
+ self.quantized_linear_cls = Linear8bitLt
93
+ elif self.quantization_mode == 'gptq.int4':
94
+ from .quantization import ColBlockQuantizedLinear
95
+ self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1)
96
+ elif self.quantization_mode == 'gptq.int8':
97
+ from .quantization import ColBlockQuantizedLinear
98
+ self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1)
99
+ elif self.quantization_mode is not None:
100
+ raise RuntimeError(f"unknown quantization mode {self.quantization_mode}")
101
+ self.device = device
102
+ self.dtype = dtype
103
+
104
+ def __enter__(self):
105
+ if self.quantized_linear_cls != None:
106
+ self.torch_linear_cls = torch.nn.Linear
107
+ torch.nn.Linear = self.quantized_linear_cls
108
+ return super().__enter__()
109
+
110
+ def __exit__(self, exc_type, exc_val, exc_tb):
111
+ if self.quantized_linear_cls != None:
112
+ torch.nn.Linear = self.torch_linear_cls
113
+ return super().__exit__(exc_type, exc_val, exc_tb)
114
+
115
+ def __torch_function__(self, func, types, args=(), kwargs=None):
116
+ kwargs = kwargs or {}
117
+ if getattr(func, "__module__", None) == "torch.nn.init":
118
+ if "tensor" in kwargs:
119
+ return kwargs["tensor"]
120
+ else:
121
+ return args[0]
122
+ if (
123
+ self.device is not None
124
+ and func in torch.utils._device._device_constructors()
125
+ and kwargs.get("device") is None
126
+ ):
127
+ kwargs["device"] = self.device
128
+ if (
129
+ self.dtype is not None
130
+ and func in torch.utils._device._device_constructors()
131
+ and kwargs.get("dtype") is None
132
+ ):
133
+ kwargs["dtype"] = self.dtype
134
+ return func(*args, **kwargs)
135
+
136
+
137
+ # this is taken from torchhacks https://github.com/lernapparat/torchhacks
138
+
139
+
140
+ class NotYetLoadedTensor:
141
+ def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args):
142
+ self.metatensor = metatensor
143
+ self.archiveinfo = archiveinfo
144
+ self.storageinfo = storageinfo
145
+ self.rebuild_args = rebuild_args
146
+
147
+ @classmethod
148
+ def rebuild_from_type_v2(cls, func, new_type, args, state, *, archiveinfo=None):
149
+ ret = func(*args)
150
+ if isinstance(ret, NotYetLoadedTensor):
151
+ old_lt = ret._load_tensor
152
+
153
+ def _load_tensor():
154
+ t = old_lt()
155
+ return torch._tensor._rebuild_from_type_v2(
156
+ lambda: t, new_type, (), state
157
+ )
158
+
159
+ ret._load_tensor = _load_tensor
160
+ return ret
161
+ return torch._tensor._rebuild_from_type_v2(func, new_type, args, state)
162
+
163
+ @classmethod
164
+ def rebuild_parameter(
165
+ cls, data, requires_grad, backward_hooks, *, archiveinfo=None
166
+ ):
167
+ if isinstance(data, NotYetLoadedTensor):
168
+ old_lt = data._load_tensor
169
+
170
+ def _load_tensor():
171
+ t = old_lt()
172
+ return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks)
173
+
174
+ data._load_tensor = _load_tensor
175
+ return data
176
+ return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks)
177
+
178
+ @classmethod
179
+ def rebuild_tensor_v2(
180
+ cls,
181
+ storage,
182
+ storage_offset,
183
+ size,
184
+ stride,
185
+ requires_grad,
186
+ backward_hooks,
187
+ metadata=None,
188
+ *,
189
+ archiveinfo=None,
190
+ ):
191
+ rebuild_args = (
192
+ storage_offset,
193
+ size,
194
+ stride,
195
+ requires_grad,
196
+ backward_hooks,
197
+ metadata,
198
+ )
199
+ metatensor = torch._utils._rebuild_tensor_v2(
200
+ storage,
201
+ storage_offset,
202
+ size,
203
+ stride,
204
+ requires_grad,
205
+ backward_hooks,
206
+ metadata,
207
+ )
208
+ storageinfo = storage.archiveinfo
209
+ return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args)
210
+
211
+ def _load_tensor(self):
212
+ name, storage_cls, fn, device, size = self.storageinfo
213
+ dtype = self.metatensor.dtype
214
+
215
+ uts = (
216
+ self.archiveinfo.zipfile_context.zf.get_storage_from_record(
217
+ f"data/{fn}",
218
+ size * torch._utils._element_size(dtype),
219
+ torch.UntypedStorage,
220
+ )
221
+ ._typed_storage()
222
+ ._untyped_storage
223
+ )
224
+ with warnings.catch_warnings():
225
+ warnings.simplefilter("ignore")
226
+ storage = torch.storage.TypedStorage(
227
+ wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True
228
+ )
229
+ tensor = torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args)
230
+ return tensor
231
+
232
+ @classmethod
233
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
234
+ if kwargs is None:
235
+ kwargs = {}
236
+ loaded_args = [
237
+ (a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args
238
+ ]
239
+ res = func(*loaded_args, **kwargs)
240
+ # gc.collect would be costly here, maybe do it optionally
241
+ return res
242
+
243
+ def __getattr__(self, name):
244
+ # properties
245
+ ## TODO: device, is_...??
246
+ ## TODO: mH, mT, H, T, data, imag, real
247
+ ## name ???
248
+ if name in {
249
+ "dtype",
250
+ "grad",
251
+ "grad_fn",
252
+ "layout",
253
+ "names",
254
+ "ndim",
255
+ "output_nr",
256
+ "requires_grad",
257
+ "retains_grad",
258
+ "shape",
259
+ "volatile",
260
+ }:
261
+ return getattr(self.metatensor, name)
262
+ if name in {"size"}:
263
+ return getattr(self.metatensor, name)
264
+ # materializing with contiguous is needed for quantization
265
+ if name in {"contiguous"}:
266
+ return getattr(self._load_tensor(), name)
267
+
268
+ raise AttributeError(f"{type(self)} does not have {name}")
269
+
270
+ def __repr__(self):
271
+ return f"NotYetLoadedTensor({repr(self.metatensor)})"
272
+
273
+
274
+ class LazyLoadingUnpickler(pickle.Unpickler):
275
+ def __init__(self, file, zipfile_context):
276
+ super().__init__(file)
277
+ self.zipfile_context = zipfile_context
278
+
279
+ def find_class(self, module, name):
280
+ res = super().find_class(module, name)
281
+ if module == "torch._utils" and name == "_rebuild_tensor_v2":
282
+ return functools.partial(
283
+ NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self
284
+ )
285
+ elif module == "torch._tensor" and name == "_rebuild_from_type_v2":
286
+ return functools.partial(
287
+ NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self
288
+ )
289
+ elif module == "torch._utils" and name == "_rebuild_parameter":
290
+ return functools.partial(
291
+ NotYetLoadedTensor.rebuild_parameter, archiveinfo=self
292
+ )
293
+ return res
294
+
295
+ def persistent_load(self, pid):
296
+ name, cls, fn, device, size = pid
297
+ with warnings.catch_warnings():
298
+ warnings.simplefilter("ignore")
299
+ s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta")
300
+ s.archiveinfo = pid
301
+ return s
302
+
303
+
304
+ class lazy_load:
305
+ def __init__(self, fn):
306
+ self.zf = torch._C.PyTorchFileReader(str(fn))
307
+ with BytesIO(self.zf.get_record("data.pkl")) as pkl:
308
+ mup = LazyLoadingUnpickler(pkl, self)
309
+ self.sd = mup.load()
310
+
311
+ def __enter__(self):
312
+ return self.sd
313
+
314
+ def __exit__(self, exc_type, exc_val, exc_tb):
315
+ del self.zf # I don't think there is a way to force closing...
316
+ self.zf = None
317
+
318
+
319
+ class SavingProxyForStorage:
320
+ def __init__(self, obj, saver, protocol_version=5):
321
+ self.protocol_version = protocol_version
322
+ self.saver = saver
323
+ if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
324
+ raise TypeError(f"expected storage, not {type(obj)}")
325
+
326
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
327
+ if isinstance(obj, torch.storage.TypedStorage):
328
+ # PT upstream wants to deprecate this eventually...
329
+ storage = obj._untyped_storage
330
+ storage_type_str = obj._pickle_storage_type()
331
+ storage_type = getattr(torch, storage_type_str)
332
+ storage_numel = obj._size()
333
+ else:
334
+ storage = obj
335
+ storage_type = normalize_storage_type(type(obj))
336
+ storage_numel = storage.nbytes()
337
+
338
+ storage_key = saver._write_storage_and_return_key(storage)
339
+ location = torch.serialization.location_tag(storage)
340
+
341
+ self.storage_info = (
342
+ "storage",
343
+ storage_type,
344
+ storage_key,
345
+ location,
346
+ storage_numel,
347
+ )
348
+
349
+ def __reduce_ex__(self, protocol_version):
350
+ assert False, "this should be handled with out of band"
351
+
352
+
353
+ class SavingProxyForTensor:
354
+ def __init__(self, tensor, saver, protocol_version=5):
355
+ self.protocol_version = protocol_version
356
+ self.reduce_ret_fn, (storage, *other_reduce_args) = tensor.__reduce_ex__(
357
+ protocol_version
358
+ )
359
+ assert isinstance(
360
+ storage, torch.storage.TypedStorage
361
+ ), "Please check for updates"
362
+ storage_proxy = SavingProxyForStorage(
363
+ storage, saver, protocol_version=protocol_version
364
+ )
365
+ self.reduce_args = (storage_proxy, *other_reduce_args)
366
+
367
+ def __reduce_ex__(self, protocol_version):
368
+ if protocol_version != self.protocol_version:
369
+ raise RuntimeError(
370
+ f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}"
371
+ )
372
+ return self.reduce_ret_fn, self.reduce_args
373
+
374
+
375
+ class IncrementalPyTorchPickler(pickle.Pickler):
376
+ def __init__(self, saver, *args, **kwargs):
377
+ super().__init__(*args, **kwargs)
378
+ self.storage_dtypes = {}
379
+ self.saver = saver
380
+ self.id_map = {}
381
+
382
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
383
+ def persistent_id(self, obj):
384
+ # FIXME: the docs say that persistent_id should only return a string
385
+ # but torch store returns tuples. This works only in the binary protocol
386
+ # see
387
+ # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
388
+ # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
389
+ if isinstance(obj, SavingProxyForStorage):
390
+ return obj.storage_info
391
+
392
+ if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
393
+ if isinstance(obj, torch.storage.TypedStorage):
394
+ # TODO: Once we decide to break serialization FC, this case
395
+ # can be deleted
396
+ storage = obj._untyped_storage
397
+ storage_dtype = obj.dtype
398
+ storage_type_str = obj._pickle_storage_type()
399
+ storage_type = getattr(torch, storage_type_str)
400
+ storage_numel = obj._size()
401
+
402
+ else:
403
+ storage = obj
404
+ storage_dtype = torch.uint8
405
+ storage_type = normalize_storage_type(type(obj))
406
+ storage_numel = storage.nbytes()
407
+
408
+ # If storage is allocated, ensure that any other saved storages
409
+ # pointing to the same data all have the same dtype. If storage is
410
+ # not allocated, don't perform this check
411
+ if storage.data_ptr() != 0:
412
+ if storage.data_ptr() in self.storage_dtypes:
413
+ if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
414
+ raise RuntimeError(
415
+ "Cannot save multiple tensors or storages that "
416
+ "view the same data as different types"
417
+ )
418
+ else:
419
+ self.storage_dtypes[storage.data_ptr()] = storage_dtype
420
+
421
+ storage_key = self.id_map.get(storage._cdata)
422
+ if storage_key is None:
423
+ storage_key = self.saver._write_storage_and_return_key(storage)
424
+ self.id_map[storage._cdata] = storage_key
425
+ location = torch.serialization.location_tag(storage)
426
+
427
+ return ("storage", storage_type, storage_key, location, storage_numel)
428
+
429
+ return None
430
+
431
+
432
+ class incremental_save:
433
+ def __init__(self, name):
434
+ self.name = name
435
+ self.zipfile = torch._C.PyTorchFileWriter(str(name))
436
+ self.has_saved = False
437
+ self.next_key = 0
438
+
439
+ def __enter__(self):
440
+ return self
441
+
442
+ def store_early(self, tensor):
443
+ if isinstance(tensor, torch.Tensor):
444
+ return SavingProxyForTensor(tensor, self)
445
+ raise TypeError(f"can only store tensors early, not {type(tensor)}")
446
+
447
+ def save(self, obj):
448
+ if self.has_saved:
449
+ raise RuntimeError("have already saved")
450
+ # Write the pickle data for `obj`
451
+ data_buf = BytesIO()
452
+ pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
453
+ pickler.dump(obj)
454
+ data_value = data_buf.getvalue()
455
+ self.zipfile.write_record("data.pkl", data_value, len(data_value))
456
+ self.has_saved = True
457
+
458
+ def _write_storage_and_return_key(self, storage):
459
+ if self.has_saved:
460
+ raise RuntimeError("have already saved")
461
+ key = self.next_key
462
+ self.next_key += 1
463
+ name = f"data/{key}"
464
+ if storage.device.type != "cpu":
465
+ storage = storage.cpu()
466
+ num_bytes = storage.nbytes()
467
+ self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
468
+ return key
469
+
470
+ def __exit__(self, type, value, traceback):
471
+ self.zipfile.write_end_of_file()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ lightning @ git+https://github.com/Lightning-AI/lightning@master
3
+ sentencepiece
4
+ tqdm # convert_checkpoint.py
5
+ numpy # train.py dataset memmap
6
+ jsonargparse[signatures] # generate.py, convert_checkpoint.py CLI
7
+ bitsandbytes # quantization.py
8
+ datasets # evaluate.py
9
+ zstandard # prepare_redpajama.py
10
+ gradio # app.py