xuxw98 commited on
Commit
7d52396
1 Parent(s): 74ba10f

Upload 58 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. evaluate/adapter.py +164 -0
  2. evaluate/adapter_v2.py +161 -0
  3. evaluate/full.py +147 -0
  4. evaluate/lora.py +172 -0
  5. finetune/adapter.py +262 -0
  6. finetune/adapter_v2.py +266 -0
  7. finetune/full.py +224 -0
  8. finetune/lora.py +218 -0
  9. generate.py +170 -0
  10. generate/adapter.py +106 -0
  11. generate/adapter_v2.py +108 -0
  12. generate/full.py +103 -0
  13. generate/lora.py +118 -0
  14. howto/convert_lora_weights.md +19 -0
  15. howto/customize_paths.md +33 -0
  16. howto/download_weights.md +130 -0
  17. howto/finetune_adapter.md +109 -0
  18. howto/finetune_adapter_v2.md +114 -0
  19. howto/finetune_full.md +106 -0
  20. howto/finetune_lora.md +90 -0
  21. howto/inference.md +43 -0
  22. howto/tpus.md +51 -0
  23. howto/train_redpajama.md +133 -0
  24. howto/unstructured_dataset.md +18 -0
  25. lit_llama/__init__.py +2 -0
  26. lit_llama/adapter.py +313 -0
  27. lit_llama/adapter_v2.py +45 -0
  28. lit_llama/lora.py +476 -0
  29. lit_llama/model.py +321 -0
  30. lit_llama/packed_dataset.py +260 -0
  31. lit_llama/quantization.py +614 -0
  32. lit_llama/tokenizer.py +49 -0
  33. lit_llama/utils.py +496 -0
  34. pretrain/redpajama.py +321 -0
  35. pretrain/shakespeare.py +166 -0
  36. quantize/gptq.py +238 -0
  37. scripts/convert_checkpoint.py +141 -0
  38. scripts/convert_hf_checkpoint.py +167 -0
  39. scripts/convert_lora_weights.py +95 -0
  40. scripts/download.py +34 -0
  41. scripts/prepare_alpaca.py +131 -0
  42. scripts/prepare_any_text.py +97 -0
  43. scripts/prepare_dolly.py +133 -0
  44. scripts/prepare_redpajama.py +181 -0
  45. scripts/prepare_shakespeare.py +69 -0
  46. setup.py +26 -0
  47. tests/conftest.py +42 -0
  48. tests/test_adapter.py +55 -0
  49. tests/test_adapter_v2.py +26 -0
  50. tests/test_generate.py +117 -0
evaluate/adapter.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This mimics GPTQ's evaluation metrics: https://github.com/IST-DASLab/gptq/
2
+ # Thanks to E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
3
+ import math
4
+ import sys
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import lightning as L
10
+ import torch
11
+ import tqdm
12
+
13
+ # support running without installing as a package
14
+ wd = Path(__file__).parent.parent.resolve()
15
+ sys.path.append(str(wd))
16
+
17
+ from lit_llama import Tokenizer
18
+ from lit_llama.adapter import LLaMA
19
+ from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
20
+ from scripts.prepare_alpaca import generate_prompt
21
+
22
+ from datasets import load_dataset
23
+
24
+ instruction_tuning = True
25
+
26
+
27
+ def load_eval_data(dataset_name: str) -> str:
28
+ # this mimics gptq datautils
29
+ if dataset_name == "wikitext":
30
+ # traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
31
+ testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
32
+ testdata = "\n\n".join(testdata["text"])
33
+ elif dataset_name == "ptb":
34
+ testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
35
+ testdata = "\n\n".join(testdata["sentence"])
36
+ elif dataset_name == "c4":
37
+ testdata = load_dataset(
38
+ "allenai/c4",
39
+ "allenai--c4",
40
+ data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
41
+ split="validation",
42
+ )
43
+ testdata = " ".join(testdata[:1100]["text"])
44
+
45
+ else:
46
+ raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)")
47
+ return testdata
48
+
49
+
50
+ @torch.inference_mode()
51
+ def main(
52
+ datasets: str = "wikitext,ptb,c4",
53
+ *,
54
+ # compilation fails as it does not support torch.complex64 for RoPE
55
+ # compile: bool = False,
56
+ accelerator: str = "auto",
57
+ adapter_path: Path = Path("out/adapter/alpaca/lit-llama-adapter-finetuned.pth"),
58
+ checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
59
+ tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
60
+ dtype: str = "float32",
61
+ quantize: Optional[str] = None,
62
+ ) -> None:
63
+ """Generates text samples based on a pre-trained LLaMA model and tokenizer.
64
+
65
+ Args:
66
+ datasets: The datasets to use as a comma separated string
67
+ # compile: Whether to compile the model.
68
+ accelerator: The hardware to run on. Possible choices are:
69
+ ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
70
+ adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
71
+ `finetune_adapter.py`.
72
+ checkpoint_path: The checkpoint path to load.
73
+ tokenizer_path: The tokenizer path to load.
74
+ dtype: The tensor dtype for choosing the floating-point precision
75
+ quantize: Whether to quantize the model and using which method:
76
+ ``"llm.int8"``: LLM.int8() mode,
77
+ ``"gptq.int4"``: GPTQ 4-bit mode.
78
+ """
79
+ assert adapter_path.is_file()
80
+ assert checkpoint_path.is_file()
81
+ assert tokenizer_path.is_file()
82
+
83
+ fabric = L.Fabric(accelerator=accelerator, devices=1)
84
+
85
+ dt = getattr(torch, dtype, None)
86
+ if not isinstance(dt, torch.dtype):
87
+ raise ValueError(f"{dtype} is not a valid dtype.")
88
+ dtype = dt
89
+
90
+ print("Loading model ...", file=sys.stderr)
91
+ t0 = time.time()
92
+ with lazy_load(checkpoint_path) as pretrained_checkpoint, lazy_load(adapter_path) as adapter_checkpoint:
93
+ name = llama_model_lookup(pretrained_checkpoint)
94
+
95
+ with EmptyInitOnDevice(
96
+ device=fabric.device, dtype=dtype, quantization_mode=quantize
97
+ ):
98
+ model = LLaMA.from_name(name)
99
+
100
+ # 1. Load the pretrained weights
101
+ model.load_state_dict(pretrained_checkpoint, strict=False)
102
+ # 2. Load the fine-tuned adapter weights
103
+ model.load_state_dict(adapter_checkpoint, strict=False)
104
+
105
+ print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
106
+
107
+ model.eval()
108
+
109
+ # if compile:
110
+ # model = torch.compile(model)
111
+
112
+ total_toks = 0
113
+ model = fabric.setup_module(model)
114
+
115
+ tokenizer = Tokenizer(tokenizer_path)
116
+
117
+ for dsname in datasets.split(","):
118
+ test_string = load_eval_data(dsname)
119
+
120
+ if instruction_tuning:
121
+ sample = {"instruction": test_string, "input": input}
122
+ test_string = generate_prompt(sample)
123
+
124
+ encoded_text = tokenizer.encode(
125
+ test_string, bos=True, eos=False, device=fabric.device
126
+ )
127
+ encoded_text = encoded_text[
128
+ None, : 256 * model.config.block_size
129
+ ] # add batch dimension, trim like gptq implementation
130
+ t0 = time.perf_counter()
131
+
132
+ nlls = 0
133
+ toks = 0
134
+ block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30)
135
+ for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)):
136
+ inp = encoded_text[:, i : i + block_size]
137
+ logits = model(inp)[0]
138
+ nll = torch.nn.functional.cross_entropy(
139
+ logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum"
140
+ )
141
+ toks += inp.size(1) - 1
142
+ nlls += nll.item()
143
+
144
+ print(encoded_text.shape, logits.shape)
145
+ ppl = math.exp(nlls / toks)
146
+ print(f"Perplexity on {dsname}: {ppl:.2f}")
147
+ total_toks += toks
148
+
149
+ t = time.perf_counter() - t0
150
+ print(
151
+ f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec",
152
+ file=sys.stderr,
153
+ )
154
+ print(
155
+ f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB",
156
+ file=sys.stderr,
157
+ )
158
+
159
+
160
+ if __name__ == "__main__":
161
+ from jsonargparse import CLI
162
+
163
+ torch.set_float32_matmul_precision("high")
164
+ CLI(main)
evaluate/adapter_v2.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This mimics GPTQ's evaluation metrics: https://github.com/IST-DASLab/gptq/
2
+ # Thanks to E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
3
+ import math
4
+ import sys
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import lightning as L
10
+ import torch
11
+ import tqdm
12
+
13
+ # support running without installing as a package
14
+ wd = Path(__file__).parent.parent.resolve()
15
+ sys.path.append(str(wd))
16
+
17
+ from lit_llama import Tokenizer
18
+ from lit_llama.adapter import LLaMA
19
+ from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
20
+ from lit_llama.adapter_v2 import add_adapter_v2_parameters_to_linear_layers
21
+ from scripts.prepare_alpaca import generate_prompt
22
+
23
+ from datasets import load_dataset
24
+
25
+
26
+ def load_eval_data(dataset_name: str) -> str:
27
+ # this mimics gptq datautils
28
+ if dataset_name == "wikitext":
29
+ # traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
30
+ testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
31
+ testdata = "\n\n".join(testdata["text"])
32
+ elif dataset_name == "ptb":
33
+ testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
34
+ testdata = "\n\n".join(testdata["sentence"])
35
+ elif dataset_name == "c4":
36
+ testdata = load_dataset(
37
+ "allenai/c4",
38
+ "allenai--c4",
39
+ data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
40
+ split="validation",
41
+ )
42
+ testdata = " ".join(testdata[:1100]["text"])
43
+
44
+ else:
45
+ raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)")
46
+ return testdata
47
+
48
+
49
+ @torch.inference_mode()
50
+ def main(
51
+ datasets: str = "wikitext,ptb,c4",
52
+ *,
53
+ accelerator: str = "auto",
54
+ adapter_path: Path = Path("out/adapter_v2/alpaca/lit-llama-adapter-finetuned.pth"),
55
+ checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
56
+ tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
57
+ dtype: str = "float32",
58
+ quantize: Optional[str] = None,
59
+ ) -> None:
60
+ """Generates text samples based on a pre-trained LLaMA model and tokenizer.
61
+
62
+ Args:
63
+ datasets: The datasets to use as a comma separated string
64
+ accelerator: The hardware to run on. Possible choices are:
65
+ ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
66
+ adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
67
+ `finetune_adapter_v2.py`.
68
+ checkpoint_path: The checkpoint path to load.
69
+ tokenizer_path: The tokenizer path to load.
70
+ dtype: The tensor dtype for choosing the floating-point precision
71
+ quantize: Whether to quantize the model and using which method:
72
+ ``"llm.int8"``: LLM.int8() mode,
73
+ ``"gptq.int4"``: GPTQ 4-bit mode.
74
+ """
75
+ assert adapter_path.is_file()
76
+ assert checkpoint_path.is_file()
77
+ assert tokenizer_path.is_file()
78
+
79
+ fabric = L.Fabric(accelerator=accelerator, devices=1)
80
+
81
+ dt = getattr(torch, dtype, None)
82
+ if not isinstance(dt, torch.dtype):
83
+ raise ValueError(f"{dtype} is not a valid dtype.")
84
+ dtype = dt
85
+
86
+ print("Loading model ...", file=sys.stderr)
87
+ t0 = time.time()
88
+ with lazy_load(checkpoint_path) as pretrained_checkpoint, lazy_load(adapter_path) as adapter_checkpoint:
89
+ name = llama_model_lookup(pretrained_checkpoint)
90
+
91
+ with EmptyInitOnDevice(
92
+ device=fabric.device, dtype=dtype, quantization_mode=quantize
93
+ ):
94
+ model = LLaMA.from_name(name)
95
+ add_adapter_v2_parameters_to_linear_layers(model)
96
+
97
+ # 1. Load the pretrained weights
98
+ model.load_state_dict(pretrained_checkpoint, strict=False)
99
+ # 2. Load the fine-tuned adapter weights
100
+ model.load_state_dict(adapter_checkpoint, strict=False)
101
+
102
+ print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
103
+
104
+ model.eval()
105
+
106
+ # if compile:
107
+ # model = torch.compile(model)
108
+
109
+ total_toks = 0
110
+ model = fabric.setup_module(model)
111
+
112
+ tokenizer = Tokenizer(tokenizer_path)
113
+
114
+ for dsname in datasets.split(","):
115
+ test_string = load_eval_data(dsname)
116
+
117
+ sample = {"instruction": test_string, "input": input}
118
+ test_string = generate_prompt(sample)
119
+
120
+ encoded_text = tokenizer.encode(
121
+ test_string, bos=True, eos=False, device=fabric.device
122
+ )
123
+ encoded_text = encoded_text[
124
+ None, : 256 * model.config.block_size
125
+ ] # add batch dimension, trim like gptq implementation
126
+ t0 = time.perf_counter()
127
+
128
+ nlls = 0
129
+ toks = 0
130
+
131
+ block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30)
132
+ for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)):
133
+ inp = encoded_text[:, i : i + block_size]
134
+ logits = model(inp)[0]
135
+ nll = torch.nn.functional.cross_entropy(
136
+ logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum"
137
+ )
138
+ toks += inp.size(1) - 1
139
+ nlls += nll.item()
140
+
141
+ print(encoded_text.shape, logits.shape)
142
+ ppl = math.exp(nlls / toks)
143
+ print(f"Perplexity on {dsname}: {ppl:.2f}")
144
+ total_toks += toks
145
+
146
+ t = time.perf_counter() - t0
147
+ print(
148
+ f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec",
149
+ file=sys.stderr,
150
+ )
151
+ print(
152
+ f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB",
153
+ file=sys.stderr,
154
+ )
155
+
156
+
157
+ if __name__ == "__main__":
158
+ from jsonargparse import CLI
159
+
160
+ torch.set_float32_matmul_precision("high")
161
+ CLI(main)
evaluate/full.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This mimics GPTQ's evaluation metrics: https://github.com/IST-DASLab/gptq/
2
+ # Thanks to E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
3
+ import math
4
+ import sys
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import lightning as L
10
+ import torch
11
+ import tqdm
12
+
13
+ # support running without installing as a package
14
+ wd = Path(__file__).parent.parent.resolve()
15
+ sys.path.append(str(wd))
16
+
17
+ from lit_llama import LLaMA, Tokenizer
18
+ from lit_llama.utils import EmptyInitOnDevice
19
+
20
+ from datasets import load_dataset
21
+
22
+
23
+ def load_eval_data(dataset_name: str) -> str:
24
+ # this mimics gptq datautils
25
+ if dataset_name == "wikitext":
26
+ # traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
27
+ testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
28
+ testdata = "\n\n".join(testdata["text"])
29
+ elif dataset_name == "ptb":
30
+ testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
31
+ testdata = "\n\n".join(testdata["sentence"])
32
+ elif dataset_name == "c4":
33
+ testdata = load_dataset(
34
+ "allenai/c4",
35
+ "allenai--c4",
36
+ data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
37
+ split="validation",
38
+ )
39
+ testdata = " ".join(testdata[:1100]["text"])
40
+
41
+ else:
42
+ raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)")
43
+ return testdata
44
+
45
+
46
+ def main(
47
+ datasets: str = "wikitext,ptb,c4",
48
+ *,
49
+ # compilation fails as it does not support torch.complex64 for RoPE
50
+ # compile: bool = False,
51
+ accelerator: str = "auto",
52
+ checkpoint_path: Optional[Path] = None,
53
+ tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
54
+ model_size: str = "7B",
55
+ dtype: str = "float32",
56
+ quantize: Optional[str] = None,
57
+ ) -> None:
58
+ """Generates text samples based on a pre-trained LLaMA model and tokenizer.
59
+
60
+ Args:
61
+ datasets: The datasets to use as a comma separated string
62
+ # compile: Whether to compile the model.
63
+ accelerator: The hardware to run on. Possible choices are:
64
+ ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
65
+ checkpoint_path: The checkpoint path to load.
66
+ tokenizer_path: The tokenizer path to load.
67
+ dtype: The tensor dtype for choosing the floating-point precision
68
+ quantize: Whether to quantize the model and using which method:
69
+ ``"llm.int8"``: LLM.int8() mode,
70
+ ``"gptq.int4"``: GPTQ 4-bit mode.
71
+ """
72
+ if not checkpoint_path:
73
+ checkpoint_path = Path(f"checkpoints/lit-llama/{model_size}/lit-llama.pth")
74
+ assert checkpoint_path.is_file()
75
+ assert tokenizer_path.is_file()
76
+
77
+ fabric = L.Fabric(accelerator=accelerator, devices=1)
78
+
79
+ dt = getattr(torch, dtype, None)
80
+ if not isinstance(dt, torch.dtype):
81
+ raise ValueError(f"{dtype} is not a valid dtype.")
82
+ dtype = dt
83
+
84
+ with EmptyInitOnDevice(
85
+ device=fabric.device, dtype=dtype, quantization_mode=quantize
86
+ ):
87
+ print("Loading model ...", file=sys.stderr)
88
+ t0 = time.time()
89
+ model = LLaMA.from_name(model_size)
90
+ checkpoint = torch.load(checkpoint_path)
91
+ model.load_state_dict(checkpoint)
92
+ print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
93
+
94
+ model.eval()
95
+
96
+ # if compile:
97
+ # model = torch.compile(model)
98
+
99
+ total_toks = 0
100
+ model = fabric.setup_module(model)
101
+
102
+ tokenizer = Tokenizer(tokenizer_path)
103
+
104
+ for dsname in datasets.split(","):
105
+ test_string = load_eval_data(dsname)
106
+ encoded_text = tokenizer.encode(
107
+ test_string, bos=True, eos=False, device=fabric.device
108
+ )
109
+ encoded_text = encoded_text[
110
+ None, : 256 * model.config.block_size
111
+ ] # add batch dimension, trim like gptq implementation
112
+ t0 = time.perf_counter()
113
+
114
+ nlls = 0
115
+ toks = 0
116
+ with torch.inference_mode():
117
+ block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30)
118
+ for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)):
119
+ inp = encoded_text[:, i : i + block_size]
120
+ logits = model(inp)[0]
121
+ nll = torch.nn.functional.cross_entropy(
122
+ logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum"
123
+ )
124
+ toks += inp.size(1) - 1
125
+ nlls += nll.item()
126
+
127
+ print(encoded_text.shape, logits.shape)
128
+ ppl = math.exp(nlls / toks)
129
+ print(f"Perplexity on {dsname}: {ppl:.2f}")
130
+ total_toks += toks
131
+
132
+ t = time.perf_counter() - t0
133
+ print(
134
+ f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec",
135
+ file=sys.stderr,
136
+ )
137
+ print(
138
+ f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB",
139
+ file=sys.stderr,
140
+ )
141
+
142
+
143
+ if __name__ == "__main__":
144
+ from jsonargparse import CLI
145
+
146
+ torch.set_float32_matmul_precision("high")
147
+ CLI(main)
evaluate/lora.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This mimics GPTQ's evaluation metrics: https://github.com/IST-DASLab/gptq/
2
+ # Thanks to E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
3
+ import math
4
+ import sys
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import lightning as L
10
+ import torch
11
+ import tqdm
12
+
13
+ # support running without installing as a package
14
+ wd = Path(__file__).parent.parent.resolve()
15
+ sys.path.append(str(wd))
16
+
17
+ from lit_llama import LLaMA, Tokenizer
18
+ from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
19
+ from lit_llama.lora import lora
20
+ from scripts.prepare_alpaca import generate_prompt
21
+
22
+ from datasets import load_dataset
23
+
24
+ instruction_tuning = True
25
+ lora_r = 8
26
+ lora_alpha = 16
27
+ lora_dropout = 0.05
28
+
29
+
30
+ def load_eval_data(dataset_name: str) -> str:
31
+ # this mimics gptq datautils
32
+ if dataset_name == "wikitext":
33
+ # traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
34
+ testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
35
+ testdata = "\n\n".join(testdata["text"])
36
+ elif dataset_name == "ptb":
37
+ testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
38
+ testdata = "\n\n".join(testdata["sentence"])
39
+ elif dataset_name == "c4":
40
+ testdata = load_dataset(
41
+ "allenai/c4",
42
+ "allenai--c4",
43
+ data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
44
+ split="validation",
45
+ )
46
+ testdata = " ".join(testdata[:1100]["text"])
47
+
48
+ else:
49
+ raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)")
50
+ return testdata
51
+
52
+
53
+ def main(
54
+ datasets: str = "wikitext,ptb,c4",
55
+ *,
56
+ # compilation fails as it does not support torch.complex64 for RoPE
57
+ # compile: bool = False,
58
+ accelerator: str = "auto",
59
+ lora_path: Path = Path("out/lora/alpaca/lit-llama-lora-finetuned.pth"),
60
+ checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
61
+ tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
62
+ dtype: str = "float32",
63
+ quantize: Optional[str] = None,
64
+ ) -> None:
65
+ """Generates text samples based on a pre-trained LLaMA model and tokenizer
66
+ finetuned with LoRA.
67
+
68
+ Args:
69
+ datasets: The datasets to use as a comma separated string
70
+ # compile: Whether to compile the model.
71
+ accelerator: The hardware to run on. Possible choices are:
72
+ ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
73
+ lora_path: Path to the checkpoint with trained LoRA weights, which are the output of
74
+ `finetune_lora.py`.
75
+ checkpoint_path: The checkpoint path to load.
76
+ tokenizer_path: The tokenizer path to load.
77
+ dtype: The tensor dtype for choosing the floating-point precision
78
+ quantize: Whether to quantize the model and using which method:
79
+ ``"llm.int8"``: LLM.int8() mode,
80
+ ``"gptq.int4"``: GPTQ 4-bit mode.
81
+ """
82
+ assert lora_path.is_file()
83
+ assert checkpoint_path.is_file()
84
+ assert tokenizer_path.is_file()
85
+
86
+ if quantize is not None:
87
+ raise NotImplementedError("Quantization in LoRA is not supported yet")
88
+
89
+ fabric = L.Fabric(accelerator=accelerator, devices=1)
90
+
91
+ dt = getattr(torch, dtype, None)
92
+ if not isinstance(dt, torch.dtype):
93
+ raise ValueError(f"{dtype} is not a valid dtype.")
94
+ dtype = dt
95
+
96
+ print("Loading model ...", file=sys.stderr)
97
+ t0 = time.time()
98
+
99
+ with lazy_load(checkpoint_path) as pretrained_checkpoint, lazy_load(lora_path) as lora_checkpoint:
100
+ name = llama_model_lookup(pretrained_checkpoint)
101
+
102
+ with EmptyInitOnDevice(
103
+ device=fabric.device, dtype=dtype, quantization_mode=quantize
104
+ ), lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
105
+ model = LLaMA.from_name(name)
106
+
107
+ # 1. Load the pretrained weights
108
+ model.load_state_dict(pretrained_checkpoint, strict=False)
109
+ # 2. Load the fine-tuned lora weights
110
+ model.load_state_dict(lora_checkpoint, strict=False)
111
+
112
+ print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
113
+
114
+ model.eval()
115
+
116
+ # if compile:
117
+ # model = torch.compile(model)
118
+
119
+ total_toks = 0
120
+ model = fabric.setup_module(model)
121
+
122
+ tokenizer = Tokenizer(tokenizer_path)
123
+
124
+ for dsname in datasets.split(","):
125
+ test_string = load_eval_data(dsname)
126
+
127
+ if instruction_tuning:
128
+ sample = {"instruction": test_string, "input": input}
129
+ test_string = generate_prompt(sample)
130
+
131
+ encoded_text = tokenizer.encode(
132
+ test_string, bos=True, eos=False, device=fabric.device
133
+ )
134
+ encoded_text = encoded_text[
135
+ None, : 256 * model.config.block_size
136
+ ] # add batch dimension, trim like gptq implementation
137
+ t0 = time.perf_counter()
138
+
139
+ nlls = 0
140
+ toks = 0
141
+ with torch.inference_mode():
142
+ block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30)
143
+ for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)):
144
+ inp = encoded_text[:, i : i + block_size]
145
+ logits = model(inp)[0]
146
+ nll = torch.nn.functional.cross_entropy(
147
+ logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum"
148
+ )
149
+ toks += inp.size(1) - 1
150
+ nlls += nll.item()
151
+
152
+ print(encoded_text.shape, logits.shape)
153
+ ppl = math.exp(nlls / toks)
154
+ print(f"Perplexity on {dsname}: {ppl:.2f}")
155
+ total_toks += toks
156
+
157
+ t = time.perf_counter() - t0
158
+ print(
159
+ f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec",
160
+ file=sys.stderr,
161
+ )
162
+ print(
163
+ f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB",
164
+ file=sys.stderr,
165
+ )
166
+
167
+
168
+ if __name__ == "__main__":
169
+ from jsonargparse import CLI
170
+
171
+ torch.set_float32_matmul_precision("high")
172
+ CLI(main)
finetune/adapter.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Instruction-tuning with LLaMA-Adapter on the Alpaca dataset following the paper
3
+
4
+ LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
5
+ https://arxiv.org/abs/2303.16199
6
+
7
+ This script runs on a single GPU by default. You can adjust the `micro_batch_size` to fit your GPU memory.
8
+ You can finetune within 1 hour as done in the original paper using DeepSpeed Zero-2 on 8 A100 GPUs by setting the
9
+ devices variable to `devices = 8` and `micro_batch_size = 8` (or higher).
10
+
11
+ Note: If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
12
+ `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
13
+ """
14
+ import os
15
+ import sys
16
+ import time
17
+ from pathlib import Path
18
+ import shutil
19
+
20
+ import lightning as L
21
+ import numpy as np
22
+ import torch
23
+
24
+ # support running without installing as a package
25
+ wd = Path(__file__).parent.parent.resolve()
26
+ sys.path.append(str(wd))
27
+
28
+ from generate import generate
29
+ from lit_llama.adapter import LLaMA, LLaMAConfig, mark_only_adapter_as_trainable, adapter_state_from_state_dict
30
+ from lit_llama.tokenizer import Tokenizer
31
+ from scripts.prepare_alpaca import generate_prompt
32
+ from lightning.fabric.strategies import DeepSpeedStrategy
33
+
34
+
35
+ instruction_tuning = True
36
+ eval_interval = 600
37
+ save_interval = 1000
38
+ eval_iters = 100
39
+ log_interval = 1
40
+ devices = 1
41
+
42
+ # Hyperparameters
43
+ learning_rate = 9e-3
44
+ batch_size = 64 / devices
45
+ micro_batch_size = 4
46
+ gradient_accumulation_iters = batch_size // micro_batch_size
47
+ assert gradient_accumulation_iters > 0
48
+ epoch_size = 50000 # train dataset size
49
+ num_epochs = 5
50
+ max_iters = num_epochs * (epoch_size // micro_batch_size) // devices
51
+ weight_decay = 0.02
52
+ max_seq_length = 256 # see scripts/prepare_alpaca.py
53
+ warmup_iters = 2 * (epoch_size // micro_batch_size) // devices # 2 epochs
54
+
55
+ ds_config = {
56
+ "train_micro_batch_size_per_gpu": micro_batch_size,
57
+ "gradient_accumulation_steps": gradient_accumulation_iters,
58
+ "zero_optimization": {"stage": 2},
59
+ }
60
+
61
+
62
+ def main(
63
+ data_dir: str = "data/alpaca",
64
+ pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth",
65
+ out_dir: str = "out/adapter/alpaca",
66
+ ):
67
+
68
+ fabric = L.Fabric(
69
+ accelerator="cuda",
70
+ devices=devices,
71
+ strategy=(DeepSpeedStrategy(config=ds_config) if devices > 1 else "auto"),
72
+ precision="bf16-true",
73
+ )
74
+ fabric.launch()
75
+ fabric.seed_everything(1337 + fabric.global_rank)
76
+
77
+ if fabric.global_rank == 0:
78
+ os.makedirs(out_dir, exist_ok=True)
79
+
80
+ train_data, val_data = load_datasets(data_dir=data_dir)
81
+
82
+ config = LLaMAConfig(block_size=max_seq_length)
83
+
84
+ if not os.path.isfile(pretrained_path):
85
+ raise FileNotFoundError(
86
+ f"Can't find the pretrained weights at {pretrained_path}."
87
+ " Please follow the instructions in the README to download them."
88
+ )
89
+ checkpoint = torch.load(pretrained_path)
90
+
91
+ with fabric.init_module():
92
+ model = LLaMA(config)
93
+ # strict=False because missing keys due to adapter weights not containted in state dict
94
+ model.load_state_dict(checkpoint, strict=False)
95
+
96
+ mark_only_adapter_as_trainable(model)
97
+
98
+ num_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
99
+ print(f"Number of trainable parameters: {num_params}")
100
+
101
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
102
+ model, optimizer = fabric.setup(model, optimizer)
103
+ train(fabric, model, optimizer, train_data, val_data, out_dir)
104
+
105
+ # Save the final checkpoint at the end of training
106
+ save_model_checkpoint(fabric, model, os.path.join(out_dir, "lit-llama-adapter-finetuned.pth"))
107
+
108
+
109
+ def train(
110
+ fabric: L.Fabric,
111
+ model: torch.nn.Module,
112
+ optimizer: torch.optim.Optimizer,
113
+ train_data: np.ndarray,
114
+ val_data: np.ndarray,
115
+ out_dir: str,
116
+ ) -> None:
117
+ """The training loop.
118
+
119
+ Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
120
+ """
121
+ step_count = 0
122
+
123
+ for iter_num in range(max_iters):
124
+
125
+ if step_count <= warmup_iters:
126
+ # linear warmup
127
+ lr = learning_rate * step_count / warmup_iters
128
+ for param_group in optimizer.param_groups:
129
+ param_group['lr'] = lr
130
+
131
+ t0 = time.time()
132
+
133
+ input_ids, targets = get_batch(fabric, train_data)
134
+ with fabric.no_backward_sync(model, enabled=((iter_num + 1) % gradient_accumulation_iters != 0)):
135
+ logits = model(input_ids)
136
+ loss = loss_fn(logits, targets)
137
+ fabric.backward(loss / gradient_accumulation_iters)
138
+
139
+ if (iter_num + 1) % gradient_accumulation_iters == 0:
140
+ optimizer.step()
141
+ optimizer.zero_grad()
142
+ step_count += 1
143
+
144
+ if step_count % eval_interval == 0:
145
+ val_loss = validate(fabric, model, val_data)
146
+ fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
147
+ fabric.barrier()
148
+
149
+ if step_count % save_interval == 0:
150
+ print(f"Saving adapter weights to {out_dir}")
151
+ # TODO: Provide a function/script to merge the adapter weights with pretrained weights
152
+ save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}.pth"))
153
+
154
+ dt = time.time() - t0
155
+ if iter_num % log_interval == 0:
156
+ fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
157
+
158
+
159
+ def generate_response(model, instruction, input=""):
160
+ tokenizer = Tokenizer("checkpoints/lit-llama/tokenizer.model")
161
+ sample = {"instruction": instruction, "input": input}
162
+ prompt = instruction
163
+ if instruction_tuning:
164
+ prompt = generate_prompt(sample)
165
+ encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
166
+
167
+ output = generate(
168
+ model,
169
+ idx=encoded,
170
+ max_seq_length=max_seq_length,
171
+ max_new_tokens=100,
172
+ temperature=0.8,
173
+ )
174
+ output = tokenizer.decode(output)
175
+ return output # output.split("### Response:")[1].strip()
176
+
177
+
178
+ @torch.no_grad()
179
+ def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor:
180
+ fabric.print("Validating ...")
181
+ model.eval()
182
+ losses = torch.zeros(eval_iters)
183
+ for k in range(eval_iters):
184
+ input_ids, targets = get_batch(fabric, val_data)
185
+ logits = model(input_ids)
186
+ loss = loss_fn(logits, targets)
187
+ losses[k] = loss.item()
188
+ val_loss = losses.mean()
189
+
190
+ # produce an example:
191
+ instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
192
+ output = generate_response(model, instruction)
193
+ fabric.print(instruction)
194
+ fabric.print(output)
195
+
196
+ model.train()
197
+ return val_loss.item()
198
+
199
+ def loss_fn(logits, targets):
200
+ # shift the targets such that output n predicts token n+1
201
+ logits = logits[..., :-1, :].contiguous()
202
+ targets = targets[..., 1:].contiguous()
203
+ loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
204
+ return loss
205
+
206
+
207
+ def get_batch(fabric: L.Fabric, data: list):
208
+ ix = torch.randint(len(data), (micro_batch_size,))
209
+
210
+ input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
211
+ labels = [data[i]["labels"].type(torch.int64) for i in ix]
212
+
213
+ max_len = max(len(s) for s in input_ids)
214
+
215
+ def pad_right(x, pad_id):
216
+ # pad right based on the longest sequence
217
+ n = max_len - len(x)
218
+ return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
219
+
220
+ x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
221
+ y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
222
+ x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
223
+ return x, y
224
+
225
+
226
+ def load_datasets(data_dir):
227
+ train_data = torch.load(os.path.join(data_dir, "train.pt"))
228
+ val_data = torch.load(os.path.join(data_dir, "test.pt"))
229
+ return train_data, val_data
230
+
231
+
232
+ def save_model_checkpoint(fabric, model, file_path):
233
+ file_path = Path(file_path)
234
+
235
+ if isinstance(fabric.strategy, DeepSpeedStrategy):
236
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
237
+
238
+ tmp_path = file_path.with_suffix(".tmp")
239
+ fabric.save(tmp_path, {"model": model})
240
+ fabric.barrier()
241
+ if fabric.global_rank == 0:
242
+ # Create a consolidated checkpoint with the same name next to the deepspeed checkpoint
243
+ # and only keep the adapter weights
244
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(tmp_path)
245
+ state_dict = adapter_state_from_state_dict(state_dict)
246
+ torch.save(state_dict, file_path)
247
+ shutil.rmtree(tmp_path)
248
+ else:
249
+ state_dict = adapter_state_from_state_dict(model.state_dict())
250
+ if fabric.global_rank == 0:
251
+ torch.save(state_dict, file_path)
252
+ fabric.barrier()
253
+
254
+
255
+ if __name__ == "__main__":
256
+ # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
257
+ # torch.backends.cuda.enable_flash_sdp(False)
258
+ torch.set_float32_matmul_precision("high")
259
+
260
+ from jsonargparse.cli import CLI
261
+
262
+ CLI(main)
finetune/adapter_v2.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Instruction-tuning with LLaMA-Adapter v2 on the Alpaca dataset following the paper
3
+
4
+ LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model
5
+ https://arxiv.org/abs/2304.15010
6
+
7
+ This script runs on a single GPU by default. You can adjust the `micro_batch_size` to fit your GPU memory.
8
+ You can finetune within 1 hour as done in the original paper using DeepSpeed Zero-2 on 8 A100 GPUs by setting the
9
+ devices variable to `devices = 8` and `micro_batch_size = 8` (or higher).
10
+
11
+ Note: If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
12
+ `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
13
+ """
14
+ import os
15
+ import sys
16
+ import time
17
+ from pathlib import Path
18
+ import shutil
19
+
20
+ import lightning as L
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ # support running without installing as a package
26
+ wd = Path(__file__).parent.parent.resolve()
27
+ sys.path.append(str(wd))
28
+
29
+ from generate import generate
30
+ from lit_llama.adapter import LLaMA, LLaMAConfig
31
+ from lit_llama.adapter_v2 import (
32
+ mark_only_adapter_v2_as_trainable,
33
+ add_adapter_v2_parameters_to_linear_layers,
34
+ adapter_v2_state_from_state_dict
35
+ )
36
+ from lit_llama.tokenizer import Tokenizer
37
+ from scripts.prepare_alpaca import generate_prompt
38
+ from lightning.fabric.strategies import DeepSpeedStrategy
39
+
40
+
41
+ eval_interval = 600
42
+ save_interval = 1000
43
+ eval_iters = 100
44
+ log_interval = 1
45
+ devices = 1
46
+
47
+ # Hyperparameters
48
+ learning_rate = 9e-3
49
+ batch_size = 64 / devices
50
+ micro_batch_size = 4
51
+ gradient_accumulation_iters = batch_size // micro_batch_size
52
+ assert gradient_accumulation_iters > 0
53
+ epoch_size = 50000 # train dataset size
54
+ num_epochs = 5
55
+ max_iters = num_epochs * (epoch_size // micro_batch_size) // devices
56
+ weight_decay = 0.02
57
+ max_seq_length = 256 # see scripts/prepare_alpaca.py
58
+ warmup_iters = 2 * (epoch_size // micro_batch_size) // devices # 2 epoch
59
+
60
+ ds_config = {
61
+ "train_micro_batch_size_per_gpu": micro_batch_size,
62
+ "gradient_accumulation_steps": gradient_accumulation_iters,
63
+ "zero_optimization": {"stage": 2},
64
+ }
65
+
66
+
67
+ def main(
68
+ data_dir: str = "data/alpaca",
69
+ pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth",
70
+ out_dir: str = "out/adapter_v2/alpaca",
71
+ ):
72
+
73
+ fabric = L.Fabric(
74
+ accelerator="cuda",
75
+ devices=1,
76
+ strategy=(DeepSpeedStrategy(config=ds_config) if devices > 1 else "auto"),
77
+ precision="bf16-true",
78
+ )
79
+ fabric.launch()
80
+ fabric.seed_everything(1337 + fabric.global_rank)
81
+
82
+ if fabric.global_rank == 0:
83
+ os.makedirs(out_dir, exist_ok=True)
84
+
85
+ train_data, val_data = load_datasets(data_dir=data_dir)
86
+
87
+ config = LLaMAConfig(block_size=max_seq_length)
88
+
89
+ if not os.path.isfile(pretrained_path):
90
+ raise FileNotFoundError(
91
+ f"Can't find the pretrained weights at {pretrained_path}."
92
+ " Please follow the instructions in the README to download them."
93
+ )
94
+ checkpoint = torch.load(pretrained_path)
95
+
96
+ with fabric.init_module():
97
+ model = LLaMA(config)
98
+ # strict=False because missing keys due to adapter weights not contained in state dict
99
+ model.load_state_dict(checkpoint, strict=False)
100
+
101
+ add_adapter_v2_parameters_to_linear_layers(model)
102
+ mark_only_adapter_v2_as_trainable(model)
103
+
104
+ num_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
105
+ print(f"Number of trainable parameters: {num_params}")
106
+
107
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
108
+ model, optimizer = fabric.setup(model, optimizer)
109
+ train(fabric, model, optimizer, train_data, val_data, out_dir)
110
+
111
+ # Save the final checkpoint at the end of training
112
+ save_model_checkpoint(fabric, model, os.path.join(out_dir, "lit-llama-adapter-finetuned.pth"))
113
+
114
+
115
+ def train(
116
+ fabric: L.Fabric,
117
+ model: torch.nn.Module,
118
+ optimizer: torch.optim.Optimizer,
119
+ train_data: np.ndarray,
120
+ val_data: np.ndarray,
121
+ out_dir: str,
122
+ ) -> None:
123
+ """The training loop.
124
+
125
+ Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
126
+ """
127
+ step_count = 0
128
+
129
+ for iter_num in range(max_iters):
130
+
131
+ if step_count <= warmup_iters:
132
+ # linear warmup
133
+ lr = learning_rate * step_count / warmup_iters
134
+ for param_group in optimizer.param_groups:
135
+ param_group['lr'] = lr
136
+
137
+ t0 = time.time()
138
+
139
+ input_ids, targets = get_batch(fabric, train_data)
140
+ with fabric.no_backward_sync(model, enabled=((iter_num + 1) % gradient_accumulation_iters != 0)):
141
+ logits = model(input_ids)
142
+ loss = loss_fn(logits, targets)
143
+ fabric.backward(loss / gradient_accumulation_iters)
144
+
145
+ if (iter_num + 1) % gradient_accumulation_iters == 0:
146
+ optimizer.step()
147
+ optimizer.zero_grad()
148
+ step_count += 1
149
+
150
+ if step_count % eval_interval == 0:
151
+ val_loss = validate(fabric, model, val_data)
152
+ fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
153
+ fabric.barrier()
154
+
155
+ if step_count % save_interval == 0:
156
+ print(f"Saving adapter weights to {out_dir}")
157
+ # TODO: Provide a function/script to merge the adapter weights with pretrained weights
158
+ save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}.pth"))
159
+
160
+ dt = time.time() - t0
161
+ if iter_num % log_interval == 0:
162
+ fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
163
+
164
+
165
+ def generate_response(model, instruction, input=""):
166
+ tokenizer = Tokenizer("checkpoints/lit-llama/tokenizer.model")
167
+ sample = {"instruction": instruction, "input": input}
168
+ prompt = generate_prompt(sample)
169
+ encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
170
+
171
+ output = generate(
172
+ model,
173
+ idx=encoded,
174
+ max_seq_length=max_seq_length,
175
+ max_new_tokens=100,
176
+ temperature=0.8,
177
+ )
178
+ output = tokenizer.decode(output)
179
+ return output # output.split("### Response:")[1].strip()
180
+
181
+
182
+ @torch.no_grad()
183
+ def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor:
184
+ fabric.print("Validating ...")
185
+ model.eval()
186
+ losses = torch.zeros(eval_iters)
187
+ for k in range(eval_iters):
188
+ input_ids, targets = get_batch(fabric, val_data)
189
+ logits = model(input_ids)
190
+ loss = loss_fn(logits, targets)
191
+ losses[k] = loss.item()
192
+ val_loss = losses.mean()
193
+
194
+ # produce an example:
195
+ instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
196
+ output = generate_response(model, instruction)
197
+ fabric.print(instruction)
198
+ fabric.print(output)
199
+
200
+ model.train()
201
+ return val_loss.item()
202
+
203
+ def loss_fn(logits, targets):
204
+ # shift the targets such that output n predicts token n+1
205
+ logits = logits[..., :-1, :].contiguous()
206
+ targets = targets[..., 1:].contiguous()
207
+ loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
208
+ return loss
209
+
210
+
211
+ def get_batch(fabric: L.Fabric, data: list):
212
+ ix = torch.randint(len(data), (micro_batch_size,))
213
+
214
+ input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
215
+ labels = [data[i]["labels"].type(torch.int64) for i in ix]
216
+
217
+ max_len = max(len(s) for s in input_ids)
218
+
219
+ def pad_right(x, pad_id):
220
+ # pad right based on the longest sequence
221
+ n = max_len - len(x)
222
+ return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
223
+
224
+ x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
225
+ y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
226
+ x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
227
+ return x, y
228
+
229
+
230
+ def load_datasets(data_dir):
231
+ train_data = torch.load(os.path.join(data_dir, "train.pt"))
232
+ val_data = torch.load(os.path.join(data_dir, "test.pt"))
233
+ return train_data, val_data
234
+
235
+
236
+ def save_model_checkpoint(fabric, model, file_path):
237
+ file_path = Path(file_path)
238
+
239
+ if isinstance(fabric.strategy, DeepSpeedStrategy):
240
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
241
+
242
+ tmp_path = file_path.with_suffix(".tmp")
243
+ fabric.save(tmp_path, {"model": model})
244
+ fabric.barrier()
245
+ if fabric.global_rank == 0:
246
+ # Create a consolidated checkpoint with the same name next to the deepspeed checkpoint
247
+ # and only keep the adapter weights
248
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(tmp_path)
249
+ state_dict = adapter_v2_state_from_state_dict(state_dict)
250
+ torch.save(state_dict, file_path)
251
+ shutil.rmtree(tmp_path)
252
+ else:
253
+ state_dict = adapter_v2_state_from_state_dict(model.state_dict())
254
+ if fabric.global_rank == 0:
255
+ torch.save(state_dict, file_path)
256
+ fabric.barrier()
257
+
258
+
259
+ if __name__ == "__main__":
260
+ # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
261
+ # torch.backends.cuda.enable_flash_sdp(False)
262
+ torch.set_float32_matmul_precision("high")
263
+
264
+ from jsonargparse.cli import CLI
265
+
266
+ CLI(main)
finetune/full.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Instruction-tuning on the Alpaca dataset using a regular finetuning procedure (updating all layers).
3
+
4
+ Note: If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
5
+ `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
6
+ """
7
+ import sys
8
+ from pathlib import Path
9
+ import os
10
+ import time
11
+ from functools import partial
12
+
13
+ import lightning as L
14
+ from lightning.fabric.strategies import FSDPStrategy
15
+ import numpy as np
16
+ import torch
17
+ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
18
+
19
+ # support running without installing as a package
20
+ wd = Path(__file__).parent.parent.resolve()
21
+ sys.path.append(str(wd))
22
+
23
+ from generate import generate
24
+ from lit_llama.model import Block, LLaMA, LLaMAConfig
25
+ from lit_llama.tokenizer import Tokenizer
26
+ from lit_llama.utils import save_model_checkpoint
27
+ from scripts.prepare_alpaca import generate_prompt
28
+
29
+
30
+ instruction_tuning = True
31
+ eval_interval = 1000
32
+ save_interval = 1000
33
+ eval_iters = 100
34
+ log_interval = 100
35
+ devices = 4
36
+
37
+ # Hyperparameters
38
+ learning_rate = 3e-5
39
+ batch_size = 128 / devices
40
+ micro_batch_size = 4
41
+ gradient_accumulation_iters = batch_size // micro_batch_size
42
+ assert gradient_accumulation_iters > 0
43
+ epoch_size = 50000 # train dataset size
44
+ num_epochs = 5
45
+ max_iters = num_epochs * (epoch_size // micro_batch_size) // devices
46
+ weight_decay = 0.0
47
+ block_size = 512
48
+ warmup_iters = 100
49
+
50
+
51
+ def main(
52
+ data_dir: str = "data/alpaca",
53
+ pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth",
54
+ out_dir: str = "out/full/alpaca",
55
+ ):
56
+
57
+ auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
58
+ strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True)
59
+
60
+ fabric = L.Fabric(accelerator="cuda", devices=devices, precision="bf16-mixed", strategy=strategy)
61
+ fabric.launch()
62
+ fabric.seed_everything(1337 + fabric.global_rank)
63
+
64
+ if fabric.global_rank == 0:
65
+ os.makedirs(out_dir, exist_ok=True)
66
+
67
+ train_data, val_data = load_datasets(data_dir=data_dir)
68
+
69
+ config = LLaMAConfig.from_name("7B")
70
+ config.block_size = block_size
71
+
72
+ checkpoint = torch.load(pretrained_path)
73
+
74
+ with fabric.device:
75
+ torch.set_default_tensor_type(torch.HalfTensor)
76
+ model = LLaMA(config).bfloat16()
77
+ torch.set_default_tensor_type(torch.FloatTensor)
78
+ model.load_state_dict(checkpoint, strict=False)
79
+
80
+ model = fabric.setup_module(model)
81
+
82
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, foreach=False)
83
+ optimizer = fabric.setup_optimizers(optimizer)
84
+
85
+ train(fabric, model, optimizer, train_data, val_data, out_dir)
86
+
87
+ # Save the final checkpoint at the end of training
88
+ save_model_checkpoint(fabric, model, os.path.join(out_dir, "lit-llama-full-finetuned.pth"))
89
+
90
+
91
+ def train(
92
+ fabric: L.Fabric,
93
+ model: torch.nn.Module,
94
+ optimizer: torch.optim.Optimizer,
95
+ train_data: np.ndarray,
96
+ val_data: np.ndarray,
97
+ out_dir: str,
98
+ ) -> None:
99
+ """The training loop.
100
+
101
+ Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
102
+ """
103
+ step_count = 0
104
+ model.train()
105
+
106
+ for iter_num in range(max_iters):
107
+
108
+ is_accumulating = (iter_num + 1) % gradient_accumulation_iters != 0
109
+
110
+ if step_count <= warmup_iters:
111
+ # linear warmup
112
+ lr = learning_rate * step_count / warmup_iters
113
+ for param_group in optimizer.param_groups:
114
+ param_group['lr'] = lr
115
+
116
+ t0 = time.time()
117
+
118
+ input_ids, targets = get_batch(fabric, train_data)
119
+ with fabric.no_backward_sync(model, enabled=is_accumulating):
120
+ logits = model(input_ids)
121
+ loss = loss_fn(logits, targets)
122
+ fabric.backward(loss / gradient_accumulation_iters)
123
+
124
+ if not is_accumulating:
125
+ optimizer.step()
126
+ optimizer.zero_grad()
127
+ step_count += 1
128
+
129
+ if step_count % eval_interval == 0:
130
+ val_loss = validate(fabric, model, val_data)
131
+ fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
132
+ fabric.barrier()
133
+
134
+ if step_count % save_interval == 0:
135
+ print(f"Saving weights to {out_dir}")
136
+ save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"))
137
+
138
+ dt = time.time() - t0
139
+ if iter_num % log_interval == 0:
140
+ fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
141
+
142
+
143
+ def generate_response(model, instruction):
144
+ tokenizer = Tokenizer("checkpoints/lit-llama/tokenizer.model")
145
+ sample = {"instruction": instruction, "input": ""}
146
+ prompt = instruction
147
+ if instruction_tuning:
148
+ prompt = generate_prompt(sample)
149
+ encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
150
+
151
+ output = generate(
152
+ model,
153
+ idx=encoded,
154
+ max_seq_length=block_size,
155
+ max_new_tokens=100,
156
+ )
157
+ output = tokenizer.decode(output)
158
+ return output # output.split("### Response:")[1].strip()
159
+
160
+
161
+ @torch.no_grad()
162
+ def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor:
163
+ fabric.print("Validating ...")
164
+ model.eval()
165
+ losses = torch.zeros(eval_iters)
166
+ for k in range(eval_iters):
167
+ input_ids, targets = get_batch(fabric, val_data)
168
+ logits = model(input_ids)
169
+ loss = loss_fn(logits, targets)
170
+ losses[k] = loss.item()
171
+ out = losses.mean()
172
+
173
+ # produce an example:
174
+ instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
175
+
176
+ output = generate_response(model, instruction)
177
+ fabric.print(instruction)
178
+ fabric.print(output)
179
+
180
+ model.train()
181
+ return out.item()
182
+
183
+
184
+ def loss_fn(logits, targets):
185
+ # shift the targets such that output n predicts token n+1
186
+ logits = logits[..., :-1, :].contiguous()
187
+ targets = targets[..., 1:].contiguous()
188
+ loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
189
+ return loss
190
+
191
+
192
+ def get_batch(fabric: L.Fabric, data: list):
193
+ ix = torch.randint(len(data), (micro_batch_size,))
194
+
195
+ input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
196
+ labels = [data[i]["labels"].type(torch.int64) for i in ix]
197
+
198
+ max_len = max(len(s) for s in input_ids)
199
+
200
+ def pad_right(x, pad_id):
201
+ # pad right based on the longest sequence
202
+ n = max_len - len(x)
203
+ return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
204
+
205
+ x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
206
+ y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
207
+ x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
208
+ return x, y
209
+
210
+
211
+ def load_datasets(data_dir):
212
+ train_data = torch.load(os.path.join(data_dir, "train.pt"))
213
+ val_data = torch.load(os.path.join(data_dir, "test.pt"))
214
+ return train_data, val_data
215
+
216
+
217
+ if __name__ == "__main__":
218
+ # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
219
+ # torch.backends.cuda.enable_flash_sdp(False)
220
+ torch.set_float32_matmul_precision("high")
221
+
222
+ from jsonargparse.cli import CLI
223
+
224
+ CLI(main)
finetune/lora.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Instruction-tuning with LoRA on the Alpaca dataset.
3
+
4
+ Note: If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
5
+ `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
6
+ """
7
+ import sys
8
+ from pathlib import Path
9
+ import os
10
+ import time
11
+
12
+ import lightning as L
13
+ import numpy as np
14
+ import torch
15
+
16
+ # support running without installing as a package
17
+ wd = Path(__file__).parent.parent.resolve()
18
+ sys.path.append(str(wd))
19
+
20
+ from generate import generate
21
+ from lit_llama.lora import mark_only_lora_as_trainable, lora, lora_state_dict
22
+ from lit_llama.model import LLaMA, LLaMAConfig
23
+ from lit_llama.tokenizer import Tokenizer
24
+ from scripts.prepare_alpaca import generate_prompt
25
+
26
+
27
+ instruction_tuning = True
28
+ eval_interval = 100
29
+ save_interval = 100
30
+ eval_iters = 100
31
+ log_interval = 1
32
+
33
+ # Hyperparameters
34
+ learning_rate = 3e-4
35
+ batch_size = 128
36
+ micro_batch_size = 4
37
+ gradient_accumulation_iters = batch_size // micro_batch_size
38
+ assert gradient_accumulation_iters > 0
39
+ max_iters = 50000 * 3 // micro_batch_size
40
+ weight_decay = 0.0
41
+ max_seq_length = 256 # see scripts/prepare_alpaca.py
42
+ lora_r = 8
43
+ lora_alpha = 16
44
+ lora_dropout = 0.05
45
+ warmup_iters = 100
46
+
47
+
48
+ def main(
49
+ data_dir: str = "data/alpaca",
50
+ pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth",
51
+ tokenizer_path: str = "checkpoints/lit-llama/tokenizer.model",
52
+ out_dir: str = "out/lora/alpaca",
53
+ ):
54
+
55
+ fabric = L.Fabric(accelerator="cuda", devices=1, precision="bf16-true")
56
+ fabric.launch()
57
+ fabric.seed_everything(1337 + fabric.global_rank)
58
+
59
+ if fabric.global_rank == 0:
60
+ os.makedirs(out_dir, exist_ok=True)
61
+
62
+ train_data, val_data = load_datasets(data_dir=data_dir)
63
+
64
+ config = LLaMAConfig.from_name("7B")
65
+ config.block_size = max_seq_length
66
+
67
+ checkpoint = torch.load(pretrained_path)
68
+
69
+ with fabric.init_module(), lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
70
+ model = LLaMA(config)
71
+ # strict=False because missing keys due to LoRA weights not contained in checkpoint state
72
+ model.load_state_dict(checkpoint, strict=False)
73
+
74
+ mark_only_lora_as_trainable(model)
75
+
76
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
77
+ model, optimizer = fabric.setup(model, optimizer)
78
+ train(fabric, model, optimizer, train_data, val_data, tokenizer_path, out_dir)
79
+
80
+ # Save the final LoRA checkpoint at the end of training
81
+ checkpoint = lora_state_dict(model)
82
+ fabric.save(os.path.join(out_dir, "lit-llama-lora-finetuned.pth"), checkpoint)
83
+
84
+
85
+ def train(
86
+ fabric: L.Fabric,
87
+ model: torch.nn.Module,
88
+ optimizer: torch.optim.Optimizer,
89
+ train_data: np.ndarray,
90
+ val_data: np.ndarray,
91
+ tokenizer_path: str,
92
+ out_dir: str,
93
+ ) -> None:
94
+ """The training loop.
95
+
96
+ Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
97
+ """
98
+ step_count = 0
99
+
100
+ for iter_num in range(max_iters):
101
+
102
+ if step_count <= warmup_iters:
103
+ # linear warmup
104
+ lr = learning_rate * step_count / warmup_iters
105
+ for param_group in optimizer.param_groups:
106
+ param_group['lr'] = lr
107
+
108
+ t0 = time.time()
109
+
110
+ input_ids, targets = get_batch(fabric, train_data)
111
+ with fabric.no_backward_sync(model, enabled=((iter_num + 1) % gradient_accumulation_iters != 0)):
112
+ logits = model(input_ids)
113
+ loss = loss_fn(logits, targets)
114
+ fabric.backward(loss / gradient_accumulation_iters)
115
+
116
+ if (iter_num + 1) % gradient_accumulation_iters == 0:
117
+ optimizer.step()
118
+ optimizer.zero_grad()
119
+ step_count += 1
120
+
121
+ if step_count % eval_interval == 0:
122
+ val_loss = validate(fabric, model, val_data, tokenizer_path)
123
+ fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
124
+ fabric.barrier()
125
+
126
+ if step_count % save_interval == 0:
127
+ print(f"Saving LoRA weights to {out_dir}")
128
+ # We are only saving the LoRA weights
129
+ # TODO: Provide a function/script to merge the LoRA weights with pretrained weights
130
+ checkpoint = lora_state_dict(model)
131
+ fabric.save(os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"), checkpoint)
132
+
133
+ dt = time.time() - t0
134
+ if iter_num % log_interval == 0:
135
+ fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
136
+
137
+
138
+ def generate_response(model, instruction, tokenizer_path):
139
+ tokenizer = Tokenizer(tokenizer_path)
140
+ sample = {"instruction": instruction, "input": ""}
141
+ prompt = instruction
142
+ if instruction_tuning:
143
+ prompt = generate_prompt(sample)
144
+ encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
145
+
146
+ output = generate(
147
+ model,
148
+ idx=encoded,
149
+ max_seq_length=max_seq_length,
150
+ max_new_tokens=100,
151
+ )
152
+ output = tokenizer.decode(output)
153
+ return output # output.split("### Response:")[1].strip()
154
+
155
+
156
+ @torch.no_grad()
157
+ def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray, tokenizer_path: str) -> torch.Tensor:
158
+ fabric.print("Validating ...")
159
+ model.eval()
160
+ losses = torch.zeros(eval_iters)
161
+ for k in range(eval_iters):
162
+ input_ids, targets = get_batch(fabric, val_data)
163
+ logits = model(input_ids)
164
+ loss = loss_fn(logits, targets)
165
+ losses[k] = loss.item()
166
+ out = losses.mean()
167
+
168
+ # produce an example:
169
+ instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
170
+
171
+ output = generate_response(model, instruction, tokenizer_path)
172
+ fabric.print(instruction)
173
+ fabric.print(output)
174
+
175
+ model.train()
176
+ return out.item()
177
+
178
+ def loss_fn(logits, targets):
179
+ # shift the targets such that output n predicts token n+1
180
+ logits = logits[..., :-1, :].contiguous()
181
+ targets = targets[..., 1:].contiguous()
182
+ loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
183
+ return loss
184
+
185
+
186
+ def get_batch(fabric: L.Fabric, data: list):
187
+ ix = torch.randint(len(data), (micro_batch_size,))
188
+
189
+ input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
190
+ labels = [data[i]["labels"].type(torch.int64) for i in ix]
191
+
192
+ max_len = max(len(s) for s in input_ids)
193
+
194
+ def pad_right(x, pad_id):
195
+ # pad right based on the longest sequence
196
+ n = max_len - len(x)
197
+ return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
198
+
199
+ x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
200
+ y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
201
+ x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
202
+ return x, y
203
+
204
+
205
+ def load_datasets(data_dir):
206
+ train_data = torch.load(os.path.join(data_dir, "train.pt"))
207
+ val_data = torch.load(os.path.join(data_dir, "test.pt"))
208
+ return train_data, val_data
209
+
210
+
211
+ if __name__ == "__main__":
212
+ # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
213
+ # torch.backends.cuda.enable_flash_sdp(False)
214
+ torch.set_float32_matmul_precision("high")
215
+
216
+ from jsonargparse.cli import CLI
217
+
218
+ CLI(main)
generate.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ import warnings
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import lightning as L
8
+ import torch
9
+
10
+ # support running without installing as a package
11
+ wd = Path(__file__).parent.parent.resolve()
12
+ sys.path.append(str(wd))
13
+
14
+ from lit_llama import LLaMA, Tokenizer
15
+ from lit_llama.utils import lazy_load, llama_model_lookup, quantization
16
+
17
+
18
+ @torch.no_grad()
19
+ def generate(
20
+ model: LLaMA,
21
+ idx: torch.Tensor,
22
+ max_new_tokens: int,
23
+ *,
24
+ max_seq_length: Optional[int] = None,
25
+ temperature: float = 1.0,
26
+ top_k: Optional[int] = None,
27
+ eos_id: Optional[int] = None,
28
+ ) -> torch.Tensor:
29
+ """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
30
+
31
+ The implementation of this function is modified from A. Karpathy's nanoGPT.
32
+
33
+ Args:
34
+ model: The model to use.
35
+ idx: Tensor of shape (T) with indices of the prompt sequence.
36
+ max_new_tokens: The number of new tokens to generate.
37
+ max_seq_length: The maximum sequence length allowed.
38
+ temperature: Scales the predicted logits by 1 / temperature
39
+ top_k: If specified, only sample among the tokens with the k highest probabilities
40
+ eos_id: If specified, stop generating any more token once the <eos> token is triggered
41
+ """
42
+ # create an empty tensor of the expected final shape and fill in the current tokens
43
+ T = idx.size(0)
44
+ T_new = T + max_new_tokens
45
+ if max_seq_length is None:
46
+ max_seq_length = min(T_new, model.config.block_size)
47
+
48
+ device, dtype = idx.device, idx.dtype
49
+ # create an empty tensor of the expected final shape and fill in the current tokens
50
+ empty = torch.empty(T_new, dtype=dtype, device=device)
51
+ empty[:T] = idx
52
+ idx = empty
53
+ input_pos = torch.arange(0, T, device=device)
54
+
55
+ if idx.device.type == "xla":
56
+ import torch_xla.core.xla_model as xm
57
+
58
+ xm.mark_step()
59
+
60
+ # generate max_new_tokens tokens
61
+ for _ in range(max_new_tokens):
62
+ x = idx.index_select(0, input_pos).view(1, -1)
63
+
64
+ # forward
65
+ logits = model(x, max_seq_length, input_pos)
66
+ logits = logits[0, -1] / temperature
67
+
68
+ # optionally crop the logits to only the top k options
69
+ if top_k is not None:
70
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
71
+ logits = torch.where(logits < v[[-1]], -float("Inf"), logits)
72
+
73
+ probs = torch.nn.functional.softmax(logits, dim=-1)
74
+ idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)
75
+
76
+ # advance
77
+ input_pos = input_pos[-1:] + 1
78
+
79
+ if idx.device.type == "xla":
80
+ xm.mark_step()
81
+
82
+ # concatenate the new generation
83
+ idx = idx.index_copy(0, input_pos, idx_next)
84
+
85
+ # if <eos> token is triggered, return the output (stop generation)
86
+ if idx_next == eos_id:
87
+ return idx[:input_pos] # include the EOS token
88
+
89
+ return idx
90
+
91
+
92
+ def main(
93
+ prompt: str = "Hello, my name is",
94
+ *,
95
+ num_samples: int = 1,
96
+ max_new_tokens: int = 50,
97
+ top_k: int = 200,
98
+ temperature: float = 0.8,
99
+ checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
100
+ tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
101
+ quantize: Optional[str] = None,
102
+ ) -> None:
103
+ """Generates text samples based on a pre-trained LLaMA model and tokenizer.
104
+
105
+ Args:
106
+ prompt: The prompt string to use for generating the samples.
107
+ num_samples: The number of text samples to generate.
108
+ max_new_tokens: The number of generation steps to take.
109
+ top_k: The number of top most probable tokens to consider in the sampling process.
110
+ temperature: A value controlling the randomness of the sampling process. Higher values result in more random
111
+ samples.
112
+ checkpoint_path: The checkpoint path to load.
113
+ tokenizer_path: The tokenizer path to load.
114
+ quantize: Whether to quantize the model and using which method:
115
+ ``"llm.int8"``: LLM.int8() mode,
116
+ ``"gptq.int4"``: GPTQ 4-bit mode.
117
+ """
118
+ assert checkpoint_path.is_file(), checkpoint_path
119
+ assert tokenizer_path.is_file(), tokenizer_path
120
+
121
+ precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
122
+ fabric = L.Fabric(devices=1, precision=precision)
123
+
124
+ print("Loading model ...", file=sys.stderr)
125
+ t0 = time.time()
126
+ with lazy_load(checkpoint_path) as checkpoint:
127
+ name = llama_model_lookup(checkpoint)
128
+
129
+ with fabric.init_module(empty_init=True), quantization(mode=quantize):
130
+ model = LLaMA.from_name(name)
131
+
132
+ model.load_state_dict(checkpoint)
133
+ print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
134
+
135
+ model.eval()
136
+ model = fabric.setup(model)
137
+
138
+ tokenizer = Tokenizer(tokenizer_path)
139
+ encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
140
+ prompt_length = encoded.size(0)
141
+
142
+ L.seed_everything(1234)
143
+ for i in range(num_samples):
144
+ t0 = time.perf_counter()
145
+ y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
146
+ t = time.perf_counter() - t0
147
+
148
+ model.reset_cache()
149
+ print(tokenizer.decode(y))
150
+ tokens_generated = y.size(0) - prompt_length
151
+ print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
152
+ if fabric.device.type == "cuda":
153
+ print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
154
+
155
+
156
+ if __name__ == "__main__":
157
+ from jsonargparse import CLI
158
+
159
+ torch.set_float32_matmul_precision("high")
160
+ warnings.filterwarnings(
161
+ # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
162
+ "ignore",
163
+ message="ComplexHalf support is experimental and many operators don't support it yet"
164
+ )
165
+ warnings.filterwarnings(
166
+ # Triggered in bitsandbytes/autograd/_functions.py:298
167
+ "ignore",
168
+ message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization",
169
+ )
170
+ CLI(main)
generate/adapter.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ import warnings
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import lightning as L
8
+ import torch
9
+
10
+ # support running without installing as a package
11
+ wd = Path(__file__).parent.parent.resolve()
12
+ sys.path.append(str(wd))
13
+
14
+ from generate import generate
15
+ from lit_llama import Tokenizer
16
+ from lit_llama.adapter import LLaMA
17
+ from lit_llama.utils import lazy_load, llama_model_lookup, quantization
18
+ from scripts.prepare_alpaca import generate_prompt
19
+
20
+
21
+ def main(
22
+ prompt: str = "What food do lamas eat?",
23
+ input: str = "",
24
+ adapter_path: Path = Path("out/adapter/alpaca/lit-llama-adapter-finetuned.pth"),
25
+ pretrained_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
26
+ tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
27
+ quantize: Optional[str] = None,
28
+ max_new_tokens: int = 100,
29
+ top_k: int = 200,
30
+ temperature: float = 0.8,
31
+ ) -> None:
32
+ """Generates a response based on a given instruction and an optional input.
33
+ This script will only work with checkpoints from the instruction-tuned LLaMA-Adapter model.
34
+ See `finetune_adapter.py`.
35
+
36
+ Args:
37
+ prompt: The prompt/instruction (Alpaca style).
38
+ adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
39
+ `finetune_adapter.py`.
40
+ input: Optional input (Alpaca style).
41
+ pretrained_path: The path to the checkpoint with pretrained LLaMA weights.
42
+ tokenizer_path: The tokenizer path to load.
43
+ quantize: Whether to quantize the model and using which method:
44
+ ``"llm.int8"``: LLM.int8() mode,
45
+ ``"gptq.int4"``: GPTQ 4-bit mode.
46
+ max_new_tokens: The number of generation steps to take.
47
+ top_k: The number of top most probable tokens to consider in the sampling process.
48
+ temperature: A value controlling the randomness of the sampling process. Higher values result in more random
49
+ samples.
50
+ """
51
+ assert adapter_path.is_file()
52
+ assert pretrained_path.is_file()
53
+ assert tokenizer_path.is_file()
54
+
55
+ precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
56
+ fabric = L.Fabric(devices=1, precision=precision)
57
+
58
+ print("Loading model ...", file=sys.stderr)
59
+ t0 = time.time()
60
+ with lazy_load(pretrained_path) as pretrained_checkpoint, lazy_load(adapter_path) as adapter_checkpoint:
61
+ name = llama_model_lookup(pretrained_checkpoint)
62
+
63
+ with fabric.init_module(empty_init=True), quantization(mode=quantize):
64
+ model = LLaMA.from_name(name)
65
+
66
+ # 1. Load the pretrained weights
67
+ model.load_state_dict(pretrained_checkpoint, strict=False)
68
+ # 2. Load the fine-tuned adapter weights
69
+ model.load_state_dict(adapter_checkpoint, strict=False)
70
+
71
+ print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
72
+
73
+ model.eval()
74
+ model = fabric.setup(model)
75
+
76
+ tokenizer = Tokenizer(tokenizer_path)
77
+ sample = {"instruction": prompt, "input": input}
78
+ prompt = generate_prompt(sample)
79
+ encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
80
+ prompt_length = encoded.size(0)
81
+
82
+ t0 = time.perf_counter()
83
+ y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
84
+ t = time.perf_counter() - t0
85
+
86
+ model.reset_cache()
87
+ output = tokenizer.decode(y)
88
+ output = output.split("### Response:")[1].strip()
89
+ print(output)
90
+
91
+ tokens_generated = y.size(0) - prompt_length
92
+ print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
93
+ if fabric.device.type == "cuda":
94
+ print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
95
+
96
+
97
+ if __name__ == "__main__":
98
+ from jsonargparse import CLI
99
+
100
+ torch.set_float32_matmul_precision("high")
101
+ warnings.filterwarnings(
102
+ # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
103
+ "ignore",
104
+ message="ComplexHalf support is experimental and many operators don't support it yet"
105
+ )
106
+ CLI(main)
generate/adapter_v2.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ import warnings
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import lightning as L
8
+ import torch
9
+
10
+ # support running without installing as a package
11
+ wd = Path(__file__).parent.parent.resolve()
12
+ sys.path.append(str(wd))
13
+
14
+ from generate import generate
15
+ from lit_llama import Tokenizer
16
+ from lit_llama.adapter import LLaMA
17
+ from lit_llama.utils import lazy_load, llama_model_lookup, quantization
18
+ from lit_llama.adapter_v2 import add_adapter_v2_parameters_to_linear_layers
19
+ from scripts.prepare_alpaca import generate_prompt
20
+
21
+
22
+ def main(
23
+ prompt: str = "What food do lamas eat?",
24
+ input: str = "",
25
+ adapter_path: Path = Path("out/adapter_v2/alpaca/lit-llama-adapter-finetuned.pth"),
26
+ pretrained_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
27
+ tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
28
+ quantize: Optional[str] = None,
29
+ max_new_tokens: int = 100,
30
+ top_k: int = 200,
31
+ temperature: float = 0.8,
32
+ ) -> None:
33
+ """Generates a response based on a given instruction and an optional input.
34
+ This script will only work with checkpoints from the instruction-tuned LLaMA-Adapter model.
35
+ See `finetune_adapter_v2.py`.
36
+
37
+ Args:
38
+ prompt: The prompt/instruction (Alpaca style).
39
+ adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
40
+ `finetune_adapter_v2.py`.
41
+ input: Optional input (Alpaca style).
42
+ pretrained_path: The path to the checkpoint with pretrained LLaMA weights.
43
+ tokenizer_path: The tokenizer path to load.
44
+ quantize: Whether to quantize the model and using which method:
45
+ ``"llm.int8"``: LLM.int8() mode,
46
+ ``"gptq.int4"``: GPTQ 4-bit mode.
47
+ max_new_tokens: The number of generation steps to take.
48
+ top_k: The number of top most probable tokens to consider in the sampling process.
49
+ temperature: A value controlling the randomness of the sampling process. Higher values result in more random
50
+ samples.
51
+ """
52
+ assert adapter_path.is_file()
53
+ assert pretrained_path.is_file()
54
+ assert tokenizer_path.is_file()
55
+
56
+ precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
57
+ fabric = L.Fabric(devices=1, precision=precision)
58
+
59
+ print("Loading model ...", file=sys.stderr)
60
+ t0 = time.time()
61
+ with lazy_load(pretrained_path) as pretrained_checkpoint, lazy_load(adapter_path) as adapter_checkpoint:
62
+ name = llama_model_lookup(pretrained_checkpoint)
63
+
64
+ with fabric.init_module(empty_init=True), quantization(mode=quantize):
65
+ model = LLaMA.from_name(name)
66
+ add_adapter_v2_parameters_to_linear_layers(model)
67
+
68
+ # 1. Load the pretrained weights
69
+ model.load_state_dict(pretrained_checkpoint, strict=False)
70
+ # 2. Load the fine-tuned adapter weights
71
+ model.load_state_dict(adapter_checkpoint, strict=False)
72
+
73
+ print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
74
+
75
+ model.eval()
76
+ model = fabric.setup(model)
77
+
78
+ tokenizer = Tokenizer(tokenizer_path)
79
+ sample = {"instruction": prompt, "input": input}
80
+ prompt = generate_prompt(sample)
81
+ encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
82
+ prompt_length = encoded.size(0)
83
+
84
+ t0 = time.perf_counter()
85
+ y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
86
+ t = time.perf_counter() - t0
87
+
88
+ model.reset_cache()
89
+ output = tokenizer.decode(y)
90
+ output = output.split("### Response:")[1].strip()
91
+ print(output)
92
+
93
+ tokens_generated = y.size(0) - prompt_length
94
+ print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
95
+ if fabric.device.type == "cuda":
96
+ print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
97
+
98
+
99
+ if __name__ == "__main__":
100
+ from jsonargparse import CLI
101
+
102
+ torch.set_float32_matmul_precision("high")
103
+ warnings.filterwarnings(
104
+ # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
105
+ "ignore",
106
+ message="ComplexHalf support is experimental and many operators don't support it yet"
107
+ )
108
+ CLI(main)
generate/full.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ import warnings
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import lightning as L
8
+ import torch
9
+
10
+ # support running without installing as a package
11
+ wd = Path(__file__).absolute().parent.parent
12
+ sys.path.append(str(wd))
13
+
14
+ from lit_llama import LLaMA, Tokenizer
15
+ from lit_llama.utils import quantization
16
+ from scripts.prepare_alpaca import generate_prompt
17
+ from generate import generate
18
+
19
+
20
+ def main(
21
+ prompt: str = "Hello, my name is",
22
+ *,
23
+ num_samples: int = 1,
24
+ max_new_tokens: int = 50,
25
+ top_k: int = 200,
26
+ temperature: float = 0.8,
27
+ checkpoint_path: Optional[Path] = None,
28
+ tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
29
+ model_size: str = "7B",
30
+ quantize: Optional[str] = None,
31
+ ) -> None:
32
+ """Generates text samples based on a pre-trained LLaMA model and tokenizer.
33
+
34
+ Args:
35
+ prompt: The prompt string to use for generating the samples.
36
+ num_samples: The number of text samples to generate.
37
+ max_new_tokens: The number of generation steps to take.
38
+ top_k: The number of top most probable tokens to consider in the sampling process.
39
+ temperature: A value controlling the randomness of the sampling process. Higher values result in more random
40
+ samples.
41
+ checkpoint_path: The checkpoint path to load.
42
+ tokenizer_path: The tokenizer path to load.
43
+ model_size: The model size to load.
44
+ quantize: Whether to quantize the model and using which method:
45
+ ``"llm.int8"``: LLM.int8() mode,
46
+ ``"gptq.int4"``: GPTQ 4-bit mode.
47
+ """
48
+ if not checkpoint_path:
49
+ checkpoint_path = Path(f"checkpoints/lit-llama/{model_size}/lit-llama.pth")
50
+ assert checkpoint_path.is_file(), checkpoint_path
51
+ assert tokenizer_path.is_file(), tokenizer_path
52
+
53
+ precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
54
+ fabric = L.Fabric(devices=1, precision=precision)
55
+
56
+ print("Loading model ...", file=sys.stderr)
57
+ t0 = time.time()
58
+
59
+ with fabric.init_module(empty_init=True), quantization(mode=quantize):
60
+ model = LLaMA.from_name(model_size)
61
+
62
+ checkpoint = torch.load(checkpoint_path)
63
+ model.load_state_dict(checkpoint)
64
+ print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
65
+
66
+ model.eval()
67
+ model = fabric.setup(model)
68
+
69
+ tokenizer = Tokenizer(tokenizer_path)
70
+ sample = {"instruction": prompt, "input": input}
71
+ prompt = generate_prompt(sample)
72
+ encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
73
+ prompt_length = encoded.size(0)
74
+
75
+ L.seed_everything(1234)
76
+ for i in range(num_samples):
77
+ t0 = time.perf_counter()
78
+ y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
79
+ t = time.perf_counter() - t0
80
+
81
+ model.reset_cache()
82
+ print(tokenizer.decode(y))
83
+ tokens_generated = y.size(0) - prompt_length
84
+ print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
85
+ if fabric.device.type == "cuda":
86
+ print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
87
+
88
+
89
+ if __name__ == "__main__":
90
+ from jsonargparse import CLI
91
+
92
+ torch.set_float32_matmul_precision("high")
93
+ warnings.filterwarnings(
94
+ # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
95
+ "ignore",
96
+ message="ComplexHalf support is experimental and many operators don't support it yet"
97
+ )
98
+ warnings.filterwarnings(
99
+ # Triggered in bitsandbytes/autograd/_functions.py:298
100
+ "ignore",
101
+ message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization",
102
+ )
103
+ CLI(main)
generate/lora.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ import warnings
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import lightning as L
8
+ import torch
9
+
10
+ # support running without installing as a package
11
+ wd = Path(__file__).parent.parent.resolve()
12
+ sys.path.append(str(wd))
13
+
14
+ from generate import generate
15
+ from lit_llama import Tokenizer, LLaMA
16
+ from lit_llama.lora import lora
17
+ from lit_llama.utils import lazy_load, llama_model_lookup
18
+ from scripts.prepare_alpaca import generate_prompt
19
+
20
+ lora_r = 8
21
+ lora_alpha = 16
22
+ lora_dropout = 0.05
23
+
24
+
25
+ def main(
26
+ prompt: str = "What food do lamas eat?",
27
+ input: str = "",
28
+ lora_path: Path = Path("out/lora/alpaca/lit-llama-lora-finetuned.pth"),
29
+ pretrained_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
30
+ tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
31
+ quantize: Optional[str] = None,
32
+ max_new_tokens: int = 100,
33
+ top_k: int = 200,
34
+ temperature: float = 0.8,
35
+ ) -> None:
36
+ """Generates a response based on a given instruction and an optional input.
37
+ This script will only work with checkpoints from the instruction-tuned LoRA model.
38
+ See `finetune_lora.py`.
39
+
40
+ Args:
41
+ prompt: The prompt/instruction (Alpaca style).
42
+ lora_path: Path to the checkpoint with trained LoRA weights, which are the output of
43
+ `finetune_lora.py`.
44
+ input: Optional input (Alpaca style).
45
+ pretrained_path: The path to the checkpoint with pretrained LLaMA weights.
46
+ tokenizer_path: The tokenizer path to load.
47
+ quantize: Whether to quantize the model and using which method:
48
+ ``"llm.int8"``: LLM.int8() mode,
49
+ ``"gptq.int4"``: GPTQ 4-bit mode.
50
+ max_new_tokens: The number of generation steps to take.
51
+ top_k: The number of top most probable tokens to consider in the sampling process.
52
+ temperature: A value controlling the randomness of the sampling process. Higher values result in more random
53
+ samples.
54
+ """
55
+ assert lora_path.is_file()
56
+ assert pretrained_path.is_file()
57
+ assert tokenizer_path.is_file()
58
+
59
+ if quantize is not None:
60
+ raise NotImplementedError("Quantization in LoRA is not supported yet")
61
+
62
+ precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
63
+ fabric = L.Fabric(devices=1, precision=precision)
64
+
65
+ print("Loading model ...", file=sys.stderr)
66
+ t0 = time.time()
67
+
68
+ with lazy_load(pretrained_path) as pretrained_checkpoint, lazy_load(lora_path) as lora_checkpoint:
69
+ name = llama_model_lookup(pretrained_checkpoint)
70
+
71
+ with fabric.init_module(empty_init=True), lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
72
+ model = LLaMA.from_name(name)
73
+
74
+ # 1. Load the pretrained weights
75
+ model.load_state_dict(pretrained_checkpoint, strict=False)
76
+ # 2. Load the fine-tuned lora weights
77
+ model.load_state_dict(lora_checkpoint, strict=False)
78
+
79
+ print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
80
+
81
+ model.eval()
82
+ model = fabric.setup(model)
83
+
84
+ tokenizer = Tokenizer(tokenizer_path)
85
+ sample = {"instruction": prompt, "input": input}
86
+ prompt = generate_prompt(sample)
87
+ encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
88
+
89
+ t0 = time.perf_counter()
90
+ output = generate(
91
+ model,
92
+ idx=encoded,
93
+ max_new_tokens=max_new_tokens,
94
+ temperature=temperature,
95
+ top_k=top_k,
96
+ eos_id=tokenizer.eos_id
97
+ )
98
+ t = time.perf_counter() - t0
99
+
100
+ output = tokenizer.decode(output)
101
+ output = output.split("### Response:")[1].strip()
102
+ print(output)
103
+
104
+ print(f"\n\nTime for inference: {t:.02f} sec total, {max_new_tokens / t:.02f} tokens/sec", file=sys.stderr)
105
+ if fabric.device.type == "cuda":
106
+ print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
107
+
108
+
109
+ if __name__ == "__main__":
110
+ from jsonargparse import CLI
111
+
112
+ torch.set_float32_matmul_precision("high")
113
+ warnings.filterwarnings(
114
+ # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
115
+ "ignore",
116
+ message="ComplexHalf support is experimental and many operators don't support it yet"
117
+ )
118
+ CLI(main)
howto/convert_lora_weights.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Merging LoRA weights into base model weights
2
+
3
+ Purpose: By merging our selected LoRA weights into the base model weights, we can benefit from all base model optimisation such as quantisation (available in this repo), pruning, caching, etc.
4
+
5
+
6
+ ## How to run?
7
+
8
+ After you have finish finetuning using LoRA, select your weight and run the converter script:
9
+
10
+ ```bash
11
+ python scripts/convert_lora_weights.py --lora_path out/lora/your-folder/your-weight-name.pth
12
+ ```
13
+
14
+ The converted base weight file will be saved into the same folder with the name `{your-weight-name}-lora-merged-weights.pth`. Now you can run `generate.py` with the merged weights and apply quantisation:
15
+
16
+ ```bash
17
+ python generate.py --checkpoint_path out/lora/your-folder/your-weight-name-lora-merged-weights.pth --quantize llm.int8
18
+ ```
19
+
howto/customize_paths.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Customize paths
2
+
3
+ The project is setup to use specific paths to read the original weights and save checkpoints etc.
4
+
5
+ For all scripts, you can run
6
+
7
+ ```shell
8
+ python script.py -h
9
+ ```
10
+
11
+ to get a list of available options. For instance, here's how you would modify the checkpoint dir:
12
+
13
+ ```shell
14
+ python scripts/convert_checkpoint.py --checkpoint_dir "data/checkpoints/foo"
15
+ ```
16
+
17
+ Note that this change will need to be passed along to subsequent steps, for example:
18
+
19
+ ```shell
20
+ python generate.py \
21
+ --checkpoint_path "data/checkpoints/foo/7B/lit-llama.pth" \
22
+ --tokenizer_path "data/checkpoints/foo/tokenizer.model"
23
+ ```
24
+
25
+ and
26
+
27
+ ```shell
28
+ python quantize/gptq.py \
29
+ --checkpoint_path "data/checkpoints/foo/7B/lit-llama.pth" \
30
+ --tokenizer_path "data/checkpoints/foo/tokenizer.model"
31
+ ```
32
+
33
+ To avoid this, you can use symbolic links to create shortcuts and avoid passing different paths.
howto/download_weights.md ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Downloading pretrained weights
2
+
3
+ Except for when you are training from scratch, you will need the pretrained weights from Meta.
4
+
5
+ ### Original Meta weights
6
+
7
+ Download the model weights following the instructions on the official [LLaMA repository](https://github.com/facebookresearch/llama).
8
+
9
+ Once downloaded, you should have a folder like this:
10
+
11
+ ```text
12
+ checkpoints/llama
13
+ ├── 7B
14
+ │ ├── ...
15
+ │ └── consolidated.00.pth
16
+ ├── 13B
17
+ │ ...
18
+ └── tokenizer.model
19
+ ```
20
+
21
+ Convert the weights to the Lit-LLaMA format:
22
+
23
+ ```bash
24
+ python scripts/convert_checkpoint.py --model_size 7B
25
+ ```
26
+
27
+ > **Note**
28
+ > All scripts support argument [customization](customize_paths.md)
29
+
30
+ ### OpenLLaMA
31
+
32
+ OpenLM Research has released **Apache 2.0 licensed** weights obtained by training LLaMA on the 1.2 trillion token open-source [RedPajama](https://github.com/togethercomputer/RedPajama-Data) dataset.
33
+
34
+ Weights were released in preview on intermediate number of tokens (1T at the time of writing). In order to get them do:
35
+
36
+ ```bash
37
+ # Make sure you have git-lfs installed (https://git-lfs.com): git lfs install
38
+ git clone https://huggingface.co/openlm-research/open_llama_7b checkpoints/open-llama/7B
39
+ ```
40
+
41
+ Or if you don't have `git-lfs` installed:
42
+
43
+ ```bash
44
+ python scripts/download.py --repo_id openlm-research/open_llama_7b --local_dir checkpoints/open-llama/7B
45
+ ```
46
+
47
+ Once downloaded, you should have a folder like this:
48
+
49
+ ```text
50
+ checkpoints/open-llama/
51
+ └── 7B
52
+ ├── ...
53
+ ├── pytorch_model-00001-of-00002.bin
54
+ ├── pytorch_model-00002-of-00002.bin
55
+ ├── pytorch_model.bin.index.json
56
+ └── tokenizer.model
57
+ ```
58
+
59
+ Convert the weights to the Lit-LLaMA format:
60
+
61
+ ```bash
62
+ python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/open-llama/7B --model_size 7B
63
+ ```
64
+
65
+ > **Note**
66
+ > All scripts support argument [customization](customize_paths.md)
67
+
68
+ Once converted, you should have a folder like this:
69
+
70
+ ```text
71
+ checkpoints/lit-llama/
72
+ ├── 7B
73
+ │ └── lit-llama.pth
74
+ └── tokenizer.model
75
+ ```
76
+
77
+ You are all set. Now you can continue with inference or finetuning.
78
+
79
+ Try running [`generate.py` to test the imported weights](inference.md).
80
+
81
+
82
+ ### Alternative sources
83
+
84
+ You might find LLaMA weights hosted online in the HuggingFace hub. Beware that this infringes the original weight's license.
85
+ You could try downloading them by running the following command with a specific repo id:
86
+
87
+ ```bash
88
+ # Make sure you have git-lfs installed (https://git-lfs.com): git lfs install
89
+ git clone REPO_ID checkpoints/hf-llama/7B
90
+ ```
91
+
92
+ Or if you don't have `git-lfs` installed:
93
+
94
+ ```bash
95
+ python scripts/download.py --repo_id REPO_ID --local_dir checkpoints/hf-llama/7B
96
+ ```
97
+
98
+ Once downloaded, you should have a folder like this:
99
+
100
+ ```text
101
+ checkpoints/hf-llama/
102
+ └── 7B
103
+ ├── ...
104
+ ├── pytorch_model-00001-of-00002.bin
105
+ ├── pytorch_model-00002-of-00002.bin
106
+ ├── pytorch_model.bin.index.json
107
+ └── tokenizer.model
108
+ ```
109
+
110
+ Convert the weights to the Lit-LLaMA format:
111
+
112
+ ```bash
113
+ python scripts/convert_hf_checkpoint.py --model_size 7B
114
+ ```
115
+
116
+ > **Note**
117
+ > All scripts support argument [customization](customize_paths.md)
118
+
119
+ Once converted, you should have a folder like this:
120
+
121
+ ```text
122
+ checkpoints/lit-llama/
123
+ ├── 7B
124
+ │ └── lit-llama.pth
125
+ └── tokenizer.model
126
+ ```
127
+
128
+ You are all set. Now you can continue with inference or finetuning.
129
+
130
+ Try running [`generate.py` to test the imported weights](inference.md).
howto/finetune_adapter.md ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Finetuning with Adapter
2
+
3
+ [LLaMA-Adapter](https://arxiv.org/abs/2303.16199) is a form of prefix-tuning that prepends a learnable adaption-prompt to the inputs of the attention blocks in LLaMA. In total, there are only 1.2M parameters to update during finetuning, which significantly reduces the memory footprint and speeds up training.
4
+
5
+ We are able to demonstrate instruction-finetuning Lit-LLaMA 7B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3090 (24GB) GPU**. If using 8 GPUs, finetuning can be completed in under 1 hour.
6
+
7
+ If you are new to LLaMA-Adapter and are interested to learn more about how it works before proceeding with the finetuning guide below, you might find our article [Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters](https://lightning.ai/pages/community/article/understanding-llama-adapters/) helpful.
8
+
9
+ ## LLaMA-Adapter v2
10
+
11
+ The LLaMA-Adapter authors developed a newer adapter method called LLaMA-Adapter v2, which is related to this LLaMA-Adapter method but includes more trainable parameters. LLaMA-Adapter v2 is also available via Lit-LLaMA; you can read more about it in [the related how-to doc here](./finetune_adapter_v2.md).
12
+
13
+ ## Preparation
14
+
15
+ The steps here only need to be done once:
16
+
17
+ 1. Follow the instructions in the [README](README.md) to install the dependencies.
18
+ 2. Download and convert the weights and save them in the `./checkpoints` folder as described [here](download_weights.md).
19
+ 3. If you want to utilize more than one GPU, you should `pip install deepspeed`.
20
+ 4. Download the data and generate the Alpaca instruction tuning dataset:
21
+
22
+ ```bash
23
+ python scripts/prepare_alpaca.py
24
+ ```
25
+
26
+ or [prepare your own dataset](#tune-on-your-dataset).
27
+
28
+ See also: [Finetuning on an unstructured dataset](unstructured_dataset.md)
29
+
30
+ ## Running the finetuning
31
+
32
+ ```bash
33
+ python finetune/adapter.py
34
+ ```
35
+
36
+ The finetuning requires at least one GPU with ~24 GB memory (RTX 3090).
37
+ You can speed up training by setting the `devices` variable in the script to utilize more GPUs if available.
38
+ Depending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently.
39
+
40
+ For example, the following settings will let you finetune the model in under 1 hour using DeepSpeed Zero-2:
41
+
42
+ ```python
43
+ devices = 8
44
+ micro_batch_size = 8
45
+ ```
46
+
47
+ This script will save checkpoints periodically to the folder `out/`.
48
+
49
+ > **Note**
50
+ > All scripts support argument [customization](customize_paths.md)
51
+
52
+ ## Test the model
53
+
54
+ You can test the finetuned model with your own instructions by running:
55
+
56
+ ```bash
57
+ python generate/adapter.py \
58
+ --prompt "Recommend a movie to watch on the weekend." \
59
+ --quantize llm.int8
60
+ ```
61
+ Output:
62
+ ```
63
+ A good movie to watch on the weekend would be The Lion King, since it's a classic family film that everyone can enjoy...
64
+ ```
65
+ If your GPU supports `bfloat16`, the script will automatically use it. Together with `--quantize llm.int8`, this brings the memory consumption down to ~8 GB.
66
+
67
+ ## Tune on your dataset
68
+
69
+ With only a few modifications, you can prepare and train on your own instruction dataset.
70
+
71
+ 1. Create a json file in which each row holds one instruction-response pair.
72
+ A row has an entry for 'instruction', 'input', and 'output', where 'input' is optional an can be
73
+ the empty string if the instruction doesn't require a context. Below is an example json file:
74
+
75
+ ```
76
+ [
77
+ {
78
+ "instruction": "Arrange the given numbers in ascending order.",
79
+ "input": "2, 4, 0, 8, 3",
80
+ "output": "0, 2, 3, 4, 8"
81
+ },
82
+ ...
83
+ ]
84
+ ```
85
+
86
+ 2. Make a copy of `scripts/prepare_alpaca.py` and name it what you want:
87
+
88
+ ```bash
89
+ cp scripts/prepare_alpaca.py scripts/prepare_mydata.py
90
+ ```
91
+
92
+ 3. Modify `scripts/prepare_mydata.py` to read the json data file.
93
+ 4. Run the script to generate the preprocessed, tokenized train-val split:
94
+
95
+ ```bash
96
+ python scripts/prepare_mydata.py --destination_path data/mydata/
97
+ ```
98
+
99
+ 5. Run `finetune/adapter.py` by passing in the location of your data (and optionally other parameters):
100
+
101
+ ```bash
102
+ python finetune/adapter.py --data_dir data/mydata/ --out_dir out/myexperiment
103
+ ```
104
+
105
+
106
+ ## Troubleshooting
107
+
108
+ If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
109
+ `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
howto/finetune_adapter_v2.md ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Finetuning with Adapter v2
2
+
3
+ [LLaMA-Adapter v2](https://arxiv.org/abs/2304.15010) is a form of prefix-tuning that prepends a learnable adaption-prompt to the inputs of the attention blocks in LLaMA. In total, there are only ~4 M parameters to update during finetuning, which significantly reduces the memory footprint and speeds up training.
4
+
5
+ We are able to demonstrate instruction-finetuning Lit-LLaMA 7B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3090 (24GB) GPU**. If using 8 GPUs, finetuning can be completed in under 1 hour.
6
+
7
+ If you are new to LLaMA-Adapter and are interested to learn more about how it works before proceeding with the finetuning guide below, you might find our article [Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters](https://lightning.ai/pages/community/article/understanding-llama-adapters/) helpful.
8
+
9
+ ## LLaMA-Adapter v1 versus LLaMA-Adapter v2
10
+
11
+ LLaMA-Adapter v2 extends the original LLaMA-Adapter idea by adding trainable bias and scale parameters to each linear layer in the transformer. Furthermore, LLaMA-Adapter v2 makes the normalization layers trainable. Where the 7B LLaMA model has 1.2M trainable parameters with LLaMA v1, LLaMA-Adapter v2 adds 2.8 M trainable parameters for the bias and scale parameters and ~300k trainable parameters for the normalization layers. So, adapter v2 has ~4.3 M trainable parameters in total.
12
+
13
+ If you are interested in using the more lightweight LLaMA-Adapter v1 approach, see [the related LLaMA Adapter how-to doc here](./finetune_adapter.md).
14
+
15
+ While LLaMA-Adapter v2 increases the number of trainable parameters from 1.2 M (from LLaMA-Apdapter v1) to 4.3 M, the inference cost is not significantly impacted. This is because the additional bias and scale parameters are cheap to compute in the forward pass, and the RMSNorm parameters are already included in the base model. In LLaMA-Adapter v1, the RMSNorm parameters are not trainable.
16
+
17
+
18
+ ## Preparation
19
+
20
+ The steps here only need to be done once:
21
+
22
+ 1. Follow the instructions in the [README](README.md) to install the dependencies.
23
+ 2. Download and convert the weights and save them in the `./checkpoints` folder as described [here](download_weights.md).
24
+ 3. If you want to utilize more than one GPU, you should `pip install deepspeed`.
25
+ 4. Download the data and generate the Alpaca instruction tuning dataset:
26
+
27
+ ```bash
28
+ python scripts/prepare_alpaca.py
29
+ ```
30
+
31
+ or [prepare your own dataset](#tune-on-your-dataset).
32
+
33
+ See also: [Finetuning on an unstructured dataset](unstructured_dataset.md)
34
+
35
+ ## Running the finetuning
36
+
37
+ ```bash
38
+ python finetune/adapter_v2.py
39
+ ```
40
+
41
+ The finetuning requires at least one GPU with ~24 GB memory (RTX 3090).
42
+ You can speed up training by setting the `devices` variable in the script to utilize more GPUs if available.
43
+ Depending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently.
44
+
45
+ For example, the following settings will let you finetune the model in under 1 hour using DeepSpeed Zero-2:
46
+
47
+ ```python
48
+ devices = 8
49
+ micro_batch_size = 8
50
+ ```
51
+
52
+ This script will save checkpoints periodically to the folder `out/`.
53
+
54
+ > **Note**
55
+ > All scripts support argument [customization](customize_paths.md)
56
+
57
+ ## Test the model
58
+
59
+ You can test the finetuned model with your own instructions by running:
60
+
61
+ ```bash
62
+ python generate/adapter_v2.py \
63
+ --prompt "Recommend a movie to watch on the weekend." \
64
+ --quantize llm.int8
65
+ ```
66
+ Output:
67
+ ```
68
+ A good movie to watch on the weekend would be The Lion King, since it's a classic family film that everyone can enjoy...
69
+ ```
70
+ If your GPU supports `bfloat16`, the script will automatically use it. Together with `--quantize llm.int8`, this brings the memory consumption down to ~8 GB.
71
+
72
+ ## Tune on your dataset
73
+
74
+ With only a few modifications, you can prepare and train on your own instruction dataset.
75
+
76
+ 1. Create a json file in which each row holds one instruction-response pair.
77
+ A row has an entry for 'instruction', 'input', and 'output', where 'input' is optional an can be
78
+ the empty string if the instruction doesn't require a context. Below is an example json file:
79
+
80
+ ```
81
+ [
82
+ {
83
+ "instruction": "Arrange the given numbers in ascending order.",
84
+ "input": "2, 4, 0, 8, 3",
85
+ "output": "0, 2, 3, 4, 8"
86
+ },
87
+ ...
88
+ ]
89
+ ```
90
+
91
+ 2. Make a copy of `scripts/prepare_alpaca.py` and name it what you want:
92
+
93
+ ```bash
94
+ cp scripts/prepare_alpaca.py scripts/prepare_mydata.py
95
+ ```
96
+
97
+ 3. Modify `scripts/prepare_mydata.py` to read the json data file.
98
+ 4. Run the script to generate the preprocessed, tokenized train-val split:
99
+
100
+ ```bash
101
+ python scripts/prepare_mydata.py --destination_path data/mydata/
102
+ ```
103
+
104
+ 5. Run `finetune/adapter_v2.py` by passing in the location of your data (and optionally other parameters):
105
+
106
+ ```bash
107
+ python finetune/adapter_v2.py --data_dir data/mydata/ --out_dir out/myexperiment
108
+ ```
109
+
110
+
111
+ ## Troubleshooting
112
+
113
+ If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
114
+ `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
howto/finetune_full.md ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Full Finetuning
2
+
3
+ Full finetuning updates all layers in the pretrained LLaMA model. This *regular* finetuning procedure is typically considered as the baseline for parameter-efficient alternatives such as Low-Rank Adaptation (LoRA) or LLaMA-Adapter.
4
+
5
+ The current [finetune/full.py](../finetune/full.py) we provide uses 4 A100 GPUs with a fully-sharded data parallel strategy to finetune Lit-LLaMA 7B on [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset. The A100 GPUs have 40 GB each, but it may require less memory to finetune this model.
6
+
7
+
8
+
9
+ ## Preparation
10
+
11
+ The steps here only need to be done once:
12
+
13
+ 1. Follow the instructions in the [README](README.md) to install the dependencies.
14
+
15
+ 2. Download and convert the weights and save them in the `./checkpoints` folder as described [here](download_weights.md).
16
+
17
+ 4. Download the data and generate the Alpaca instruction tuning dataset:
18
+
19
+ ```bash
20
+ python scripts/prepare_alpaca.py
21
+ ```
22
+
23
+ or [prepare your own dataset](#tune-on-your-own-dataset).
24
+
25
+ See also: [Finetuning on an unstructured dataset](unstructured_dataset.md)
26
+
27
+ ## Running the finetuning
28
+
29
+ ```bash
30
+ python finetune/full.py
31
+ ```
32
+
33
+
34
+ You can speed up training by setting the `devices` variable in the script to utilize more GPUs if available or increase the `batch_size`.
35
+ Depending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently.
36
+
37
+ For example, the following settings will let you finetune the model in 32 hours using a fully-sharded data parallel strategy:
38
+ ```python
39
+ devices = 4
40
+ batch_size = 128 // devices
41
+ micro_batch_size = 4
42
+ ```
43
+
44
+ This script will save checkpoints periodically to the folder `out/`.
45
+
46
+ > **Note**
47
+ > All scripts support argument [customization](customize_paths.md)
48
+
49
+ ## Test the model
50
+
51
+ You can test the finetuned model with your own instructions by running:
52
+
53
+ ```bash
54
+ python generate/full.py \
55
+ --prompt "Recommend a movie to watch on the weekend." \
56
+ --quantize llm.int8
57
+ ```
58
+ Output:
59
+ ```
60
+ A good movie to watch on the weekend would be The Lion King, since it's a classic family film that everyone can enjoy...
61
+ ```
62
+ If your GPU supports `bfloat16`, the script will automatically use it. Together with `--quantize llm.int8`, this brings the memory consumption down to ~8 GB.
63
+
64
+ ## Tune on your dataset
65
+
66
+ With only a few modifications, you can prepare and train on your own instruction dataset.
67
+
68
+ 1. Create a json file in which each row holds one instruction-response pair.
69
+ A row has an entry for 'instruction', 'input', and 'output', where 'input' is optional an can be
70
+ the empty string if the instruction doesn't require a context. Below is an example json file:
71
+
72
+ ```
73
+ [
74
+ {
75
+ "instruction": "Arrange the given numbers in ascending order.",
76
+ "input": "2, 4, 0, 8, 3",
77
+ "output": "0, 2, 3, 4, 8"
78
+ },
79
+ ...
80
+ ]
81
+ ```
82
+
83
+ 2. Make a copy of `scripts/prepare_alpaca.py` and name it what you want:
84
+
85
+ ```bash
86
+ cp scripts/prepare_alpaca.py scripts/prepare_mydata.py
87
+ ```
88
+
89
+ 3. Modify `scripts/prepare_mydata.py` to read the json data file.
90
+ 4. Run the script to generate the preprocessed, tokenized train-val split:
91
+
92
+ ```bash
93
+ python scripts/prepare_mydata.py --destination_path data/mydata/
94
+ ```
95
+
96
+ 5. Run `finetune/full.py` by passing in the location of your data (and optionally other parameters):
97
+
98
+ ```bash
99
+ python finetune/full.py --data_dir data/mydata/ --out_dir out/myexperiment
100
+ ```
101
+
102
+
103
+ ## Troubleshooting
104
+
105
+ If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
106
+ `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
howto/finetune_lora.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Finetuning with LoRA
2
+
3
+ [Low-rank adaption (LoRA)](https://arxiv.org/abs/2106.09685) is a technique to approximate the update to the linear layers in a LLM with a low-rank matrix factorization. This significantly reduces the number of trainable parameters and speeds up training with little impact on the final performance of the model.
4
+ We demonstrate this method by instruction-finetuning LLaMA 7B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3090 (24GB) GPU**.
5
+
6
+ ## Preparation
7
+
8
+ The steps here only need to be done once:
9
+
10
+ 1. Follow the instructions in the [README](../README.md) to install the dependencies.
11
+ 2. Download and convert the weights and save them in the `./checkpoints` folder as described [here](download_weights.md).
12
+ 3. Download the data and generate the instruction tuning dataset:
13
+
14
+ ```bash
15
+ python scripts/prepare_alpaca.py
16
+ ```
17
+
18
+ See also: [Finetuning on an unstructured dataset](unstructured_dataset.md)
19
+
20
+ ## Running the finetuning
21
+
22
+ ```bash
23
+ python finetune/lora.py
24
+ ```
25
+
26
+ The finetuning requires at least one GPU with ~24 GB memory (RTX 3090).
27
+
28
+ This script will save checkpoints periodically to the folder `out/`.
29
+
30
+ > **Note**
31
+ > All scripts support argument [customization](customize_paths.md)
32
+
33
+
34
+ ## Test the model
35
+
36
+ You can test the finetuned model with your own instructions by running:
37
+
38
+ ```bash
39
+ python generate/lora.py --prompt "Recommend a movie to watch on the weekend."
40
+ ```
41
+ Output:
42
+ ```
43
+ I would recommend the movie The Martian (2015). It is a sci-fi movie starring Matt Damon that follows the story of...
44
+ ```
45
+
46
+ If your GPU supports `bfloat16`, you can additionally pass `--dtype bfloat16` to bring the memory consumption down to ~14 GB.
47
+
48
+ ## Tune on your dataset
49
+
50
+ With only a few modifications, you can prepare and train on your own instruction dataset.
51
+
52
+ 1. Create a json file in which each row holds one instruction-response pair.
53
+ A row has an entry for 'instruction', 'input', and 'output', where 'input' is optional an can be
54
+ the empty string if the instruction doesn't require a context. Below is an example json file:
55
+
56
+ ```
57
+ [
58
+ {
59
+ "instruction": "Arrange the given numbers in ascending order.",
60
+ "input": "2, 4, 0, 8, 3",
61
+ "output": "0, 2, 3, 4, 8"
62
+ },
63
+ ...
64
+ ]
65
+ ```
66
+
67
+ 2. Make a copy of `scripts/prepare_alpaca.py` and name it what you want:
68
+
69
+ ```bash
70
+ cp scripts/prepare_alpaca.py scripts/prepare_mydata.py
71
+ ```
72
+
73
+ 3. Modify `scripts/prepare_mydata.py` to read the json data file.
74
+ 4. Run the script to generate the preprocessed, tokenized train-val split:
75
+
76
+ ```bash
77
+ python scripts/prepare_mydata.py --destination_path data/mydata/
78
+ ```
79
+
80
+ 5. Run `finetune/lora.py` by passing in the location of your data (and optionally other parameters):
81
+
82
+ ```bash
83
+ python finetune/lora.py --data_dir data/mydata/ --out_dir out/myexperiment
84
+ ```
85
+
86
+
87
+ ## Troubleshooting
88
+
89
+ If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
90
+ `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
howto/inference.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference
2
+
3
+ We demonstrate how to run inference (next token prediction) with the LLaMA base model in the [`generate.py`](generate.py) script:
4
+
5
+ ```bash
6
+ python generate.py --prompt "Hello, my name is"
7
+ ```
8
+ Output:
9
+ ```
10
+ Hello my name is TJ. I have a passion for the outdoors, love hiking and exploring. I also enjoy traveling and learning new things. I especially enjoy long walks, good conversation and a friendly smile.
11
+ ```
12
+
13
+ The script assumes you have downloaded and converted the weights and saved them in the `./checkpoints` folder as described [here](download_weights.md).
14
+
15
+ > **Note**
16
+ > All scripts support argument [customization](customize_paths.md)
17
+
18
+ With the default settings, this will run the 7B model and require ~26 GB of GPU memory (A100 GPU).
19
+
20
+ ## Run Lit-LLaMA on consumer devices
21
+
22
+ On GPUs with `bfloat16` support, the `generate.py` script will automatically convert the weights and consume about ~14 GB.
23
+ For GPUs with less memory, or ones that don't support `bfloat16`, enable quantization (`--quantize llm.int8`):
24
+
25
+ ```bash
26
+ python generate.py --quantize llm.int8 --prompt "Hello, my name is"
27
+ ```
28
+ This will consume about ~10 GB of GPU memory or ~8 GB if also using `bfloat16`.
29
+ See `python generate.py --help` for more options.
30
+
31
+ You can also use GPTQ-style int4 quantization, but this needs conversions of the weights first:
32
+
33
+ ```bash
34
+ python quantize/gptq.py --output_path checkpoints/lit-llama/7B/llama-gptq.4bit.pth --dtype bfloat16 --quantize gptq.int4
35
+ ```
36
+
37
+ GPTQ-style int4 quantization brings GPU usage down to about ~5GB. As only the weights of the Linear layers are quantized, it is useful to also use `--dtype bfloat16` even with the quantization enabled.
38
+
39
+ With the generated quantized checkpoint generation quantization then works as usual with `--quantize gptq.int4` and the newly generated checkpoint file:
40
+
41
+ ```bash
42
+ python generate.py --quantize gptq.int4 --checkpoint_path checkpoints/lit-llama/7B/llama-gptq.4bit.pth
43
+ ```
howto/tpus.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TPU support
2
+
3
+ Lit-LLaMA used `lightning.Fabric` under the hood, which itself supports TPUs (via [PyTorch XLA](https://github.com/pytorch/xla)).
4
+
5
+ The following commands will allow you to set up a `Google Cloud` instance with a [TPU v4](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm) VM:
6
+
7
+ ```shell
8
+ gcloud compute tpus tpu-vm create lit-llama --version=tpu-vm-v4-pt-2.0 --accelerator-type=v4-8 --zone=us-central2-b
9
+ gcloud compute tpus tpu-vm ssh lit-llama --zone=us-central2-b
10
+ ```
11
+
12
+ Now that you are in the machine, let's clone the repository and install the dependencies
13
+
14
+ ```shell
15
+ git clone https://github.com/Lightning-AI/lit-llama
16
+ cd lit-llama
17
+ pip install -r requirements.txt
18
+ ```
19
+
20
+ By default, computations will run using the new (and experimental) PjRT runtime. Still, it's recommended that you set the following environment variables
21
+
22
+ ```shell
23
+ export PJRT_DEVICE=TPU
24
+ export ALLOW_MULTIPLE_LIBTPU_LOAD=1
25
+ ```
26
+
27
+ > **Note**
28
+ > You can find an extensive guide on how to get set-up and all the available options [here](https://cloud.google.com/tpu/docs/v4-users-guide).
29
+
30
+ Since you created a new machine, you'll probably need to download the weights. You could scp them into the machine with `gcloud compute tpus tpu-vm scp` or you can follow the steps described in our [downloading guide](download_weights.md).
31
+
32
+ ## Inference
33
+
34
+ Generation works out-of-the-box with TPUs:
35
+
36
+ ```shell
37
+ python3 generate.py --prompt "Hello, my name is" --num_samples 3
38
+ ```
39
+
40
+ This command will take take ~20s for the first generation time as XLA needs to compile the graph.
41
+ You'll notice that afterwards, generation times drop to ~5s.
42
+
43
+ ## Finetuning
44
+
45
+ Coming soon.
46
+
47
+ > **Warning**
48
+ > When you are done, remember to delete your instance
49
+ > ```shell
50
+ > gcloud compute tpus tpu-vm delete lit-llama --zone=us-central2-b
51
+ > ```
howto/train_redpajama.md ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pre-train LLaMA on RedPajama
2
+
3
+ This howto will walk you through setting up the RedPajama dataset and launching the pre-training script.
4
+
5
+ ## What's RedPajama
6
+
7
+ [RedPajama](https://github.com/togethercomputer/RedPajama-Data) is an open-source reproduction of the original LLaMA training dataset.
8
+
9
+ It contains a total of 1.2 trillion tokens, divided into
10
+
11
+ ```text
12
+ Commoncrawl 878B
13
+ C4 175B
14
+ GitHub 59B
15
+ Books 26B
16
+ ArXiv 28B
17
+ Wikipedia 24B
18
+ StackExchange 20B
19
+ ```
20
+
21
+ The [RedPajama repo](https://github.com/togethercomputer/RedPajama-Data) contains the source code for collecting and preparing
22
+ the dataset, and it is Apache 2.0 licensed.
23
+
24
+ The data itself is licensed according to the original licenses with which its invidivdual parts were released.
25
+ The GitHub datasets are limited to MIT, BSD, or Apache 2.0 repositories.
26
+
27
+ Along with the full [RedPajama-1T dataset](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T),
28
+ the [RedPajama-1T-Sample](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample) 1B sample dataset
29
+ is also available for development.
30
+
31
+ You can download the data using git lfs:
32
+
33
+ ```bash
34
+ # Make sure you have git-lfs installed (https://git-lfs.com): git lfs install
35
+ git clone https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T data/RedPajama-Data-1T
36
+ ```
37
+
38
+ ```bash
39
+ # Make sure you have git-lfs installed (https://git-lfs.com): git lfs install
40
+ git clone https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample data/RedPajama-Data-1T-Sample
41
+ ```
42
+
43
+ ## Prepare RedPajama for training
44
+
45
+ The dataset consists of 2084 `jsonl` files (the sample dataset contains 11). In order to start pre-training lit-llama
46
+ on it, you need to read, tokenize, and write the data in binary chunks. This will leverage the `PackedDataset`
47
+ streaming dataset that comes with lit-llama.
48
+
49
+ Do to so, run
50
+
51
+ ```bash
52
+ python scripts/prepare_redpajama.py --source_path data/RedPajama-Data-1T --tokenizer_path checkpoints/lit-llama/tokenizer.model --destination_path data/lit-redpajama
53
+ ```
54
+
55
+ or
56
+
57
+ ```bash
58
+ python scripts/prepare_redpajama.py --source_path data/RedPajama-Data-1T-Sample --tokenizer_path checkpoints/lit-llama/tokenizer.model --destination_path data/lit-redpajama-sample --sample True
59
+ ```
60
+
61
+ for the sample dataset.
62
+
63
+ In the above we are assuming that you will be using the same tokenizer as used in LLaMA, but any trained [SentencePiece](https://github.com/google/sentencepiece) tokenizer with a 32000 vocabulary size will do here.
64
+
65
+ The script will take a while to run, so time for :tea:
66
+
67
+ ## Pre-training
68
+
69
+ Running the pre-training script requires at least 4 GPUs with 40GB+ each (A100).
70
+
71
+ ```bash
72
+ python pretrain/redpajama.py --devices 4 --train_data_dir data/lit-redpajama
73
+ ```
74
+
75
+ For running on the sample dataset:
76
+
77
+ ```bash
78
+ python pretrain/redpajama.py --devices 4 --train_data_dir data/lit-redpajama-sample
79
+ ```
80
+
81
+ The script will save checkpoints periodically to the folder `out/`.
82
+
83
+ The `train_redpajama.py` script will pre-train the LLaMA 7B model with FSDP in
84
+ `bfloat16` precision and gradient accumulation.
85
+
86
+ You can easily change the size of the model by passing a different string to
87
+
88
+ ```python
89
+ config = LLaMAConfig.from_name("7B")
90
+ ```
91
+
92
+ in the `main` function.
93
+
94
+ Keep in mind that the original LLaMA training for the 7B model required 83k A100 80GB
95
+ hours, so you'll need access to a cluster.
96
+
97
+ Once you're in a cluster, you can follow [these instructions](https://lightning.ai/docs/fabric/stable/guide/multi_node/other.html)
98
+ to launch the script across machines:
99
+
100
+ - [SLURM cluster](https://lightning.ai/docs/fabric/stable/guide/multi_node/slurm.html)
101
+ - [Barebones cluster](https://lightning.ai/docs/fabric/stable/guide/multi_node/barebones.html)
102
+ - [MPI](https://lightning.ai/docs/fabric/stable/guide/multi_node/other.html)
103
+
104
+ The script contains several configurations and hyperparameters you can tweak:
105
+
106
+ ```python
107
+ out_dir = "out/training"
108
+ save_interval = 1000
109
+ eval_interval = 1000
110
+ eval_iters = 100
111
+ log_interval = 1
112
+
113
+ # Hyperparameters
114
+ learning_rate = 6e-4
115
+ batch_size = 125
116
+ micro_batch_size = 5
117
+ max_iters = 600000 # num_epochs * (epoch_size // micro_batch_size) // devices
118
+ weight_decay = 1e-1
119
+ beta1 = 0.9
120
+ beta2 = 0.95
121
+ grad_clip = 1.0
122
+ decay_lr = True
123
+ warmup_iters = 2000
124
+ lr_decay_iters = max_iters
125
+ min_lr = 6e-5
126
+ ```
127
+
128
+ In particular, `micro_batch_size` should be adjusted so the process will use the available
129
+ GPU memory.
130
+
131
+ Last, logging is kept minimal in the script. In order to use a particular logger
132
+ please refer to <https://lightning.ai/docs/fabric/stable/api/loggers.html> or
133
+ call a logging client library like `wandb` directly.
howto/unstructured_dataset.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Finetuning on an unstructured dataset
2
+
3
+ While most scripts were made to finetune on instruction datasets, it is possible to finetune on any dataset. This is useful for experimentation while not being as expensive as training a full model.
4
+
5
+ This guide is only to prepare the finetuning, as either LoRA or Adapter-v1 methods support this dataset type!
6
+
7
+ ## Preparation
8
+
9
+ 1. Gather your text into an input file named `input.txt`
10
+ 2. Divide the data into training and validation sets using the following script:
11
+
12
+ ```bash
13
+ python scripts/prepare_any_text.py
14
+ ```
15
+
16
+ 3. Modify relevant scripts for your finetuning method under `finetune/` and `evaluate/`, setting the `instruction_tuning` variable to `False`
17
+
18
+ And then you're set! Proceed to run the [LoRA guide](./finetune_lora.md) or [Adapter v1 guide](./finetune_adapter.md).
lit_llama/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from lit_llama.model import LLaMAConfig, LLaMA, RMSNorm, build_rope_cache, apply_rope
2
+ from lit_llama.tokenizer import Tokenizer
lit_llama/adapter.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of the paper:
2
+
3
+ LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
4
+ https://arxiv.org/abs/2303.16199
5
+
6
+ | Prefix cross-attention
7
+ |
8
+ ┌─────────────────┐ | ┌──────────────────┐
9
+ ┆ x ┆ | ┆ prefix ┆
10
+ └─────────────────┘ | └──────────────────┘
11
+ | | |
12
+ ▼ | ▼
13
+ ┌──────────────────┐ | ┌─────────────────────┐
14
+ ┆ self-attention ┆ --------------------------------------------------------------┐ ┆ linear projection ┆
15
+ └──────────────────┘ | ┆ └─────────────────────┘
16
+ | | ┆ | \
17
+ ▼ | ▼ ▼ ▼
18
+ ╭───╮ ┌────────────────┐ ╭───╮ ┌──────────────────────────┐ | ┌─────────┐ ┌──────────────┐ ┌────────────────┐
19
+ ┆ + ┆ ◀── ┆ gating factor ┆-┆ x ┆-┆ prefix cross-attention ┆ | ┆ query ┆ ┆ prefix key ┆ ┆ prefix value ┆
20
+ ╰───╯ └────────────────┘ ╰───╯ └──────────────────────────┘ | └─────────┘ └──────────────┘ └────────────────┘
21
+ | | \ | /
22
+ ▼ | ▼ ▼ ▼
23
+ | ┌────────────────────────────────┐
24
+ | ┆ scaled dot-product attention ┆
25
+ | └────────────────────────────────┘
26
+
27
+
28
+ In order to inject learnable information from the prefix to pretrained weights we need to sum outputs from
29
+ self-attention and prefix cross-attention (times gating factor). For prefix cross-attention we need `query` (from
30
+ self-attention as a result of linear projection), `prefix key` and `prefix value` (from cross-attention as a result of
31
+ linear projection).
32
+ The output of prefix cross-attention is multiplied by gating factor, which is a learnable parameter that is needed to
33
+ avoid potential disruption of pretrained weights caused by incorporating randomly initialized tensors. This factor is
34
+ initialized with zeros to avoid noise from the adaption prompts at the early training stage.
35
+ More about it: https://lightning.ai/pages/community/article/understanding-llama-adapters/
36
+
37
+ Notes about implementation: as per paper adapter's prefix is concatenated with the input, while here outputs of
38
+ self-attention and prefix cross-attention are summed. Both variants are mathematically equivalent:
39
+ https://github.com/ZrrSkywalker/LLaMA-Adapter/issues/47
40
+ """
41
+ # mypy: ignore-errors
42
+ from dataclasses import dataclass
43
+ from typing import Optional, Tuple, List, Union
44
+
45
+ import torch
46
+ import torch.nn as nn
47
+ from torch.nn import functional as F
48
+
49
+ import lit_llama.model as llama
50
+ from lit_llama.model import build_rope_cache, apply_rope, RMSNorm, MLP, KVCache, RoPECache
51
+
52
+
53
+ @dataclass
54
+ class LLaMAConfig(llama.LLaMAConfig):
55
+ adapter_prompt_length: int = 10
56
+ adapter_start_layer: int = 2
57
+
58
+
59
+ class CausalSelfAttention(nn.Module):
60
+ """A modification of `lit_llama.model.CausalSelfAttention` that adds the attention
61
+ over the adaption prompt."""
62
+
63
+ def __init__(self, config: LLaMAConfig, block_idx: int) -> None:
64
+ super().__init__()
65
+ assert config.n_embd % config.n_head == 0
66
+
67
+ # key, query, value projections for all heads, but in a batch
68
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
69
+ # output projection
70
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
71
+
72
+ if block_idx >= config.adapter_start_layer:
73
+ # adapter embedding layer
74
+ self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
75
+ # a learnable gating factor (to avoid potential disruption of pretrained weights) initialized with zeros (to
76
+ # avoid noise from adaption prompts at the early training stage)
77
+ self.gating_factor = torch.nn.Parameter(torch.zeros(1, config.n_head, 1, 1))
78
+
79
+ self.n_head = config.n_head
80
+ self.n_embd = config.n_embd
81
+ self.block_size = config.block_size
82
+ self.block_idx = block_idx
83
+ self.adapter_prompt_length = config.adapter_prompt_length
84
+ self.adapter_start_layer = config.adapter_start_layer
85
+
86
+ def forward(
87
+ self,
88
+ x: torch.Tensor,
89
+ rope: RoPECache,
90
+ mask: torch.Tensor,
91
+ max_seq_length: int,
92
+ input_pos: Optional[torch.Tensor] = None,
93
+ kv_cache: Optional[KVCache] = None,
94
+ adapter_kv_cache: Optional[KVCache] = None,
95
+ ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]:
96
+ # notation:
97
+ # - B | batch
98
+ # - T | time-step (sequence length)
99
+ # - C | embeddings size (n_embd) = head size * num heads
100
+ # - hs | head size
101
+ # - nh | number of heads
102
+
103
+ B, T, C = x.size()
104
+
105
+ # instead of calculating `query`, `key` and `value` by separately multiplying input `x` with corresponding
106
+ # weight matrices do it (for all heads) in a single multiplication with a matrix of 3x size (concatenated
107
+ # weights for q, k, v) and then split the result along `embedding size` dimension
108
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2) # (B, T, 3 * C) --> 3 * (B, T, C)
109
+
110
+ # in order to move head_size (hs) dimension right after batch (B) dimension, we need to first split
111
+ # embedding size (C) dimension into num_heads (nh) and head_size (hs)
112
+ head_size = C // self.n_head
113
+ k = k.view(B, T, self.n_head, head_size)
114
+ q = q.view(B, T, self.n_head, head_size)
115
+ v = v.view(B, T, self.n_head, head_size)
116
+
117
+ # "Unlike standard positional embeddings rotary embeddings must be applied at every layer"
118
+ q = apply_rope(q, rope) # (B, T, nh, hs)
119
+ k = apply_rope(k, rope) # (B, T, nh, hs)
120
+
121
+ # now `key`, 'query` and `value` tensors are correctly represented: for each element in a batch (B)
122
+ # there is a number of heads (nh) and for each head there is a sequence of elements (T), each of them is
123
+ # represented by a vector of size `hs`
124
+ k = k.transpose(1, 2) # (B, nh, T, hs)
125
+ q = q.transpose(1, 2) # (B, nh, T, hs)
126
+ v = v.transpose(1, 2) # (B, nh, T, hs)
127
+
128
+ if kv_cache is not None:
129
+ cache_k, cache_v = kv_cache # 2 * (B, nh, max_seq_length, hs)
130
+ # check if reached token limit
131
+ if input_pos[-1] >= max_seq_length:
132
+ # if we reached token limit and thus there is no space to put newly calculated `key` and `value`
133
+ # right next to cached ones, we need to rotate cache tensor along `max_seq_length` dimension by one
134
+ # element to the left: this will free up space for new `key` and `value`
135
+ input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
136
+ # shift 1 position to the left
137
+ cache_k = torch.roll(cache_k, -1, dims=2)
138
+ cache_v = torch.roll(cache_v, -1, dims=2)
139
+ k = cache_k.index_copy(2, input_pos, k) # (B, nh, max_seq_length, hs)
140
+ v = cache_v.index_copy(2, input_pos, v) # (B, nh, max_seq_length, hs)
141
+ kv_cache = k, v
142
+
143
+ # efficient attention using Flash Attention CUDA kernels
144
+ # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
145
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) # (B, nh, T, hs)
146
+
147
+ # "Adapters are applied to the topmost layers to better tune the language
148
+ # representations with higher-level semantics".
149
+ if self.block_idx >= self.adapter_start_layer:
150
+ if adapter_kv_cache is not None:
151
+ ak, av = adapter_kv_cache # 2 * (B, nh, aT, hs)
152
+ else:
153
+ prefix = self.adapter_wte.weight.reshape(1, self.adapter_prompt_length, self.n_embd)
154
+ aT = prefix.size(1)
155
+ _, ak, av = self.c_attn(prefix).split(self.n_embd, dim=2) # (1, aT, 3 * C) --> 3 * (1, aT, C)
156
+ ak = ak.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) # (B, nh, aT, hs)
157
+ av = av.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) # (B, nh, aT, hs)
158
+ adapter_kv_cache = (ak, av)
159
+
160
+ # Apply cross-attention with `query`, `adapter_key`, `adapter_value` and sum the output with the output
161
+ # obtained from self-attention step. This is mathematically equivalent to concatenation of prefix and input as per paper.
162
+ amask = torch.ones(q.shape[-2], ak.shape[-2], dtype=torch.bool, device=x.device) # (T, aT)
163
+ # ↓ (B, nh, T, hs) @ (B, nh, aT, hs).mT --> (B, nh, T, aT) @ (B, nh, aT, hs) --> (B, nh, T, hs)
164
+ ay = F.scaled_dot_product_attention(q, ak, av, attn_mask=amask, dropout_p=0.0, is_causal=False) # (B, nh, T, hs)
165
+ y = y + self.gating_factor * ay
166
+
167
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
168
+
169
+ # output projection
170
+ y = self.c_proj(y) # (B, T, C)
171
+
172
+ return y, kv_cache, adapter_kv_cache
173
+
174
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
175
+ """For backward compatibility with old checkpoints that have a single gating value for all heads."""
176
+ name = prefix + "gating_factor"
177
+ if name in state_dict:
178
+ tensor = state_dict[name]
179
+ # in case we are loading with `utils.lazy_load()`
180
+ tensor = tensor._load_tensor() if hasattr(tensor, "_load_tensor") else tensor
181
+
182
+ if len(tensor.shape) < 4:
183
+ # For old checkpoints with unified gating value
184
+ state_dict[name] = tensor.reshape(1, 1, 1, 1).repeat(1, self.n_head, 1, 1)
185
+ else:
186
+ state_dict[name] = tensor
187
+
188
+ return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
189
+
190
+
191
+ class Block(nn.Module):
192
+ """The implementation is identical to `lit_llama.model.Block` with the exception that
193
+ we replace the attention layer where adaption is implemented."""
194
+
195
+ def __init__(self, config: LLaMAConfig, block_idx: int) -> None:
196
+ super().__init__()
197
+ self.rms_1 = RMSNorm(config.n_embd)
198
+ self.attn = CausalSelfAttention(config, block_idx)
199
+ self.rms_2 = RMSNorm(config.n_embd)
200
+ self.mlp = MLP(config)
201
+
202
+ def forward(
203
+ self,
204
+ x: torch.Tensor,
205
+ rope: RoPECache,
206
+ mask: torch.Tensor,
207
+ max_seq_length: int,
208
+ input_pos: Optional[torch.Tensor] = None,
209
+ kv_cache: Optional[KVCache] = None,
210
+ adapter_kv_cache: Optional[KVCache] = None,
211
+ ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]:
212
+ h, new_kv_cache, new_adapter_kv_cache = self.attn(
213
+ self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache, adapter_kv_cache
214
+ )
215
+ x = x + h
216
+ x = x + self.mlp(self.rms_2(x))
217
+ return x, new_kv_cache, new_adapter_kv_cache
218
+
219
+
220
+ class LLaMA(llama.LLaMA):
221
+ """The implementation is identical to `lit_llama.model.LLaMA` with the exception that
222
+ the `Block` saves the layer index and passes it down to the attention layer."""
223
+
224
+ def __init__(self, config: LLaMAConfig) -> None:
225
+ nn.Module.__init__(self)
226
+ assert config.vocab_size is not None
227
+ assert config.block_size is not None
228
+ self.config = config
229
+
230
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
231
+ self.transformer = nn.ModuleDict(
232
+ dict(
233
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
234
+ h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
235
+ ln_f=RMSNorm(config.n_embd),
236
+ )
237
+ )
238
+
239
+ self.rope_cache: Optional[RoPECache] = None
240
+ self.mask_cache: Optional[torch.Tensor] = None
241
+ self.kv_caches: List[KVCache] = []
242
+ self.adapter_kv_caches: List[KVCache] = []
243
+
244
+ @classmethod
245
+ def from_name(cls, name: str):
246
+ return cls(LLaMAConfig.from_name(name))
247
+
248
+ def reset_cache(self) -> None:
249
+ super().reset_cache()
250
+ self.adapter_kv_caches.clear()
251
+
252
+ def forward(
253
+ self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None
254
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[KVCache]]]:
255
+ B, T = idx.size()
256
+
257
+ block_size = self.config.block_size
258
+ if max_seq_length is None:
259
+ max_seq_length = block_size
260
+ assert T <= max_seq_length, f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
261
+ assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
262
+ assert T <= block_size, f"Cannot forward sequence of length {T}, block size is only {block_size}"
263
+
264
+ if self.rope_cache is None:
265
+ self.rope_cache = self.build_rope_cache(idx) # (block_size, head_size / 2, 2)
266
+ if self.mask_cache is None:
267
+ self.mask_cache = self.build_mask_cache(idx) # (1, 1, block_size, block_size)
268
+
269
+ if input_pos is not None:
270
+ rope = self.rope_cache.index_select(0, input_pos)
271
+ mask = self.mask_cache.index_select(2, input_pos)
272
+ mask = mask[:, :, :, :max_seq_length]
273
+ else:
274
+ rope = self.rope_cache[:T]
275
+ mask = self.mask_cache[:, :, :T, :T]
276
+
277
+ # forward the model itself
278
+ x = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
279
+
280
+ if input_pos is None: # proxy for use_cache=False
281
+ for block in self.transformer.h:
282
+ x, *_ = block(x, rope, mask, max_seq_length)
283
+ else:
284
+ if not self.kv_caches:
285
+ head_size = self.config.n_embd // self.config.n_head
286
+ cache_shape = (B, self.config.n_head, max_seq_length, head_size)
287
+ self.kv_caches = [
288
+ (torch.zeros(cache_shape, device=x.device, dtype=x.dtype), torch.zeros(cache_shape, device=x.device, dtype=x.dtype))
289
+ for _ in range(self.config.n_layer)
290
+ ]
291
+ if not self.adapter_kv_caches:
292
+ self.adapter_kv_caches = [None for _ in range(self.config.n_layer)]
293
+ for i, block in enumerate(self.transformer.h):
294
+ x, self.kv_caches[i], self.adapter_kv_caches[i] = block(
295
+ x, rope, mask, max_seq_length, input_pos, self.kv_caches[i], self.adapter_kv_caches[i]
296
+ )
297
+
298
+ x = self.transformer.ln_f(x) # (B, T, n_embd)
299
+
300
+ logits = self.lm_head(x) # (B, T, vocab_size)
301
+
302
+ return logits
303
+
304
+
305
+ def mark_only_adapter_as_trainable(model: LLaMA) -> None:
306
+ """Sets `requires_grad=False` for all non-adapter weights."""
307
+ for name, param in model.named_parameters():
308
+ param.requires_grad = "adapter_wte" in name or "gating_factor" in name
309
+
310
+
311
+ def adapter_state_from_state_dict(state_dict: dict) -> dict:
312
+ """Returns the model state dict with only the adapter weights for saving."""
313
+ return {name: param for name, param in state_dict.items() if "adapter_wte" in name or "gating_factor" in name}
lit_llama/adapter_v2.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+
6
+ from lit_llama.adapter import LLaMA
7
+
8
+
9
+ def get_adapter_substrings():
10
+ substrings = ["adapter_wte", "gating_factor"] # regular adapter v1 parameters
11
+ substrings.extend(["adapter_scale", "adapter_bias"]) # adapter v2: new bias and scale used in Linear
12
+ substrings.extend(["rms_1", "rms_2", "ln_f"]) # adapter v2: RMSNorm parameters are now trainable
13
+ return substrings
14
+
15
+
16
+ def mark_only_adapter_v2_as_trainable(model: LLaMA) -> None:
17
+ """Sets `requires_grad=False` for all non-adapter weights."""
18
+ for name, param in model.named_parameters():
19
+ param.requires_grad = any(s in name for s in get_adapter_substrings())
20
+
21
+
22
+ def adapter_v2_state_from_state_dict(state_dict: dict) -> dict:
23
+ """Returns the model state dict with only the adapter weights for saving."""
24
+ return {name: param for name, param in state_dict.items()
25
+ if any(s in name for s in get_adapter_substrings())}
26
+
27
+
28
+ def adapter_v2_new_forward(self, input: Tensor) -> Tensor:
29
+ return self.adapter_scale * (
30
+ F.linear(input, self.weight, self.bias) + self.adapter_bias
31
+ )
32
+
33
+
34
+ def adapter_v2_linear_with_bias_and_scale(layer):
35
+ layer.adapter_bias = torch.nn.Parameter(torch.zeros(layer.weight.shape[0]), requires_grad=True)
36
+ layer.adapter_scale = torch.nn.Parameter(torch.ones(layer.weight.shape[0]), requires_grad=True)
37
+ bound_method = adapter_v2_new_forward.__get__(layer, layer.__class__)
38
+ setattr(layer, 'forward', bound_method)
39
+ return layer
40
+
41
+
42
+ def add_adapter_v2_parameters_to_linear_layers(model):
43
+ for module in model.modules():
44
+ if isinstance(module, nn.Linear):
45
+ adapter_v2_linear_with_bias_and_scale(module)
lit_llama/lora.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Derived from https://github.com/microsoft/LoRA
2
+ # ------------------------------------------------------------------------------------------
3
+ # Copyright (c) Microsoft Corporation. All rights reserved.
4
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
5
+ # ------------------------------------------------------------------------------------------
6
+
7
+ r"""
8
+ Low Ranking Adaptation for LLMs scheme.
9
+
10
+ ┌───────────────────┐
11
+ ┆ h ┆
12
+ └───────────────────┘
13
+
14
+ |
15
+ +
16
+ / \
17
+ ┌─────────────────┐ ╭───────────────╮ Matrix initialization:
18
+ ┆ ┆ \ B / B = 0
19
+ ┆ pretrained ┆ \ r*d / A = N(0, sigma^2)
20
+ ┆ weights ┆ ╰─────────╯
21
+ ┆ ┆ | r | r - rank
22
+ ┆ W e R^(d*d) ┆ | ◀─────▶ |
23
+ ┆ ┆ ╭─────────╮
24
+ └─────────────────┘ / A \
25
+ ▲ / d*r \
26
+ \ ╰───────────────╯
27
+ \ ▲
28
+ \ /
29
+ \ /
30
+ ┌───────────────────┐
31
+ ┆ x ┆
32
+ └───────────────────┘
33
+
34
+ With LoRA (Low Ranking Adaptation: https://arxiv.org/abs/2106.09685) instead of learning weights of size d*d,
35
+ we can freeze the pretrained weights and instead learn two matrices of size d*r and r*d (they will store weight updates
36
+ for the pretrained weights): the number of parameters in this case will be reduced drastically (depending on the rank of
37
+ course) yet after multiplication of matrices d*r and r*d we will get a matrix d*d which we can sum with frozen
38
+ pretrained weights and thus fine-tune the model.
39
+
40
+ The goal of this approach is to move weight updates into a separate matrix which is decomposed with
41
+ two matrices of a lower rank.
42
+ """
43
+
44
+ import torch
45
+ import torch.nn as nn
46
+ import torch.nn.functional as F
47
+
48
+ import math
49
+ from typing import Dict, List
50
+
51
+ import lit_llama.model as llama
52
+
53
+ from contextlib import contextmanager
54
+ from dataclasses import dataclass
55
+
56
+
57
+ class LoRALayer():
58
+ def __init__(
59
+ self,
60
+ r: int,
61
+ lora_alpha: int,
62
+ lora_dropout: float,
63
+ merge_weights: bool,
64
+ ):
65
+ """Store LoRA specific attributes in a class.
66
+
67
+ Args:
68
+ r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
69
+ the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
70
+ lora_alpha: alpha is needed for scaling updates as alpha/r
71
+ "This scaling helps to reduce the need to retune hyperparameters when we vary r"
72
+ https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
73
+ lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
74
+ merge_weights: whether we want to merge pretrained weights and LoRA weight updates. This is useful if one wants to use
75
+ fine-tuned model as a standalone one (without storing LoRA weights separately) plus it helps to reduce
76
+ overhead during inference.
77
+ """
78
+ self.r = r
79
+ self.lora_alpha = lora_alpha
80
+ # Optional dropout
81
+ if lora_dropout > 0.:
82
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
83
+ else:
84
+ self.lora_dropout = lambda x: x
85
+ # Mark the weight as unmerged
86
+ self.merged = False
87
+ self.merge_weights = merge_weights
88
+
89
+
90
+ class MergedLinear(nn.Linear, LoRALayer):
91
+ # LoRA implemented in a dense layer
92
+ def __init__(
93
+ self,
94
+ # ↓ this part is for pretrained weights
95
+ in_features: int,
96
+ out_features: int,
97
+ # ↓ the remaining part is for LoRA
98
+ r: int = 0,
99
+ lora_alpha: int = 1,
100
+ lora_dropout: float = 0.,
101
+ enable_lora: List[bool] = [False],
102
+ fan_in_fan_out: bool = False,
103
+ merge_weights: bool = True,
104
+ **kwargs
105
+ ):
106
+ """LoRA wrapper around linear class that is used for calculation of q, k and v matrices.
107
+
108
+ This class has three weight matrices:
109
+ 1. Pretrained weights are stored as `self.weight` (because of the nn.Linear inheritance)
110
+ 2. LoRA A matrix as `self.lora_A`
111
+ 3. LoRA B matrix as `self.lora_B`
112
+ Only LoRA's A and B matrices are updated, pretrained weights stay frozen.
113
+
114
+ Args:
115
+ in_features: number of input features of the pretrained weights
116
+ out_features: number of output features of the pretrained weights
117
+ r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
118
+ the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
119
+ lora_alpha: alpha is needed for scaling updates as alpha/r
120
+ "This scaling helps to reduce the need to retune hyperparameters when we vary r"
121
+ https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
122
+ lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
123
+ enable_lora: MergeLinear class is for attention mechanism where qkv are calculated with a single weight matrix. If we
124
+ don't want to apply LoRA for all three (query, key and value) we can set it as False. For example if we want
125
+ to apply LoRA only to `query` and `value` but keep `key` without weight updates we should pass `[True,
126
+ False, True]`
127
+ fan_in_fan_out: set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses
128
+ `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`
129
+ https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora.py#LL53C9-L53C112
130
+ merge_weights: whether we want to merge pretrained weights and LoRA weight updates. This is useful if one wants to use
131
+ fine-tuned model as a standalone one (without storing LoRA weight separately) plus it helps to reduce
132
+ overhead during inference.
133
+ """
134
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
135
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
136
+ merge_weights=merge_weights)
137
+ assert out_features % len(enable_lora) == 0, \
138
+ 'The length of enable_lora must divide out_features'
139
+ self.enable_lora = enable_lora
140
+ self.fan_in_fan_out = fan_in_fan_out
141
+
142
+ # Actual trainable parameters
143
+ # To better understand initialization let's imagine that we have such parameters:
144
+ # ⚬ in_features: 128 (embeddings_size)
145
+ # ⚬ out_features: 384 (3 * embedding_size)
146
+ # ⚬ r: 2
147
+ # ⚬ enable_lora: [True, False, True]
148
+ if r > 0 and any(enable_lora):
149
+ self.lora_A = nn.Parameter(
150
+ self.weight.new_zeros((r * sum(enable_lora), in_features))) # (4, 128)
151
+ self.lora_B = nn.Parameter(
152
+ self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r)) # (256, 2)
153
+ ) # weights for Conv1D with groups=sum(enable_lora)
154
+ # Notes about shapes above
155
+ # - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices;
156
+ # 128 is the input size of the x (embedding size). (4, 128) and not (128, 4) because later on in
157
+ # F.linear function weights are automatically transposed. In addition conv1d requires channels to
158
+ # be before seq length
159
+ # - self.lora_B has shape (256, 2): 256 because LoRA is applied only to two matrices, so the output is
160
+ # 128*2; 2 tells to have two channels per group for group convolution
161
+
162
+ # Scaling:
163
+ # This balances the pretrained model`s knowledge and the new task-specific adaptation
164
+ # https://lightning.ai/pages/community/tutorial/lora-llm/
165
+ # So, set alpha to 1.0 to fully add LoRA. If the LoRA seems to have too much effect (i.e., overfitted), set
166
+ # alpha to lower value. If the LoRA seems to have too little effect, set alpha to higher than 1.0. You can
167
+ # tune these values to your needs. This value can be even slightly greater than 1.0!
168
+ # https://github.com/cloneofsimo/lora
169
+ self.scaling = self.lora_alpha / self.r
170
+
171
+ # Freezing the pre-trained weight matrix
172
+ self.weight.requires_grad = False # (384, 128)
173
+
174
+ # Compute the indices
175
+ # Indices are needed to properly pad weight updates with zeros. If we want to fine-tune queries and values,
176
+ # but not keys, then the weights update should be:
177
+ #
178
+ # [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
179
+ # [....................................],
180
+ # [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
181
+ # ↑ ↑ ↑
182
+ # ________________________________________
183
+ # | query | key | value |
184
+ # ----------------------------------------
185
+ self.lora_ind = self.weight.new_zeros(
186
+ (out_features, ), dtype=torch.bool
187
+ ).view(len(enable_lora), -1) # (3, 128)
188
+ self.lora_ind[enable_lora, :] = True # (3, 128)
189
+ self.lora_ind = self.lora_ind.view(-1) # (384,)
190
+ self.reset_parameters()
191
+ if fan_in_fan_out:
192
+ self.weight.data = self.weight.data.T
193
+
194
+ def reset_parameters(self):
195
+ """Reset all the weights, even including pretrained ones."""
196
+ nn.Linear.reset_parameters(self)
197
+ if hasattr(self, 'lora_A'):
198
+ # initialize A the same way as the default for nn.Linear and B to zero
199
+ # Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314
200
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
201
+ nn.init.zeros_(self.lora_B)
202
+
203
+ def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
204
+ """Properly pad weight updates with zeros.
205
+
206
+ If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys,
207
+ then the weights update should be:
208
+
209
+ [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
210
+ [....................................],
211
+ [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
212
+ ↑ ↑ ↑
213
+ ________________________________________
214
+ | query | key | value |
215
+ ----------------------------------------
216
+
217
+ Args:
218
+ x: tensor with weights update that will be padded with zeros if necessary
219
+
220
+ Returns:
221
+ A tensor with weight updates and zeros for deselected q, k or v
222
+ """
223
+ # Let's image that:
224
+ # ⚬ input x has shape (64, 64, 256): (batch_size, sequence_length, embeddings_size)
225
+ # ⚬ embeddings_size: 128
226
+ # ⚬ self.out_features: 384 (3 * embeddings_size)
227
+ # ⚬ enable_lora: [True, False, True]
228
+ # Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected
229
+ # embeddings_size is 384 (self.out_features), so that means that we need to pad from 256 to 384 with zeros, but
230
+ # only for key updates (this is where self.lora_ind comes in handy)
231
+ # Note: double transpose (in the beginning and in the end) is basically a guard for two-dimensional tensors
232
+ # for example when we want to merge/unmerge LoRA weights and pretrained weights
233
+ x = x.transpose(0, 1)
234
+ result = x.new_zeros((*x.shape[:-1], self.out_features)) # (64, 64, 384)
235
+ result = result.view(-1, self.out_features) # (4096, 384)
236
+ result[:, self.lora_ind] = x.reshape(
237
+ -1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
238
+ ) # (4096, 256)
239
+ return result.view((*x.shape[:-1], self.out_features)).transpose(0, 1) # (64, 64, 384)
240
+
241
+ def train(self, mode: bool = True):
242
+ """Set the module into train or eval mode if `mode` is True of False respectively.
243
+
244
+ For train mode (train(True)) if weights are merged we need to subtract weights updates (LoRA_A @ LoRA_B) from
245
+ pretrained weights so we can continue training LoRA's matrices A and B and keep pretrained weights frozen.
246
+
247
+ For eval mode (train(False)) if weights are not merged we need to add weight updates to pretrained weights in
248
+ order to reduce computational overhead during inference.
249
+
250
+ Args:
251
+ mode: if True the module will be set into train mode (affects Dropout and BatchNorm), if False - eval mode.
252
+
253
+ """
254
+ def T(w):
255
+ return w.T if self.fan_in_fan_out else w
256
+ # despite being called from nn.Linear this method will put all layers into train mode, including nn.Dropout
257
+ # of course except parameters (such as self.lora_A, self.lora_B)
258
+ nn.Linear.train(self, mode)
259
+
260
+ # if train(True) -> unmerge unless we already have them unmerged
261
+ # if train(False) -> merge unless we already have them merged
262
+ should = self.merged if mode else not self.merged
263
+
264
+ # Let's assume that:
265
+ # ⚬ self.weight.data: (384, 128) or (3 * embedding_size, embedding_size)
266
+ # ⚬ self.lora_A.data: (4, 128)
267
+ # ⚬ self.lora_B.data: (256, 2)
268
+ if self.merge_weights and should:
269
+ if self.r > 0 and any(self.enable_lora):
270
+ delta_w = F.conv1d(
271
+ self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128)
272
+ self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
273
+ groups=sum(self.enable_lora)
274
+ ).squeeze(0) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)
275
+ # -1: W = W - delta_W (unmerge), +1: W = W + delta_W (merge)
276
+ sign = -1 if mode else 1
277
+ self.weight.data += sign * self.zero_pad(T(delta_w * self.scaling)) # (256, 128) after zero_pad (384, 128)
278
+ self.merged = not mode
279
+
280
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
281
+ """Do the forward pass.
282
+
283
+ If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication.
284
+ If not, then multiply pretrained weights with input, apply LoRA on input and do summation.
285
+
286
+ Args:
287
+ x: input tensor of shape (batch_size, context_length, embedding_size)
288
+
289
+ Returns:
290
+ Output tensor of shape (batch_size, context_length, 3 * embedding_size)
291
+ """
292
+ def T(w):
293
+ return w.T if self.fan_in_fan_out else w
294
+
295
+ # Let's assume that:
296
+ # ⚬ x: (64, 64, 128) or (batch_size, context_length, embedding_size)
297
+ # ⚬ self.weight: (384, 128) or (3 * embedding_size, embedding_size)
298
+ # ⚬ self.lora_A.data: (4, 128)
299
+ # ⚬ self.lora_B.data: (256, 2)
300
+
301
+ # the logic here is that the weights are merged only during inference
302
+ # so if they are merged we don't need to do anything with LoRA's A and B matrices
303
+ # but if the weights are not merged that means that the forward method is called during
304
+ # training and we need to forward pass input through pretrained weights, LoRA A and B matrices
305
+ # and do the summation (as per scheme at the top of the file)
306
+ if self.merged:
307
+ return F.linear(x, T(self.weight), bias=self.bias)
308
+ else:
309
+ # `F.linear` automatically transposes the second argument (T(self.weight) in our case)
310
+ result = F.linear(x, T(self.weight), bias=self.bias) # (64, 64, 128) @ (384, 128) -> (64, 64, 384)
311
+ if self.r > 0:
312
+ after_A = F.linear(self.lora_dropout(x), self.lora_A) # (64, 64, 128) @ (4, 128) -> (64, 64, 4)
313
+ # For F.conv1d:
314
+ # ⚬ input: input tensor of shape (mini-batch, in_channels, iW)
315
+ # ⚬ weight: filters of shape (out_channels, in_channels/groups, kW)
316
+ # ⚬ groups: split input into groups, in_channels should be divisible by the number of groups. Default: 1
317
+ # presumably iW - sequence width/length, kW - kernel width
318
+ after_B = F.conv1d(
319
+ after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64)
320
+ self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
321
+ groups=sum(self.enable_lora)
322
+ ).transpose(-2, -1) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256)
323
+ result += self.zero_pad(after_B) * self.scaling # (64, 64, 256) after zero_pad (64, 64, 384)
324
+ return result
325
+
326
+
327
+ def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
328
+ """Freeze all modules except LoRA's and depending on 'bias' value unfreezes bias weights.
329
+
330
+ Args:
331
+ model: model with LoRA layers
332
+ bias:
333
+ ``"none"``: all bias weights will be frozen,
334
+ ``"lora_only"``: only bias weight for LoRA layers will be unfrozen,
335
+ ``"all"``: all bias weights will be unfrozen.
336
+
337
+ Raises:
338
+ NotImplementedError: if `bias` not in ["none", "lora_only", "all"]
339
+ """
340
+ # freeze all layers except LoRA's
341
+ for n, p in model.named_parameters():
342
+ if 'lora_' not in n:
343
+ p.requires_grad = False
344
+
345
+ # depending on the `bias` value unfreeze bias weights
346
+ if bias == 'none':
347
+ return
348
+ elif bias == 'all':
349
+ for n, p in model.named_parameters():
350
+ if 'bias' in n:
351
+ p.requires_grad = True
352
+ elif bias == 'lora_only':
353
+ for m in model.modules():
354
+ if isinstance(m, LoRALayer) and \
355
+ hasattr(m, 'bias') and \
356
+ m.bias is not None:
357
+ m.bias.requires_grad = True
358
+ else:
359
+ raise NotImplementedError
360
+
361
+
362
+ def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
363
+ """Return state_dict with weights of LoRA's A and B matrices and with biases depending on the `bias` value.
364
+
365
+ Args:
366
+ model: model with LoRA layers
367
+ bias:
368
+ ``"none"``: state dict will not store bias weights,
369
+ ``"lora_only"``: state dict will store bias weights only from LoRA layers,
370
+ ``"all"``: state dict will store all bias weights.
371
+
372
+ Returns:
373
+ Weights and biases of LoRA layers
374
+
375
+ Raises:
376
+ NotImplementedError: if `bias` not in ["none", "lora_only", "all"]
377
+ """
378
+ my_state_dict = model.state_dict()
379
+ if bias == 'none':
380
+ return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
381
+ elif bias == 'all':
382
+ return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k}
383
+ elif bias == 'lora_only':
384
+ to_return = {}
385
+ for k in my_state_dict:
386
+ if 'lora_' in k:
387
+ to_return[k] = my_state_dict[k]
388
+ bias_name = k.split('lora_')[0]+'bias'
389
+ if bias_name in my_state_dict:
390
+ to_return[bias_name] = my_state_dict[bias_name]
391
+ return to_return
392
+ else:
393
+ raise NotImplementedError
394
+
395
+
396
+ @dataclass
397
+ class LoRAConfig:
398
+ r: float = 0.0
399
+ alpha: float = 1.0
400
+ dropout: float = 0.0
401
+
402
+
403
+ class CausalSelfAttention(llama.CausalSelfAttention):
404
+ lora_config = None
405
+
406
+ def __init__(self, config: llama.LLaMAConfig) -> None:
407
+ """Causal self-attention with calculating qkv matrices with a single matrix* and Low Ranking Adaptation for
408
+ parameter-efficient fine-tuning.
409
+
410
+ *Instead of creating multiple heads and concatenating the result (in addition to creating separate matrices for
411
+ query, key and value for each head) we can do this in a single pass with a single weight matrix.
412
+
413
+ Args:
414
+ config:
415
+ ``"block_size"``: size of the context of the model,
416
+ ``"vocab_size"``: number of unique tokens,
417
+ ``"padded_vocab_size"``: padded size of the vocabulary to the nearest multiple of 64 (leads to a greater performance),
418
+ ``"n_layer"``: number of transformer blocks (self-attention + MLP),
419
+ ``"n_head"``: number of heads in multi-head attention mechanism,
420
+ ``"n_embd"``: size of the embedding: vector representation of each token.
421
+ """
422
+ # Skip the parent class __init__ altogether and replace it to avoid
423
+ # useless allocations
424
+ nn.Module.__init__(self)
425
+ assert config.n_embd % config.n_head == 0
426
+
427
+ # key, query, value projections for all heads, but in a batch
428
+ self.c_attn = MergedLinear(
429
+ in_features=config.n_embd,
430
+ out_features=3 * config.n_embd,
431
+ r=self.lora_config.r,
432
+ lora_alpha=self.lora_config.alpha,
433
+ lora_dropout=self.lora_config.dropout,
434
+ enable_lora=[True, False, True],
435
+ fan_in_fan_out = False,
436
+ merge_weights=True,
437
+ bias=False)
438
+ # output projection
439
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
440
+ # regularization
441
+ self.n_head = config.n_head
442
+ self.n_embd = config.n_embd
443
+ self.block_size = config.block_size
444
+ self.rope_cache = None
445
+
446
+
447
+ @contextmanager
448
+ def lora(r, alpha, dropout, enabled: bool = True):
449
+ """Apply context manager under which you can instantiate the model with LoRA.
450
+
451
+ In a nutshell the code inside this function forces to use LoRA variant of causal self-attention
452
+ instead of the original one (without LoRA).
453
+
454
+ Args:
455
+ r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
456
+ the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
457
+ alpha: alpha is needed for scaling updates as alpha/r
458
+ "This scaling helps to reduce the need to retune hyperparameters when we vary r"
459
+ https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
460
+ dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
461
+ enabled: enables/disables LoRA
462
+ """
463
+ if not enabled:
464
+ yield
465
+ return
466
+
467
+ CausalSelfAttention.lora_config = LoRAConfig(r=r, alpha=alpha, dropout=dropout)
468
+ # when entering context manager replace link to causal self-attention class from original
469
+ # to a variant with LoRA
470
+ causal_self_attention = llama.CausalSelfAttention
471
+ llama.CausalSelfAttention = CausalSelfAttention
472
+ yield
473
+ # when exiting context manager - restore link to original causal self-attention class
474
+ llama.CausalSelfAttention = causal_self_attention
475
+
476
+ CausalSelfAttention.lora_config = None
lit_llama/model.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Full definition of a LLaMA Language Model, all of it in this single file.
2
+
3
+ Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
4
+ """
5
+ # mypy: ignore-errors
6
+ import math
7
+ from dataclasses import dataclass
8
+ from typing import List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.nn import functional as F
13
+ from typing_extensions import Self
14
+
15
+ from lit_llama.utils import find_multiple
16
+
17
+
18
+ MaskCache = torch.Tensor
19
+ RoPECache = torch.Tensor
20
+ KVCache = Tuple[torch.Tensor, torch.Tensor]
21
+
22
+
23
+ @dataclass
24
+ class LLaMAConfig:
25
+ block_size: int = 2048
26
+ vocab_size: int = 32000
27
+ padded_vocab_size: Optional[int] = None
28
+ n_layer: int = 32
29
+ n_head: int = 32
30
+ n_embd: int = 4096
31
+
32
+ def __post_init__(self):
33
+ if self.padded_vocab_size is None:
34
+ self.padded_vocab_size = find_multiple(self.vocab_size, 64)
35
+
36
+ @classmethod
37
+ def from_name(cls, name: str) -> Self:
38
+ return cls(**llama_configs[name])
39
+
40
+
41
+ llama_configs = {
42
+ "7B": dict(n_layer=32, n_head=32, n_embd=4096),
43
+ "13B": dict(n_layer=40, n_head=40, n_embd=5120),
44
+ "30B": dict(n_layer=60, n_head=52, n_embd=6656),
45
+ "65B": dict(n_layer=80, n_head=64, n_embd=8192),
46
+ }
47
+
48
+
49
+ class LLaMA(nn.Module):
50
+ def __init__(self, config: LLaMAConfig) -> None:
51
+ super().__init__()
52
+ assert config.padded_vocab_size is not None
53
+ self.config = config
54
+
55
+ self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
56
+ self.transformer = nn.ModuleDict(
57
+ dict(
58
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
59
+ h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
60
+ ln_f=RMSNorm(config.n_embd),
61
+ )
62
+ )
63
+
64
+ self.rope_cache: Optional[RoPECache] = None
65
+ self.mask_cache: Optional[MaskCache] = None
66
+ self.kv_caches: List[KVCache] = []
67
+
68
+ def _init_weights(self, module: nn.Module) -> None:
69
+ if isinstance(module, nn.Linear):
70
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
71
+ elif isinstance(module, nn.Embedding):
72
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
73
+
74
+ def forward(
75
+ self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None
76
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[KVCache]]]:
77
+ B, T = idx.size()
78
+
79
+ block_size = self.config.block_size
80
+ if max_seq_length is None:
81
+ max_seq_length = block_size
82
+ assert T <= max_seq_length, f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
83
+ assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
84
+ assert T <= block_size, f"Cannot forward sequence of length {T}, block size is only {block_size}"
85
+
86
+ if self.rope_cache is None:
87
+ self.rope_cache = self.build_rope_cache(idx)
88
+ if self.mask_cache is None:
89
+ self.mask_cache = self.build_mask_cache(idx)
90
+
91
+ if input_pos is not None:
92
+ rope = self.rope_cache.index_select(0, input_pos)
93
+ mask = self.mask_cache.index_select(2, input_pos)
94
+ mask = mask[:, :, :, :max_seq_length]
95
+ else:
96
+ rope = self.rope_cache[:T]
97
+ mask = self.mask_cache[:, :, :T, :T]
98
+
99
+ # forward the model itself
100
+ x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
101
+
102
+ if input_pos is None: # proxy for use_cache=False
103
+ for block in self.transformer.h:
104
+ x, _ = block(x, rope, mask, max_seq_length)
105
+ else:
106
+ if not self.kv_caches:
107
+ head_size = self.config.n_embd // self.config.n_head
108
+ cache_shape = (B, self.config.n_head, max_seq_length, head_size)
109
+ self.kv_caches = [
110
+ (torch.zeros(cache_shape, device=x.device, dtype=x.dtype), torch.zeros(cache_shape, device=x.device, dtype=x.dtype))
111
+ for _ in range(self.config.n_layer)
112
+ ]
113
+ for i, block in enumerate(self.transformer.h):
114
+ x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])
115
+
116
+ x = self.transformer.ln_f(x)
117
+
118
+ logits = self.lm_head(x) # (b, t, vocab_size)
119
+
120
+ return logits
121
+
122
+ @classmethod
123
+ def from_name(cls, name: str) -> Self:
124
+ return cls(LLaMAConfig.from_name(name))
125
+
126
+ def build_rope_cache(self, idx: torch.Tensor) -> RoPECache:
127
+ return build_rope_cache(
128
+ seq_len=self.config.block_size,
129
+ n_elem=self.config.n_embd // self.config.n_head,
130
+ dtype=idx.dtype,
131
+ device=idx.device,
132
+ )
133
+
134
+ def build_mask_cache(self, idx: torch.Tensor) -> MaskCache:
135
+ ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool)
136
+ return torch.tril(ones).unsqueeze(0).unsqueeze(0)
137
+
138
+ def reset_cache(self) -> None:
139
+ self.kv_caches.clear()
140
+ if self.mask_cache.device.type == "xla":
141
+ # https://github.com/Lightning-AI/lit-parrot/pull/83#issuecomment-1558150179
142
+ self.rope_cache = None
143
+ self.mask_cache = None
144
+
145
+
146
+ class Block(nn.Module):
147
+ def __init__(self, config: LLaMAConfig) -> None:
148
+ super().__init__()
149
+ self.rms_1 = RMSNorm(config.n_embd)
150
+ self.attn = CausalSelfAttention(config)
151
+ self.rms_2 = RMSNorm(config.n_embd)
152
+ self.mlp = MLP(config)
153
+
154
+ def forward(
155
+ self,
156
+ x: torch.Tensor,
157
+ rope: RoPECache,
158
+ mask: MaskCache,
159
+ max_seq_length: int,
160
+ input_pos: Optional[torch.Tensor] = None,
161
+ kv_cache: Optional[KVCache] = None,
162
+ ) -> Tuple[torch.Tensor, Optional[KVCache]]:
163
+ h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache)
164
+ x = x + h
165
+ x = x + self.mlp(self.rms_2(x))
166
+ return x, new_kv_cache
167
+
168
+
169
+ class CausalSelfAttention(nn.Module):
170
+ def __init__(self, config: LLaMAConfig) -> None:
171
+ super().__init__()
172
+ assert config.n_embd % config.n_head == 0
173
+
174
+ # key, query, value projections for all heads, but in a batch
175
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
176
+ # output projection
177
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
178
+
179
+ self.n_head = config.n_head
180
+ self.n_embd = config.n_embd
181
+ self.block_size = config.block_size
182
+
183
+ def forward(
184
+ self,
185
+ x: torch.Tensor,
186
+ rope: RoPECache,
187
+ mask: MaskCache,
188
+ max_seq_length: int,
189
+ input_pos: Optional[torch.Tensor] = None,
190
+ kv_cache: Optional[KVCache] = None,
191
+ ) -> Tuple[torch.Tensor, Optional[KVCache]]:
192
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
193
+
194
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
195
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
196
+
197
+ head_size = C // self.n_head
198
+ k = k.view(B, T, self.n_head, head_size)
199
+ q = q.view(B, T, self.n_head, head_size)
200
+ v = v.view(B, T, self.n_head, head_size)
201
+
202
+ q = apply_rope(q, rope)
203
+ k = apply_rope(k, rope)
204
+
205
+ k = k.transpose(1, 2) # (B, nh, T, hs)
206
+ q = q.transpose(1, 2) # (B, nh, T, hs)
207
+ v = v.transpose(1, 2) # (B, nh, T, hs)
208
+
209
+ if kv_cache is not None:
210
+ cache_k, cache_v = kv_cache
211
+ # check if reached token limit
212
+ if input_pos[-1] >= max_seq_length:
213
+ input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
214
+ # shift 1 position to the left
215
+ cache_k = torch.roll(cache_k, -1, dims=2)
216
+ cache_v = torch.roll(cache_v, -1, dims=2)
217
+ k = cache_k.index_copy(2, input_pos, k)
218
+ v = cache_v.index_copy(2, input_pos, v)
219
+ kv_cache = k, v
220
+
221
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
222
+ # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
223
+ # att = att.masked_fill(mask[:,:,:T,:T] == 0, float('-inf'))
224
+ # att = F.softmax(att, dim=-1)
225
+ # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
226
+
227
+ # efficient attention using Flash Attention CUDA kernels
228
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
229
+
230
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
231
+
232
+ # output projection
233
+ y = self.c_proj(y)
234
+
235
+ return y, kv_cache
236
+
237
+
238
+ class MLP(nn.Module):
239
+ def __init__(self, config: LLaMAConfig) -> None:
240
+ super().__init__()
241
+ hidden_dim = 4 * config.n_embd
242
+ n_hidden = int(2 * hidden_dim / 3)
243
+ n_hidden = find_multiple(n_hidden, 256)
244
+
245
+ self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False)
246
+ self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False)
247
+ self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False)
248
+
249
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
250
+ x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
251
+ x = self.c_proj(x)
252
+ return x
253
+
254
+
255
+ class RMSNorm(nn.Module):
256
+ """Root Mean Square Layer Normalization.
257
+
258
+ Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
259
+ https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
260
+ """
261
+
262
+ def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
263
+ super().__init__()
264
+ self.scale = nn.Parameter(torch.ones(size))
265
+ self.eps = eps
266
+ self.dim = dim
267
+
268
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
269
+ # NOTE: the original RMSNorm paper implementation is not equivalent
270
+ # norm_x = x.norm(2, dim=self.dim, keepdim=True)
271
+ # rms_x = norm_x * d_x ** (-1. / 2)
272
+ # x_normed = x / (rms_x + self.eps)
273
+ norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
274
+ x_normed = x * torch.rsqrt(norm_x + self.eps)
275
+ return self.scale * x_normed
276
+
277
+
278
+ def build_rope_cache(
279
+ seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
280
+ ) -> RoPECache:
281
+ """Enhanced Transformer with Rotary Position Embedding.
282
+
283
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
284
+ transformers/rope/__init__.py. MIT License:
285
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
286
+ """
287
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
288
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
289
+
290
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
291
+ seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
292
+
293
+ # Calculate the product of position index and $\theta_i$
294
+ idx_theta = torch.outer(seq_idx, theta).float()
295
+
296
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
297
+
298
+ # this is to mimic the behaviour of complex32, else we will get different results
299
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
300
+ cache = cache.half()
301
+ return cache
302
+
303
+
304
+ def apply_rope(x: torch.Tensor, rope_cache: RoPECache) -> torch.Tensor:
305
+ # truncate to support variable sizes
306
+ T = x.size(1)
307
+ rope_cache = rope_cache[:T]
308
+
309
+ # cast because the reference does
310
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
311
+ rope_cache = rope_cache.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
312
+ x_out2 = torch.stack(
313
+ [
314
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
315
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
316
+ ],
317
+ -1,
318
+ )
319
+
320
+ x_out2 = x_out2.flatten(3)
321
+ return x_out2.type_as(x)
lit_llama/packed_dataset.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Very loosely inspired by indexed_dataset in Fairseq, Megatron
2
+ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py
3
+
4
+
5
+ import os
6
+ import struct
7
+ import random
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch.utils.data import IterableDataset, get_worker_info
12
+
13
+
14
+ dtypes = {
15
+ 1: np.uint8,
16
+ 2: np.int8,
17
+ 3: np.int16,
18
+ 4: np.int32,
19
+ 5: np.int64,
20
+ 6: np.float32,
21
+ 7: np.float64,
22
+ 8: np.uint16,
23
+ }
24
+
25
+
26
+ def code(dtype):
27
+ for k in dtypes.keys():
28
+ if dtypes[k] == dtype:
29
+ return k
30
+ raise ValueError(dtype)
31
+
32
+
33
+ HDR_MAGIC = b"LITPKDS"
34
+ HDR_SIZE = 24 # bytes
35
+
36
+
37
+ class PackedDataset(IterableDataset):
38
+ def __init__(self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0):
39
+ self._filenames = filenames
40
+ self._n_chunks = n_chunks
41
+ self._block_size = block_size
42
+ self._seed = seed
43
+ self._shuffle = shuffle
44
+ self._wrap = wrap
45
+ self._num_processes = num_processes
46
+ self._process_rank = process_rank
47
+
48
+ def __iter__(self):
49
+ worker_info = get_worker_info()
50
+ num_workers = worker_info.num_workers if worker_info is not None else 1
51
+ worker_id = worker_info.id if worker_info is not None else 0
52
+ num_shards = num_workers * self._num_processes
53
+ shard_id = self._process_rank * num_workers + worker_id
54
+
55
+ max_num_files = len(self._filenames) // num_shards * num_shards
56
+ filenames = self._filenames[shard_id : max_num_files : num_shards]
57
+
58
+ return PackedDatasetIterator(
59
+ filenames=filenames,
60
+ n_chunks=self._n_chunks,
61
+ block_size=self._block_size,
62
+ seed=self._seed,
63
+ shuffle=self._shuffle,
64
+ wrap=self._wrap,
65
+ )
66
+
67
+
68
+ class PackedDatasetBuilder(object):
69
+ def __init__(
70
+ self,
71
+ outdir,
72
+ prefix,
73
+ chunk_size,
74
+ sep_token,
75
+ dtype="auto",
76
+ vocab_size=None,
77
+ ):
78
+ if dtype == "auto":
79
+ if vocab_size is None:
80
+ raise ValueError("vocab_size cannot be None when dtype='auto'")
81
+ if vocab_size is not None and vocab_size < 65500:
82
+ self._dtype = np.uint16
83
+ else:
84
+ self._dtype = np.int32
85
+ else:
86
+ self._dtype = dtype
87
+ self._counter = 0
88
+ self._chunk_size = chunk_size
89
+ self._outdir = outdir
90
+ self._prefix = prefix
91
+ self._sep_token = sep_token
92
+ self._arr = np.zeros(self._chunk_size, dtype=self._dtype)
93
+ self._arr.fill(self._sep_token)
94
+ self._idx = 0
95
+ self._version = 1
96
+ self._filenames = []
97
+
98
+ def _write_chunk(self):
99
+ filename = f"{self._prefix}_{self._counter:010d}.bin"
100
+ filename = os.path.join(self._outdir, filename)
101
+
102
+ with open(filename, "wb") as f:
103
+ f.write(HDR_MAGIC)
104
+ f.write(struct.pack("<Q", self._version))
105
+ f.write(struct.pack("<B", code(self._dtype)))
106
+ f.write(struct.pack("<Q", self._chunk_size))
107
+ f.write(self._arr.tobytes(order="C"))
108
+
109
+ self._filenames.append(filename)
110
+ self._counter += 1
111
+ self._arr.fill(self._sep_token)
112
+ self._idx = 0
113
+
114
+ @property
115
+ def dtype(self):
116
+ return self._dtype
117
+
118
+ @property
119
+ def filenames(self):
120
+ return self._filenames.copy()
121
+
122
+ def add_array(self, arr):
123
+ while self._idx + arr.shape[0] > self._chunk_size:
124
+ part_len = self._chunk_size - self._idx
125
+ self._arr[self._idx : self._idx + part_len] = arr[:part_len]
126
+ self._write_chunk()
127
+ arr = arr[part_len:]
128
+
129
+ arr_len = arr.shape[0]
130
+ self._arr[self._idx : self._idx + arr_len] = arr
131
+ self._idx += arr_len
132
+
133
+ def write_reminder(self):
134
+ self._write_chunk()
135
+
136
+
137
+ class PackedDatasetIterator:
138
+ def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap):
139
+ self._seed = seed
140
+ self._shuffle = shuffle
141
+ self._rng = np.random.default_rng(seed) if shuffle else None
142
+ self._block_idxs = None
143
+
144
+ self._wrap = wrap
145
+
146
+ # TODO: instead of filenames, we could have a single text stream
147
+ # (or text file) with the sequence of all files to be
148
+ # fetched/loaded.
149
+ self._filenames = filenames
150
+ self._file_idx = 0
151
+
152
+ self._n_chunks = n_chunks
153
+
154
+ self._dtype = None
155
+ self._block_size = block_size
156
+ self._n_blocks = None
157
+
158
+ self._mmaps = []
159
+ self._buffers = []
160
+
161
+ self._block_idxs = []
162
+ self._curr_idx = 0
163
+
164
+ self._load_n_chunks()
165
+
166
+ def _read_header(self, path):
167
+ with open(path, "rb") as f:
168
+ magic = f.read(len(HDR_MAGIC))
169
+ assert magic == HDR_MAGIC, "File doesn't match expected format."
170
+ version = struct.unpack("<Q", f.read(8))
171
+ assert (1,) == version
172
+ (dtype_code,) = struct.unpack("<B", f.read(1))
173
+ dtype = dtypes[dtype_code]
174
+ (chunk_size,) = struct.unpack("<Q", f.read(8))
175
+ return dtype, chunk_size
176
+
177
+ def _close_mmaps(self):
178
+ for mmap in self._mmaps:
179
+ mmap._mmap.close()
180
+
181
+ def _load_n_chunks(self):
182
+ self._close_mmaps()
183
+ self._mmaps = []
184
+ self._buffers = []
185
+
186
+ if self._n_chunks > len(self._filenames[self._file_idx:]):
187
+ if not self._wrap:
188
+ raise StopIteration
189
+ else:
190
+ self._file_idx = 0
191
+
192
+ for i in range(self._n_chunks):
193
+ filename = self._filenames[self._file_idx + i]
194
+ if self._dtype is None:
195
+ self._dtype, self._chunk_size = self._read_header(
196
+ filename
197
+ )
198
+ self._n_blocks = self._chunk_size // self._block_size
199
+ # TODO: check header matches with previous files
200
+ mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE)
201
+ self._mmaps.append(mmap)
202
+ self._buffers.append(memoryview(mmap))
203
+
204
+ self._file_idx += self._n_chunks
205
+ n_all_blocks = self._n_chunks * self._n_blocks
206
+
207
+ self._block_idxs = (
208
+ self._rng.permutation(n_all_blocks)
209
+ if self._shuffle
210
+ else range(n_all_blocks)
211
+ )
212
+
213
+ self._curr_idx = 0
214
+
215
+ def __del__(self):
216
+ self._close_mmaps()
217
+ del self._mmaps
218
+ del self._buffers
219
+
220
+ def __iter__(self):
221
+ return self
222
+
223
+ def __next__(self):
224
+ if self._curr_idx >= len(self._block_idxs):
225
+ self._load_n_chunks()
226
+ # TODO: trigger fetching next next n_chunks if remote
227
+ block_idx = self._block_idxs[self._curr_idx]
228
+ chunk_id = block_idx // self._n_blocks
229
+ buffer = self._buffers[chunk_id]
230
+ elem_id = (block_idx % self._n_blocks) * self._block_size
231
+ offset = np.dtype(self._dtype).itemsize * elem_id
232
+ arr = np.frombuffer(
233
+ buffer, dtype=self._dtype, count=self._block_size, offset=offset
234
+ )
235
+ self._curr_idx += 1
236
+ return torch.from_numpy(arr.astype(np.int64))
237
+
238
+
239
+ class CombinedDataset(IterableDataset):
240
+ def __init__(self, datasets, seed, weights=None):
241
+ self._seed = seed
242
+ self._datasets = datasets
243
+ self._weights = weights
244
+ n_datasets = len(datasets)
245
+ if weights is None:
246
+ self._weights = [1 / n_datasets] * n_datasets
247
+
248
+ def __iter__(self):
249
+ return CombinedDatasetIterator(self._datasets, self._seed, self._weights)
250
+
251
+
252
+ class CombinedDatasetIterator:
253
+ def __init__(self, datasets, seed, weights):
254
+ self._datasets = [iter(el) for el in datasets]
255
+ self._weights = weights
256
+ self._rng = random.Random(seed)
257
+
258
+ def __next__(self):
259
+ dataset, = self._rng.choices(self._datasets, weights=self._weights, k=1)
260
+ return next(dataset)
lit_llama/quantization.py ADDED
@@ -0,0 +1,614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from contextlib import contextmanager
3
+ import warnings
4
+ import math
5
+
6
+ import torch
7
+
8
+ # configuration for bitsandbytes before import
9
+ os.environ["BITSANDBYTES_NOWELCOME"] = "1"
10
+ warnings.filterwarnings(
11
+ "ignore",
12
+ message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization",
13
+ )
14
+ warnings.filterwarnings(
15
+ "ignore",
16
+ message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization",
17
+ )
18
+ warnings.filterwarnings(
19
+ "ignore",
20
+ message="The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable.",
21
+ )
22
+
23
+ try:
24
+ import bitsandbytes as bnb # noqa: E402
25
+ except:
26
+ bnb = None
27
+
28
+ try:
29
+ import triton # noqa: E402
30
+ import triton.language as tl # noqa: E402
31
+ except:
32
+ triton = None
33
+
34
+ if bnb is not None:
35
+
36
+ class Linear8bitLt(bnb.nn.Linear8bitLt):
37
+ """Wraps `bnb.nn.Linear8bitLt` and enables instantiation directly on the device and
38
+ re-quantizaton when loading the state dict.
39
+
40
+
41
+ This should only be used for inference. For training, use `bnb.nn.Linear8bitLt` directly.
42
+ """
43
+
44
+ def __init__(self, *args, **kwargs):
45
+ super().__init__(*args, **kwargs, has_fp16_weights=False, threshold=6.0)
46
+ # We quantize the initial weight here so we don't end up filling the device
47
+ # memory with float32 weights which could lead to OOM.
48
+ self._quantize_weight(self.weight.data)
49
+
50
+ def _load_from_state_dict(self, local_state_dict, *args, **kwargs):
51
+ # There is only one key that ends with `*.weight`, the other one is the bias
52
+ weight_key = next(
53
+ (name for name in local_state_dict.keys() if name.endswith("weight")),
54
+ None,
55
+ )
56
+ if weight_key is None:
57
+ return
58
+
59
+ # Load the weight from the state dict and re-quantize it
60
+ weight = local_state_dict.pop(weight_key)
61
+ self._quantize_weight(weight)
62
+
63
+ # If there is a bias, let nn.Module load it
64
+ if local_state_dict:
65
+ super()._load_from_state_dict(local_state_dict, *args, **kwargs)
66
+
67
+ def _quantize_weight(self, weight: torch.Tensor) -> None:
68
+ # This code is taken and adapted from `bnb.nn.Int8Params.cuda()`
69
+ B = weight.contiguous().half().cuda()
70
+ CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
71
+ del CBt
72
+ del SCBt
73
+ self.weight.data = CB
74
+ setattr(self.weight, "CB", CB)
75
+ setattr(self.weight, "SCB", SCB)
76
+
77
+
78
+ if triton is not None:
79
+ # This is adapted from the OpenAI Triton matmul example.
80
+ @triton.autotune(
81
+ configs=[
82
+ triton.Config(
83
+ {
84
+ "BLOCK_SIZE_M": 128,
85
+ "BLOCK_SIZE_N": 256,
86
+ "BLOCK_SIZE_K": 32,
87
+ "GROUP_SIZE_M": 8,
88
+ },
89
+ num_stages=3,
90
+ num_warps=8,
91
+ ),
92
+ triton.Config(
93
+ {
94
+ "BLOCK_SIZE_M": 256,
95
+ "BLOCK_SIZE_N": 128,
96
+ "BLOCK_SIZE_K": 32,
97
+ "GROUP_SIZE_M": 8,
98
+ },
99
+ num_stages=3,
100
+ num_warps=8,
101
+ ),
102
+ triton.Config(
103
+ {
104
+ "BLOCK_SIZE_M": 256,
105
+ "BLOCK_SIZE_N": 64,
106
+ "BLOCK_SIZE_K": 32,
107
+ "GROUP_SIZE_M": 8,
108
+ },
109
+ num_stages=4,
110
+ num_warps=4,
111
+ ),
112
+ triton.Config(
113
+ {
114
+ "BLOCK_SIZE_M": 64,
115
+ "BLOCK_SIZE_N": 256,
116
+ "BLOCK_SIZE_K": 32,
117
+ "GROUP_SIZE_M": 8,
118
+ },
119
+ num_stages=4,
120
+ num_warps=4,
121
+ ),
122
+ triton.Config(
123
+ {
124
+ "BLOCK_SIZE_M": 128,
125
+ "BLOCK_SIZE_N": 128,
126
+ "BLOCK_SIZE_K": 32,
127
+ "GROUP_SIZE_M": 8,
128
+ },
129
+ num_stages=4,
130
+ num_warps=4,
131
+ ),
132
+ triton.Config(
133
+ {
134
+ "BLOCK_SIZE_M": 128,
135
+ "BLOCK_SIZE_N": 64,
136
+ "BLOCK_SIZE_K": 32,
137
+ "GROUP_SIZE_M": 8,
138
+ },
139
+ num_stages=4,
140
+ num_warps=4,
141
+ ),
142
+ triton.Config(
143
+ {
144
+ "BLOCK_SIZE_M": 64,
145
+ "BLOCK_SIZE_N": 128,
146
+ "BLOCK_SIZE_K": 32,
147
+ "GROUP_SIZE_M": 8,
148
+ },
149
+ num_stages=4,
150
+ num_warps=4,
151
+ ),
152
+ triton.Config(
153
+ {
154
+ "BLOCK_SIZE_M": 128,
155
+ "BLOCK_SIZE_N": 32,
156
+ "BLOCK_SIZE_K": 32,
157
+ "GROUP_SIZE_M": 8,
158
+ },
159
+ num_stages=4,
160
+ num_warps=4,
161
+ ),
162
+ triton.Config(
163
+ {
164
+ "BLOCK_SIZE_M": 64,
165
+ "BLOCK_SIZE_N": 32,
166
+ "BLOCK_SIZE_K": 32,
167
+ "GROUP_SIZE_M": 8,
168
+ },
169
+ num_stages=5,
170
+ num_warps=2,
171
+ ),
172
+ triton.Config(
173
+ {
174
+ "BLOCK_SIZE_M": 32,
175
+ "BLOCK_SIZE_N": 64,
176
+ "BLOCK_SIZE_K": 32,
177
+ "GROUP_SIZE_M": 8,
178
+ },
179
+ num_stages=5,
180
+ num_warps=2,
181
+ ),
182
+ ],
183
+ key=["M", "N", "K"],
184
+ )
185
+ @triton.jit
186
+ def linear_kernel_4bit_weight(
187
+ # Pointers to matrices
188
+ a_ptr,
189
+ b_ptr,
190
+ c_ptr,
191
+ bscales_ptr,
192
+ bzeros_ptr,
193
+ # bdequant,
194
+ # Matrix dimensions
195
+ M,
196
+ N,
197
+ K,
198
+ # The stride variables represent how much to increase the ptr by when moving by 1
199
+ # element in a particular dimension. E.g. stride_am is how much to increase a_ptr
200
+ # by to get the element one row down (A has M rows)
201
+ stride_am,
202
+ stride_ak,
203
+ stride_bk,
204
+ stride_bn,
205
+ stride_cm,
206
+ stride_cn,
207
+ # Meta-parameters
208
+ BLOCK_SIZE_M: tl.constexpr,
209
+ BLOCK_SIZE_N: tl.constexpr,
210
+ BLOCK_SIZE_K: tl.constexpr,
211
+ GROUP_SIZE_M: tl.constexpr,
212
+ ):
213
+ """Kernel for computing the matmul C = A x B.T.
214
+ A has shape (M, K), B has shape (N, K) and C has shape (M, N)
215
+ """
216
+ # -----------------------------------------------------------
217
+ # Map program ids `pid` to the block of C it should compute.
218
+ # This is done in a grouped ordering to promote L2 data reuse
219
+ # See above `L2 Cache Optimizations` section for details
220
+ pid = tl.program_id(axis=0)
221
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
222
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
223
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
224
+ group_id = pid // num_pid_in_group
225
+ first_pid_m = group_id * GROUP_SIZE_M
226
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
227
+ pid_m = first_pid_m + (pid % group_size_m)
228
+ pid_n = (pid % num_pid_in_group) // group_size_m
229
+
230
+ # ----------------------------------------------------------
231
+ # Create pointers for the first blocks of A and B.
232
+ # We will advance this pointer as we move in the K direction
233
+ # and accumulate
234
+ # a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
235
+ # b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
236
+ # see above `Pointer Arithmetics` section for details
237
+ offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
238
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
239
+ a_mask = offs_am[:, None] < M
240
+ b_mask = offs_bn[None, :] < N
241
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
242
+ a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
243
+ b_ptrs = b_ptr + (
244
+ (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
245
+ )
246
+
247
+ bscales_ptrs = bscales_ptr + offs_bn[None, :]
248
+ bzeros_ptrs = bzeros_ptr + offs_bn[None, :]
249
+
250
+ scale = tl.load(bscales_ptrs)
251
+ zero = tl.load(bzeros_ptrs)
252
+ # -----------------------------------------------------------
253
+ # Iterate to compute a block of the C matrix
254
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
255
+ # of fp32 values for higher accuracy.
256
+ # `accumulator` will be converted back to fp16 after the loop
257
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
258
+ for k in range(0, K, BLOCK_SIZE_K):
259
+ # wasteful as it is to load everything twice, my attempts at avoiding it lead to slower code
260
+ b12 = tl.load(b_ptrs, mask=b_mask)
261
+ # Note that for simplicity, we don't apply a mask in K here.
262
+ a = tl.load(a_ptrs, mask=a_mask).to(tl.float32)
263
+ b = (
264
+ ((b12.to(tl.uint8) >> ((offs_k[:, None] % 2) * 4)) & 0xF).to(tl.float32)
265
+ - zero
266
+ ) * scale
267
+ accumulator += tl.dot(a, b)
268
+
269
+ # Advance the ptrs to the next K block
270
+ a_ptrs += BLOCK_SIZE_K * stride_ak
271
+ b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
272
+ c = accumulator
273
+
274
+ # -----------------------------------------------------------
275
+ # Write back the block of the output matrix C
276
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
277
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
278
+ c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
279
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
280
+ tl.store(c_ptrs, c, mask=c_mask)
281
+
282
+ def qlinear_4bit_weight(inp, weight, scales, zeros):
283
+ weight = weight.t().contiguous()
284
+ c_shape = inp.shape[:-1] + weight.shape[-1:]
285
+ inp = inp.reshape(-1, inp.shape[-1]).contiguous()
286
+ # we pad the input to amortize triton compilation cost better
287
+ PAD_TO = 256
288
+ if inp.shape[0] % PAD_TO != 0:
289
+ c_crop = inp.shape[0]
290
+ new_inp_shape0 = inp.shape[0] + PAD_TO - inp.shape[0] % PAD_TO
291
+ inp2 = inp.new_empty((new_inp_shape0, inp.shape[1]))
292
+ inp2[: inp.shape[0]] = inp
293
+ inp2[inp.shape[0] :].zero_()
294
+ inp = inp2
295
+ else:
296
+ c_crop = None
297
+
298
+ assert inp.shape[1] == weight.shape[0] * 2, "incompatible dimensions"
299
+
300
+ assert scales.shape == (weight.shape[1], 1)
301
+ assert zeros.shape == (weight.shape[1], 1)
302
+ scales = scales.contiguous()
303
+ zeros = zeros.contiguous()
304
+ K, N = weight.shape
305
+ M, K = inp.shape
306
+ assert (
307
+ K % 32 == 0
308
+ ), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"
309
+ # allocates output
310
+ c = torch.empty((M, N), device=inp.device, dtype=inp.dtype)
311
+ # 1D launch kernel where each block gets its own program.
312
+ grid = lambda META: (
313
+ triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
314
+ )
315
+ linear_kernel_4bit_weight[grid](
316
+ inp,
317
+ weight,
318
+ c,
319
+ scales,
320
+ zeros,
321
+ M,
322
+ N,
323
+ K,
324
+ inp.stride(0),
325
+ inp.stride(1),
326
+ weight.stride(0),
327
+ weight.stride(1),
328
+ c.stride(0),
329
+ c.stride(1),
330
+ )
331
+ return c[:c_crop].reshape(c_shape)
332
+
333
+ else:
334
+ qlinear_4bit_weight = None
335
+
336
+
337
+ # for correctness but with terrible perf
338
+ class ColBlockQuantizedLinear(torch.nn.Module):
339
+ def __init__(self, in_features, out_features, bias: bool, *, bits, tile_cols):
340
+ super().__init__()
341
+ self.in_features = in_features
342
+ self.out_features = out_features
343
+ self.tile_cols = tile_cols if tile_cols != -1 else self.in_features
344
+ self.bits = bits
345
+ self.entries_per_byte = 8 // bits
346
+ assert self.entries_per_byte > 0 and self.entries_per_byte * self.bits == 8
347
+ assert in_features % self.entries_per_byte == 0
348
+ self.register_buffer(
349
+ "quant_weight",
350
+ torch.empty(
351
+ (self.out_features, self.in_features // self.entries_per_byte),
352
+ dtype=torch.uint8,
353
+ )
354
+ .t()
355
+ .contiguous()
356
+ .t(),
357
+ )
358
+ self.register_buffer(
359
+ "scales",
360
+ torch.empty(
361
+ (
362
+ self.out_features,
363
+ (self.in_features + self.tile_cols - 1) // self.tile_cols,
364
+ )
365
+ ),
366
+ )
367
+ self.register_buffer("zeros", torch.empty_like(self.scales))
368
+ assert isinstance(bias, bool)
369
+ if bias:
370
+ self.register_buffer("bias", torch.empty((self.out_features,)))
371
+ else:
372
+ self.register_buffer("bias", None)
373
+
374
+ def pack_weight(self, weight):
375
+ weight = weight.to(device=self.quant_weight.device, copy=True)
376
+ for j in range(self.scales.size(1)):
377
+ weight[:, j * self.tile_cols : (j + 1) * self.tile_cols] /= self.scales[
378
+ :, j : j + 1
379
+ ]
380
+ weight[:, j * self.tile_cols : (j + 1) * self.tile_cols] += self.zeros[
381
+ :, j : j + 1
382
+ ]
383
+ weight = weight.clamp_(min=0, max=2**self.bits - 1).to(dtype=torch.uint8)
384
+ self.quant_weight.zero_()
385
+ for nr in range(self.entries_per_byte):
386
+ self.quant_weight += weight[:, nr :: self.entries_per_byte] << (
387
+ nr * self.bits
388
+ )
389
+
390
+ def get_weight(self, dtype=torch.float):
391
+ weight = torch.empty(
392
+ (self.out_features, self.in_features),
393
+ device=self.quant_weight.device,
394
+ dtype=dtype,
395
+ )
396
+ mask = (1 << self.bits) - 1
397
+ for nr in range(self.entries_per_byte):
398
+ weight[:, nr :: self.entries_per_byte] = (
399
+ (self.quant_weight >> (nr * self.bits)) & mask
400
+ ).float()
401
+ self.quant_weight.to(dtype)
402
+ for j in range(self.scales.size(1)):
403
+ weight[:, j * self.tile_cols : (j + 1) * self.tile_cols] -= self.zeros[
404
+ :, j : j + 1
405
+ ]
406
+ weight[:, j * self.tile_cols : (j + 1) * self.tile_cols] *= self.scales[
407
+ :, j : j + 1
408
+ ]
409
+ return weight
410
+
411
+ def forward(self, inp):
412
+ if (
413
+ triton is not None
414
+ and self.bits == 4
415
+ and self.quant_weight.device.type == "cuda"
416
+ and self.zeros.shape[1] == 1
417
+ and self.quant_weight.shape[1] % 32 == 0
418
+ ):
419
+ return qlinear_4bit_weight(inp, self.quant_weight, self.scales, self.zeros)
420
+ weight = self.get_weight(dtype=inp.dtype)
421
+ return torch.nn.functional.linear(inp, weight, self.bias)
422
+
423
+
424
+ class GPTQQuantizer:
425
+ # The algorithm and code has been taken from https://github.com/IST-DASLab/gptq/
426
+ # E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
427
+ # portions copyright by the authors licensed under the Apache License 2.0
428
+ # All errors are our own.
429
+
430
+ def __init__(
431
+ self,
432
+ linear_module,
433
+ *,
434
+ bits,
435
+ perchannel=True,
436
+ sym=False,
437
+ blocksize=128,
438
+ percdamp=0.01,
439
+ groupsize=-1,
440
+ actorder=False
441
+ ):
442
+ assert isinstance(linear_module, torch.nn.Linear)
443
+
444
+ self.linear_module = linear_module
445
+ self.dev = self.linear_module.weight.device
446
+ self.rows = linear_module.weight.shape[0]
447
+ self.columns = linear_module.weight.shape[1]
448
+ self.H = torch.zeros((self.columns, self.columns), device=self.dev)
449
+ self.nsamples = 0
450
+ self.bits = bits
451
+ self.maxq = 2**bits - 1
452
+ self.perchannel = perchannel
453
+ self.sym = sym
454
+ self.blocksize = blocksize
455
+ self.percdamp = percdamp
456
+ self.groupsize = groupsize
457
+ self.actorder = actorder
458
+ self.tile_cols = self.columns if groupsize == -1 else groupsize
459
+ self.scales = torch.zeros(
460
+ (self.rows, (self.columns + self.tile_cols - 1) // self.tile_cols),
461
+ dtype=self.linear_module.weight.dtype,
462
+ device=self.dev,
463
+ )
464
+ self.zeros = torch.zeros_like(self.scales)
465
+ assert not (
466
+ self.actorder and self.groupsize != -1
467
+ ), "The permutation trick does not work for grouped quantization"
468
+
469
+ @staticmethod
470
+ def quantize_weight(x, scale, zero, maxq):
471
+ q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
472
+ x_rec = scale * (q - zero)
473
+ return x_rec
474
+
475
+ def find_params_weight(self, x):
476
+ dev = x.device
477
+
478
+ shape = x.shape
479
+ if self.perchannel:
480
+ x = x.flatten(1)
481
+ else:
482
+ x = x.flatten().unsqueeze(0)
483
+
484
+ tmp = torch.zeros(x.shape[0], device=dev)
485
+ xmin = torch.minimum(x.min(1)[0], tmp)
486
+ xmax = torch.maximum(x.max(1)[0], tmp)
487
+
488
+ if self.sym:
489
+ xmax = torch.maximum(torch.abs(xmin), xmax)
490
+ tmp = xmin < 0
491
+ if torch.any(tmp):
492
+ xmin[tmp] = -xmax[tmp]
493
+ tmp = (xmin == 0) & (xmax == 0)
494
+ xmin[tmp] = -1
495
+ xmax[tmp] = +1
496
+
497
+ scale = (xmax - xmin) / self.maxq
498
+ if self.sym:
499
+ zero = torch.full_like(scale, (self.maxq + 1) / 2)
500
+ else:
501
+ zero = torch.round(-xmin / scale)
502
+
503
+ if not self.perchannel:
504
+ tmp = shape[0]
505
+ scale = scale.repeat(tmp)
506
+ zero = zero.repeat(tmp)
507
+
508
+ shape = [-1] + [1] * (len(shape) - 1)
509
+ scale = scale.reshape(shape)
510
+ zero = zero.reshape(shape)
511
+ return scale, zero
512
+
513
+ def collect_input_stats(self, _1, inp, _2):
514
+ inp = inp[0].detach()
515
+ self.last_inp = inp
516
+ if len(inp.shape) == 2:
517
+ inp = inp.unsqueeze(0)
518
+ tmp = inp.shape[0]
519
+ if len(inp.shape) == 3:
520
+ inp = inp.reshape((-1, inp.shape[-1]))
521
+ inp = inp.t()
522
+ self.H *= self.nsamples / (self.nsamples + tmp)
523
+ self.nsamples += tmp
524
+ # inp = inp.float()
525
+ inp = math.sqrt(2 / self.nsamples) * inp.float()
526
+ # self.H += 2 / self.nsamples * inp.matmul(inp.t())
527
+ self.H += inp.matmul(inp.t())
528
+
529
+ def quantize(self):
530
+ W = self.linear_module.weight.detach().to(dtype=torch.float, copy=True)
531
+
532
+ scale, zero = self.find_params_weight(W)
533
+ self.scales[:] = scale
534
+ self.zeros[:] = zero
535
+
536
+ H = self.H
537
+ del self.H
538
+ dead = torch.diag(H) == 0
539
+ H[dead, dead] = 1
540
+ W[:, dead] = 0
541
+ if self.actorder:
542
+ perm = torch.argsort(torch.diag(H), descending=True)
543
+ W = W[:, perm]
544
+ H = H[perm][:, perm]
545
+
546
+ Losses = torch.zeros_like(W)
547
+ Q = torch.zeros_like(W)
548
+
549
+ damp = self.percdamp * torch.mean(torch.diag(H))
550
+ diag = torch.arange(self.columns, device=self.dev)
551
+ H[diag, diag] += damp
552
+ H = torch.linalg.cholesky(H)
553
+ H = torch.cholesky_inverse(H)
554
+ H = torch.linalg.cholesky(H, upper=True)
555
+ Hinv = H
556
+
557
+ for i1 in range(0, self.columns, self.blocksize):
558
+ i2 = min(i1 + self.blocksize, self.columns)
559
+ count = i2 - i1
560
+
561
+ W1 = W[:, i1:i2].clone()
562
+ Q1 = torch.zeros_like(W1)
563
+ Err1 = torch.zeros_like(W1)
564
+ Losses1 = torch.zeros_like(W1)
565
+ Hinv1 = Hinv[i1:i2, i1:i2]
566
+
567
+ for i in range(count):
568
+ w = W1[:, i]
569
+ d = Hinv1[i, i]
570
+
571
+ if self.groupsize != -1:
572
+ if (i1 + i) % self.groupsize == 0:
573
+ scale, zero = self.find_params_weight(
574
+ W[:, (i1 + i) : (i1 + i + self.groupsize)]
575
+ )
576
+ self.scales[:, (i1 + i) // self.groupsize] = scale
577
+ self.zeros[:, (i1 + i) // self.groupsize] = zero
578
+
579
+ q = self.quantize_weight(w.unsqueeze(1), scale, zero, self.maxq)
580
+ q = q.squeeze(1)
581
+ assert q.dim() == 1
582
+ Q1[:, i] = q
583
+ Losses1[:, i] = (w - q) ** 2 / d**2
584
+
585
+ err1 = (w - q) / d
586
+ W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
587
+ Err1[:, i] = err1
588
+
589
+ Q[:, i1:i2] = Q1
590
+ Losses[:, i1:i2] = Losses1 / 2
591
+
592
+ W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
593
+
594
+ if self.actorder:
595
+ invperm = torch.argsort(perm)
596
+ Q = Q[:, invperm]
597
+
598
+ weight = Q.reshape(self.linear_module.weight.shape).to(
599
+ self.linear_module.weight.data.dtype
600
+ )
601
+ error = torch.sum(Losses).item()
602
+
603
+ q_module = ColBlockQuantizedLinear(
604
+ self.linear_module.in_features,
605
+ self.linear_module.out_features,
606
+ self.linear_module.bias is not None,
607
+ bits=self.bits,
608
+ tile_cols=self.groupsize,
609
+ ).to(self.dev)
610
+ q_module.scales = self.scales
611
+ q_module.zeros = self.zeros
612
+ q_module.pack_weight(weight)
613
+ q_module.bias = self.linear_module.bias
614
+ return q_module, error
lit_llama/tokenizer.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from sentencepiece import SentencePieceProcessor, SentencePieceTrainer
7
+
8
+
9
+ class Tokenizer:
10
+ """Tokenizer for LLaMA."""
11
+
12
+ def __init__(self, model_path: Path) -> None:
13
+ self.processor = SentencePieceProcessor(model_file=str(model_path))
14
+ self.bos_id = self.processor.bos_id()
15
+ self.eos_id = self.processor.eos_id()
16
+ self.pad_id = self.processor.pad_id()
17
+
18
+ @property
19
+ def vocab_size(self) -> int:
20
+ return self.processor.vocab_size()
21
+
22
+ def encode(
23
+ self,
24
+ string: str,
25
+ bos: bool = True,
26
+ eos: bool = False,
27
+ max_length: int = -1,
28
+ pad: bool = False,
29
+ device: Optional[torch.device] = None
30
+ ) -> torch.Tensor:
31
+ tokens = self.processor.encode(string)
32
+ if bos:
33
+ tokens = [self.bos_id] + tokens
34
+ if eos:
35
+ tokens = tokens + [self.eos_id]
36
+ if max_length > 0:
37
+ tokens = tokens[:max_length]
38
+ if pad and len(tokens) < max_length:
39
+ tokens += [self.pad_id] * (max_length - len(tokens))
40
+
41
+ return torch.tensor(tokens, dtype=torch.int, device=device)
42
+
43
+ def decode(self, tokens: torch.Tensor) -> str:
44
+ return self.processor.decode(tokens.tolist())
45
+
46
+ @staticmethod
47
+ def train(input: str, destination: str, vocab_size=32000) -> None:
48
+ model_prefix = os.path.join(destination, "tokenizer")
49
+ SentencePieceTrainer.Train(input=input, model_prefix=model_prefix, vocab_size=vocab_size)
lit_llama/utils.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for training and inference."""
2
+
3
+ import functools
4
+ import pickle
5
+ import warnings
6
+ from io import BytesIO
7
+ from pathlib import Path
8
+ from contextlib import contextmanager
9
+
10
+ import torch
11
+ import torch.utils._device
12
+ from lightning.fabric.strategies import DeepSpeedStrategy, FSDPStrategy
13
+ from torch.distributed.fsdp import FullStateDictConfig
14
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
15
+ from torch.distributed.fsdp import StateDictType
16
+ from torch.serialization import normalize_storage_type
17
+
18
+ llama_model_sizes = {
19
+ 4096: "7B", # 7B n_embd=4096
20
+ 5120: "13B", # 13B n_embd=5120
21
+ 6656: "30B", # 30B n_embd=6656
22
+ 8192: "65B", # 65B n_embd=8192
23
+ }
24
+
25
+
26
+ def llama_model_lookup(checkpoint: dict) -> str:
27
+ """Returns the LLaMA model name from the checkpoint.
28
+
29
+ Checks the width of the lm_head.weight matrix, as these uniquely identify the model.
30
+ """
31
+ embedding_size = checkpoint['transformer.wte.weight'].shape[1]
32
+ return llama_model_sizes[embedding_size]
33
+
34
+
35
+ def find_multiple(n: int, k: int) -> int:
36
+ if n % k == 0:
37
+ return n
38
+ return n + k - (n % k)
39
+
40
+
41
+ def save_model_checkpoint(fabric, model, file_path):
42
+ """Handles boilerplate logic for retrieving and saving the state_dict.
43
+
44
+ This will be upstreamed to Fabric soon.
45
+ """
46
+ file_path = Path(file_path)
47
+
48
+ if isinstance(fabric.strategy, DeepSpeedStrategy):
49
+ from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
50
+
51
+ fabric.save(file_path, {"model": model})
52
+ fabric.barrier()
53
+ if fabric.global_rank == 0:
54
+ # Create a consolidated checkpoint with the same name next to the deepspeed checkpoint
55
+ convert_zero_checkpoint_to_fp32_state_dict(file_path, file_path.with_suffix(".pth"))
56
+ return
57
+
58
+ if isinstance(fabric.strategy, FSDPStrategy):
59
+ save_policy = FullStateDictConfig(offload_to_cpu=(fabric.world_size > 1), rank0_only=True)
60
+ with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
61
+ state_dict = model._forward_module.state_dict()
62
+ else:
63
+ state_dict = model.state_dict()
64
+
65
+ if fabric.global_rank == 0:
66
+ torch.save(state_dict, file_path)
67
+ fabric.barrier()
68
+
69
+
70
+ class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
71
+ def __init__(self, device=None, dtype=None, quantization_mode=None):
72
+ """
73
+ Create tensors with given device and dtype and don't run initialization
74
+ (but instead use "empty tensors", i.e. uninitialized memory).
75
+
76
+ device: `torch.device` to work with
77
+ dtype: `torch.dtype` to work with
78
+ quantization_mode: optional string, quantization mode to work with, default `None`.
79
+ Available modes: `llm.int8` bitsnbytes LLM.int8 quantization (only on GPU)
80
+ `gptq.int4`, `gptq.int8`: GPTQ pre-quantized models
81
+
82
+ Example::
83
+ with EmptyInitOnDevice("cuda", dtype=torch.bfloat16):
84
+ model = LLaMA.from_name('7B')
85
+ model.load_state_dict(torch.load('llama-lit/7B/lit-llama.pth'))"""
86
+
87
+ self.quantization_mode = quantization_mode
88
+ self.quantized_linear_cls = None
89
+ if self.quantization_mode == 'llm.int8':
90
+ if device.type != "cuda":
91
+ raise ValueError("Quantization is only supported on the GPU.")
92
+ from .quantization import Linear8bitLt
93
+ self.quantized_linear_cls = Linear8bitLt
94
+ elif self.quantization_mode == 'gptq.int4':
95
+ from .quantization import ColBlockQuantizedLinear
96
+ self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1)
97
+ elif self.quantization_mode == 'gptq.int8':
98
+ from .quantization import ColBlockQuantizedLinear
99
+ self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1)
100
+ elif self.quantization_mode is not None:
101
+ raise RuntimeError(f"unknown quantization mode {self.quantization_mode}")
102
+ self.device = device
103
+ self.dtype = dtype
104
+
105
+ def __enter__(self):
106
+ if self.quantized_linear_cls != None:
107
+ self.torch_linear_cls = torch.nn.Linear
108
+ torch.nn.Linear = self.quantized_linear_cls
109
+ return super().__enter__()
110
+
111
+ def __exit__(self, exc_type, exc_val, exc_tb):
112
+ if self.quantized_linear_cls != None:
113
+ torch.nn.Linear = self.torch_linear_cls
114
+ return super().__exit__(exc_type, exc_val, exc_tb)
115
+
116
+ def __torch_function__(self, func, types, args=(), kwargs=None):
117
+ kwargs = kwargs or {}
118
+ if getattr(func, "__module__", None) == "torch.nn.init":
119
+ if "tensor" in kwargs:
120
+ return kwargs["tensor"]
121
+ else:
122
+ return args[0]
123
+ if (
124
+ self.device is not None
125
+ and func in torch.utils._device._device_constructors()
126
+ and kwargs.get("device") is None
127
+ ):
128
+ kwargs["device"] = self.device
129
+ if (
130
+ self.dtype is not None
131
+ and func in torch.utils._device._device_constructors()
132
+ and kwargs.get("dtype") is None
133
+ ):
134
+ kwargs["dtype"] = self.dtype
135
+ return func(*args, **kwargs)
136
+
137
+
138
+ @contextmanager
139
+ def quantization(mode: str = None):
140
+ quantized_linear_cls = None
141
+ if mode == 'llm.int8':
142
+ from .quantization import Linear8bitLt
143
+ quantized_linear_cls = Linear8bitLt
144
+ elif mode == 'gptq.int4':
145
+ from .quantization import ColBlockQuantizedLinear
146
+ quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1)
147
+ elif mode == 'gptq.int8':
148
+ from .quantization import ColBlockQuantizedLinear
149
+ quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1)
150
+ elif mode is not None:
151
+ raise ValueError(f"Unknown quantization mode: {mode}")
152
+
153
+ enabled = mode is not None
154
+ torch_linear_cls = torch.nn.Linear
155
+ if enabled:
156
+ torch.nn.Linear = quantized_linear_cls
157
+ yield
158
+ if enabled:
159
+ torch.nn.Linear = torch_linear_cls
160
+
161
+
162
+ # this is taken from torchhacks https://github.com/lernapparat/torchhacks
163
+
164
+
165
+ class NotYetLoadedTensor:
166
+ def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args):
167
+ self.metatensor = metatensor
168
+ self.archiveinfo = archiveinfo
169
+ self.storageinfo = storageinfo
170
+ self.rebuild_args = rebuild_args
171
+
172
+ @classmethod
173
+ def rebuild_from_type_v2(cls, func, new_type, args, state, *, archiveinfo=None):
174
+ ret = func(*args)
175
+ if isinstance(ret, NotYetLoadedTensor):
176
+ old_lt = ret._load_tensor
177
+
178
+ def _load_tensor():
179
+ t = old_lt()
180
+ return torch._tensor._rebuild_from_type_v2(
181
+ lambda: t, new_type, (), state
182
+ )
183
+
184
+ ret._load_tensor = _load_tensor
185
+ return ret
186
+ return torch._tensor._rebuild_from_type_v2(func, new_type, args, state)
187
+
188
+ @classmethod
189
+ def rebuild_parameter(
190
+ cls, data, requires_grad, backward_hooks, *, archiveinfo=None
191
+ ):
192
+ if isinstance(data, NotYetLoadedTensor):
193
+ old_lt = data._load_tensor
194
+
195
+ def _load_tensor():
196
+ t = old_lt()
197
+ return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks)
198
+
199
+ data._load_tensor = _load_tensor
200
+ return data
201
+ return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks)
202
+
203
+ @classmethod
204
+ def rebuild_tensor_v2(
205
+ cls,
206
+ storage,
207
+ storage_offset,
208
+ size,
209
+ stride,
210
+ requires_grad,
211
+ backward_hooks,
212
+ metadata=None,
213
+ *,
214
+ archiveinfo=None,
215
+ ):
216
+ rebuild_args = (
217
+ storage_offset,
218
+ size,
219
+ stride,
220
+ requires_grad,
221
+ backward_hooks,
222
+ metadata,
223
+ )
224
+ metatensor = torch._utils._rebuild_tensor_v2(
225
+ storage,
226
+ storage_offset,
227
+ size,
228
+ stride,
229
+ requires_grad,
230
+ backward_hooks,
231
+ metadata,
232
+ )
233
+ storageinfo = storage.archiveinfo
234
+ return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args)
235
+
236
+ def _load_tensor(self):
237
+ name, storage_cls, fn, device, size = self.storageinfo
238
+ dtype = self.metatensor.dtype
239
+
240
+ uts = (
241
+ self.archiveinfo.zipfile_context.zf.get_storage_from_record(
242
+ f"data/{fn}",
243
+ size * torch._utils._element_size(dtype),
244
+ torch.UntypedStorage,
245
+ )
246
+ ._typed_storage()
247
+ ._untyped_storage
248
+ )
249
+ with warnings.catch_warnings():
250
+ warnings.simplefilter("ignore")
251
+ storage = torch.storage.TypedStorage(
252
+ wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True
253
+ )
254
+ tensor = torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args)
255
+ return tensor
256
+
257
+ @classmethod
258
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
259
+ if kwargs is None:
260
+ kwargs = {}
261
+ loaded_args = [
262
+ (a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args
263
+ ]
264
+ res = func(*loaded_args, **kwargs)
265
+ # gc.collect would be costly here, maybe do it optionally
266
+ return res
267
+
268
+ def __getattr__(self, name):
269
+ # properties
270
+ ## TODO: device, is_...??
271
+ ## TODO: mH, mT, H, T, data, imag, real
272
+ ## name ???
273
+ if name in {
274
+ "dtype",
275
+ "grad",
276
+ "grad_fn",
277
+ "layout",
278
+ "names",
279
+ "ndim",
280
+ "output_nr",
281
+ "requires_grad",
282
+ "retains_grad",
283
+ "shape",
284
+ "volatile",
285
+ }:
286
+ return getattr(self.metatensor, name)
287
+ if name in {"size"}:
288
+ return getattr(self.metatensor, name)
289
+ # materializing with contiguous is needed for quantization
290
+ if name in {"contiguous"}:
291
+ return getattr(self._load_tensor(), name)
292
+
293
+ raise AttributeError(f"{type(self)} does not have {name}")
294
+
295
+ def __repr__(self):
296
+ return f"NotYetLoadedTensor({repr(self.metatensor)})"
297
+
298
+
299
+ class LazyLoadingUnpickler(pickle.Unpickler):
300
+ def __init__(self, file, zipfile_context):
301
+ super().__init__(file)
302
+ self.zipfile_context = zipfile_context
303
+
304
+ def find_class(self, module, name):
305
+ res = super().find_class(module, name)
306
+ if module == "torch._utils" and name == "_rebuild_tensor_v2":
307
+ return functools.partial(
308
+ NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self
309
+ )
310
+ elif module == "torch._tensor" and name == "_rebuild_from_type_v2":
311
+ return functools.partial(
312
+ NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self
313
+ )
314
+ elif module == "torch._utils" and name == "_rebuild_parameter":
315
+ return functools.partial(
316
+ NotYetLoadedTensor.rebuild_parameter, archiveinfo=self
317
+ )
318
+ return res
319
+
320
+ def persistent_load(self, pid):
321
+ name, cls, fn, device, size = pid
322
+ with warnings.catch_warnings():
323
+ warnings.simplefilter("ignore")
324
+ s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta")
325
+ s.archiveinfo = pid
326
+ return s
327
+
328
+
329
+ class lazy_load:
330
+ def __init__(self, fn):
331
+ self.zf = torch._C.PyTorchFileReader(str(fn))
332
+ with BytesIO(self.zf.get_record("data.pkl")) as pkl:
333
+ mup = LazyLoadingUnpickler(pkl, self)
334
+ self.sd = mup.load()
335
+
336
+ def __enter__(self):
337
+ return self.sd
338
+
339
+ def __exit__(self, exc_type, exc_val, exc_tb):
340
+ del self.zf # I don't think there is a way to force closing...
341
+ self.zf = None
342
+
343
+
344
+ class SavingProxyForStorage:
345
+ def __init__(self, obj, saver, protocol_version=5):
346
+ self.protocol_version = protocol_version
347
+ self.saver = saver
348
+ if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
349
+ raise TypeError(f"expected storage, not {type(obj)}")
350
+
351
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
352
+ if isinstance(obj, torch.storage.TypedStorage):
353
+ # PT upstream wants to deprecate this eventually...
354
+ storage = obj._untyped_storage
355
+ storage_type_str = obj._pickle_storage_type()
356
+ storage_type = getattr(torch, storage_type_str)
357
+ storage_numel = obj._size()
358
+ else:
359
+ storage = obj
360
+ storage_type = normalize_storage_type(type(obj))
361
+ storage_numel = storage.nbytes()
362
+
363
+ storage_key = saver._write_storage_and_return_key(storage)
364
+ location = torch.serialization.location_tag(storage)
365
+
366
+ self.storage_info = (
367
+ "storage",
368
+ storage_type,
369
+ storage_key,
370
+ location,
371
+ storage_numel,
372
+ )
373
+
374
+ def __reduce_ex__(self, protocol_version):
375
+ assert False, "this should be handled with out of band"
376
+
377
+
378
+ class SavingProxyForTensor:
379
+ def __init__(self, tensor, saver, protocol_version=5):
380
+ self.protocol_version = protocol_version
381
+ self.reduce_ret_fn, (storage, *other_reduce_args) = tensor.__reduce_ex__(
382
+ protocol_version
383
+ )
384
+ assert isinstance(
385
+ storage, torch.storage.TypedStorage
386
+ ), "Please check for updates"
387
+ storage_proxy = SavingProxyForStorage(
388
+ storage, saver, protocol_version=protocol_version
389
+ )
390
+ self.reduce_args = (storage_proxy, *other_reduce_args)
391
+
392
+ def __reduce_ex__(self, protocol_version):
393
+ if protocol_version != self.protocol_version:
394
+ raise RuntimeError(
395
+ f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}"
396
+ )
397
+ return self.reduce_ret_fn, self.reduce_args
398
+
399
+
400
+ class IncrementalPyTorchPickler(pickle.Pickler):
401
+ def __init__(self, saver, *args, **kwargs):
402
+ super().__init__(*args, **kwargs)
403
+ self.storage_dtypes = {}
404
+ self.saver = saver
405
+ self.id_map = {}
406
+
407
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
408
+ def persistent_id(self, obj):
409
+ # FIXME: the docs say that persistent_id should only return a string
410
+ # but torch store returns tuples. This works only in the binary protocol
411
+ # see
412
+ # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
413
+ # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
414
+ if isinstance(obj, SavingProxyForStorage):
415
+ return obj.storage_info
416
+
417
+ if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
418
+ if isinstance(obj, torch.storage.TypedStorage):
419
+ # TODO: Once we decide to break serialization FC, this case
420
+ # can be deleted
421
+ storage = obj._untyped_storage
422
+ storage_dtype = obj.dtype
423
+ storage_type_str = obj._pickle_storage_type()
424
+ storage_type = getattr(torch, storage_type_str)
425
+ storage_numel = obj._size()
426
+
427
+ else:
428
+ storage = obj
429
+ storage_dtype = torch.uint8
430
+ storage_type = normalize_storage_type(type(obj))
431
+ storage_numel = storage.nbytes()
432
+
433
+ # If storage is allocated, ensure that any other saved storages
434
+ # pointing to the same data all have the same dtype. If storage is
435
+ # not allocated, don't perform this check
436
+ if storage.data_ptr() != 0:
437
+ if storage.data_ptr() in self.storage_dtypes:
438
+ if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
439
+ raise RuntimeError(
440
+ "Cannot save multiple tensors or storages that "
441
+ "view the same data as different types"
442
+ )
443
+ else:
444
+ self.storage_dtypes[storage.data_ptr()] = storage_dtype
445
+
446
+ storage_key = self.id_map.get(storage._cdata)
447
+ if storage_key is None:
448
+ storage_key = self.saver._write_storage_and_return_key(storage)
449
+ self.id_map[storage._cdata] = storage_key
450
+ location = torch.serialization.location_tag(storage)
451
+
452
+ return ("storage", storage_type, storage_key, location, storage_numel)
453
+
454
+ return None
455
+
456
+
457
+ class incremental_save:
458
+ def __init__(self, name):
459
+ self.name = name
460
+ self.zipfile = torch._C.PyTorchFileWriter(str(name))
461
+ self.has_saved = False
462
+ self.next_key = 0
463
+
464
+ def __enter__(self):
465
+ return self
466
+
467
+ def store_early(self, tensor):
468
+ if isinstance(tensor, torch.Tensor):
469
+ return SavingProxyForTensor(tensor, self)
470
+ raise TypeError(f"can only store tensors early, not {type(tensor)}")
471
+
472
+ def save(self, obj):
473
+ if self.has_saved:
474
+ raise RuntimeError("have already saved")
475
+ # Write the pickle data for `obj`
476
+ data_buf = BytesIO()
477
+ pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
478
+ pickler.dump(obj)
479
+ data_value = data_buf.getvalue()
480
+ self.zipfile.write_record("data.pkl", data_value, len(data_value))
481
+ self.has_saved = True
482
+
483
+ def _write_storage_and_return_key(self, storage):
484
+ if self.has_saved:
485
+ raise RuntimeError("have already saved")
486
+ key = self.next_key
487
+ self.next_key += 1
488
+ name = f"data/{key}"
489
+ if storage.device.type != "cpu":
490
+ storage = storage.cpu()
491
+ num_bytes = storage.nbytes()
492
+ self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
493
+ return key
494
+
495
+ def __exit__(self, type, value, traceback):
496
+ self.zipfile.write_end_of_file()
pretrain/redpajama.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import math
4
+ import glob
5
+ import time
6
+ from functools import partial
7
+ from pathlib import Path
8
+ from typing import Tuple, Optional
9
+
10
+ import lightning as L
11
+ from lightning.fabric.strategies import FSDPStrategy
12
+
13
+ import torch
14
+ from torch.utils.data import DataLoader
15
+ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
16
+
17
+ import numpy as np
18
+
19
+ # support running without installing as a package
20
+ wd = Path(__file__).parent.parent.resolve()
21
+ sys.path.append(str(wd))
22
+
23
+ from lit_llama.model import Block, LLaMA, LLaMAConfig
24
+ from lit_llama.packed_dataset import PackedDataset, CombinedDataset
25
+ from lit_llama.utils import save_model_checkpoint
26
+
27
+
28
+ out_dir = "out/training"
29
+ save_interval = 1000
30
+ eval_interval = 1000
31
+ eval_iters = 100
32
+ log_interval = 1
33
+
34
+ # compile = False
35
+
36
+ # Hyperparameters
37
+ learning_rate = 6e-4
38
+ batch_size = 125
39
+ micro_batch_size = 5
40
+ max_iters = 600000 # num_epochs * (epoch_size // micro_batch_size) // devices
41
+ weight_decay = 1e-1
42
+ beta1 = 0.9
43
+ beta2 = 0.95
44
+ grad_clip = 1.0
45
+ decay_lr = True
46
+ warmup_iters = 2000
47
+ lr_decay_iters = max_iters
48
+ min_lr = 6e-5
49
+
50
+
51
+ # Data proportions from https://arxiv.org/pdf/2302.13971.pdf Table 1
52
+ data_config = [
53
+ ("arxiv", 2.5),
54
+ ("book", 4.5),
55
+ ("c4", 15.0),
56
+ ("cc", 67.0),
57
+ ("github", 4.5),
58
+ ("stackexchange", 2.0),
59
+ ("wikipedia", 4.5),
60
+ ]
61
+
62
+
63
+ def main(
64
+ devices: int = 4,
65
+ train_data_dir: Path = "data/lit-redpajama",
66
+ val_data_dir: Optional[Path] = None,
67
+ ) -> None:
68
+ auto_wrap_policy = partial(
69
+ transformer_auto_wrap_policy, transformer_layer_cls={Block}
70
+ )
71
+ strategy = FSDPStrategy(
72
+ auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True
73
+ )
74
+
75
+ fabric = L.Fabric(
76
+ accelerator="cuda", devices=devices, precision="bf16-mixed", strategy=strategy
77
+ )
78
+ fabric.launch()
79
+ fabric.seed_everything(1337)
80
+
81
+ if fabric.global_rank == 0:
82
+ os.makedirs(out_dir, exist_ok=True)
83
+
84
+ config = LLaMAConfig.from_name("7B")
85
+
86
+ train_dataloader, val_dataloader = create_dataloaders(
87
+ batch_size=micro_batch_size,
88
+ block_size=config.block_size,
89
+ fabric=fabric,
90
+ train_data_dir=train_data_dir,
91
+ val_data_dir=val_data_dir,
92
+ seed=1338,
93
+ )
94
+ if val_dataloader is None:
95
+ train_dataloader = fabric.setup_dataloaders(train_dataloader)
96
+ else:
97
+ train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
98
+
99
+ with fabric.device:
100
+ torch.set_default_dtype(torch.bfloat16)
101
+ model = LLaMA(config)
102
+ model.apply(model._init_weights)
103
+ torch.set_default_dtype(torch.float32)
104
+
105
+ # if compile:
106
+ # model = torch.compile(model)
107
+
108
+ optimizer = torch.optim.AdamW(
109
+ model.parameters(),
110
+ lr=learning_rate,
111
+ weight_decay=weight_decay,
112
+ betas=(beta1, beta2),
113
+ foreach=False,
114
+ )
115
+
116
+ model, optimizer = fabric.setup(model, optimizer)
117
+
118
+ process_batch_size = batch_size // devices
119
+ gradient_accumulation_iters = process_batch_size // micro_batch_size
120
+
121
+ train(fabric, model, optimizer, train_dataloader, val_dataloader, gradient_accumulation_iters, devices)
122
+
123
+
124
+ def train(
125
+ fabric: L.Fabric,
126
+ model: torch.nn.Module,
127
+ optimizer: torch.optim.Optimizer,
128
+ train_dataloader: DataLoader,
129
+ val_dataloader: Optional[DataLoader],
130
+ grad_accum_steps: int,
131
+ devices: int,
132
+ ) -> None:
133
+ """The training loop.
134
+
135
+ Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
136
+ """
137
+
138
+ step_count = 0
139
+
140
+ step_time = 0.0
141
+ tokens = 0
142
+ tokens_sec = 0.0
143
+ prev_t1 = time.time()
144
+
145
+ for iter_num, train_data in enumerate(train_dataloader):
146
+ t0 = time.time()
147
+
148
+ # determine and set the learning rate for this iteration
149
+ lr = get_lr(iter_num) if decay_lr else learning_rate
150
+ for param_group in optimizer.param_groups:
151
+ param_group["lr"] = lr
152
+
153
+
154
+ input_ids = train_data[:, 0 : model.config.block_size].contiguous()
155
+ targets = train_data[:, 1 : model.config.block_size + 1].contiguous()
156
+
157
+ is_accumulating = (iter_num + 1) % grad_accum_steps != 0
158
+
159
+ with fabric.no_backward_sync(model, enabled=is_accumulating):
160
+ logits = model(input_ids)
161
+ loss = torch.nn.functional.cross_entropy(
162
+ logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
163
+ )
164
+ fabric.backward(loss / grad_accum_steps)
165
+
166
+ t1 = time.time()
167
+
168
+ if not is_accumulating:
169
+ fabric.clip_gradients(model, optimizer, max_norm=grad_clip)
170
+
171
+ optimizer.step()
172
+ optimizer.zero_grad()
173
+ step_count += 1
174
+
175
+ t1 = time.time()
176
+
177
+ if val_dataloader is not None and step_count % eval_interval == 0:
178
+ val_loss = validate(fabric, model, val_dataloader)
179
+ fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
180
+ fabric.barrier()
181
+ fabric.log_dict(
182
+ {"iter": iter_num, "val_loss": val_loss, "step": step_count, "lr": lr}
183
+ )
184
+
185
+ if step_count % save_interval == 0:
186
+ fabric.print(f"Saving checkpoint to {out_dir}")
187
+ save_model_checkpoint(
188
+ fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth")
189
+ )
190
+
191
+ dt = t1 - t0
192
+
193
+ tokens += micro_batch_size * model.config.block_size
194
+ step_time += t1 - prev_t1
195
+ prev_t1 = t1
196
+
197
+ if iter_num % log_interval == 0:
198
+ tokens_sec_str = f"{tokens / step_time:.0f}" if not is_accumulating else "-"
199
+
200
+ fabric.log_dict(
201
+ {"iter": iter_num, "train_loss": loss, "step": step_count, "lr": lr}
202
+ )
203
+ fabric.print(
204
+ f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms, speed: {tokens_sec_str} toks/s/device"
205
+ )
206
+
207
+ if not is_accumulating:
208
+ tokens = 0
209
+ step_time = 0.0
210
+
211
+ if iter_num > max_iters:
212
+ break
213
+
214
+
215
+ @torch.no_grad()
216
+ def validate(
217
+ fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader
218
+ ) -> torch.Tensor:
219
+ fabric.print("Validating ...")
220
+ model.eval()
221
+ losses = torch.zeros(eval_iters)
222
+ for k, val_data in enumerate(val_dataloader):
223
+ input_ids = val_data[:, 0 : model.config.block_size].contiguous()
224
+ targets = val_data[:, 1 : model.config.block_size + 1].contiguous()
225
+ logits = model(input_ids)
226
+ loss = torch.nn.functional.cross_entropy(
227
+ logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
228
+ )
229
+ losses[k] = loss.item()
230
+ out = losses.mean()
231
+ model.train()
232
+ return out
233
+
234
+
235
+ def create_dataloader(
236
+ batch_size: int,
237
+ block_size: int,
238
+ data_dir: str,
239
+ fabric,
240
+ shuffle: bool = True,
241
+ seed: int = 12345,
242
+ ) -> DataLoader:
243
+ datasets = []
244
+ for prefix, _ in data_config:
245
+ filenames = glob.glob(os.path.join(data_dir, prefix + "*"))
246
+ dataset = PackedDataset(
247
+ filenames, n_chunks=4, block_size=block_size, shuffle=shuffle, seed=seed,
248
+ num_processes=fabric.world_size, process_rank=fabric.global_rank,
249
+ )
250
+ datasets.append(dataset)
251
+
252
+ if not datasets:
253
+ raise RuntimeError(
254
+ f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset."
255
+ )
256
+
257
+ weights = [weight for _, weight in data_config]
258
+ sum_weights = sum(weights)
259
+ weights = [el / sum_weights for el in weights]
260
+
261
+ combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights)
262
+
263
+ return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
264
+
265
+
266
+ def create_dataloaders(
267
+ batch_size: int,
268
+ block_size: int,
269
+ fabric,
270
+ train_data_dir: str = "data/lit-redpajama",
271
+ val_data_dir: Optional[str] = None,
272
+ seed: int = 12345,
273
+ ) -> Tuple[DataLoader, DataLoader]:
274
+ # Increase by one because we need the next word as well
275
+ effective_block_size = block_size + 1
276
+ train_dataloader = create_dataloader(
277
+ batch_size=batch_size,
278
+ block_size=effective_block_size,
279
+ fabric=fabric,
280
+ data_dir=train_data_dir,
281
+ shuffle=True,
282
+ seed=seed,
283
+ )
284
+ val_dataloader = (
285
+ create_dataloader(
286
+ batch_size=batch_size,
287
+ block_size=effective_block_size,
288
+ fabric=fabric,
289
+ data_dir=val_data_dir,
290
+ shuffle=False,
291
+ seed=seed,
292
+ )
293
+ if val_data_dir
294
+ else None
295
+ )
296
+ return train_dataloader, val_dataloader
297
+
298
+
299
+ # learning rate decay scheduler (cosine with warmup)
300
+ def get_lr(it):
301
+ # 1) linear warmup for warmup_iters steps
302
+ if it < warmup_iters:
303
+ return learning_rate * it / warmup_iters
304
+ # 2) if it > lr_decay_iters, return min learning rate
305
+ if it > lr_decay_iters:
306
+ return min_lr
307
+ # 3) in between, use cosine decay down to min learning rate
308
+ decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
309
+ assert 0 <= decay_ratio <= 1
310
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
311
+ return min_lr + coeff * (learning_rate - min_lr)
312
+
313
+
314
+ if __name__ == "__main__":
315
+ # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
316
+ # torch.backends.cuda.enable_flash_sdp(False)
317
+ torch.set_float32_matmul_precision("high")
318
+
319
+ from jsonargparse.cli import CLI
320
+
321
+ CLI(main)
pretrain/shakespeare.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script is a placeholder for training LLaMA from scratch.
3
+ Currently, it just trains on the Shakespeare dataset.
4
+ """
5
+ from pathlib import Path
6
+ import sys
7
+ import os
8
+ import time
9
+ from functools import partial
10
+ from typing import Tuple
11
+
12
+ import lightning as L
13
+ from lightning.fabric.strategies import FSDPStrategy
14
+
15
+ import torch
16
+ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
17
+
18
+ import numpy as np
19
+
20
+ # support running without installing as a package
21
+ wd = Path(__file__).parent.parent.resolve()
22
+ sys.path.append(str(wd))
23
+
24
+ from lit_llama.model import Block, LLaMA, LLaMAConfig
25
+ from lit_llama.utils import save_model_checkpoint
26
+
27
+
28
+ out_dir = "out/training"
29
+ eval_interval = 2000
30
+ eval_iters = 200
31
+ log_interval = 1
32
+ # compilation fails as it does not support torch.complex64 for RoPE
33
+ # compile = False
34
+
35
+ # Hyperparameters
36
+ learning_rate = 6e-4
37
+ batch_size = 2
38
+ max_iters = 600000
39
+ weight_decay = 1e-1
40
+ beta1 = 0.9
41
+ beta2 = 0.95
42
+ grad_clip = 1.0
43
+
44
+ # For shakespeare, choose smaller block size than vanilla LLaMA
45
+ block_size = 1024
46
+
47
+
48
+ def main() -> None:
49
+ auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
50
+ strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True)
51
+
52
+ fabric = L.Fabric(accelerator="cuda", devices=4, precision="bf16-mixed", strategy=strategy)
53
+ fabric.launch()
54
+ fabric.seed_everything(1337 + fabric.global_rank)
55
+
56
+ if fabric.global_rank == 0:
57
+ os.makedirs(out_dir, exist_ok=True)
58
+
59
+ train_data, val_data = load_datasets()
60
+
61
+ config = LLaMAConfig.from_name("7B")
62
+ config.block_size = block_size
63
+ config.vocab_size = 100 # from prepare_shakespeare.py
64
+
65
+ with fabric.device:
66
+ model = LLaMA(config)
67
+
68
+ # if compile:
69
+ # model = torch.compile(model)
70
+
71
+ model = fabric.setup_module(model)
72
+
73
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False)
74
+ optimizer = fabric.setup_optimizers(optimizer)
75
+
76
+ train(fabric, model, optimizer, train_data, val_data)
77
+
78
+
79
+ def train(
80
+ fabric: L.Fabric,
81
+ model: torch.nn.Module,
82
+ optimizer: torch.optim.Optimizer,
83
+ train_data: np.ndarray,
84
+ val_data: np.ndarray,
85
+ ) -> None:
86
+ """The training loop.
87
+
88
+ Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
89
+ """
90
+
91
+ iter_num = 0
92
+
93
+ while True:
94
+ # TODO: add learning rate scheduling
95
+
96
+ # evaluate the loss on train/val sets and write checkpoints
97
+ if iter_num > 0 and iter_num % eval_interval == 0:
98
+ val_loss = validate(fabric, model, val_data)
99
+ fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
100
+ fabric.print(f"Saving checkpoint to {out_dir}")
101
+ save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"))
102
+
103
+ t0 = time.time()
104
+
105
+ input_ids, targets = get_batch(
106
+ fabric,
107
+ train_data,
108
+ block_size=model.config.block_size, # type: ignore[union-attr,arg-type]
109
+ )
110
+ logits = model(input_ids)
111
+ loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
112
+
113
+ fabric.backward(loss)
114
+
115
+ # TODO: Gradient clipping
116
+ # if grad_clip != 0.0:
117
+ # fabric.clip_gradients(model, optimizer, max_norm=grad_clip)
118
+
119
+ optimizer.step()
120
+ optimizer.zero_grad()
121
+
122
+ dt = time.time() - t0
123
+ if iter_num % log_interval == 0:
124
+ fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
125
+ iter_num += 1
126
+
127
+ if iter_num > max_iters:
128
+ break
129
+
130
+
131
+ @torch.no_grad()
132
+ def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor:
133
+ fabric.print("Validating ...")
134
+ model.eval()
135
+ losses = torch.zeros(eval_iters)
136
+ for k in range(eval_iters):
137
+ input_ids, targets = get_batch(
138
+ fabric,
139
+ val_data,
140
+ block_size=model.config.block_size, # type: ignore[union-attr,arg-type]
141
+ )
142
+ logits = model(input_ids)
143
+ loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
144
+ losses[k] = loss.item()
145
+ out = losses.mean()
146
+ model.train()
147
+ return out
148
+
149
+
150
+ def get_batch(fabric: L.Fabric, data: np.ndarray, block_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
151
+ ix = torch.randint(len(data) - block_size, (batch_size,))
152
+ x = torch.stack([torch.from_numpy((data[i : i + block_size]).astype(np.int64)) for i in ix])
153
+ y = torch.stack([torch.from_numpy((data[i + 1 : i + 1 + block_size]).astype(np.int64)) for i in ix])
154
+ x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
155
+ return x, y
156
+
157
+
158
+ def load_datasets(data_dir: str = "data/shakespeare") -> Tuple[np.ndarray, np.ndarray]:
159
+ train_data = np.memmap(os.path.join(data_dir, "train.bin"), dtype=np.uint16, mode="r")
160
+ val_data = np.memmap(os.path.join(data_dir, "val.bin"), dtype=np.uint16, mode="r")
161
+ return train_data, val_data
162
+
163
+
164
+ if __name__ == "__main__":
165
+ torch.set_float32_matmul_precision("high")
166
+ main()
quantize/gptq.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This adapts GPTQ's quantization process: https://github.com/IST-DASLab/gptq/
2
+ # E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
3
+ # portions copyright by the authors licensed under the Apache License 2.0
4
+ import gc
5
+ import sys
6
+ import time
7
+ from pathlib import Path
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from datasets import load_dataset
12
+
13
+ # support running without installing as a package
14
+ wd = Path(__file__).parent.parent.resolve()
15
+ sys.path.append(str(wd))
16
+
17
+ from lit_llama import LLaMA, Tokenizer
18
+ from lit_llama.quantization import GPTQQuantizer
19
+ from lit_llama.utils import EmptyInitOnDevice, llama_model_lookup
20
+
21
+
22
+ def get_sample_data():
23
+ traindata = load_dataset(
24
+ "allenai/c4",
25
+ "allenai--c4",
26
+ data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
27
+ split="train",
28
+ )
29
+ # heuristic for the data size?
30
+ txt = "\n".join(
31
+ traindata[i]["text"] for i in torch.randperm(len(traindata))[:1000].tolist()
32
+ )
33
+ return txt
34
+
35
+
36
+ @torch.no_grad()
37
+ def llama_blockwise_quantization(
38
+ model, sample_inputs, working_device, *, bits=4, groupsize=-1
39
+ ):
40
+ """
41
+ This is the classic post-training quantization of all linear layers.
42
+ We quantize in order, i.e. when observing the inputs, we use the outputs of the previously quantized layers rather
43
+ than doing them all at once.
44
+ """
45
+ print(model)
46
+ print(model.config)
47
+
48
+ print("Getting inputs for first block")
49
+ model.transformer.wte.to(working_device)
50
+ sample_inputs = sample_inputs.to(working_device)
51
+ inps = model.transformer.wte(sample_inputs)
52
+ model.transformer.wte.to("cpu")
53
+ torch.cuda.empty_cache()
54
+
55
+ rope_cache = model.build_rope_cache(sample_inputs)
56
+ mask_cache = model.build_mask_cache(sample_inputs)
57
+
58
+ print("Starting to quantize blocks")
59
+ outs = torch.zeros_like(inps)
60
+
61
+ # better than relying on enumeration? originally the code bundled
62
+ # the two mlp fc layers
63
+ # we could automate this with a lot of hooks and another iteration
64
+ submodules_to_process = [
65
+ "attn.c_attn",
66
+ "attn.c_proj",
67
+ "mlp.c_fc1",
68
+ "mlp.c_fc2",
69
+ "mlp.c_proj",
70
+ ]
71
+
72
+ for i, block in enumerate(model.transformer.h):
73
+ block.to(working_device)
74
+
75
+ for name in submodules_to_process:
76
+ print(i, name, end=" ")
77
+ t0 = time.perf_counter()
78
+ print("collecting stats", end=" ")
79
+ sys.stdout.flush()
80
+ module = block.get_submodule(name)
81
+
82
+ gptq = GPTQQuantizer(
83
+ module,
84
+ bits=bits,
85
+ groupsize=groupsize,
86
+ actorder=(groupsize == -1),
87
+ )
88
+ handle = module.register_forward_hook(gptq.collect_input_stats)
89
+ for j in range(inps.size(0)):
90
+ outs[j : j + 1], _ = block(
91
+ inps[j : j + 1],
92
+ rope=rope_cache,
93
+ mask=mask_cache,
94
+ max_seq_length=model.config.block_size
95
+ )
96
+
97
+ handle.remove()
98
+
99
+ print("quantizing", end=" ")
100
+ sys.stdout.flush()
101
+ q_module, error = gptq.quantize()
102
+
103
+ # replace the linear module with the quantized module
104
+ pname, dname = name.rsplit(".", 1)
105
+ setattr(block.get_submodule(pname), dname, q_module)
106
+
107
+ # cleanup in an attempt to not run out of memory
108
+ del gptq
109
+ gc.collect()
110
+ torch.cuda.empty_cache()
111
+ t1 = time.perf_counter()
112
+ print(f"time {int(t1 - t0 + 0.5)}s quantization error {error:.1f}")
113
+
114
+ for j in range(inps.size(0)):
115
+ outs[j : j + 1], _ = block(
116
+ inps[j : j + 1],
117
+ rope=rope_cache,
118
+ mask=mask_cache,
119
+ max_seq_length=model.config.block_size
120
+ )
121
+
122
+ block.cpu()
123
+ gc.collect()
124
+ torch.cuda.empty_cache()
125
+
126
+ # the outputs are the next block's inputs and we'll reuse the old inputs
127
+ inps, outs = outs, inps
128
+
129
+ model.transformer.ln_f.to(working_device)
130
+ for j in range(inps.size(0)):
131
+ outs[j : j + 1] = model.transformer.ln_f(inps[j : j + 1])
132
+ model.transformer.ln_f.to("cpu")
133
+ inps, outs = outs, inps
134
+
135
+ model.lm_head.to(working_device)
136
+ gptq = GPTQQuantizer(
137
+ model.lm_head,
138
+ bits=bits,
139
+ groupsize=groupsize,
140
+ actorder=(groupsize == -1),
141
+ )
142
+ handle = model.lm_head.register_forward_hook(gptq.collect_input_stats)
143
+ for j in range(inps.size(0)):
144
+ model.lm_head(inps[j : j + 1])
145
+ handle.remove()
146
+ q_module, error = gptq.quantize()
147
+ model.lm_head = q_module
148
+ model.lm_head.to("cpu")
149
+
150
+
151
+ def main(
152
+ *,
153
+ checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
154
+ output_path: Optional[Path] = None,
155
+ tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
156
+ n_samples: int = 128,
157
+ dtype: str = "float32",
158
+ quantize: Optional[str] = None,
159
+ ) -> None:
160
+ """Generates text samples based on a pre-trained LLaMA model and tokenizer.
161
+
162
+ Args:
163
+ checkpoint_path: The checkpoint path to load.
164
+ output_path: Path to write the quantized model's state dict to.
165
+ tokenizer_path: The tokenizer path to load.
166
+ n_samples: Number of example inputs to use for statistics (default: 128)
167
+ dtype: The dtype to use to load the model.
168
+ quantize: Mode to quantize the model to:
169
+ ``"gptq.int4"``: GPTQ 4-bit mode.
170
+ Note that ``"llm.int8"```does not need a quantization step.
171
+ """
172
+ assert checkpoint_path.is_file()
173
+ assert tokenizer_path.is_file()
174
+ if output_path is None:
175
+ output_path = checkpoint_path.parent / "llama-gptq.4bit.pth"
176
+ assert output_path.parent.is_dir() and (not output_path.exists() or output_path.is_file())
177
+
178
+ device = "cuda"
179
+
180
+ dt = getattr(torch, dtype, None)
181
+ if not isinstance(dt, torch.dtype):
182
+ raise ValueError(f"{dtype} is not a valid dtype.")
183
+ dtype = dt
184
+
185
+ if quantize == "gptq.int4":
186
+ bits = 4
187
+ elif quantize == "gptq.int8":
188
+ bits = 8
189
+ else:
190
+ raise RuntimeError(f"unknown/unsupported quantization mode {quantize}")
191
+
192
+ # we avoid loading the entire model on the GPU and do this block by block
193
+ with EmptyInitOnDevice(
194
+ device="cpu",
195
+ dtype=dtype,
196
+ ):
197
+ print("Loading model ...", file=sys.stderr)
198
+ t0 = time.time()
199
+ checkpoint = torch.load(checkpoint_path)
200
+ name = llama_model_lookup(checkpoint)
201
+ model = LLaMA.from_name(name)
202
+ model.load_state_dict(checkpoint)
203
+ print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
204
+
205
+ model.eval()
206
+
207
+ tokenizer = Tokenizer(tokenizer_path)
208
+
209
+ test_string = get_sample_data()
210
+ encoded_text = tokenizer.encode(
211
+ test_string,
212
+ bos=True,
213
+ eos=False,
214
+ )
215
+ block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30)
216
+ encoded_text = encoded_text[: n_samples * block_size].reshape(n_samples, block_size)
217
+
218
+ t0 = time.perf_counter()
219
+ llama_blockwise_quantization(model, encoded_text, device, bits=bits)
220
+ t = time.perf_counter() - t0
221
+
222
+ print(
223
+ f"\n\nTime for quantization: {t:.02f} sec total",
224
+ file=sys.stderr,
225
+ )
226
+ print(
227
+ f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB",
228
+ file=sys.stderr,
229
+ )
230
+
231
+ torch.save(model.state_dict(), output_path)
232
+
233
+
234
+ if __name__ == "__main__":
235
+ from jsonargparse import CLI
236
+
237
+ torch.set_float32_matmul_precision("high")
238
+ CLI(main)
scripts/convert_checkpoint.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import shutil
3
+ from pathlib import Path
4
+ from typing import Dict
5
+
6
+ import torch
7
+ from tqdm import tqdm
8
+
9
+ """
10
+ Sample usage:
11
+
12
+ ```bash
13
+ python -m scripts.convert_checkpoint -h
14
+
15
+ python -m scripts.convert_checkpoint converted
16
+ ```
17
+ """
18
+
19
+
20
+ def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]:
21
+ converted = {}
22
+ converted["transformer.wte.weight"] = state_dict["tok_embeddings.weight"].to(dtype)
23
+ converted["lm_head.weight"] = state_dict["output.weight"].to(dtype)
24
+ converted["transformer.ln_f.scale"] = state_dict["norm.weight"].to(dtype)
25
+
26
+ for layer_idx in sorted(set([k.split(".")[1] for k in state_dict if k.startswith("layers")])):
27
+ # attention
28
+ # the wq, wk, wv from the FB model are stacked in our model as c_attn
29
+ converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = torch.cat(
30
+ (
31
+ state_dict[f"layers.{layer_idx}.attention.wq.weight"].to(dtype),
32
+ state_dict[f"layers.{layer_idx}.attention.wk.weight"].to(dtype),
33
+ state_dict[f"layers.{layer_idx}.attention.wv.weight"].to(dtype),
34
+ )
35
+ )
36
+ converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = state_dict[
37
+ f"layers.{layer_idx}.attention.wo.weight"
38
+ ].to(dtype)
39
+ # mlp
40
+ converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = state_dict[
41
+ f"layers.{layer_idx}.feed_forward.w1.weight"
42
+ ].to(dtype)
43
+ converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = state_dict[
44
+ f"layers.{layer_idx}.feed_forward.w2.weight"
45
+ ].to(dtype)
46
+ converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = state_dict[
47
+ f"layers.{layer_idx}.feed_forward.w3.weight"
48
+ ].to(dtype)
49
+ # rms norm
50
+ converted[f"transformer.h.{layer_idx}.rms_1.scale"] = state_dict[f"layers.{layer_idx}.attention_norm.weight"].to(dtype)
51
+ converted[f"transformer.h.{layer_idx}.rms_2.scale"] = state_dict[f"layers.{layer_idx}.ffn_norm.weight"].to(dtype)
52
+ return converted
53
+
54
+
55
+ shard_dims = {
56
+ "lm_head.weight": 0,
57
+ "wte.weight": 1,
58
+ "attn.c_attn.weight": 0,
59
+ "attn.c_proj.weight": 1,
60
+ "mlp.c_fc1.weight": 0,
61
+ "mlp.c_fc2.weight": 0,
62
+ "mlp.c_proj.weight": 1
63
+ }
64
+
65
+
66
+ def meta_weights_for_nano_model(
67
+ *,
68
+ output_dir: Path = Path("checkpoints/lit-llama"),
69
+ checkpoint_dir: Path = Path("checkpoints/llama/"),
70
+ model_size: str = "7B",
71
+ dtype: str = "float32",
72
+ ) -> None:
73
+ output_dir = output_dir / model_size
74
+ checkpoint_dir = checkpoint_dir / model_size
75
+ output_dir.mkdir(parents=True, exist_ok=True)
76
+
77
+ # the tokenizer is the same for all model sizes, so we store it in the parent dir
78
+ shutil.copy(checkpoint_dir.parent / "tokenizer.model", output_dir.parent)
79
+
80
+ dt = getattr(torch, dtype, None)
81
+ if not isinstance(dt, torch.dtype):
82
+ raise ValueError(f"{dtype} is not a valid dtype.")
83
+ dtype = dt
84
+
85
+ checkpoint_files = sorted(checkpoint_dir.glob("*.pth"))
86
+ checkpoint_files.sort()
87
+ n_checkpoints = len(checkpoint_files)
88
+
89
+ if n_checkpoints == 0:
90
+ raise RuntimeError(f"No checkpoints were found at checkpoint_dir {checkpoint_dir}. `consolidated.0*.pth` files expected at that location.")
91
+
92
+ # for the bigger models, there are multiple model-parallel checkpoints
93
+ # and we combine them into one single file
94
+ combined = None
95
+ for file in tqdm(checkpoint_files, total=n_checkpoints):
96
+ checkpoint = torch.load(file, map_location="cpu")
97
+ converted = convert_state_dict(checkpoint, dtype=dtype)
98
+ if combined is None:
99
+ combined = converted
100
+ continue
101
+ for name, param in converted.items():
102
+ dim = None
103
+ for k, d in shard_dims.items():
104
+ if k in name:
105
+ dim = d
106
+ break
107
+ if dim is None:
108
+ # Extra check: assert that tensors are the same if not sharded
109
+ # assert torch.allclose(combined[name], param)
110
+ continue
111
+ combined[name] = torch.cat((combined[name], param), dim=dim)
112
+
113
+ del checkpoint
114
+ del converted
115
+ gc.collect()
116
+
117
+ for name, param in combined.items():
118
+ if "c_attn" not in name:
119
+ continue
120
+
121
+ # Turn [Q1, K1, V1, Q2, K2, V2, ...] into [Q1, Q2, ..., K1, K2, .., V1, V2, ...]
122
+
123
+ src_chunk_len = param.shape[0] // n_checkpoints
124
+ mat_len = src_chunk_len // 3
125
+ dst_chunk_len = mat_len * n_checkpoints
126
+ attn = torch.clone(param)
127
+ for i in range(n_checkpoints):
128
+ for j in range(3):
129
+ param[j * dst_chunk_len + i * mat_len: j * dst_chunk_len + (i+1) * mat_len] = \
130
+ attn[i * src_chunk_len + j * mat_len: i * src_chunk_len + (j+1) * mat_len]
131
+
132
+ del attn
133
+ gc.collect()
134
+
135
+ torch.save(combined, output_dir / "lit-llama.pth")
136
+
137
+
138
+ if __name__ == "__main__":
139
+ from jsonargparse import CLI
140
+
141
+ CLI(meta_weights_for_nano_model)
scripts/convert_hf_checkpoint.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import contextlib
3
+ import gc
4
+ import json
5
+ import shutil
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ import torch
10
+
11
+ # support running without installing as a package
12
+ wd = Path(__file__).parent.parent.resolve()
13
+ sys.path.append(str(wd))
14
+
15
+ from lit_llama.model import LLaMA, LLaMAConfig
16
+ from lit_llama.utils import EmptyInitOnDevice, lazy_load, incremental_save
17
+
18
+
19
+ @torch.no_grad()
20
+ def convert_hf_checkpoint(
21
+ *,
22
+ output_dir: Path = Path("checkpoints/lit-llama/7B"),
23
+ checkpoint_dir: Path = Path("checkpoints/hf-llama/7B"),
24
+ model_size: str = "7B",
25
+ dtype: str = "float32",
26
+ verify: bool = False,
27
+ ) -> None:
28
+ """
29
+ Perform the reverse operation of: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
30
+ """
31
+ output_dir.mkdir(parents=True, exist_ok=True)
32
+
33
+ # the tokenizer is the same for all model sizes, so we store it in the parent dir
34
+ shutil.copy(checkpoint_dir / "tokenizer.model", output_dir.parent)
35
+
36
+ dt = getattr(torch, dtype, None)
37
+ if not isinstance(dt, torch.dtype):
38
+ raise ValueError(f"{dtype} is not a valid dtype.")
39
+ dtype = dt
40
+
41
+ print("Initializing lit-llama")
42
+ config = LLaMAConfig.from_name(model_size)
43
+
44
+ with EmptyInitOnDevice(device="meta", dtype=dtype):
45
+ model = LLaMA(config)
46
+
47
+ qkv_size = model.transformer.h[0].attn.c_attn.weight.shape[0] // 3
48
+
49
+ # initialize a new empty state dict to hold our new weights
50
+ sd_meta = model.state_dict()
51
+ sd = {}
52
+
53
+ # Load the json file containing weight mapping
54
+ pytorch_bin_map_json_path = checkpoint_dir / "pytorch_model.bin.index.json"
55
+ with open(pytorch_bin_map_json_path) as json_map:
56
+ bin_index = json.load(json_map)
57
+ bin_files = set(checkpoint_dir / bin for bin in bin_index["weight_map"].values())
58
+ if not bin_files:
59
+ raise ValueError(f"Expected {str(checkpoint_dir)!r} to contain .bin files")
60
+
61
+ def permute(w):
62
+ dim = config.n_embd
63
+ w = w._load_tensor().to(dtype)
64
+ return (
65
+ w.view(config.n_head, 2, dim // config.n_head // 2, dim)
66
+ .transpose(1, 2)
67
+ .reshape(dim, dim)
68
+ )
69
+
70
+ weight_map = {
71
+ "self_attn.o_proj.weight": "attn.c_proj.weight",
72
+ "self_attn.q_proj.weight": "attn.c_attn.weight",
73
+ "self_attn.k_proj.weight": "attn.c_attn.weight",
74
+ "self_attn.v_proj.weight": "attn.c_attn.weight",
75
+ "mlp.gate_proj.weight": "mlp.c_fc1.weight",
76
+ "mlp.up_proj.weight": "mlp.c_fc2.weight",
77
+ "mlp.down_proj.weight": "mlp.c_proj.weight",
78
+ "input_layernorm.weight": "rms_1.scale",
79
+ "post_attention_layernorm.weight": "rms_2.scale",
80
+ "model.embed_tokens.weight": "transformer.wte.weight",
81
+ "model.norm.weight": "transformer.ln_f.scale",
82
+ "lm_head.weight": "lm_head.weight",
83
+ }
84
+
85
+ print(f"Saving to disk at {output_dir}")
86
+ unprocessed_weights = collections.defaultdict(dict)
87
+
88
+ with incremental_save(output_dir / "lit-llama.pth") as saver:
89
+ # for checkpoints that split the QKV across several files, we need to keep all the bin files
90
+ # open, so we use `ExitStack` to close them all together at the end
91
+ with contextlib.ExitStack() as stack:
92
+ for bin_file in bin_files:
93
+ print("Processing", bin_file)
94
+ hf_weights = stack.enter_context(lazy_load(bin_file))
95
+ for name, param in hf_weights.items():
96
+ skip = False
97
+ if "rotary_emb.inv_freq" in name:
98
+ continue
99
+ if "model.layers" in name:
100
+ block_id = int(name.split(".")[2])
101
+ from_name = ".".join(name.split(".")[3:])
102
+ to_name = weight_map[from_name]
103
+ sd_key = f"transformer.h.{block_id}.{to_name}"
104
+
105
+ if "q_proj" in name:
106
+ unprocessed_weights[sd_key]["q_proj"] = param
107
+ skip = True
108
+ elif "k_proj" in name:
109
+ unprocessed_weights[sd_key]["k_proj"] = param
110
+ skip = True
111
+ elif "v_proj" in name:
112
+ unprocessed_weights[sd_key]["v_proj"] = param
113
+ skip = True
114
+ if skip and len(unprocessed_weights[sd_key]) == 3:
115
+ w = torch.empty(
116
+ sd_meta[sd_key].shape, dtype=sd_meta[sd_key].dtype
117
+ )
118
+ w[:qkv_size] = permute(unprocessed_weights[sd_key]["q_proj"])
119
+ w[qkv_size:-qkv_size] = permute(
120
+ unprocessed_weights[sd_key]["k_proj"]
121
+ )
122
+ w[-qkv_size:] = (
123
+ unprocessed_weights[sd_key]["v_proj"]
124
+ ._load_tensor()
125
+ .to(dtype)
126
+ )
127
+ sd[sd_key] = w
128
+ del unprocessed_weights[sd_key]
129
+ skip = False
130
+ else:
131
+ sd[sd_key] = param._load_tensor().to(dtype)
132
+ else:
133
+ sd_key = weight_map[name]
134
+ sd[sd_key] = param._load_tensor().to(dtype)
135
+ if not skip:
136
+ sd[sd_key] = saver.store_early(sd[sd_key])
137
+ gc.collect()
138
+ saver.save(sd)
139
+
140
+ assert len(unprocessed_weights) == 0, f"unexpected partial weights {list(unprocessed_weights)}"
141
+ if verify:
142
+ try:
143
+ from transformers import LlamaForCausalLM
144
+ except ImportError:
145
+ raise ImportError("verify=True requires transformers to be installed, please `pip install transformers`")
146
+ print("Verifying...")
147
+
148
+ token_sample = torch.randint(0, config.vocab_size, size=(1, config.block_size), dtype=torch.int64)
149
+ out = model(token_sample)
150
+ del model
151
+ gc.collect()
152
+
153
+ print("Loading original model for comparison")
154
+ model_hf = LlamaForCausalLM.from_pretrained(checkpoint_dir)
155
+ out_hf = model_hf(token_sample)["logits"]
156
+
157
+ print("Comparing outputs")
158
+ assert out.device.type == out_hf.device.type
159
+ assert out.dtype == out_hf.dtype
160
+ assert torch.testing.assert_close(out, out_hf)
161
+
162
+
163
+ if __name__ == "__main__":
164
+ from jsonargparse import CLI
165
+
166
+ CLI(convert_hf_checkpoint)
167
+
scripts/convert_lora_weights.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import lightning as L
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ # support running without installing as a package
11
+ wd = Path(__file__).parent.parent.resolve()
12
+ sys.path.append(str(wd))
13
+
14
+ from lit_llama import LLaMA
15
+ from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
16
+ from lit_llama.lora import lora
17
+
18
+ def del_lora_state_dict(model: nn.Module):
19
+ base_model_dict = model.state_dict()
20
+ key_to_delete = [k for k in base_model_dict if "lora_" in k]
21
+ for del_key in key_to_delete:
22
+ del base_model_dict[del_key]
23
+ return base_model_dict
24
+
25
+
26
+ def lora_model_lookup(checkpoint: dict) -> int:
27
+ """Returns the LoRA rank from the adapter checkpoint.
28
+
29
+ """
30
+ return checkpoint["transformer.h.0.attn.c_attn.lora_B"].shape[1]
31
+
32
+
33
+ def main(
34
+ accelerator: str = "auto",
35
+ lora_path: Optional[Path] = None,
36
+ checkpoint_path: Optional[Path] = None,
37
+ dtype: str = "bfloat16",
38
+ ) -> None:
39
+ """Merges lora weights to base model.
40
+
41
+ Args:
42
+ accelerator: The hardware to run on. Possible choices are:
43
+ ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
44
+ lora_path: Path to the checkpoint with trained LoRA weights, which are the output of
45
+ `finetune_lora.py`.
46
+ checkpoint_path: The checkpoint path to load.
47
+ dtype: `torch.dtype` to work with
48
+ """
49
+ if not lora_path:
50
+ lora_path = Path("out/lora/alpaca/lit-llama-lora-finetuned.pth")
51
+ if not checkpoint_path:
52
+ checkpoint_path = Path(f"./checkpoints/lit-llama/7B/lit-llama.pth")
53
+
54
+ assert lora_path.is_file()
55
+ assert checkpoint_path.is_file()
56
+
57
+ fabric = L.Fabric(accelerator=accelerator, devices=1)
58
+
59
+ dt = getattr(torch, dtype, None)
60
+ if not isinstance(dt, torch.dtype):
61
+ raise ValueError(f"{dtype} is not a valid dtype.")
62
+ dtype = dt
63
+
64
+ print("Loading model ...", file=sys.stderr)
65
+ t0 = time.time()
66
+
67
+ with (lazy_load(checkpoint_path) as pretrained_checkpoint,
68
+ lazy_load(lora_path) as lora_checkpoint):
69
+ name = llama_model_lookup(pretrained_checkpoint)
70
+ rank = lora_model_lookup(lora_checkpoint)
71
+
72
+ with EmptyInitOnDevice(
73
+ device=fabric.device, dtype=dtype
74
+ ), lora(r=rank, alpha=16, dropout=0.05, enabled=True):
75
+ model = LLaMA.from_name(name)
76
+
77
+ # 1. Load the pretrained weights
78
+ model.load_state_dict(pretrained_checkpoint, strict=False)
79
+ # 2. Load the fine-tuned lora weights
80
+ model.load_state_dict(lora_checkpoint, strict=False)
81
+
82
+ print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
83
+
84
+ model.eval()
85
+ base_model_dict = del_lora_state_dict(model)
86
+ save_path = lora_path.with_stem(f"{lora_path.stem}-lora-merged-weights")
87
+ print("Saving LoRA to base model weights ...")
88
+ torch.save(base_model_dict, save_path)
89
+ print(f"Model saved at {save_path}")
90
+
91
+
92
+ if __name__ == "__main__":
93
+ from jsonargparse import CLI
94
+
95
+ CLI(main)
scripts/download.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+ from urllib.request import urlretrieve
4
+
5
+ files = {
6
+ "original_model.py": "https://gist.githubusercontent.com/lantiga/fd36849fb1c498da949a0af635318a7b/raw/7dd20f51c2a1ff2886387f0e25c1750a485a08e1/llama_model.py",
7
+ "original_adapter.py": "https://gist.githubusercontent.com/awaelchli/546f33fcdb84cc9f1b661ca1ca18418d/raw/e81d8f35fb1fec53af1099349b0c455fc8c9fb01/original_adapter.py",
8
+ }
9
+
10
+
11
+ def download_original(wd: str) -> None:
12
+ for file, url in files.items():
13
+ filepath = os.path.join(wd, file)
14
+ if not os.path.isfile(filepath):
15
+ print(f"Downloading original implementation to {filepath!r}")
16
+ urlretrieve(url=url, filename=file)
17
+ print("Done")
18
+ else:
19
+ print("Original implementation found. Skipping download.")
20
+
21
+
22
+ def download_from_hub(repo_id: Optional[str] = None, local_dir: str = "checkpoints/hf-llama/7B") -> None:
23
+ if repo_id is None:
24
+ raise ValueError("Please pass `--repo_id=...`. You can try googling 'huggingface hub llama' for options.")
25
+
26
+ from huggingface_hub import snapshot_download
27
+
28
+ snapshot_download(repo_id, local_dir=local_dir, local_dir_use_symlinks=False)
29
+
30
+
31
+ if __name__ == "__main__":
32
+ from jsonargparse import CLI
33
+
34
+ CLI(download_from_hub)
scripts/prepare_alpaca.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation derived from https://github.com/tloen/alpaca-lora"""
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ # support running without installing as a package
6
+ wd = Path(__file__).parent.parent.resolve()
7
+ sys.path.append(str(wd))
8
+
9
+ import torch
10
+ import requests
11
+ import json
12
+ from torch.utils.data import random_split
13
+ from lit_llama.tokenizer import Tokenizer
14
+ from tqdm import tqdm
15
+
16
+
17
+ DATA_FILE = "https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json"
18
+ DATA_FILE_NAME = "alpaca_data_cleaned_archive.json"
19
+ IGNORE_INDEX = -1
20
+
21
+
22
+ def prepare(
23
+ destination_path: Path = Path("data/alpaca"),
24
+ tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
25
+ test_split_size: int = 2000,
26
+ max_seq_length: int = 256,
27
+ seed: int = 42,
28
+ mask_inputs: bool = False, # as in alpaca-lora
29
+ data_file_name: str = DATA_FILE_NAME
30
+ ) -> None:
31
+ """Prepare the Alpaca dataset for instruction tuning.
32
+
33
+ The output is a training and validation dataset saved as `train.pt` and `val.pt`,
34
+ which stores the preprocessed and tokenized prompts and labels.
35
+ """
36
+
37
+ destination_path.mkdir(parents=True, exist_ok=True)
38
+ file_path = destination_path / data_file_name
39
+ download(file_path)
40
+
41
+ # TODO: If we don't have the Meta weights, where do we get the tokenizer from?
42
+ tokenizer = Tokenizer(tokenizer_path)
43
+
44
+ with open(file_path, "r") as file:
45
+ data = json.load(file)
46
+
47
+ # Partition the dataset into train and test
48
+ train_split_size = len(data) - test_split_size
49
+ train_set, test_set = random_split(
50
+ data,
51
+ lengths=(train_split_size, test_split_size),
52
+ generator=torch.Generator().manual_seed(seed),
53
+ )
54
+ train_set, test_set = list(train_set), list(test_set)
55
+
56
+ print(f"train has {len(train_set):,} samples")
57
+ print(f"val has {len(test_set):,} samples")
58
+
59
+ print("Processing train split ...")
60
+ train_set = [prepare_sample(sample, tokenizer, max_seq_length, mask_inputs) for sample in tqdm(train_set)]
61
+ torch.save(train_set, file_path.parent / "train.pt")
62
+
63
+ print("Processing test split ...")
64
+ test_set = [prepare_sample(sample, tokenizer, max_seq_length, mask_inputs) for sample in tqdm(test_set)]
65
+ torch.save(test_set, file_path.parent / "test.pt")
66
+
67
+
68
+ def download(file_path: Path):
69
+ """Downloads the raw json data file and saves it in the given destination."""
70
+ if file_path.exists():
71
+ return
72
+ with open(file_path, "w") as f:
73
+ f.write(requests.get(DATA_FILE).text)
74
+
75
+
76
+ def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool = True):
77
+ """Processes a single sample.
78
+
79
+ Each sample in the dataset consists of:
80
+ - instruction: A string describing the task
81
+ - input: A string holding a special input value for the instruction.
82
+ This only applies to some samples, and in others this is empty.
83
+ - output: The response string
84
+
85
+ This function processes this data to produce a prompt text and a label for
86
+ supervised training. The input text is formed as a single message including all
87
+ the instruction, the input (optional) and the response.
88
+ The label/target is the same message but can optionally have the instruction + input text
89
+ masked out (mask_inputs=True).
90
+
91
+ Finally, both the prompt and the label get tokenized. If desired, all tokens
92
+ in the label that correspond to the original input prompt get masked out (default).
93
+ """
94
+ full_prompt = generate_prompt(example)
95
+ full_prompt_and_response = full_prompt + example["output"]
96
+ encoded_full_prompt = tokenize(tokenizer, full_prompt, max_length=max_length, eos=False)
97
+ encoded_full_prompt_and_response = tokenize(tokenizer, full_prompt_and_response, eos=True, max_length=max_length)
98
+
99
+ # The labels are the full prompt with response, but with the prompt masked out
100
+ labels = encoded_full_prompt_and_response.clone()
101
+ if mask_inputs:
102
+ labels[:len(encoded_full_prompt)] = IGNORE_INDEX
103
+
104
+ return {**example, "input_ids": encoded_full_prompt_and_response, "input_ids_no_response": encoded_full_prompt, "labels": labels}
105
+
106
+
107
+ def tokenize(tokenizer: Tokenizer, string: str, max_length: int, eos=True) -> torch.Tensor:
108
+ return tokenizer.encode(string, bos=True, eos=eos, max_length=max_length)
109
+
110
+
111
+ def generate_prompt(example):
112
+ """Generates a standardized message to prompt the model with an instruction, optional input and a
113
+ 'response' field."""
114
+
115
+ if example["input"]:
116
+ return (
117
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
118
+ "Write a response that appropriately completes the request.\n\n"
119
+ f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:"
120
+ )
121
+ return (
122
+ "Below is an instruction that describes a task. "
123
+ "Write a response that appropriately completes the request.\n\n"
124
+ f"### Instruction:\n{example['instruction']}\n\n### Response:"
125
+ )
126
+
127
+
128
+ if __name__ == "__main__":
129
+ from jsonargparse import CLI
130
+
131
+ CLI(prepare)
scripts/prepare_any_text.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation derived from https://github.com/tloen/alpaca-lora"""
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ # support running without installing as a package
6
+ wd = Path(__file__).parent.parent.resolve()
7
+ sys.path.append(str(wd))
8
+
9
+ import torch
10
+ import requests
11
+ import json
12
+ from torch.utils.data import random_split
13
+ from lit_llama.tokenizer import Tokenizer
14
+ from tqdm import tqdm
15
+
16
+
17
+ IGNORE_INDEX = -1
18
+
19
+ DATA_FILE_NAME = "input.txt"
20
+
21
+
22
+ def prepare(
23
+ destination_path: Path = Path("data/any"),
24
+ tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
25
+ test_split_ratio: float = 0.9, # default 90% train, 10% validation
26
+ max_seq_length: int = 256,
27
+ seed: int = 42,
28
+ data_file_name: str = DATA_FILE_NAME,
29
+ ) -> None:
30
+ """Prepare any dataset for finetuning (akin to Shakespheare full tuning).
31
+
32
+ The output is a training and validation dataset saved as `train.pt` and `val.pt`,
33
+ which stores the preprocessed and tokenized prompts and labels.
34
+ """
35
+
36
+ destination_path.mkdir(parents=True, exist_ok=True)
37
+ file_path = destination_path / data_file_name
38
+ if not file_path.exists():
39
+ raise AssertionError(f"{data_file_name} is provided by the user")
40
+
41
+ # TODO: If we don't have the Meta weights, where do we get the tokenizer from?
42
+ tokenizer = Tokenizer(tokenizer_path)
43
+
44
+ data = []
45
+
46
+ with open(file_path, "r") as input_file:
47
+ for line in input_file.readlines():
48
+ data.append(line)
49
+
50
+ # Partition the dataset into train and test
51
+ train_split_size = int(len(data) * test_split_ratio)
52
+ test_split_size = len(data) - train_split_size
53
+ train_set, test_set = random_split(
54
+ data,
55
+ lengths=(train_split_size, test_split_size),
56
+ generator=torch.Generator().manual_seed(seed),
57
+ )
58
+ train_set, test_set = list(train_set), list(test_set)
59
+
60
+ print(f"train has {len(train_set):,} samples")
61
+ print(f"val has {len(test_set):,} samples")
62
+
63
+ print("Processing train split ...")
64
+ train_set = [
65
+ prepare_line(line, tokenizer, max_seq_length) for line in tqdm(train_set)
66
+ ]
67
+ torch.save(train_set, file_path.parent / "train.pt")
68
+
69
+ print("Processing test split ...")
70
+ test_set = [
71
+ prepare_line(line, tokenizer, max_seq_length) for line in tqdm(test_set)
72
+ ]
73
+ torch.save(test_set, file_path.parent / "test.pt")
74
+
75
+
76
+ def prepare_line(line: str, tokenizer: Tokenizer, max_length: int):
77
+ """Processes a single sample.
78
+
79
+ This function processes the line to produce the tokenized version of it.
80
+ """
81
+ encoded_full_prompt = tokenize(tokenizer, line, max_length=max_length, eos=False)
82
+ return {
83
+ "input_ids": encoded_full_prompt,
84
+ "labels": encoded_full_prompt,
85
+ }
86
+
87
+
88
+ def tokenize(
89
+ tokenizer: Tokenizer, string: str, max_length: int, eos=True
90
+ ) -> torch.Tensor:
91
+ return tokenizer.encode(string, bos=True, eos=eos, max_length=max_length)
92
+
93
+
94
+ if __name__ == "__main__":
95
+ from jsonargparse import CLI
96
+
97
+ CLI(prepare)
scripts/prepare_dolly.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation derived from https://github.com/tloen/alpaca-lora"""
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ # support running without installing as a package
6
+ wd = Path(__file__).parent.parent.resolve()
7
+ sys.path.append(str(wd))
8
+
9
+ import torch
10
+ import requests
11
+ import json
12
+ from torch.utils.data import random_split
13
+ from lit_llama.tokenizer import Tokenizer
14
+ from tqdm import tqdm
15
+
16
+
17
+ DATA_FILE = "https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl"
18
+ DATA_FILE_NAME = "dolly_data_cleaned.json"
19
+ IGNORE_INDEX = -1
20
+
21
+
22
+ def prepare(
23
+ destination_path: Path = Path("data/dolly"),
24
+ tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
25
+ test_split_size: int = 2000,
26
+ max_seq_length: int = 1024,
27
+ seed: int = 42,
28
+ mask_inputs: bool = False, # as in alpaca-lora
29
+ ) -> None:
30
+ """Prepare the Dolly dataset for instruction tuning.
31
+
32
+ The output is a training and validation dataset saved as `train.pt` and `val.pt`,
33
+ which stores the preprocessed and tokenized prompts and labels.
34
+ """
35
+
36
+ destination_path.mkdir(parents=True, exist_ok=True)
37
+ file_path = destination_path / DATA_FILE_NAME
38
+ download(file_path)
39
+
40
+ # TODO: If we don't have the Meta weights, where do we get the tokenizer from?
41
+ tokenizer = Tokenizer(tokenizer_path)
42
+
43
+ with open(file_path, "r") as file:
44
+ data = file.readlines()
45
+ data = [json.loads(line) for line in data]
46
+ for item in data:
47
+ item["input"] = item.pop("context")
48
+ item["output"] = item.pop("response")
49
+
50
+ # Partition the dataset into train and test
51
+ train_split_size = len(data) - test_split_size
52
+ train_set, test_set = random_split(
53
+ data,
54
+ lengths=(train_split_size, test_split_size),
55
+ generator=torch.Generator().manual_seed(seed),
56
+ )
57
+ train_set, test_set = list(train_set), list(test_set)
58
+
59
+ print(f"train has {len(train_set):,} samples")
60
+ print(f"val has {len(test_set):,} samples")
61
+
62
+ print("Processing train split ...")
63
+ train_set = [prepare_sample(sample, tokenizer, max_seq_length, mask_inputs) for sample in tqdm(train_set)]
64
+ torch.save(train_set, file_path.parent / "train.pt")
65
+
66
+ print("Processing test split ...")
67
+ test_set = [prepare_sample(sample, tokenizer, max_seq_length, mask_inputs) for sample in tqdm(test_set)]
68
+ torch.save(test_set, file_path.parent / "test.pt")
69
+
70
+
71
+ def download(file_path: Path):
72
+ """Downloads the raw json data file and saves it in the given destination."""
73
+ if file_path.exists():
74
+ return
75
+ with open(file_path, "w") as f:
76
+ f.write(requests.get(DATA_FILE).text)
77
+
78
+
79
+ def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool = True):
80
+ """Processes a single sample.
81
+
82
+ Each sample in the dataset consists of:
83
+ - instruction: A string describing the task
84
+ - input: A string holding a special input value for the instruction.
85
+ This only applies to some samples, and in others this is empty.
86
+ - output: The response string
87
+
88
+ This function processes this data to produce a prompt text and a label for
89
+ supervised training. The prompt text is formed as a single message including both
90
+ the instruction and the input. The label/target is the same message but with the
91
+ response attached.
92
+
93
+ Finally, both the prompt and the label get tokenized. If desired, all tokens
94
+ in the label that correspond to the original input prompt get masked out (default).
95
+ """
96
+ full_prompt = generate_prompt(example)
97
+ full_prompt_and_response = full_prompt + example["output"]
98
+ encoded_full_prompt = tokenize(tokenizer, full_prompt, max_length=max_length, eos=False)
99
+ encoded_full_prompt_and_response = tokenize(tokenizer, full_prompt_and_response, eos=True, max_length=max_length)
100
+
101
+ # The labels are the full prompt with response, but with the prompt masked out
102
+ labels = encoded_full_prompt_and_response.clone()
103
+ if mask_inputs:
104
+ labels[:len(encoded_full_prompt)] = IGNORE_INDEX
105
+
106
+ return {**example, "input_ids": encoded_full_prompt_and_response, "input_ids_no_response": encoded_full_prompt, "labels": labels}
107
+
108
+
109
+ def tokenize(tokenizer: Tokenizer, string: str, max_length: int, eos=True) -> torch.Tensor:
110
+ return tokenizer.encode(string, bos=True, eos=eos, max_length=max_length)
111
+
112
+
113
+ def generate_prompt(example):
114
+ """Generates a standardized message to prompt the model with an instruction, optional input and a
115
+ 'response' field."""
116
+
117
+ if example["input"]:
118
+ return (
119
+ f"Below is an instruction that describes a task, paired with an input that provides further context. "
120
+ "Write a response that appropriately completes the request.\n\n"
121
+ f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:"
122
+ )
123
+ return (
124
+ f"Below is an instruction that describes a task. "
125
+ "Write a response that appropriately completes the request.\n\n"
126
+ f"### Instruction:\n{example['instruction']}\n\n### Response:"
127
+ )
128
+
129
+
130
+ if __name__ == "__main__":
131
+ from jsonargparse import CLI
132
+
133
+ CLI(prepare)
scripts/prepare_redpajama.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import glob
3
+ import os
4
+ from pathlib import Path
5
+ import sys
6
+
7
+ # support running without installing as a package
8
+ wd = Path(__file__).parent.parent.resolve()
9
+ sys.path.append(str(wd))
10
+
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+
14
+ from lit_llama import Tokenizer
15
+ import lit_llama.packed_dataset as packed_dataset
16
+
17
+
18
+ filenames_sample = [
19
+ "arxiv_sample.jsonl",
20
+ "book_sample.jsonl",
21
+ "c4_sample.jsonl",
22
+ "cc_2019-30_sample.jsonl",
23
+ "cc_2020-05_sample.jsonl",
24
+ "cc_2021-04_sample.jsonl",
25
+ "cc_2022-05_sample.jsonl",
26
+ "cc_2023-06_sample.jsonl",
27
+ "github_sample.jsonl",
28
+ "stackexchange_sample.jsonl",
29
+ "wikipedia_sample.jsonl",
30
+ ]
31
+
32
+ filename_sets = {
33
+ "arxiv": "arxiv/arxiv*",
34
+ "book": "book/book*",
35
+ "c4": "c4/c4-train*",
36
+ "common_crawl": "common_crawl/*",
37
+ "github": "github/filtered*",
38
+ "stackexchange": "stackexchange/stackexchange*",
39
+ "wikipedia": "wikipedia/wiki*",
40
+ }
41
+
42
+
43
+ def prepare_sample(
44
+ source_path: Path,
45
+ tokenizer_path: Path,
46
+ destination_path: Path,
47
+ chunk_size: int,
48
+ match = ""
49
+ ) -> None:
50
+ """Prepare the "Red Pajama" dataset. We assume tokenizer has been trained (i.e. we reuse LLaMA's tokenizer model)."""
51
+ destination_path.mkdir(parents=True, exist_ok=True)
52
+
53
+ tokenizer = Tokenizer(tokenizer_path)
54
+
55
+ for name in filenames_sample:
56
+ if match and match not in name:
57
+ continue
58
+
59
+ filepath = source_path / name
60
+
61
+ if not filepath.is_file():
62
+ raise RuntimeError(
63
+ f"Input file not found at {filepath}. \n"
64
+ "Make sure you download the data, e.g. wget -i https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through \n"
65
+ "https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T \n"
66
+ "https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n"
67
+ )
68
+
69
+ prefix, _ = os.path.splitext(name)
70
+
71
+ builder = packed_dataset.PackedDatasetBuilder(
72
+ outdir=destination_path,
73
+ prefix=prefix,
74
+ chunk_size=chunk_size,
75
+ sep_token=tokenizer.bos_id,
76
+ dtype="auto",
77
+ vocab_size=tokenizer.vocab_size,
78
+ )
79
+
80
+ print(f"Processing {name}")
81
+
82
+ with open(filepath, encoding="utf-8") as f:
83
+ for row in tqdm(f):
84
+ text = json.loads(row)["text"]
85
+ text_ids = tokenizer.encode(text)
86
+ builder.add_array(np.array(text_ids, dtype=builder.dtype))
87
+
88
+ builder.write_reminder()
89
+
90
+
91
+ def prepare_full(
92
+ source_path: Path,
93
+ tokenizer_path: Path,
94
+ destination_path: Path,
95
+ chunk_size: int,
96
+ match: str = ""
97
+ ) -> None:
98
+ """Prepare the "Red Pajama" dataset. We assume tokenizer has been trained (i.e. we reuse LLaMA's tokenizer model)."""
99
+ import zstandard as zstd
100
+
101
+ destination_path.mkdir(parents=True, exist_ok=True)
102
+
103
+ tokenizer = Tokenizer(tokenizer_path)
104
+
105
+ for set_name, pattern in filename_sets.items():
106
+ if match and match not in set_name:
107
+ continue
108
+
109
+ is_cc = set_name == "common_crawl"
110
+
111
+ filenames = glob.glob(os.path.join(source_path, pattern), recursive=True)
112
+
113
+ if not filenames:
114
+ raise RuntimeError(
115
+ f"No files matching {pattern} found at {source_path}. \n"
116
+ "Make sure you download the data, e.g. wget -i https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through \n"
117
+ "https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T \n"
118
+ "https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n"
119
+ )
120
+
121
+ builder = packed_dataset.PackedDatasetBuilder(
122
+ outdir=destination_path,
123
+ prefix=set_name,
124
+ chunk_size=chunk_size,
125
+ sep_token=tokenizer.bos_id,
126
+ dtype="auto",
127
+ vocab_size=tokenizer.vocab_size,
128
+ )
129
+
130
+ for name in filenames:
131
+ filepath = source_path / name
132
+
133
+ print(f"Processing {name}")
134
+
135
+ if is_cc:
136
+ with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f:
137
+ for row in tqdm(f):
138
+ text = json.loads(row)["text"]
139
+ text_ids = tokenizer.encode(text)
140
+ builder.add_array(np.array(text_ids, dtype=builder.dtype))
141
+ else:
142
+ with open(filepath, encoding="utf-8") as f:
143
+ for row in tqdm(f):
144
+ text = json.loads(row)["text"]
145
+ text_ids = tokenizer.encode(text)
146
+ builder.add_array(np.array(text_ids, dtype=builder.dtype))
147
+
148
+ builder.write_reminder()
149
+
150
+
151
+ def prepare(
152
+ source_path: Path = Path("data/RedPajama-Data-1T-Sample"),
153
+ tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
154
+ destination_path: Path = Path("data/red_pajama_sample"),
155
+ chunk_size: int = 2049 * 1024, # 2048 block size + 1 for causal (from LLama), 1024 blocks
156
+ sample: bool = False,
157
+ match: str = "",
158
+ ) -> None:
159
+ """Prepare the "Red Pajama" dataset. We assume tokenizer has been trained (i.e. we reuse LLaMA's tokenizer model)."""
160
+ if sample:
161
+ prepare_sample(
162
+ source_path=source_path,
163
+ tokenizer_path=tokenizer_path,
164
+ destination_path=destination_path,
165
+ chunk_size=chunk_size,
166
+ match=match,
167
+ )
168
+ else:
169
+ prepare_full(
170
+ source_path=source_path,
171
+ tokenizer_path=tokenizer_path,
172
+ destination_path=destination_path,
173
+ chunk_size=chunk_size,
174
+ match=match,
175
+ )
176
+
177
+
178
+ if __name__ == "__main__":
179
+ from jsonargparse import CLI
180
+
181
+ CLI(prepare)
scripts/prepare_shakespeare.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Andrej Karpathy
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+ import sys
23
+ from pathlib import Path
24
+
25
+ # support running without installing as a package
26
+ wd = Path(__file__).parent.parent.resolve()
27
+ sys.path.append(str(wd))
28
+
29
+ import numpy as np
30
+ import requests
31
+
32
+
33
+ def prepare(destination_path: Path = Path("data/shakespeare")) -> None:
34
+ """Prepare the "Tiny Shakespeare" dataset."""
35
+ destination_path.mkdir(parents=True, exist_ok=True)
36
+
37
+ # download the tiny shakespeare dataset
38
+ input_file_path = destination_path / "input.txt"
39
+ if not input_file_path.exists():
40
+ data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
41
+ with open(input_file_path, "w") as f:
42
+ f.write(requests.get(data_url).text)
43
+
44
+ with open(input_file_path) as f:
45
+ data = f.read()
46
+ n = len(data)
47
+ train_data = data[: int(n * 0.9)]
48
+ val_data = data[int(n * 0.9) :]
49
+
50
+ from lit_llama import Tokenizer
51
+
52
+ Tokenizer.train(input=input_file_path, destination=destination_path, vocab_size=100)
53
+ tokenizer = Tokenizer(destination_path / "tokenizer.model")
54
+ train_ids = tokenizer.encode(train_data)
55
+ val_ids = tokenizer.encode(val_data)
56
+ print(f"train has {len(train_ids):,} tokens")
57
+ print(f"val has {len(val_ids):,} tokens")
58
+
59
+ # export to bin files
60
+ train_ids = np.array(train_ids, dtype=np.uint16)
61
+ val_ids = np.array(val_ids, dtype=np.uint16)
62
+ train_ids.tofile(destination_path / "train.bin")
63
+ val_ids.tofile(destination_path / "val.bin")
64
+
65
+
66
+ if __name__ == "__main__":
67
+ from jsonargparse import CLI
68
+
69
+ CLI(prepare)
setup.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from setuptools import setup, find_packages
4
+
5
+
6
+ _PATH_ROOT = os.path.dirname(__file__)
7
+
8
+ with open(os.path.join(_PATH_ROOT, "README.md"), encoding="utf-8") as fo:
9
+ readme = fo.read()
10
+
11
+ setup(
12
+ name='lit-llama',
13
+ version='0.1.0',
14
+ description='Implementation of the LLaMA language model',
15
+ author='Lightning AI',
16
+ url='https://github.com/lightning-AI/lit-llama',
17
+ install_requires=[
18
+ "torch>=2.0.0",
19
+ "lightning @ git+https://github.com/Lightning-AI/lightning@master",
20
+ "sentencepiece",
21
+ "bitsandbytes",
22
+ ],
23
+ packages=find_packages(),
24
+ long_description=readme,
25
+ long_description_content_type="text/markdown",
26
+ )
tests/conftest.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+ import pytest
5
+
6
+ wd = Path(__file__).parent.parent.absolute()
7
+
8
+
9
+ @pytest.fixture()
10
+ def orig_llama():
11
+ sys.path.append(str(wd))
12
+
13
+ from scripts.download import download_original
14
+
15
+ download_original(wd)
16
+
17
+ import original_model
18
+
19
+ return original_model
20
+
21
+
22
+ @pytest.fixture()
23
+ def orig_llama_adapter():
24
+ sys.path.append(str(wd))
25
+
26
+ from scripts.download import download_original
27
+
28
+ download_original(wd)
29
+
30
+ import original_adapter
31
+
32
+ return original_adapter
33
+
34
+
35
+ @pytest.fixture()
36
+ def lit_llama():
37
+ # this adds support for running tests without the package installed
38
+ sys.path.append(str(wd))
39
+
40
+ import lit_llama
41
+
42
+ return lit_llama
tests/test_adapter.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import asdict
2
+ import pytest
3
+ import sys
4
+ import torch
5
+
6
+
7
+ @pytest.mark.skipif(sys.platform == "win32", reason="EmptyInitOnDevice on CPU not working for Windows.")
8
+ @pytest.mark.parametrize("model_size", ["7B", "13B", "30B", "65B"])
9
+ def test_config_identical(model_size, lit_llama):
10
+ import lit_llama.adapter as llama_adapter
11
+ import lit_llama.model as llama
12
+ from lit_llama.utils import EmptyInitOnDevice
13
+
14
+ llama_config = asdict(llama.LLaMAConfig.from_name(model_size))
15
+ adapter_config = asdict(llama_adapter.LLaMAConfig.from_name(model_size))
16
+
17
+ del adapter_config["adapter_prompt_length"]
18
+ del adapter_config["adapter_start_layer"]
19
+ assert adapter_config == llama_config
20
+
21
+ with EmptyInitOnDevice():
22
+ llama_model = llama.LLaMA.from_name(model_size)
23
+ adapter_model = llama_adapter.LLaMA.from_name(model_size)
24
+ assert llama_model.lm_head.weight.shape == adapter_model.lm_head.weight.shape
25
+
26
+
27
+ def test_adapter_load_gating_factor(lit_llama):
28
+ """Tests backward-compatible loading of checkpoints after the `gating_factor` was extended per-head
29
+ in PR #297.
30
+ """
31
+ import lit_llama.adapter as llama_adapter
32
+ from lit_llama.utils import lazy_load
33
+
34
+ config = llama_adapter.LLaMAConfig(n_head=4, block_size=100, n_embd=16)
35
+ attn = llama_adapter.CausalSelfAttention(config=config, block_idx=3)
36
+
37
+ # Old checkpoint format
38
+ state_dict={
39
+ "gating_factor": torch.tensor(0.42), # in old checkpoints, this was a scalar
40
+ "c_attn.weight": torch.zeros(3 * 16, 16),
41
+ "c_proj.weight": torch.zeros(16, 16),
42
+ "adapter_wte.weight": torch.zeros(10, 16),
43
+ }
44
+ attn.load_state_dict(state_dict=state_dict)
45
+ assert torch.equal(attn.gating_factor, torch.full((1, 4, 1, 1), 0.42))
46
+
47
+ # New checkpoint format
48
+ state_dict={
49
+ "gating_factor": torch.tensor([0.42, 0.42, 0.42, 0.42]).reshape(1, 4, 1, 1),
50
+ "c_attn.weight": torch.zeros(3 * 16, 16),
51
+ "c_proj.weight": torch.zeros(16, 16),
52
+ "adapter_wte.weight": torch.zeros(10, 16),
53
+ }
54
+ attn.load_state_dict(state_dict=state_dict)
55
+ assert torch.equal(attn.gating_factor, torch.full((1, 4, 1, 1), 0.42))
tests/test_adapter_v2.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import sys
3
+
4
+
5
+ @pytest.mark.skipif(sys.platform == "win32", reason="EmptyInitOnDevice on CPU not working for Windows.")
6
+ @pytest.mark.parametrize("model_size", ["7B", "13B", "30B", "65B"])
7
+ def test_config_identical(model_size, lit_llama):
8
+ import torch.nn as nn
9
+ import lit_llama.adapter as llama_adapter
10
+ from lit_llama.adapter_v2 import adapter_v2_linear_with_bias_and_scale
11
+ import lit_llama.model as llama
12
+ from lit_llama.utils import EmptyInitOnDevice
13
+
14
+ with EmptyInitOnDevice():
15
+ llama_model = llama.LLaMA.from_name(model_size)
16
+ adapter_model = llama_adapter.LLaMA.from_name(model_size)
17
+
18
+ for module in adapter_model.modules():
19
+ if isinstance(module, nn.Linear):
20
+ adapter_v2_linear_with_bias_and_scale(module)
21
+
22
+ print(adapter_model.transformer.h[2].attn.c_attn.adapter_bias)
23
+ assert not hasattr(llama_model.transformer.h[2].attn.c_attn, 'adapter_bias')
24
+ assert not hasattr(llama_model.transformer.h[2].attn.c_attn, 'adapter_scale')
25
+ assert hasattr(adapter_model.transformer.h[2].attn.c_attn, 'adapter_bias')
26
+ assert hasattr(adapter_model.transformer.h[2].attn.c_attn, 'adapter_scale')
tests/test_generate.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import subprocess
3
+ import sys
4
+ from contextlib import contextmanager, redirect_stdout
5
+ from io import StringIO
6
+ from pathlib import Path
7
+ from unittest import mock
8
+ from unittest.mock import Mock, call, ANY
9
+
10
+ import torch
11
+
12
+ wd = Path(__file__).parent.parent.absolute()
13
+
14
+
15
+ @functools.lru_cache(maxsize=1)
16
+ def load_generate_script():
17
+ sys.path.append(str(wd))
18
+
19
+ import generate as generate
20
+
21
+ return generate
22
+
23
+
24
+ def test_generate():
25
+ generate = load_generate_script()
26
+
27
+ from lit_llama.model import LLaMA, LLaMAConfig
28
+
29
+ T, C = 5, 3
30
+ logits = torch.randn(T, C)
31
+ input_idx = torch.randint(10, size=(T,))
32
+
33
+ config = LLaMAConfig(block_size=128, vocab_size=16, n_layer=1, n_head=4, n_embd=8)
34
+ model = LLaMA(config)
35
+ max_new_tokens = 20
36
+
37
+ multinomial_results = []
38
+ original_multinomial = torch.multinomial
39
+
40
+ def multinomial(*args, **kwargs):
41
+ out = original_multinomial(*args, **kwargs)
42
+ multinomial_results.append(out)
43
+ return out
44
+
45
+ with mock.patch("torch.multinomial", multinomial):
46
+ out = generate.generate(model, input_idx, max_new_tokens, max_seq_length=10, top_k=4)
47
+
48
+ assert out.size(0) == T + max_new_tokens
49
+ multinomial_results = torch.hstack(multinomial_results)
50
+ expected = torch.cat((input_idx, multinomial_results))
51
+ assert out.shape == expected.shape
52
+ torch.testing.assert_close(out, expected)
53
+
54
+
55
+ @mock.patch("torch.cuda.is_bf16_supported", return_value=False)
56
+ def test_main(tmp_path, monkeypatch):
57
+ generate = load_generate_script()
58
+
59
+ checkpoint_path = tmp_path / "ckpt"
60
+ checkpoint_path.touch()
61
+ tokenizer_path = tmp_path / "tokenizer"
62
+ tokenizer_path.touch()
63
+
64
+ class FabricMock(Mock):
65
+ @property
66
+ def device(self):
67
+ return torch.device("cpu")
68
+
69
+ @contextmanager
70
+ def init_module(self, empty_init):
71
+ yield
72
+
73
+ monkeypatch.setattr(generate.L, "Fabric", FabricMock)
74
+ model_mock = Mock()
75
+ monkeypatch.setattr(generate.LLaMA, "from_name", model_mock)
76
+ lookup_mock = Mock(return_value="1T")
77
+ monkeypatch.setattr(generate, "llama_model_lookup", lookup_mock)
78
+ load_mock = Mock()
79
+ load_mock.return_value = load_mock
80
+ load_mock.__enter__ = Mock()
81
+ load_mock.__exit__ = Mock()
82
+ monkeypatch.setattr(generate.torch, "load", load_mock)
83
+ monkeypatch.setattr(generate, "lazy_load", load_mock)
84
+ tokenizer_mock = Mock()
85
+ tokenizer_mock.return_value.encode.return_value = torch.tensor([[1, 2, 3]])
86
+ tokenizer_mock.return_value.decode.return_value = "foo bar baz"
87
+ monkeypatch.setattr(generate, "Tokenizer", tokenizer_mock)
88
+ generate_mock = Mock()
89
+ generate_mock.return_value = torch.tensor([[3, 2, 1]])
90
+ monkeypatch.setattr(generate, "generate", generate_mock)
91
+
92
+ num_samples = 2
93
+ out = StringIO()
94
+ with redirect_stdout(out):
95
+ generate.main(
96
+ checkpoint_path=checkpoint_path,
97
+ tokenizer_path=tokenizer_path,
98
+ temperature=2.0,
99
+ top_k=2,
100
+ num_samples=num_samples,
101
+ )
102
+
103
+ model_mock.assert_called_once_with("1T")
104
+ load_mock.assert_called_once_with(checkpoint_path)
105
+ tokenizer_mock.assert_called_once_with(tokenizer_path)
106
+ assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples
107
+ assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value)
108
+ assert generate_mock.mock_calls == [call(ANY, ANY, 50, temperature=2.0, top_k=2)] * num_samples
109
+ # only the generated result is printed to stdout
110
+ assert out.getvalue() == "foo bar baz\n" * num_samples
111
+
112
+
113
+ def test_cli():
114
+ cli_path = wd / "generate.py"
115
+ output = subprocess.check_output([sys.executable, cli_path, "-h"])
116
+ output = str(output.decode())
117
+ assert "Generates text samples" in output