ravi.naik commited on
Commit
17a7426
1 Parent(s): ac0ad3c

Updated repository for gradio UI and model

Browse files
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: ERA SESSION22
3
  emoji: 📈
4
  colorFrom: indigo
5
  colorTo: yellow
@@ -9,5 +9,3 @@ app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: "ERA-SESSION22 Training PyThia-160M from scratch on AWS Sagemaker"
3
  emoji: 📈
4
  colorFrom: indigo
5
  colorTo: yellow
 
9
  pinned: false
10
  license: mit
11
  ---
 
 
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from pathlib import Path
4
+
5
+ torch.set_float32_matmul_precision("high")
6
+
7
+ from generate.base import main
8
+
9
+
10
+ def generate(prompt, max_new_tokens, temperature, num_samples):
11
+ prompt = prompt.strip()
12
+
13
+ responses = main(
14
+ prompt=prompt,
15
+ checkpoint_dir=Path("out/redpajama"),
16
+ max_new_tokens=max_new_tokens,
17
+ temperature=temperature,
18
+ num_samples=num_samples,
19
+ )
20
+ return {output: responses}
21
+
22
+
23
+ with gr.Blocks() as app:
24
+ gr.Markdown("## ERA Session22 - Pythia-160M Pre-training with LitGPT")
25
+ gr.Markdown(
26
+ """This is an implementation of Pythia-160M using [LitGPT](https://github.com/Lightning-AI/lit-gpt) by LightningAI.
27
+
28
+ Please find the source code and training details [here](https://github.com/RaviNaik/ERA-SESSION22).
29
+
30
+ Dataset used to train: [RedPajama](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T).
31
+ """
32
+ )
33
+ with gr.Row():
34
+ with gr.Column():
35
+ prompt_box = gr.Textbox(label="Initial Prompt", interactive=True)
36
+ max_new_tokens = gr.Slider(
37
+ minimum=10,
38
+ maximum=200,
39
+ value=50,
40
+ step=10,
41
+ label="Select Number of Tokens to be Generated",
42
+ interactive=True,
43
+ )
44
+ temperature = gr.Slider(
45
+ minimum=0.1,
46
+ maximum=1,
47
+ value=0.7,
48
+ step=0.1,
49
+ label="Select Temperature",
50
+ interactive=True,
51
+ )
52
+ num_samples = gr.Dropdown(
53
+ choices=[1, 2, 5, 10],
54
+ value=1,
55
+ interactive=True,
56
+ label="Select No. of outputs to be generated",
57
+ )
58
+ submit_btn = gr.Button(value="Generate")
59
+
60
+ with gr.Column():
61
+ output = gr.JSON(label="Generated Text")
62
+
63
+ submit_btn.click(
64
+ generate,
65
+ inputs=[prompt_box, max_new_tokens, temperature, num_samples],
66
+ outputs=[output],
67
+ )
68
+
69
+ app.launch()
generate/adapter.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ from pathlib import Path
4
+ from typing import Literal, Optional
5
+
6
+ import lightning as L
7
+ import torch
8
+ from lightning.fabric.plugins import BitsandbytesPrecision
9
+ from lightning.fabric.strategies import FSDPStrategy
10
+
11
+ # support running without installing as a package
12
+ wd = Path(__file__).parent.parent.resolve()
13
+ sys.path.append(str(wd))
14
+
15
+ from generate.base import generate
16
+ from lit_gpt import Tokenizer
17
+ from lit_gpt.adapter import GPT, Block, Config
18
+ from lit_gpt.utils import check_valid_checkpoint_dir, get_default_supported_precision, gptq_quantization, lazy_load
19
+ from scripts.prepare_alpaca import generate_prompt
20
+
21
+
22
+ def main(
23
+ prompt: str = "What food do llamas eat?",
24
+ input: str = "",
25
+ adapter_path: Path = Path("out/adapter/alpaca/lit_model_adapter_finetuned.pth"),
26
+ checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
27
+ quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]] = None,
28
+ max_new_tokens: int = 100,
29
+ top_k: Optional[int] = 200,
30
+ temperature: float = 0.8,
31
+ strategy: str = "auto",
32
+ devices: int = 1,
33
+ precision: Optional[str] = None,
34
+ ) -> None:
35
+ """Generates a response based on a given instruction and an optional input.
36
+ This script will only work with checkpoints from the instruction-tuned GPT-Adapter model.
37
+ See `finetune/adapter.py`.
38
+
39
+ Args:
40
+ prompt: The prompt/instruction (Alpaca style).
41
+ input: Optional input (Alpaca style).
42
+ adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
43
+ `finetune/adapter.py`.
44
+ checkpoint_dir: The path to the checkpoint folder with pretrained GPT weights.
45
+ quantize: Whether to quantize the model and using which method:
46
+ - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
47
+ - bnb.int8: 8-bit quantization from bitsandbytes
48
+ - gptq.int4: 4-bit quantization from GPTQ
49
+ for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md
50
+ max_new_tokens: The number of generation steps to take.
51
+ top_k: The number of top most probable tokens to consider in the sampling process.
52
+ temperature: A value controlling the randomness of the sampling process. Higher values result in more random
53
+ samples.
54
+ strategy: Indicates the Fabric strategy setting to use.
55
+ devices: How many devices to use.
56
+ precision: Indicates the Fabric precision setting to use.
57
+ """
58
+ precision = precision or get_default_supported_precision(training=False)
59
+
60
+ plugins = None
61
+ if quantize is not None:
62
+ if devices > 1:
63
+ raise NotImplementedError(
64
+ "Quantization is currently not supported for multi-GPU training. Please set devices=1 when using the"
65
+ " --quantize flag."
66
+ )
67
+ if quantize.startswith("bnb."):
68
+ if "mixed" in precision:
69
+ raise ValueError("Quantization and mixed precision is not supported.")
70
+ dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
71
+ plugins = BitsandbytesPrecision(quantize[4:], dtype)
72
+ precision = None
73
+
74
+ if strategy == "fsdp":
75
+ strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
76
+
77
+ fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy, plugins=plugins)
78
+ fabric.launch()
79
+
80
+ check_valid_checkpoint_dir(checkpoint_dir)
81
+
82
+ config = Config.from_json(checkpoint_dir / "lit_config.json")
83
+
84
+ if quantize is not None and devices > 1:
85
+ raise NotImplementedError
86
+ if quantize == "gptq.int4":
87
+ model_file = "lit_model_gptq.4bit.pth"
88
+ if not (checkpoint_dir / model_file).is_file():
89
+ raise ValueError("Please run `python quantize/gptq.py` first")
90
+ else:
91
+ model_file = "lit_model.pth"
92
+ checkpoint_path = checkpoint_dir / model_file
93
+
94
+ tokenizer = Tokenizer(checkpoint_dir)
95
+ sample = {"instruction": prompt, "input": input}
96
+ prompt = generate_prompt(sample)
97
+ encoded = tokenizer.encode(prompt, device=fabric.device)
98
+ prompt_length = encoded.size(0)
99
+ max_returned_tokens = prompt_length + max_new_tokens
100
+
101
+ fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
102
+ t0 = time.perf_counter()
103
+ with fabric.init_module(empty_init=True), gptq_quantization(quantize == "gptq.int4"):
104
+ model = GPT(config)
105
+ fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
106
+ with fabric.init_tensor():
107
+ # set the max_seq_length to limit the memory usage to what we need
108
+ model.max_seq_length = max_returned_tokens
109
+ # enable the kv cache
110
+ model.set_kv_cache(batch_size=1)
111
+ model.eval()
112
+
113
+ t0 = time.perf_counter()
114
+ checkpoint = lazy_load(checkpoint_path)
115
+ adapter_checkpoint = lazy_load(adapter_path)
116
+ checkpoint.update(adapter_checkpoint.get("model", adapter_checkpoint))
117
+ model.load_state_dict(checkpoint)
118
+ fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
119
+
120
+ model = fabric.setup(model)
121
+
122
+ L.seed_everything(1234)
123
+ t0 = time.perf_counter()
124
+ y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
125
+ t = time.perf_counter() - t0
126
+
127
+ output = tokenizer.decode(y)
128
+ output = output.split("### Response:")[1].strip()
129
+ fabric.print(output)
130
+
131
+ tokens_generated = y.size(0) - prompt_length
132
+ fabric.print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
133
+ if fabric.device.type == "cuda":
134
+ fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)
135
+
136
+
137
+ if __name__ == "__main__":
138
+ from jsonargparse import CLI
139
+
140
+ torch.set_float32_matmul_precision("high")
141
+ CLI(main)
generate/adapter_v2.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ from pathlib import Path
4
+ from typing import Literal, Optional
5
+
6
+ import lightning as L
7
+ import torch
8
+ from lightning.fabric.plugins import BitsandbytesPrecision
9
+ from lightning.fabric.strategies import FSDPStrategy
10
+
11
+ # support running without installing as a package
12
+ wd = Path(__file__).parent.parent.resolve()
13
+ sys.path.append(str(wd))
14
+
15
+ from generate.base import generate
16
+ from lit_gpt import Tokenizer
17
+ from lit_gpt.adapter_v2 import GPT, Block, Config
18
+ from lit_gpt.utils import check_valid_checkpoint_dir, get_default_supported_precision, gptq_quantization, lazy_load
19
+ from scripts.prepare_alpaca import generate_prompt
20
+
21
+
22
+ def main(
23
+ prompt: str = "What food do llamas eat?",
24
+ input: str = "",
25
+ adapter_path: Path = Path("out/adapter_v2/alpaca/lit_model_adapter_finetuned.pth"),
26
+ checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
27
+ quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]] = None,
28
+ max_new_tokens: int = 100,
29
+ top_k: Optional[int] = 200,
30
+ temperature: float = 0.8,
31
+ strategy: str = "auto",
32
+ devices: int = 1,
33
+ precision: Optional[str] = None,
34
+ ) -> None:
35
+ """Generates a response based on a given instruction and an optional input.
36
+ This script will only work with checkpoints from the instruction-tuned GPT-AdapterV2 model.
37
+ See `finetune/adapter_v2.py`.
38
+
39
+ Args:
40
+ prompt: The prompt/instruction (Alpaca style).
41
+ input: Optional input (Alpaca style).
42
+ adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
43
+ `finetune/adapter_v2.py`.
44
+ checkpoint_dir: The path to the checkpoint folder with pretrained GPT weights.
45
+ quantize: Whether to quantize the model and using which method:
46
+ - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
47
+ - bnb.int8: 8-bit quantization from bitsandbytes
48
+ - gptq.int4: 4-bit quantization from GPTQ
49
+ for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md
50
+ max_new_tokens: The number of generation steps to take.
51
+ top_k: The number of top most probable tokens to consider in the sampling process.
52
+ temperature: A value controlling the randomness of the sampling process. Higher values result in more random
53
+ samples.
54
+ strategy: Indicates the Fabric strategy setting to use.
55
+ devices: How many devices to use.
56
+ precision: Indicates the Fabric precision setting to use.
57
+ """
58
+ precision = precision or get_default_supported_precision(training=False)
59
+
60
+ plugins = None
61
+ if quantize is not None:
62
+ if devices > 1:
63
+ raise NotImplementedError(
64
+ "Quantization is currently not supported for multi-GPU training. Please set devices=1 when using the"
65
+ " --quantize flag."
66
+ )
67
+ if quantize.startswith("bnb."):
68
+ if "mixed" in precision:
69
+ raise ValueError("Quantization and mixed precision is not supported.")
70
+ dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
71
+ plugins = BitsandbytesPrecision(quantize[4:], dtype)
72
+ precision = None
73
+
74
+ if strategy == "fsdp":
75
+ strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
76
+
77
+ fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy, plugins=plugins)
78
+ fabric.launch()
79
+
80
+ check_valid_checkpoint_dir(checkpoint_dir)
81
+
82
+ config = Config.from_json(checkpoint_dir / "lit_config.json")
83
+
84
+ if quantize is not None and devices > 1:
85
+ raise NotImplementedError
86
+ if quantize == "gptq.int4":
87
+ model_file = "lit_model_gptq.4bit.pth"
88
+ if not (checkpoint_dir / model_file).is_file():
89
+ raise ValueError("Please run `python quantize/gptq.py` first")
90
+ else:
91
+ model_file = "lit_model.pth"
92
+ checkpoint_path = checkpoint_dir / model_file
93
+
94
+ tokenizer = Tokenizer(checkpoint_dir)
95
+ sample = {"instruction": prompt, "input": input}
96
+ prompt = generate_prompt(sample)
97
+ encoded = tokenizer.encode(prompt, device=fabric.device)
98
+ prompt_length = encoded.size(0)
99
+ max_returned_tokens = prompt_length + max_new_tokens
100
+
101
+ fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
102
+ t0 = time.perf_counter()
103
+ with fabric.init_module(empty_init=True), gptq_quantization(quantize == "gptq.int4"):
104
+ model = GPT(config)
105
+ fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
106
+ with fabric.init_tensor():
107
+ # set the max_seq_length to limit the memory usage to what we need
108
+ model.max_seq_length = max_returned_tokens
109
+ # enable the kv cache
110
+ model.set_kv_cache(batch_size=1)
111
+ model.eval()
112
+
113
+ t0 = time.perf_counter()
114
+ checkpoint = lazy_load(checkpoint_path)
115
+ adapter_checkpoint = lazy_load(adapter_path)
116
+ checkpoint.update(adapter_checkpoint.get("model", adapter_checkpoint))
117
+ model.load_state_dict(checkpoint)
118
+ fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
119
+
120
+ model = fabric.setup(model)
121
+
122
+ L.seed_everything(1234)
123
+ t0 = time.perf_counter()
124
+ y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
125
+ t = time.perf_counter() - t0
126
+
127
+ output = tokenizer.decode(y)
128
+ output = output.split("### Response:")[1].strip()
129
+ fabric.print(output)
130
+
131
+ tokens_generated = y.size(0) - prompt_length
132
+ fabric.print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
133
+ if fabric.device.type == "cuda":
134
+ fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)
135
+
136
+
137
+ if __name__ == "__main__":
138
+ from jsonargparse import CLI
139
+
140
+ torch.set_float32_matmul_precision("high")
141
+ CLI(main)
generate/base.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ from pathlib import Path
4
+ from typing import Any, Literal, Optional
5
+
6
+ import lightning as L
7
+ import torch
8
+ import torch._dynamo.config
9
+ import torch._inductor.config
10
+ from lightning.fabric.plugins import BitsandbytesPrecision
11
+ from lightning.fabric.strategies import FSDPStrategy
12
+
13
+ # support running without installing as a package
14
+ wd = Path(__file__).parent.parent.resolve()
15
+ sys.path.append(str(wd))
16
+
17
+ from lit_gpt import GPT, Config, Tokenizer
18
+ from lit_gpt.model import Block
19
+ from lit_gpt.utils import (
20
+ check_valid_checkpoint_dir,
21
+ get_default_supported_precision,
22
+ gptq_quantization,
23
+ load_checkpoint,
24
+ )
25
+
26
+
27
+ def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
28
+ if torch._dynamo.is_compiling():
29
+ # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly
30
+ distribution = torch.empty_like(probs).exponential_(1)
31
+ return torch.argmax(probs / distribution, dim=-1, keepdim=True)
32
+ return torch.multinomial(probs, num_samples=1)
33
+
34
+
35
+ def sample(
36
+ logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None
37
+ ) -> torch.Tensor:
38
+ logits = logits[0, -1]
39
+ # optionally crop the logits to only the top k options
40
+ if top_k is not None:
41
+ v, i = torch.topk(logits, min(top_k, logits.size(-1)))
42
+ # do not use `torch.where` as in nanogpt because it will repeat top-k collisions
43
+ logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
44
+ # optionally scale the logits and sample from a probability distribution
45
+ if temperature > 0.0:
46
+ probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
47
+ return multinomial_num_samples_1(probs)
48
+ return torch.argmax(logits, dim=-1, keepdim=True)
49
+
50
+
51
+ def next_token(
52
+ model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: Any
53
+ ) -> torch.Tensor:
54
+ logits = model(x, input_pos)
55
+ next = sample(logits, **kwargs)
56
+ return next.type_as(x)
57
+
58
+
59
+ @torch.inference_mode()
60
+ def generate(
61
+ model: GPT,
62
+ prompt: torch.Tensor,
63
+ max_returned_tokens: int,
64
+ *,
65
+ temperature: float = 1.0,
66
+ top_k: Optional[int] = None,
67
+ eos_id: Optional[int] = None,
68
+ ) -> torch.Tensor:
69
+ """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
70
+
71
+ The implementation of this function is modified from A. Karpathy's nanoGPT.
72
+
73
+ Args:
74
+ model: The model to use.
75
+ prompt: Tensor of shape (T) with indices of the prompt sequence.
76
+ max_returned_tokens: The maximum number of tokens to return (given plus generated).
77
+ temperature: Scales the predicted logits by 1 / temperature.
78
+ top_k: If specified, only sample among the tokens with the k highest probabilities.
79
+ eos_id: If specified, stop generating any more token once the <eos> token is triggered.
80
+ """
81
+ T = prompt.size(0)
82
+ assert max_returned_tokens > T
83
+ if model.max_seq_length < max_returned_tokens - 1:
84
+ # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
85
+ # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
86
+ # not support it to avoid negatively impacting the overall speed
87
+ raise NotImplementedError(
88
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
89
+ )
90
+
91
+ device = prompt.device
92
+ tokens = [prompt]
93
+ input_pos = torch.tensor([T], device=device)
94
+ token = next_token(
95
+ model,
96
+ torch.arange(0, T, device=device),
97
+ prompt.view(1, -1),
98
+ temperature=temperature,
99
+ top_k=top_k,
100
+ ).clone()
101
+ tokens.append(token)
102
+ for _ in range(2, max_returned_tokens - T + 1):
103
+ token = next_token(
104
+ model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k
105
+ ).clone()
106
+ tokens.append(token)
107
+ if token == eos_id:
108
+ break
109
+ input_pos = input_pos.add_(1)
110
+ return torch.cat(tokens)
111
+
112
+
113
+ def main(
114
+ prompt: str = "What food do llamas eat?",
115
+ *,
116
+ num_samples: int = 1,
117
+ max_new_tokens: int = 50,
118
+ top_k: Optional[int] = 200,
119
+ temperature: float = 0.8,
120
+ checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
121
+ quantize: Optional[
122
+ Literal[
123
+ "bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"
124
+ ]
125
+ ] = None,
126
+ strategy: str = "auto",
127
+ devices: int = 1,
128
+ precision: Optional[str] = None,
129
+ compile: bool = False,
130
+ ) -> None:
131
+ """Generates text samples based on a pre-trained model and tokenizer.
132
+
133
+ Args:
134
+ prompt: The prompt string to use for generating the samples.
135
+ num_samples: The number of text samples to generate.
136
+ max_new_tokens: The number of generation steps to take.
137
+ top_k: The number of top most probable tokens to consider in the sampling process.
138
+ temperature: A value controlling the randomness of the sampling process. Higher values result in more random
139
+ samples.
140
+ checkpoint_dir: The checkpoint directory to load.
141
+ quantize: Whether to quantize the model and using which method:
142
+ - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
143
+ - bnb.int8: 8-bit quantization from bitsandbytes
144
+ - gptq.int4: 4-bit quantization from GPTQ
145
+ for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md
146
+ strategy: Indicates the Fabric strategy setting to use.
147
+ devices: How many devices to use.
148
+ precision: Indicates the Fabric precision setting to use.
149
+ compile: Whether to compile the model.
150
+ """
151
+ precision = precision or get_default_supported_precision(training=False)
152
+
153
+ plugins = None
154
+ if quantize is not None:
155
+ if devices > 1:
156
+ raise NotImplementedError(
157
+ "Quantization is currently not supported for multi-GPU training. Please set devices=1 when using the"
158
+ " --quantize flag."
159
+ )
160
+ if quantize.startswith("bnb."):
161
+ if "mixed" in precision:
162
+ raise ValueError("Quantization and mixed precision is not supported.")
163
+ dtype = {
164
+ "16-true": torch.float16,
165
+ "bf16-true": torch.bfloat16,
166
+ "32-true": torch.float32,
167
+ }[precision]
168
+ plugins = BitsandbytesPrecision(quantize[4:], dtype)
169
+ precision = None
170
+
171
+ if strategy == "fsdp":
172
+ strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
173
+
174
+ fabric = L.Fabric(
175
+ devices=devices, precision=precision, strategy=strategy, plugins=plugins
176
+ )
177
+ fabric.launch()
178
+
179
+ check_valid_checkpoint_dir(checkpoint_dir)
180
+
181
+ config = Config.from_json(checkpoint_dir / "lit_config.json")
182
+
183
+ if quantize == "gptq.int4":
184
+ model_file = "lit_model_gptq.4bit.pth"
185
+ if not (checkpoint_dir / model_file).is_file():
186
+ raise ValueError("Please run `python quantize/gptq.py` first")
187
+ else:
188
+ model_file = "lit_model.pth"
189
+ checkpoint_path = checkpoint_dir / model_file
190
+
191
+ tokenizer = Tokenizer(checkpoint_dir)
192
+ encoded = tokenizer.encode(prompt, device=fabric.device)
193
+ prompt_length = encoded.size(0)
194
+ max_returned_tokens = prompt_length + max_new_tokens
195
+
196
+ fabric.print(
197
+ f"Loading model {str(checkpoint_path)!r} with {config.__dict__}",
198
+ file=sys.stderr,
199
+ )
200
+ t0 = time.perf_counter()
201
+ with fabric.init_module(empty_init=True), gptq_quantization(
202
+ quantize == "gptq.int4"
203
+ ):
204
+ model = GPT(config)
205
+ fabric.print(
206
+ f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.",
207
+ file=sys.stderr,
208
+ )
209
+ with fabric.init_tensor():
210
+ # set the max_seq_length to limit the memory usage to what we need
211
+ model.max_seq_length = max_returned_tokens
212
+ # enable the kv cache
213
+ model.set_kv_cache(batch_size=1)
214
+ model.eval()
215
+
216
+ if compile:
217
+ torch._dynamo.config.automatic_dynamic_shapes = True
218
+ torch._inductor.config.triton.unique_kernel_names = True
219
+ torch._inductor.config.coordinate_descent_tuning = True
220
+ global next_token
221
+ next_token = torch.compile(next_token, mode="reduce-overhead")
222
+
223
+ model = fabric.setup_module(model)
224
+
225
+ t0 = time.perf_counter()
226
+ load_checkpoint(fabric, model, checkpoint_path)
227
+ fabric.print(
228
+ f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.",
229
+ file=sys.stderr,
230
+ )
231
+
232
+ L.seed_everything(1234)
233
+ responses = []
234
+ for i in range(num_samples):
235
+ t0 = time.perf_counter()
236
+ y = generate(
237
+ model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k
238
+ )
239
+ t = time.perf_counter() - t0
240
+ for block in model.transformer.h:
241
+ block.attn.kv_cache.reset_parameters()
242
+
243
+ fabric.print(tokenizer.decode(y))
244
+ tokens_generated = y.size(0) - prompt_length
245
+ fabric.print(
246
+ f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec",
247
+ file=sys.stderr,
248
+ )
249
+ responses.append(
250
+ {
251
+ "response": tokenizer.decode(y),
252
+ "latency": f"{round(t, 2)} seconds",
253
+ "generation_rate": f"{round(tokens_generated / t, 2)} tokens per sec",
254
+ }
255
+ )
256
+ if fabric.device.type == "cuda":
257
+ fabric.print(
258
+ f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB",
259
+ file=sys.stderr,
260
+ )
261
+ return responses
262
+
263
+
264
+ if __name__ == "__main__":
265
+ from jsonargparse import CLI
266
+
267
+ torch.set_float32_matmul_precision("high")
268
+ CLI(main)
generate/full.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ from pathlib import Path
4
+ from typing import Literal, Optional
5
+
6
+ import lightning as L
7
+ import torch
8
+ from lightning.fabric.plugins import BitsandbytesPrecision
9
+ from lightning.fabric.strategies import FSDPStrategy
10
+
11
+ # support running without installing as a package
12
+ wd = Path(__file__).parent.parent.resolve()
13
+ sys.path.append(str(wd))
14
+
15
+ from generate.base import generate
16
+ from lit_gpt import GPT, Config, Tokenizer
17
+ from lit_gpt.model import Block
18
+ from lit_gpt.utils import (
19
+ check_valid_checkpoint_dir,
20
+ get_default_supported_precision,
21
+ gptq_quantization,
22
+ load_checkpoint,
23
+ )
24
+ from scripts.prepare_alpaca import generate_prompt
25
+
26
+
27
+ def main(
28
+ prompt: str = "What food do llamas eat?",
29
+ input: str = "",
30
+ finetuned_path: Path = Path("out/full/alpaca/lit_model_finetuned.pth"),
31
+ checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
32
+ quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]] = None,
33
+ max_new_tokens: int = 100,
34
+ top_k: Optional[int] = 200,
35
+ temperature: float = 0.8,
36
+ strategy: str = "auto",
37
+ devices: int = 1,
38
+ precision: Optional[str] = None,
39
+ ) -> None:
40
+ """Generates a response based on a given instruction and an optional input.
41
+ This script will only work with checkpoints from the instruction-tuned GPT model.
42
+ See `finetune/full.py`.
43
+
44
+ Args:
45
+ prompt: The prompt/instruction (Alpaca style).
46
+ input: Optional input (Alpaca style).
47
+ finetuned_path: Path to the checkpoint with trained weights, which are the output of
48
+ `finetune/full.py`.
49
+ checkpoint_dir: The path to the checkpoint folder with pretrained GPT weights.
50
+ quantize: Whether to quantize the model and using which method:
51
+ - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
52
+ - bnb.int8: 8-bit quantization from bitsandbytes
53
+ - gptq.int4: 4-bit quantization from GPTQ
54
+ for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md
55
+ max_new_tokens: The number of generation steps to take.
56
+ top_k: The number of top most probable tokens to consider in the sampling process.
57
+ temperature: A value controlling the randomness of the sampling process. Higher values result in more random
58
+ samples.
59
+ strategy: Indicates the Fabric strategy setting to use.
60
+ devices: How many devices to use.
61
+ precision: Indicates the Fabric precision setting to use.
62
+ """
63
+ precision = precision or get_default_supported_precision(training=False)
64
+
65
+ plugins = None
66
+ if quantize is not None:
67
+ if devices > 1:
68
+ raise NotImplementedError(
69
+ "Quantization is currently not supported for multi-GPU training. Please set devices=1 when using the"
70
+ " --quantize flag."
71
+ )
72
+ if quantize.startswith("bnb."):
73
+ if "mixed" in precision:
74
+ raise ValueError("Quantization and mixed precision is not supported.")
75
+ dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
76
+ plugins = BitsandbytesPrecision(quantize[4:], dtype)
77
+ precision = None
78
+
79
+ if strategy == "fsdp":
80
+ strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
81
+
82
+ fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy, plugins=plugins)
83
+ fabric.launch()
84
+
85
+ check_valid_checkpoint_dir(checkpoint_dir)
86
+
87
+ config = Config.from_json(checkpoint_dir / "lit_config.json")
88
+
89
+ if quantize is not None and devices > 1:
90
+ raise NotImplementedError
91
+ checkpoint_path = finetuned_path
92
+
93
+ tokenizer = Tokenizer(checkpoint_dir)
94
+ sample = {"instruction": prompt, "input": input}
95
+ prompt = generate_prompt(sample)
96
+ encoded = tokenizer.encode(prompt, device=fabric.device)
97
+ prompt_length = encoded.size(0)
98
+ max_returned_tokens = prompt_length + max_new_tokens
99
+
100
+ fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
101
+ t0 = time.perf_counter()
102
+ with fabric.init_module(empty_init=True), gptq_quantization(quantize == "gptq.int4"):
103
+ model = GPT(config)
104
+ fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
105
+ with fabric.init_tensor():
106
+ # set the max_seq_length to limit the memory usage to what we need
107
+ model.max_seq_length = max_returned_tokens
108
+ # enable the kv cache
109
+ model.set_kv_cache(batch_size=1)
110
+ model.eval()
111
+
112
+ model = fabric.setup(model)
113
+
114
+ t0 = time.perf_counter()
115
+ load_checkpoint(fabric, model, checkpoint_path)
116
+ fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
117
+
118
+ L.seed_everything(1234)
119
+ t0 = time.perf_counter()
120
+ y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
121
+ t = time.perf_counter() - t0
122
+
123
+ output = tokenizer.decode(y)
124
+ output = output.split("### Response:")[1].strip()
125
+ fabric.print(output)
126
+
127
+ tokens_generated = y.size(0) - prompt_length
128
+ fabric.print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
129
+ if fabric.device.type == "cuda":
130
+ fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)
131
+
132
+
133
+ if __name__ == "__main__":
134
+ from jsonargparse import CLI
135
+
136
+ torch.set_float32_matmul_precision("high")
137
+ CLI(main)
generate/lora.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ from pathlib import Path
4
+ from typing import Literal, Optional
5
+
6
+ import lightning as L
7
+ import torch
8
+ from lightning.fabric.plugins import BitsandbytesPrecision
9
+ from lightning.fabric.strategies import FSDPStrategy
10
+
11
+ # support running without installing as a package
12
+ wd = Path(__file__).parent.parent.resolve()
13
+ sys.path.append(str(wd))
14
+
15
+ from generate.base import generate
16
+ from lit_gpt import Tokenizer
17
+ from lit_gpt.lora import GPT, Block, Config, merge_lora_weights
18
+ from lit_gpt.utils import check_valid_checkpoint_dir, get_default_supported_precision, gptq_quantization, lazy_load
19
+ from scripts.prepare_alpaca import generate_prompt
20
+
21
+ lora_r = 8
22
+ lora_alpha = 16
23
+ lora_dropout = 0.05
24
+ lora_query = True
25
+ lora_key = False
26
+ lora_value = True
27
+ lora_projection = False
28
+ lora_mlp = False
29
+ lora_head = False
30
+
31
+
32
+ def main(
33
+ prompt: str = "What food do llamas eat?",
34
+ input: str = "",
35
+ lora_path: Path = Path("out/lora/alpaca/lit_model_lora_finetuned.pth"),
36
+ checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
37
+ quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]] = None,
38
+ max_new_tokens: int = 100,
39
+ top_k: Optional[int] = 200,
40
+ temperature: float = 0.8,
41
+ strategy: str = "auto",
42
+ devices: int = 1,
43
+ precision: Optional[str] = None,
44
+ ) -> None:
45
+ """Generates a response based on a given instruction and an optional input.
46
+ This script will only work with checkpoints from the instruction-tuned GPT-LoRA model.
47
+ See `finetune/lora.py`.
48
+
49
+ Args:
50
+ prompt: The prompt/instruction (Alpaca style).
51
+ input: Optional input (Alpaca style).
52
+ lora_path: Path to the checkpoint with trained adapter weights, which are the output of
53
+ `finetune/lora.py`.
54
+ checkpoint_dir: The path to the checkpoint folder with pretrained GPT weights.
55
+ quantize: Whether to quantize the model and using which method:
56
+ - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
57
+ - bnb.int8: 8-bit quantization from bitsandbytes
58
+ - gptq.int4: 4-bit quantization from GPTQ
59
+ for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md
60
+ max_new_tokens: The number of generation steps to take.
61
+ top_k: The number of top most probable tokens to consider in the sampling process.
62
+ temperature: A value controlling the randomness of the sampling process. Higher values result in more random
63
+ samples.
64
+ strategy: Indicates the Fabric strategy setting to use.
65
+ devices: How many devices to use.
66
+ precision: Indicates the Fabric precision setting to use.
67
+ """
68
+ precision = precision or get_default_supported_precision(training=False)
69
+
70
+ plugins = None
71
+ if quantize is not None:
72
+ if devices > 1:
73
+ raise NotImplementedError(
74
+ "Quantization is currently not supported for multi-GPU training. Please set devices=1 when using the"
75
+ " --quantize flag."
76
+ )
77
+ if quantize.startswith("bnb."):
78
+ if "mixed" in precision:
79
+ raise ValueError("Quantization and mixed precision is not supported.")
80
+ dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
81
+ plugins = BitsandbytesPrecision(quantize[4:], dtype)
82
+ precision = None
83
+
84
+ if strategy == "fsdp":
85
+ strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
86
+
87
+ fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy, plugins=plugins)
88
+ fabric.launch()
89
+
90
+ check_valid_checkpoint_dir(checkpoint_dir)
91
+
92
+ config = Config.from_json(
93
+ checkpoint_dir / "lit_config.json",
94
+ r=lora_r,
95
+ alpha=lora_alpha,
96
+ dropout=lora_dropout,
97
+ to_query=lora_query,
98
+ to_key=lora_key,
99
+ to_value=lora_value,
100
+ to_projection=lora_projection,
101
+ to_mlp=lora_mlp,
102
+ to_head=lora_head,
103
+ )
104
+
105
+ if quantize is not None and devices > 1:
106
+ raise NotImplementedError
107
+ if quantize == "gptq.int4":
108
+ model_file = "lit_model_gptq.4bit.pth"
109
+ if not (checkpoint_dir / model_file).is_file():
110
+ raise ValueError("Please run `python quantize/gptq.py` first")
111
+ else:
112
+ model_file = "lit_model.pth"
113
+ checkpoint_path = checkpoint_dir / model_file
114
+
115
+ tokenizer = Tokenizer(checkpoint_dir)
116
+ sample = {"instruction": prompt, "input": input}
117
+ prompt = generate_prompt(sample)
118
+ encoded = tokenizer.encode(prompt, device=fabric.device)
119
+ prompt_length = encoded.size(0)
120
+ max_returned_tokens = prompt_length + max_new_tokens
121
+
122
+ fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
123
+ t0 = time.perf_counter()
124
+ with fabric.init_module(empty_init=True), gptq_quantization(quantize == "gptq.int4"):
125
+ model = GPT(config)
126
+ fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
127
+ with fabric.init_tensor():
128
+ # set the max_seq_length to limit the memory usage to what we need
129
+ model.max_seq_length = max_returned_tokens
130
+ # enable the kv cache
131
+ model.set_kv_cache(batch_size=1)
132
+ model.eval()
133
+
134
+ t0 = time.perf_counter()
135
+ checkpoint = lazy_load(checkpoint_path)
136
+ lora_checkpoint = lazy_load(lora_path)
137
+ checkpoint.update(lora_checkpoint.get("model", lora_checkpoint))
138
+ model.load_state_dict(checkpoint)
139
+ fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
140
+
141
+ merge_lora_weights(model)
142
+ model = fabric.setup(model)
143
+
144
+ L.seed_everything(1234)
145
+ t0 = time.perf_counter()
146
+ y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
147
+ t = time.perf_counter() - t0
148
+
149
+ output = tokenizer.decode(y)
150
+ output = output.split("### Response:")[1].strip()
151
+ fabric.print(output)
152
+
153
+ tokens_generated = y.size(0) - prompt_length
154
+ fabric.print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
155
+ if fabric.device.type == "cuda":
156
+ fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)
157
+
158
+
159
+ if __name__ == "__main__":
160
+ from jsonargparse import CLI
161
+
162
+ torch.set_float32_matmul_precision("high")
163
+ CLI(main)
generate_test.ipynb ADDED
@@ -0,0 +1,754 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "data": {
10
+ "text/plain": [
11
+ "True"
12
+ ]
13
+ },
14
+ "execution_count": 1,
15
+ "metadata": {},
16
+ "output_type": "execute_result"
17
+ }
18
+ ],
19
+ "source": [
20
+ "import torch\n",
21
+ "\n",
22
+ "torch.cuda.is_available()"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 2,
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "import glob\n",
32
+ "import math\n",
33
+ "import sys\n",
34
+ "import time\n",
35
+ "from pathlib import Path\n",
36
+ "from typing import Optional, Tuple, Union\n",
37
+ "\n",
38
+ "import lightning as L\n",
39
+ "import torch\n",
40
+ "from lightning.fabric.loggers import CSVLogger\n",
41
+ "from lightning.fabric.strategies import FSDPStrategy\n",
42
+ "from torch.utils.data import DataLoader\n",
43
+ "\n",
44
+ "# # support running without installing as a package\n",
45
+ "# wd = Path(__file__).parent.parent.resolve()\n",
46
+ "# sys.path.append(str(wd))\n",
47
+ "\n",
48
+ "from tsai_gpt.model import GPT, Block, Config\n",
49
+ "from tsai_gpt.packed_dataset import CombinedDataset, PackedDataset\n",
50
+ "from tsai_gpt.speed_monitor import SpeedMonitorBase, estimate_flops, measure_flops\n",
51
+ "from tsai_gpt.speed_monitor import SpeedMonitorFabric as SpeedMonitor\n",
52
+ "from tsai_gpt.utils import (\n",
53
+ " chunked_cross_entropy,\n",
54
+ " get_default_supported_precision,\n",
55
+ " num_parameters,\n",
56
+ " load_checkpoint,\n",
57
+ ")"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": 3,
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "model_name = \"pythia-160m\"\n",
67
+ "name = \"redpajama\"\n",
68
+ "out_dir = Path(\"out\") / name\n",
69
+ "save_interval = 1000\n",
70
+ "eval_interval = 1000\n",
71
+ "eval_iters = 100\n",
72
+ "log_interval = 100"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": 4,
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "# Hyperparameters\n",
82
+ "learning_rate = 6e-3\n",
83
+ "batch_size = 32\n",
84
+ "micro_batch_size = 8\n",
85
+ "gradient_accumulation_steps = batch_size // micro_batch_size\n",
86
+ "assert gradient_accumulation_steps > 0\n",
87
+ "# max_iters = 600000 # num_epochs * (epoch_size // micro_batch_size) // devices\n",
88
+ "max_iters = 15000\n",
89
+ "weight_decay = 1e-1\n",
90
+ "beta1 = 0.9\n",
91
+ "beta2 = 0.95\n",
92
+ "grad_clip = 1.0\n",
93
+ "decay_lr = True\n",
94
+ "warmup_iters = 2000\n",
95
+ "lr_decay_iters = max_iters\n",
96
+ "min_lr = 6e-6"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": 5,
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "# Data proportions from https://arxiv.org/pdf/2302.13971.pdf Table 1\n",
106
+ "data_config = [\n",
107
+ " (\"arxiv\", 2.5),\n",
108
+ " (\"book\", 4.5),\n",
109
+ " (\"c4\", 15.0),\n",
110
+ " (\"cc\", 67.0),\n",
111
+ " (\"github\", 4.5),\n",
112
+ " (\"stackexchange\", 2.0),\n",
113
+ " (\"wikipedia\", 4.5),\n",
114
+ "]"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": 6,
120
+ "metadata": {},
121
+ "outputs": [],
122
+ "source": [
123
+ "hparams = {\n",
124
+ " k: v\n",
125
+ " for k, v in locals().items()\n",
126
+ " if isinstance(v, (int, float, str)) and not k.startswith(\"_\")\n",
127
+ "}\n",
128
+ "logger = CSVLogger(\"out\", name, flush_logs_every_n_steps=log_interval)\n",
129
+ "\n",
130
+ "\n",
131
+ "def setup(\n",
132
+ " devices: int = 4,\n",
133
+ " train_data_dir: Path = Path(\"data/redpajama_sample\"),\n",
134
+ " val_data_dir: Optional[Path] = None,\n",
135
+ " precision: Optional[str] = None,\n",
136
+ " resume: Union[bool, Path] = False,\n",
137
+ ") -> None:\n",
138
+ " precision = precision or get_default_supported_precision(training=True)\n",
139
+ "\n",
140
+ " if devices > 1:\n",
141
+ " strategy = FSDPStrategy(\n",
142
+ " auto_wrap_policy={Block},\n",
143
+ " activation_checkpointing_policy={Block},\n",
144
+ " state_dict_type=\"full\",\n",
145
+ " limit_all_gathers=True,\n",
146
+ " cpu_offload=False,\n",
147
+ " )\n",
148
+ " else:\n",
149
+ " strategy = \"auto\"\n",
150
+ "\n",
151
+ " fabric = L.Fabric(\n",
152
+ " devices=devices, strategy=strategy, precision=precision, loggers=logger\n",
153
+ " )\n",
154
+ " fabric.print(hparams)\n",
155
+ " fabric.launch(main, train_data_dir, val_data_dir, resume)"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": 7,
161
+ "metadata": {},
162
+ "outputs": [],
163
+ "source": [
164
+ "model_copy = None"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": 8,
170
+ "metadata": {},
171
+ "outputs": [],
172
+ "source": [
173
+ "def main(\n",
174
+ " fabric: L.Fabric,\n",
175
+ " train_data_dir: Path,\n",
176
+ " val_data_dir: Path,\n",
177
+ " resume: Union[bool, Path],\n",
178
+ ") -> None:\n",
179
+ " global model_copy\n",
180
+ " speed_monitor = SpeedMonitor(fabric, window_size=50, time_unit=\"seconds\")\n",
181
+ "\n",
182
+ " if fabric.global_rank == 0:\n",
183
+ " out_dir.mkdir(parents=True, exist_ok=True)\n",
184
+ "\n",
185
+ " config = Config.from_name(model_name)\n",
186
+ "\n",
187
+ " train_dataloader, val_dataloader = create_dataloaders(\n",
188
+ " batch_size=micro_batch_size,\n",
189
+ " block_size=config.block_size,\n",
190
+ " fabric=fabric,\n",
191
+ " train_data_dir=train_data_dir,\n",
192
+ " val_data_dir=val_data_dir,\n",
193
+ " seed=(1337 + fabric.global_rank),\n",
194
+ " )\n",
195
+ " if val_dataloader is None:\n",
196
+ " train_dataloader = fabric.setup_dataloaders(train_dataloader)\n",
197
+ " else:\n",
198
+ " train_dataloader, val_dataloader = fabric.setup_dataloaders(\n",
199
+ " train_dataloader, val_dataloader\n",
200
+ " )\n",
201
+ "\n",
202
+ " fabric.seed_everything(1337) # same seed for every process to init model (FSDP)\n",
203
+ "\n",
204
+ " fabric.print(f\"Loading model with {config.__dict__}\")\n",
205
+ " t0 = time.perf_counter()\n",
206
+ " import torch\n",
207
+ " import torch.nn as nn\n",
208
+ "\n",
209
+ " def _init_weights(module: nn.Module) -> None:\n",
210
+ " \"\"\"Meant to be used with `gpt.apply(gpt._init_weights)`.\"\"\"\n",
211
+ " if isinstance(module, nn.Linear):\n",
212
+ " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
213
+ " if module.bias is not None:\n",
214
+ " torch.nn.init.zeros_(module.bias)\n",
215
+ " elif isinstance(module, nn.Embedding):\n",
216
+ " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
217
+ "\n",
218
+ " with fabric.init_module(empty_init=True):\n",
219
+ " model = GPT(config)\n",
220
+ " model.apply(_init_weights)\n",
221
+ " model.apply(_init_weights)\n",
222
+ "\n",
223
+ " # checkpoint_path = Path(\"out/redpajama/iter-000999-ckpt.pth\")\n",
224
+ "\n",
225
+ " # load_checkpoint(fabric, model, checkpoint_path)\n",
226
+ "\n",
227
+ " # print(model.transformer.h[0].mlp.fc.weight)\n",
228
+ "\n",
229
+ " fabric.print(f\"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.\")\n",
230
+ " fabric.print(f\"Total parameters {num_parameters(model):,}\")\n",
231
+ "\n",
232
+ " model = fabric.setup(model)\n",
233
+ " optimizer = torch.optim.AdamW(\n",
234
+ " model.parameters(),\n",
235
+ " lr=learning_rate,\n",
236
+ " weight_decay=weight_decay,\n",
237
+ " betas=(beta1, beta2),\n",
238
+ " foreach=False,\n",
239
+ " )\n",
240
+ "\n",
241
+ " # model_copy = model\n",
242
+ "\n",
243
+ " optimizer = fabric.setup_optimizers(optimizer)\n",
244
+ "\n",
245
+ " state = {\n",
246
+ " \"model\": model,\n",
247
+ " \"optimizer\": optimizer,\n",
248
+ " \"hparams\": hparams,\n",
249
+ " \"iter_num\": 0,\n",
250
+ " \"step_count\": 0,\n",
251
+ " }\n",
252
+ "\n",
253
+ " if resume is True:\n",
254
+ " resume = max(out_dir.glob(\"*.pth\"), key=lambda p: int(p.name.split(\"-\")[1]))\n",
255
+ " if resume:\n",
256
+ " fabric.print(f\"Resuming training from {resume}\")\n",
257
+ " fabric.load(resume, state)\n",
258
+ "\n",
259
+ " train_time = time.perf_counter()\n",
260
+ " train(fabric, state, train_dataloader, val_dataloader, speed_monitor)\n",
261
+ " fabric.print(f\"Training time: {(time.perf_counter()-train_time):.2f}s\")\n",
262
+ " if fabric.device.type == \"cuda\":\n",
263
+ " fabric.print(f\"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB\")"
264
+ ]
265
+ },
266
+ {
267
+ "cell_type": "code",
268
+ "execution_count": 9,
269
+ "metadata": {},
270
+ "outputs": [],
271
+ "source": [
272
+ "def train(\n",
273
+ " fabric: L.Fabric,\n",
274
+ " state: dict,\n",
275
+ " train_dataloader: DataLoader,\n",
276
+ " val_dataloader: DataLoader,\n",
277
+ " speed_monitor: SpeedMonitorBase,\n",
278
+ ") -> None:\n",
279
+ " model = state[\"model\"]\n",
280
+ " optimizer = state[\"optimizer\"]\n",
281
+ "\n",
282
+ " if val_dataloader is not None:\n",
283
+ " validate(fabric, model, val_dataloader) # sanity check\n",
284
+ "\n",
285
+ " with torch.device(\"meta\"):\n",
286
+ " meta_model = GPT(model.config)\n",
287
+ " # \"estimated\" is not as precise as \"measured\". Estimated is optimistic but widely used in the wild.\n",
288
+ " # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs,\n",
289
+ " # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead\n",
290
+ " estimated_flops = estimate_flops(meta_model) * micro_batch_size\n",
291
+ " fabric.print(\n",
292
+ " f\"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}\"\n",
293
+ " )\n",
294
+ " x = torch.randint(0, 1, (micro_batch_size, model.max_seq_length))\n",
295
+ " measured_flops = measure_flops(meta_model, x)\n",
296
+ " fabric.print(\n",
297
+ " f\"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}\"\n",
298
+ " )\n",
299
+ " del meta_model, x\n",
300
+ "\n",
301
+ " total_lengths = 0\n",
302
+ " total_t0 = time.perf_counter()\n",
303
+ "\n",
304
+ " for state[\"iter_num\"], train_data in enumerate(train_dataloader, state[\"iter_num\"]):\n",
305
+ " if state[\"iter_num\"] >= max_iters:\n",
306
+ " checkpoint_path = out_dir / f\"iter-{state['iter_num']:06d}-ckpt.pth\"\n",
307
+ " fabric.print(f\"Saving checkpoint to {str(checkpoint_path)!r}\")\n",
308
+ " fabric.save(checkpoint_path, state)\n",
309
+ " break\n",
310
+ "\n",
311
+ " # determine and set the learning rate for this iteration\n",
312
+ " lr = get_lr(state[\"iter_num\"]) if decay_lr else learning_rate\n",
313
+ " for param_group in optimizer.param_groups:\n",
314
+ " param_group[\"lr\"] = lr\n",
315
+ "\n",
316
+ " iter_t0 = time.perf_counter()\n",
317
+ "\n",
318
+ " input_ids = train_data[:, 0 : model.max_seq_length].contiguous()\n",
319
+ " targets = train_data[:, 1 : model.max_seq_length + 1].contiguous()\n",
320
+ "\n",
321
+ " is_accumulating = (state[\"iter_num\"] + 1) % gradient_accumulation_steps != 0\n",
322
+ " with fabric.no_backward_sync(model, enabled=is_accumulating):\n",
323
+ " logits = model(input_ids)\n",
324
+ " loss = chunked_cross_entropy(logits, targets, chunk_size=0)\n",
325
+ " fabric.backward(loss / gradient_accumulation_steps)\n",
326
+ "\n",
327
+ " # return\n",
328
+ "\n",
329
+ " if not is_accumulating:\n",
330
+ " fabric.clip_gradients(model, optimizer, max_norm=grad_clip)\n",
331
+ " optimizer.step()\n",
332
+ " optimizer.zero_grad()\n",
333
+ " state[\"step_count\"] += 1\n",
334
+ "\n",
335
+ " t1 = time.perf_counter()\n",
336
+ " total_lengths += input_ids.size(1)\n",
337
+ " speed_monitor.on_train_batch_end(\n",
338
+ " (state[\"iter_num\"] + 1) * micro_batch_size,\n",
339
+ " t1 - total_t0,\n",
340
+ " # this assumes that device FLOPs are the same and that all devices have the same batch size\n",
341
+ " fabric.world_size,\n",
342
+ " flops_per_batch=measured_flops,\n",
343
+ " lengths=total_lengths,\n",
344
+ " )\n",
345
+ " if state[\"iter_num\"] % log_interval == 0:\n",
346
+ " fabric.print(\n",
347
+ " f\"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, LR: {lr:.6f}, iter time:\"\n",
348
+ " f\" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}\"\n",
349
+ " )\n",
350
+ "\n",
351
+ " if (\n",
352
+ " val_dataloader is not None\n",
353
+ " and not is_accumulating\n",
354
+ " and state[\"step_count\"] % eval_interval == 0\n",
355
+ " ):\n",
356
+ " t0 = time.perf_counter()\n",
357
+ " val_loss = validate(fabric, model, val_dataloader)\n",
358
+ " t1 = time.perf_counter() - t0\n",
359
+ " speed_monitor.eval_end(t1)\n",
360
+ " fabric.print(\n",
361
+ " f\"step {state['iter_num']}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f}ms\"\n",
362
+ " )\n",
363
+ " fabric.barrier()\n",
364
+ " if not is_accumulating and state[\"step_count\"] % save_interval == 0:\n",
365
+ " checkpoint_path = out_dir / f\"iter-{state['iter_num']:06d}-ckpt.pth\"\n",
366
+ " fabric.print(f\"Saving checkpoint to {str(checkpoint_path)!r}\")\n",
367
+ " fabric.save(checkpoint_path, state)"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "code",
372
+ "execution_count": 10,
373
+ "metadata": {},
374
+ "outputs": [],
375
+ "source": [
376
+ "@torch.inference_mode()\n",
377
+ "def validate(\n",
378
+ " fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader\n",
379
+ ") -> torch.Tensor:\n",
380
+ " fabric.print(\"Validating ...\")\n",
381
+ " model.eval()\n",
382
+ "\n",
383
+ " losses = torch.zeros(eval_iters, device=fabric.device)\n",
384
+ " for k, val_data in enumerate(val_dataloader):\n",
385
+ " input_ids = val_data[:, 0 : model.max_seq_length].contiguous()\n",
386
+ " targets = val_data[:, 1 : model.max_seq_length + 1].contiguous()\n",
387
+ " logits = model(input_ids)\n",
388
+ " losses[k] = chunked_cross_entropy(logits, targets, chunk_size=0)\n",
389
+ " out = losses.mean()\n",
390
+ "\n",
391
+ " model.train()\n",
392
+ " return out"
393
+ ]
394
+ },
395
+ {
396
+ "cell_type": "code",
397
+ "execution_count": 11,
398
+ "metadata": {},
399
+ "outputs": [],
400
+ "source": [
401
+ "def create_dataloader(\n",
402
+ " batch_size: int,\n",
403
+ " block_size: int,\n",
404
+ " data_dir: Path,\n",
405
+ " fabric: L.Fabric,\n",
406
+ " shuffle: bool = True,\n",
407
+ " seed: int = 12345,\n",
408
+ ") -> DataLoader:\n",
409
+ " datasets = []\n",
410
+ " for prefix, _ in data_config:\n",
411
+ " filenames = glob.glob(str(data_dir / f\"{prefix}*\"))\n",
412
+ " dataset = PackedDataset(\n",
413
+ " filenames,\n",
414
+ " n_chunks=4,\n",
415
+ " block_size=block_size,\n",
416
+ " shuffle=shuffle,\n",
417
+ " seed=seed,\n",
418
+ " num_processes=fabric.world_size,\n",
419
+ " process_rank=fabric.global_rank,\n",
420
+ " )\n",
421
+ " datasets.append(dataset)\n",
422
+ "\n",
423
+ " if not datasets:\n",
424
+ " raise RuntimeError(\n",
425
+ " f\"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset.\"\n",
426
+ " )\n",
427
+ "\n",
428
+ " weights = [weight for _, weight in data_config]\n",
429
+ " sum_weights = sum(weights)\n",
430
+ " weights = [el / sum_weights for el in weights]\n",
431
+ "\n",
432
+ " combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights)\n",
433
+ "\n",
434
+ " return DataLoader(\n",
435
+ " combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True\n",
436
+ " )"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": 12,
442
+ "metadata": {},
443
+ "outputs": [],
444
+ "source": [
445
+ "def create_dataloaders(\n",
446
+ " batch_size: int,\n",
447
+ " block_size: int,\n",
448
+ " fabric: L.Fabric,\n",
449
+ " train_data_dir: Path = Path(\"data/redpajama_sample\"),\n",
450
+ " val_data_dir: Optional[Path] = None,\n",
451
+ " seed: int = 12345,\n",
452
+ ") -> Tuple[DataLoader, DataLoader]:\n",
453
+ " # Increase by one because we need the next word as well\n",
454
+ " effective_block_size = block_size + 1\n",
455
+ " train_dataloader = create_dataloader(\n",
456
+ " batch_size=batch_size,\n",
457
+ " block_size=effective_block_size,\n",
458
+ " fabric=fabric,\n",
459
+ " data_dir=train_data_dir,\n",
460
+ " shuffle=True,\n",
461
+ " seed=seed,\n",
462
+ " )\n",
463
+ " val_dataloader = (\n",
464
+ " create_dataloader(\n",
465
+ " batch_size=batch_size,\n",
466
+ " block_size=effective_block_size,\n",
467
+ " fabric=fabric,\n",
468
+ " data_dir=val_data_dir,\n",
469
+ " shuffle=False,\n",
470
+ " seed=seed,\n",
471
+ " )\n",
472
+ " if val_data_dir\n",
473
+ " else None\n",
474
+ " )\n",
475
+ " return train_dataloader, val_dataloader"
476
+ ]
477
+ },
478
+ {
479
+ "cell_type": "code",
480
+ "execution_count": 13,
481
+ "metadata": {},
482
+ "outputs": [],
483
+ "source": [
484
+ "def get_lr(it: int) -> float:\n",
485
+ " # 1) linear warmup for warmup_iters steps\n",
486
+ " if it < warmup_iters:\n",
487
+ " return learning_rate * it / warmup_iters\n",
488
+ " # 2) if it > lr_decay_iters, return min learning rate\n",
489
+ " if it > lr_decay_iters:\n",
490
+ " return min_lr\n",
491
+ " # 3) in between, use cosine decay down to min learning rate\n",
492
+ " decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)\n",
493
+ " assert 0 <= decay_ratio <= 1\n",
494
+ " coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1\n",
495
+ " return min_lr + coeff * (learning_rate - min_lr)"
496
+ ]
497
+ },
498
+ {
499
+ "cell_type": "code",
500
+ "execution_count": 16,
501
+ "metadata": {},
502
+ "outputs": [],
503
+ "source": [
504
+ "# torch.set_float32_matmul_precision(\"medium\")\n",
505
+ "# setup(devices=1, train_data_dir=Path(\"data/lit-redpajama-sample\"))"
506
+ ]
507
+ },
508
+ {
509
+ "cell_type": "code",
510
+ "execution_count": 5,
511
+ "metadata": {},
512
+ "outputs": [],
513
+ "source": [
514
+ "from generate.base import main\n",
515
+ "from pathlib import Path"
516
+ ]
517
+ },
518
+ {
519
+ "cell_type": "code",
520
+ "execution_count": 7,
521
+ "metadata": {},
522
+ "outputs": [
523
+ {
524
+ "name": "stderr",
525
+ "output_type": "stream",
526
+ "text": [
527
+ "Loading model 'out/redpajama/lit_model.pth' with {'name': 'pythia-160m', 'hf_config': {'org': 'EleutherAI', 'name': 'pythia-160m'}, 'block_size': 2048, 'vocab_size': 50254, 'padding_multiple': 128, 'padded_vocab_size': 50304, 'n_layer': 12, 'n_head': 12, 'n_embd': 768, 'rotary_percentage': 0.25, 'parallel_residual': True, 'bias': True, 'lm_head_bias': False, 'n_query_groups': 12, 'shared_attention_norm': False, '_norm_class': 'LayerNorm', 'norm_eps': 1e-05, '_mlp_class': 'GptNeoxMLP', 'gelu_approximate': 'none', 'intermediate_size': 3072, 'rope_condense_ratio': 1, 'rope_base': 10000, 'head_size': 64, 'rope_n_elem': 16}\n",
528
+ "Time to instantiate model: 0.17 seconds.\n"
529
+ ]
530
+ },
531
+ {
532
+ "name": "stderr",
533
+ "output_type": "stream",
534
+ "text": [
535
+ "Time to load the model weights: 0.50 seconds.\n",
536
+ "Seed set to 1234\n"
537
+ ]
538
+ },
539
+ {
540
+ "name": "stdout",
541
+ "output_type": "stream",
542
+ "text": [
543
+ "Earth is a planet with rocky core and 100,000 hectares of natural Earth. Our planet is a planet with rocky core and 100,000 hectares of natural Earth. The sun has a warm, warm surface and the sun has a\n"
544
+ ]
545
+ },
546
+ {
547
+ "name": "stderr",
548
+ "output_type": "stream",
549
+ "text": [
550
+ "Time for inference 1: 0.71 sec total, 70.90 tokens/sec\n",
551
+ "Memory used: 0.35 GB\n"
552
+ ]
553
+ }
554
+ ],
555
+ "source": [
556
+ "import torch\n",
557
+ "\n",
558
+ "torch.set_float32_matmul_precision(\"high\")\n",
559
+ "main(\n",
560
+ " prompt=\"Earth is a planet with rocky core and \",\n",
561
+ " checkpoint_dir=Path(\"out/redpajama\"),\n",
562
+ ")"
563
+ ]
564
+ },
565
+ {
566
+ "cell_type": "code",
567
+ "execution_count": 12,
568
+ "metadata": {},
569
+ "outputs": [
570
+ {
571
+ "name": "stderr",
572
+ "output_type": "stream",
573
+ "text": [
574
+ "Loading model 'out/redpajama/lit_model.pth' with {'name': 'pythia-160m', 'hf_config': {'org': 'EleutherAI', 'name': 'pythia-160m'}, 'block_size': 2048, 'vocab_size': 50254, 'padding_multiple': 128, 'padded_vocab_size': 50304, 'n_layer': 12, 'n_head': 12, 'n_embd': 768, 'rotary_percentage': 0.25, 'parallel_residual': True, 'bias': True, 'lm_head_bias': False, 'n_query_groups': 12, 'shared_attention_norm': False, '_norm_class': 'LayerNorm', 'norm_eps': 1e-05, '_mlp_class': 'GptNeoxMLP', 'gelu_approximate': 'none', 'intermediate_size': 3072, 'rope_condense_ratio': 1, 'rope_base': 10000, 'head_size': 64, 'rope_n_elem': 16}\n",
575
+ "Time to instantiate model: 0.02 seconds.\n"
576
+ ]
577
+ },
578
+ {
579
+ "name": "stderr",
580
+ "output_type": "stream",
581
+ "text": [
582
+ "Time to load the model weights: 0.49 seconds.\n",
583
+ "Seed set to 1234\n"
584
+ ]
585
+ },
586
+ {
587
+ "name": "stdout",
588
+ "output_type": "stream",
589
+ "text": [
590
+ "I like to drive when it is raining outside and 100% of the time. The next day, I think you will see the right movement.\n",
591
+ "We already know that if you don't go to the center, you can be a hug, or a bit more vigor.\n"
592
+ ]
593
+ },
594
+ {
595
+ "name": "stderr",
596
+ "output_type": "stream",
597
+ "text": [
598
+ "Time for inference 1: 0.69 sec total, 72.80 tokens/sec\n",
599
+ "Memory used: 0.35 GB\n"
600
+ ]
601
+ }
602
+ ],
603
+ "source": [
604
+ "main(\n",
605
+ " prompt=\"I like to drive when it is raining outside and \",\n",
606
+ " checkpoint_dir=Path(\"out/redpajama\"),\n",
607
+ ")"
608
+ ]
609
+ },
610
+ {
611
+ "cell_type": "code",
612
+ "execution_count": 13,
613
+ "metadata": {},
614
+ "outputs": [
615
+ {
616
+ "name": "stderr",
617
+ "output_type": "stream",
618
+ "text": [
619
+ "Loading model 'out/redpajama/lit_model.pth' with {'name': 'pythia-160m', 'hf_config': {'org': 'EleutherAI', 'name': 'pythia-160m'}, 'block_size': 2048, 'vocab_size': 50254, 'padding_multiple': 128, 'padded_vocab_size': 50304, 'n_layer': 12, 'n_head': 12, 'n_embd': 768, 'rotary_percentage': 0.25, 'parallel_residual': True, 'bias': True, 'lm_head_bias': False, 'n_query_groups': 12, 'shared_attention_norm': False, '_norm_class': 'LayerNorm', 'norm_eps': 1e-05, '_mlp_class': 'GptNeoxMLP', 'gelu_approximate': 'none', 'intermediate_size': 3072, 'rope_condense_ratio': 1, 'rope_base': 10000, 'head_size': 64, 'rope_n_elem': 16}\n",
620
+ "Time to instantiate model: 0.02 seconds.\n",
621
+ "Time to load the model weights: 0.51 seconds.\n",
622
+ "Seed set to 1234\n"
623
+ ]
624
+ },
625
+ {
626
+ "name": "stdout",
627
+ "output_type": "stream",
628
+ "text": [
629
+ "I like to drive when it is raining outside and 100% of the time. The next day, I think you will see the right movement.\n",
630
+ "We already know that if you don't go to the center, you can be a hug, or a bit more vigor.\n"
631
+ ]
632
+ },
633
+ {
634
+ "name": "stderr",
635
+ "output_type": "stream",
636
+ "text": [
637
+ "Time for inference 1: 0.65 sec total, 76.96 tokens/sec\n",
638
+ "Memory used: 0.35 GB\n"
639
+ ]
640
+ }
641
+ ],
642
+ "source": [
643
+ "main(\n",
644
+ " prompt=\"I like to drive when it is raining outside and \",\n",
645
+ " checkpoint_dir=Path(\"out/redpajama\"),\n",
646
+ ")"
647
+ ]
648
+ },
649
+ {
650
+ "cell_type": "code",
651
+ "execution_count": 10,
652
+ "metadata": {},
653
+ "outputs": [
654
+ {
655
+ "name": "stderr",
656
+ "output_type": "stream",
657
+ "text": [
658
+ "Loading model 'out/redpajama/lit_model.pth' with {'name': 'pythia-160m', 'hf_config': {'org': 'EleutherAI', 'name': 'pythia-160m'}, 'block_size': 2048, 'vocab_size': 50254, 'padding_multiple': 128, 'padded_vocab_size': 50304, 'n_layer': 12, 'n_head': 12, 'n_embd': 768, 'rotary_percentage': 0.25, 'parallel_residual': True, 'bias': True, 'lm_head_bias': False, 'n_query_groups': 12, 'shared_attention_norm': False, '_norm_class': 'LayerNorm', 'norm_eps': 1e-05, '_mlp_class': 'GptNeoxMLP', 'gelu_approximate': 'none', 'intermediate_size': 3072, 'rope_condense_ratio': 1, 'rope_base': 10000, 'head_size': 64, 'rope_n_elem': 16}\n",
659
+ "Time to instantiate model: 0.02 seconds.\n",
660
+ "Time to load the model weights: 0.49 seconds.\n",
661
+ "Seed set to 1234\n"
662
+ ]
663
+ },
664
+ {
665
+ "name": "stdout",
666
+ "output_type": "stream",
667
+ "text": [
668
+ "What a beautiful day it was, never imagined I would be able to 100,000 times a month. It was the beginning of a carpet, and was about 15 minutes to drain from the carpet. We were so overwhelmed, ready to do the kits,\n"
669
+ ]
670
+ },
671
+ {
672
+ "name": "stderr",
673
+ "output_type": "stream",
674
+ "text": [
675
+ "Time for inference 1: 0.68 sec total, 73.18 tokens/sec\n",
676
+ "Memory used: 0.35 GB\n"
677
+ ]
678
+ }
679
+ ],
680
+ "source": [
681
+ "main(\n",
682
+ " prompt=\"What a beautiful day it was, never imagined I would be able to \",\n",
683
+ " checkpoint_dir=Path(\"out/redpajama\"),\n",
684
+ ")"
685
+ ]
686
+ },
687
+ {
688
+ "cell_type": "code",
689
+ "execution_count": 11,
690
+ "metadata": {},
691
+ "outputs": [
692
+ {
693
+ "name": "stderr",
694
+ "output_type": "stream",
695
+ "text": [
696
+ "Loading model 'out/redpajama/lit_model.pth' with {'name': 'pythia-160m', 'hf_config': {'org': 'EleutherAI', 'name': 'pythia-160m'}, 'block_size': 2048, 'vocab_size': 50254, 'padding_multiple': 128, 'padded_vocab_size': 50304, 'n_layer': 12, 'n_head': 12, 'n_embd': 768, 'rotary_percentage': 0.25, 'parallel_residual': True, 'bias': True, 'lm_head_bias': False, 'n_query_groups': 12, 'shared_attention_norm': False, '_norm_class': 'LayerNorm', 'norm_eps': 1e-05, '_mlp_class': 'GptNeoxMLP', 'gelu_approximate': 'none', 'intermediate_size': 3072, 'rope_condense_ratio': 1, 'rope_base': 10000, 'head_size': 64, 'rope_n_elem': 16}\n",
697
+ "Time to instantiate model: 0.02 seconds.\n",
698
+ "Time to load the model weights: 0.49 seconds.\n",
699
+ "Seed set to 1234\n"
700
+ ]
701
+ },
702
+ {
703
+ "name": "stdout",
704
+ "output_type": "stream",
705
+ "text": [
706
+ "Do you think Einstein was the greatest ever physicist ever lived? I think 1 of the 1980s wrote a very deep, poetic narration of my life. I know all of you and your life is beautiful, especially in the sense of storytelling. You are. I know all of you\n"
707
+ ]
708
+ },
709
+ {
710
+ "name": "stderr",
711
+ "output_type": "stream",
712
+ "text": [
713
+ "Time for inference 1: 0.68 sec total, 74.07 tokens/sec\n",
714
+ "Memory used: 0.35 GB\n"
715
+ ]
716
+ }
717
+ ],
718
+ "source": [
719
+ "main(\n",
720
+ " prompt=\"Do you think Einstein was the greatest ever physicist ever lived? I think \",\n",
721
+ " checkpoint_dir=Path(\"out/redpajama\"),\n",
722
+ ")"
723
+ ]
724
+ },
725
+ {
726
+ "cell_type": "code",
727
+ "execution_count": null,
728
+ "metadata": {},
729
+ "outputs": [],
730
+ "source": []
731
+ }
732
+ ],
733
+ "metadata": {
734
+ "kernelspec": {
735
+ "display_name": "base",
736
+ "language": "python",
737
+ "name": "python3"
738
+ },
739
+ "language_info": {
740
+ "codemirror_mode": {
741
+ "name": "ipython",
742
+ "version": 3
743
+ },
744
+ "file_extension": ".py",
745
+ "mimetype": "text/x-python",
746
+ "name": "python",
747
+ "nbconvert_exporter": "python",
748
+ "pygments_lexer": "ipython3",
749
+ "version": "3.10.13"
750
+ }
751
+ },
752
+ "nbformat": 4,
753
+ "nbformat_minor": 2
754
+ }
generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "do_sample": true,
4
+ "eos_token_id": 2,
5
+ "max_length": 4096,
6
+ "pad_token_id": 0,
7
+ "temperature": 0.6,
8
+ "top_p": 0.9,
9
+ "transformers_version": "4.32.0.dev0"
10
+ }
lit_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"name": "Llama-2-7b-chat-hf", "hf_config": {"org": "meta-llama", "name": "Llama-2-7b-chat-hf"}, "block_size": 4096, "vocab_size": 32000, "padding_multiple": 64, "padded_vocab_size": 32000, "n_layer": 32, "n_head": 32, "n_embd": 4096, "rotary_percentage": 1.0, "parallel_residual": false, "bias": false, "lm_head_bias": false, "n_query_groups": 32, "shared_attention_norm": false, "_norm_class": "RMSNorm", "norm_eps": 1e-05, "_mlp_class": "LLaMAMLP", "gelu_approximate": "none", "intermediate_size": 11008, "rope_condense_ratio": 1, "rope_base": 10000}
lit_gpt/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import logging
3
+
4
+ from lit_gpt.model import GPT
5
+ from lit_gpt.config import Config
6
+ from lit_gpt.tokenizer import Tokenizer
7
+
8
+ from lightning_utilities.core.imports import RequirementCache
9
+
10
+ _LIGHTNING_AVAILABLE = RequirementCache("lightning>=2.2.0.dev0")
11
+ if not bool(_LIGHTNING_AVAILABLE):
12
+ raise ImportError(
13
+ "Lit-GPT requires lightning nightly. Please run:\n"
14
+ f" pip uninstall -y lightning; pip install -r requirements.txt\n{str(_LIGHTNING_AVAILABLE)}"
15
+ )
16
+
17
+ # Suppress excessive warnings, see https://github.com/pytorch/pytorch/issues/111632
18
+ pattern = re.compile(".*Profiler function .* will be ignored")
19
+ logging.getLogger("torch._dynamo.variables.torch").addFilter(lambda record: not pattern.search(record.getMessage()))
20
+
21
+
22
+ __all__ = ["GPT", "Config", "Tokenizer"]
lit_gpt/adapter.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of the paper:
2
+
3
+ LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
4
+ https://arxiv.org/abs/2303.16199
5
+
6
+ Port for Lit-GPT
7
+ """
8
+ from dataclasses import dataclass
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from typing_extensions import Self
14
+
15
+ from lit_gpt.config import Config as BaseConfig
16
+ from lit_gpt.model import GPT as BaseModel
17
+ from lit_gpt.model import Block as BaseBlock
18
+ from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention
19
+
20
+
21
+ @dataclass
22
+ class Config(BaseConfig):
23
+ adapter_prompt_length: int = 10
24
+ adapter_start_layer: int = 2
25
+
26
+
27
+ class GPT(BaseModel):
28
+ """The implementation is identical to `lit_gpt.model.GPT` with the exception that
29
+ the `Block` saves the layer index and passes it down to the attention layer."""
30
+
31
+ def __init__(self, config: Config) -> None:
32
+ nn.Module.__init__(self)
33
+ assert config.padded_vocab_size is not None
34
+ self.config = config
35
+
36
+ self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
37
+ self.transformer = nn.ModuleDict(
38
+ dict(
39
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
40
+ h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
41
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
42
+ )
43
+ )
44
+ self.max_seq_length = self.config.block_size
45
+ self.mask_cache: Optional[torch.Tensor] = None
46
+
47
+ def forward(
48
+ self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0
49
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
50
+ T = idx.size(1)
51
+ if self.max_seq_length < T:
52
+ raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")
53
+
54
+ if input_pos is not None: # use the kv cache
55
+ cos = self.cos.index_select(0, input_pos)
56
+ sin = self.sin.index_select(0, input_pos)
57
+ if self.mask_cache is None:
58
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
59
+ mask = self.mask_cache.index_select(2, input_pos)
60
+ else:
61
+ cos = self.cos[:T]
62
+ sin = self.sin[:T]
63
+ mask = None
64
+
65
+ x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
66
+ for block in self.transformer.h:
67
+ x = block(x, cos, sin, mask, input_pos)
68
+ x = self.transformer.ln_f(x)
69
+ if lm_head_chunk_size > 0:
70
+ # chunk the lm head logits to reduce the peak memory used by autograd
71
+ return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)]
72
+ return self.lm_head(x) # (b, t, vocab_size)
73
+
74
+ @classmethod
75
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
76
+ return cls(Config.from_name(name, **kwargs))
77
+
78
+ def _init_weights(self, module: nn.Module) -> None:
79
+ """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
80
+ super()._init_weights(module)
81
+ if isinstance(module, CausalSelfAttention):
82
+ module.reset_parameters()
83
+
84
+
85
+ class Block(BaseBlock):
86
+ """The implementation is identical to `lit_gpt.model.Block` with the exception that
87
+ we replace the attention layer where adaption is implemented."""
88
+
89
+ def __init__(self, config: Config, block_idx: int) -> None:
90
+ # Skip the parent class __init__ altogether and replace it to avoid useless allocations
91
+ nn.Module.__init__(self)
92
+ self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
93
+ self.attn = CausalSelfAttention(config, block_idx)
94
+ if not config.shared_attention_norm:
95
+ self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
96
+ self.mlp = config.mlp_class(config)
97
+
98
+ self.config = config
99
+
100
+
101
+ class CausalSelfAttention(BaseCausalSelfAttention):
102
+ """A modification of `lit_gpt.model.CausalSelfAttention` that adds the attention
103
+ over the adaption prompt."""
104
+
105
+ def __init__(self, config: Config, block_idx: int) -> None:
106
+ super().__init__(config)
107
+ if block_idx >= config.adapter_start_layer:
108
+ # adapter embedding layer
109
+ self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
110
+ # gate for adaption
111
+ self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
112
+ # kv cache for inference
113
+ self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
114
+ self.block_idx = block_idx
115
+
116
+ def scaled_dot_product_attention(
117
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
118
+ ) -> torch.Tensor:
119
+ y = super().scaled_dot_product_attention(q, k, v, mask)
120
+ if self.block_idx < self.config.adapter_start_layer:
121
+ return y
122
+
123
+ aT = self.config.adapter_prompt_length
124
+ if self.adapter_kv_cache is not None:
125
+ # since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av
126
+ # are the same every call
127
+ ak, av = self.adapter_kv_cache
128
+ else:
129
+ prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd)
130
+ aqkv = self.attn(prefix)
131
+ q_per_kv = self.config.n_head // self.config.n_query_groups
132
+ aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)
133
+ aqkv = aqkv.permute(0, 2, 3, 1, 4)
134
+ _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2)
135
+ if self.config.n_query_groups != 1:
136
+ # for MHA this is a no-op
137
+ ak = ak.repeat_interleave(q_per_kv, dim=2)
138
+ av = av.repeat_interleave(q_per_kv, dim=2)
139
+ ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs)
140
+ av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs)
141
+ self.adapter_kv_cache = (ak, av)
142
+
143
+ T = q.size(2)
144
+ amask = torch.ones(T, aT, dtype=torch.bool, device=q.device)
145
+ ay = super().scaled_dot_product_attention(q, ak, av, amask)
146
+ return y + self.gating_factor * ay
147
+
148
+ def reset_parameters(self) -> None:
149
+ torch.nn.init.zeros_(self.gating_factor)
150
+
151
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
152
+ """For compatibility with older checkpoints."""
153
+ if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
154
+ state_dict[key] = state_dict[key].permute(0, 2, 1, 3)
155
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
156
+
157
+
158
+ def mark_only_adapter_as_trainable(model: GPT) -> None:
159
+ """Sets `requires_grad=False` for all non-adapter weights."""
160
+ for name, param in model.named_parameters():
161
+ param.requires_grad = adapter_filter(name, param)
162
+
163
+
164
+ def adapter_filter(key: str, value: Any) -> bool:
165
+ return "adapter_wte" in key or "gating_factor" in key
lit_gpt/adapter_v2.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of the paper:
2
+
3
+ LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model
4
+ https://arxiv.org/abs/2304.15010
5
+
6
+ Port for Lit-GPT
7
+ """
8
+ from dataclasses import dataclass
9
+ from typing import Any, Dict, Optional, Tuple, Type
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from typing_extensions import Self
14
+
15
+ import lit_gpt
16
+ from lit_gpt.adapter import GPT as BaseModel
17
+ from lit_gpt.adapter import Block as BaseBlock
18
+ from lit_gpt.adapter import CausalSelfAttention as BaseCausalSelfAttention
19
+ from lit_gpt.adapter import Config as BaseConfig
20
+ from lit_gpt.model import KVCache
21
+ from lit_gpt.utils import map_old_state_dict_weights
22
+
23
+
24
+ @dataclass
25
+ class Config(BaseConfig):
26
+ @property
27
+ def mlp_class(self) -> Type:
28
+ return getattr(lit_gpt.adapter_v2, self._mlp_class)
29
+
30
+
31
+ def adapter_filter(key: str, value: Any) -> bool:
32
+ adapter_substrings = (
33
+ # regular adapter v1 parameters
34
+ "adapter_wte",
35
+ "gating_factor",
36
+ # adapter v2: new bias and scale used in Linear
37
+ "adapter_scale",
38
+ "adapter_bias",
39
+ # adapter v2: Norm parameters are now trainable
40
+ "norm_1",
41
+ "norm_2",
42
+ "ln_f",
43
+ )
44
+ return any(s in key for s in adapter_substrings)
45
+
46
+
47
+ class AdapterV2Linear(torch.nn.Module):
48
+ def __init__(self, in_features: int, out_features: int, **kwargs) -> None:
49
+ super().__init__()
50
+ self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
51
+ self.adapter_bias = torch.nn.Parameter(torch.zeros(out_features), requires_grad=False)
52
+ self.adapter_scale = torch.nn.Parameter(torch.ones(out_features), requires_grad=False)
53
+
54
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
55
+ return self.adapter_scale * (self.linear(x) + self.adapter_bias)
56
+
57
+ def reset_parameters(self) -> None:
58
+ nn.init.zeros_(self.adapter_bias)
59
+ nn.init.ones_(self.adapter_scale)
60
+
61
+
62
+ class GPT(BaseModel):
63
+ def __init__(self, config: Config) -> None:
64
+ # Skip the parent class __init__ altogether and replace it to avoid useless allocations
65
+ nn.Module.__init__(self)
66
+ assert config.padded_vocab_size is not None
67
+ self.config = config
68
+
69
+ self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
70
+ self.transformer = nn.ModuleDict(
71
+ dict(
72
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
73
+ h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
74
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
75
+ )
76
+ )
77
+ self.max_seq_length = self.config.block_size
78
+ self.mask_cache: Optional[torch.Tensor] = None
79
+
80
+ @classmethod
81
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
82
+ return cls(Config.from_name(name, **kwargs))
83
+
84
+ def _init_weights(self, module: nn.Module) -> None:
85
+ """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
86
+ super()._init_weights(module)
87
+ if isinstance(module, AdapterV2Linear):
88
+ module.reset_parameters()
89
+
90
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
91
+ """For compatibility with base checkpoints."""
92
+ mapping = {"lm_head.weight": "lm_head.linear.weight", "lm_head.bias": "lm_head.linear.bias"}
93
+ state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
94
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
95
+
96
+
97
+ class Block(BaseBlock):
98
+ """The implementation is identical to `lit_gpt.model.Block` with the exception that
99
+ we replace the attention layer where adaption is implemented."""
100
+
101
+ def __init__(self, config: Config, block_idx: int) -> None:
102
+ # Skip the parent class __init__ altogether and replace it to avoid useless allocations
103
+ nn.Module.__init__(self)
104
+ self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
105
+ self.attn = CausalSelfAttention(config, block_idx)
106
+ if not config.shared_attention_norm:
107
+ self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
108
+ self.mlp = config.mlp_class(config)
109
+
110
+ self.config = config
111
+
112
+
113
+ class CausalSelfAttention(BaseCausalSelfAttention):
114
+ """A modification of `lit_gpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class"""
115
+
116
+ def __init__(self, config: Config, block_idx: int) -> None:
117
+ # Skip the parent class __init__ altogether and replace it to avoid useless allocations
118
+ nn.Module.__init__(self)
119
+ shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
120
+ # key, query, value projections for all heads, but in a batch
121
+ self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias)
122
+ # output projection
123
+ self.proj = AdapterV2Linear(config.n_embd, config.n_embd, bias=config.bias)
124
+ # disabled by default
125
+ self.kv_cache: Optional[KVCache] = None
126
+
127
+ if block_idx >= config.adapter_start_layer:
128
+ # adapter embedding layer
129
+ self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
130
+ # gate for adaption
131
+ self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
132
+ # kv cache for inference
133
+ self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
134
+ self.block_idx = block_idx
135
+
136
+ self.config = config
137
+
138
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
139
+ """For compatibility with base checkpoints."""
140
+ mapping = {
141
+ "attn.weight": "attn.linear.weight",
142
+ "attn.bias": "attn.linear.bias",
143
+ "proj.weight": "proj.linear.weight",
144
+ "proj.bias": "proj.linear.bias",
145
+ }
146
+ state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
147
+ # For compatibility with older checkpoints
148
+ if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
149
+ state_dict[key] = state_dict[key].permute(0, 2, 1, 3)
150
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
151
+
152
+
153
+ class GptNeoxMLP(lit_gpt.model.GptNeoxMLP):
154
+ def __init__(self, config: Config) -> None:
155
+ nn.Module.__init__(self)
156
+ self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
157
+ self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)
158
+
159
+ self.config = config
160
+
161
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
162
+ """For compatibility with base checkpoints."""
163
+ mapping = {
164
+ "fc.weight": "fc.linear.weight",
165
+ "fc.bias": "fc.linear.bias",
166
+ "proj.weight": "proj.linear.weight",
167
+ "proj.bias": "proj.linear.bias",
168
+ }
169
+ state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
170
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
171
+
172
+
173
+ class LLaMAMLP(lit_gpt.model.LLaMAMLP):
174
+ def __init__(self, config: Config) -> None:
175
+ nn.Module.__init__(self)
176
+ self.fc_1 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
177
+ self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
178
+ self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)
179
+
180
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
181
+ """For compatibility with base checkpoints."""
182
+ mapping = {
183
+ "fc_1.weight": "fc_1.linear.weight",
184
+ "fc_1.bias": "fc_1.linear.bias",
185
+ "fc_2.weight": "fc_2.linear.weight",
186
+ "fc_2.bias": "fc_2.linear.bias",
187
+ "proj.weight": "proj.linear.weight",
188
+ "proj.bias": "proj.linear.bias",
189
+ }
190
+ state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
191
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
192
+
193
+
194
+ def mark_only_adapter_v2_as_trainable(model: GPT) -> None:
195
+ """Sets requires_grad=False for all non-adapter weights"""
196
+ for name, param in model.named_parameters():
197
+ param.requires_grad = adapter_filter(name, param)
lit_gpt/config.py ADDED
@@ -0,0 +1,1203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from copy import deepcopy
3
+ from dataclasses import dataclass, field
4
+ from pathlib import Path
5
+ from typing import Any, Literal, Optional, Type, Union
6
+
7
+ import torch
8
+ from typing_extensions import Self
9
+
10
+ import lit_gpt.model
11
+ from lit_gpt.utils import find_multiple
12
+
13
+
14
+ @dataclass
15
+ class Config:
16
+ name: str = ""
17
+ hf_config: dict = field(default_factory=dict)
18
+ block_size: int = 4096
19
+ vocab_size: int = 50254
20
+ padding_multiple: int = 512
21
+ padded_vocab_size: Optional[int] = None
22
+ n_layer: int = 16
23
+ n_head: int = 32
24
+ n_embd: int = 4096
25
+ rotary_percentage: float = 0.25
26
+ parallel_residual: bool = True
27
+ bias: bool = True
28
+ lm_head_bias: bool = False
29
+ # to use multi-head attention (MHA), set this to `n_head` (default)
30
+ # to use multi-query attention (MQA), set this to 1
31
+ # to use grouped-query attention (GQA), set this to a value in between
32
+ # Example with `n_head=4`
33
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
34
+ # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
35
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
36
+ # │ │ │ │ │ │ │
37
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
38
+ # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
39
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
40
+ # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
41
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
42
+ # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
43
+ # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
44
+ # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
45
+ # MHA GQA MQA
46
+ # n_query_groups=4 n_query_groups=2 n_query_groups=1
47
+ #
48
+ # credit https://arxiv.org/pdf/2305.13245.pdf
49
+ n_query_groups: Optional[int] = None
50
+ shared_attention_norm: bool = False
51
+ _norm_class: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
52
+ norm_eps: float = 1e-5
53
+ _mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP"
54
+ gelu_approximate: str = "none"
55
+ intermediate_size: Optional[int] = None
56
+ rope_condense_ratio: int = 1
57
+ rope_base: int = 10000
58
+
59
+ def __post_init__(self):
60
+ if not self.name:
61
+ self.name = self.hf_config.get("name", self.name)
62
+
63
+ assert self.n_embd % self.n_head == 0
64
+ self.head_size = self.n_embd // self.n_head
65
+
66
+ # vocab size should be a power of 2 to be optimal on hardware. compute the closest value
67
+ if self.padded_vocab_size is None:
68
+ self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple)
69
+ else:
70
+ # vocab size shouldn't be larger than padded vocab size
71
+ self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
72
+
73
+ # compute the number of query groups
74
+ if self.n_query_groups is not None:
75
+ assert self.n_head % self.n_query_groups == 0
76
+ else:
77
+ self.n_query_groups = self.n_head
78
+
79
+ # compute the intermediate size for MLP if not set
80
+ if self.intermediate_size is None:
81
+ if self._mlp_class == "LLaMAMLP":
82
+ raise ValueError("The config needs to set the `intermediate_size`")
83
+ self.intermediate_size = 4 * self.n_embd
84
+
85
+ self.rope_n_elem = int(self.rotary_percentage * self.head_size)
86
+
87
+ @classmethod
88
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
89
+ if name not in name_to_config:
90
+ # search through all `config['hf_config']['name']`
91
+ try:
92
+ conf_dict = next(config for config in configs if name == config["hf_config"]["name"])
93
+ except StopIteration:
94
+ raise ValueError(f"{name!r} is not a supported config name")
95
+ else:
96
+ conf_dict = name_to_config[name]
97
+
98
+ conf_dict = conf_dict.copy()
99
+ if "condense_ratio" in kwargs: # legacy name
100
+ kwargs["rope_condense_ratio"] = kwargs.pop("condense_ratio")
101
+ conf_dict.update(kwargs)
102
+ return cls(**conf_dict)
103
+
104
+ @classmethod
105
+ def from_json(cls, path: Union[str, Path], **kwargs: Any) -> Self:
106
+ with open(path, encoding="utf-8") as fp:
107
+ json_kwargs = json.load(fp)
108
+ if "condense_ratio" in json_kwargs: # legacy name
109
+ json_kwargs["rope_condense_ratio"] = json_kwargs.pop("condense_ratio")
110
+ if "condense_ratio" in kwargs: # legacy name
111
+ kwargs["rope_condense_ratio"] = kwargs.pop("condense_ratio")
112
+ if "org" in json_kwargs: # legacy name
113
+ json_kwargs["hf_config"] = {"name": json_kwargs["name"], "org": json_kwargs.pop("org")}
114
+ if "org" in kwargs: # legacy name
115
+ kwargs["hf_config"] = {"name": kwargs.get("name", json_kwargs["name"]), "org": kwargs.pop("org")}
116
+ json_kwargs.update(kwargs)
117
+ return cls(**json_kwargs)
118
+
119
+ @classmethod
120
+ def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
121
+ """Automatically load `lit_config.json` and if it doesn't exist - a matching config from `lit_gpt/config.py`."""
122
+ if (config_path := path / "lit_config.json").is_file():
123
+ return cls.from_json(config_path, **kwargs)
124
+ if (model_name := path.name) in name_to_config:
125
+ return cls.from_name(model_name, **kwargs)
126
+ raise FileNotFoundError(f"For {str(path)!r} neither 'lit_config.json' nor matching config exists.")
127
+
128
+ @property
129
+ def mlp_class(self) -> Type:
130
+ # `self._mlp_class` cannot be the type to keep the config json serializable
131
+ return getattr(lit_gpt.model, self._mlp_class)
132
+
133
+ @property
134
+ def norm_class(self) -> Type:
135
+ # `self._norm_class` cannot be the type to keep the config json serializable
136
+ if self._norm_class == "RMSNorm":
137
+ from lit_gpt.rmsnorm import RMSNorm
138
+
139
+ return RMSNorm
140
+ return getattr(torch.nn, self._norm_class)
141
+
142
+
143
+ ########################
144
+ # Stability AI StableLM
145
+ ########################
146
+ configs = [
147
+ # https://huggingface.co/stabilityai/stablelm-base-alpha-3b/blob/main/config.json
148
+ dict(name="stablelm-base-alpha-3b", hf_config=dict(org="stabilityai", name="stablelm-base-alpha-3b")),
149
+ # https://huggingface.co/stabilityai/stablelm-base-alpha-7b/blob/main/config.json
150
+ dict(
151
+ name="stablelm-base-alpha-7b",
152
+ hf_config=dict(org="stabilityai", name="stablelm-base-alpha-7b"),
153
+ n_head=48,
154
+ n_embd=6144,
155
+ padding_multiple=256,
156
+ ),
157
+ # https://huggingface.co/stabilityai/stablelm-tuned-alpha-3b/blob/main/config.json
158
+ dict(name="stablelm-tuned-alpha-3b", hf_config=dict(org="stabilityai", name="stablelm-tuned-alpha-3b"), n_head=32),
159
+ # https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b/blob/main/config.json
160
+ dict(
161
+ name="stablelm-tuned-alpha-7b",
162
+ hf_config=dict(org="stabilityai", name="stablelm-tuned-alpha-7b"),
163
+ n_head=48,
164
+ n_embd=6144,
165
+ padding_multiple=256,
166
+ ),
167
+ ]
168
+
169
+ ####################
170
+ # EleutherAI Pythia
171
+ ####################
172
+ pythia = [
173
+ # https://huggingface.co/EleutherAI/pythia-14m/blob/main/config.json
174
+ dict(
175
+ name="pythia-14m",
176
+ hf_config=dict(org="EleutherAI", name="pythia-14m"),
177
+ block_size=512,
178
+ n_layer=6,
179
+ n_embd=128,
180
+ n_head=4,
181
+ padding_multiple=128,
182
+ ),
183
+ # https://huggingface.co/EleutherAI/pythia-31m/blob/main/config.json
184
+ dict(
185
+ name="pythia-31m",
186
+ hf_config=dict(org="EleutherAI", name="pythia-31m"),
187
+ block_size=1024,
188
+ n_layer=6,
189
+ n_embd=256,
190
+ n_head=8,
191
+ padding_multiple=128,
192
+ ),
193
+ # https://huggingface.co/EleutherAI/pythia-70m/blob/main/config.json
194
+ dict(
195
+ name="pythia-70m",
196
+ hf_config=dict(org="EleutherAI", name="pythia-70m"),
197
+ block_size=2048,
198
+ n_layer=6,
199
+ n_embd=512,
200
+ n_head=8,
201
+ padding_multiple=128,
202
+ ),
203
+ # https://huggingface.co/EleutherAI/pythia-160m/blob/main/config.json
204
+ dict(
205
+ name="pythia-160m",
206
+ hf_config=dict(org="EleutherAI", name="pythia-160m"),
207
+ block_size=2048,
208
+ n_layer=12,
209
+ n_embd=768,
210
+ n_head=12,
211
+ padding_multiple=128,
212
+ ),
213
+ # https://huggingface.co/EleutherAI/pythia-410m/blob/main/config.json
214
+ dict(
215
+ name="pythia-410m",
216
+ hf_config=dict(org="EleutherAI", name="pythia-410m"),
217
+ block_size=2048,
218
+ n_layer=24,
219
+ n_embd=1024,
220
+ n_head=16,
221
+ padding_multiple=128,
222
+ ),
223
+ # https://huggingface.co/EleutherAI/pythia-1b/blob/main/config.json
224
+ dict(
225
+ name="pythia-1b",
226
+ hf_config=dict(org="EleutherAI", name="pythia-1b"),
227
+ block_size=2048,
228
+ n_embd=2048,
229
+ n_head=8,
230
+ padding_multiple=128,
231
+ ),
232
+ # https://huggingface.co/EleutherAI/pythia-1.4b/blob/main/config.json
233
+ dict(
234
+ name="pythia-1.4b",
235
+ hf_config=dict(org="EleutherAI", name="pythia-1.4b"),
236
+ block_size=2048,
237
+ n_layer=24,
238
+ n_embd=2048,
239
+ n_head=16,
240
+ padding_multiple=128,
241
+ ),
242
+ # https://huggingface.co/EleutherAI/pythia-2.8b/blob/main/config.json
243
+ dict(
244
+ name="pythia-2.8b",
245
+ hf_config=dict(org="EleutherAI", name="pythia-2.8b"),
246
+ block_size=2048,
247
+ n_layer=32,
248
+ n_embd=2560,
249
+ padding_multiple=128,
250
+ ),
251
+ # https://huggingface.co/EleutherAI/pythia-6.9b/blob/main/config.json
252
+ dict(
253
+ name="pythia-6.9b",
254
+ hf_config=dict(org="EleutherAI", name="pythia-6.9b"),
255
+ block_size=2048,
256
+ n_layer=32,
257
+ padding_multiple=256,
258
+ ),
259
+ # https://huggingface.co/EleutherAI/pythia-12b/blob/main/config.json
260
+ dict(
261
+ name="pythia-12b",
262
+ hf_config=dict(org="EleutherAI", name="pythia-12b"),
263
+ block_size=2048,
264
+ n_layer=36,
265
+ n_embd=5120,
266
+ n_head=40,
267
+ ),
268
+ ]
269
+ configs.extend(pythia)
270
+ for c in pythia:
271
+ # "pythia-14m" and "pythia-31m" don't have deduped version
272
+ if c["name"] in ("pythia-14m", "pythia-31m"):
273
+ continue
274
+ copy = deepcopy(c)
275
+ copy["name"] = f"{c['name']}-deduped"
276
+ copy["hf_config"]["name"] = f"{c['hf_config']['name']}-deduped"
277
+ configs.append(copy)
278
+
279
+
280
+ ####################################
281
+ # togethercomputer RedPajama INCITE
282
+ ####################################
283
+ redpajama_incite = [
284
+ # https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1/blob/main/config.json
285
+ dict(
286
+ name="RedPajama-INCITE-{}-3B-v1",
287
+ hf_config=dict(org="togethercomputer", name="RedPajama-INCITE-{}-3B-v1"),
288
+ block_size=2048,
289
+ n_layer=32,
290
+ n_embd=2560,
291
+ padding_multiple=256,
292
+ rotary_percentage=1.0,
293
+ parallel_residual=False,
294
+ ),
295
+ # https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Base/blob/main/config.json
296
+ dict(
297
+ name="RedPajama-INCITE-7B-{}",
298
+ hf_config=dict(org="togethercomputer", name="RedPajama-INCITE-7B-{}"),
299
+ block_size=2048,
300
+ n_layer=32,
301
+ padding_multiple=256,
302
+ rotary_percentage=1.0,
303
+ parallel_residual=False,
304
+ ),
305
+ # this redirects to the checkpoint above. kept for those who had the old weights already downloaded
306
+ dict(
307
+ name="RedPajama-INCITE-{}-7B-v0.1",
308
+ hf_config=dict(org="togethercomputer", name="RedPajama-INCITE-{}-7B-v0.1"),
309
+ block_size=2048,
310
+ n_layer=32,
311
+ padding_multiple=256,
312
+ rotary_percentage=1.0,
313
+ parallel_residual=False,
314
+ ),
315
+ ]
316
+ for c in redpajama_incite:
317
+ for kind in ("Base", "Chat", "Instruct"):
318
+ copy = deepcopy(c)
319
+ copy["name"] = c["name"].format(kind)
320
+ copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
321
+ configs.append(copy)
322
+
323
+
324
+ #################
325
+ # TII UAE Falcon
326
+ #################
327
+ falcon = [
328
+ # https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json
329
+ dict(
330
+ name="falcon-7b{}",
331
+ hf_config=dict(org="tiiuae", name="falcon-7b{}"),
332
+ block_size=2048,
333
+ vocab_size=65024,
334
+ padded_vocab_size=65024,
335
+ n_layer=32,
336
+ n_head=71,
337
+ n_embd=4544,
338
+ rotary_percentage=1.0,
339
+ n_query_groups=1,
340
+ bias=False,
341
+ # this is not in the config, but in the original model implementation, only for this config
342
+ shared_attention_norm=True,
343
+ ),
344
+ # https://huggingface.co/tiiuae/falcon-40b/blob/main/config.json
345
+ dict(
346
+ name="falcon-40b{}",
347
+ hf_config=dict(org="tiiuae", name="falcon-40b{}"),
348
+ block_size=2048,
349
+ vocab_size=65024,
350
+ padded_vocab_size=65024,
351
+ n_layer=60,
352
+ n_head=128,
353
+ n_embd=8192,
354
+ rotary_percentage=1.0,
355
+ n_query_groups=8,
356
+ bias=False,
357
+ ),
358
+ ]
359
+ for c in falcon:
360
+ for kind in ("", "-instruct"):
361
+ copy = deepcopy(c)
362
+ copy["name"] = c["name"].format(kind)
363
+ copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
364
+ configs.append(copy)
365
+
366
+ # https://huggingface.co/tiiuae/falcon-180b/blob/main/config.json
367
+ falcon180b = dict(
368
+ name="falcon-180B{}",
369
+ hf_config=dict(org="tiiuae", name="falcon-180B{}"),
370
+ block_size=2048,
371
+ vocab_size=65024,
372
+ padded_vocab_size=65024,
373
+ n_layer=80,
374
+ n_head=232,
375
+ n_embd=14848,
376
+ rotary_percentage=1.0,
377
+ n_query_groups=8,
378
+ bias=False,
379
+ )
380
+
381
+ for kind in ("", "-chat"):
382
+ copy = deepcopy(falcon180b)
383
+ copy["name"] = falcon180b["name"].format(kind)
384
+ copy["hf_config"]["name"] = falcon180b["hf_config"]["name"].format(kind)
385
+ configs.append(copy)
386
+
387
+
388
+ #############################
389
+ # OpenLM Research Open LLaMA
390
+ #############################
391
+ open_LLaMA = [
392
+ # https://huggingface.co/openlm-research/open_llama_3b/blob/main/config.json
393
+ dict(
394
+ name="open_llama_3b",
395
+ hf_config=dict(org="openlm-research", name="open_llama_3b"),
396
+ block_size=2048,
397
+ vocab_size=32000,
398
+ padding_multiple=64,
399
+ n_layer=26,
400
+ n_embd=3200,
401
+ rotary_percentage=1.0,
402
+ parallel_residual=False,
403
+ bias=False,
404
+ _norm_class="RMSNorm",
405
+ norm_eps=1e-6,
406
+ _mlp_class="LLaMAMLP",
407
+ intermediate_size=8640,
408
+ ),
409
+ # https://huggingface.co/openlm-research/open_llama_7b/blob/main/config.json
410
+ dict(
411
+ name="open_llama_7b",
412
+ hf_config=dict(org="openlm-research", name="open_llama_7b"),
413
+ block_size=2048,
414
+ vocab_size=32000,
415
+ padding_multiple=64,
416
+ n_layer=32,
417
+ rotary_percentage=1.0,
418
+ parallel_residual=False,
419
+ bias=False,
420
+ _norm_class="RMSNorm",
421
+ norm_eps=1e-6,
422
+ _mlp_class="LLaMAMLP",
423
+ intermediate_size=11008,
424
+ ),
425
+ # https://huggingface.co/openlm-research/open_llama_13b/blob/main/config.json
426
+ dict(
427
+ name="open_llama_13b",
428
+ hf_config=dict(org="openlm-research", name="open_llama_13b"),
429
+ block_size=2048,
430
+ vocab_size=32000,
431
+ padding_multiple=64,
432
+ n_layer=40,
433
+ n_head=40,
434
+ n_embd=5120,
435
+ rotary_percentage=1.0,
436
+ parallel_residual=False,
437
+ bias=False,
438
+ _norm_class="RMSNorm",
439
+ norm_eps=1e-6,
440
+ _mlp_class="LLaMAMLP",
441
+ intermediate_size=13824,
442
+ ),
443
+ ]
444
+ configs.extend(open_LLaMA)
445
+
446
+
447
+ ###############
448
+ # LMSYS Vicuna
449
+ ###############
450
+ vicuna = [
451
+ # https://huggingface.co/lmsys/vicuna-7b-v1.3/blob/main/config.json
452
+ dict(
453
+ name="vicuna-7b-v1.3",
454
+ hf_config=dict(org="lmsys", name="vicuna-7b-v1.3"),
455
+ block_size=2048,
456
+ vocab_size=32000,
457
+ padding_multiple=64,
458
+ n_layer=32,
459
+ rotary_percentage=1.0,
460
+ parallel_residual=False,
461
+ bias=False,
462
+ _norm_class="RMSNorm",
463
+ norm_eps=1e-6,
464
+ _mlp_class="LLaMAMLP",
465
+ intermediate_size=11008,
466
+ ),
467
+ # https://huggingface.co/lmsys/vicuna-13b-v1.3/blob/main/config.json
468
+ dict(
469
+ name="vicuna-13b-v1.3",
470
+ hf_config=dict(org="lmsys", name="vicuna-13b-v1.3"),
471
+ block_size=2048,
472
+ vocab_size=32000,
473
+ padding_multiple=64,
474
+ n_layer=40,
475
+ n_head=40,
476
+ n_embd=5120,
477
+ rotary_percentage=1.0,
478
+ parallel_residual=False,
479
+ bias=False,
480
+ _norm_class="RMSNorm",
481
+ norm_eps=1e-6,
482
+ _mlp_class="LLaMAMLP",
483
+ intermediate_size=13824,
484
+ ),
485
+ # https://huggingface.co/lmsys/vicuna-33b-v1.3/blob/main/config.json
486
+ dict(
487
+ name="vicuna-33b-v1.3",
488
+ hf_config=dict(org="lmsys", name="vicuna-33b-v1.3"),
489
+ block_size=2048,
490
+ vocab_size=32000,
491
+ padding_multiple=64,
492
+ n_layer=60,
493
+ n_head=52,
494
+ n_embd=6656,
495
+ rotary_percentage=1.0,
496
+ parallel_residual=False,
497
+ bias=False,
498
+ _norm_class="RMSNorm",
499
+ norm_eps=1e-6,
500
+ _mlp_class="LLaMAMLP",
501
+ intermediate_size=17920,
502
+ ),
503
+ # https://huggingface.co/lmsys/vicuna-7b-v1.5/blob/main/config.json
504
+ dict(
505
+ name="vicuna-7b-v1.5",
506
+ hf_config=dict(org="lmsys", name="vicuna-7b-v1.5"),
507
+ vocab_size=32000,
508
+ padding_multiple=64,
509
+ n_layer=32,
510
+ rotary_percentage=1.0,
511
+ parallel_residual=False,
512
+ bias=False,
513
+ _norm_class="RMSNorm",
514
+ _mlp_class="LLaMAMLP",
515
+ intermediate_size=11008,
516
+ ),
517
+ # https://huggingface.co/lmsys/vicuna-7b-v1.5-16k/blob/main/config.json
518
+ dict(
519
+ name="vicuna-7b-v1.5-16k",
520
+ hf_config=dict(org="lmsys", name="vicuna-7b-v1.5-16k"),
521
+ block_size=16384,
522
+ vocab_size=32000,
523
+ padding_multiple=64,
524
+ n_layer=32,
525
+ rotary_percentage=1.0,
526
+ parallel_residual=False,
527
+ bias=False,
528
+ _norm_class="RMSNorm",
529
+ _mlp_class="LLaMAMLP",
530
+ intermediate_size=11008,
531
+ rope_condense_ratio=4,
532
+ ),
533
+ # https://huggingface.co/lmsys/vicuna-13b-v1.5/blob/main/config.json
534
+ dict(
535
+ name="vicuna-13b-v1.5",
536
+ hf_config=dict(org="lmsys", name="vicuna-13b-v1.5"),
537
+ vocab_size=32000,
538
+ padding_multiple=64,
539
+ n_layer=40,
540
+ n_head=40,
541
+ n_embd=5120,
542
+ rotary_percentage=1.0,
543
+ parallel_residual=False,
544
+ bias=False,
545
+ _norm_class="RMSNorm",
546
+ _mlp_class="LLaMAMLP",
547
+ intermediate_size=13824,
548
+ ),
549
+ # https://huggingface.co/lmsys/vicuna-13b-v1.5-16k/blob/main/config.json
550
+ dict(
551
+ name="vicuna-13b-v1.5-16k",
552
+ hf_config=dict(org="lmsys", name="vicuna-13b-v1.5-16k"),
553
+ block_size=16384,
554
+ vocab_size=32000,
555
+ padding_multiple=64,
556
+ n_layer=40,
557
+ n_head=40,
558
+ n_embd=5120,
559
+ rotary_percentage=1.0,
560
+ parallel_residual=False,
561
+ bias=False,
562
+ _norm_class="RMSNorm",
563
+ _mlp_class="LLaMAMLP",
564
+ intermediate_size=13824,
565
+ rope_condense_ratio=4,
566
+ ),
567
+ ]
568
+ configs.extend(vicuna)
569
+
570
+
571
+ #################
572
+ # LMSYS LongChat
573
+ #################
574
+ long_chat = [
575
+ # https://huggingface.co/lmsys/longchat-7b-16k/blob/main/config.json
576
+ dict(
577
+ name="longchat-7b-16k",
578
+ hf_config=dict(org="lmsys", name="longchat-7b-16k"),
579
+ block_size=16384,
580
+ vocab_size=32000,
581
+ padding_multiple=64,
582
+ n_layer=32,
583
+ rotary_percentage=1.0,
584
+ parallel_residual=False,
585
+ bias=False,
586
+ _norm_class="RMSNorm",
587
+ norm_eps=1e-6,
588
+ _mlp_class="LLaMAMLP",
589
+ intermediate_size=11008,
590
+ rope_condense_ratio=8,
591
+ ),
592
+ # https://huggingface.co/lmsys/longchat-13b-16k/blob/main/config.json
593
+ dict(
594
+ name="longchat-13b-16k",
595
+ hf_config=dict(org="lmsys", name="longchat-13b-16k"),
596
+ block_size=16384,
597
+ vocab_size=32000,
598
+ padding_multiple=64,
599
+ n_layer=40,
600
+ n_head=40,
601
+ n_embd=5120,
602
+ rotary_percentage=1.0,
603
+ parallel_residual=False,
604
+ bias=False,
605
+ _norm_class="RMSNorm",
606
+ norm_eps=1e-6,
607
+ _mlp_class="LLaMAMLP",
608
+ intermediate_size=13824,
609
+ rope_condense_ratio=8,
610
+ ),
611
+ ]
612
+ configs.extend(long_chat)
613
+
614
+
615
+ ######################
616
+ # NousResearch Hermes
617
+ ######################
618
+ nous_research = [
619
+ # https://huggingface.co/NousResearch/Nous-Hermes-llama-2-7b/blob/main/config.json
620
+ dict(
621
+ name="Nous-Hermes-llama-2-7b",
622
+ hf_config=dict(org="NousResearch", name="Nous-Hermes-llama-2-7b"),
623
+ padded_vocab_size=32000,
624
+ n_layer=32,
625
+ rotary_percentage=1.0,
626
+ parallel_residual=False,
627
+ bias=False,
628
+ _norm_class="RMSNorm",
629
+ norm_eps=1e-05,
630
+ _mlp_class="LLaMAMLP",
631
+ intermediate_size=11008,
632
+ ),
633
+ # https://huggingface.co/NousResearch/Nous-Hermes-13B/blob/main/config.json
634
+ dict(
635
+ name="Nous-Hermes-13b",
636
+ hf_config=dict(org="NousResearch", name="Nous-Hermes-13b"),
637
+ block_size=2048,
638
+ vocab_size=32000,
639
+ padded_vocab_size=32001,
640
+ n_layer=40,
641
+ n_head=40,
642
+ n_embd=5120,
643
+ rotary_percentage=1.0,
644
+ parallel_residual=False,
645
+ bias=False,
646
+ _norm_class="RMSNorm",
647
+ norm_eps=1e-6,
648
+ _mlp_class="LLaMAMLP",
649
+ intermediate_size=13824,
650
+ ),
651
+ # https://huggingface.co/NousResearch/Nous-Hermes-Llama2-13b
652
+ dict(
653
+ name="Nous-Hermes-Llama2-13b",
654
+ hf_config=dict(org="NousResearch", name="Nous-Hermes-Llama2-13b"),
655
+ vocab_size=32000,
656
+ padded_vocab_size=32032,
657
+ n_layer=40,
658
+ n_head=40,
659
+ n_embd=5120,
660
+ rotary_percentage=1.0,
661
+ parallel_residual=False,
662
+ bias=False,
663
+ _norm_class="RMSNorm",
664
+ norm_eps=1e-05,
665
+ _mlp_class="LLaMAMLP",
666
+ intermediate_size=13824,
667
+ ),
668
+ ]
669
+ configs.extend(nous_research)
670
+
671
+
672
+ ###############
673
+ # Meta LLaMA 2
674
+ ###############
675
+ llama_2 = [
676
+ # https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json
677
+ dict(
678
+ name="Llama-2-7b{}-hf",
679
+ hf_config=dict(org="meta-llama", name="Llama-2-7b{}-hf"),
680
+ vocab_size=32000,
681
+ padding_multiple=64,
682
+ n_layer=32,
683
+ rotary_percentage=1.0,
684
+ parallel_residual=False,
685
+ bias=False,
686
+ _norm_class="RMSNorm",
687
+ _mlp_class="LLaMAMLP",
688
+ intermediate_size=11008,
689
+ ),
690
+ # https://huggingface.co/meta-llama/Llama-2-13b-hf/blob/main/config.json
691
+ dict(
692
+ name="Llama-2-13b{}-hf",
693
+ hf_config=dict(org="meta-llama", name="Llama-2-13b{}-hf"),
694
+ vocab_size=32000,
695
+ padding_multiple=64,
696
+ n_layer=40,
697
+ n_head=40,
698
+ n_embd=5120,
699
+ rotary_percentage=1.0,
700
+ parallel_residual=False,
701
+ bias=False,
702
+ _norm_class="RMSNorm",
703
+ _mlp_class="LLaMAMLP",
704
+ intermediate_size=13824,
705
+ ),
706
+ # https://huggingface.co/meta-llama/Llama-2-70b-hf/blob/main/config.json
707
+ dict(
708
+ name="Llama-2-70b{}-hf",
709
+ hf_config=dict(org="meta-llama", name="Llama-2-70b{}-hf"),
710
+ vocab_size=32000,
711
+ padding_multiple=64,
712
+ n_layer=80,
713
+ n_head=64,
714
+ n_embd=8192,
715
+ n_query_groups=8,
716
+ rotary_percentage=1.0,
717
+ parallel_residual=False,
718
+ bias=False,
719
+ _norm_class="RMSNorm",
720
+ _mlp_class="LLaMAMLP",
721
+ intermediate_size=28672,
722
+ ),
723
+ ]
724
+ for c in llama_2:
725
+ for kind in ("", "-chat"):
726
+ copy = deepcopy(c)
727
+ copy["name"] = c["name"].format(kind)
728
+ copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
729
+ configs.append(copy)
730
+
731
+
732
+ ##########################
733
+ # Stability AI FreeWilly2
734
+ ##########################
735
+ freewilly_2 = [
736
+ # https://huggingface.co/stabilityai/FreeWilly2/blob/main/config.json
737
+ dict(
738
+ name="FreeWilly2",
739
+ hf_config=dict(org="stabilityai", name="FreeWilly2"),
740
+ vocab_size=32000,
741
+ padding_multiple=64,
742
+ n_layer=80,
743
+ n_head=64,
744
+ n_embd=8192,
745
+ n_query_groups=8,
746
+ rotary_percentage=1.0,
747
+ parallel_residual=False,
748
+ bias=False,
749
+ _norm_class="RMSNorm",
750
+ _mlp_class="LLaMAMLP",
751
+ intermediate_size=28672,
752
+ )
753
+ ]
754
+ configs.extend(freewilly_2)
755
+
756
+
757
+ ##################
758
+ # Meta Code Llama
759
+ ##################
760
+ code_llama = [
761
+ # https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json
762
+ dict(
763
+ name="CodeLlama-7b-hf",
764
+ hf_config=dict(org="codellama", name="CodeLlama-7b-hf"),
765
+ block_size=16384,
766
+ vocab_size=32016,
767
+ padding_multiple=16,
768
+ n_layer=32,
769
+ rotary_percentage=1.0,
770
+ parallel_residual=False,
771
+ bias=False,
772
+ _norm_class="RMSNorm",
773
+ norm_eps=1e-05,
774
+ _mlp_class="LLaMAMLP",
775
+ intermediate_size=11008,
776
+ rope_base=1000000,
777
+ ),
778
+ # https://huggingface.co/codellama/CodeLlama-13b-hf/blob/main/config.json
779
+ dict(
780
+ name="CodeLlama-13b-hf",
781
+ hf_config=dict(org="codellama", name="CodeLlama-13b-hf"),
782
+ block_size=16384,
783
+ vocab_size=32016,
784
+ padding_multiple=16,
785
+ n_layer=40,
786
+ n_head=40,
787
+ n_embd=5120,
788
+ rotary_percentage=1.0,
789
+ parallel_residual=False,
790
+ bias=False,
791
+ _norm_class="RMSNorm",
792
+ norm_eps=1e-05,
793
+ _mlp_class="LLaMAMLP",
794
+ intermediate_size=13824,
795
+ rope_base=1000000,
796
+ ),
797
+ # https://huggingface.co/codellama/CodeLlama-34b-hf/blob/main/config.json
798
+ dict(
799
+ name="CodeLlama-34b-hf",
800
+ hf_config=dict(org="codellama", name="CodeLlama-34b-hf"),
801
+ block_size=16384,
802
+ vocab_size=32000,
803
+ padding_multiple=64,
804
+ n_layer=48,
805
+ n_head=64,
806
+ n_embd=8192,
807
+ n_query_groups=8,
808
+ rotary_percentage=1.0,
809
+ parallel_residual=False,
810
+ bias=False,
811
+ _norm_class="RMSNorm",
812
+ norm_eps=1e-05,
813
+ _mlp_class="LLaMAMLP",
814
+ intermediate_size=22016,
815
+ rope_base=1000000,
816
+ ),
817
+ # https://huggingface.co/codellama/CodeLlama-7b-Python-hf/blob/main/config.json
818
+ dict(
819
+ name="CodeLlama-7b-Python-hf",
820
+ hf_config=dict(org="codellama", name="CodeLlama-7b-Python-hf"),
821
+ block_size=16384,
822
+ vocab_size=32000,
823
+ padding_multiple=64,
824
+ n_layer=32,
825
+ rotary_percentage=1.0,
826
+ parallel_residual=False,
827
+ bias=False,
828
+ _norm_class="RMSNorm",
829
+ norm_eps=1e-05,
830
+ _mlp_class="LLaMAMLP",
831
+ intermediate_size=11008,
832
+ rope_base=1000000,
833
+ ),
834
+ # https://huggingface.co/codellama/CodeLlama-13b-Python-hf/blob/main/config.json
835
+ dict(
836
+ name="CodeLlama-13b-Python-hf",
837
+ hf_config=dict(org="codellama", name="CodeLlama-13b-Python-hf"),
838
+ block_size=16384,
839
+ vocab_size=32000,
840
+ padding_multiple=64,
841
+ n_layer=40,
842
+ n_head=40,
843
+ n_embd=5120,
844
+ rotary_percentage=1.0,
845
+ parallel_residual=False,
846
+ bias=False,
847
+ _norm_class="RMSNorm",
848
+ norm_eps=1e-05,
849
+ _mlp_class="LLaMAMLP",
850
+ intermediate_size=13824,
851
+ rope_base=1000000,
852
+ ),
853
+ # https://huggingface.co/codellama/CodeLlama-34b-Python-hf/blob/main/config.json
854
+ dict(
855
+ name="CodeLlama-34b-Python-hf",
856
+ hf_config=dict(org="codellama", name="CodeLlama-34b-Python-hf"),
857
+ block_size=16384,
858
+ vocab_size=32000,
859
+ padding_multiple=64,
860
+ n_layer=48,
861
+ n_head=64,
862
+ n_embd=8192,
863
+ n_query_groups=8,
864
+ rotary_percentage=1.0,
865
+ parallel_residual=False,
866
+ bias=False,
867
+ _norm_class="RMSNorm",
868
+ norm_eps=1e-05,
869
+ _mlp_class="LLaMAMLP",
870
+ intermediate_size=22016,
871
+ rope_base=1000000,
872
+ ),
873
+ # https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/tree/main/config.json
874
+ dict(
875
+ name="CodeLlama-7b-Instruct-hf",
876
+ hf_config=dict(org="codellama", name="CodeLlama-7b-Instruct-hf"),
877
+ block_size=16384,
878
+ vocab_size=32016,
879
+ padding_multiple=16,
880
+ n_layer=32,
881
+ rotary_percentage=1.0,
882
+ parallel_residual=False,
883
+ bias=False,
884
+ _norm_class="RMSNorm",
885
+ norm_eps=1e-05,
886
+ _mlp_class="LLaMAMLP",
887
+ intermediate_size=11008,
888
+ rope_base=1000000,
889
+ ),
890
+ # https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf/blob/main/config.json
891
+ dict(
892
+ name="CodeLlama-13b-Instruct-hf",
893
+ hf_config=dict(org="codellama", name="CodeLlama-13b-Instruct-hf"),
894
+ block_size=2048,
895
+ vocab_size=32016,
896
+ padding_multiple=16,
897
+ n_layer=40,
898
+ n_head=40,
899
+ n_embd=5120,
900
+ rotary_percentage=1.0,
901
+ parallel_residual=False,
902
+ bias=False,
903
+ _norm_class="RMSNorm",
904
+ norm_eps=1e-05,
905
+ _mlp_class="LLaMAMLP",
906
+ intermediate_size=13824,
907
+ rope_base=1000000,
908
+ ),
909
+ # https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf/blob/main/config.json
910
+ dict(
911
+ name="CodeLlama-34b-Instruct-hf",
912
+ hf_config=dict(org="codellama", name="CodeLlama-34b-Instruct-hf"),
913
+ block_size=16384,
914
+ vocab_size=32000,
915
+ padding_multiple=64,
916
+ n_layer=48,
917
+ n_head=64,
918
+ n_embd=8192,
919
+ n_query_groups=8,
920
+ rotary_percentage=1.0,
921
+ parallel_residual=False,
922
+ bias=False,
923
+ _norm_class="RMSNorm",
924
+ norm_eps=1e-05,
925
+ _mlp_class="LLaMAMLP",
926
+ intermediate_size=22016,
927
+ rope_base=1000000,
928
+ ),
929
+ ]
930
+ configs.extend(code_llama)
931
+
932
+
933
+ ########################
934
+ # garage-bAInd Platypus
935
+ ########################
936
+ platypus = [
937
+ # https://huggingface.co/garage-bAInd/Platypus-30B/blob/main/config.json
938
+ dict(
939
+ name="Platypus-30B",
940
+ hf_config=dict(org="garage-bAInd", name="Platypus-30B"),
941
+ block_size=2048,
942
+ padded_vocab_size=32000,
943
+ n_layer=60,
944
+ n_head=52,
945
+ n_embd=6656,
946
+ rotary_percentage=1.0,
947
+ parallel_residual=False,
948
+ bias=False,
949
+ _norm_class="RMSNorm",
950
+ norm_eps=1e-06,
951
+ _mlp_class="LLaMAMLP",
952
+ intermediate_size=17920,
953
+ ),
954
+ # https://huggingface.co/garage-bAInd/Platypus2-7B/blob/main/config.json
955
+ dict(
956
+ name="Platypus2-7B",
957
+ hf_config=dict(org="garage-bAInd", name="Platypus2-7B"),
958
+ padded_vocab_size=32000,
959
+ n_layer=32,
960
+ rotary_percentage=1.0,
961
+ parallel_residual=False,
962
+ bias=False,
963
+ _norm_class="RMSNorm",
964
+ norm_eps=1e-05,
965
+ _mlp_class="LLaMAMLP",
966
+ intermediate_size=11008,
967
+ ),
968
+ # https://huggingface.co/garage-bAInd/Platypus2-13B/blob/main/config.json
969
+ dict(
970
+ name="Platypus2-13B",
971
+ hf_config=dict(org="garage-bAInd", name="Platypus2-13B"),
972
+ padded_vocab_size=32000,
973
+ n_layer=40,
974
+ n_head=40,
975
+ n_embd=5120,
976
+ rotary_percentage=1.0,
977
+ parallel_residual=False,
978
+ bias=False,
979
+ _norm_class="RMSNorm",
980
+ norm_eps=1e-05,
981
+ _mlp_class="LLaMAMLP",
982
+ intermediate_size=13824,
983
+ ),
984
+ # https://huggingface.co/garage-bAInd/Platypus2-70B/blob/main/config.json
985
+ dict(
986
+ name="Platypus2-70B",
987
+ hf_config=dict(org="garage-bAInd", name="Platypus2-70B"),
988
+ padded_vocab_size=32000,
989
+ n_layer=80,
990
+ n_head=64,
991
+ n_embd=8192,
992
+ rotary_percentage=1.0,
993
+ parallel_residual=False,
994
+ bias=False,
995
+ _norm_class="RMSNorm",
996
+ _mlp_class="LLaMAMLP",
997
+ intermediate_size=28672,
998
+ ),
999
+ # https://huggingface.co/garage-bAInd/Camel-Platypus2-13B/blob/main/config.json
1000
+ dict(
1001
+ name="Camel-Platypus2-13B",
1002
+ hf_config=dict(org="garage-bAInd", name="Camel-Platypus2-13B"),
1003
+ padded_vocab_size=32000,
1004
+ n_layer=40,
1005
+ n_head=40,
1006
+ n_embd=5120,
1007
+ rotary_percentage=1.0,
1008
+ parallel_residual=False,
1009
+ bias=False,
1010
+ _norm_class="RMSNorm",
1011
+ _mlp_class="LLaMAMLP",
1012
+ intermediate_size=13824,
1013
+ ),
1014
+ # https://huggingface.co/garage-bAInd/Camel-Platypus2-70B/blob/main/config.json
1015
+ dict(
1016
+ name="Camel-Platypus2-70B",
1017
+ hf_config=dict(org="garage-bAInd", name="Camel-Platypus2-70B"),
1018
+ padded_vocab_size=32000,
1019
+ n_layer=80,
1020
+ n_head=64,
1021
+ n_embd=8192,
1022
+ n_query_groups=8,
1023
+ rotary_percentage=1.0,
1024
+ parallel_residual=False,
1025
+ bias=False,
1026
+ _norm_class="RMSNorm",
1027
+ _mlp_class="LLaMAMLP",
1028
+ intermediate_size=28672,
1029
+ ),
1030
+ # https://huggingface.co/garage-bAInd/Stable-Platypus2-13B/blob/main/config.json
1031
+ dict(
1032
+ name="Stable-Platypus2-13B",
1033
+ hf_config=dict(org="garage-bAInd", name="Stable-Platypus2-13B"),
1034
+ padded_vocab_size=32000,
1035
+ n_layer=40,
1036
+ n_head=40,
1037
+ n_embd=5120,
1038
+ rotary_percentage=1.0,
1039
+ parallel_residual=False,
1040
+ bias=False,
1041
+ _norm_class="RMSNorm",
1042
+ _mlp_class="LLaMAMLP",
1043
+ intermediate_size=13824,
1044
+ ),
1045
+ # https://huggingface.co/garage-bAInd/Platypus2-70B-instruct/blob/main/config.json
1046
+ dict(
1047
+ name="Platypus2-70B-instruct",
1048
+ hf_config=dict(org="garage-bAInd", name="Platypus2-70B-instruct"),
1049
+ padded_vocab_size=32000,
1050
+ n_layer=80,
1051
+ n_head=64,
1052
+ n_embd=8192,
1053
+ n_query_groups=8,
1054
+ rotary_percentage=1.0,
1055
+ parallel_residual=False,
1056
+ bias=False,
1057
+ _norm_class="RMSNorm",
1058
+ _mlp_class="LLaMAMLP",
1059
+ intermediate_size=28672,
1060
+ ),
1061
+ ]
1062
+ configs.extend(platypus)
1063
+
1064
+
1065
+ ##########################
1066
+ # Stability AI StableCode
1067
+ ##########################
1068
+ stablecode = [
1069
+ # https://huggingface.co/stabilityai/stablecode-completion-alpha-3b/blob/main/config.json
1070
+ dict(
1071
+ name="stablecode-completion-alpha-3b",
1072
+ hf_config=dict(org="stabilityai", name="stablecode-completion-alpha-3b"),
1073
+ block_size=16384,
1074
+ vocab_size=49152,
1075
+ n_layer=32,
1076
+ n_embd=2560,
1077
+ ),
1078
+ # https://huggingface.co/stabilityai/stablecode-completion-alpha-3b-4k/blob/main/config.json
1079
+ dict(
1080
+ name="stablecode-completion-alpha-3b-4k",
1081
+ hf_config=dict(org="stabilityai", name="stablecode-completion-alpha-3b-4k"),
1082
+ vocab_size=49152,
1083
+ n_layer=32,
1084
+ n_embd=2560,
1085
+ ),
1086
+ # https://huggingface.co/stabilityai/stablecode-instruct-alpha-3b/blob/main/config.json
1087
+ dict(
1088
+ name="stablecode-instruct-alpha-3b",
1089
+ hf_config=dict(org="stabilityai", name="stablecode-instruct-alpha-3b"),
1090
+ vocab_size=49152,
1091
+ n_layer=32,
1092
+ n_embd=2560,
1093
+ ),
1094
+ ]
1095
+ configs.extend(stablecode)
1096
+
1097
+
1098
+ ##################################
1099
+ # togethercomputer LLaMA-2-7B-32K
1100
+ ##################################
1101
+ together_llama2_32k = [
1102
+ # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/config.json
1103
+ dict(
1104
+ name="LLaMA-2-7B-32K",
1105
+ hf_config=dict(org="togethercomputer", name="LLaMA-2-7B-32K"),
1106
+ vocab_size=32000,
1107
+ padding_multiple=64,
1108
+ n_layer=32,
1109
+ rotary_percentage=1.0,
1110
+ parallel_residual=False,
1111
+ bias=False,
1112
+ _norm_class="RMSNorm",
1113
+ _mlp_class="LLaMAMLP",
1114
+ intermediate_size=11008,
1115
+ rope_condense_ratio=8,
1116
+ )
1117
+ ]
1118
+ configs.extend(together_llama2_32k)
1119
+
1120
+
1121
+ ################
1122
+ # Microsoft Phi
1123
+ ################
1124
+ phi = [
1125
+ # https://huggingface.co/microsoft/phi-1_5/blob/main/config.json
1126
+ dict(
1127
+ name="phi-1_5",
1128
+ hf_config=dict(org="microsoft", name="phi-1_5"),
1129
+ vocab_size=50257,
1130
+ padded_vocab_size=51200,
1131
+ block_size=2048,
1132
+ n_embd=2048,
1133
+ n_layer=24,
1134
+ rotary_percentage=0.5, # 32 / (n_embd / n_head) = 32 / 64
1135
+ shared_attention_norm=True,
1136
+ lm_head_bias=True,
1137
+ gelu_approximate="tanh",
1138
+ )
1139
+ ]
1140
+ configs.extend(phi)
1141
+
1142
+
1143
+ #############
1144
+ # Mistral AI
1145
+ #############
1146
+ mistral = [
1147
+ # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
1148
+ dict(
1149
+ name="Mistral-7B-{}v0.1",
1150
+ hf_config=dict(org="mistralai", name="Mistral-7B-{}v0.1"),
1151
+ padded_vocab_size=32000,
1152
+ block_size=4096, # should be 32768 but sliding window attention is not implemented
1153
+ n_layer=32,
1154
+ n_query_groups=8,
1155
+ rotary_percentage=1.0,
1156
+ parallel_residual=False,
1157
+ bias=False,
1158
+ _norm_class="RMSNorm",
1159
+ norm_eps=1e-05,
1160
+ _mlp_class="LLaMAMLP",
1161
+ intermediate_size=14336,
1162
+ )
1163
+ ]
1164
+ for c in mistral:
1165
+ for kind in ("", "Instruct-"):
1166
+ copy = deepcopy(c)
1167
+ copy["name"] = c["name"].format(kind)
1168
+ copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
1169
+ configs.append(copy)
1170
+
1171
+
1172
+ ############
1173
+ # TinyLlama
1174
+ ############
1175
+ tiny_llama = [
1176
+ dict(
1177
+ name="tiny-llama-1.1b{}",
1178
+ hf_config=dict(org="TinyLlama", name="TinyLlama-1.1B{}"),
1179
+ block_size=2048,
1180
+ vocab_size=32000,
1181
+ padding_multiple=64,
1182
+ n_layer=22,
1183
+ n_head=32,
1184
+ n_embd=2048,
1185
+ rotary_percentage=1.0,
1186
+ parallel_residual=False,
1187
+ bias=False,
1188
+ _norm_class="RMSNorm", # original TinyLlama uses FusedRMSNorm
1189
+ norm_eps=1e-5,
1190
+ _mlp_class="LLaMAMLP",
1191
+ intermediate_size=5632,
1192
+ n_query_groups=4,
1193
+ ),
1194
+ ]
1195
+ for c in tiny_llama:
1196
+ for kind, hf_postfix in (("", "-intermediate-step-955k-token-2T"), ("chat", "-Chat-v0.6")):
1197
+ copy = deepcopy(c)
1198
+ copy["name"] = c["name"].format(kind)
1199
+ copy["hf_config"]["name"] = c["hf_config"]["name"].format(hf_postfix)
1200
+ configs.append(copy)
1201
+
1202
+
1203
+ name_to_config = {config["name"]: config for config in configs}
lit_gpt/lora.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Derived from https://github.com/microsoft/LoRA
2
+ # ------------------------------------------------------------------------------------------
3
+ # Copyright (c) Microsoft Corporation. All rights reserved.
4
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
5
+ # ------------------------------------------------------------------------------------------
6
+
7
+ r"""
8
+ Low Ranking Adaptation for LLMs scheme.
9
+
10
+ ┌───────────────────┐
11
+ ┆ h ┆
12
+ └───────────────────┘
13
+
14
+ |
15
+ +
16
+ / \
17
+ ┌─────────────────┐ ╭───────────────╮ Matrix initialization:
18
+ ┆ ┆ \ B / B = 0
19
+ ┆ pretrained ┆ \ r*d / A = N(0, sigma^2)
20
+ ┆ weights ┆ ╰─────────╯
21
+ ┆ ┆ | r | r - rank
22
+ ┆ W e R^(d*d) ┆ | ◀─────▶ |
23
+ ┆ ┆ ╭─────────╮
24
+ └─────────────────┘ / A \
25
+ ▲ / d*r \
26
+ \ ╰───────────────╯
27
+ \ ▲
28
+ \ /
29
+ \ /
30
+ ┌───────────────────┐
31
+ ┆ x ┆
32
+ └───────────────────┘
33
+
34
+ With LoRA (Low Ranking Adaptation: https://arxiv.org/abs/2106.09685) instead of learning weights of size d*d,
35
+ we can freeze the pretrained weights and instead learn two matrices of size d*r and r*d (they will store weight updates
36
+ for the pretrained weights): the number of parameters in this case will be reduced drastically (depending on the rank of
37
+ course) yet after multiplication of matrices d*r and r*d we will get a matrix d*d which we can sum with frozen
38
+ pretrained weights and thus fine-tune the model.
39
+
40
+ The goal of this approach is to move weight updates into a separate matrix which is decomposed with
41
+ two matrices of a lower rank.
42
+ """
43
+
44
+ import math
45
+ from dataclasses import dataclass
46
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union
47
+
48
+ import torch
49
+ import torch.nn as nn
50
+ from torch.nn import functional as F
51
+ from typing_extensions import Self
52
+
53
+ import lit_gpt
54
+ from lit_gpt.config import Config as BaseConfig
55
+ from lit_gpt.model import GPT as BaseModel
56
+ from lit_gpt.model import Block as BaseBlock
57
+ from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention
58
+ from lit_gpt.model import KVCache
59
+ from lit_gpt.utils import map_old_state_dict_weights
60
+
61
+
62
+ class LoRALayer(nn.Module):
63
+ def __init__(self, r: int, lora_alpha: int, lora_dropout: float):
64
+ """Store LoRA specific attributes in a class.
65
+
66
+ Args:
67
+ r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
68
+ the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
69
+ lora_alpha: alpha is needed for scaling updates as alpha/r
70
+ "This scaling helps to reduce the need to retune hyperparameters when we vary r"
71
+ https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
72
+ lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
73
+ """
74
+ super().__init__()
75
+ assert r >= 0
76
+ self.r = r
77
+ self.lora_alpha = lora_alpha
78
+ # Optional dropout
79
+ if lora_dropout > 0.0:
80
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
81
+ else:
82
+ self.lora_dropout = lambda x: x
83
+ # Mark the weight as unmerged
84
+ self.merged = False
85
+
86
+
87
+ class LoRALinear(LoRALayer):
88
+ # LoRA implemented in a dense layer
89
+ def __init__(
90
+ self,
91
+ # ↓ this part is for pretrained weights
92
+ in_features: int,
93
+ out_features: int,
94
+ # ↓ the remaining part is for LoRA
95
+ r: int = 0,
96
+ lora_alpha: int = 1,
97
+ lora_dropout: float = 0.0,
98
+ **kwargs,
99
+ ):
100
+ """LoRA wrapper around linear class.
101
+
102
+ This class has three weight matrices:
103
+ 1. Pretrained weights are stored as `self.linear.weight`
104
+ 2. LoRA A matrix as `self.lora_A`
105
+ 3. LoRA B matrix as `self.lora_B`
106
+ Only LoRA's A and B matrices are updated, pretrained weights stay frozen.
107
+
108
+ Args:
109
+ in_features: number of input features of the pretrained weights
110
+ out_features: number of output features of the pretrained weights
111
+ r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
112
+ the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
113
+ lora_alpha: alpha is needed for scaling updates as alpha/r
114
+ "This scaling helps to reduce the need to retune hyperparameters when we vary r"
115
+ https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
116
+ lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
117
+ """
118
+ super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
119
+ self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
120
+
121
+ # Actual trainable parameters
122
+ if r > 0:
123
+ self.lora_A = nn.Parameter(torch.zeros((r, in_features)))
124
+ self.lora_B = nn.Parameter(torch.zeros((out_features, r)))
125
+ self.scaling = self.lora_alpha / self.r
126
+ self.reset_parameters()
127
+
128
+ def reset_parameters(self) -> None:
129
+ """Reset all the weights, even including pretrained ones."""
130
+ if hasattr(self, "lora_A"):
131
+ # initialize A the same way as the default for nn.Linear and B to zero
132
+ # Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314
133
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
134
+ nn.init.zeros_(self.lora_B)
135
+
136
+ def merge(self) -> None:
137
+ """Merges the LoRA weights into the full-rank weights (W = W + delta_W)."""
138
+ if self.r > 0 and not self.merged:
139
+ # Merge the weights and mark it
140
+ self.linear.weight.data += (self.lora_B @ self.lora_A) * self.scaling
141
+ self.merged = True
142
+
143
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
144
+ # if weights are merged or rank is less or equal to zero (LoRA is disabled) - it's only a regular nn.Linear forward pass;
145
+ # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights
146
+ pretrained = self.linear(x)
147
+ if self.r == 0 or self.merged:
148
+ return pretrained
149
+ lora = (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
150
+ return pretrained + lora
151
+
152
+
153
+ class LoRAQKVLinear(LoRALinear):
154
+ # LoRA implemented in a dense layer
155
+ def __init__(
156
+ self,
157
+ # ↓ this part is for pretrained weights
158
+ in_features: int,
159
+ out_features: int,
160
+ # ↓ the remaining part is for LoRA
161
+ n_head: int,
162
+ n_query_groups: int,
163
+ r: int = 0,
164
+ lora_alpha: int = 1,
165
+ lora_dropout: float = 0.0,
166
+ enable_lora: Union[bool, Tuple[bool, bool, bool]] = False,
167
+ **kwargs,
168
+ ):
169
+ """LoRA wrapper around linear class that is used for calculation of q, k and v matrices.
170
+
171
+ This class has three weight matrices:
172
+ 1. Pretrained weights are stored as `self.linear.weight`
173
+ 2. LoRA A matrix as `self.lora_A`
174
+ 3. LoRA B matrix as `self.lora_B`
175
+ Only LoRA's A and B matrices are updated, pretrained weights stay frozen.
176
+
177
+ Args:
178
+ in_features: number of input features of the pretrained weights
179
+ out_features: number of output features of the pretrained weights
180
+ n_head: number of attention heads
181
+ n_query_groups: number of query groups (see diagram in `lit_gpt/config.py`)
182
+ r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
183
+ the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
184
+ lora_alpha: alpha is needed for scaling updates as alpha/r
185
+ "This scaling helps to reduce the need to retune hyperparameters when we vary r"
186
+ https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
187
+ lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
188
+ enable_lora: MergeLinear class is for attention mechanism where qkv are calculated with a single weight matrix. If we
189
+ don't want to apply LoRA we can set it as False. For example if we want to apply LoRA only to `query`
190
+ and `value` but keep `key` without weight updates we should pass `[True, False, True]`
191
+ """
192
+ super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
193
+ self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
194
+ self.n_head = n_head
195
+ self.n_query_groups = n_query_groups
196
+ if isinstance(enable_lora, bool):
197
+ enable_lora = [enable_lora] * 3
198
+ assert len(enable_lora) == 3
199
+ self.enable_lora = enable_lora
200
+
201
+ # Actual trainable parameters
202
+ # To better understand initialization let's imagine that we have such parameters:
203
+ # ⚬ in_features: 128 (embeddings_size)
204
+ # ⚬ out_features: 384 (3 * embedding_size)
205
+ # ⚬ r: 2
206
+ # ⚬ enable_lora: [True, False, True]
207
+ if r > 0 and any(enable_lora):
208
+ self.lora_A = nn.Parameter(torch.zeros((r * sum(enable_lora), in_features))) # (4, 128)
209
+ enable_q, enable_k, enable_v = enable_lora
210
+ self.kv_embd_size = self.linear.in_features // (n_head // n_query_groups)
211
+ # qkv_shapes will be used to split a tensor with weights correctly
212
+ qkv_shapes = (
213
+ self.linear.in_features * enable_q,
214
+ self.kv_embd_size * enable_k,
215
+ self.kv_embd_size * enable_v,
216
+ )
217
+ self.qkv_shapes = [s for s in qkv_shapes if s]
218
+ self.lora_B = nn.Parameter(torch.zeros(sum(self.qkv_shapes), r)) # (256, 2))
219
+ # Notes about shapes above
220
+ # - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices;
221
+ # 128 is the input size of the x (embedding size). (4, 128) and not (128, 4) because later on in
222
+ # F.linear function weights are automatically transposed. In addition conv1d requires channels to
223
+ # be before seq length
224
+ # - self.lora_B has shape (256, 2): 256 because LoRA is applied only to two matrices, so the output is
225
+ # 128*2; 2 tells to have two channels per group for group convolution
226
+
227
+ # Scaling:
228
+ # This balances the pretrained model`s knowledge and the new task-specific adaptation
229
+ # https://lightning.ai/pages/community/tutorial/lora-llm/
230
+ # So, set alpha to 1.0 to fully add LoRA. If the LoRA seems to have too much effect (i.e., overfitted), set
231
+ # alpha to lower value. If the LoRA seems to have too little effect, set alpha to higher than 1.0. You can
232
+ # tune these values to your needs. This value can be even slightly greater than 1.0!
233
+ # https://github.com/cloneofsimo/lora
234
+ self.scaling = self.lora_alpha / self.r
235
+
236
+ # Compute the indices
237
+ # Indices are needed to properly pad weight updates with zeros. If we want to fine-tune queries and values,
238
+ # but not keys, then the weights update should be:
239
+ #
240
+ # [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
241
+ # [....................................],
242
+ # [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
243
+ # ↑ ↑ ↑
244
+ # ________________________________________
245
+ # | query | key | value |
246
+ # ----------------------------------------
247
+ self.lora_ind = []
248
+ if enable_q:
249
+ self.lora_ind.extend(range(0, self.linear.in_features))
250
+ if enable_k:
251
+ self.lora_ind.extend(range(self.linear.in_features, self.linear.in_features + self.kv_embd_size))
252
+ if enable_v:
253
+ self.lora_ind.extend(range(self.linear.in_features + self.kv_embd_size, self.linear.out_features))
254
+ self.reset_parameters()
255
+
256
+ def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
257
+ """Properly pad weight updates with zeros.
258
+
259
+ If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys,
260
+ then the weights update should be:
261
+
262
+ [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
263
+ [....................................],
264
+ [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
265
+ ↑ ↑ ↑
266
+ ________________________________________
267
+ | query | key | value |
268
+ ----------------------------------------
269
+
270
+ Args:
271
+ x: tensor with weights update that will be padded with zeros if necessary
272
+
273
+ Returns:
274
+ A tensor with weight updates and zeros for deselected q, k or v
275
+ """
276
+ # we need to do zero padding only if LoRA is disabled for one of QKV matrices
277
+ if all(self.enable_lora):
278
+ return x
279
+
280
+ # Let's image that:
281
+ # ⚬ input x has shape (64, 64, 256): (batch_size, sequence_length, embeddings_size)
282
+ # ⚬ embeddings_size: 128
283
+ # ⚬ self.linear.out_features: 384 (3 * embeddings_size)
284
+ # ⚬ enable_lora: [True, False, True]
285
+ # Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected
286
+ # embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but
287
+ # only for key updates (this is where self.lora_ind comes in handy)
288
+ # Note: double transpose (in the beginning and in the end) is basically a guard for two-dimensional tensors
289
+ # for example when we want to merge/unmerge LoRA weights and pretrained weights
290
+ x = x.transpose(0, 1)
291
+ result = x.new_zeros((*x.shape[:-1], self.linear.out_features)) # (64, 64, 384)
292
+ result = result.view(-1, self.linear.out_features) # (4096, 384)
293
+ result = result.index_copy(
294
+ 1, torch.tensor(self.lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes))
295
+ ) # (4096, 256)
296
+ return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1) # (64, 64, 384)
297
+
298
+ def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
299
+ """An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries.
300
+
301
+ If the number of heads is equal to the number of query groups - grouped queries are disabled
302
+ (see scheme in `lit_gpt/config.py:Config`). In this case the combined QKV matrix consists of equally sized
303
+ query, key and value parts, which means we can utilize `groups` argument from `conv1d`: with this argument the
304
+ input and weight matrices will be splitted in equally sized parts and applied separately (like having multiple
305
+ conv layers side by side).
306
+
307
+ Otherwise QKV matrix consists of unequally sized parts and thus we have to split input and weight matrices manually,
308
+ apply each part of the weight matrix to the corresponding input's part and concatenate the result.
309
+
310
+ Args:
311
+ input: input matrix of shape (B, C, T)
312
+ weight: weight matrix of shape (C_output, rank, 1).
313
+ "C_output" is defined as a sum of embedding sizes for each enabled LoRA layer (see init method of the class).
314
+
315
+ Returns:
316
+ A tensor with a shape (B, C_output, T)
317
+
318
+ """
319
+ if self.n_head == self.n_query_groups:
320
+ return F.conv1d(input, weight, groups=sum(self.enable_lora)) # (B, C_output, T)
321
+
322
+ # Notation:
323
+ # ⚬ N: number of enabled LoRA layers (self.enable_lora)
324
+ # ⚬ C_output': embeddings size for each LoRA layer (not equal in size)
325
+ # ⚬ r: rank of all LoRA layers (equal in size)
326
+
327
+ input_splitted = input.chunk(sum(self.enable_lora), dim=1) # N * (B, C // N, T)
328
+ weight_splitted = weight.split(self.qkv_shapes) # N * (C_output', r, 1)
329
+ return torch.cat(
330
+ [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1 # (B, C_output', T)
331
+ ) # (B, C_output, T)
332
+
333
+ def merge(self) -> None:
334
+ """Merges the LoRA weights into the full-rank weights (W = W + delta_W)."""
335
+
336
+ # Let's assume that:
337
+ # ⚬ self.linear.weight.data: (384, 128) or (3 * embedding_size, embedding_size)
338
+ # ⚬ self.lora_A.data: (4, 128)
339
+ # ⚬ self.lora_B.data: (256, 2)
340
+ if self.r > 0 and any(self.enable_lora) and not self.merged:
341
+ delta_w = self.conv1d(
342
+ self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128)
343
+ self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
344
+ ).squeeze(
345
+ 0
346
+ ) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)
347
+ # W = W + delta_W (merge)
348
+ self.linear.weight.data += self.zero_pad(delta_w * self.scaling) # (256, 128) after zero_pad (384, 128)
349
+ self.merged = True
350
+
351
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
352
+ """Do the forward pass.
353
+
354
+ If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication.
355
+ If not, then multiply pretrained weights with input, apply LoRA on input and do summation.
356
+
357
+ Args:
358
+ x: input tensor of shape (batch_size, context_length, embedding_size)
359
+
360
+ Returns:
361
+ Output tensor of shape (batch_size, context_length, 3 * embedding_size)
362
+ """
363
+
364
+ # Let's assume that:
365
+ # ⚬ x: (64, 64, 128) or (batch_size, context_length, embedding_size)
366
+ # ⚬ self.linear.weight: (384, 128) or (3 * embedding_size, embedding_size)
367
+ # ⚬ self.lora_A.data: (4, 128)
368
+ # ⚬ self.lora_B.data: (256, 2)
369
+
370
+ # if weights are merged or LoRA is disabled (r <= 0 or all `enable_lora` are False) - it's only a regular nn.Linear forward pass;
371
+ # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights
372
+ pretrained = self.linear(x)
373
+ if self.r == 0 or not any(self.enable_lora) or self.merged:
374
+ return pretrained
375
+ after_A = F.linear(self.lora_dropout(x), self.lora_A) # (64, 64, 128) @ (4, 128) -> (64, 64, 4)
376
+ # For F.conv1d:
377
+ # ⚬ input: input tensor of shape (mini-batch, in_channels, iW)
378
+ # ⚬ weight: filters of shape (out_channels, in_channels/groups, kW)
379
+ after_B = self.conv1d(
380
+ after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64)
381
+ self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
382
+ ).transpose(
383
+ -2, -1
384
+ ) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256)
385
+ lora = self.zero_pad(after_B) * self.scaling # (64, 64, 256) after zero_pad (64, 64, 384)
386
+ return pretrained + lora
387
+
388
+
389
+ def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
390
+ """Freeze all modules except LoRA's and depending on 'bias' value unfreezes bias weights.
391
+
392
+ Args:
393
+ model: model with LoRA layers
394
+ bias:
395
+ ``"none"``: all bias weights will be frozen,
396
+ ``"lora_only"``: only bias weight for LoRA layers will be unfrozen,
397
+ ``"all"``: all bias weights will be unfrozen.
398
+
399
+ Raises:
400
+ NotImplementedError: if `bias` not in ["none", "lora_only", "all"]
401
+ """
402
+ # freeze all layers except LoRA's
403
+ for n, p in model.named_parameters():
404
+ if "lora_" not in n:
405
+ p.requires_grad = False
406
+
407
+ # depending on the `bias` value unfreeze bias weights
408
+ if bias == "none":
409
+ return
410
+ if bias == "all":
411
+ for n, p in model.named_parameters():
412
+ if "bias" in n:
413
+ p.requires_grad = True
414
+ elif bias == "lora_only":
415
+ for m in model.modules():
416
+ if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None:
417
+ m.bias.requires_grad = True
418
+ else:
419
+ raise NotImplementedError
420
+
421
+
422
+ def lora_filter(key: str, value: Any) -> bool:
423
+ return "lora_" in key
424
+
425
+
426
+ @dataclass
427
+ class Config(BaseConfig):
428
+ """
429
+ Args:
430
+ r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
431
+ the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
432
+ alpha: alpha is needed for scaling updates as alpha/r
433
+ "This scaling helps to reduce the need to retune hyperparameters when we vary r"
434
+ https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
435
+ dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
436
+ to_*: either apply LoRA to the specified weights or not
437
+ """
438
+
439
+ r: int = 0
440
+ alpha: int = 1
441
+ dropout: float = 0.0
442
+ to_query: bool = False
443
+ to_key: bool = False
444
+ to_value: bool = False
445
+ to_projection: bool = False
446
+ to_mlp: bool = False
447
+ to_head: bool = False
448
+
449
+ @property
450
+ def mlp_class(self) -> Type:
451
+ return getattr(lit_gpt.lora, self._mlp_class)
452
+
453
+
454
+ class GPT(BaseModel):
455
+ def __init__(self, config: Config) -> None:
456
+ nn.Module.__init__(self)
457
+ assert config.padded_vocab_size is not None
458
+ self.config = config
459
+
460
+ self.lm_head = LoRALinear(
461
+ config.n_embd,
462
+ config.padded_vocab_size,
463
+ bias=config.lm_head_bias,
464
+ r=(config.r if config.to_head else 0),
465
+ lora_alpha=config.alpha,
466
+ lora_dropout=config.dropout,
467
+ )
468
+ self.transformer = nn.ModuleDict(
469
+ dict(
470
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
471
+ h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
472
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
473
+ )
474
+ )
475
+ self.max_seq_length = self.config.block_size
476
+ self.mask_cache: Optional[torch.Tensor] = None
477
+
478
+ def forward(
479
+ self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0
480
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
481
+ T = idx.size(1)
482
+ if self.max_seq_length < T:
483
+ raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")
484
+
485
+ if input_pos is not None: # use the kv cache
486
+ cos = self.cos.index_select(0, input_pos)
487
+ sin = self.sin.index_select(0, input_pos)
488
+ if self.mask_cache is None:
489
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
490
+ mask = self.mask_cache.index_select(2, input_pos)
491
+ else:
492
+ cos = self.cos[:T]
493
+ sin = self.sin[:T]
494
+ mask = None
495
+
496
+ x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
497
+ for block in self.transformer.h:
498
+ x = block(x, cos, sin, mask, input_pos)
499
+ x = self.transformer.ln_f(x)
500
+ if lm_head_chunk_size > 0:
501
+ # chunk the lm head logits to reduce the peak memory used by autograd
502
+ return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)]
503
+ return self.lm_head(x) # (B, T, vocab_size)
504
+
505
+ @classmethod
506
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
507
+ return cls(Config.from_name(name, **kwargs))
508
+
509
+ def _init_weights(self, module: nn.Module) -> None:
510
+ """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
511
+ super()._init_weights(module)
512
+ if isinstance(module, LoRALinear):
513
+ module.reset_parameters()
514
+
515
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
516
+ """For compatibility with base checkpoints."""
517
+ mapping = {"lm_head.weight": "lm_head.linear.weight", "lm_head.bias": "lm_head.linear.bias"}
518
+ state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
519
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
520
+
521
+
522
+ class Block(BaseBlock):
523
+ def __init__(self, config: Config) -> None:
524
+ nn.Module.__init__(self)
525
+ self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
526
+ self.attn = CausalSelfAttention(config)
527
+ if not config.shared_attention_norm:
528
+ self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
529
+ self.mlp = config.mlp_class(config)
530
+
531
+ self.config = config
532
+
533
+
534
+ class CausalSelfAttention(BaseCausalSelfAttention):
535
+ def __init__(self, config: Config) -> None:
536
+ # Skip the parent class __init__ altogether and replace it to avoid
537
+ # useless allocations
538
+ nn.Module.__init__(self)
539
+ shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
540
+ # key, query, value projections for all heads, but in a batch
541
+ self.attn = LoRAQKVLinear(
542
+ in_features=config.n_embd,
543
+ out_features=shape,
544
+ r=config.r,
545
+ lora_alpha=config.alpha,
546
+ lora_dropout=config.dropout,
547
+ enable_lora=(config.to_query, config.to_key, config.to_value),
548
+ bias=config.bias,
549
+ # for MQA/GQA support
550
+ n_head=config.n_head,
551
+ n_query_groups=config.n_query_groups,
552
+ )
553
+ # output projection
554
+ self.proj = LoRALinear(
555
+ config.n_embd,
556
+ config.n_embd,
557
+ bias=config.bias,
558
+ r=(config.r if config.to_projection else 0),
559
+ lora_alpha=config.alpha,
560
+ lora_dropout=config.dropout,
561
+ )
562
+ # disabled by default
563
+ self.kv_cache: Optional[KVCache] = None
564
+
565
+ self.config = config
566
+
567
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
568
+ """For compatibility with base checkpoints."""
569
+ mapping = {
570
+ "attn.weight": "attn.linear.weight",
571
+ "attn.bias": "attn.linear.bias",
572
+ "proj.weight": "proj.linear.weight",
573
+ "proj.bias": "proj.linear.bias",
574
+ }
575
+ state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
576
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
577
+
578
+
579
+ class GptNeoxMLP(lit_gpt.model.GptNeoxMLP):
580
+ def __init__(self, config: Config) -> None:
581
+ nn.Module.__init__(self)
582
+ self.fc = LoRALinear(
583
+ config.n_embd,
584
+ config.intermediate_size,
585
+ bias=config.bias,
586
+ r=(config.r if config.to_mlp else 0),
587
+ lora_alpha=config.alpha,
588
+ lora_dropout=config.dropout,
589
+ )
590
+ self.proj = LoRALinear(
591
+ config.intermediate_size,
592
+ config.n_embd,
593
+ bias=config.bias,
594
+ r=(config.r if config.to_mlp else 0),
595
+ lora_alpha=config.alpha,
596
+ lora_dropout=config.dropout,
597
+ )
598
+
599
+ self.config = config
600
+
601
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
602
+ """For compatibility with base checkpoints."""
603
+ mapping = {
604
+ "fc.weight": "fc.linear.weight",
605
+ "fc.bias": "fc.linear.bias",
606
+ "proj.weight": "proj.linear.weight",
607
+ "proj.bias": "proj.linear.bias",
608
+ }
609
+ state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
610
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
611
+
612
+
613
+ class LLaMAMLP(lit_gpt.model.LLaMAMLP):
614
+ def __init__(self, config: Config) -> None:
615
+ nn.Module.__init__(self)
616
+ self.fc_1 = LoRALinear(
617
+ config.n_embd,
618
+ config.intermediate_size,
619
+ bias=config.bias,
620
+ r=(config.r if config.to_mlp else 0),
621
+ lora_alpha=config.alpha,
622
+ lora_dropout=config.dropout,
623
+ )
624
+ self.fc_2 = LoRALinear(
625
+ config.n_embd,
626
+ config.intermediate_size,
627
+ bias=config.bias,
628
+ r=(config.r if config.to_mlp else 0),
629
+ lora_alpha=config.alpha,
630
+ lora_dropout=config.dropout,
631
+ )
632
+ self.proj = LoRALinear(
633
+ config.intermediate_size,
634
+ config.n_embd,
635
+ bias=config.bias,
636
+ r=(config.r if config.to_mlp else 0),
637
+ lora_alpha=config.alpha,
638
+ lora_dropout=config.dropout,
639
+ )
640
+
641
+ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
642
+ """For compatibility with base checkpoints."""
643
+ mapping = {
644
+ "fc_1.weight": "fc_1.linear.weight",
645
+ "fc_1.bias": "fc_1.linear.bias",
646
+ "fc_2.weight": "fc_2.linear.weight",
647
+ "fc_2.bias": "fc_2.linear.bias",
648
+ "proj.weight": "proj.linear.weight",
649
+ "proj.bias": "proj.linear.bias",
650
+ }
651
+ state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
652
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
653
+
654
+
655
+ def merge_lora_weights(model: GPT) -> None:
656
+ """Merge LoRA weights into the full-rank weights to speed up inference."""
657
+ for module in model.modules():
658
+ if isinstance(module, LoRALinear):
659
+ module.merge()
lit_gpt/model.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Full definition of a GPT NeoX Language Model, all of it in this single file.
2
+
3
+ Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
4
+ https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
5
+ """
6
+ import math
7
+ from typing import Any, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from typing_extensions import Self
12
+
13
+ from lit_gpt.config import Config
14
+
15
+
16
+ class GPT(nn.Module):
17
+ def __init__(self, config: Config) -> None:
18
+ super().__init__()
19
+ assert config.padded_vocab_size is not None
20
+ self.config = config
21
+
22
+ self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
23
+ self.transformer = nn.ModuleDict(
24
+ dict(
25
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
26
+ h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
27
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
28
+ )
29
+ )
30
+ self.max_seq_length = self.config.block_size
31
+ self.mask_cache: Optional[torch.Tensor] = None
32
+
33
+ @property
34
+ def max_seq_length(self) -> int:
35
+ return self._max_seq_length
36
+
37
+ @max_seq_length.setter
38
+ def max_seq_length(self, value: int) -> None:
39
+ """
40
+ When doing inference, the sequences used might be shorter than the model's context length.
41
+ This allows setting a smaller number to avoid allocating unused memory
42
+ """
43
+ if value > self.config.block_size:
44
+ raise ValueError(f"Cannot attend to {value}, block size is only {self.config.block_size}")
45
+ self._max_seq_length = value
46
+ if not hasattr(self, "cos"):
47
+ # first call
48
+ cos, sin = self.rope_cache()
49
+ self.register_buffer("cos", cos, persistent=False)
50
+ self.register_buffer("sin", sin, persistent=False)
51
+ elif value != self.cos.size(0):
52
+ # override
53
+ self.cos, self.sin = self.rope_cache(device=self.cos.device)
54
+ # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know
55
+ # if the kv cache is expected
56
+
57
+ def reset_parameters(self) -> None:
58
+ # Trigger resetting the rope-cache
59
+ self.max_seq_length = self.config.block_size
60
+
61
+ def _init_weights(self, module: nn.Module) -> None:
62
+ """Meant to be used with `gpt.apply(gpt._init_weights)`."""
63
+ if isinstance(module, nn.Linear):
64
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
65
+ if module.bias is not None:
66
+ torch.nn.init.zeros_(module.bias)
67
+ elif isinstance(module, nn.Embedding):
68
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
69
+
70
+ def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
71
+ T = idx.size(1)
72
+ if self.max_seq_length < T:
73
+ raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")
74
+
75
+ if input_pos is not None: # use the kv cache
76
+ cos = self.cos.index_select(0, input_pos)
77
+ sin = self.sin.index_select(0, input_pos)
78
+ if self.mask_cache is None:
79
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
80
+ mask = self.mask_cache.index_select(2, input_pos)
81
+ else:
82
+ cos = self.cos[:T]
83
+ sin = self.sin[:T]
84
+ mask = None
85
+
86
+ x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
87
+ for block in self.transformer.h:
88
+ x = block(x, cos, sin, mask, input_pos)
89
+ x = self.transformer.ln_f(x)
90
+ return self.lm_head(x) # (b, t, vocab_size)
91
+
92
+ @classmethod
93
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
94
+ return cls(Config.from_name(name, **kwargs))
95
+
96
+ def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tensor, torch.Tensor]:
97
+ return build_rope_cache(
98
+ seq_len=self.max_seq_length,
99
+ n_elem=self.config.rope_n_elem,
100
+ device=device,
101
+ condense_ratio=self.config.rope_condense_ratio,
102
+ base=self.config.rope_base,
103
+ )
104
+
105
+ def set_kv_cache(
106
+ self,
107
+ batch_size: int,
108
+ rope_cache_length: Optional[int] = None,
109
+ device: Optional[torch.device] = None,
110
+ dtype: Optional[torch.dtype] = None,
111
+ ) -> None:
112
+ if rope_cache_length is None:
113
+ rope_cache_length = self.cos.size(-1)
114
+ max_seq_length = self.max_seq_length
115
+
116
+ # initialize the kv cache for all blocks
117
+ for block in self.transformer.h:
118
+ block.attn.kv_cache = block.attn.build_kv_cache(
119
+ batch_size, max_seq_length, rope_cache_length, device, dtype
120
+ )
121
+
122
+ if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length:
123
+ # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask
124
+ # for the kv-cache support (only during inference), we only create it in that situation
125
+ # this will be resolved by https://github.com/pytorch/pytorch/issues/96099
126
+ ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)
127
+ self.mask_cache = torch.tril(ones).unsqueeze(0).unsqueeze(0)
128
+
129
+ def clear_kv_cache(self) -> None:
130
+ self.mask_cache = None
131
+ for block in self.transformer.h:
132
+ block.attn.kv_cache = None
133
+
134
+
135
+ class Block(nn.Module):
136
+ def __init__(self, config: Config) -> None:
137
+ super().__init__()
138
+ self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
139
+ self.attn = CausalSelfAttention(config)
140
+ self.norm_2 = None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps)
141
+ self.mlp = config.mlp_class(config)
142
+
143
+ self.config = config
144
+
145
+ def forward(
146
+ self,
147
+ x: torch.Tensor,
148
+ cos: torch.Tensor,
149
+ sin: torch.Tensor,
150
+ mask: Optional[torch.Tensor] = None,
151
+ input_pos: Optional[torch.Tensor] = None,
152
+ ) -> torch.Tensor:
153
+ n_1 = self.norm_1(x)
154
+ h = self.attn(n_1, cos, sin, mask, input_pos)
155
+ if self.config.parallel_residual:
156
+ n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
157
+ x = self.mlp(n_2) + h + x
158
+ else:
159
+ if self.config.shared_attention_norm:
160
+ raise NotImplementedError(
161
+ "No checkpoint amongst the ones we support uses this configuration"
162
+ " (non-parallel residual and shared attention norm)."
163
+ )
164
+ x = h + x
165
+ x = self.mlp(self.norm_2(x)) + x
166
+ return x
167
+
168
+
169
+ class CausalSelfAttention(nn.Module):
170
+ def __init__(self, config: Config) -> None:
171
+ super().__init__()
172
+ shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
173
+ # key, query, value projections for all heads, but in a batch
174
+ self.attn = nn.Linear(config.n_embd, shape, bias=config.bias)
175
+ # output projection
176
+ self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
177
+ # disabled by default
178
+ self.kv_cache: Optional[KVCache] = None
179
+
180
+ self.config = config
181
+
182
+ def forward(
183
+ self,
184
+ x: torch.Tensor,
185
+ cos: torch.Tensor,
186
+ sin: torch.Tensor,
187
+ mask: Optional[torch.Tensor] = None,
188
+ input_pos: Optional[torch.Tensor] = None,
189
+ ) -> torch.Tensor:
190
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
191
+
192
+ qkv = self.attn(x)
193
+
194
+ # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
195
+ q_per_kv = self.config.n_head // self.config.n_query_groups
196
+ total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
197
+ qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size)
198
+ qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)
199
+
200
+ # split batched computation into three
201
+ q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
202
+
203
+ # maybe repeat k and v if for the non multi-head attention cases
204
+ # training: flash attention requires it
205
+ # inference: multi-query would require a full kv cache so avoid it to limit its memory usage
206
+ if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1):
207
+ k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
208
+ v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
209
+
210
+ q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs)
211
+ k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs)
212
+ v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs)
213
+
214
+ q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin)
215
+ k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin)
216
+ q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1)
217
+ k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1)
218
+
219
+ if input_pos is not None:
220
+ if not isinstance(self.kv_cache, KVCache):
221
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
222
+ k, v = self.kv_cache(input_pos, k, v)
223
+
224
+ y = self.scaled_dot_product_attention(q, k, v, mask)
225
+
226
+ y = y.reshape(B, T, C) # re-assemble all head outputs side by side
227
+
228
+ # output projection
229
+ return self.proj(y)
230
+
231
+ def scaled_dot_product_attention(
232
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
233
+ ) -> torch.Tensor:
234
+ scale = 1.0 / math.sqrt(self.config.head_size)
235
+ y = torch.nn.functional.scaled_dot_product_attention(
236
+ q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
237
+ )
238
+ return y.transpose(1, 2)
239
+
240
+ def build_kv_cache(
241
+ self,
242
+ batch_size: int,
243
+ max_seq_length: int,
244
+ rope_cache_length: Optional[int] = None,
245
+ device: Optional[torch.device] = None,
246
+ dtype: Optional[torch.dtype] = None,
247
+ ) -> "KVCache":
248
+ heads = 1 if self.config.n_query_groups == 1 else self.config.n_head
249
+ v_shape = (batch_size, heads, max_seq_length, self.config.head_size)
250
+ if rope_cache_length is None:
251
+ if self.config.rotary_percentage != 1.0:
252
+ raise TypeError("Please pass the `rope_cache_length=gpt.cos.size(-1)` value")
253
+ k_shape = v_shape
254
+ else:
255
+ k_shape = (
256
+ batch_size,
257
+ heads,
258
+ max_seq_length,
259
+ rope_cache_length + self.config.head_size - self.config.rope_n_elem,
260
+ )
261
+ return KVCache(k_shape, v_shape, device=device, dtype=dtype)
262
+
263
+
264
+ class GptNeoxMLP(nn.Module):
265
+ def __init__(self, config: Config) -> None:
266
+ super().__init__()
267
+ self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
268
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
269
+
270
+ self.config = config
271
+
272
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
273
+ x = self.fc(x)
274
+ x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate)
275
+ return self.proj(x)
276
+
277
+
278
+ class LLaMAMLP(nn.Module):
279
+ def __init__(self, config: Config) -> None:
280
+ super().__init__()
281
+ self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
282
+ self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
283
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
284
+
285
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
286
+ x_fc_1 = self.fc_1(x)
287
+ x_fc_2 = self.fc_2(x)
288
+ x = torch.nn.functional.silu(x_fc_1) * x_fc_2
289
+ return self.proj(x)
290
+
291
+
292
+ def build_rope_cache(
293
+ seq_len: int, n_elem: int, device: Optional[torch.device] = None, base: int = 10000, condense_ratio: int = 1
294
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
295
+ """Enhanced Transformer with Rotary Position Embedding.
296
+
297
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
298
+ transformers/rope/__init__.py. MIT License:
299
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
300
+ """
301
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
302
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
303
+
304
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
305
+ seq_idx = torch.arange(seq_len, device=device) / condense_ratio
306
+
307
+ # Calculate the product of position index and $\theta_i$
308
+ idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
309
+
310
+ return torch.cos(idx_theta), torch.sin(idx_theta)
311
+
312
+
313
+ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
314
+ head_size = x.size(-1)
315
+ x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
316
+ x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
317
+ rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
318
+ roped = (x * cos) + (rotated * sin)
319
+ return roped.type_as(x)
320
+
321
+
322
+ class KVCache(nn.Module):
323
+ def __init__(
324
+ self,
325
+ k_shape: Tuple[int, int, int, int],
326
+ v_shape: Tuple[int, int, int, int],
327
+ device: Optional[torch.device] = None,
328
+ dtype: Optional[torch.dtype] = None,
329
+ ) -> None:
330
+ super().__init__()
331
+ self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False)
332
+ self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False)
333
+
334
+ def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
335
+ # move the buffer to the activation dtype for when AMP is used
336
+ self.k = self.k.to(k.dtype)
337
+ self.v = self.v.to(v.dtype)
338
+ # update the cache
339
+ k = self.k.index_copy_(2, input_pos, k)
340
+ v = self.v.index_copy_(2, input_pos, v)
341
+ return k, v
342
+
343
+ def reset_parameters(self) -> None:
344
+ torch.nn.init.zeros_(self.k)
345
+ torch.nn.init.zeros_(self.v)
lit_gpt/packed_dataset.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Very loosely inspired by indexed_dataset in Fairseq, Megatron
2
+ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py
3
+
4
+
5
+ import os
6
+ import random
7
+ import struct
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch.utils.data import IterableDataset, get_worker_info
12
+
13
+ dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float32, 7: np.float64, 8: np.uint16}
14
+
15
+
16
+ def code(dtype):
17
+ for k in dtypes:
18
+ if dtypes[k] == dtype:
19
+ return k
20
+ raise ValueError(dtype)
21
+
22
+
23
+ HDR_MAGIC = b"LITPKDS"
24
+ HDR_SIZE = 24 # bytes
25
+
26
+
27
+ class PackedDataset(IterableDataset):
28
+ def __init__(
29
+ self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0
30
+ ):
31
+ self._filenames = filenames
32
+ self._n_chunks = n_chunks
33
+ self._block_size = block_size
34
+ self._seed = seed
35
+ self._shuffle = shuffle
36
+ self._wrap = wrap
37
+ self._num_processes = num_processes
38
+ self._process_rank = process_rank
39
+
40
+ def __iter__(self):
41
+ worker_info = get_worker_info()
42
+ num_workers = worker_info.num_workers if worker_info is not None else 1
43
+ worker_id = worker_info.id if worker_info is not None else 0
44
+ num_shards = num_workers * self._num_processes
45
+ shard_id = self._process_rank * num_workers + worker_id
46
+
47
+ max_num_files = len(self._filenames) // num_shards * num_shards
48
+ filenames = self._filenames[shard_id:max_num_files:num_shards]
49
+
50
+ return PackedDatasetIterator(
51
+ filenames=filenames,
52
+ n_chunks=self._n_chunks,
53
+ block_size=self._block_size,
54
+ seed=self._seed,
55
+ shuffle=self._shuffle,
56
+ wrap=self._wrap,
57
+ )
58
+
59
+
60
+ class PackedDatasetBuilder(object):
61
+ def __init__(self, outdir, prefix, chunk_size, sep_token, dtype="auto", vocab_size=None):
62
+ if dtype == "auto":
63
+ if vocab_size is None:
64
+ raise ValueError("vocab_size cannot be None when dtype='auto'")
65
+ if vocab_size is not None and vocab_size < 65500:
66
+ self._dtype = np.uint16
67
+ else:
68
+ self._dtype = np.int32
69
+ else:
70
+ self._dtype = dtype
71
+ self._counter = 0
72
+ self._chunk_size = chunk_size
73
+ self._outdir = outdir
74
+ self._prefix = prefix
75
+ self._sep_token = sep_token
76
+ self._arr = np.zeros(self._chunk_size, dtype=self._dtype)
77
+ self._arr.fill(self._sep_token)
78
+ self._idx = 0
79
+ self._version = 1
80
+ self._filenames = []
81
+
82
+ def _write_chunk(self):
83
+ filename = f"{self._prefix}_{self._counter:010d}.bin"
84
+ filename = os.path.join(self._outdir, filename)
85
+
86
+ with open(filename, "wb") as f:
87
+ f.write(HDR_MAGIC)
88
+ f.write(struct.pack("<Q", self._version))
89
+ f.write(struct.pack("<B", code(self._dtype)))
90
+ f.write(struct.pack("<Q", self._chunk_size))
91
+ f.write(self._arr.tobytes(order="C"))
92
+
93
+ self._filenames.append(filename)
94
+ self._counter += 1
95
+ self._arr.fill(self._sep_token)
96
+ self._idx = 0
97
+
98
+ @property
99
+ def dtype(self):
100
+ return self._dtype
101
+
102
+ @property
103
+ def filenames(self):
104
+ return self._filenames.copy()
105
+
106
+ def add_array(self, arr):
107
+ while self._idx + arr.shape[0] > self._chunk_size:
108
+ part_len = self._chunk_size - self._idx
109
+ self._arr[self._idx : self._idx + part_len] = arr[:part_len]
110
+ self._write_chunk()
111
+ arr = arr[part_len:]
112
+
113
+ arr_len = arr.shape[0]
114
+ self._arr[self._idx : self._idx + arr_len] = arr
115
+ self._idx += arr_len
116
+
117
+ def write_reminder(self):
118
+ self._write_chunk()
119
+
120
+
121
+ class PackedDatasetIterator:
122
+ def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap):
123
+ self._seed = seed
124
+ self._shuffle = shuffle
125
+ self._rng = np.random.default_rng(seed) if shuffle else None
126
+ self._block_idxs = None
127
+
128
+ self._wrap = wrap
129
+
130
+ # TODO: instead of filenames, we could have a single text stream
131
+ # (or text file) with the sequence of all files to be
132
+ # fetched/loaded.
133
+ self._filenames = filenames
134
+ self._file_idx = 0
135
+
136
+ self._n_chunks = n_chunks
137
+
138
+ self._dtype = None
139
+ self._block_size = block_size
140
+ self._n_blocks = None
141
+
142
+ self._mmaps = []
143
+ self._buffers = []
144
+
145
+ self._block_idxs = []
146
+ self._curr_idx = 0
147
+
148
+ self._load_n_chunks()
149
+
150
+ def _read_header(self, path):
151
+ with open(path, "rb") as f:
152
+ magic = f.read(len(HDR_MAGIC))
153
+ assert magic == HDR_MAGIC, "File doesn't match expected format."
154
+ version = struct.unpack("<Q", f.read(8))
155
+ assert version == (1,)
156
+ (dtype_code,) = struct.unpack("<B", f.read(1))
157
+ dtype = dtypes[dtype_code]
158
+ (chunk_size,) = struct.unpack("<Q", f.read(8))
159
+ return dtype, chunk_size
160
+
161
+ def _close_mmaps(self):
162
+ for mmap in self._mmaps:
163
+ mmap._mmap.close()
164
+
165
+ def _load_n_chunks(self):
166
+ self._close_mmaps()
167
+ self._mmaps = []
168
+ self._buffers = []
169
+
170
+ if self._n_chunks > len(self._filenames[self._file_idx :]):
171
+ if not self._wrap:
172
+ raise StopIteration
173
+ self._file_idx = 0
174
+
175
+ for i in range(self._n_chunks):
176
+ filename = self._filenames[self._file_idx + i]
177
+ if self._dtype is None:
178
+ self._dtype, self._chunk_size = self._read_header(filename)
179
+ self._n_blocks = self._chunk_size // self._block_size
180
+ # TODO: check header matches with previous files
181
+ mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE)
182
+ self._mmaps.append(mmap)
183
+ self._buffers.append(memoryview(mmap))
184
+
185
+ self._file_idx += self._n_chunks
186
+ n_all_blocks = self._n_chunks * self._n_blocks
187
+
188
+ self._block_idxs = self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks)
189
+
190
+ self._curr_idx = 0
191
+
192
+ def __del__(self):
193
+ self._close_mmaps()
194
+ del self._mmaps
195
+ del self._buffers
196
+
197
+ def __iter__(self):
198
+ return self
199
+
200
+ def __next__(self):
201
+ if self._curr_idx >= len(self._block_idxs):
202
+ self._load_n_chunks()
203
+ # TODO: trigger fetching next next n_chunks if remote
204
+ block_idx = self._block_idxs[self._curr_idx]
205
+ chunk_id = block_idx // self._n_blocks
206
+ buffer = self._buffers[chunk_id]
207
+ elem_id = (block_idx % self._n_blocks) * self._block_size
208
+ offset = np.dtype(self._dtype).itemsize * elem_id
209
+ arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)
210
+ self._curr_idx += 1
211
+ return torch.from_numpy(arr.astype(np.int64))
212
+
213
+
214
+ class CombinedDataset(IterableDataset):
215
+ def __init__(self, datasets, seed, weights=None):
216
+ self._seed = seed
217
+ self._datasets = datasets
218
+ self._weights = weights
219
+ n_datasets = len(datasets)
220
+ if weights is None:
221
+ self._weights = [1 / n_datasets] * n_datasets
222
+ else:
223
+ self._weights = [w / sum(weights) for w in weights]
224
+
225
+ def __iter__(self):
226
+ return CombinedDatasetIterator(self._datasets, self._seed, self._weights)
227
+
228
+
229
+ class CombinedDatasetIterator:
230
+ def __init__(self, datasets, seed, weights):
231
+ self._datasets = [iter(el) for el in datasets]
232
+ self._weights = weights
233
+ self._rng = random.Random(seed)
234
+
235
+ def __next__(self):
236
+ (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1)
237
+ return next(dataset)
lit_gpt/rmsnorm.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class RMSNorm(torch.nn.Module):
5
+ """Root Mean Square Layer Normalization.
6
+
7
+ Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
8
+ https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
9
+ """
10
+
11
+ def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
12
+ super().__init__()
13
+ self.weight = torch.nn.Parameter(torch.ones(size))
14
+ self.eps = eps
15
+ self.dim = dim
16
+
17
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
18
+ dtype = x.dtype
19
+ x = x.float()
20
+ # NOTE: the original RMSNorm paper implementation is not equivalent
21
+ norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
22
+ x_normed = x * torch.rsqrt(norm_x + self.eps)
23
+ return (self.weight * x_normed).to(dtype=dtype)
24
+
25
+ def reset_parameters(self) -> None:
26
+ torch.nn.init.ones_(self.weight)
lit_gpt/tokenizer.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+
7
+
8
+ class Tokenizer:
9
+ def __init__(self, checkpoint_dir: Union[Path, str]) -> None:
10
+ checkpoint_dir = Path(checkpoint_dir)
11
+ if not checkpoint_dir.exists():
12
+ raise NotADirectoryError(f"The checkpoint directory does not exist: {str(checkpoint_dir)}")
13
+
14
+ self.use_bos = self.check_if_bos_token_used(checkpoint_dir)
15
+ self.bos_id = None
16
+ self.eos_id = None
17
+
18
+ # some checkpoints have both files, `.model` takes precedence
19
+ if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file():
20
+ from sentencepiece import SentencePieceProcessor
21
+
22
+ self.processor = SentencePieceProcessor(model_file=str(vocabulary_path))
23
+ self.backend = "sentencepiece"
24
+ self.bos_id = self.processor.bos_id()
25
+ self.eos_id = self.processor.eos_id()
26
+
27
+ elif (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file():
28
+ from tokenizers import Tokenizer as HFTokenizer
29
+
30
+ self.processor = HFTokenizer.from_file(str(vocabulary_path))
31
+ self.backend = "huggingface"
32
+
33
+ if (special_tokens_path := checkpoint_dir / "tokenizer_config.json").is_file():
34
+ with open(special_tokens_path) as fp:
35
+ config = json.load(fp)
36
+ bos_token = config.get("bos_token")
37
+ self.bos_id = self.token_to_id(bos_token) if bos_token is not None else None
38
+ eos_token = config.get("eos_token")
39
+ self.eos_id = self.token_to_id(eos_token) if eos_token is not None else None
40
+ if (special_tokens_path := checkpoint_dir / "generation_config.json").is_file():
41
+ with open(special_tokens_path) as fp:
42
+ config = json.load(fp)
43
+ if self.bos_id is None:
44
+ self.bos_id = config.get("bos_token_id")
45
+ if self.eos_id is None:
46
+ self.eos_id = config.get("eos_token_id")
47
+ else:
48
+ raise NotImplementedError
49
+
50
+ @property
51
+ def vocab_size(self) -> int:
52
+ if self.backend == "huggingface":
53
+ return self.processor.get_vocab_size(with_added_tokens=False)
54
+ if self.backend == "sentencepiece":
55
+ return self.processor.vocab_size()
56
+ raise RuntimeError
57
+
58
+ def token_to_id(self, token: str) -> int:
59
+ if self.backend == "huggingface":
60
+ id_ = self.processor.token_to_id(token)
61
+ elif self.backend == "sentencepiece":
62
+ id_ = self.processor.piece_to_id(token)
63
+ else:
64
+ raise RuntimeError
65
+ if id_ is None:
66
+ raise ValueError(f"token {token!r} not found in the collection.")
67
+ return id_
68
+
69
+ def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
70
+ if not (tokenizer_config_path := checkpoint_dir / "tokenizer_config.json").is_file():
71
+ return False
72
+ with open(tokenizer_config_path) as fp:
73
+ config = json.load(fp)
74
+ if any(config.get(check, False) for check in ("add_bos_token", "add_prefix_space")):
75
+ return True
76
+ # for examples that also use the Llama tokenizer, but do not have or set add_bos_token to True.
77
+ # ex: https://huggingface.co/stabilityai/StableBeluga2/blob/main/tokenizer_config.json#L2
78
+ return config.get("add_bos_token") is None and config.get("tokenizer_class") == "LlamaTokenizer"
79
+
80
+ def encode(
81
+ self,
82
+ string: str,
83
+ device: Optional[torch.device] = None,
84
+ bos: Optional[bool] = None,
85
+ eos: bool = False,
86
+ max_length: int = -1,
87
+ ) -> torch.Tensor:
88
+ if self.backend == "huggingface":
89
+ tokens = self.processor.encode(string).ids
90
+ elif self.backend == "sentencepiece":
91
+ tokens = self.processor.encode(string)
92
+ else:
93
+ raise RuntimeError
94
+ if bos or (bos is None and self.use_bos):
95
+ bos_id = self.bos_id
96
+ if bos_id is None:
97
+ raise NotImplementedError("This tokenizer does not have a defined a bos token")
98
+ tokens = [bos_id] + tokens
99
+ if eos:
100
+ tokens = tokens + [self.eos_id]
101
+ if max_length > 0:
102
+ tokens = tokens[:max_length]
103
+ return torch.tensor(tokens, dtype=torch.int, device=device)
104
+
105
+ def decode(self, tensor: torch.Tensor) -> str:
106
+ tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()
107
+ return self.processor.decode(tokens)
lit_gpt/utils.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for training and inference."""
2
+ import math
3
+ import pickle
4
+ import sys
5
+ from contextlib import nullcontext
6
+ from io import BytesIO
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING, ContextManager, Dict, List, Mapping, Optional, TypeVar, Union
9
+
10
+ import lightning as L
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.utils._device
14
+ from lightning.fabric.strategies import FSDPStrategy
15
+ from lightning.fabric.utilities.load import _lazy_load as lazy_load
16
+ from torch.serialization import normalize_storage_type
17
+
18
+ if TYPE_CHECKING:
19
+ from lit_gpt import GPT
20
+
21
+
22
+ def find_multiple(n: int, k: int) -> int:
23
+ assert k > 0
24
+ if n % k == 0:
25
+ return n
26
+ return n + k - (n % k)
27
+
28
+
29
+ def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:
30
+ total = 0
31
+ for p in module.parameters():
32
+ if requires_grad is None or p.requires_grad == requires_grad:
33
+ if hasattr(p, "quant_state"):
34
+ # bitsandbytes 4bit layer support
35
+ total += math.prod(p.quant_state[1])
36
+ else:
37
+ total += p.numel()
38
+ return total
39
+
40
+
41
+ def gptq_quantization(enabled: bool = False) -> ContextManager:
42
+ if not enabled:
43
+ return nullcontext()
44
+
45
+ from lightning.fabric.plugins.precision.utils import _ClassReplacementContextManager
46
+
47
+ from quantize.gptq import ColBlockQuantizedLinear
48
+
49
+ class QuantizedLinear(ColBlockQuantizedLinear):
50
+ def __init__(self, *args, **kwargs):
51
+ super().__init__(*args, bits=4, tile_cols=-1, **kwargs)
52
+
53
+ return _ClassReplacementContextManager({"torch.nn.Linear": QuantizedLinear})
54
+
55
+
56
+ def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None:
57
+ files = {
58
+ "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(),
59
+ "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(),
60
+ "tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or (
61
+ checkpoint_dir / "tokenizer.model"
62
+ ).is_file(),
63
+ "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
64
+ }
65
+ if checkpoint_dir.is_dir():
66
+ if all(files.values()):
67
+ # we're good
68
+ return
69
+ problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}"
70
+ else:
71
+ problem = " is not a checkpoint directory"
72
+
73
+ # list locally available checkpoints
74
+ available = list(Path("checkpoints").glob("*/*"))
75
+ if available:
76
+ options = "\n --checkpoint_dir ".join([""] + [repr(str(p.resolve())) for p in available])
77
+ extra = f"\nYou have downloaded locally:{options}\n"
78
+ else:
79
+ extra = ""
80
+
81
+ error_message = (
82
+ f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
83
+ "\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n"
84
+ f"{extra}\nSee all download options by running:\n python scripts/download.py"
85
+ )
86
+ print(error_message, file=sys.stderr)
87
+ raise SystemExit(1)
88
+
89
+
90
+ class SavingProxyForStorage:
91
+ def __init__(self, obj, saver, protocol_version=5):
92
+ self.protocol_version = protocol_version
93
+ self.saver = saver
94
+ if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
95
+ raise TypeError(f"expected storage, not {type(obj)}")
96
+
97
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
98
+ if isinstance(obj, torch.storage.TypedStorage):
99
+ # PT upstream wants to deprecate this eventually...
100
+ storage = obj._untyped_storage
101
+ storage_type_str = obj._pickle_storage_type()
102
+ storage_type = getattr(torch, storage_type_str)
103
+ storage_numel = obj._size()
104
+ else:
105
+ storage = obj
106
+ storage_type = normalize_storage_type(type(obj))
107
+ storage_numel = storage.nbytes()
108
+
109
+ storage_key = saver._write_storage_and_return_key(storage)
110
+ location = torch.serialization.location_tag(storage)
111
+
112
+ self.storage_info = ("storage", storage_type, storage_key, location, storage_numel)
113
+
114
+ def __reduce_ex__(self, protocol_version):
115
+ assert False, "this should be handled with out of band"
116
+
117
+
118
+ class SavingProxyForTensor:
119
+ def __init__(self, tensor, saver, protocol_version=5):
120
+ self.protocol_version = protocol_version
121
+ self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version)
122
+ if reduce_args[0] == torch._utils._rebuild_tensor_v2:
123
+ # for Tensors with Python attributes
124
+ (a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args
125
+ assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates"
126
+ storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version)
127
+ self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args)
128
+ else:
129
+ (storage, *other_reduce_args) = reduce_args
130
+ assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates"
131
+ storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version)
132
+ self.reduce_args = (storage_proxy, *other_reduce_args)
133
+
134
+ def __reduce_ex__(self, protocol_version):
135
+ if protocol_version != self.protocol_version:
136
+ raise RuntimeError(f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}")
137
+ return self.reduce_ret_fn, self.reduce_args
138
+
139
+
140
+ class IncrementalPyTorchPickler(pickle.Pickler):
141
+ def __init__(self, saver, *args, **kwargs):
142
+ super().__init__(*args, **kwargs)
143
+ self.storage_dtypes = {}
144
+ self.saver = saver
145
+ self.id_map = {}
146
+
147
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
148
+ def persistent_id(self, obj):
149
+ # FIXME: the docs say that persistent_id should only return a string
150
+ # but torch store returns tuples. This works only in the binary protocol
151
+ # see
152
+ # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
153
+ # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
154
+ if isinstance(obj, SavingProxyForStorage):
155
+ return obj.storage_info
156
+
157
+ if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
158
+ if isinstance(obj, torch.storage.TypedStorage):
159
+ # TODO: Once we decide to break serialization FC, this case
160
+ # can be deleted
161
+ storage = obj._untyped_storage
162
+ storage_dtype = obj.dtype
163
+ storage_type_str = obj._pickle_storage_type()
164
+ storage_type = getattr(torch, storage_type_str)
165
+ storage_numel = obj._size()
166
+
167
+ else:
168
+ storage = obj
169
+ storage_dtype = torch.uint8
170
+ storage_type = normalize_storage_type(type(obj))
171
+ storage_numel = storage.nbytes()
172
+
173
+ # If storage is allocated, ensure that any other saved storages
174
+ # pointing to the same data all have the same dtype. If storage is
175
+ # not allocated, don't perform this check
176
+ if storage.data_ptr() != 0:
177
+ if storage.data_ptr() in self.storage_dtypes:
178
+ if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
179
+ raise RuntimeError(
180
+ "Cannot save multiple tensors or storages that view the same data as different types"
181
+ )
182
+ else:
183
+ self.storage_dtypes[storage.data_ptr()] = storage_dtype
184
+
185
+ storage_key = self.id_map.get(storage._cdata)
186
+ if storage_key is None:
187
+ storage_key = self.saver._write_storage_and_return_key(storage)
188
+ self.id_map[storage._cdata] = storage_key
189
+ location = torch.serialization.location_tag(storage)
190
+
191
+ return ("storage", storage_type, storage_key, location, storage_numel)
192
+
193
+ return None
194
+
195
+
196
+ class incremental_save:
197
+ def __init__(self, name):
198
+ self.name = name
199
+ self.zipfile = torch._C.PyTorchFileWriter(str(name))
200
+ self.has_saved = False
201
+ self.next_key = 0
202
+
203
+ def __enter__(self):
204
+ return self
205
+
206
+ def store_early(self, tensor):
207
+ if isinstance(tensor, torch.Tensor):
208
+ return SavingProxyForTensor(tensor, self)
209
+ raise TypeError(f"can only store tensors early, not {type(tensor)}")
210
+
211
+ def save(self, obj):
212
+ if self.has_saved:
213
+ raise RuntimeError("have already saved")
214
+ # Write the pickle data for `obj`
215
+ data_buf = BytesIO()
216
+ pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
217
+ pickler.dump(obj)
218
+ data_value = data_buf.getvalue()
219
+ self.zipfile.write_record("data.pkl", data_value, len(data_value))
220
+ self.has_saved = True
221
+
222
+ def _write_storage_and_return_key(self, storage):
223
+ if self.has_saved:
224
+ raise RuntimeError("have already saved")
225
+ key = self.next_key
226
+ self.next_key += 1
227
+ name = f"data/{key}"
228
+ if storage.device.type != "cpu":
229
+ storage = storage.cpu()
230
+ num_bytes = storage.nbytes()
231
+ self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
232
+ return key
233
+
234
+ def __exit__(self, type, value, traceback):
235
+ self.zipfile.write_end_of_file()
236
+
237
+
238
+ T = TypeVar("T")
239
+
240
+
241
+ def chunked_cross_entropy(
242
+ logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128
243
+ ) -> torch.Tensor:
244
+ # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate
245
+ # the memory usage in fine-tuning settings with low number of parameters.
246
+ # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing
247
+ # the memory spike's magnitude
248
+
249
+ # lm_head was chunked (we are fine-tuning)
250
+ if isinstance(logits, list):
251
+ # don't want to chunk cross entropy
252
+ if chunk_size == 0:
253
+ logits = torch.cat(logits, dim=1)
254
+ logits = logits.reshape(-1, logits.size(-1))
255
+ targets = targets.reshape(-1)
256
+ return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)
257
+
258
+ # chunk cross entropy
259
+ logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits]
260
+ target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)]
261
+ loss_chunks = [
262
+ torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none")
263
+ for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
264
+ ]
265
+ non_masked_elems = (targets != -1).sum()
266
+ mean_loss = torch.cat(loss_chunks).sum() / max(1, non_masked_elems)
267
+ return mean_loss
268
+
269
+ # no chunking at all
270
+ logits = logits.reshape(-1, logits.size(-1))
271
+ targets = targets.reshape(-1)
272
+ if chunk_size == 0:
273
+ return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)
274
+
275
+ # lm_head wasn't chunked, chunk cross entropy
276
+ logit_chunks = logits.split(chunk_size)
277
+ target_chunks = targets.split(chunk_size)
278
+ loss_chunks = [
279
+ torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none")
280
+ for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
281
+ ]
282
+ non_masked_elems = (targets != -1).sum()
283
+ mean_loss = torch.cat(loss_chunks).sum() / max(1, non_masked_elems)
284
+ return mean_loss
285
+
286
+
287
+ def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:
288
+ for checkpoint_name, attribute_name in mapping.items():
289
+ full_checkpoint_name = prefix + checkpoint_name
290
+ if full_checkpoint_name in state_dict:
291
+ full_attribute_name = prefix + attribute_name
292
+ state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name)
293
+ return state_dict
294
+
295
+
296
+ def get_default_supported_precision(training: bool) -> str:
297
+ """Return default precision that is supported by the hardware: either `bf16` or `16`.
298
+
299
+ Args:
300
+ training: `-mixed` or `-true` version of the precision to use
301
+
302
+ Returns:
303
+ default precision that is suitable for the task and is supported by the hardware
304
+ """
305
+ from lightning.fabric.accelerators import MPSAccelerator
306
+
307
+ if MPSAccelerator.is_available() or (torch.cuda.is_available() and not torch.cuda.is_bf16_supported()):
308
+ return "16-mixed" if training else "16-true"
309
+ return "bf16-mixed" if training else "bf16-true"
310
+
311
+
312
+ def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True) -> None:
313
+ if isinstance(fabric.strategy, FSDPStrategy):
314
+ fabric.load_raw(checkpoint_path, model, strict=strict)
315
+ else:
316
+ state_dict = lazy_load(checkpoint_path)
317
+ state_dict = state_dict.get("model", state_dict)
318
+ model.load_state_dict(state_dict, strict=strict)
319
+
320
+
321
+ def flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int:
322
+ flops_per_token = 2 * n_params # each parameter is used for a MAC (2 FLOPS) per network operation
323
+ # this assumes that all samples have a fixed length equal to the block size
324
+ # which is most likely false during finetuning
325
+ flops_per_seq = flops_per_token * max_seq_length
326
+ attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))
327
+ return flops_per_seq + attn_flops_per_seq
328
+
329
+
330
+ def estimate_flops(model: "GPT", training: bool) -> int:
331
+ """Measures estimated FLOPs for MFU.
332
+
333
+ Refs:
334
+ * https://ar5iv.labs.arxiv.org/html/2205.05198#A1
335
+ * https://ar5iv.labs.arxiv.org/html/2204.02311#A2
336
+ """
337
+ # using all parameters for this is a naive over estimation because not all model parameters actually contribute to
338
+ # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage
339
+ # (~10%) compared to the measured FLOPs, making those lower but more realistic.
340
+ # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.
341
+ n_trainable_params = num_parameters(model, requires_grad=True)
342
+ trainable_flops = flops_per_param(
343
+ model.max_seq_length, model.config.n_layer, model.config.n_embd, n_trainable_params
344
+ )
345
+ # forward + backward + gradients (assumes no gradient accumulation)
346
+ ops_per_step = 3 if training else 1
347
+ n_frozen_params = num_parameters(model, requires_grad=False)
348
+ frozen_flops = flops_per_param(model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params)
349
+ # forward + backward
350
+ frozen_ops_per_step = 2 if training else 1
351
+ return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops
main.ipynb ADDED
@@ -0,0 +1,714 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "data": {
10
+ "text/plain": [
11
+ "True"
12
+ ]
13
+ },
14
+ "execution_count": 1,
15
+ "metadata": {},
16
+ "output_type": "execute_result"
17
+ }
18
+ ],
19
+ "source": [
20
+ "import torch\n",
21
+ "\n",
22
+ "torch.cuda.is_available()"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 2,
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "import glob\n",
32
+ "import math\n",
33
+ "import sys\n",
34
+ "import time\n",
35
+ "from pathlib import Path\n",
36
+ "from typing import Optional, Tuple, Union\n",
37
+ "\n",
38
+ "import lightning as L\n",
39
+ "import torch\n",
40
+ "from lightning.fabric.loggers import CSVLogger\n",
41
+ "from lightning.fabric.strategies import FSDPStrategy\n",
42
+ "from torch.utils.data import DataLoader\n",
43
+ "\n",
44
+ "# # support running without installing as a package\n",
45
+ "# wd = Path(__file__).parent.parent.resolve()\n",
46
+ "# sys.path.append(str(wd))\n",
47
+ "\n",
48
+ "from tsai_gpt.model import GPT, Block, Config\n",
49
+ "from tsai_gpt.packed_dataset import CombinedDataset, PackedDataset\n",
50
+ "from tsai_gpt.speed_monitor import SpeedMonitorBase, estimate_flops, measure_flops\n",
51
+ "from tsai_gpt.speed_monitor import SpeedMonitorFabric as SpeedMonitor\n",
52
+ "from tsai_gpt.utils import (\n",
53
+ " chunked_cross_entropy,\n",
54
+ " get_default_supported_precision,\n",
55
+ " num_parameters,\n",
56
+ " load_checkpoint,\n",
57
+ ")"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": 3,
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "model_name = \"pythia-160m\"\n",
67
+ "name = \"redpajama\"\n",
68
+ "out_dir = Path(\"out\") / name\n",
69
+ "save_interval = 1000\n",
70
+ "eval_interval = 1000\n",
71
+ "eval_iters = 100\n",
72
+ "log_interval = 100"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": 4,
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "# Hyperparameters\n",
82
+ "learning_rate = 6e-3\n",
83
+ "batch_size = 32\n",
84
+ "micro_batch_size = 8\n",
85
+ "gradient_accumulation_steps = batch_size // micro_batch_size\n",
86
+ "assert gradient_accumulation_steps > 0\n",
87
+ "# max_iters = 600000 # num_epochs * (epoch_size // micro_batch_size) // devices\n",
88
+ "max_iters = 15000\n",
89
+ "weight_decay = 1e-1\n",
90
+ "beta1 = 0.9\n",
91
+ "beta2 = 0.95\n",
92
+ "grad_clip = 1.0\n",
93
+ "decay_lr = True\n",
94
+ "warmup_iters = 2000\n",
95
+ "lr_decay_iters = max_iters\n",
96
+ "min_lr = 6e-6"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": 5,
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "# Data proportions from https://arxiv.org/pdf/2302.13971.pdf Table 1\n",
106
+ "data_config = [\n",
107
+ " (\"arxiv\", 2.5),\n",
108
+ " (\"book\", 4.5),\n",
109
+ " (\"c4\", 15.0),\n",
110
+ " (\"cc\", 67.0),\n",
111
+ " (\"github\", 4.5),\n",
112
+ " (\"stackexchange\", 2.0),\n",
113
+ " (\"wikipedia\", 4.5),\n",
114
+ "]"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": 6,
120
+ "metadata": {},
121
+ "outputs": [],
122
+ "source": [
123
+ "hparams = {\n",
124
+ " k: v\n",
125
+ " for k, v in locals().items()\n",
126
+ " if isinstance(v, (int, float, str)) and not k.startswith(\"_\")\n",
127
+ "}\n",
128
+ "logger = CSVLogger(\"out\", name, flush_logs_every_n_steps=log_interval)\n",
129
+ "\n",
130
+ "\n",
131
+ "def setup(\n",
132
+ " devices: int = 4,\n",
133
+ " train_data_dir: Path = Path(\"data/redpajama_sample\"),\n",
134
+ " val_data_dir: Optional[Path] = None,\n",
135
+ " precision: Optional[str] = None,\n",
136
+ " resume: Union[bool, Path] = False,\n",
137
+ ") -> None:\n",
138
+ " precision = precision or get_default_supported_precision(training=True)\n",
139
+ "\n",
140
+ " if devices > 1:\n",
141
+ " strategy = FSDPStrategy(\n",
142
+ " auto_wrap_policy={Block},\n",
143
+ " activation_checkpointing_policy={Block},\n",
144
+ " state_dict_type=\"full\",\n",
145
+ " limit_all_gathers=True,\n",
146
+ " cpu_offload=False,\n",
147
+ " )\n",
148
+ " else:\n",
149
+ " strategy = \"auto\"\n",
150
+ "\n",
151
+ " fabric = L.Fabric(\n",
152
+ " devices=devices, strategy=strategy, precision=precision, loggers=logger\n",
153
+ " )\n",
154
+ " fabric.print(hparams)\n",
155
+ " fabric.launch(main, train_data_dir, val_data_dir, resume)"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": 7,
161
+ "metadata": {},
162
+ "outputs": [],
163
+ "source": [
164
+ "model_copy = None"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": 8,
170
+ "metadata": {},
171
+ "outputs": [],
172
+ "source": [
173
+ "def main(\n",
174
+ " fabric: L.Fabric,\n",
175
+ " train_data_dir: Path,\n",
176
+ " val_data_dir: Path,\n",
177
+ " resume: Union[bool, Path],\n",
178
+ ") -> None:\n",
179
+ " global model_copy\n",
180
+ " speed_monitor = SpeedMonitor(fabric, window_size=50, time_unit=\"seconds\")\n",
181
+ "\n",
182
+ " if fabric.global_rank == 0:\n",
183
+ " out_dir.mkdir(parents=True, exist_ok=True)\n",
184
+ "\n",
185
+ " config = Config.from_name(model_name)\n",
186
+ "\n",
187
+ " train_dataloader, val_dataloader = create_dataloaders(\n",
188
+ " batch_size=micro_batch_size,\n",
189
+ " block_size=config.block_size,\n",
190
+ " fabric=fabric,\n",
191
+ " train_data_dir=train_data_dir,\n",
192
+ " val_data_dir=val_data_dir,\n",
193
+ " seed=(1337 + fabric.global_rank),\n",
194
+ " )\n",
195
+ " if val_dataloader is None:\n",
196
+ " train_dataloader = fabric.setup_dataloaders(train_dataloader)\n",
197
+ " else:\n",
198
+ " train_dataloader, val_dataloader = fabric.setup_dataloaders(\n",
199
+ " train_dataloader, val_dataloader\n",
200
+ " )\n",
201
+ "\n",
202
+ " fabric.seed_everything(1337) # same seed for every process to init model (FSDP)\n",
203
+ "\n",
204
+ " fabric.print(f\"Loading model with {config.__dict__}\")\n",
205
+ " t0 = time.perf_counter()\n",
206
+ " import torch\n",
207
+ " import torch.nn as nn\n",
208
+ "\n",
209
+ " def _init_weights(module: nn.Module) -> None:\n",
210
+ " \"\"\"Meant to be used with `gpt.apply(gpt._init_weights)`.\"\"\"\n",
211
+ " if isinstance(module, nn.Linear):\n",
212
+ " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
213
+ " if module.bias is not None:\n",
214
+ " torch.nn.init.zeros_(module.bias)\n",
215
+ " elif isinstance(module, nn.Embedding):\n",
216
+ " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
217
+ "\n",
218
+ " with fabric.init_module(empty_init=True):\n",
219
+ " model = GPT(config)\n",
220
+ " model.apply(_init_weights)\n",
221
+ " model.apply(_init_weights)\n",
222
+ "\n",
223
+ " # checkpoint_path = Path(\"out/redpajama/iter-000999-ckpt.pth\")\n",
224
+ "\n",
225
+ " # load_checkpoint(fabric, model, checkpoint_path)\n",
226
+ "\n",
227
+ " # print(model.transformer.h[0].mlp.fc.weight)\n",
228
+ "\n",
229
+ " fabric.print(f\"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.\")\n",
230
+ " fabric.print(f\"Total parameters {num_parameters(model):,}\")\n",
231
+ "\n",
232
+ " model = fabric.setup(model)\n",
233
+ " optimizer = torch.optim.AdamW(\n",
234
+ " model.parameters(),\n",
235
+ " lr=learning_rate,\n",
236
+ " weight_decay=weight_decay,\n",
237
+ " betas=(beta1, beta2),\n",
238
+ " foreach=False,\n",
239
+ " )\n",
240
+ "\n",
241
+ " # model_copy = model\n",
242
+ "\n",
243
+ " optimizer = fabric.setup_optimizers(optimizer)\n",
244
+ "\n",
245
+ " state = {\n",
246
+ " \"model\": model,\n",
247
+ " \"optimizer\": optimizer,\n",
248
+ " \"hparams\": hparams,\n",
249
+ " \"iter_num\": 0,\n",
250
+ " \"step_count\": 0,\n",
251
+ " }\n",
252
+ "\n",
253
+ " if resume is True:\n",
254
+ " resume = max(out_dir.glob(\"*.pth\"), key=lambda p: int(p.name.split(\"-\")[1]))\n",
255
+ " if resume:\n",
256
+ " fabric.print(f\"Resuming training from {resume}\")\n",
257
+ " fabric.load(resume, state)\n",
258
+ "\n",
259
+ " train_time = time.perf_counter()\n",
260
+ " train(fabric, state, train_dataloader, val_dataloader, speed_monitor)\n",
261
+ " fabric.print(f\"Training time: {(time.perf_counter()-train_time):.2f}s\")\n",
262
+ " if fabric.device.type == \"cuda\":\n",
263
+ " fabric.print(f\"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB\")"
264
+ ]
265
+ },
266
+ {
267
+ "cell_type": "code",
268
+ "execution_count": 9,
269
+ "metadata": {},
270
+ "outputs": [],
271
+ "source": [
272
+ "def train(\n",
273
+ " fabric: L.Fabric,\n",
274
+ " state: dict,\n",
275
+ " train_dataloader: DataLoader,\n",
276
+ " val_dataloader: DataLoader,\n",
277
+ " speed_monitor: SpeedMonitorBase,\n",
278
+ ") -> None:\n",
279
+ " model = state[\"model\"]\n",
280
+ " optimizer = state[\"optimizer\"]\n",
281
+ "\n",
282
+ " if val_dataloader is not None:\n",
283
+ " validate(fabric, model, val_dataloader) # sanity check\n",
284
+ "\n",
285
+ " with torch.device(\"meta\"):\n",
286
+ " meta_model = GPT(model.config)\n",
287
+ " # \"estimated\" is not as precise as \"measured\". Estimated is optimistic but widely used in the wild.\n",
288
+ " # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs,\n",
289
+ " # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead\n",
290
+ " estimated_flops = estimate_flops(meta_model) * micro_batch_size\n",
291
+ " fabric.print(\n",
292
+ " f\"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}\"\n",
293
+ " )\n",
294
+ " x = torch.randint(0, 1, (micro_batch_size, model.max_seq_length))\n",
295
+ " measured_flops = measure_flops(meta_model, x)\n",
296
+ " fabric.print(\n",
297
+ " f\"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}\"\n",
298
+ " )\n",
299
+ " del meta_model, x\n",
300
+ "\n",
301
+ " total_lengths = 0\n",
302
+ " total_t0 = time.perf_counter()\n",
303
+ "\n",
304
+ " for state[\"iter_num\"], train_data in enumerate(train_dataloader, state[\"iter_num\"]):\n",
305
+ " if state[\"iter_num\"] >= max_iters:\n",
306
+ " checkpoint_path = out_dir / f\"iter-{state['iter_num']:06d}-ckpt.pth\"\n",
307
+ " fabric.print(f\"Saving checkpoint to {str(checkpoint_path)!r}\")\n",
308
+ " fabric.save(checkpoint_path, state)\n",
309
+ " break\n",
310
+ "\n",
311
+ " # determine and set the learning rate for this iteration\n",
312
+ " lr = get_lr(state[\"iter_num\"]) if decay_lr else learning_rate\n",
313
+ " for param_group in optimizer.param_groups:\n",
314
+ " param_group[\"lr\"] = lr\n",
315
+ "\n",
316
+ " iter_t0 = time.perf_counter()\n",
317
+ "\n",
318
+ " input_ids = train_data[:, 0 : model.max_seq_length].contiguous()\n",
319
+ " targets = train_data[:, 1 : model.max_seq_length + 1].contiguous()\n",
320
+ "\n",
321
+ " is_accumulating = (state[\"iter_num\"] + 1) % gradient_accumulation_steps != 0\n",
322
+ " with fabric.no_backward_sync(model, enabled=is_accumulating):\n",
323
+ " logits = model(input_ids)\n",
324
+ " loss = chunked_cross_entropy(logits, targets, chunk_size=0)\n",
325
+ " fabric.backward(loss / gradient_accumulation_steps)\n",
326
+ "\n",
327
+ " # return\n",
328
+ "\n",
329
+ " if not is_accumulating:\n",
330
+ " fabric.clip_gradients(model, optimizer, max_norm=grad_clip)\n",
331
+ " optimizer.step()\n",
332
+ " optimizer.zero_grad()\n",
333
+ " state[\"step_count\"] += 1\n",
334
+ "\n",
335
+ " t1 = time.perf_counter()\n",
336
+ " total_lengths += input_ids.size(1)\n",
337
+ " speed_monitor.on_train_batch_end(\n",
338
+ " (state[\"iter_num\"] + 1) * micro_batch_size,\n",
339
+ " t1 - total_t0,\n",
340
+ " # this assumes that device FLOPs are the same and that all devices have the same batch size\n",
341
+ " fabric.world_size,\n",
342
+ " flops_per_batch=measured_flops,\n",
343
+ " lengths=total_lengths,\n",
344
+ " )\n",
345
+ " if state[\"iter_num\"] % log_interval == 0:\n",
346
+ " fabric.print(\n",
347
+ " f\"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, LR: {lr:.6f}, iter time:\"\n",
348
+ " f\" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}\"\n",
349
+ " )\n",
350
+ "\n",
351
+ " if (\n",
352
+ " val_dataloader is not None\n",
353
+ " and not is_accumulating\n",
354
+ " and state[\"step_count\"] % eval_interval == 0\n",
355
+ " ):\n",
356
+ " t0 = time.perf_counter()\n",
357
+ " val_loss = validate(fabric, model, val_dataloader)\n",
358
+ " t1 = time.perf_counter() - t0\n",
359
+ " speed_monitor.eval_end(t1)\n",
360
+ " fabric.print(\n",
361
+ " f\"step {state['iter_num']}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f}ms\"\n",
362
+ " )\n",
363
+ " fabric.barrier()\n",
364
+ " if not is_accumulating and state[\"step_count\"] % save_interval == 0:\n",
365
+ " checkpoint_path = out_dir / f\"iter-{state['iter_num']:06d}-ckpt.pth\"\n",
366
+ " fabric.print(f\"Saving checkpoint to {str(checkpoint_path)!r}\")\n",
367
+ " fabric.save(checkpoint_path, state)"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "code",
372
+ "execution_count": 10,
373
+ "metadata": {},
374
+ "outputs": [],
375
+ "source": [
376
+ "@torch.inference_mode()\n",
377
+ "def validate(\n",
378
+ " fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader\n",
379
+ ") -> torch.Tensor:\n",
380
+ " fabric.print(\"Validating ...\")\n",
381
+ " model.eval()\n",
382
+ "\n",
383
+ " losses = torch.zeros(eval_iters, device=fabric.device)\n",
384
+ " for k, val_data in enumerate(val_dataloader):\n",
385
+ " input_ids = val_data[:, 0 : model.max_seq_length].contiguous()\n",
386
+ " targets = val_data[:, 1 : model.max_seq_length + 1].contiguous()\n",
387
+ " logits = model(input_ids)\n",
388
+ " losses[k] = chunked_cross_entropy(logits, targets, chunk_size=0)\n",
389
+ " out = losses.mean()\n",
390
+ "\n",
391
+ " model.train()\n",
392
+ " return out"
393
+ ]
394
+ },
395
+ {
396
+ "cell_type": "code",
397
+ "execution_count": 11,
398
+ "metadata": {},
399
+ "outputs": [],
400
+ "source": [
401
+ "def create_dataloader(\n",
402
+ " batch_size: int,\n",
403
+ " block_size: int,\n",
404
+ " data_dir: Path,\n",
405
+ " fabric: L.Fabric,\n",
406
+ " shuffle: bool = True,\n",
407
+ " seed: int = 12345,\n",
408
+ ") -> DataLoader:\n",
409
+ " datasets = []\n",
410
+ " for prefix, _ in data_config:\n",
411
+ " filenames = glob.glob(str(data_dir / f\"{prefix}*\"))\n",
412
+ " dataset = PackedDataset(\n",
413
+ " filenames,\n",
414
+ " n_chunks=4,\n",
415
+ " block_size=block_size,\n",
416
+ " shuffle=shuffle,\n",
417
+ " seed=seed,\n",
418
+ " num_processes=fabric.world_size,\n",
419
+ " process_rank=fabric.global_rank,\n",
420
+ " )\n",
421
+ " datasets.append(dataset)\n",
422
+ "\n",
423
+ " if not datasets:\n",
424
+ " raise RuntimeError(\n",
425
+ " f\"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset.\"\n",
426
+ " )\n",
427
+ "\n",
428
+ " weights = [weight for _, weight in data_config]\n",
429
+ " sum_weights = sum(weights)\n",
430
+ " weights = [el / sum_weights for el in weights]\n",
431
+ "\n",
432
+ " combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights)\n",
433
+ "\n",
434
+ " return DataLoader(\n",
435
+ " combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True\n",
436
+ " )"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": 12,
442
+ "metadata": {},
443
+ "outputs": [],
444
+ "source": [
445
+ "def create_dataloaders(\n",
446
+ " batch_size: int,\n",
447
+ " block_size: int,\n",
448
+ " fabric: L.Fabric,\n",
449
+ " train_data_dir: Path = Path(\"data/redpajama_sample\"),\n",
450
+ " val_data_dir: Optional[Path] = None,\n",
451
+ " seed: int = 12345,\n",
452
+ ") -> Tuple[DataLoader, DataLoader]:\n",
453
+ " # Increase by one because we need the next word as well\n",
454
+ " effective_block_size = block_size + 1\n",
455
+ " train_dataloader = create_dataloader(\n",
456
+ " batch_size=batch_size,\n",
457
+ " block_size=effective_block_size,\n",
458
+ " fabric=fabric,\n",
459
+ " data_dir=train_data_dir,\n",
460
+ " shuffle=True,\n",
461
+ " seed=seed,\n",
462
+ " )\n",
463
+ " val_dataloader = (\n",
464
+ " create_dataloader(\n",
465
+ " batch_size=batch_size,\n",
466
+ " block_size=effective_block_size,\n",
467
+ " fabric=fabric,\n",
468
+ " data_dir=val_data_dir,\n",
469
+ " shuffle=False,\n",
470
+ " seed=seed,\n",
471
+ " )\n",
472
+ " if val_data_dir\n",
473
+ " else None\n",
474
+ " )\n",
475
+ " return train_dataloader, val_dataloader"
476
+ ]
477
+ },
478
+ {
479
+ "cell_type": "code",
480
+ "execution_count": 13,
481
+ "metadata": {},
482
+ "outputs": [],
483
+ "source": [
484
+ "def get_lr(it: int) -> float:\n",
485
+ " # 1) linear warmup for warmup_iters steps\n",
486
+ " if it < warmup_iters:\n",
487
+ " return learning_rate * it / warmup_iters\n",
488
+ " # 2) if it > lr_decay_iters, return min learning rate\n",
489
+ " if it > lr_decay_iters:\n",
490
+ " return min_lr\n",
491
+ " # 3) in between, use cosine decay down to min learning rate\n",
492
+ " decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)\n",
493
+ " assert 0 <= decay_ratio <= 1\n",
494
+ " coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1\n",
495
+ " return min_lr + coeff * (learning_rate - min_lr)"
496
+ ]
497
+ },
498
+ {
499
+ "cell_type": "code",
500
+ "execution_count": 14,
501
+ "metadata": {},
502
+ "outputs": [
503
+ {
504
+ "name": "stderr",
505
+ "output_type": "stream",
506
+ "text": [
507
+ "Using bfloat16 Automatic Mixed Precision (AMP)\n",
508
+ "Seed set to 1337\n"
509
+ ]
510
+ },
511
+ {
512
+ "name": "stdout",
513
+ "output_type": "stream",
514
+ "text": [
515
+ "{'model_name': 'pythia-160m', 'name': 'redpajama', 'save_interval': 1000, 'eval_interval': 1000, 'eval_iters': 100, 'log_interval': 100, 'learning_rate': 0.006, 'batch_size': 32, 'micro_batch_size': 8, 'gradient_accumulation_steps': 4, 'max_iters': 15000, 'weight_decay': 0.1, 'beta1': 0.9, 'beta2': 0.95, 'grad_clip': 1.0, 'decay_lr': True, 'warmup_iters': 2000, 'lr_decay_iters': 15000, 'min_lr': 6e-06}\n",
516
+ "Loading model with {'name': 'pythia-160m', 'hf_config': {'org': 'EleutherAI', 'name': 'pythia-160m-deduped'}, 'block_size': 2048, 'vocab_size': 50254, 'padding_multiple': 128, 'padded_vocab_size': 50304, 'n_layer': 12, 'n_head': 12, 'n_embd': 768, 'rotary_percentage': 0.25, 'parallel_residual': True, 'bias': True, 'lm_head_bias': False, 'n_query_groups': 12, 'shared_attention_norm': False, '_norm_class': 'LayerNorm', 'norm_eps': 1e-05, '_mlp_class': 'GptNeoxMLP', 'gelu_approximate': 'none', 'intermediate_size': 3072, 'rope_condense_ratio': 1, 'rope_base': 10000, 'head_size': 64, 'rope_n_elem': 16}\n",
517
+ "Time to instantiate model: 1.99 seconds.\n",
518
+ "Total parameters 162,322,944\n",
519
+ "Estimated TFLOPs: 22.14\n",
520
+ "Measured TFLOPs: 15.86\n",
521
+ "iter 0 step 0: loss 11.0478, LR: 0.000000, iter time: 1312.30ms\n",
522
+ "iter 100 step 25: loss 7.3711, LR: 0.000300, iter time: 282.00ms\n",
523
+ "iter 200 step 50: loss 5.9653, LR: 0.000600, iter time: 293.93ms\n",
524
+ "iter 300 step 75: loss 6.1456, LR: 0.000900, iter time: 290.72ms\n",
525
+ "iter 400 step 100: loss 6.4233, LR: 0.001200, iter time: 291.77ms\n",
526
+ "iter 500 step 125: loss 5.8922, LR: 0.001500, iter time: 292.98ms\n",
527
+ "iter 600 step 150: loss 5.7330, LR: 0.001800, iter time: 292.54ms\n",
528
+ "iter 700 step 175: loss 5.2412, LR: 0.002100, iter time: 293.18ms\n",
529
+ "iter 800 step 200: loss 4.7973, LR: 0.002400, iter time: 291.61ms\n",
530
+ "iter 900 step 225: loss 5.4157, LR: 0.002700, iter time: 292.85ms\n",
531
+ "iter 1000 step 250: loss 5.1732, LR: 0.003000, iter time: 292.74ms\n",
532
+ "iter 1100 step 275: loss 5.1144, LR: 0.003300, iter time: 291.97ms\n",
533
+ "iter 1200 step 300: loss 4.6204, LR: 0.003600, iter time: 291.41ms\n",
534
+ "iter 1300 step 325: loss 5.2649, LR: 0.003900, iter time: 292.33ms\n",
535
+ "iter 1400 step 350: loss 5.3906, LR: 0.004200, iter time: 291.61ms\n",
536
+ "iter 1500 step 375: loss 5.1544, LR: 0.004500, iter time: 292.87ms\n",
537
+ "iter 1600 step 400: loss 5.2281, LR: 0.004800, iter time: 291.19ms\n",
538
+ "iter 1700 step 425: loss 4.6215, LR: 0.005100, iter time: 290.65ms\n",
539
+ "iter 1800 step 450: loss 5.1470, LR: 0.005400, iter time: 291.07ms\n",
540
+ "iter 1900 step 475: loss 5.1262, LR: 0.005700, iter time: 291.85ms\n",
541
+ "iter 2000 step 500: loss 4.7982, LR: 0.006000, iter time: 291.74ms\n",
542
+ "iter 2100 step 525: loss 4.7870, LR: 0.005999, iter time: 291.40ms\n",
543
+ "iter 2200 step 550: loss 4.6758, LR: 0.005997, iter time: 291.24ms\n",
544
+ "iter 2300 step 575: loss 4.2770, LR: 0.005992, iter time: 290.94ms\n",
545
+ "iter 2400 step 600: loss 4.9993, LR: 0.005986, iter time: 290.82ms\n",
546
+ "iter 2500 step 625: loss 4.7006, LR: 0.005978, iter time: 291.72ms\n",
547
+ "iter 2600 step 650: loss 4.4606, LR: 0.005969, iter time: 291.41ms\n",
548
+ "iter 2700 step 675: loss 4.2507, LR: 0.005957, iter time: 291.65ms\n",
549
+ "iter 2800 step 700: loss 4.2737, LR: 0.005944, iter time: 298.98ms\n",
550
+ "iter 2900 step 725: loss 3.2729, LR: 0.005929, iter time: 291.06ms\n",
551
+ "iter 3000 step 750: loss 3.6851, LR: 0.005913, iter time: 290.95ms\n",
552
+ "iter 3100 step 775: loss 4.3133, LR: 0.005895, iter time: 291.41ms\n",
553
+ "iter 3200 step 800: loss 4.0082, LR: 0.005875, iter time: 290.55ms\n",
554
+ "iter 3300 step 825: loss 4.4818, LR: 0.005853, iter time: 291.40ms\n",
555
+ "iter 3400 step 850: loss 4.0966, LR: 0.005830, iter time: 291.75ms\n",
556
+ "iter 3500 step 875: loss 3.3417, LR: 0.005805, iter time: 291.56ms\n",
557
+ "iter 3600 step 900: loss 3.3930, LR: 0.005779, iter time: 291.98ms\n",
558
+ "iter 3700 step 925: loss 3.9926, LR: 0.005751, iter time: 291.38ms\n",
559
+ "iter 3800 step 950: loss 4.4130, LR: 0.005721, iter time: 290.98ms\n",
560
+ "iter 3900 step 975: loss 4.2273, LR: 0.005690, iter time: 290.82ms\n",
561
+ "Saving checkpoint to 'out/redpajama/iter-003999-ckpt.pth'\n",
562
+ "iter 4000 step 1000: loss 4.1836, LR: 0.005657, iter time: 289.39ms\n",
563
+ "iter 4100 step 1025: loss 3.8898, LR: 0.005622, iter time: 290.57ms\n",
564
+ "iter 4200 step 1050: loss 3.2994, LR: 0.005586, iter time: 290.66ms\n",
565
+ "iter 4300 step 1075: loss 3.5536, LR: 0.005549, iter time: 291.97ms\n",
566
+ "iter 4400 step 1100: loss 4.0568, LR: 0.005510, iter time: 290.74ms\n",
567
+ "iter 4500 step 1125: loss 4.0688, LR: 0.005469, iter time: 291.51ms\n",
568
+ "iter 4600 step 1150: loss 3.9602, LR: 0.005428, iter time: 291.69ms\n",
569
+ "iter 4700 step 1175: loss 3.9015, LR: 0.005384, iter time: 291.05ms\n",
570
+ "iter 4800 step 1200: loss 3.9838, LR: 0.005340, iter time: 290.89ms\n",
571
+ "iter 4900 step 1225: loss 4.1498, LR: 0.005294, iter time: 291.43ms\n",
572
+ "iter 5000 step 1250: loss 3.9890, LR: 0.005246, iter time: 292.04ms\n",
573
+ "iter 5100 step 1275: loss 3.7998, LR: 0.005198, iter time: 291.67ms\n",
574
+ "iter 5200 step 1300: loss 4.3898, LR: 0.005148, iter time: 292.07ms\n",
575
+ "iter 5300 step 1325: loss 3.8301, LR: 0.005096, iter time: 291.71ms\n",
576
+ "iter 5400 step 1350: loss 3.9250, LR: 0.005044, iter time: 291.87ms\n",
577
+ "iter 5500 step 1375: loss 3.4592, LR: 0.004990, iter time: 292.45ms\n",
578
+ "iter 5600 step 1400: loss 3.9057, LR: 0.004936, iter time: 292.48ms\n",
579
+ "iter 5700 step 1425: loss 3.4640, LR: 0.004880, iter time: 292.17ms\n",
580
+ "iter 5800 step 1450: loss 3.5189, LR: 0.004823, iter time: 291.53ms\n",
581
+ "iter 5900 step 1475: loss 3.8723, LR: 0.004765, iter time: 291.76ms\n",
582
+ "iter 6000 step 1500: loss 3.5505, LR: 0.004705, iter time: 291.40ms\n",
583
+ "iter 6100 step 1525: loss 2.7599, LR: 0.004645, iter time: 290.44ms\n",
584
+ "iter 6200 step 1550: loss 4.0639, LR: 0.004584, iter time: 290.73ms\n",
585
+ "iter 6300 step 1575: loss 3.9124, LR: 0.004522, iter time: 290.77ms\n",
586
+ "iter 6400 step 1600: loss 3.7831, LR: 0.004459, iter time: 290.48ms\n",
587
+ "iter 6500 step 1625: loss 3.6439, LR: 0.004396, iter time: 291.02ms\n",
588
+ "iter 6600 step 1650: loss 3.6231, LR: 0.004331, iter time: 293.27ms\n",
589
+ "iter 6700 step 1675: loss 3.4389, LR: 0.004266, iter time: 291.11ms\n",
590
+ "iter 6800 step 1700: loss 3.5385, LR: 0.004200, iter time: 290.80ms\n",
591
+ "iter 6900 step 1725: loss 3.4988, LR: 0.004133, iter time: 291.01ms\n",
592
+ "iter 7000 step 1750: loss 3.8966, LR: 0.004066, iter time: 290.56ms\n",
593
+ "iter 7100 step 1775: loss 3.6816, LR: 0.003998, iter time: 290.93ms\n",
594
+ "iter 7200 step 1800: loss 3.4510, LR: 0.003929, iter time: 291.20ms\n",
595
+ "iter 7300 step 1825: loss 3.9102, LR: 0.003860, iter time: 292.28ms\n",
596
+ "iter 7400 step 1850: loss 3.6360, LR: 0.003790, iter time: 291.56ms\n",
597
+ "iter 7500 step 1875: loss 3.8664, LR: 0.003720, iter time: 290.58ms\n",
598
+ "iter 7600 step 1900: loss 3.6073, LR: 0.003650, iter time: 291.40ms\n",
599
+ "iter 7700 step 1925: loss 2.9199, LR: 0.003579, iter time: 290.78ms\n",
600
+ "iter 7800 step 1950: loss 2.7844, LR: 0.003508, iter time: 290.67ms\n",
601
+ "iter 7900 step 1975: loss 3.1176, LR: 0.003436, iter time: 291.73ms\n",
602
+ "Saving checkpoint to 'out/redpajama/iter-007999-ckpt.pth'\n",
603
+ "iter 8000 step 2000: loss 3.7936, LR: 0.003364, iter time: 290.49ms\n",
604
+ "iter 8100 step 2025: loss 3.6197, LR: 0.003292, iter time: 290.46ms\n",
605
+ "iter 8200 step 2050: loss 3.7480, LR: 0.003220, iter time: 291.78ms\n",
606
+ "iter 8300 step 2075: loss 3.6900, LR: 0.003148, iter time: 291.11ms\n",
607
+ "iter 8400 step 2100: loss 2.8864, LR: 0.003075, iter time: 291.39ms\n",
608
+ "iter 8500 step 2125: loss 3.6963, LR: 0.003003, iter time: 291.51ms\n",
609
+ "iter 8600 step 2150: loss 3.7093, LR: 0.002931, iter time: 291.80ms\n",
610
+ "iter 8700 step 2175: loss 3.3042, LR: 0.002858, iter time: 290.53ms\n",
611
+ "iter 8800 step 2200: loss 3.0944, LR: 0.002786, iter time: 290.83ms\n",
612
+ "iter 8900 step 2225: loss 3.4312, LR: 0.002714, iter time: 290.81ms\n",
613
+ "iter 9000 step 2250: loss 3.5048, LR: 0.002642, iter time: 290.99ms\n",
614
+ "iter 9100 step 2275: loss 3.2803, LR: 0.002570, iter time: 291.00ms\n",
615
+ "iter 9200 step 2300: loss 3.5930, LR: 0.002498, iter time: 292.10ms\n",
616
+ "iter 9300 step 2325: loss 2.2495, LR: 0.002427, iter time: 290.29ms\n",
617
+ "iter 9400 step 2350: loss 2.9088, LR: 0.002356, iter time: 290.19ms\n",
618
+ "iter 9500 step 2375: loss 2.6597, LR: 0.002286, iter time: 291.29ms\n",
619
+ "iter 9600 step 2400: loss 3.6206, LR: 0.002216, iter time: 291.64ms\n",
620
+ "iter 9700 step 2425: loss 2.3134, LR: 0.002146, iter time: 289.83ms\n",
621
+ "iter 9800 step 2450: loss 2.4301, LR: 0.002077, iter time: 289.59ms\n",
622
+ "iter 9900 step 2475: loss 2.4800, LR: 0.002008, iter time: 290.77ms\n",
623
+ "iter 10000 step 2500: loss 2.2368, LR: 0.001940, iter time: 290.11ms\n",
624
+ "iter 10100 step 2525: loss 3.1508, LR: 0.001873, iter time: 291.03ms\n",
625
+ "iter 10200 step 2550: loss 3.2954, LR: 0.001806, iter time: 291.14ms\n",
626
+ "iter 10300 step 2575: loss 3.0130, LR: 0.001740, iter time: 291.20ms\n",
627
+ "iter 10400 step 2600: loss 3.0044, LR: 0.001675, iter time: 290.75ms\n",
628
+ "iter 10500 step 2625: loss 2.8596, LR: 0.001610, iter time: 290.14ms\n",
629
+ "iter 10600 step 2650: loss 2.0126, LR: 0.001547, iter time: 290.53ms\n",
630
+ "iter 10700 step 2675: loss 3.0040, LR: 0.001484, iter time: 292.51ms\n",
631
+ "iter 10800 step 2700: loss 3.4691, LR: 0.001422, iter time: 290.79ms\n",
632
+ "iter 10900 step 2725: loss 3.3719, LR: 0.001361, iter time: 291.21ms\n",
633
+ "iter 11000 step 2750: loss 2.9904, LR: 0.001301, iter time: 292.52ms\n",
634
+ "iter 11100 step 2775: loss 2.7121, LR: 0.001241, iter time: 291.23ms\n",
635
+ "iter 11200 step 2800: loss 3.2472, LR: 0.001183, iter time: 291.06ms\n",
636
+ "iter 11300 step 2825: loss 3.3517, LR: 0.001126, iter time: 291.27ms\n",
637
+ "iter 11400 step 2850: loss 3.2715, LR: 0.001070, iter time: 292.14ms\n",
638
+ "iter 11500 step 2875: loss 3.4200, LR: 0.001016, iter time: 290.81ms\n",
639
+ "iter 11600 step 2900: loss 3.4924, LR: 0.000962, iter time: 291.75ms\n",
640
+ "iter 11700 step 2925: loss 2.2736, LR: 0.000910, iter time: 290.48ms\n",
641
+ "iter 11800 step 2950: loss 3.1776, LR: 0.000858, iter time: 291.91ms\n",
642
+ "iter 11900 step 2975: loss 3.1710, LR: 0.000808, iter time: 291.62ms\n",
643
+ "Saving checkpoint to 'out/redpajama/iter-011999-ckpt.pth'\n",
644
+ "iter 12000 step 3000: loss 3.6688, LR: 0.000760, iter time: 290.94ms\n",
645
+ "iter 12100 step 3025: loss 3.0179, LR: 0.000712, iter time: 290.84ms\n",
646
+ "iter 12200 step 3050: loss 3.2257, LR: 0.000666, iter time: 291.06ms\n",
647
+ "iter 12300 step 3075: loss 3.1653, LR: 0.000622, iter time: 292.47ms\n",
648
+ "iter 12400 step 3100: loss 3.4042, LR: 0.000578, iter time: 291.42ms\n",
649
+ "iter 12500 step 3125: loss 3.1884, LR: 0.000537, iter time: 290.93ms\n",
650
+ "iter 12600 step 3150: loss 3.4705, LR: 0.000496, iter time: 291.49ms\n",
651
+ "iter 12700 step 3175: loss 3.5805, LR: 0.000457, iter time: 291.72ms\n",
652
+ "iter 12800 step 3200: loss 2.8953, LR: 0.000420, iter time: 292.49ms\n",
653
+ "iter 12900 step 3225: loss 3.3408, LR: 0.000384, iter time: 297.87ms\n",
654
+ "iter 13000 step 3250: loss 3.0779, LR: 0.000349, iter time: 298.95ms\n",
655
+ "iter 13100 step 3275: loss 2.5973, LR: 0.000316, iter time: 291.06ms\n",
656
+ "iter 13200 step 3300: loss 3.5901, LR: 0.000285, iter time: 291.16ms\n",
657
+ "iter 13300 step 3325: loss 2.4544, LR: 0.000255, iter time: 290.62ms\n",
658
+ "iter 13400 step 3350: loss 2.9969, LR: 0.000227, iter time: 290.56ms\n",
659
+ "iter 13500 step 3375: loss 3.1975, LR: 0.000201, iter time: 291.62ms\n",
660
+ "iter 13600 step 3400: loss 2.8946, LR: 0.000176, iter time: 290.60ms\n",
661
+ "iter 13700 step 3425: loss 3.4701, LR: 0.000153, iter time: 291.61ms\n",
662
+ "iter 13800 step 3450: loss 2.6274, LR: 0.000131, iter time: 289.90ms\n",
663
+ "iter 13900 step 3475: loss 3.3881, LR: 0.000111, iter time: 291.66ms\n",
664
+ "iter 14000 step 3500: loss 3.0832, LR: 0.000093, iter time: 291.88ms\n",
665
+ "iter 14100 step 3525: loss 3.2224, LR: 0.000077, iter time: 291.17ms\n",
666
+ "iter 14200 step 3550: loss 3.5854, LR: 0.000062, iter time: 290.77ms\n",
667
+ "iter 14300 step 3575: loss 3.3620, LR: 0.000049, iter time: 292.27ms\n",
668
+ "iter 14400 step 3600: loss 3.5590, LR: 0.000037, iter time: 291.91ms\n",
669
+ "iter 14500 step 3625: loss 3.2781, LR: 0.000028, iter time: 290.50ms\n",
670
+ "iter 14600 step 3650: loss 3.4279, LR: 0.000020, iter time: 291.54ms\n",
671
+ "iter 14700 step 3675: loss 2.8695, LR: 0.000014, iter time: 291.52ms\n",
672
+ "iter 14800 step 3700: loss 2.8212, LR: 0.000009, iter time: 291.34ms\n",
673
+ "iter 14900 step 3725: loss 3.3649, LR: 0.000007, iter time: 292.48ms\n",
674
+ "Saving checkpoint to 'out/redpajama/iter-015000-ckpt.pth'\n",
675
+ "Training time: 4615.15s\n",
676
+ "Memory used: 21.58 GB\n"
677
+ ]
678
+ }
679
+ ],
680
+ "source": [
681
+ "torch.set_float32_matmul_precision(\"medium\")\n",
682
+ "setup(devices=1, train_data_dir=Path(\"data/lit-redpajama-sample\"))"
683
+ ]
684
+ },
685
+ {
686
+ "cell_type": "code",
687
+ "execution_count": null,
688
+ "metadata": {},
689
+ "outputs": [],
690
+ "source": []
691
+ }
692
+ ],
693
+ "metadata": {
694
+ "kernelspec": {
695
+ "display_name": "base",
696
+ "language": "python",
697
+ "name": "python3"
698
+ },
699
+ "language_info": {
700
+ "codemirror_mode": {
701
+ "name": "ipython",
702
+ "version": 3
703
+ },
704
+ "file_extension": ".py",
705
+ "mimetype": "text/x-python",
706
+ "name": "python",
707
+ "nbconvert_exporter": "python",
708
+ "pygments_lexer": "ipython3",
709
+ "version": "3.10.12"
710
+ }
711
+ },
712
+ "nbformat": 4,
713
+ "nbformat_minor": 2
714
+ }
out/redpajama/iter-003999-ckpt.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:396f17fb6dcf0dff11914ce7b427547fa35b9fe9691a70084ceefc3f6b1d2a69
3
+ size 42205184
out/redpajama/iter-007999-ckpt.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c169e321ef26a1bcf3fe750aab25264f781c69e4763858824cb08979ebe7b13a
3
+ size 41943040
out/redpajama/iter-011999-ckpt.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad33194d951debfaf63810e94385dc23b0379e058ee7d22f9d059038d8f137e7
3
+ size 41943040
out/redpajama/lit_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"name": "pythia-160m", "hf_config": {"org": "EleutherAI", "name": "pythia-160m"}, "block_size": 2048, "vocab_size": 50254, "padding_multiple": 128, "padded_vocab_size": 50304, "n_layer": 12, "n_head": 12, "n_embd": 768, "rotary_percentage": 0.25, "parallel_residual": true, "bias": true, "lm_head_bias": false, "n_query_groups": 12, "shared_attention_norm": false, "_norm_class": "LayerNorm", "norm_eps": 1e-05, "_mlp_class": "GptNeoxMLP", "gelu_approximate": "none", "intermediate_size": 3072, "rope_condense_ratio": 1, "rope_base": 10000}
out/redpajama/lit_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aae789bf9e490f230f8347baf067918c95be2d71b47112e9e63476a1894a19ad
3
+ size 44826624
out/redpajama/lit_model2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:568b2c0443dc4464590b9bab5953f53eadc9c4ae3bcd00679e59d924fa3f7778
3
+ size 44826624
out/redpajama/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
out/redpajama/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
out/redpajama/tokenizer_config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<s>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
13
+ "clean_up_tokenization_spaces": false,
14
+ "eos_token": {
15
+ "__type": "AddedToken",
16
+ "content": "</s>",
17
+ "lstrip": false,
18
+ "normalized": false,
19
+ "rstrip": false,
20
+ "single_word": false
21
+ },
22
+ "legacy": false,
23
+ "model_max_length": 1000000000000000019884624838656,
24
+ "pad_token": null,
25
+ "padding_side": "right",
26
+ "sp_model_kwargs": {},
27
+ "tokenizer_class": "LlamaTokenizer",
28
+ "unk_token": {
29
+ "__type": "AddedToken",
30
+ "content": "<unk>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false
35
+ }
36
+ }
out/redpajama/version_1/metrics.csv ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=2.1.0
2
+ lightning @ git+https://github.com/Lightning-AI/lightning@6cbe9ceb560d798892bdae9186291acf9bf5d2e3
3
+ jsonargparse[signatures] # CLI
4
+ gradio
5
+ sentencepiece
tokenizer_config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<s>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
13
+ "clean_up_tokenization_spaces": false,
14
+ "eos_token": {
15
+ "__type": "AddedToken",
16
+ "content": "</s>",
17
+ "lstrip": false,
18
+ "normalized": false,
19
+ "rstrip": false,
20
+ "single_word": false
21
+ },
22
+ "legacy": false,
23
+ "model_max_length": 1000000000000000019884624838656,
24
+ "pad_token": null,
25
+ "padding_side": "right",
26
+ "sp_model_kwargs": {},
27
+ "tokenizer_class": "LlamaTokenizer",
28
+ "unk_token": {
29
+ "__type": "AddedToken",
30
+ "content": "<unk>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false
35
+ }
36
+ }
tsai_gpt/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tsai_gpt.model import GPT
2
+ from tsai_gpt.config import Config
3
+ from tsai_gpt.tokenizer import Tokenizer
4
+
5
+ from lightning_utilities.core.imports import RequirementCache
6
+
7
+ _LIGHTNING_AVAILABLE = RequirementCache("lightning>=2.1.0.dev0")
8
+ if not bool(_LIGHTNING_AVAILABLE):
9
+ raise ImportError(
10
+ "Lit-GPT requires lightning==2.1. Please run:\n"
11
+ f" pip uninstall -y lightning; pip install -r requirements.txt\n{str(_LIGHTNING_AVAILABLE)}"
12
+ )
13
+
14
+
15
+ __all__ = ["GPT", "Config", "Tokenizer"]
tsai_gpt/config.py ADDED
@@ -0,0 +1,1181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from copy import deepcopy
3
+ from dataclasses import dataclass, field
4
+ from pathlib import Path
5
+ from typing import Any, Literal, Optional, Type, Union
6
+
7
+ import torch
8
+ from typing_extensions import Self
9
+
10
+ import tsai_gpt.model
11
+ from tsai_gpt.utils import find_multiple
12
+
13
+
14
+ @dataclass
15
+ class Config:
16
+ name: str = ""
17
+ hf_config: dict = field(default_factory=dict)
18
+ block_size: int = 4096
19
+ vocab_size: int = 50254
20
+ padding_multiple: int = 512
21
+ padded_vocab_size: Optional[int] = None
22
+ n_layer: int = 16
23
+ n_head: int = 32
24
+ n_embd: int = 4096
25
+ rotary_percentage: float = 0.25
26
+ parallel_residual: bool = True
27
+ bias: bool = True
28
+ lm_head_bias: bool = False
29
+ # to use multi-head attention (MHA), set this to `n_head` (default)
30
+ # to use multi-query attention (MQA), set this to 1
31
+ # to use grouped-query attention (GQA), set this to a value in between
32
+ # Example with `n_head=4`
33
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
34
+ # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
35
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
36
+ # │ │ │ │ │ │ │
37
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
38
+ # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
39
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
40
+ # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
41
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
42
+ # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
43
+ # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
44
+ # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
45
+ # MHA GQA MQA
46
+ # n_query_groups=4 n_query_groups=2 n_query_groups=1
47
+ #
48
+ # credit https://arxiv.org/pdf/2305.13245.pdf
49
+ n_query_groups: Optional[int] = None
50
+ shared_attention_norm: bool = False
51
+ _norm_class: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
52
+ norm_eps: float = 1e-5
53
+ _mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP"
54
+ gelu_approximate: str = "none"
55
+ intermediate_size: Optional[int] = None
56
+ rope_condense_ratio: int = 1
57
+ rope_base: int = 10000
58
+
59
+ def __post_init__(self):
60
+ if not self.name:
61
+ self.name = self.hf_config.get("name", self.name)
62
+
63
+ assert self.n_embd % self.n_head == 0
64
+ self.head_size = self.n_embd // self.n_head
65
+
66
+ # vocab size should be a power of 2 to be optimal on hardware. compute the closest value
67
+ if self.padded_vocab_size is None:
68
+ self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple)
69
+ else:
70
+ # vocab size shouldn't be larger than padded vocab size
71
+ self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
72
+
73
+ # compute the number of query groups
74
+ if self.n_query_groups is not None:
75
+ assert self.n_head % self.n_query_groups == 0
76
+ else:
77
+ self.n_query_groups = self.n_head
78
+
79
+ # compute the intermediate size for MLP if not set
80
+ if self.intermediate_size is None:
81
+ if self._mlp_class == "LLaMAMLP":
82
+ raise ValueError("The config needs to set the `intermediate_size`")
83
+ self.intermediate_size = 4 * self.n_embd
84
+
85
+ self.rope_n_elem = int(self.rotary_percentage * self.head_size)
86
+
87
+ @classmethod
88
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
89
+ if name not in name_to_config:
90
+ # search through all `config['hf_config']['name']`
91
+ conf_dict = next(config for config in configs if name == config["hf_config"]["name"])
92
+ else:
93
+ conf_dict = name_to_config[name]
94
+
95
+ conf_dict = conf_dict.copy()
96
+ if "condense_ratio" in kwargs: # legacy name
97
+ kwargs["rope_condense_ratio"] = kwargs.pop("condense_ratio")
98
+ conf_dict.update(kwargs)
99
+ return cls(**conf_dict)
100
+
101
+ @classmethod
102
+ def from_json(cls, path: Union[str, Path], **kwargs: Any) -> Self:
103
+ with open(path, encoding="utf-8") as fp:
104
+ json_kwargs = json.load(fp)
105
+ if "condense_ratio" in json_kwargs: # legacy name
106
+ json_kwargs["rope_condense_ratio"] = json_kwargs.pop("condense_ratio")
107
+ if "condense_ratio" in kwargs: # legacy name
108
+ kwargs["rope_condense_ratio"] = kwargs.pop("condense_ratio")
109
+ if "org" in json_kwargs: # legacy name
110
+ json_kwargs["hf_config"] = {"name": json_kwargs["name"], "org": json_kwargs.pop("org")}
111
+ if "org" in kwargs: # legacy name
112
+ kwargs["hf_config"] = {"name": kwargs.get("name", json_kwargs["name"]), "org": kwargs.pop("org")}
113
+ json_kwargs.update(kwargs)
114
+ return cls(**json_kwargs)
115
+
116
+ @property
117
+ def mlp_class(self) -> Type:
118
+ # `self._mlp_class` cannot be the type to keep the config json serializable
119
+ return getattr(tsai_gpt.model, self._mlp_class)
120
+
121
+ @property
122
+ def norm_class(self) -> Type:
123
+ # `self._norm_class` cannot be the type to keep the config json serializable
124
+ if self._norm_class == "RMSNorm":
125
+ from tsai_gpt.rmsnorm import RMSNorm
126
+
127
+ return RMSNorm
128
+ return getattr(torch.nn, self._norm_class)
129
+
130
+
131
+ ########################
132
+ # Stability AI StableLM
133
+ ########################
134
+ configs = [
135
+ # https://huggingface.co/stabilityai/stablelm-base-alpha-3b/blob/main/config.json
136
+ dict(name="stablelm-base-alpha-3b", hf_config=dict(org="stabilityai", name="stablelm-base-alpha-3b")),
137
+ # https://huggingface.co/stabilityai/stablelm-base-alpha-7b/blob/main/config.json
138
+ dict(
139
+ name="stablelm-base-alpha-7b",
140
+ hf_config=dict(org="stabilityai", name="stablelm-base-alpha-7b"),
141
+ n_head=48,
142
+ n_embd=6144,
143
+ padding_multiple=256,
144
+ ),
145
+ # https://huggingface.co/stabilityai/stablelm-tuned-alpha-3b/blob/main/config.json
146
+ dict(name="stablelm-tuned-alpha-3b", hf_config=dict(org="stabilityai", name="stablelm-tuned-alpha-3b"), n_head=32),
147
+ # https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b/blob/main/config.json
148
+ dict(
149
+ name="stablelm-tuned-alpha-7b",
150
+ hf_config=dict(org="stabilityai", name="stablelm-tuned-alpha-7b"),
151
+ n_head=48,
152
+ n_embd=6144,
153
+ padding_multiple=256,
154
+ ),
155
+ ]
156
+
157
+ ####################
158
+ # EleutherAI Pythia
159
+ ####################
160
+ pythia = [
161
+ # https://huggingface.co/EleutherAI/pythia-70m/blob/main/config.json
162
+ dict(
163
+ name="pythia-70m",
164
+ hf_config=dict(org="EleutherAI", name="pythia-70m"),
165
+ block_size=2048,
166
+ n_layer=6,
167
+ n_embd=512,
168
+ n_head=8,
169
+ padding_multiple=128,
170
+ ),
171
+ # https://huggingface.co/EleutherAI/pythia-160m/blob/main/config.json
172
+ dict(
173
+ name="pythia-160m",
174
+ hf_config=dict(org="EleutherAI", name="pythia-160m"),
175
+ block_size=2048,
176
+ n_layer=12,
177
+ n_embd=768,
178
+ n_head=12,
179
+ padding_multiple=128,
180
+ ),
181
+ # https://huggingface.co/EleutherAI/pythia-410m/blob/main/config.json
182
+ dict(
183
+ name="pythia-410m",
184
+ hf_config=dict(org="EleutherAI", name="pythia-410m"),
185
+ block_size=2048,
186
+ n_layer=24,
187
+ n_embd=1024,
188
+ n_head=16,
189
+ padding_multiple=128,
190
+ ),
191
+ # https://huggingface.co/EleutherAI/pythia-1b/blob/main/config.json
192
+ dict(
193
+ name="pythia-1b",
194
+ hf_config=dict(org="EleutherAI", name="pythia-1b"),
195
+ block_size=2048,
196
+ n_embd=2048,
197
+ n_head=8,
198
+ padding_multiple=128,
199
+ ),
200
+ # https://huggingface.co/EleutherAI/pythia-1.4b/blob/main/config.json
201
+ dict(
202
+ name="pythia-1.4b",
203
+ hf_config=dict(org="EleutherAI", name="pythia-1.4b"),
204
+ block_size=2048,
205
+ n_layer=24,
206
+ n_embd=2048,
207
+ n_head=16,
208
+ padding_multiple=128,
209
+ ),
210
+ # https://huggingface.co/EleutherAI/pythia-2.8b/blob/main/config.json
211
+ dict(
212
+ name="pythia-2.8b",
213
+ hf_config=dict(org="EleutherAI", name="pythia-2.8b"),
214
+ block_size=2048,
215
+ n_layer=32,
216
+ n_embd=2560,
217
+ padding_multiple=128,
218
+ ),
219
+ # https://huggingface.co/EleutherAI/pythia-6.9b/blob/main/config.json
220
+ dict(
221
+ name="pythia-6.9b",
222
+ hf_config=dict(org="EleutherAI", name="pythia-6.9b"),
223
+ block_size=2048,
224
+ n_layer=32,
225
+ padding_multiple=256,
226
+ ),
227
+ # https://huggingface.co/EleutherAI/pythia-12b/blob/main/config.json
228
+ dict(
229
+ name="pythia-12b",
230
+ hf_config=dict(org="EleutherAI", name="pythia-12b"),
231
+ block_size=2048,
232
+ n_layer=36,
233
+ n_embd=5120,
234
+ n_head=40,
235
+ ),
236
+ ]
237
+ configs.extend(pythia)
238
+ for c in pythia:
239
+ copy = c.copy()
240
+ copy["name"] = f"{c['name']}-deduped"
241
+ copy["hf_config"]["name"] = f"{c['hf_config']['name']}-deduped"
242
+ configs.append(copy)
243
+
244
+
245
+ ####################################
246
+ # togethercomputer RedPajama INCITE
247
+ ####################################
248
+ redpajama_incite = [
249
+ # https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1/blob/main/config.json
250
+ dict(
251
+ name="RedPajama-INCITE-{}-3B-v1",
252
+ hf_config=dict(org="togethercomputer", name="RedPajama-INCITE-{}-3B-v1"),
253
+ block_size=2048,
254
+ n_layer=32,
255
+ n_embd=2560,
256
+ padding_multiple=256,
257
+ rotary_percentage=1.0,
258
+ parallel_residual=False,
259
+ ),
260
+ # https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Base/blob/main/config.json
261
+ dict(
262
+ name="RedPajama-INCITE-7B-{}",
263
+ hf_config=dict(org="togethercomputer", name="RedPajama-INCITE-7B-{}"),
264
+ block_size=2048,
265
+ n_layer=32,
266
+ padding_multiple=256,
267
+ rotary_percentage=1.0,
268
+ parallel_residual=False,
269
+ ),
270
+ # this redirects to the checkpoint above. kept for those who had the old weights already downloaded
271
+ dict(
272
+ name="RedPajama-INCITE-{}-7B-v0.1",
273
+ hf_config=dict(org="togethercomputer", name="RedPajama-INCITE-{}-7B-v0.1"),
274
+ block_size=2048,
275
+ n_layer=32,
276
+ padding_multiple=256,
277
+ rotary_percentage=1.0,
278
+ parallel_residual=False,
279
+ ),
280
+ ]
281
+ for c in redpajama_incite:
282
+ for kind in ("Base", "Chat", "Instruct"):
283
+ copy = c.copy()
284
+ copy["name"] = c["name"].format(kind)
285
+ copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
286
+ configs.append(copy)
287
+
288
+
289
+ #################
290
+ # TII UAE Falcon
291
+ #################
292
+ falcon = [
293
+ # https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json
294
+ dict(
295
+ name="falcon-7b{}",
296
+ hf_config=dict(org="tiiuae", name="falcon-7b{}"),
297
+ block_size=2048,
298
+ vocab_size=65024,
299
+ padded_vocab_size=65024,
300
+ n_layer=32,
301
+ n_head=71,
302
+ n_embd=4544,
303
+ rotary_percentage=1.0,
304
+ n_query_groups=1,
305
+ bias=False,
306
+ # this is not in the config, but in the original model implementation, only for this config
307
+ shared_attention_norm=True,
308
+ ),
309
+ # https://huggingface.co/tiiuae/falcon-40b/blob/main/config.json
310
+ dict(
311
+ name="falcon-40b{}",
312
+ hf_config=dict(org="tiiuae", name="falcon-40b{}"),
313
+ block_size=2048,
314
+ vocab_size=65024,
315
+ padded_vocab_size=65024,
316
+ n_layer=60,
317
+ n_head=128,
318
+ n_embd=8192,
319
+ rotary_percentage=1.0,
320
+ n_query_groups=8,
321
+ bias=False,
322
+ ),
323
+ ]
324
+ for c in falcon:
325
+ for kind in ("", "-instruct"):
326
+ copy = c.copy()
327
+ copy["name"] = c["name"].format(kind)
328
+ copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
329
+ configs.append(copy)
330
+
331
+ # https://huggingface.co/tiiuae/falcon-180b/blob/main/config.json
332
+ falcon180b = dict(
333
+ name="falcon-180B{}",
334
+ hf_config=dict(org="tiiuae", name="falcon-180B{}"),
335
+ block_size=2048,
336
+ vocab_size=65024,
337
+ padded_vocab_size=65024,
338
+ n_layer=80,
339
+ n_head=232,
340
+ n_embd=14848,
341
+ rotary_percentage=1.0,
342
+ n_query_groups=8,
343
+ bias=False,
344
+ )
345
+
346
+ for kind in ("", "-chat"):
347
+ copy = falcon180b.copy()
348
+ copy["name"] = falcon180b["name"].format(kind)
349
+ copy["hf_config"]["name"] = falcon180b["hf_config"]["name"].format(kind)
350
+ configs.append(copy)
351
+
352
+
353
+ #############################
354
+ # OpenLM Research Open LLaMA
355
+ #############################
356
+ open_LLaMA = [
357
+ # https://huggingface.co/openlm-research/open_llama_3b/blob/main/config.json
358
+ dict(
359
+ name="open_llama_3b",
360
+ hf_config=dict(org="openlm-research", name="open_llama_3b"),
361
+ block_size=2048,
362
+ vocab_size=32000,
363
+ padding_multiple=64,
364
+ n_layer=26,
365
+ n_embd=3200,
366
+ rotary_percentage=1.0,
367
+ parallel_residual=False,
368
+ bias=False,
369
+ _norm_class="RMSNorm",
370
+ norm_eps=1e-6,
371
+ _mlp_class="LLaMAMLP",
372
+ intermediate_size=8640,
373
+ ),
374
+ # https://huggingface.co/openlm-research/open_llama_7b/blob/main/config.json
375
+ dict(
376
+ name="open_llama_7b",
377
+ hf_config=dict(org="openlm-research", name="open_llama_7b"),
378
+ block_size=2048,
379
+ vocab_size=32000,
380
+ padding_multiple=64,
381
+ n_layer=32,
382
+ rotary_percentage=1.0,
383
+ parallel_residual=False,
384
+ bias=False,
385
+ _norm_class="RMSNorm",
386
+ norm_eps=1e-6,
387
+ _mlp_class="LLaMAMLP",
388
+ intermediate_size=11008,
389
+ ),
390
+ # https://huggingface.co/openlm-research/open_llama_13b/blob/main/config.json
391
+ dict(
392
+ name="open_llama_13b",
393
+ hf_config=dict(org="openlm-research", name="open_llama_13b"),
394
+ block_size=2048,
395
+ vocab_size=32000,
396
+ padding_multiple=64,
397
+ n_layer=40,
398
+ n_head=40,
399
+ n_embd=5120,
400
+ rotary_percentage=1.0,
401
+ parallel_residual=False,
402
+ bias=False,
403
+ _norm_class="RMSNorm",
404
+ norm_eps=1e-6,
405
+ _mlp_class="LLaMAMLP",
406
+ intermediate_size=13824,
407
+ ),
408
+ ]
409
+ configs.extend(open_LLaMA)
410
+
411
+
412
+ ###############
413
+ # LMSYS Vicuna
414
+ ###############
415
+ vicuna = [
416
+ # https://huggingface.co/lmsys/vicuna-7b-v1.3/blob/main/config.json
417
+ dict(
418
+ name="vicuna-7b-v1.3",
419
+ hf_config=dict(org="lmsys", name="vicuna-7b-v1.3"),
420
+ block_size=2048,
421
+ vocab_size=32000,
422
+ padding_multiple=64,
423
+ n_layer=32,
424
+ rotary_percentage=1.0,
425
+ parallel_residual=False,
426
+ bias=False,
427
+ _norm_class="RMSNorm",
428
+ norm_eps=1e-6,
429
+ _mlp_class="LLaMAMLP",
430
+ intermediate_size=11008,
431
+ ),
432
+ # https://huggingface.co/lmsys/vicuna-13b-v1.3/blob/main/config.json
433
+ dict(
434
+ name="vicuna-13b-v1.3",
435
+ hf_config=dict(org="lmsys", name="vicuna-13b-v1.3"),
436
+ block_size=2048,
437
+ vocab_size=32000,
438
+ padding_multiple=64,
439
+ n_layer=40,
440
+ n_head=40,
441
+ n_embd=5120,
442
+ rotary_percentage=1.0,
443
+ parallel_residual=False,
444
+ bias=False,
445
+ _norm_class="RMSNorm",
446
+ norm_eps=1e-6,
447
+ _mlp_class="LLaMAMLP",
448
+ intermediate_size=13824,
449
+ ),
450
+ # https://huggingface.co/lmsys/vicuna-33b-v1.3/blob/main/config.json
451
+ dict(
452
+ name="vicuna-33b-v1.3",
453
+ hf_config=dict(org="lmsys", name="vicuna-33b-v1.3"),
454
+ block_size=2048,
455
+ vocab_size=32000,
456
+ padding_multiple=64,
457
+ n_layer=60,
458
+ n_head=52,
459
+ n_embd=6656,
460
+ rotary_percentage=1.0,
461
+ parallel_residual=False,
462
+ bias=False,
463
+ _norm_class="RMSNorm",
464
+ norm_eps=1e-6,
465
+ _mlp_class="LLaMAMLP",
466
+ intermediate_size=17920,
467
+ ),
468
+ # https://huggingface.co/lmsys/vicuna-7b-v1.5/blob/main/config.json
469
+ dict(
470
+ name="vicuna-7b-v1.5",
471
+ hf_config=dict(org="lmsys", name="vicuna-7b-v1.5"),
472
+ vocab_size=32000,
473
+ padding_multiple=64,
474
+ n_layer=32,
475
+ rotary_percentage=1.0,
476
+ parallel_residual=False,
477
+ bias=False,
478
+ _norm_class="RMSNorm",
479
+ _mlp_class="LLaMAMLP",
480
+ intermediate_size=11008,
481
+ ),
482
+ # https://huggingface.co/lmsys/vicuna-7b-v1.5-16k/blob/main/config.json
483
+ dict(
484
+ name="vicuna-7b-v1.5-16k",
485
+ hf_config=dict(org="lmsys", name="vicuna-7b-v1.5-16k"),
486
+ block_size=16384,
487
+ vocab_size=32000,
488
+ padding_multiple=64,
489
+ n_layer=32,
490
+ rotary_percentage=1.0,
491
+ parallel_residual=False,
492
+ bias=False,
493
+ _norm_class="RMSNorm",
494
+ _mlp_class="LLaMAMLP",
495
+ intermediate_size=11008,
496
+ rope_condense_ratio=4,
497
+ ),
498
+ # https://huggingface.co/lmsys/vicuna-13b-v1.5/blob/main/config.json
499
+ dict(
500
+ name="vicuna-13b-v1.5",
501
+ hf_config=dict(org="lmsys", name="vicuna-13b-v1.5"),
502
+ vocab_size=32000,
503
+ padding_multiple=64,
504
+ n_layer=40,
505
+ n_head=40,
506
+ n_embd=5120,
507
+ rotary_percentage=1.0,
508
+ parallel_residual=False,
509
+ bias=False,
510
+ _norm_class="RMSNorm",
511
+ _mlp_class="LLaMAMLP",
512
+ intermediate_size=13824,
513
+ ),
514
+ # https://huggingface.co/lmsys/vicuna-13b-v1.5-16k/blob/main/config.json
515
+ dict(
516
+ name="vicuna-13b-v1.5-16k",
517
+ hf_config=dict(org="lmsys", name="vicuna-13b-v1.5-16k"),
518
+ block_size=16384,
519
+ vocab_size=32000,
520
+ padding_multiple=64,
521
+ n_layer=40,
522
+ n_head=40,
523
+ n_embd=5120,
524
+ rotary_percentage=1.0,
525
+ parallel_residual=False,
526
+ bias=False,
527
+ _norm_class="RMSNorm",
528
+ _mlp_class="LLaMAMLP",
529
+ intermediate_size=13824,
530
+ rope_condense_ratio=4,
531
+ ),
532
+ ]
533
+ configs.extend(vicuna)
534
+
535
+
536
+ #################
537
+ # LMSYS LongChat
538
+ #################
539
+ long_chat = [
540
+ # https://huggingface.co/lmsys/longchat-7b-16k/blob/main/config.json
541
+ dict(
542
+ name="longchat-7b-16k",
543
+ hf_config=dict(org="lmsys", name="longchat-7b-16k"),
544
+ block_size=16384,
545
+ vocab_size=32000,
546
+ padding_multiple=64,
547
+ n_layer=32,
548
+ rotary_percentage=1.0,
549
+ parallel_residual=False,
550
+ bias=False,
551
+ _norm_class="RMSNorm",
552
+ norm_eps=1e-6,
553
+ _mlp_class="LLaMAMLP",
554
+ intermediate_size=11008,
555
+ rope_condense_ratio=8,
556
+ ),
557
+ # https://huggingface.co/lmsys/longchat-13b-16k/blob/main/config.json
558
+ dict(
559
+ name="longchat-13b-16k",
560
+ hf_config=dict(org="lmsys", name="longchat-13b-16k"),
561
+ block_size=16384,
562
+ vocab_size=32000,
563
+ padding_multiple=64,
564
+ n_layer=40,
565
+ n_head=40,
566
+ n_embd=5120,
567
+ rotary_percentage=1.0,
568
+ parallel_residual=False,
569
+ bias=False,
570
+ _norm_class="RMSNorm",
571
+ norm_eps=1e-6,
572
+ _mlp_class="LLaMAMLP",
573
+ intermediate_size=13824,
574
+ rope_condense_ratio=8,
575
+ ),
576
+ ]
577
+ configs.extend(long_chat)
578
+
579
+
580
+ ######################
581
+ # NousResearch Hermes
582
+ ######################
583
+ nous_research = [
584
+ # https://huggingface.co/NousResearch/Nous-Hermes-llama-2-7b/blob/main/config.json
585
+ dict(
586
+ name="Nous-Hermes-llama-2-7b",
587
+ hf_config=dict(org="NousResearch", name="Nous-Hermes-llama-2-7b"),
588
+ padded_vocab_size=32000,
589
+ n_layer=32,
590
+ rotary_percentage=1.0,
591
+ parallel_residual=False,
592
+ bias=False,
593
+ _norm_class="RMSNorm",
594
+ norm_eps=1e-05,
595
+ _mlp_class="LLaMAMLP",
596
+ intermediate_size=11008,
597
+ ),
598
+ # https://huggingface.co/NousResearch/Nous-Hermes-13B/blob/main/config.json
599
+ dict(
600
+ name="Nous-Hermes-13b",
601
+ hf_config=dict(org="NousResearch", name="Nous-Hermes-13b"),
602
+ block_size=2048,
603
+ vocab_size=32000,
604
+ padded_vocab_size=32001,
605
+ n_layer=40,
606
+ n_head=40,
607
+ n_embd=5120,
608
+ rotary_percentage=1.0,
609
+ parallel_residual=False,
610
+ bias=False,
611
+ _norm_class="RMSNorm",
612
+ norm_eps=1e-6,
613
+ _mlp_class="LLaMAMLP",
614
+ intermediate_size=13824,
615
+ ),
616
+ # https://huggingface.co/NousResearch/Nous-Hermes-Llama2-13b
617
+ dict(
618
+ name="Nous-Hermes-Llama2-13b",
619
+ hf_config=dict(org="NousResearch", name="Nous-Hermes-Llama2-13b"),
620
+ vocab_size=32000,
621
+ padded_vocab_size=32032,
622
+ n_layer=40,
623
+ n_head=40,
624
+ n_embd=5120,
625
+ rotary_percentage=1.0,
626
+ parallel_residual=False,
627
+ bias=False,
628
+ _norm_class="RMSNorm",
629
+ norm_eps=1e-05,
630
+ _mlp_class="LLaMAMLP",
631
+ intermediate_size=13824,
632
+ ),
633
+ ]
634
+ configs.extend(nous_research)
635
+
636
+
637
+ ###############
638
+ # Meta LLaMA 2
639
+ ###############
640
+ llama_2 = [
641
+ # https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json
642
+ dict(
643
+ name="Llama-2-7b{}-hf",
644
+ hf_config=dict(org="meta-llama", name="Llama-2-7b{}-hf"),
645
+ vocab_size=32000,
646
+ padding_multiple=64,
647
+ n_layer=32,
648
+ rotary_percentage=1.0,
649
+ parallel_residual=False,
650
+ bias=False,
651
+ _norm_class="RMSNorm",
652
+ _mlp_class="LLaMAMLP",
653
+ intermediate_size=11008,
654
+ ),
655
+ # https://huggingface.co/meta-llama/Llama-2-13b-hf/blob/main/config.json
656
+ dict(
657
+ name="Llama-2-13b{}-hf",
658
+ hf_config=dict(org="meta-llama", name="Llama-2-13b{}-hf"),
659
+ vocab_size=32000,
660
+ padding_multiple=64,
661
+ n_layer=40,
662
+ n_head=40,
663
+ n_embd=5120,
664
+ rotary_percentage=1.0,
665
+ parallel_residual=False,
666
+ bias=False,
667
+ _norm_class="RMSNorm",
668
+ _mlp_class="LLaMAMLP",
669
+ intermediate_size=13824,
670
+ ),
671
+ # https://huggingface.co/meta-llama/Llama-2-70b-hf/blob/main/config.json
672
+ dict(
673
+ name="Llama-2-70b{}-hf",
674
+ hf_config=dict(org="meta-llama", name="Llama-2-70b{}-hf"),
675
+ vocab_size=32000,
676
+ padding_multiple=64,
677
+ n_layer=80,
678
+ n_head=64,
679
+ n_embd=8192,
680
+ n_query_groups=8,
681
+ rotary_percentage=1.0,
682
+ parallel_residual=False,
683
+ bias=False,
684
+ _norm_class="RMSNorm",
685
+ _mlp_class="LLaMAMLP",
686
+ intermediate_size=28672,
687
+ ),
688
+ ]
689
+ for c in llama_2:
690
+ for kind in ("", "-chat"):
691
+ copy = c.copy()
692
+ copy["name"] = c["name"].format(kind)
693
+ copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
694
+ configs.append(copy)
695
+
696
+
697
+ ##########################
698
+ # Stability AI FreeWilly2
699
+ ##########################
700
+ freewilly_2 = [
701
+ # https://huggingface.co/stabilityai/FreeWilly2/blob/main/config.json
702
+ dict(
703
+ name="FreeWilly2",
704
+ hf_config=dict(org="stabilityai", name="FreeWilly2"),
705
+ vocab_size=32000,
706
+ padding_multiple=64,
707
+ n_layer=80,
708
+ n_head=64,
709
+ n_embd=8192,
710
+ n_query_groups=8,
711
+ rotary_percentage=1.0,
712
+ parallel_residual=False,
713
+ bias=False,
714
+ _norm_class="RMSNorm",
715
+ _mlp_class="LLaMAMLP",
716
+ intermediate_size=28672,
717
+ )
718
+ ]
719
+ configs.extend(freewilly_2)
720
+
721
+
722
+ ##################
723
+ # Meta Code Llama
724
+ ##################
725
+ code_llama = [
726
+ # https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json
727
+ dict(
728
+ name="CodeLlama-7b-hf",
729
+ hf_config=dict(org="codellama", name="CodeLlama-7b-hf"),
730
+ block_size=16384,
731
+ vocab_size=32016,
732
+ padding_multiple=16,
733
+ n_layer=32,
734
+ rotary_percentage=1.0,
735
+ parallel_residual=False,
736
+ bias=False,
737
+ _norm_class="RMSNorm",
738
+ norm_eps=1e-05,
739
+ _mlp_class="LLaMAMLP",
740
+ intermediate_size=11008,
741
+ rope_base=1000000,
742
+ ),
743
+ # https://huggingface.co/codellama/CodeLlama-13b-hf/blob/main/config.json
744
+ dict(
745
+ name="CodeLlama-13b-hf",
746
+ hf_config=dict(org="codellama", name="CodeLlama-13b-hf"),
747
+ block_size=16384,
748
+ vocab_size=32016,
749
+ padding_multiple=16,
750
+ n_layer=40,
751
+ n_head=40,
752
+ n_embd=5120,
753
+ rotary_percentage=1.0,
754
+ parallel_residual=False,
755
+ bias=False,
756
+ _norm_class="RMSNorm",
757
+ norm_eps=1e-05,
758
+ _mlp_class="LLaMAMLP",
759
+ intermediate_size=13824,
760
+ rope_base=1000000,
761
+ ),
762
+ # https://huggingface.co/codellama/CodeLlama-34b-hf/blob/main/config.json
763
+ dict(
764
+ name="CodeLlama-34b-hf",
765
+ hf_config=dict(org="codellama", name="CodeLlama-34b-hf"),
766
+ block_size=16384,
767
+ vocab_size=32000,
768
+ padding_multiple=64,
769
+ n_layer=48,
770
+ n_head=64,
771
+ n_embd=8192,
772
+ n_query_groups=8,
773
+ rotary_percentage=1.0,
774
+ parallel_residual=False,
775
+ bias=False,
776
+ _norm_class="RMSNorm",
777
+ norm_eps=1e-05,
778
+ _mlp_class="LLaMAMLP",
779
+ intermediate_size=22016,
780
+ rope_base=1000000,
781
+ ),
782
+ # https://huggingface.co/codellama/CodeLlama-7b-Python-hf/blob/main/config.json
783
+ dict(
784
+ name="CodeLlama-7b-Python-hf",
785
+ hf_config=dict(org="codellama", name="CodeLlama-7b-Python-hf"),
786
+ block_size=16384,
787
+ vocab_size=32000,
788
+ padding_multiple=64,
789
+ n_layer=32,
790
+ rotary_percentage=1.0,
791
+ parallel_residual=False,
792
+ bias=False,
793
+ _norm_class="RMSNorm",
794
+ norm_eps=1e-05,
795
+ _mlp_class="LLaMAMLP",
796
+ intermediate_size=11008,
797
+ rope_base=1000000,
798
+ ),
799
+ # https://huggingface.co/codellama/CodeLlama-13b-Python-hf/blob/main/config.json
800
+ dict(
801
+ name="CodeLlama-13b-Python-hf",
802
+ hf_config=dict(org="codellama", name="CodeLlama-13b-Python-hf"),
803
+ block_size=16384,
804
+ vocab_size=32000,
805
+ padding_multiple=64,
806
+ n_layer=40,
807
+ n_head=40,
808
+ n_embd=5120,
809
+ rotary_percentage=1.0,
810
+ parallel_residual=False,
811
+ bias=False,
812
+ _norm_class="RMSNorm",
813
+ norm_eps=1e-05,
814
+ _mlp_class="LLaMAMLP",
815
+ intermediate_size=13824,
816
+ rope_base=1000000,
817
+ ),
818
+ # https://huggingface.co/codellama/CodeLlama-34b-Python-hf/blob/main/config.json
819
+ dict(
820
+ name="CodeLlama-34b-Python-hf",
821
+ hf_config=dict(org="codellama", name="CodeLlama-34b-Python-hf"),
822
+ block_size=16384,
823
+ vocab_size=32000,
824
+ padding_multiple=64,
825
+ n_layer=48,
826
+ n_head=64,
827
+ n_embd=8192,
828
+ n_query_groups=8,
829
+ rotary_percentage=1.0,
830
+ parallel_residual=False,
831
+ bias=False,
832
+ _norm_class="RMSNorm",
833
+ norm_eps=1e-05,
834
+ _mlp_class="LLaMAMLP",
835
+ intermediate_size=22016,
836
+ rope_base=1000000,
837
+ ),
838
+ # https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/tree/main/config.json
839
+ dict(
840
+ name="CodeLlama-7b-Instruct-hf",
841
+ hf_config=dict(org="codellama", name="CodeLlama-7b-Instruct-hf"),
842
+ block_size=16384,
843
+ vocab_size=32016,
844
+ padding_multiple=16,
845
+ n_layer=32,
846
+ rotary_percentage=1.0,
847
+ parallel_residual=False,
848
+ bias=False,
849
+ _norm_class="RMSNorm",
850
+ norm_eps=1e-05,
851
+ _mlp_class="LLaMAMLP",
852
+ intermediate_size=11008,
853
+ rope_base=1000000,
854
+ ),
855
+ # https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf/blob/main/config.json
856
+ dict(
857
+ name="CodeLlama-13b-Instruct-hf",
858
+ hf_config=dict(org="codellama", name="CodeLlama-13b-Instruct-hf"),
859
+ block_size=2048,
860
+ vocab_size=32016,
861
+ padding_multiple=16,
862
+ n_layer=40,
863
+ n_head=40,
864
+ n_embd=5120,
865
+ rotary_percentage=1.0,
866
+ parallel_residual=False,
867
+ bias=False,
868
+ _norm_class="RMSNorm",
869
+ norm_eps=1e-05,
870
+ _mlp_class="LLaMAMLP",
871
+ intermediate_size=13824,
872
+ rope_base=1000000,
873
+ ),
874
+ # https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf/blob/main/config.json
875
+ dict(
876
+ name="CodeLlama-34b-Instruct-hf",
877
+ hf_config=dict(org="codellama", name="CodeLlama-34b-Instruct-hf"),
878
+ block_size=16384,
879
+ vocab_size=32000,
880
+ padding_multiple=64,
881
+ n_layer=48,
882
+ n_head=64,
883
+ n_embd=8192,
884
+ n_query_groups=8,
885
+ rotary_percentage=1.0,
886
+ parallel_residual=False,
887
+ bias=False,
888
+ _norm_class="RMSNorm",
889
+ norm_eps=1e-05,
890
+ _mlp_class="LLaMAMLP",
891
+ intermediate_size=22016,
892
+ rope_base=1000000,
893
+ ),
894
+ ]
895
+ configs.extend(code_llama)
896
+
897
+
898
+ ########################
899
+ # garage-bAInd Platypus
900
+ ########################
901
+ platypus = [
902
+ # https://huggingface.co/garage-bAInd/Platypus-30B/blob/main/config.json
903
+ dict(
904
+ name="Platypus-30B",
905
+ hf_config=dict(org="garage-bAInd", name="Platypus-30B"),
906
+ block_size=2048,
907
+ padded_vocab_size=32000,
908
+ n_layer=60,
909
+ n_head=52,
910
+ n_embd=6656,
911
+ rotary_percentage=1.0,
912
+ parallel_residual=False,
913
+ bias=False,
914
+ _norm_class="RMSNorm",
915
+ norm_eps=1e-06,
916
+ _mlp_class="LLaMAMLP",
917
+ intermediate_size=17920,
918
+ ),
919
+ # https://huggingface.co/garage-bAInd/Platypus2-7B/blob/main/config.json
920
+ dict(
921
+ name="Platypus2-7B",
922
+ hf_config=dict(org="garage-bAInd", name="Platypus2-7B"),
923
+ padded_vocab_size=32000,
924
+ n_layer=32,
925
+ rotary_percentage=1.0,
926
+ parallel_residual=False,
927
+ bias=False,
928
+ _norm_class="RMSNorm",
929
+ norm_eps=1e-05,
930
+ _mlp_class="LLaMAMLP",
931
+ intermediate_size=11008,
932
+ ),
933
+ # https://huggingface.co/garage-bAInd/Platypus2-13B/blob/main/config.json
934
+ dict(
935
+ name="Platypus2-13B",
936
+ hf_config=dict(org="garage-bAInd", name="Platypus2-13B"),
937
+ padded_vocab_size=32000,
938
+ n_layer=40,
939
+ n_head=40,
940
+ n_embd=5120,
941
+ rotary_percentage=1.0,
942
+ parallel_residual=False,
943
+ bias=False,
944
+ _norm_class="RMSNorm",
945
+ norm_eps=1e-05,
946
+ _mlp_class="LLaMAMLP",
947
+ intermediate_size=13824,
948
+ ),
949
+ # https://huggingface.co/garage-bAInd/Platypus2-70B/blob/main/config.json
950
+ dict(
951
+ name="Platypus2-70B",
952
+ hf_config=dict(org="garage-bAInd", name="Platypus2-70B"),
953
+ padded_vocab_size=32000,
954
+ n_layer=80,
955
+ n_head=64,
956
+ n_embd=8192,
957
+ rotary_percentage=1.0,
958
+ parallel_residual=False,
959
+ bias=False,
960
+ _norm_class="RMSNorm",
961
+ _mlp_class="LLaMAMLP",
962
+ intermediate_size=28672,
963
+ ),
964
+ # https://huggingface.co/garage-bAInd/Camel-Platypus2-13B/blob/main/config.json
965
+ dict(
966
+ name="Camel-Platypus2-13B",
967
+ hf_config=dict(org="garage-bAInd", name="Camel-Platypus2-13B"),
968
+ padded_vocab_size=32000,
969
+ n_layer=40,
970
+ n_head=40,
971
+ n_embd=5120,
972
+ rotary_percentage=1.0,
973
+ parallel_residual=False,
974
+ bias=False,
975
+ _norm_class="RMSNorm",
976
+ _mlp_class="LLaMAMLP",
977
+ intermediate_size=13824,
978
+ ),
979
+ # https://huggingface.co/garage-bAInd/Camel-Platypus2-70B/blob/main/config.json
980
+ dict(
981
+ name="Camel-Platypus2-70B",
982
+ hf_config=dict(org="garage-bAInd", name="Camel-Platypus2-70B"),
983
+ padded_vocab_size=32000,
984
+ n_layer=80,
985
+ n_head=64,
986
+ n_embd=8192,
987
+ n_query_groups=8,
988
+ rotary_percentage=1.0,
989
+ parallel_residual=False,
990
+ bias=False,
991
+ _norm_class="RMSNorm",
992
+ _mlp_class="LLaMAMLP",
993
+ intermediate_size=28672,
994
+ ),
995
+ # https://huggingface.co/garage-bAInd/Stable-Platypus2-13B/blob/main/config.json
996
+ dict(
997
+ name="Stable-Platypus2-13B",
998
+ hf_config=dict(org="garage-bAInd", name="Stable-Platypus2-13B"),
999
+ padded_vocab_size=32000,
1000
+ n_layer=40,
1001
+ n_head=40,
1002
+ n_embd=5120,
1003
+ rotary_percentage=1.0,
1004
+ parallel_residual=False,
1005
+ bias=False,
1006
+ _norm_class="RMSNorm",
1007
+ _mlp_class="LLaMAMLP",
1008
+ intermediate_size=13824,
1009
+ ),
1010
+ # https://huggingface.co/garage-bAInd/Platypus2-70B-instruct/blob/main/config.json
1011
+ dict(
1012
+ name="Platypus2-70B-instruct",
1013
+ hf_config=dict(org="garage-bAInd", name="Platypus2-70B-instruct"),
1014
+ padded_vocab_size=32000,
1015
+ n_layer=80,
1016
+ n_head=64,
1017
+ n_embd=8192,
1018
+ n_query_groups=8,
1019
+ rotary_percentage=1.0,
1020
+ parallel_residual=False,
1021
+ bias=False,
1022
+ _norm_class="RMSNorm",
1023
+ _mlp_class="LLaMAMLP",
1024
+ intermediate_size=28672,
1025
+ ),
1026
+ ]
1027
+ configs.extend(platypus)
1028
+
1029
+
1030
+ ##########################
1031
+ # Stability AI StableCode
1032
+ ##########################
1033
+ stablecode = [
1034
+ # https://huggingface.co/stabilityai/stablecode-completion-alpha-3b/blob/main/config.json
1035
+ dict(
1036
+ name="stablecode-completion-alpha-3b",
1037
+ hf_config=dict(org="stabilityai", name="stablecode-completion-alpha-3b"),
1038
+ block_size=16384,
1039
+ vocab_size=49152,
1040
+ n_layer=32,
1041
+ n_embd=2560,
1042
+ ),
1043
+ # https://huggingface.co/stabilityai/stablecode-completion-alpha-3b-4k/blob/main/config.json
1044
+ dict(
1045
+ name="stablecode-completion-alpha-3b-4k",
1046
+ hf_config=dict(org="stabilityai", name="stablecode-completion-alpha-3b-4k"),
1047
+ vocab_size=49152,
1048
+ n_layer=32,
1049
+ n_embd=2560,
1050
+ ),
1051
+ # https://huggingface.co/stabilityai/stablecode-instruct-alpha-3b/blob/main/config.json
1052
+ dict(
1053
+ name="stablecode-instruct-alpha-3b",
1054
+ hf_config=dict(org="stabilityai", name="stablecode-instruct-alpha-3b"),
1055
+ vocab_size=49152,
1056
+ n_layer=32,
1057
+ n_embd=2560,
1058
+ ),
1059
+ ]
1060
+ configs.extend(stablecode)
1061
+
1062
+
1063
+ ##################################
1064
+ # togethercomputer LLaMA-2-7B-32K
1065
+ ##################################
1066
+ together_llama2_32k = [
1067
+ # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/config.json
1068
+ dict(
1069
+ name="LLaMA-2-7B-32K",
1070
+ hf_config=dict(org="togethercomputer", name="LLaMA-2-7B-32K"),
1071
+ vocab_size=32000,
1072
+ padding_multiple=64,
1073
+ n_layer=32,
1074
+ rotary_percentage=1.0,
1075
+ parallel_residual=False,
1076
+ bias=False,
1077
+ _norm_class="RMSNorm",
1078
+ _mlp_class="LLaMAMLP",
1079
+ intermediate_size=11008,
1080
+ rope_condense_ratio=8,
1081
+ )
1082
+ ]
1083
+ configs.extend(together_llama2_32k)
1084
+
1085
+
1086
+ ################
1087
+ # Microsoft Phi
1088
+ ################
1089
+ phi = [
1090
+ # https://huggingface.co/microsoft/phi-1_5/blob/main/config.json
1091
+ dict(
1092
+ name="phi-1_5",
1093
+ hf_config=dict(org="microsoft", name="phi-1_5"),
1094
+ vocab_size=50257,
1095
+ padded_vocab_size=51200,
1096
+ block_size=2048,
1097
+ n_embd=2048,
1098
+ n_layer=24,
1099
+ rotary_percentage=0.5, # 32 / (n_embd / n_head) = 32 / 64
1100
+ shared_attention_norm=True,
1101
+ lm_head_bias=True,
1102
+ gelu_approximate="tanh",
1103
+ )
1104
+ ]
1105
+ configs.extend(phi)
1106
+
1107
+
1108
+ #############
1109
+ # Mistral AI
1110
+ #############
1111
+ mistral = [
1112
+ # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
1113
+ dict(
1114
+ name="Mistral-7B-{}v0.1",
1115
+ hf_config=dict(org="mistralai", name="Mistral-7B-{}v0.1"),
1116
+ padded_vocab_size=32000,
1117
+ block_size=4096, # should be 32768 but sliding window attention is not implemented
1118
+ n_layer=32,
1119
+ n_query_groups=8,
1120
+ rotary_percentage=1.0,
1121
+ parallel_residual=False,
1122
+ bias=False,
1123
+ _norm_class="RMSNorm",
1124
+ norm_eps=1e-05,
1125
+ _mlp_class="LLaMAMLP",
1126
+ intermediate_size=14336,
1127
+ )
1128
+ ]
1129
+ for c in mistral:
1130
+ for kind in ("", "Instruct-"):
1131
+ copy = c.copy()
1132
+ copy["name"] = c["name"].format(kind)
1133
+ copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
1134
+ configs.append(copy)
1135
+
1136
+
1137
+ ############
1138
+ # TinyLlama
1139
+ ############
1140
+ tiny_llama = [
1141
+ dict(
1142
+ name="tiny-llama-1.1b",
1143
+ hf_config=dict(org="PY007", name="TinyLlama-1.1B-intermediate-step-480k-1T"),
1144
+ block_size=2048,
1145
+ vocab_size=32000,
1146
+ padding_multiple=64,
1147
+ n_layer=22,
1148
+ n_head=32,
1149
+ n_embd=2048,
1150
+ rotary_percentage=1.0,
1151
+ parallel_residual=False,
1152
+ bias=False,
1153
+ _norm_class="RMSNorm", # original TinyLlama uses FusedRMSNorm
1154
+ norm_eps=1e-5,
1155
+ _mlp_class="LLaMAMLP",
1156
+ intermediate_size=5632,
1157
+ n_query_groups=4,
1158
+ ),
1159
+ dict(
1160
+ name="tiny-llama-new",
1161
+ hf_config=dict(org="PY007", name="TinyLlama-1.1B-intermediate-step-480k-1T"),
1162
+ block_size=768,
1163
+ vocab_size=32000,
1164
+ padding_multiple=64,
1165
+ n_layer=18,
1166
+ n_head=32,
1167
+ n_embd=1024,
1168
+ rotary_percentage=1.0,
1169
+ parallel_residual=False,
1170
+ bias=False,
1171
+ _norm_class="RMSNorm", # original TinyLlama uses FusedRMSNorm
1172
+ norm_eps=1e-5,
1173
+ _mlp_class="LLaMAMLP",
1174
+ intermediate_size=5632,
1175
+ n_query_groups=4,
1176
+ ),
1177
+ ]
1178
+ configs.extend(tiny_llama)
1179
+
1180
+
1181
+ name_to_config = {config["name"]: config for config in configs}
tsai_gpt/model.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Full definition of a GPT NeoX Language Model, all of it in this single file.
2
+
3
+ Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
4
+ https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
5
+ """
6
+ import math
7
+ from typing import Any, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from typing_extensions import Self
12
+
13
+ from tsai_gpt.config import Config
14
+
15
+
16
+
17
+ class GPT(nn.Module):
18
+ def __init__(self, config: Config) -> None:
19
+ super().__init__()
20
+ assert config.padded_vocab_size is not None
21
+ self.config = config
22
+
23
+ self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
24
+ self.transformer = nn.ModuleDict(
25
+ dict(
26
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
27
+ h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
28
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
29
+ )
30
+ )
31
+ self.max_seq_length = self.config.block_size
32
+ self.mask_cache: Optional[torch.Tensor] = None
33
+
34
+ @property
35
+ def max_seq_length(self) -> int:
36
+ return self._max_seq_length
37
+
38
+ @max_seq_length.setter
39
+ def max_seq_length(self, value: int) -> None:
40
+ """
41
+ When doing inference, the sequences used might be shorter than the model's context length.
42
+ This allows setting a smaller number to avoid allocating unused memory
43
+ """
44
+ if value > self.config.block_size:
45
+ raise ValueError(f"Cannot attend to {value}, block size is only {self.config.block_size}")
46
+ self._max_seq_length = value
47
+ if not hasattr(self, "cos"):
48
+ # first call
49
+ cos, sin = self.rope_cache()
50
+ self.register_buffer("cos", cos, persistent=False)
51
+ self.register_buffer("sin", sin, persistent=False)
52
+ elif value != self.cos.size(0):
53
+ # override
54
+ self.cos, self.sin = self.rope_cache(device=self.cos.device)
55
+ # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know
56
+ # if the kv cache is expected
57
+
58
+ def reset_parameters(self) -> None:
59
+ # Trigger resetting the rope-cache
60
+ self.max_seq_length = self.config.block_size
61
+
62
+ def _init_weights(self, module: nn.Module) -> None:
63
+ """Meant to be used with `gpt.apply(gpt._init_weights)`."""
64
+ if isinstance(module, nn.Linear):
65
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
66
+ if module.bias is not None:
67
+ torch.nn.init.zeros_(module.bias)
68
+ elif isinstance(module, nn.Embedding):
69
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
70
+
71
+ def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
72
+ T = idx.size(1)
73
+ if self.max_seq_length < T:
74
+ raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")
75
+
76
+ if input_pos is not None: # use the kv cache
77
+ cos = self.cos.index_select(0, input_pos)
78
+ sin = self.sin.index_select(0, input_pos)
79
+ if self.mask_cache is None:
80
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
81
+ mask = self.mask_cache.index_select(2, input_pos)
82
+ else:
83
+ cos = self.cos[:T]
84
+ sin = self.sin[:T]
85
+ mask = None
86
+
87
+ x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
88
+ for block in self.transformer.h:
89
+ x = block(x, cos, sin, mask, input_pos)
90
+ x = self.transformer.ln_f(x)
91
+ return self.lm_head(x) # (b, t, vocab_size)
92
+
93
+ @classmethod
94
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
95
+ return cls(Config.from_name(name, **kwargs))
96
+
97
+ def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tensor, torch.Tensor]:
98
+ return build_rope_cache(
99
+ seq_len=self.max_seq_length,
100
+ n_elem=self.config.rope_n_elem,
101
+ device=device,
102
+ condense_ratio=self.config.rope_condense_ratio,
103
+ base=self.config.rope_base,
104
+ )
105
+
106
+ def set_kv_cache(
107
+ self,
108
+ batch_size: int,
109
+ rope_cache_length: Optional[int] = None,
110
+ device: Optional[torch.device] = None,
111
+ dtype: Optional[torch.dtype] = None,
112
+ ) -> None:
113
+ if rope_cache_length is None:
114
+ rope_cache_length = self.cos.size(-1)
115
+ max_seq_length = self.max_seq_length
116
+
117
+ # initialize the kv cache for all blocks
118
+ for block in self.transformer.h:
119
+ block.attn.kv_cache = block.attn.build_kv_cache(
120
+ batch_size, max_seq_length, rope_cache_length, device, dtype
121
+ )
122
+
123
+ if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length:
124
+ # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask
125
+ # for the kv-cache support (only during inference), we only create it in that situation
126
+ # this will be resolved by https://github.com/pytorch/pytorch/issues/96099
127
+ ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)
128
+ self.mask_cache = torch.tril(ones).unsqueeze(0).unsqueeze(0)
129
+
130
+ def clear_kv_cache(self) -> None:
131
+ self.mask_cache = None
132
+ for block in self.transformer.h:
133
+ block.attn.kv_cache = None
134
+
135
+
136
+ class Block(nn.Module):
137
+ def __init__(self, config: Config) -> None:
138
+ super().__init__()
139
+ self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
140
+ self.attn = CausalSelfAttention(config)
141
+ self.norm_2 = None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps)
142
+ self.mlp = config.mlp_class(config)
143
+
144
+ self.config = config
145
+
146
+ def forward(
147
+ self,
148
+ x: torch.Tensor,
149
+ cos: torch.Tensor,
150
+ sin: torch.Tensor,
151
+ mask: Optional[torch.Tensor] = None,
152
+ input_pos: Optional[torch.Tensor] = None,
153
+ ) -> torch.Tensor:
154
+ n_1 = self.norm_1(x)
155
+ h = self.attn(n_1, cos, sin, mask, input_pos)
156
+ if self.config.parallel_residual:
157
+ n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
158
+ x = self.mlp(n_2) + h + x
159
+ else:
160
+ if self.config.shared_attention_norm:
161
+ raise NotImplementedError(
162
+ "No checkpoint amongst the ones we support uses this configuration"
163
+ " (non-parallel residual and shared attention norm)."
164
+ )
165
+ x = h + x
166
+ x = self.mlp(self.norm_2(x)) + x
167
+ return x
168
+
169
+
170
+ class CausalSelfAttention(nn.Module):
171
+ def __init__(self, config: Config) -> None:
172
+ super().__init__()
173
+ shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
174
+ # key, query, value projections for all heads, but in a batch
175
+ self.attn = nn.Linear(config.n_embd, shape, bias=config.bias)
176
+ # output projection
177
+ self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
178
+ # disabled by default
179
+ self.kv_cache: Optional[KVCache] = None
180
+
181
+ self.config = config
182
+
183
+ def forward(
184
+ self,
185
+ x: torch.Tensor,
186
+ cos: torch.Tensor,
187
+ sin: torch.Tensor,
188
+ mask: Optional[torch.Tensor] = None,
189
+ input_pos: Optional[torch.Tensor] = None,
190
+ ) -> torch.Tensor:
191
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
192
+
193
+ qkv = self.attn(x)
194
+
195
+ # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
196
+ q_per_kv = self.config.n_head // self.config.n_query_groups
197
+ total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
198
+ qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size)
199
+ qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)
200
+
201
+ # split batched computation into three
202
+ q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
203
+
204
+ # maybe repeat k and v if for the non multi-head attention cases
205
+ # training: flash attention requires it
206
+ # inference: multi-query would require a full kv cache so avoid it to limit its memory usage
207
+ if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1):
208
+ k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
209
+ v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
210
+
211
+ q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs)
212
+ k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs)
213
+ v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs)
214
+
215
+ q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin)
216
+ k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin)
217
+ q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1)
218
+ k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1)
219
+
220
+ if input_pos is not None:
221
+ if not isinstance(self.kv_cache, KVCache):
222
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
223
+ k, v = self.kv_cache(input_pos, k, v)
224
+
225
+ y = self.scaled_dot_product_attention(q, k, v, mask)
226
+
227
+ y = y.reshape(B, T, C) # re-assemble all head outputs side by side
228
+
229
+ # output projection
230
+ return self.proj(y)
231
+
232
+ def scaled_dot_product_attention(
233
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
234
+ ) -> torch.Tensor:
235
+ scale = 1.0 / math.sqrt(self.config.head_size)
236
+ y = torch.nn.functional.scaled_dot_product_attention(
237
+ q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
238
+ )
239
+ return y.transpose(1, 2)
240
+
241
+ def build_kv_cache(
242
+ self,
243
+ batch_size: int,
244
+ max_seq_length: int,
245
+ rope_cache_length: Optional[int] = None,
246
+ device: Optional[torch.device] = None,
247
+ dtype: Optional[torch.dtype] = None,
248
+ ) -> "KVCache":
249
+ heads = 1 if self.config.n_query_groups == 1 else self.config.n_head
250
+ v_shape = (batch_size, heads, max_seq_length, self.config.head_size)
251
+ if rope_cache_length is None:
252
+ if self.config.rotary_percentage != 1.0:
253
+ raise TypeError("Please pass the `rope_cache_length=gpt.cos.size(-1)` value")
254
+ k_shape = v_shape
255
+ else:
256
+ k_shape = (
257
+ batch_size,
258
+ heads,
259
+ max_seq_length,
260
+ rope_cache_length + self.config.head_size - self.config.rope_n_elem,
261
+ )
262
+ return KVCache(k_shape, v_shape, device=device, dtype=dtype)
263
+
264
+
265
+ class GptNeoxMLP(nn.Module):
266
+ def __init__(self, config: Config) -> None:
267
+ super().__init__()
268
+ self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
269
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
270
+
271
+ self.config = config
272
+
273
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
274
+ x = self.fc(x)
275
+ x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate)
276
+ return self.proj(x)
277
+
278
+
279
+ class LLaMAMLP(nn.Module):
280
+ def __init__(self, config: Config) -> None:
281
+ super().__init__()
282
+ self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
283
+ self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
284
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
285
+
286
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
287
+ x_fc_1 = self.fc_1(x)
288
+ x_fc_2 = self.fc_2(x)
289
+ x = torch.nn.functional.silu(x_fc_1) * x_fc_2
290
+ return self.proj(x)
291
+
292
+
293
+ def build_rope_cache(
294
+ seq_len: int, n_elem: int, device: Optional[torch.device] = None, base: int = 10000, condense_ratio: int = 1
295
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
296
+ """Enhanced Transformer with Rotary Position Embedding.
297
+
298
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
299
+ transformers/rope/__init__.py. MIT License:
300
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
301
+ """
302
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
303
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
304
+
305
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
306
+ seq_idx = torch.arange(seq_len, device=device) / condense_ratio
307
+
308
+ # Calculate the product of position index and $\theta_i$
309
+ idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
310
+
311
+ return torch.cos(idx_theta), torch.sin(idx_theta)
312
+
313
+
314
+ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
315
+ head_size = x.size(-1)
316
+ x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
317
+ x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
318
+ rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
319
+ roped = (x * cos) + (rotated * sin)
320
+ return roped.type_as(x)
321
+
322
+
323
+ class KVCache(nn.Module):
324
+ def __init__(
325
+ self,
326
+ k_shape: Tuple[int, int, int, int],
327
+ v_shape: Tuple[int, int, int, int],
328
+ device: Optional[torch.device] = None,
329
+ dtype: Optional[torch.dtype] = None,
330
+ ) -> None:
331
+ super().__init__()
332
+ self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False)
333
+ self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False)
334
+
335
+ def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
336
+ # move the buffer to the activation dtype for when AMP is used
337
+ self.k = self.k.to(k.dtype)
338
+ self.v = self.v.to(v.dtype)
339
+ # update the cache
340
+ k = self.k.index_copy_(2, input_pos, k)
341
+ v = self.v.index_copy_(2, input_pos, v)
342
+ return k, v
tsai_gpt/packed_dataset.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Very loosely inspired by indexed_dataset in Fairseq, Megatron
2
+ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py
3
+
4
+
5
+ import os
6
+ import random
7
+ import struct
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch.utils.data import IterableDataset, get_worker_info
12
+
13
+ dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float32, 7: np.float64, 8: np.uint16}
14
+
15
+
16
+ def code(dtype):
17
+ for k in dtypes:
18
+ if dtypes[k] == dtype:
19
+ return k
20
+ raise ValueError(dtype)
21
+
22
+
23
+ HDR_MAGIC = b"LITPKDS"
24
+ HDR_SIZE = 24 # bytes
25
+
26
+
27
+ class PackedDataset(IterableDataset):
28
+ def __init__(
29
+ self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0
30
+ ):
31
+ self._filenames = filenames
32
+ self._n_chunks = n_chunks
33
+ self._block_size = block_size
34
+ self._seed = seed
35
+ self._shuffle = shuffle
36
+ self._wrap = wrap
37
+ self._num_processes = num_processes
38
+ self._process_rank = process_rank
39
+
40
+ def __iter__(self):
41
+ worker_info = get_worker_info()
42
+ num_workers = worker_info.num_workers if worker_info is not None else 1
43
+ worker_id = worker_info.id if worker_info is not None else 0
44
+ num_shards = num_workers * self._num_processes
45
+ shard_id = self._process_rank * num_workers + worker_id
46
+
47
+ max_num_files = len(self._filenames) // num_shards * num_shards
48
+ filenames = self._filenames[shard_id:max_num_files:num_shards]
49
+
50
+ return PackedDatasetIterator(
51
+ filenames=filenames,
52
+ n_chunks=self._n_chunks,
53
+ block_size=self._block_size,
54
+ seed=self._seed,
55
+ shuffle=self._shuffle,
56
+ wrap=self._wrap,
57
+ )
58
+
59
+
60
+ class PackedDatasetBuilder(object):
61
+ def __init__(self, outdir, prefix, chunk_size, sep_token, dtype="auto", vocab_size=None):
62
+ if dtype == "auto":
63
+ if vocab_size is None:
64
+ raise ValueError("vocab_size cannot be None when dtype='auto'")
65
+ if vocab_size is not None and vocab_size < 65500:
66
+ self._dtype = np.uint16
67
+ else:
68
+ self._dtype = np.int32
69
+ else:
70
+ self._dtype = dtype
71
+ self._counter = 0
72
+ self._chunk_size = chunk_size
73
+ self._outdir = outdir
74
+ self._prefix = prefix
75
+ self._sep_token = sep_token
76
+ self._arr = np.zeros(self._chunk_size, dtype=self._dtype)
77
+ self._arr.fill(self._sep_token)
78
+ self._idx = 0
79
+ self._version = 1
80
+ self._filenames = []
81
+
82
+ def _write_chunk(self):
83
+ filename = f"{self._prefix}_{self._counter:010d}.bin"
84
+ filename = os.path.join(self._outdir, filename)
85
+
86
+ with open(filename, "wb") as f:
87
+ f.write(HDR_MAGIC)
88
+ f.write(struct.pack("<Q", self._version))
89
+ f.write(struct.pack("<B", code(self._dtype)))
90
+ f.write(struct.pack("<Q", self._chunk_size))
91
+ f.write(self._arr.tobytes(order="C"))
92
+
93
+ self._filenames.append(filename)
94
+ self._counter += 1
95
+ self._arr.fill(self._sep_token)
96
+ self._idx = 0
97
+
98
+ @property
99
+ def dtype(self):
100
+ return self._dtype
101
+
102
+ @property
103
+ def filenames(self):
104
+ return self._filenames.copy()
105
+
106
+ def add_array(self, arr):
107
+ while self._idx + arr.shape[0] > self._chunk_size:
108
+ part_len = self._chunk_size - self._idx
109
+ self._arr[self._idx : self._idx + part_len] = arr[:part_len]
110
+ self._write_chunk()
111
+ arr = arr[part_len:]
112
+
113
+ arr_len = arr.shape[0]
114
+ self._arr[self._idx : self._idx + arr_len] = arr
115
+ self._idx += arr_len
116
+
117
+ def write_reminder(self):
118
+ self._write_chunk()
119
+
120
+
121
+ class PackedDatasetIterator:
122
+ def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap):
123
+ self._seed = seed
124
+ self._shuffle = shuffle
125
+ self._rng = np.random.default_rng(seed) if shuffle else None
126
+ self._block_idxs = None
127
+
128
+ self._wrap = wrap
129
+
130
+ # TODO: instead of filenames, we could have a single text stream
131
+ # (or text file) with the sequence of all files to be
132
+ # fetched/loaded.
133
+ self._filenames = filenames
134
+ self._file_idx = 0
135
+
136
+ self._n_chunks = n_chunks
137
+
138
+ self._dtype = None
139
+ self._block_size = block_size
140
+ self._n_blocks = None
141
+
142
+ self._mmaps = []
143
+ self._buffers = []
144
+
145
+ self._block_idxs = []
146
+ self._curr_idx = 0
147
+
148
+ self._load_n_chunks()
149
+
150
+ def _read_header(self, path):
151
+ with open(path, "rb") as f:
152
+ magic = f.read(len(HDR_MAGIC))
153
+ assert magic == HDR_MAGIC, "File doesn't match expected format."
154
+ version = struct.unpack("<Q", f.read(8))
155
+ assert version == (1,)
156
+ (dtype_code,) = struct.unpack("<B", f.read(1))
157
+ dtype = dtypes[dtype_code]
158
+ (chunk_size,) = struct.unpack("<Q", f.read(8))
159
+ return dtype, chunk_size
160
+
161
+ def _close_mmaps(self):
162
+ for mmap in self._mmaps:
163
+ mmap._mmap.close()
164
+
165
+ def _load_n_chunks(self):
166
+ self._close_mmaps()
167
+ self._mmaps = []
168
+ self._buffers = []
169
+
170
+ if self._n_chunks > len(self._filenames[self._file_idx :]):
171
+ if not self._wrap:
172
+ raise StopIteration
173
+ self._file_idx = 0
174
+
175
+ for i in range(self._n_chunks):
176
+ filename = self._filenames[self._file_idx + i]
177
+ if self._dtype is None:
178
+ self._dtype, self._chunk_size = self._read_header(filename)
179
+ self._n_blocks = self._chunk_size // self._block_size
180
+ # TODO: check header matches with previous files
181
+ mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE)
182
+ self._mmaps.append(mmap)
183
+ self._buffers.append(memoryview(mmap))
184
+
185
+ self._file_idx += self._n_chunks
186
+ n_all_blocks = self._n_chunks * self._n_blocks
187
+
188
+ self._block_idxs = self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks)
189
+
190
+ self._curr_idx = 0
191
+
192
+ def __del__(self):
193
+ self._close_mmaps()
194
+ del self._mmaps
195
+ del self._buffers
196
+
197
+ def __iter__(self):
198
+ return self
199
+
200
+ def __next__(self):
201
+ if self._curr_idx >= len(self._block_idxs):
202
+ self._load_n_chunks()
203
+ # TODO: trigger fetching next next n_chunks if remote
204
+ block_idx = self._block_idxs[self._curr_idx]
205
+ chunk_id = block_idx // self._n_blocks
206
+ buffer = self._buffers[chunk_id]
207
+ elem_id = (block_idx % self._n_blocks) * self._block_size
208
+ offset = np.dtype(self._dtype).itemsize * elem_id
209
+ arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)
210
+ self._curr_idx += 1
211
+ return torch.from_numpy(arr.astype(np.int64))
212
+
213
+
214
+ class CombinedDataset(IterableDataset):
215
+ def __init__(self, datasets, seed, weights=None):
216
+ self._seed = seed
217
+ self._datasets = datasets
218
+ self._weights = weights
219
+ n_datasets = len(datasets)
220
+ if weights is None:
221
+ self._weights = [1 / n_datasets] * n_datasets
222
+
223
+ def __iter__(self):
224
+ return CombinedDatasetIterator(self._datasets, self._seed, self._weights)
225
+
226
+
227
+ class CombinedDatasetIterator:
228
+ def __init__(self, datasets, seed, weights):
229
+ self._datasets = [iter(el) for el in datasets]
230
+ self._weights = weights
231
+ self._rng = random.Random(seed)
232
+
233
+ def __next__(self):
234
+ (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1)
235
+ return next(dataset)
tsai_gpt/rmsnorm.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class RMSNorm(torch.nn.Module):
5
+ """Root Mean Square Layer Normalization.
6
+
7
+ Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
8
+ https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
9
+ """
10
+
11
+ def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
12
+ super().__init__()
13
+ self.weight = torch.nn.Parameter(torch.ones(size))
14
+ self.eps = eps
15
+ self.dim = dim
16
+
17
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
18
+ dtype = x.dtype
19
+ x = x.float()
20
+ # NOTE: the original RMSNorm paper implementation is not equivalent
21
+ norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
22
+ x_normed = x * torch.rsqrt(norm_x + self.eps)
23
+ return (self.weight * x_normed).to(dtype=dtype)
24
+
25
+ def reset_parameters(self) -> None:
26
+ torch.nn.init.ones_(self.weight)
tsai_gpt/speed_monitor.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from collections import deque
3
+ from contextlib import nullcontext
4
+ from typing import Any, Callable, Deque, Dict, Optional
5
+
6
+ import torch
7
+ from lightning import Callback, Fabric, LightningModule, Trainer
8
+ from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1
9
+ from lightning.fabric.plugins import (
10
+ BitsandbytesPrecision,
11
+ DoublePrecision,
12
+ FSDPPrecision,
13
+ HalfPrecision,
14
+ MixedPrecision,
15
+ Precision,
16
+ TransformerEnginePrecision,
17
+ XLAPrecision,
18
+ )
19
+ from lightning.fabric.utilities.rank_zero import rank_zero_only as fabric_rank_zero_only
20
+ from lightning.pytorch.plugins import (
21
+ DoublePrecisionPlugin,
22
+ FSDPPrecisionPlugin,
23
+ HalfPrecisionPlugin,
24
+ MixedPrecisionPlugin,
25
+ XLAPrecisionPlugin,
26
+ )
27
+ from lightning.pytorch.utilities.rank_zero import rank_zero_only as trainer_rank_zero_only
28
+ from torch.utils.flop_counter import FlopCounterMode
29
+
30
+ from tsai_gpt import GPT
31
+ from tsai_gpt.utils import num_parameters
32
+
33
+ GPU_AVAILABLE_FLOPS = {
34
+ # source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
35
+ # nvidia publishes spec sheet with a 2x sparsity factor
36
+ "h100-sxm": {
37
+ torch.float64: 67e12,
38
+ torch.float32: 67e12,
39
+ torch.bfloat16: 1.979e15 / 2,
40
+ torch.float16: 1.979e15 / 2,
41
+ torch.int8: 3.958e15 / 2,
42
+ },
43
+ "h100-pcie": {
44
+ torch.float64: 51e12,
45
+ torch.float32: 51e12,
46
+ torch.bfloat16: 1.513e15 / 2,
47
+ torch.float16: 1.513e15 / 2,
48
+ torch.int8: 3.026e15 / 2,
49
+ },
50
+ # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
51
+ # sxm and pcie have same flop counts
52
+ "a100": {torch.float64: 19.5e12, torch.float32: 19.5e12, torch.bfloat16: 312e12, torch.float16: 312e12},
53
+ # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/a10-datasheet.pdf
54
+ "a10g": {torch.float32: 31.2e12, torch.bfloat16: 125e12, torch.float16: 125e12},
55
+ # source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf
56
+ "v100-sxm": {torch.float64: 7.8e12, torch.float32: 15.7e12, torch.float16: 125e12},
57
+ "v100-pcie": {torch.float64: 7e12, torch.float32: 14e12, torch.float16: 112e12},
58
+ "v100s-pcie": {torch.float64: 8.2e12, torch.float32: 16.4e12, torch.float16: 130e12},
59
+ # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf
60
+ # sxm and pcie have same flop counts
61
+ "t4": {torch.float32: 8.1e12, torch.float16: 65e12, torch.int8: 130e12},
62
+ # https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/quadro-rtx-5000-data-sheet-us-nvidia-704120-r4-web.pdf
63
+ "quadro rtx 5000": {torch.float32: 11.2e12, torch.float16: 89.2e12},
64
+ }
65
+
66
+ TPU_AVAILABLE_FLOPS = {
67
+ # flop count for each TPU generation is the same for all precisions
68
+ # since bfloat16 precision is always used for performing matrix operations
69
+ # for more info: https://cloud.google.com/tpu/docs/bfloat16#choosing_bfloat16
70
+ # source: https://arxiv.org/pdf/1907.10701.pdf
71
+ "v2": 45e12,
72
+ # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v3
73
+ "v3": 123e12,
74
+ # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v4
75
+ "v4": 275e12,
76
+ # source: https://cloud.google.com/tpu/docs/v5e-training
77
+ "v5litepod": 197e12,
78
+ }
79
+
80
+
81
+ def get_flops_available(device: torch.device, dtype: torch.dtype) -> Optional[float]:
82
+ if device.type == "cuda":
83
+ device_name = torch.cuda.get_device_name(device).lower()
84
+ if "h100" in device_name and "hbm3" in device_name:
85
+ device_name = "h100-sxm"
86
+ elif "h100" in device_name and ("pcie" in device_name or "hbm2e" in device_name):
87
+ device_name = "h100-pcie"
88
+ elif "a100" in device_name:
89
+ device_name = "a100"
90
+ elif "a10g" in device_name:
91
+ device_name = "a10g"
92
+ elif "v100-sxm" in device_name:
93
+ device_name = "v100-sxm"
94
+ elif "v100-pcie" in device_name:
95
+ device_name = "v100-pcie"
96
+ elif "t4" in device_name:
97
+ device_name = "t4"
98
+ elif "quadro rtx 5000" in device_name:
99
+ device_name = "quadro rtx 5000"
100
+ else:
101
+ device_name = None
102
+
103
+ if device_name is not None:
104
+ try:
105
+ return int(GPU_AVAILABLE_FLOPS[device_name][dtype])
106
+ except KeyError:
107
+ raise KeyError(
108
+ f"flop count not found for {device_name} with dtype: {dtype}; "
109
+ "MFU cannot be calculated and reported."
110
+ )
111
+ elif device.type == "xla":
112
+ if _XLA_GREATER_EQUAL_2_1:
113
+ from torch_xla._internal import tpu
114
+ else:
115
+ from torch_xla.experimental import tpu
116
+
117
+ device_name = tpu.get_tpu_env()["TYPE"].lower()
118
+ try:
119
+ return int(TPU_AVAILABLE_FLOPS[device_name])
120
+ except KeyError:
121
+ raise KeyError(
122
+ f"flop count not found for {device_name} with dtype: {dtype}; MFU cannot be calculated and reported."
123
+ )
124
+
125
+ return None
126
+
127
+
128
+ # Adapted from https://github.com/mosaicml/composer/blob/f2a2dc820cb75023b9eb7c46fdfd25273712abd0/composer/callbacks/speed_monitor.py
129
+
130
+
131
+ class SpeedMonitorBase:
132
+ """Logs the training throughput and utilization.
133
+
134
+ +-------------------------------------+-----------------------------------------------------------+
135
+ | Key | Logged data |
136
+ +=====================================+===========================================================+
137
+ | | Rolling average (over `window_size` most recent |
138
+ | `throughput/batches_per_sec` | batches) of the number of batches processed per second |
139
+ | | |
140
+ +-------------------------------------+-----------------------------------------------------------+
141
+ | | Rolling average (over `window_size` most recent |
142
+ | `throughput/samples_per_sec` | batches) of the number of samples processed per second |
143
+ | | |
144
+ +-------------------------------------+-----------------------------------------------------------+
145
+ | | Rolling average (over `window_size` most recent |
146
+ | `throughput/tokens_per_sec` | batches) of the number of tokens processed per second. |
147
+ | | This may include padding depending on dataset |
148
+ +-------------------------------------+-----------------------------------------------------------+
149
+ | | Estimates flops by `flops_per_batch * batches_per_sec` |
150
+ | `throughput/flops_per_sec` | |
151
+ | | |
152
+ +-------------------------------------+-----------------------------------------------------------+
153
+ | `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size |
154
+ +-------------------------------------+-----------------------------------------------------------+
155
+ | `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size |
156
+ +-------------------------------------+-----------------------------------------------------------+
157
+ | | `throughput/tokens_per_sec` divided by world size. This |
158
+ | `throughput/device/tokens_per_sec` | may include pad tokens depending on dataset |
159
+ | | |
160
+ +-------------------------------------+-----------------------------------------------------------+
161
+ | | `throughput/flops_per_sec` divided by world size. Only |
162
+ | `throughput/device/flops_per_sec` | logged when model has attribute `flops_per_batch` |
163
+ | | |
164
+ +-------------------------------------+-----------------------------------------------------------+
165
+ | | `throughput/device/flops_per_sec` divided by world size. |
166
+ | `throughput/device/mfu` | |
167
+ | | |
168
+ +-------------------------------------+-----------------------------------------------------------+
169
+ | `time/train` | Total elapsed training time |
170
+ +-------------------------------------+-----------------------------------------------------------+
171
+ | `time/val` | Total elapsed validation time |
172
+ +-------------------------------------+-----------------------------------------------------------+
173
+ | `time/total` | Total elapsed time (time/train + time/val) |
174
+ +-------------------------------------+-----------------------------------------------------------+
175
+
176
+ Notes:
177
+ - The implementation assumes that devices are homogeneous as it normalizes by the world size.
178
+ - Tokens/sec, flops/sec and MFU do not account for padding tokens if present. We suggest using samples/sec or
179
+ batches/sec to measure throughput under this circumstance.
180
+ - Be careful when comparing MFU numbers across projects, as this will highly depend on the ``flops_per_batch``.
181
+ There is no widespread, realistic, and reliable implementation to compute them.
182
+ We suggest using our ``measure_flops`` function, but many other works will use ``estimated_flops`` which
183
+ will almost always be an overestimate when compared to the true value.
184
+
185
+ Args:
186
+ window_size (int, optional): Number of batches to use for a rolling average of throughput.
187
+ Defaults to 100.
188
+ time_unit (str, optional): Time unit to use for `time` logging. Can be one of
189
+ 'seconds', 'minutes', 'hours', or 'days'. Defaults to 'hours'.
190
+ """
191
+
192
+ def __init__(
193
+ self,
194
+ flops_available: float,
195
+ log_dict: Callable[[Dict, int], None],
196
+ window_size: int = 100,
197
+ time_unit: str = "hours",
198
+ ):
199
+ self.flops_available = flops_available
200
+ self.log_dict = log_dict
201
+
202
+ # Track the batch num samples and wct to compute throughput over a window of batches
203
+ self.history_samples: Deque[int] = deque(maxlen=window_size + 1)
204
+ self.history_wct: Deque[float] = deque(maxlen=window_size + 1)
205
+ self.history_lengths: Deque[int] = deque(maxlen=window_size + 1)
206
+ self.history_flops: Deque[int] = deque(maxlen=window_size + 1)
207
+
208
+ self.divider = 1
209
+ if time_unit == "seconds":
210
+ self.divider = 1
211
+ elif time_unit == "minutes":
212
+ self.divider = 60
213
+ elif time_unit == "hours":
214
+ self.divider = 60 * 60
215
+ elif time_unit == "days":
216
+ self.divider = 60 * 60 * 24
217
+ else:
218
+ raise ValueError(
219
+ f'Invalid time_unit: {time_unit}. Must be one of "seconds", "minutes", "hours", or "days".'
220
+ )
221
+
222
+ # Keep track of time spent evaluating
223
+ self.total_eval_wct = 0.0
224
+ self.step = -1
225
+
226
+ def on_train_batch_end(
227
+ self,
228
+ samples: int, # total samples seen (per device)
229
+ train_elapsed: float, # total training time (seconds)
230
+ world_size: int,
231
+ flops_per_batch: Optional[int] = None, # (per device)
232
+ lengths: Optional[int] = None, # total length of the samples seen (per device)
233
+ ) -> None:
234
+ self.step += 1
235
+ step = self.step
236
+ metrics = {}
237
+
238
+ self.history_samples.append(samples)
239
+ if lengths is not None:
240
+ self.history_lengths.append(lengths)
241
+ # if lengths are passed, there should be as many values as samples
242
+ assert len(self.history_samples) == len(self.history_lengths)
243
+ self.history_wct.append(train_elapsed)
244
+ if len(self.history_wct) == self.history_wct.maxlen:
245
+ elapsed_batches = len(self.history_samples) - 1
246
+ elapsed_samples = self.history_samples[-1] - self.history_samples[0]
247
+ elapsed_wct = self.history_wct[-1] - self.history_wct[0]
248
+ samples_per_sec = elapsed_samples * world_size / elapsed_wct
249
+ dev_samples_per_sec = elapsed_samples / elapsed_wct
250
+ metrics.update(
251
+ {
252
+ "throughput/batches_per_sec": elapsed_batches * world_size / elapsed_wct,
253
+ "throughput/samples_per_sec": samples_per_sec,
254
+ "throughput/device/batches_per_sec": elapsed_batches / elapsed_wct,
255
+ "throughput/device/samples_per_sec": dev_samples_per_sec,
256
+ }
257
+ )
258
+ if lengths is not None:
259
+ elapsed_lengths = int(self.history_lengths[-1]) - int(self.history_lengths[0])
260
+ avg_length = elapsed_lengths / elapsed_batches
261
+ metrics.update(
262
+ {
263
+ "throughput/tokens_per_sec": samples_per_sec * avg_length,
264
+ "throughput/device/tokens_per_sec": dev_samples_per_sec * avg_length,
265
+ }
266
+ )
267
+
268
+ if flops_per_batch is not None:
269
+ # sum of flops per batch across ranks
270
+ self.history_flops.append(flops_per_batch * world_size)
271
+ if len(self.history_flops) == self.history_flops.maxlen:
272
+ elapsed_flops = sum(self.history_flops) - self.history_flops[0]
273
+ elapsed_wct = self.history_wct[-1] - self.history_wct[0]
274
+ flops_per_sec = elapsed_flops / elapsed_wct
275
+ device_flops_per_sec = flops_per_sec / world_size
276
+ metrics.update(
277
+ {"throughput/flops_per_sec": flops_per_sec, "throughput/device/flops_per_sec": device_flops_per_sec}
278
+ )
279
+ if self.flops_available:
280
+ metrics["throughput/device/mfu"] = device_flops_per_sec / self.flops_available
281
+
282
+ metrics.update(
283
+ {
284
+ "time/train": train_elapsed / self.divider,
285
+ "time/val": self.total_eval_wct / self.divider,
286
+ "time/total": (train_elapsed + self.total_eval_wct) / self.divider,
287
+ "samples": samples,
288
+ }
289
+ )
290
+
291
+ self.log_dict(metrics, step)
292
+
293
+ def eval_end(self, eval_elapsed: float) -> None:
294
+ self.total_eval_wct += eval_elapsed # seconds
295
+
296
+
297
+ def plugin_to_compute_dtype(plugin: Precision) -> torch.dtype:
298
+ if isinstance(plugin, BitsandbytesPrecision):
299
+ return plugin.dtype
300
+ if isinstance(plugin, (HalfPrecision, MixedPrecision, HalfPrecisionPlugin)):
301
+ return plugin._desired_input_dtype
302
+ if isinstance(plugin, MixedPrecisionPlugin):
303
+ return torch.bfloat16 if plugin.precision == "bf16-mixed" else torch.half
304
+ if isinstance(plugin, (DoublePrecision, DoublePrecisionPlugin)):
305
+ return torch.double
306
+ if isinstance(plugin, (XLAPrecision, XLAPrecisionPlugin)):
307
+ return plugin._desired_dtype
308
+ if isinstance(plugin, TransformerEnginePrecision):
309
+ return torch.int8
310
+ if isinstance(plugin, (FSDPPrecision, FSDPPrecisionPlugin)):
311
+ return plugin.mixed_precision_config.reduce_dtype
312
+ if isinstance(plugin, Precision):
313
+ return torch.float32
314
+ raise NotImplementedError(plugin)
315
+
316
+
317
+ class SpeedMonitorFabric(SpeedMonitorBase):
318
+ def __init__(self, fabric: Fabric, *args: Any, **kwargs: Any) -> None:
319
+ dtype = plugin_to_compute_dtype(fabric.strategy.precision)
320
+ flops_available = get_flops_available(fabric.device, dtype)
321
+ super().__init__(flops_available, fabric.log_dict, *args, **kwargs)
322
+
323
+ @fabric_rank_zero_only
324
+ def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None:
325
+ super().on_train_batch_end(*args, **kwargs)
326
+
327
+
328
+ class SpeedMonitorCallback(Callback):
329
+ def __init__(self, length_fn: Callable[[Any], int], batch_size: int, **kwargs: Any) -> None:
330
+ super().__init__()
331
+ self.speed_monitor: Optional[SpeedMonitorBase] = None
332
+ self.speed_monitor_kwargs = kwargs
333
+ self.length_fn = length_fn
334
+ self.batch_size = batch_size
335
+ self.eval_t0: int = 0
336
+ self.train_t0: int = 0
337
+ self.total_lengths: int = 0
338
+
339
+ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
340
+ if self.speed_monitor is not None:
341
+ return # already setup
342
+ dtype = plugin_to_compute_dtype(trainer.precision_plugin)
343
+ flops_available = get_flops_available(trainer.strategy.root_device, dtype)
344
+ self.speed_monitor = SpeedMonitorBase(flops_available, trainer.logger.log_metrics, **self.speed_monitor_kwargs)
345
+
346
+ @trainer_rank_zero_only
347
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
348
+ if trainer.fit_loop._should_accumulate():
349
+ return
350
+
351
+ self.train_t0 = time.perf_counter()
352
+
353
+ @trainer_rank_zero_only
354
+ def on_train_batch_end(
355
+ self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int
356
+ ) -> None:
357
+ self.total_lengths += self.length_fn(batch)
358
+ if trainer.fit_loop._should_accumulate():
359
+ return
360
+ train_elapsed = time.perf_counter() - self.train_t0
361
+ assert self.speed_monitor is not None
362
+ iter_num = trainer.fit_loop.total_batch_idx
363
+ assert (measured_flops := pl_module.measured_flops) is not None
364
+ self.speed_monitor.on_train_batch_end(
365
+ (iter_num + 1) * self.batch_size,
366
+ train_elapsed,
367
+ # this assumes that device FLOPs are the same and that all devices have the same batch size
368
+ trainer.world_size,
369
+ flops_per_batch=measured_flops,
370
+ lengths=self.total_lengths,
371
+ )
372
+
373
+ @trainer_rank_zero_only
374
+ def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
375
+ self.eval_t0 = time.perf_counter()
376
+
377
+ @trainer_rank_zero_only
378
+ def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
379
+ eval_elapsed = time.perf_counter() - self.eval_t0
380
+ assert self.speed_monitor is not None
381
+ self.speed_monitor.eval_end(eval_elapsed)
382
+
383
+
384
+ def flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int:
385
+ flops_per_token = 2 * n_params # each parameter is used for a MAC (2 FLOPS) per network operation
386
+ # this assumes that all samples have a fixed length equal to the block size
387
+ # which is most likely false during finetuning
388
+ flops_per_seq = flops_per_token * max_seq_length
389
+ attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))
390
+ return flops_per_seq + attn_flops_per_seq
391
+
392
+
393
+ def estimate_flops(model: GPT) -> int:
394
+ """Measures estimated FLOPs for MFU.
395
+
396
+ Refs:
397
+ * https://ar5iv.labs.arxiv.org/html/2205.05198#A1
398
+ * https://ar5iv.labs.arxiv.org/html/2204.02311#A2
399
+ """
400
+ # using all parameters for this is a naive over estimation because not all model parameters actually contribute to
401
+ # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage
402
+ # (~10%) compared to the measured FLOPs, making those lower but more realistic.
403
+ # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.
404
+ n_trainable_params = num_parameters(model, requires_grad=True)
405
+ trainable_flops = flops_per_param(
406
+ model.max_seq_length, model.config.n_layer, model.config.n_embd, n_trainable_params
407
+ )
408
+ # forward + backward + gradients (assumes no gradient accumulation)
409
+ ops_per_step = 3 if model.training else 1
410
+ n_frozen_params = num_parameters(model, requires_grad=False)
411
+ frozen_flops = flops_per_param(model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params)
412
+ # forward + backward
413
+ frozen_ops_per_step = 2 if model.training else 1
414
+ return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops
415
+
416
+
417
+ def measure_flops(model: GPT, x: torch.Tensor) -> int:
418
+ """Measures real FLOPs for HFU"""
419
+ flop_counter = FlopCounterMode(model, display=False)
420
+ ctx = nullcontext() if model.training else torch.no_grad()
421
+ with ctx, flop_counter:
422
+ y = model(x)
423
+ if model.training:
424
+ y.sum().backward()
425
+ return flop_counter.get_total_flops()
tsai_gpt/tokenizer.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+
8
+ class Tokenizer:
9
+ def __init__(self, checkpoint_dir: Path) -> None:
10
+ self.use_bos = self.check_if_bos_token_used(checkpoint_dir)
11
+ self.bos_id = None
12
+ self.eos_id = None
13
+
14
+ # some checkpoints have both files, `.model` takes precedence
15
+ if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file():
16
+ from sentencepiece import SentencePieceProcessor
17
+
18
+ self.processor = SentencePieceProcessor(model_file=str(vocabulary_path))
19
+ self.backend = "sentencepiece"
20
+ self.bos_id = self.processor.bos_id()
21
+ self.eos_id = self.processor.eos_id()
22
+
23
+ elif (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file():
24
+ from tokenizers import Tokenizer as HFTokenizer
25
+
26
+ self.processor = HFTokenizer.from_file(str(vocabulary_path))
27
+ self.backend = "huggingface"
28
+
29
+ if (special_tokens_path := checkpoint_dir / "tokenizer_config.json").is_file():
30
+ with open(special_tokens_path) as fp:
31
+ config = json.load(fp)
32
+ bos_token = config.get("bos_token")
33
+ self.bos_id = self.token_to_id(bos_token) if bos_token is not None else None
34
+ eos_token = config.get("eos_token")
35
+ self.eos_id = self.token_to_id(eos_token) if eos_token is not None else None
36
+ if (special_tokens_path := checkpoint_dir / "generation_config.json").is_file():
37
+ with open(special_tokens_path) as fp:
38
+ config = json.load(fp)
39
+ if self.bos_id is None:
40
+ self.bos_id = config.get("bos_token_id")
41
+ if self.eos_id is None:
42
+ self.eos_id = config.get("eos_token_id")
43
+ else:
44
+ raise NotImplementedError
45
+
46
+ @property
47
+ def vocab_size(self) -> int:
48
+ if self.backend == "huggingface":
49
+ return self.processor.get_vocab_size(with_added_tokens=False)
50
+ if self.backend == "sentencepiece":
51
+ return self.processor.vocab_size()
52
+ raise RuntimeError
53
+
54
+ def token_to_id(self, token: str) -> int:
55
+ if self.backend == "huggingface":
56
+ id_ = self.processor.token_to_id(token)
57
+ elif self.backend == "sentencepiece":
58
+ id_ = self.processor.piece_to_id(token)
59
+ else:
60
+ raise RuntimeError
61
+ if id_ is None:
62
+ raise ValueError(f"token {token!r} not found in the collection.")
63
+ return id_
64
+
65
+ def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
66
+ if not (tokenizer_config_path := checkpoint_dir / "tokenizer_config.json").is_file():
67
+ return False
68
+ with open(tokenizer_config_path) as fp:
69
+ config = json.load(fp)
70
+ if any(config.get(check, False) for check in ("add_bos_token", "add_prefix_space")):
71
+ return True
72
+ # for examples that also use the Llama tokenizer, but do not have or set add_bos_token to True.
73
+ # ex: https://huggingface.co/stabilityai/StableBeluga2/blob/main/tokenizer_config.json#L2
74
+ return config.get("add_bos_token") is None and config.get("tokenizer_class") == "LlamaTokenizer"
75
+
76
+ def encode(
77
+ self,
78
+ string: str,
79
+ device: Optional[torch.device] = None,
80
+ bos: Optional[bool] = None,
81
+ eos: bool = False,
82
+ max_length: int = -1,
83
+ ) -> torch.Tensor:
84
+ if self.backend == "huggingface":
85
+ tokens = self.processor.encode(string).ids
86
+ elif self.backend == "sentencepiece":
87
+ tokens = self.processor.encode(string)
88
+ else:
89
+ raise RuntimeError
90
+ if bos or (bos is None and self.use_bos):
91
+ bos_id = self.bos_id
92
+ if bos_id is None:
93
+ raise NotImplementedError("This tokenizer does not have a defined a bos token")
94
+ tokens = [bos_id] + tokens
95
+ if eos:
96
+ tokens = tokens + [self.eos_id]
97
+ if max_length > 0:
98
+ tokens = tokens[:max_length]
99
+ return torch.tensor(tokens, dtype=torch.int, device=device)
100
+
101
+ def decode(self, tensor: torch.Tensor) -> str:
102
+ tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()
103
+ return self.processor.decode(tokens)
tsai_gpt/utils.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for training and inference."""
2
+ import math
3
+ import pickle
4
+ import sys
5
+ from contextlib import nullcontext
6
+ from io import BytesIO
7
+ from pathlib import Path
8
+ from typing import (
9
+ TYPE_CHECKING,
10
+ ContextManager,
11
+ Dict,
12
+ List,
13
+ Mapping,
14
+ Optional,
15
+ TypeVar,
16
+ Union,
17
+ )
18
+
19
+ import lightning as L
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.utils._device
23
+ from lightning.fabric.strategies import FSDPStrategy
24
+ from lightning.fabric.utilities.load import _lazy_load as lazy_load
25
+ from torch.serialization import normalize_storage_type
26
+
27
+ if TYPE_CHECKING:
28
+ from model import GPT
29
+
30
+
31
+ def find_multiple(n: int, k: int) -> int:
32
+ assert k > 0
33
+ if n % k == 0:
34
+ return n
35
+ return n + k - (n % k)
36
+
37
+
38
+ def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:
39
+ total = 0
40
+ for p in module.parameters():
41
+ if requires_grad is None or p.requires_grad == requires_grad:
42
+ if hasattr(p, "quant_state"):
43
+ # bitsandbytes 4bit layer support
44
+ total += math.prod(p.quant_state[1])
45
+ else:
46
+ total += p.numel()
47
+ return total
48
+
49
+
50
+ def gptq_quantization(enabled: bool = False) -> ContextManager:
51
+ if not enabled:
52
+ return nullcontext()
53
+
54
+ from lightning.fabric.plugins.precision.utils import _ClassReplacementContextManager
55
+
56
+ from quantize.gptq import ColBlockQuantizedLinear
57
+
58
+ class QuantizedLinear(ColBlockQuantizedLinear):
59
+ def __init__(self, *args, **kwargs):
60
+ super().__init__(*args, bits=4, tile_cols=-1, **kwargs)
61
+
62
+ return _ClassReplacementContextManager({"torch.nn.Linear": QuantizedLinear})
63
+
64
+
65
+ def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None:
66
+ files = {
67
+ "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(),
68
+ "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(),
69
+ "tokenizer.json OR tokenizer.model": (
70
+ checkpoint_dir / "tokenizer.json"
71
+ ).is_file()
72
+ or (checkpoint_dir / "tokenizer.model").is_file(),
73
+ "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
74
+ }
75
+ if checkpoint_dir.is_dir():
76
+ if all(files.values()):
77
+ # we're good
78
+ return
79
+ problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}"
80
+ else:
81
+ problem = " is not a checkpoint directory"
82
+
83
+ # list locally available checkpoints
84
+ available = list(Path("checkpoints").glob("*/*"))
85
+ if available:
86
+ options = "\n --checkpoint_dir ".join(
87
+ [""] + [repr(str(p.resolve())) for p in available]
88
+ )
89
+ extra = f"\nYou have downloaded locally:{options}\n"
90
+ else:
91
+ extra = ""
92
+
93
+ error_message = (
94
+ f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
95
+ "\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n"
96
+ f"{extra}\nSee all download options by running:\n python scripts/download.py"
97
+ )
98
+ print(error_message, file=sys.stderr)
99
+ raise SystemExit(1)
100
+
101
+
102
+ class SavingProxyForStorage:
103
+ def __init__(self, obj, saver, protocol_version=5):
104
+ self.protocol_version = protocol_version
105
+ self.saver = saver
106
+ if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
107
+ raise TypeError(f"expected storage, not {type(obj)}")
108
+
109
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
110
+ if isinstance(obj, torch.storage.TypedStorage):
111
+ # PT upstream wants to deprecate this eventually...
112
+ storage = obj._untyped_storage
113
+ storage_type_str = obj._pickle_storage_type()
114
+ storage_type = getattr(torch, storage_type_str)
115
+ storage_numel = obj._size()
116
+ else:
117
+ storage = obj
118
+ storage_type = normalize_storage_type(type(obj))
119
+ storage_numel = storage.nbytes()
120
+
121
+ storage_key = saver._write_storage_and_return_key(storage)
122
+ location = torch.serialization.location_tag(storage)
123
+
124
+ self.storage_info = (
125
+ "storage",
126
+ storage_type,
127
+ storage_key,
128
+ location,
129
+ storage_numel,
130
+ )
131
+
132
+ def __reduce_ex__(self, protocol_version):
133
+ assert False, "this should be handled with out of band"
134
+
135
+
136
+ class SavingProxyForTensor:
137
+ def __init__(self, tensor, saver, protocol_version=5):
138
+ self.protocol_version = protocol_version
139
+ self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version)
140
+ if reduce_args[0] == torch._utils._rebuild_tensor_v2:
141
+ # for Tensors with Python attributes
142
+ (a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args
143
+ assert isinstance(
144
+ storage, torch.storage.TypedStorage
145
+ ), "Please check for updates"
146
+ storage_proxy = SavingProxyForStorage(
147
+ storage, saver, protocol_version=protocol_version
148
+ )
149
+ self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args)
150
+ else:
151
+ (storage, *other_reduce_args) = reduce_args
152
+ assert isinstance(
153
+ storage, torch.storage.TypedStorage
154
+ ), "Please check for updates"
155
+ storage_proxy = SavingProxyForStorage(
156
+ storage, saver, protocol_version=protocol_version
157
+ )
158
+ self.reduce_args = (storage_proxy, *other_reduce_args)
159
+
160
+ def __reduce_ex__(self, protocol_version):
161
+ if protocol_version != self.protocol_version:
162
+ raise RuntimeError(
163
+ f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}"
164
+ )
165
+ return self.reduce_ret_fn, self.reduce_args
166
+
167
+
168
+ class IncrementalPyTorchPickler(pickle.Pickler):
169
+ def __init__(self, saver, *args, **kwargs):
170
+ super().__init__(*args, **kwargs)
171
+ self.storage_dtypes = {}
172
+ self.saver = saver
173
+ self.id_map = {}
174
+
175
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
176
+ def persistent_id(self, obj):
177
+ # FIXME: the docs say that persistent_id should only return a string
178
+ # but torch store returns tuples. This works only in the binary protocol
179
+ # see
180
+ # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
181
+ # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
182
+ if isinstance(obj, SavingProxyForStorage):
183
+ return obj.storage_info
184
+
185
+ if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
186
+ if isinstance(obj, torch.storage.TypedStorage):
187
+ # TODO: Once we decide to break serialization FC, this case
188
+ # can be deleted
189
+ storage = obj._untyped_storage
190
+ storage_dtype = obj.dtype
191
+ storage_type_str = obj._pickle_storage_type()
192
+ storage_type = getattr(torch, storage_type_str)
193
+ storage_numel = obj._size()
194
+
195
+ else:
196
+ storage = obj
197
+ storage_dtype = torch.uint8
198
+ storage_type = normalize_storage_type(type(obj))
199
+ storage_numel = storage.nbytes()
200
+
201
+ # If storage is allocated, ensure that any other saved storages
202
+ # pointing to the same data all have the same dtype. If storage is
203
+ # not allocated, don't perform this check
204
+ if storage.data_ptr() != 0:
205
+ if storage.data_ptr() in self.storage_dtypes:
206
+ if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
207
+ raise RuntimeError(
208
+ "Cannot save multiple tensors or storages that view the same data as different types"
209
+ )
210
+ else:
211
+ self.storage_dtypes[storage.data_ptr()] = storage_dtype
212
+
213
+ storage_key = self.id_map.get(storage._cdata)
214
+ if storage_key is None:
215
+ storage_key = self.saver._write_storage_and_return_key(storage)
216
+ self.id_map[storage._cdata] = storage_key
217
+ location = torch.serialization.location_tag(storage)
218
+
219
+ return ("storage", storage_type, storage_key, location, storage_numel)
220
+
221
+ return None
222
+
223
+
224
+ class incremental_save:
225
+ def __init__(self, name):
226
+ self.name = name
227
+ self.zipfile = torch._C.PyTorchFileWriter(str(name))
228
+ self.has_saved = False
229
+ self.next_key = 0
230
+
231
+ def __enter__(self):
232
+ return self
233
+
234
+ def store_early(self, tensor):
235
+ if isinstance(tensor, torch.Tensor):
236
+ return SavingProxyForTensor(tensor, self)
237
+ raise TypeError(f"can only store tensors early, not {type(tensor)}")
238
+
239
+ def save(self, obj):
240
+ if self.has_saved:
241
+ raise RuntimeError("have already saved")
242
+ # Write the pickle data for `obj`
243
+ data_buf = BytesIO()
244
+ pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
245
+ pickler.dump(obj)
246
+ data_value = data_buf.getvalue()
247
+ self.zipfile.write_record("data.pkl", data_value, len(data_value))
248
+ self.has_saved = True
249
+
250
+ def _write_storage_and_return_key(self, storage):
251
+ if self.has_saved:
252
+ raise RuntimeError("have already saved")
253
+ key = self.next_key
254
+ self.next_key += 1
255
+ name = f"data/{key}"
256
+ if storage.device.type != "cpu":
257
+ storage = storage.cpu()
258
+ num_bytes = storage.nbytes()
259
+ self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
260
+ return key
261
+
262
+ def __exit__(self, type, value, traceback):
263
+ self.zipfile.write_end_of_file()
264
+
265
+
266
+ T = TypeVar("T")
267
+
268
+
269
+ def chunked_cross_entropy(
270
+ logits: Union[torch.Tensor, List[torch.Tensor]],
271
+ targets: torch.Tensor,
272
+ chunk_size: int = 128,
273
+ ) -> torch.Tensor:
274
+ # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate
275
+ # the memory usage in fine-tuning settings with low number of parameters.
276
+ # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing
277
+ # the memory spike's magnitude
278
+
279
+ # lm_head was chunked (we are fine-tuning)
280
+ if isinstance(logits, list):
281
+ # don't want to chunk cross entropy
282
+ if chunk_size == 0:
283
+ logits = torch.cat(logits, dim=1)
284
+ logits = logits.reshape(-1, logits.size(-1))
285
+ targets = targets.reshape(-1)
286
+ return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)
287
+
288
+ # chunk cross entropy
289
+ logit_chunks = [
290
+ logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits
291
+ ]
292
+ target_chunks = [
293
+ target_chunk.reshape(-1)
294
+ for target_chunk in targets.split(logits[0].size(1), dim=1)
295
+ ]
296
+ loss_chunks = [
297
+ torch.nn.functional.cross_entropy(
298
+ logit_chunk, target_chunk, ignore_index=-1, reduction="none"
299
+ )
300
+ for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
301
+ ]
302
+ return torch.cat(loss_chunks).mean()
303
+
304
+ # no chunking at all
305
+ logits = logits.reshape(-1, logits.size(-1))
306
+ targets = targets.reshape(-1)
307
+ if chunk_size == 0:
308
+ return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)
309
+
310
+ # lm_head wasn't chunked, chunk cross entropy
311
+ logit_chunks = logits.split(chunk_size)
312
+ target_chunks = targets.split(chunk_size)
313
+ loss_chunks = [
314
+ torch.nn.functional.cross_entropy(
315
+ logit_chunk, target_chunk, ignore_index=-1, reduction="none"
316
+ )
317
+ for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
318
+ ]
319
+ return torch.cat(loss_chunks).mean()
320
+
321
+
322
+ def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:
323
+ for checkpoint_name, attribute_name in mapping.items():
324
+ full_checkpoint_name = prefix + checkpoint_name
325
+ if full_checkpoint_name in state_dict:
326
+ full_attribute_name = prefix + attribute_name
327
+ state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name)
328
+ return state_dict
329
+
330
+
331
+ def get_default_supported_precision(training: bool) -> str:
332
+ """Return default precision that is supported by the hardware: either `bf16` or `16`.
333
+
334
+ Args:
335
+ training: `-mixed` or `-true` version of the precision to use
336
+
337
+ Returns:
338
+ default precision that is suitable for the task and is supported by the hardware
339
+ """
340
+ from lightning.fabric.accelerators import MPSAccelerator
341
+
342
+ if MPSAccelerator.is_available() or (
343
+ torch.cuda.is_available() and not torch.cuda.is_bf16_supported()
344
+ ):
345
+ return "16-mixed" if training else "16-true"
346
+ return "bf16-mixed" if training else "bf16-true"
347
+
348
+
349
+ def load_checkpoint(
350
+ fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True
351
+ ) -> None:
352
+ if isinstance(fabric.strategy, FSDPStrategy):
353
+ fabric.load_raw(checkpoint_path, model, strict=strict)
354
+ else:
355
+ state_dict = lazy_load(checkpoint_path)
356
+ state_dict = state_dict.get("model", state_dict)
357
+ model.load_state_dict(state_dict, strict=strict)
358
+
359
+
360
+ def flops_per_param(
361
+ max_seq_length: int, n_layer: int, n_embd: int, n_params: int
362
+ ) -> int:
363
+ flops_per_token = (
364
+ 2 * n_params
365
+ ) # each parameter is used for a MAC (2 FLOPS) per network operation
366
+ # this assumes that all samples have a fixed length equal to the block size
367
+ # which is most likely false during finetuning
368
+ flops_per_seq = flops_per_token * max_seq_length
369
+ attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))
370
+ return flops_per_seq + attn_flops_per_seq
371
+
372
+
373
+ def estimate_flops(model: "GPT", training: bool) -> int:
374
+ """Measures estimated FLOPs for MFU.
375
+
376
+ Refs:
377
+ * https://ar5iv.labs.arxiv.org/html/2205.05198#A1
378
+ * https://ar5iv.labs.arxiv.org/html/2204.02311#A2
379
+ """
380
+ # using all parameters for this is a naive over estimation because not all model parameters actually contribute to
381
+ # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage
382
+ # (~10%) compared to the measured FLOPs, making those lower but more realistic.
383
+ # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.
384
+ n_trainable_params = num_parameters(model, requires_grad=True)
385
+ trainable_flops = flops_per_param(
386
+ model.max_seq_length,
387
+ model.config.n_layer,
388
+ model.config.n_embd,
389
+ n_trainable_params,
390
+ )
391
+ # forward + backward + gradients (assumes no gradient accumulation)
392
+ ops_per_step = 3 if training else 1
393
+ n_frozen_params = num_parameters(model, requires_grad=False)
394
+ frozen_flops = flops_per_param(
395
+ model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params
396
+ )
397
+ # forward + backward
398
+ frozen_ops_per_step = 2 if training else 1
399
+ return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops