thomwolf HF staff commited on
Commit
54ba632
·
1 Parent(s): 645b194
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +1 -0
  2. README.md +37 -0
  3. config_minicpm.py +119 -0
  4. config_tiny_mistral.yaml +92 -0
  5. convert_trfrs_to_brrr.py +262 -0
  6. dataloader.py +107 -0
  7. modeling_minicpm.py +1147 -0
  8. pretrained/MiniCPM-2B-dpo-bf16/checkpoint_metadata.json +9 -0
  9. pretrained/MiniCPM-2B-dpo-bf16/config.yaml +55 -0
  10. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/0/pp_block/attn/o_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  11. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/0/pp_block/attn/qkv_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  12. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/0/pp_block/input_layernorm/model_weight.safetensors +3 -0
  13. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/0/pp_block/mlp/down_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  14. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/0/pp_block/mlp/gate_up_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  15. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/0/pp_block/post_attention_layernorm/model_weight.safetensors +3 -0
  16. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/1/pp_block/attn/o_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  17. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/1/pp_block/attn/qkv_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  18. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/1/pp_block/input_layernorm/model_weight.safetensors +3 -0
  19. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/1/pp_block/mlp/down_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  20. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/1/pp_block/mlp/gate_up_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  21. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/1/pp_block/post_attention_layernorm/model_weight.safetensors +3 -0
  22. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/10/pp_block/attn/o_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  23. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/10/pp_block/attn/qkv_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  24. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/10/pp_block/input_layernorm/model_weight.safetensors +3 -0
  25. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/10/pp_block/mlp/down_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  26. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/10/pp_block/mlp/gate_up_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  27. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/10/pp_block/post_attention_layernorm/model_weight.safetensors +3 -0
  28. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/11/pp_block/attn/o_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  29. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/11/pp_block/attn/qkv_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  30. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/11/pp_block/input_layernorm/model_weight.safetensors +3 -0
  31. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/11/pp_block/mlp/down_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  32. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/11/pp_block/mlp/gate_up_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  33. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/11/pp_block/post_attention_layernorm/model_weight.safetensors +3 -0
  34. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/12/pp_block/attn/o_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  35. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/12/pp_block/attn/qkv_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  36. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/12/pp_block/input_layernorm/model_weight.safetensors +3 -0
  37. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/12/pp_block/mlp/down_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  38. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/12/pp_block/mlp/gate_up_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  39. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/12/pp_block/post_attention_layernorm/model_weight.safetensors +3 -0
  40. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/13/pp_block/attn/o_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  41. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/13/pp_block/attn/qkv_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  42. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/13/pp_block/input_layernorm/model_weight.safetensors +3 -0
  43. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/13/pp_block/mlp/down_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  44. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/13/pp_block/mlp/gate_up_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  45. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/13/pp_block/post_attention_layernorm/model_weight.safetensors +3 -0
  46. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/14/pp_block/attn/o_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  47. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/14/pp_block/attn/qkv_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  48. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/14/pp_block/input_layernorm/model_weight.safetensors +3 -0
  49. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/14/pp_block/mlp/down_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
  50. pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/14/pp_block/mlp/gate_up_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors +3 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: nanotron
3
+ ---
4
+
5
+ # ⚙️ Nano-Mistral
6
+
7
+ Modeling code for Mistral to use with [Nanotron](https://github.com/huggingface/nanotron/)
8
+
9
+ Also contains converted pretrained weights for Mistral-7B-0.1: https://huggingface.co/mistralai/Mistral-7B-v0.1
10
+
11
+ ## 🚀 Quickstart
12
+
13
+ ```bash
14
+ # Generate a config file
15
+ python config_tiny_mistral.py
16
+
17
+ # Run training
18
+ export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations
19
+ torchrun --nproc_per_node=8 run_train.py --config-file config_tiny_mistral.yaml
20
+ ```
21
+
22
+ ## 🚀 Run generation with pretrained Mistral-7B-0.1
23
+
24
+ ```bash
25
+ export CUDA_DEVICE_MAX_CONNECTIONS=1
26
+ torchrun --nproc_per_node=1 run_generate.py --ckpt-path ./pretrained/Mistral-7B-v0.1
27
+ ```
28
+
29
+ ## 🚀 Use your custom model
30
+
31
+ - Update the `MistralConfig` class in `config_tiny_mistral.py` to match your model's configuration
32
+ - Update the `MistralForTraining` class in `modeling_mistral.py` to match your model's architecture
33
+ - Pass the previous to the `DistributedTrainer` class in `run_train.py`:
34
+ ```python
35
+ trainer = DistributedTrainer(config_file, model_class=MistralForTraining, model_config_class=MistralConfig)
36
+ ```
37
+ - Run training as usual
config_minicpm.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information.
2
+
3
+ Usage:
4
+ ```
5
+ python config_tiny_mistral.py
6
+ ```
7
+ """
8
+ import os
9
+ from dataclasses import dataclass
10
+ from typing import Optional
11
+
12
+ from nanotron.config import (
13
+ CheckpointsArgs,
14
+ Config,
15
+ DataArgs,
16
+ GeneralArgs,
17
+ LoggingArgs,
18
+ LRSchedulerArgs,
19
+ ModelArgs,
20
+ OptimizerArgs,
21
+ ParallelismArgs,
22
+ PretrainDatasetsArgs,
23
+ RandomInit,
24
+ TokenizerArgs,
25
+ TokensArgs,
26
+ )
27
+ from nanotron.logging import human_format
28
+
29
+
30
+ @dataclass
31
+ class MiniCPMConfig:
32
+ """Configuration for a MiniCPM model.
33
+
34
+ Be careful on having a coherent typing as we use it to reconstruct the model from yaml
35
+ """
36
+
37
+ attn_pdrop: float = 0.0
38
+ bos_token_id: int =1
39
+ eos_token_id: int =2
40
+ pad_token_id: Optional[int] = None
41
+ hidden_act: str ="silu"
42
+ hidden_size: int =2304
43
+ initializer_range: float =0.1
44
+ intermediate_size: int =5760
45
+ max_position_embeddings: int =2048
46
+ num_attention_heads: int =36
47
+ num_hidden_layers: int =40
48
+ num_key_value_heads: int =36
49
+ pretraining_tp: int=1
50
+ rms_norm_eps: float=1e-05
51
+ rope_theta: float = 10000.0
52
+ tie_word_embeddings: bool =True
53
+ use_cache: bool =True
54
+ vocab_size: int = 122753
55
+ scale_emb: float = 12
56
+ dim_model_base: int= 256
57
+ scale_depth: float = 1.4
58
+
59
+ def __post_init__(self):
60
+ # for backward compatibility
61
+ if self.num_key_value_heads is None:
62
+ self.num_key_value_heads = self.num_attention_heads
63
+
64
+ def get_num_params(model_config: MiniCPMConfig) -> int:
65
+ num_params = model_config.vocab_size * model_config.hidden_size * 2 + \
66
+ model_config.num_hidden_layers * (
67
+ 3 * model_config.hidden_size * model_config.intermediate_size
68
+ + 2 * model_config.hidden_size * model_config.hidden_size
69
+ + 2 * model_config.hidden_size * (model_config.hidden_size / (model_config.num_attention_heads / model_config.num_key_value_heads))
70
+ )
71
+ return num_params
72
+
73
+ def get_num_params_no_embed(model_config: MiniCPMConfig) -> int:
74
+ num_params = model_config.num_hidden_layers * (
75
+ 3 * model_config.hidden_size * model_config.intermediate_size
76
+ + 2 * model_config.hidden_size * model_config.hidden_size
77
+ + 2 * model_config.hidden_size * (model_config.hidden_size / (model_config.num_attention_heads / model_config.num_key_value_heads))
78
+ )
79
+ return num_params
80
+
81
+ MODEL_CONFIG = MiniCPMConfig()
82
+
83
+ num_params = human_format(get_num_params(MODEL_CONFIG)).replace(".", "p")
84
+ num_params_no_embed = human_format(get_num_params_no_embed(MODEL_CONFIG)).replace(".", "p")
85
+
86
+ print(f"Model has {num_params} parameters or {num_params_no_embed} without embeddings")
87
+
88
+ PARALLELISM = ParallelismArgs(
89
+ dp=1,
90
+ pp=1,
91
+ tp=1,
92
+ pp_engine="1f1b",
93
+ tp_mode="REDUCE_SCATTER",
94
+ tp_linear_async_communication=True,
95
+ recompute_granularity="selective",
96
+ )
97
+
98
+ CONFIG = Config(
99
+ general=GeneralArgs(project="openbmb", run="MiniCPM-2B-dpo-bf16", seed=42, step=0),
100
+ checkpoints=None,
101
+ parallelism=PARALLELISM,
102
+ model=ModelArgs(init_method=RandomInit(std=0.025), model_config=MODEL_CONFIG),
103
+ tokenizer=TokenizerArgs("openbmb/MiniCPM-2B-dpo-bf16"),
104
+ optimizer=None,
105
+ logging=None,
106
+ tokens=None,
107
+ data=None,
108
+ profiler=None,
109
+ lighteval=None,
110
+ )
111
+
112
+ if __name__ == "__main__":
113
+ file_path = os.path.abspath(__file__)
114
+
115
+ file_path = file_path.replace(".py", ".yaml")
116
+ # Save config as YAML file
117
+ CONFIG.save_as_yaml(file_path)
118
+
119
+ # You can now train a model with this config using `/run_train.py`
config_tiny_mistral.yaml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoints:
2
+ checkpoint_interval: 10
3
+ checkpoints_path: /fsx/thomwolf/github/textbooks-proj/brrr/models/checkpoints
4
+ checkpoints_path_is_shared_file_system: false
5
+ resume_checkpoint_path: null
6
+ save_initial_state: false
7
+ data:
8
+ dataset:
9
+ dataset_overwrite_cache: false
10
+ dataset_processing_num_proc_per_process: 1
11
+ hf_dataset_config_name: null
12
+ hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small
13
+ hf_dataset_splits: train
14
+ text_column_name: completion
15
+ num_loading_workers: 1
16
+ seed: 42
17
+ general:
18
+ benchmark_csv_path: null
19
+ consumed_train_samples: null
20
+ ignore_sanity_checks: false
21
+ project: debug
22
+ run: tiny_mistral
23
+ seed: 42
24
+ step: null
25
+ logging:
26
+ iteration_step_info_interval: 1
27
+ log_level: info
28
+ log_level_replica: info
29
+ model:
30
+ ddp_bucket_cap_mb: 25
31
+ dtype: bfloat16
32
+ init_method:
33
+ std: 0.025
34
+ make_vocab_size_divisible_by: 1
35
+ model_config:
36
+ attn_pdrop: 0.0
37
+ bos_token_id: 1
38
+ eos_token_id: 2
39
+ hidden_act: silu
40
+ hidden_size: 16
41
+ initializer_range: 0.02
42
+ intermediate_size: 64
43
+ is_mistral_config: true
44
+ max_position_embeddings: 256
45
+ num_attention_heads: 4
46
+ num_hidden_layers: 2
47
+ num_key_value_heads: 4
48
+ pad_token_id: null
49
+ pretraining_tp: 1
50
+ rms_norm_eps: 1.0e-05
51
+ rope_theta: 10000.0
52
+ sliding_window_size: 4096
53
+ tie_word_embeddings: true
54
+ use_cache: true
55
+ vocab_size: 256
56
+ optimizer:
57
+ accumulate_grad_in_fp32: true
58
+ adam_beta1: 0.9
59
+ adam_beta2: 0.95
60
+ adam_eps: 1.0e-08
61
+ clip_grad: 1.0
62
+ learning_rate_scheduler:
63
+ learning_rate: 0.0003
64
+ lr_decay_steps: 8
65
+ lr_decay_style: cosine
66
+ lr_warmup_steps: 2
67
+ lr_warmup_style: linear
68
+ min_decay_lr: 1.0e-05
69
+ torch_adam_is_fused: true
70
+ weight_decay: 0.01
71
+ zero_stage: 0
72
+ parallelism:
73
+ dp: 2
74
+ pp: 2
75
+ pp_engine: 1f1b
76
+ recompute_granularity: SELECTIVE
77
+ tp: 2
78
+ tp_linear_async_communication: true
79
+ tp_mode: REDUCE_SCATTER
80
+ profiler: null
81
+ tokenizer:
82
+ tokenizer_max_length: null
83
+ tokenizer_name_or_path: gpt2
84
+ tokenizer_revision: null
85
+ tokens:
86
+ batch_accumulation_per_replica: 1
87
+ limit_test_batches: 0
88
+ limit_val_batches: 0
89
+ micro_batch_size: 2
90
+ sequence_length: 32
91
+ train_steps: 10
92
+ val_check_interval: -1
convert_trfrs_to_brrr.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ruff: noqa: E402
2
+ """
3
+ This module converts a transformers LlamaForCausalLM to a brrr model
4
+
5
+ Command:
6
+ torchrun --nproc_per_node=1 convert_trfrs_to_brrr.py \
7
+ --model_name openbmb/MiniCPM-2B-dpo-bf16 \
8
+ --save_path ./pretrained/MiniCPM-2B-dpo-bf16
9
+ """
10
+ import argparse
11
+ import sys
12
+ from dataclasses import asdict
13
+ from pathlib import Path
14
+ from typing import Dict, List
15
+
16
+ import torch
17
+
18
+ from brrr.trainer import DistributedTrainer
19
+
20
+ sys.path.append(Path(__file__).parent.parent.as_posix())
21
+ import os
22
+
23
+ from nanotron.parallel.parameters import NanotronParameter, sanity_check
24
+ from nanotron.parallel.pipeline_parallel.engine import (
25
+ AllForwardAllBackwardPipelineEngine,
26
+ )
27
+ from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode
28
+ from transformers import MistralConfig as MistralConfig_trfs, MistralForCausalLM
29
+
30
+ import nanotron.distributed as dist
31
+ from nanotron.config import ParallelismArgs, RecomputeGranularity
32
+ from nanotron.parallel.context import ParallelContext
33
+ from nanotron.models import build_model
34
+ from nanotron.trainer import mark_tied_parameters
35
+ from nanotron.serialize import save_meta, save_weights, save
36
+
37
+ from modeling_minicpm import MiniCPMForTraining
38
+ from config_minicpm import PARALLELISM as PARALLELISM_BRRR, CONFIG as CONFIG_BRRR
39
+
40
+
41
+ def get_args():
42
+ parser = argparse.ArgumentParser(description="Convert transformers weights to brrr weights")
43
+ parser.add_argument("--model_name", type=str, default="openbmb/MiniCPM-2B-dpo-bf16")
44
+ parser.add_argument("--save_path", type=str, default="pretrained/MiniCPM-2B-dpo-bf16")
45
+ parser.add_argument("--dp", type=int, default=1)
46
+ parser.add_argument("--pp", type=int, default=1)
47
+ parser.add_argument("--tp", type=int, default=1)
48
+ return parser.parse_args()
49
+
50
+
51
+ def permute_for_rotary(tensor, num_heads, per_head_hidden_size, hidden_size):
52
+ return (
53
+ tensor.view(num_heads, 2, per_head_hidden_size // 2, hidden_size)
54
+ .transpose(1, 2)
55
+ .contiguous()
56
+ .view(num_heads * per_head_hidden_size, hidden_size)
57
+ )
58
+
59
+
60
+ def get_transformers_weight(
61
+ name: str, ref_module_state_dict: Dict[str, torch.Tensor], ref_module: MistralForCausalLM, get_grad: bool = False
62
+ ) -> torch.Tensor:
63
+ """From our brrr implementation, we get the equivalent tensor in transformers implementation"""
64
+ config = ref_module.config
65
+ brrr_prefix = "model."
66
+ assert name.startswith(brrr_prefix)
67
+ name = name[len(brrr_prefix) :]
68
+
69
+ path = name.split(".")
70
+ path.remove("pp_block")
71
+ name = ".".join(path)
72
+
73
+ if get_grad is False:
74
+
75
+ def get_tensor(path: str):
76
+ return ref_module_state_dict[path]
77
+
78
+ def get_tensors(path: List[str]):
79
+ return [get_tensor(p) for p in path]
80
+
81
+ else:
82
+
83
+ def get_tensor(path: str):
84
+ weight = ref_module.get_parameter(path)
85
+ return weight.grad
86
+
87
+ def get_tensors(path: List[str]):
88
+ return [get_tensor(p) for p in path]
89
+
90
+ if name == "token_position_embeddings.token_embedding.weight":
91
+ return get_tensor("model.embed_tokens.weight")
92
+
93
+ elif name == "lm_head.weight":
94
+ # This only used when weights are not shared
95
+ return get_tensor("lm_head.weight")
96
+
97
+ elif name == "final_layer_norm.weight":
98
+ return get_tensor("model.norm.weight")
99
+
100
+ if path[0] == "decoder":
101
+ transformer_path = ["model"] + ["layers"] + [path[1]]
102
+
103
+ if path[2] == "attn":
104
+ path[2] = "self_attn"
105
+
106
+ if path[2] == "ff":
107
+ path[2] = "mlp"
108
+
109
+ if path[3] == "qkv_proj":
110
+ proj_names = ["q_proj", "k_proj", "v_proj"]
111
+ tensor_list = get_tensors(
112
+ [".".join(transformer_path + path[2:3] + [proj_name] + path[4:]) for proj_name in proj_names]
113
+ )
114
+ # Permute q/k
115
+ per_head_hidden_size = config.hidden_size // config.num_attention_heads
116
+ # Permute q
117
+ print(f"Permuting q {tensor_list[0].shape}")
118
+ tensor_list[0] = permute_for_rotary(
119
+ tensor=tensor_list[0],
120
+ num_heads=config.num_attention_heads,
121
+ per_head_hidden_size=per_head_hidden_size,
122
+ hidden_size=config.hidden_size,
123
+ )
124
+ # Permute k
125
+ print(f"Permuting k {tensor_list[1].shape}")
126
+ tensor_list[1] = permute_for_rotary(
127
+ tensor=tensor_list[1],
128
+ num_heads=config.num_key_value_heads,
129
+ per_head_hidden_size=per_head_hidden_size,
130
+ hidden_size=config.hidden_size,
131
+ )
132
+ return torch.cat(tensor_list, dim=0)
133
+
134
+ if path[3] == "gate_up_proj":
135
+ tensor_list = get_tensors(
136
+ [
137
+ ".".join(transformer_path + path[2:3] + [proj_name] + path[4:])
138
+ for proj_name in ["gate_proj", "up_proj"]
139
+ ]
140
+ )
141
+ return torch.cat(tensor_list, dim=0)
142
+
143
+ return get_tensor(".".join(transformer_path + path[2:]))
144
+
145
+ else:
146
+ raise ValueError(f"Couldn't find transformer equivalent of {name}")
147
+
148
+
149
+ def convert_trfrs_to_brrr(dp, pp, tp, model_name="huggyllama/llama-7b", save_path="pretrained/llama-7b"):
150
+ # check save_path doesnt exist or is empty
151
+ save_path = Path(save_path)
152
+ # assert not save_path.exists() or len(list(save_path.iterdir())) == 0, f"save_path {save_path} is not empty"
153
+
154
+ parallel_config = PARALLELISM_BRRR
155
+
156
+ parallel_config.dp = dp
157
+ parallel_config.pp = pp
158
+ parallel_config.tp = tp
159
+
160
+ # Initialise all process groups
161
+ parallel_context = ParallelContext(
162
+ data_parallel_size=parallel_config.dp,
163
+ pipeline_parallel_size=parallel_config.pp,
164
+ tensor_parallel_size=parallel_config.tp,
165
+ )
166
+ # params
167
+ dtype = torch.bfloat16 # Flash attention doesn't support fp32
168
+
169
+ # Initialise brrr model
170
+ model_config_brrr = CONFIG_BRRR.model.model_config
171
+
172
+ model = build_model(
173
+ model_builder=lambda: MiniCPMForTraining(
174
+ config=model_config_brrr,
175
+ parallel_context=parallel_context,
176
+ parallel_config=parallel_config,
177
+ random_states=None,
178
+ ),
179
+ dtype=dtype,
180
+ parallel_context=parallel_context,
181
+ device=torch.device("cpu"),
182
+ )
183
+
184
+ # Initialise transformers model
185
+ device_map = {}
186
+ current_pp_rank = dist.get_rank(group=parallel_context.pp_pg)
187
+ device_map["model.embed_tokens"] = (
188
+ model.model.token_position_embeddings.rank
189
+ if current_pp_rank == model.model.token_position_embeddings.rank
190
+ else "meta"
191
+ )
192
+ for i in range(model_config_brrr.num_hidden_layers):
193
+ device_map[f"model.layers.{i}"] = (
194
+ model.model.decoder[i].rank if current_pp_rank == model.model.decoder[i].rank else "meta"
195
+ )
196
+ device_map["model.norm"] = (
197
+ model.model.final_layer_norm.rank if current_pp_rank == model.model.final_layer_norm.rank else "meta"
198
+ )
199
+ device_map["lm_head"] = model.model.lm_head.rank if current_pp_rank == model.model.lm_head.rank else "meta"
200
+ model_ref = MistralForCausalLM.from_pretrained(model_name, torch_dtype=dtype, device_map=device_map)
201
+
202
+ # Copy weights from trfrs to brrr
203
+ ref_state_dict = model_ref.state_dict()
204
+ for name, param in model.named_parameters():
205
+ print(f"Syncing {name}")
206
+ ref_param = get_transformers_weight(name=name, ref_module_state_dict=ref_state_dict, ref_module=model_ref)
207
+
208
+ param_is_tp_sharded = (
209
+ isinstance(param, NanotronParameter)
210
+ and param.is_sharded
211
+ and parallel_context.world_ranks_to_pg[param.get_sharded_info().global_ranks] == parallel_context.tp_pg
212
+ )
213
+
214
+ if param_is_tp_sharded:
215
+ sharded_info = param.get_sharded_info()
216
+ # copy param data (not just the reference)
217
+ with torch.no_grad():
218
+ for local_global_slices_pair in sharded_info.local_global_slices_pairs:
219
+ local_slices = local_global_slices_pair.local_slices
220
+ global_slices = local_global_slices_pair.global_slices
221
+ param[local_slices].copy_(ref_param[global_slices])
222
+ else:
223
+ assert (
224
+ ref_param.shape == param.shape
225
+ ), f"Parameter shape don't match for {name}\n{ref_param.shape} != {param.shape}"
226
+ # copy param data (not just the reference)
227
+ with torch.no_grad():
228
+ param.copy_(ref_param)
229
+ ref_param = None
230
+ # torch.cuda.empty_cache()
231
+
232
+ # TODO @nouamanetazi: assert weights are the same
233
+ # Marks parameters as NanotronParameters
234
+ mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config)
235
+
236
+ sanity_check(root_module=model)
237
+
238
+ checkpoint_metadata = {
239
+ "last_train_step": 0,
240
+ "consumed_train_samples": 0,
241
+ }
242
+ save(config=CONFIG_BRRR, model=model, optimizer=None, lr_scheduler=None, parallel_context=parallel_context, root_folder=save_path,
243
+ should_save_optimizer=False, should_save_lr_scheduler=False, checkpoint_metadata=checkpoint_metadata,
244
+ sanity_checks=False)
245
+ # save_weights(model=model, parallel_context=parallel_context, root_folder=save_path)
246
+ # save_meta(root_folder=save_path, parallel_context=parallel_context, checkpoint_metadata=checkpoint_metadata)
247
+
248
+ if dist.get_rank(parallel_context.world_pg) == 0:
249
+ print(save_path)
250
+ import json
251
+
252
+ with open(save_path / "model_config.json", mode="w") as fo:
253
+ fo.write(json.dumps(asdict(CONFIG_BRRR.model.model_config), indent=4))
254
+
255
+
256
+ def main():
257
+ args = get_args()
258
+ convert_trfrs_to_brrr(**vars(args))
259
+
260
+
261
+ if __name__ == "__main__":
262
+ main()
dataloader.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nanotron import logging
2
+ from nanotron.config import (
3
+ PretrainDatasetsArgs,
4
+ )
5
+ from nanotron.dataloader import (
6
+ clm_process,
7
+ dummy_infinite_data_generator,
8
+ get_datasets,
9
+ get_train_dataloader,
10
+ )
11
+ from nanotron.logging import log_rank
12
+ from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks
13
+ from nanotron.trainer import DistributedTrainer
14
+ from nanotron.utils import (
15
+ main_rank_first,
16
+ )
17
+
18
+ try:
19
+ from huggingface_hub import __version__ as hf_hub_version
20
+ from transformers import AutoTokenizer
21
+ from transformers import __version__ as tf_version
22
+ except ImportError:
23
+ hf_hub_version = None
24
+ tf_version = None
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ def get_dataloader(trainer: DistributedTrainer):
30
+ """Returns a dataloader for training."""
31
+
32
+ # First, we need to know which ranks to feed the dataloader to
33
+ input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model)
34
+
35
+ # Case 1: Dummy data generator
36
+ if trainer.config.data.dataset is None:
37
+ log_rank("Using dummy data generator", logger=logger, level=logging.INFO, rank=0)
38
+ dataloader = dummy_infinite_data_generator(
39
+ micro_batch_size=trainer.micro_batch_size,
40
+ sequence_length=trainer.sequence_length,
41
+ input_pp_rank=input_pp_rank,
42
+ output_pp_rank=output_pp_rank,
43
+ vocab_size=trainer.model_config.vocab_size,
44
+ seed=trainer.config.data.seed,
45
+ parallel_context=trainer.parallel_context,
46
+ )()
47
+
48
+ # Case 2: HuggingFace datasets
49
+ elif isinstance(trainer.config.data.dataset, PretrainDatasetsArgs):
50
+ log_rank("Using `datasets` library", logger=logger, level=logging.INFO, rank=0)
51
+ tokenizer_path = trainer.config.tokenizer.tokenizer_name_or_path
52
+ log_rank(
53
+ f"Loading tokenizer from {tokenizer_path} and transformers/hf_hub versions {tf_version, hf_hub_version}",
54
+ logger=logger,
55
+ level=logging.INFO,
56
+ rank=0,
57
+ )
58
+
59
+ # We need to the 1st device to process dataset and cache it, then other devices load from cache
60
+ with main_rank_first(trainer.parallel_context.world_pg):
61
+ # TODO @nouamanetazi: this may timeout before 1st device finishes processing dataset. Can we have a ctxmanager to modify timeout?
62
+ # TODO: generalise to include for validation/test splits
63
+
64
+ # We load the raw dataset
65
+ raw_dataset = get_datasets(
66
+ hf_dataset_or_datasets=trainer.config.data.dataset.hf_dataset_or_datasets,
67
+ splits=trainer.config.data.dataset.hf_dataset_splits,
68
+ )["train"]
69
+
70
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
71
+ tokenizer.pad_token = tokenizer.eos_token
72
+ tokenizer.padding_side = "left"
73
+
74
+ # We apply the Causal Language Modeling preprocessing
75
+ train_dataset = clm_process(
76
+ raw_dataset=raw_dataset,
77
+ tokenizer=tokenizer,
78
+ text_column_name=trainer.config.data.dataset.text_column_name,
79
+ dataset_processing_num_proc_per_process=trainer.config.data.dataset.dataset_processing_num_proc_per_process,
80
+ dataset_overwrite_cache=trainer.config.data.dataset.dataset_overwrite_cache,
81
+ sequence_length=trainer.sequence_length,
82
+ )
83
+
84
+ # We load the processed dataset on the ranks requiring it
85
+ dataloader = get_train_dataloader(
86
+ train_dataset=train_dataset,
87
+ sequence_length=trainer.sequence_length,
88
+ parallel_context=trainer.parallel_context,
89
+ input_pp_rank=input_pp_rank,
90
+ output_pp_rank=output_pp_rank,
91
+ micro_batch_size=trainer.micro_batch_size,
92
+ consumed_train_samples=trainer.consumed_train_samples,
93
+ dataloader_num_workers=trainer.config.data.num_loading_workers,
94
+ seed_worker=trainer.config.data.seed,
95
+ dataloader_drop_last=True,
96
+ )
97
+ # Check if we have enough samples for train_steps
98
+ assert (
99
+ trainer.config.tokens.train_steps - trainer.start_iteration_step
100
+ ) * trainer.global_batch_size // trainer.parallel_context.dp_pg.size() < len(dataloader), (
101
+ f"Dataset is too small for steps ({len(dataloader)} < {(trainer.config.tokens.train_steps - trainer.start_iteration_step) * trainer.global_batch_size // trainer.parallel_context.dp_pg.size()}), "
102
+ f"Try train_steps<={len(dataloader) * trainer.parallel_context.dp_pg.size() // trainer.global_batch_size + trainer.start_iteration_step}"
103
+ )
104
+ else:
105
+ raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {trainer.config.data.dataset}")
106
+
107
+ return dataloader
modeling_minicpm.py ADDED
@@ -0,0 +1,1147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch MiniCPM model.
16
+ """
17
+ from typing import Dict, Optional, Union
18
+ import inspect
19
+ import math
20
+
21
+ import torch
22
+ from flash_attn import bert_padding
23
+ from flash_attn.flash_attn_interface import (
24
+ flash_attn_varlen_func,
25
+ flash_attn_with_kvcache,
26
+ )
27
+ from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
28
+ from nanotron import distributed as dist
29
+ from nanotron import logging
30
+ from nanotron.config import ParallelismArgs, RecomputeGranularity
31
+ from nanotron.generation.generate_store import AttachableStore
32
+ from nanotron.logging import log_rank
33
+ from nanotron.models import NanotronModel
34
+ from nanotron.nn.layer_norm import TritonRMSNorm
35
+ from nanotron.parallel import ParallelContext
36
+ from nanotron.parallel.parameters import NanotronParameter
37
+ from nanotron.parallel.pipeline_parallel.block import (
38
+ PipelineBlock,
39
+ TensorPointer,
40
+ )
41
+ from nanotron.parallel.pipeline_parallel.p2p import P2P
42
+ from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy
43
+ from nanotron.parallel.tensor_parallel.nn import (
44
+ TensorParallelColumnLinear,
45
+ TensorParallelEmbedding,
46
+ TensorParallelLinearMode,
47
+ TensorParallelRowLinear,
48
+ )
49
+ from nanotron.random import RandomStates
50
+ from nanotron.utils import checkpoint_method
51
+ from nanotron.nn.activations import ACT2FN
52
+ from torch import nn
53
+
54
+ from config_minicpm import MiniCPMConfig
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_varlen_func).parameters)
59
+
60
+
61
+ class RotaryEmbedding(nn.Module):
62
+ def __init__(self, dim: int, end: int, theta: float = 10000.0):
63
+ super().__init__()
64
+ assert dim % 2 == 0
65
+ self.dim = dim
66
+ self.end = end
67
+ self.theta = theta
68
+ # TODO @nouamane: Figure out why we can't set `DTypeInvariantTensor` ...
69
+ # TODO @thomasw21: Complex buffers break DDP, instead we store float and view them as complex
70
+ self.freqs_cis: torch.Tensor
71
+ self._initialized_buffer = False
72
+
73
+ def init_rotary_embeddings(self):
74
+ if self._initialized_buffer is True:
75
+ # Buffer if already initialized
76
+ return
77
+ self.register_buffer(
78
+ "freqs_cis",
79
+ torch.empty(self.end, self.dim // 2, 2, dtype=torch.float, device="cuda"),
80
+ persistent=False,
81
+ )
82
+ assert self.freqs_cis.device.type == "cuda"
83
+ # TODO @nouamane: One we figure out how to do the DTypeInvariantTensor, this can be removed and changed to an assert
84
+ if self.freqs_cis.dtype != torch.float:
85
+ self.freqs_cis = self.freqs_cis.to(torch.float)
86
+ assert self.freqs_cis.dtype == torch.float
87
+ freqs = 1.0 / (
88
+ self.theta
89
+ ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cuda")[: (self.dim // 2)] / self.dim)
90
+ )
91
+ t = torch.arange(self.end, device="cuda")
92
+ freqs = torch.outer(t, freqs).float()
93
+ complex_freqs = torch.polar(torch.ones_like(freqs), freqs)
94
+ freqs = torch.view_as_real(complex_freqs)
95
+ self.freqs_cis.copy_(freqs)
96
+ self._initialized_buffer = True
97
+
98
+ def forward(
99
+ self,
100
+ x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk]
101
+ position_ids: Optional[torch.LongTensor], # [batch_size, seq_length]
102
+ ):
103
+ batch_size, seq_length, num_heads, inner_dim = x.shape
104
+ while (
105
+ position_ids is not None and position_ids[-1, -1] >= self.end
106
+ ) or seq_length >= self.end: # TODO @nouamane: check if this causes cpu-gpu sync
107
+ self.end *= 2
108
+ self._initialized_buffer = False
109
+ if self._initialized_buffer is False:
110
+ self.init_rotary_embeddings()
111
+ dtype = x.dtype
112
+ assert inner_dim % 2 == 0
113
+ x = x.view(
114
+ batch_size, seq_length, num_heads, inner_dim // 2, 2
115
+ ) # [batch_size, q_length, num_heads, inner_dim]
116
+ if x.dtype == torch.bfloat16:
117
+ x = x.float()
118
+ complex_x = torch.view_as_complex(x) # [batch_size, q_length, num_heads, inner_dim // 2]
119
+ if position_ids is None:
120
+ freqs_cis = self.freqs_cis[None, :seq_length, None, :]
121
+ else:
122
+ # TODO(kunhao): Should None follow the num_heads dimension?
123
+ if position_ids[-1, -1] < 0 or position_ids[-1, -1] >= self.end: # Quick test hopefully
124
+ raise ValueError(f"Position ids must be in the range [0, {self.end}), but got {position_ids}")
125
+ freqs_cis = self.freqs_cis[position_ids][:, :, None, :]
126
+ complex_freqs = torch.view_as_complex(freqs_cis)
127
+ x_out = torch.view_as_real(complex_x * complex_freqs).view(batch_size, seq_length, num_heads, inner_dim)
128
+ return x_out.type(dtype)
129
+
130
+
131
+ class GLUActivation(nn.Module):
132
+ def __init__(self, act_fn_name: str):
133
+ super().__init__()
134
+ self.act = ACT2FN[act_fn_name]
135
+
136
+ def forward(self, merged_states: torch.Tensor):
137
+ gate_states, up_states = torch.split(merged_states, merged_states.shape[-1] // 2, dim=-1)
138
+ return self.act(gate_states) * up_states
139
+
140
+
141
+ class MLP(nn.Module):
142
+ def __init__(
143
+ self,
144
+ config: MiniCPMConfig,
145
+ parallel_config: Optional[ParallelismArgs],
146
+ tp_pg: dist.ProcessGroup,
147
+ ):
148
+ super().__init__()
149
+
150
+ # TODO @thomasw21: refactor so that we store that default in a single place.
151
+ tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
152
+ tp_linear_async_communication = (
153
+ parallel_config.tp_linear_async_communication if parallel_config is not None else False
154
+ )
155
+
156
+ gate_up_contiguous_chunks = (
157
+ config.intermediate_size, # shape of gate_linear
158
+ config.intermediate_size, # shape of up_linear
159
+ )
160
+ self.gate_up_proj = TensorParallelColumnLinear(
161
+ config.hidden_size,
162
+ 2 * config.intermediate_size,
163
+ pg=tp_pg,
164
+ mode=tp_mode,
165
+ bias=False,
166
+ async_communication=tp_linear_async_communication,
167
+ contiguous_chunks=gate_up_contiguous_chunks,
168
+ )
169
+
170
+ self.down_proj = TensorParallelRowLinear(
171
+ config.intermediate_size,
172
+ config.hidden_size,
173
+ pg=tp_pg,
174
+ mode=tp_mode,
175
+ bias=False,
176
+ async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
177
+ )
178
+ # TODO @nouamane: why can't we torch.jit.script GLUActivation?
179
+ self.split_silu_mul = GLUActivation(config.hidden_act)
180
+
181
+ def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim]
182
+ merged_states = self.gate_up_proj(hidden_states)
183
+ hidden_states = self.down_proj(self.split_silu_mul(merged_states))
184
+ return {"hidden_states": hidden_states}
185
+
186
+
187
+ class CoreAttention(nn.Module):
188
+ def __init__(self, config: MiniCPMConfig, parallel_config: Optional[ParallelismArgs], layer_idx: int):
189
+ super().__init__()
190
+ # TODO @thomasw21: GPT has a weird `d_kv` config which I'm guessing is essentically a `d_qkv`
191
+ assert (
192
+ config.hidden_size % config.num_attention_heads == 0
193
+ ), f"Hidden size {config.hidden_size} must be divisible by number of attention heads {config.num_attention_heads}."
194
+ self.d_qk = config.hidden_size // config.num_attention_heads
195
+ self.d_v = config.hidden_size // config.num_attention_heads
196
+ self.dropout = config.attn_pdrop
197
+
198
+ self.checkpoint_attention = False # Because flash_attn already does checkpointing
199
+
200
+ # if config.sliding_window_size is not None:
201
+ # assert (
202
+ # _flash_supports_window_size
203
+ # ), "Current version of flash-attn doesn't support sliding window: `pip install flash-attn>=2.3`"
204
+ # self.sliding_window_size = config.sliding_window_size # if layer_idx not in config.global_attn_layers else None
205
+
206
+ @checkpoint_method(attr_name="checkpoint_attention")
207
+ def forward(
208
+ self,
209
+ query_states: torch.Tensor, # [batch_size * q_length, num_heads, inner_dim]
210
+ key_states: torch.Tensor, # [batch_size * kv_length, 1, inner_dim]
211
+ value_states: torch.Tensor, # [batch_size * kv_length, 1, inner_dim]
212
+ q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size)
213
+ kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size)
214
+ ):
215
+ # TODO @thomasw21: Compute once, instead of computing for each layers.
216
+ cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
217
+ cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
218
+ torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:])
219
+ torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:])
220
+
221
+ # TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not
222
+ # what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache.
223
+ causal = False if q_sequence_mask.shape[1] == 1 else True
224
+ attn_output = flash_attn_varlen_func(
225
+ q=query_states,
226
+ k=key_states,
227
+ v=value_states,
228
+ cu_seqlens_q=cu_seqlens_q,
229
+ cu_seqlens_k=cu_seqlens_k,
230
+ max_seqlen_q=q_sequence_mask.shape[1],
231
+ max_seqlen_k=kv_sequence_mask.shape[1],
232
+ dropout_p=self.dropout if self.training else 0.0,
233
+ softmax_scale=None, # defaults to 1/sqrt(d_qk)
234
+ causal=causal,
235
+ # window_size=(self.sliding_window_size - 1, 0) if self.sliding_window_size is not None else (-1, -1),
236
+ return_attn_probs=False,
237
+ )
238
+
239
+ return attn_output
240
+
241
+
242
+ def pad_to_right(tensor, mask, new_tensor=None):
243
+ """Transform a left-padded tensor into a right-padded tensor. (Useful for prefilling key/value states)
244
+ Args:
245
+ tensor: (batch_size, seqlen, d1, d2)
246
+ mask: (batch_size, seqlen)
247
+ new_tensor: (batch_size, new_tensor_seqlen, d1, d2)
248
+ Returns:
249
+ new_tensor: (batch_size, new_tensor_seqlen, d1, d2)
250
+ right_padded_mask: (batch_size, seqlen)
251
+ """
252
+ # First, we need to find the number of padding for each row
253
+ unpad_seqlens = mask.sum(1)
254
+ # Then, we need to find the maximum length of the tensor
255
+ max_seqlen = mask.shape[1]
256
+ # We can then create the indices to select the padded values
257
+ # The indices are the same for each row
258
+ indices = torch.arange(max_seqlen, device=mask.device)
259
+ # We can then create the mask for the padded values
260
+ right_padded_mask = indices < unpad_seqlens[:, None]
261
+ # We select the useful values
262
+ useful_values = tensor[mask]
263
+ # We create the new tensor (if not provided)
264
+ new_tensor = torch.zeros_like(tensor) if new_tensor is None else new_tensor
265
+ # We fill the new tensor with the useful values
266
+ new_tensor[:, : right_padded_mask.shape[1], :, :][right_padded_mask] = useful_values
267
+ return new_tensor, right_padded_mask
268
+
269
+
270
+ class CausalSelfAttention(nn.Module, AttachableStore):
271
+ def __init__(
272
+ self,
273
+ config: MiniCPMConfig,
274
+ parallel_config: Optional[ParallelismArgs],
275
+ tp_pg: dist.ProcessGroup,
276
+ layer_idx: int,
277
+ ):
278
+ super().__init__()
279
+ # Tensor parallel considerations: We split tensors along head dimension
280
+ assert (
281
+ config.num_attention_heads % tp_pg.size() == 0
282
+ ), f"Number of attention heads ({config.num_attention_heads}) must be divisible by TP size ({tp_pg.size()})."
283
+ try:
284
+ assert (
285
+ config.num_key_value_heads % tp_pg.size() == 0
286
+ ), f"Number of key/value heads ({config.num_key_value_heads}) must be divisible by TP size ({tp_pg.size()})."
287
+ except AttributeError:
288
+ log_rank(
289
+ "WARNING: num_key_value_heads not defined, assuming it is equal to num_attention_heads",
290
+ logger=logger,
291
+ level=logging.WARNING,
292
+ rank=0,
293
+ )
294
+ # If num_key_value_heads is not defined, we assume that it is equal to num_attention_heads
295
+ config.num_key_value_heads = config.num_attention_heads
296
+ assert (
297
+ config.num_attention_heads % config.num_key_value_heads == 0
298
+ ), f"Number of attention heads ({config.num_attention_heads}) must be divisible by number of key/value heads ({config.num_key_value_heads})."
299
+ self.n_local_q_heads = config.num_attention_heads // tp_pg.size()
300
+ self.n_local_kv_heads = config.num_key_value_heads // tp_pg.size()
301
+ self.n_repeats = config.num_attention_heads // config.num_key_value_heads
302
+ self.is_gqa = config.num_attention_heads != config.num_key_value_heads # Whether we are using GQA or not
303
+ self.d_qk = config.hidden_size // config.num_attention_heads
304
+ self.d_v = config.hidden_size // config.num_attention_heads
305
+ self.d_model = config.hidden_size
306
+
307
+ # TODO @thomasw21: refactor so that we store that default in a single place.
308
+ tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
309
+ tp_linear_async_communication = (
310
+ parallel_config.tp_linear_async_communication if parallel_config is not None else False
311
+ )
312
+
313
+ # build the slice config for self.qkv for save/load
314
+ # shard are done within the contiguous chunk
315
+ qkv_contiguous_chunks = (
316
+ config.num_attention_heads * self.d_qk, # shape of q
317
+ config.num_key_value_heads * self.d_qk, # shape of k
318
+ config.num_key_value_heads * self.d_qk, # shape of v
319
+ )
320
+ self.qkv_proj = TensorParallelColumnLinear(
321
+ self.d_model,
322
+ config.num_attention_heads * self.d_qk + 2 * config.num_key_value_heads * self.d_qk,
323
+ pg=tp_pg,
324
+ mode=tp_mode,
325
+ bias=False,
326
+ async_communication=tp_linear_async_communication,
327
+ contiguous_chunks=qkv_contiguous_chunks,
328
+ )
329
+ # TODO(kunhao): We want to have only one version per device and not one version per layer.
330
+ self.rotary_embedding = RotaryEmbedding(
331
+ dim=self.d_qk,
332
+ end=config.max_position_embeddings,
333
+ theta=config.rope_theta
334
+ )
335
+
336
+ # NOTE: Only supported for training (TODO(fmom): position_ids not supported yet)
337
+ self.flash_rotary_embedding = FlashRotaryEmbedding(dim=self.d_qk, base=config.rope_theta, interleaved=True)
338
+
339
+ self.o_proj = TensorParallelRowLinear(
340
+ config.num_attention_heads * self.d_qk,
341
+ self.d_model,
342
+ pg=tp_pg,
343
+ mode=tp_mode,
344
+ bias=False,
345
+ async_communication=tp_linear_async_communication,
346
+ )
347
+
348
+ self.attention = CoreAttention(
349
+ config,
350
+ parallel_config=parallel_config,
351
+ layer_idx=layer_idx,
352
+ )
353
+
354
+ self.prefill_kv_len = (
355
+ config.max_position_embeddings
356
+ ) # TODO @nouamane: compute based on free memory, because in rope we can surpass max_position_embeddings
357
+
358
+ def forward(
359
+ self,
360
+ hidden_states, # [seq_length, batch_size, hidden_size]
361
+ sequence_mask, # [batch_size, seq_length]
362
+ ):
363
+ qkv_states = self.qkv_proj(
364
+ hidden_states
365
+ ) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk]
366
+ q_length, batch_size, _ = qkv_states.shape
367
+
368
+ if self.is_gqa:
369
+ query_states, key_states, value_states = torch.split(
370
+ qkv_states,
371
+ [
372
+ self.n_local_q_heads * self.d_qk,
373
+ self.n_local_kv_heads * self.d_qk,
374
+ self.n_local_kv_heads * self.d_qk,
375
+ ],
376
+ dim=-1,
377
+ )
378
+
379
+ query_states = (
380
+ query_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_q_heads, self.d_qk)
381
+ )
382
+ key_states = (
383
+ key_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_kv_heads, self.d_qk)
384
+ )
385
+ value_states = (
386
+ value_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_kv_heads, self.d_qk)
387
+ )
388
+ else:
389
+ query_states, key_states, value_states = (
390
+ qkv_states.view(q_length, batch_size, 3, self.n_local_q_heads, self.d_qk)
391
+ .permute(2, 1, 0, 3, 4)
392
+ .contiguous()
393
+ ) # [3, batch_size, seq_length, n_local_q_heads, d_qk]
394
+
395
+ store = self.get_local_store()
396
+ if store is not None: # Inference case
397
+ # Double check that we use store only at inference time
398
+ assert key_states.requires_grad is False
399
+ assert value_states.requires_grad is False
400
+ if "position_offsets" in store:
401
+ old_position_offsets = store["position_offsets"]
402
+ position_ids = old_position_offsets[:, None] + sequence_mask
403
+ else:
404
+ position_ids = torch.cumsum(sequence_mask, dim=-1, dtype=torch.int32) - 1
405
+ position_offsets = position_ids[:, -1]
406
+
407
+ # Compute rotary embeddings
408
+ # Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache
409
+ old_rotary_embed_end = self.rotary_embedding.end
410
+ query_states = self.rotary_embedding(query_states, position_ids=position_ids)
411
+ key_states = self.rotary_embedding(key_states, position_ids=position_ids)
412
+
413
+ if "key" not in store:
414
+ # First inference iteration (Prefill)
415
+ # TODO @nouamane: support custom masking
416
+ # assert that [ False, False, False, False, True, True, True, True, True, True] is accepted
417
+ # but [ False, False, False, False, True, True, False, False, True, True] is not (can't mask in the middle of sequence)
418
+ assert ~(
419
+ sequence_mask[:, :-1] & (~sequence_mask[:, 1:]) # True is never followed by False
420
+ ).any(), "Can't mask in the middle of sequence, please make sure that pads are at the left of the sequence if existing"
421
+
422
+ # preallocate k_cache, v_cache to self.prefill_kv_len
423
+ k_cache = torch.zeros(
424
+ (
425
+ batch_size,
426
+ self.prefill_kv_len,
427
+ self.n_local_kv_heads,
428
+ self.d_qk,
429
+ ),
430
+ dtype=query_states.dtype,
431
+ device=query_states.device,
432
+ )
433
+ v_cache = torch.zeros(
434
+ (batch_size, self.prefill_kv_len, self.n_local_kv_heads, self.d_v),
435
+ dtype=query_states.dtype,
436
+ device=query_states.device,
437
+ )
438
+ # Remove pad tokens from key_states and concatenate samples in key_unpad
439
+ # cu_seqlens_k is the cumulative sequence lengths of key_states
440
+ (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
441
+ query_states,
442
+ sequence_mask,
443
+ )
444
+ (key_unpad, indices_k, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
445
+ key_states, sequence_mask
446
+ )
447
+ (value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask)
448
+
449
+ output_unpad = flash_attn_varlen_func(
450
+ q=query_unpad, # (total_q, n_local_q_heads, d_qk)
451
+ k=key_unpad, # (total_kv, n_local_kv_heads, d_qk)
452
+ v=value_unpad, # (total_kv, n_local_kv_heads, d_v)
453
+ cu_seqlens_q=cu_seqlens_q,
454
+ cu_seqlens_k=cu_seqlens_k,
455
+ max_seqlen_q=max_seqlen_q,
456
+ max_seqlen_k=max_seqlen_k,
457
+ dropout_p=0.0,
458
+ softmax_scale=None,
459
+ causal=True, # True in prefill phase, False in subsequent phases
460
+ return_attn_probs=False,
461
+ ) # (total_unpadded, n_local_q_heads, d_v)
462
+
463
+ attention_output = bert_padding.pad_input(
464
+ output_unpad, indices_q, batch_size, q_length
465
+ ) # (batch_size, q_length, n_local_q_heads, d_v)
466
+
467
+ pad_to_right(key_states, sequence_mask, new_tensor=k_cache)
468
+ pad_to_right(value_states, sequence_mask, new_tensor=v_cache)
469
+
470
+ else:
471
+ # Pull pre-computed key/value states
472
+ # Subsequent inference iterations (q_length=1)
473
+ k_cache = store["key"]
474
+ v_cache = store["value"]
475
+
476
+ # NOTE(fmom): According to flash_attn_with_kvcache, "If you pass in k / v, you must make sure that the cache is large enough to hold the new values"
477
+ # Since rotary embedding has changed (to enable larger context), we need to enlarge k_cache and v_cache
478
+ if self.rotary_embedding.end > old_rotary_embed_end:
479
+ k_cache = torch.cat(
480
+ [
481
+ k_cache,
482
+ torch.zeros(
483
+ (
484
+ batch_size,
485
+ self.rotary_embedding.end - old_rotary_embed_end,
486
+ self.n_local_kv_heads,
487
+ self.d_qk,
488
+ ),
489
+ dtype=query_states.dtype,
490
+ device=query_states.device,
491
+ ),
492
+ ],
493
+ dim=1,
494
+ )
495
+
496
+ v_cache = torch.cat(
497
+ [
498
+ v_cache,
499
+ torch.zeros(
500
+ (
501
+ batch_size,
502
+ self.rotary_embedding.end - old_rotary_embed_end,
503
+ self.n_local_kv_heads,
504
+ self.d_v,
505
+ ),
506
+ dtype=query_states.dtype,
507
+ device=query_states.device,
508
+ ),
509
+ ],
510
+ dim=1,
511
+ )
512
+
513
+ assert (
514
+ k_cache.shape[1] == self.rotary_embedding.end
515
+ ), f"Cache size {k_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}"
516
+ assert (
517
+ v_cache.shape[1] == self.rotary_embedding.end
518
+ ), f"Cache size {v_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}"
519
+
520
+ # [batch_size, seq_length, num_heads, d_qk]
521
+ query_states = query_states.view(
522
+ batch_size, q_length, self.n_local_q_heads, self.d_qk
523
+ ) # [batch_size, q_length, self.n_heads, d_qk]
524
+ kv_length = key_states.shape[1]
525
+ key_states = key_states.view(
526
+ batch_size, kv_length, self.n_local_kv_heads, self.d_qk
527
+ ) # [batch_size, kv_length, self.n_heads, d_qk]
528
+ value_states = value_states.view(
529
+ batch_size, kv_length, self.n_local_kv_heads, self.d_v
530
+ ) # [batch_size, kv_length, self.n_heads, d_v]
531
+
532
+ attention_output = flash_attn_with_kvcache(
533
+ query_states,
534
+ k_cache,
535
+ v_cache,
536
+ key_states,
537
+ value_states,
538
+ rotary_cos=None,
539
+ rotary_sin=None,
540
+ # TODO @nouamane: seems like this doesnt help to indicate padding in (for first iteration it's just 0)
541
+ cache_seqlens=position_offsets.contiguous(),
542
+ softmax_scale=None,
543
+ causal=True,
544
+ rotary_interleaved=False, # GPT-NeoX style
545
+ )
546
+
547
+ store.update(
548
+ {
549
+ "key": k_cache, # flash-attn has updated with new key_states using cache_seqlens
550
+ "value": v_cache,
551
+ "position_offsets": position_offsets,
552
+ }
553
+ )
554
+
555
+ else: # Training case
556
+ # Apply rotary embeddings to query/key states
557
+ # NOTE: The layout is different from models/MiniCPM.py which is [batch_size, num_heads, seq_length, d_qk]
558
+ # Here it is, [batch_size, seq_length, num_heads, d_qk]
559
+ # [2, batch_size, seq_length, num_heads, d_qk]
560
+ key_value_states = torch.cat([key_states.unsqueeze(0), value_states.unsqueeze(0)], dim=0)
561
+ # [batch_size, seq_length, 2, num_heads, d_qk]
562
+ key_value_states = key_value_states.permute(1, 2, 0, 3, 4).contiguous()
563
+ query_states, key_value_states = self.flash_rotary_embedding(query_states, kv=key_value_states)
564
+ # [batch_size, seq_length, num_heads, d_qk]
565
+ key_states, value_states = torch.split(key_value_states, 1, dim=2)
566
+
567
+ q_sequence_mask = sequence_mask
568
+ kv_sequence_mask = sequence_mask
569
+
570
+ kv_length = key_states.shape[1]
571
+ # [batch_size, seq_length, num_heads, d_qk]
572
+ # Shaping for use in `flash-attn` version of flash-attn: `flash_attn_unpadded_func`
573
+ query_states = query_states.view(
574
+ batch_size * q_length, self.n_local_q_heads, self.d_qk
575
+ ) # [batch_size * q_length, self.n_heads, d_qk]
576
+
577
+ key_states = key_states.view(
578
+ batch_size * kv_length, self.n_local_kv_heads, self.d_qk
579
+ ) # [batch_size * kv_length, self.n_heads, d_qk]
580
+ value_states = value_states.view(
581
+ batch_size * kv_length, self.n_local_kv_heads, self.d_v
582
+ ) # [batch_size * kv_length, self.n_heads, d_v]
583
+
584
+ attention_output = self.attention(
585
+ query_states=query_states,
586
+ key_states=key_states,
587
+ value_states=value_states,
588
+ q_sequence_mask=q_sequence_mask,
589
+ kv_sequence_mask=kv_sequence_mask,
590
+ )
591
+
592
+ attention_output = (
593
+ attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1)
594
+ )
595
+ output = self.o_proj(attention_output)
596
+
597
+ return {"hidden_states": output, "sequence_mask": sequence_mask}
598
+
599
+
600
+ class MiniCPMDecoderLayer(nn.Module):
601
+ def __init__(
602
+ self,
603
+ config: MiniCPMConfig,
604
+ parallel_config: Optional[ParallelismArgs],
605
+ tp_pg: dist.ProcessGroup,
606
+ layer_idx: int,
607
+ ):
608
+ super().__init__()
609
+ self.scale_depth = config.scale_depth
610
+ self.num_hidden_layers = config.num_hidden_layers
611
+ self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
612
+ self.attn = CausalSelfAttention(
613
+ config=config,
614
+ parallel_config=parallel_config,
615
+ tp_pg=tp_pg,
616
+ layer_idx=layer_idx,
617
+ )
618
+
619
+ self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
620
+ self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)
621
+
622
+ def forward(
623
+ self,
624
+ hidden_states: Union[torch.Tensor, TensorPointer],
625
+ sequence_mask: Union[torch.Tensor, TensorPointer],
626
+ ) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
627
+ residual = hidden_states
628
+ hidden_states = self.input_layernorm(hidden_states)
629
+
630
+ output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask)
631
+ hidden_states = output["hidden_states"]
632
+ hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
633
+
634
+ residual = hidden_states
635
+ hidden_states = self.post_attention_layernorm(hidden_states)
636
+ hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"]
637
+ hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
638
+
639
+ return {
640
+ "hidden_states": hidden_states,
641
+ "sequence_mask": output["sequence_mask"],
642
+ }
643
+
644
+
645
+ class Embedding(nn.Module, AttachableStore):
646
+ def __init__(self, tp_pg: dist.ProcessGroup, config: MiniCPMConfig, parallel_config: Optional[ParallelismArgs]):
647
+ super().__init__()
648
+ self.token_embedding = TensorParallelEmbedding(
649
+ num_embeddings=config.vocab_size,
650
+ embedding_dim=config.hidden_size,
651
+ padding_idx=config.pad_token_id,
652
+ pg=tp_pg,
653
+ mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE,
654
+ )
655
+ self.pg = tp_pg
656
+
657
+ def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_size, seq_length]
658
+ store = self.get_local_store()
659
+ if store is not None:
660
+ if "past_length" in store:
661
+ past_length = store["past_length"]
662
+ else:
663
+ past_length = torch.zeros(1, dtype=torch.long, device=input_ids.device).expand(input_ids.shape[0])
664
+
665
+ cumsum_mask = input_mask.cumsum(-1, dtype=torch.long)
666
+ # Store new past_length in store
667
+ store["past_length"] = past_length + cumsum_mask[:, -1]
668
+
669
+ # Format input in `[seq_length, batch_size]` to support high TP with low batch_size
670
+ input_ids = input_ids.transpose(0, 1)
671
+ input_embeds = self.token_embedding(input_ids)
672
+ return {"input_embeds": input_embeds}
673
+
674
+
675
+ class MiniCPMModel(nn.Module):
676
+ """Build pipeline graph"""
677
+
678
+ def __init__(
679
+ self,
680
+ config: MiniCPMConfig,
681
+ parallel_context: ParallelContext,
682
+ parallel_config: Optional[ParallelismArgs],
683
+ ):
684
+ super().__init__()
685
+
686
+ # Declare all the nodes
687
+ self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda"))
688
+ self.config = config
689
+ self.parallel_config = parallel_config
690
+ self.parallel_context = parallel_context
691
+ self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
692
+ tp_linear_async_communication = (
693
+ parallel_config.tp_linear_async_communication if parallel_config is not None else False
694
+ )
695
+
696
+ self.token_position_embeddings = PipelineBlock(
697
+ p2p=self.p2p,
698
+ module_builder=Embedding,
699
+ module_kwargs={
700
+ "tp_pg": parallel_context.tp_pg,
701
+ "config": config,
702
+ "parallel_config": parallel_config,
703
+ },
704
+ module_input_keys={"input_ids", "input_mask"},
705
+ module_output_keys={"input_embeds"},
706
+ )
707
+
708
+ self.decoder = nn.ModuleList(
709
+ [
710
+ PipelineBlock(
711
+ p2p=self.p2p,
712
+ module_builder=MiniCPMDecoderLayer,
713
+ module_kwargs={
714
+ "config": config,
715
+ "parallel_config": parallel_config,
716
+ "tp_pg": parallel_context.tp_pg,
717
+ "layer_idx": layer_idx,
718
+ },
719
+ module_input_keys={"hidden_states", "sequence_mask"},
720
+ module_output_keys={"hidden_states", "sequence_mask"},
721
+ )
722
+ for layer_idx in range(config.num_hidden_layers)
723
+ ]
724
+ )
725
+
726
+ self.final_layer_norm = PipelineBlock(
727
+ p2p=self.p2p,
728
+ module_builder=TritonRMSNorm,
729
+ module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps},
730
+ module_input_keys={"input"},
731
+ module_output_keys={"hidden_states"},
732
+ ) # TODO
733
+
734
+ self.lm_head = PipelineBlock(
735
+ p2p=self.p2p,
736
+ # Understand that this means that we return sharded logits that are going to need to be gathered
737
+ module_builder=TensorParallelColumnLinear,
738
+ module_kwargs={
739
+ "in_features": config.hidden_size,
740
+ "out_features": config.vocab_size,
741
+ "pg": parallel_context.tp_pg,
742
+ "bias": False,
743
+ # TODO @thomasw21: refactor so that we store that default in a single place.
744
+ "mode": self.tp_mode,
745
+ "async_communication": tp_linear_async_communication,
746
+ },
747
+ module_input_keys={"x"},
748
+ module_output_keys={"logits"},
749
+ )
750
+
751
+ self.cast_to_fp32 = PipelineBlock(
752
+ p2p=self.p2p,
753
+ module_builder=lambda: lambda x: x.float(),
754
+ module_kwargs={},
755
+ module_input_keys={"x"},
756
+ module_output_keys={"output"},
757
+ )
758
+
759
+ def forward(
760
+ self,
761
+ input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
762
+ input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
763
+ ):
764
+ return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask)[0]
765
+
766
+ def forward_with_hidden_states(
767
+ self,
768
+ input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
769
+ input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
770
+ ):
771
+ # all tensors are optional as most ranks don't need anything from the dataloader.
772
+
773
+ output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask)
774
+
775
+ hidden_encoder_states = {
776
+ "hidden_states": output["input_embeds"] * self.config.scale_emb,
777
+ "sequence_mask": input_mask,
778
+ }
779
+ for encoder_block in self.decoder:
780
+ hidden_encoder_states = encoder_block(**hidden_encoder_states)
781
+
782
+ hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"]
783
+
784
+ sharded_logits = self.lm_head(x=hidden_states / (self.config.hidden_size / self.config.dim_model_base))["logits"]
785
+
786
+ fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"]
787
+
788
+ return fp32_sharded_logits, hidden_states
789
+
790
+ def get_block_compute_costs(self):
791
+ """Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
792
+ model_config = self.config
793
+ d_ff = model_config.intermediate_size
794
+ d_qkv = model_config.hidden_size // model_config.num_attention_heads
795
+ block_compute_costs = {
796
+ # CausalSelfAttention (qkv proj + attn out) + MLP
797
+ MiniCPMDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size
798
+ + 3 * d_ff * model_config.hidden_size,
799
+ # This is the last lm_head
800
+ TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
801
+ }
802
+ return block_compute_costs
803
+
804
+ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
805
+ """Get flops per second for a given model"""
806
+ world_size = self.parallel_context.world_pg.size()
807
+ try:
808
+ num_key_values_heads = self.config.num_key_value_heads
809
+ except AttributeError:
810
+ num_key_values_heads = self.config.num_attention_heads
811
+
812
+ model_flops, hardware_flops = get_flops(
813
+ num_layers=self.config.num_hidden_layers,
814
+ hidden_size=self.config.hidden_size,
815
+ num_heads=self.config.num_attention_heads,
816
+ num_key_value_heads=num_key_values_heads,
817
+ vocab_size=self.config.vocab_size,
818
+ ffn_hidden_size=self.config.intermediate_size,
819
+ seq_len=sequence_length,
820
+ batch_size=global_batch_size,
821
+ recompute_granularity=self.parallel_config.recompute_granularity,
822
+ )
823
+
824
+ model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12)
825
+ hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12)
826
+ return model_flops_per_s, hardware_flops_per_s
827
+
828
+
829
+ @torch.jit.script
830
+ def masked_mean(loss, label_mask, dtype):
831
+ # type: (Tensor, Tensor, torch.dtype) -> Tensor
832
+ return (loss * label_mask).sum(dtype=dtype) / label_mask.sum()
833
+
834
+
835
+ class Loss(nn.Module):
836
+ def __init__(self, tp_pg: dist.ProcessGroup):
837
+ super().__init__()
838
+ self.tp_pg = tp_pg
839
+
840
+ def forward(
841
+ self,
842
+ sharded_logits: torch.Tensor, # [seq_length, batch_size, logits]
843
+ label_ids: torch.Tensor, # [batch_size, seq_length]
844
+ label_mask: torch.Tensor, # [batch_size, seq_length]
845
+ ) -> Dict[str, torch.Tensor]:
846
+ # Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision.
847
+ # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38
848
+ loss = sharded_cross_entropy(
849
+ sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float
850
+ ).transpose(0, 1)
851
+ # TODO @thomasw21: It's unclear what kind of normalization we want to do.
852
+ loss = masked_mean(loss, label_mask, dtype=torch.float)
853
+ # I think indexing causes a sync we don't actually want
854
+ # loss = loss[label_mask].sum()
855
+ return {"loss": loss}
856
+
857
+
858
+ class MiniCPMForTraining(NanotronModel):
859
+ def __init__(
860
+ self,
861
+ config: MiniCPMConfig,
862
+ parallel_context: ParallelContext,
863
+ parallel_config: Optional[ParallelismArgs],
864
+ random_states: Optional[RandomStates] = None,
865
+ ):
866
+ super().__init__()
867
+ import warnings
868
+
869
+ self.model = MiniCPMModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config)
870
+ self.loss = PipelineBlock(
871
+ p2p=self.model.p2p,
872
+ module_builder=Loss,
873
+ module_kwargs={"tp_pg": parallel_context.tp_pg},
874
+ module_input_keys={
875
+ "sharded_logits",
876
+ "label_ids",
877
+ "label_mask",
878
+ },
879
+ module_output_keys={"loss"},
880
+ )
881
+ self.parallel_context = parallel_context
882
+ self.config = config
883
+ self.parallel_config = parallel_config
884
+
885
+ def forward(
886
+ self,
887
+ input_ids: Union[torch.Tensor, TensorPointer],
888
+ input_mask: Union[torch.Tensor, TensorPointer],
889
+ label_ids: Union[torch.Tensor, TensorPointer],
890
+ label_mask: Union[torch.Tensor, TensorPointer],
891
+ ) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
892
+ sharded_logits = self.model(
893
+ input_ids=input_ids,
894
+ input_mask=input_mask,
895
+ )
896
+ loss = self.loss(
897
+ sharded_logits=sharded_logits,
898
+ label_ids=label_ids,
899
+ label_mask=label_mask,
900
+ )["loss"]
901
+ return {"loss": loss}
902
+
903
+ @torch.no_grad()
904
+ def init_model_randomly(self, init_method, scaled_init_method):
905
+ """Initialize model parameters randomly.
906
+ Args:
907
+ init_method (callable): Used for embedding/position/qkv weight in attention/first layer weight of mlp/ /lm_head/
908
+ scaled_init_method (callable): Used for o weight in attention/second layer weight of mlp/
909
+
910
+ Note:
911
+ Layernorm weight all 0 or 1 depending on `apply_layernorm_1p`
912
+ """
913
+ model = self
914
+ initialized_parameters = set()
915
+ # Handle tensor parallelism
916
+ module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
917
+ # Fix the root_model
918
+ module_id_to_prefix[id(model)] = ""
919
+
920
+ for module_name, module in model.named_modules():
921
+ if isinstance(module, TensorParallelColumnLinear):
922
+ # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96
923
+ # What it does:
924
+ # - instantiate a buffer of the `full size` in fp32
925
+ # - run init method on it
926
+ # - shard result to get only a specific shard
927
+ # Instead I'm lazy and just going to run init_method, since they are scalar independent
928
+ assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == {
929
+ name for name, _ in module.named_parameters()
930
+ }
931
+ for param_name, param in module.named_parameters():
932
+ assert isinstance(param, NanotronParameter)
933
+ if param.is_tied:
934
+ tied_info = param.get_tied_info()
935
+ full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
936
+ module_id_to_prefix=module_id_to_prefix
937
+ )
938
+ else:
939
+ full_param_name = f"{module_name}.{param_name}"
940
+
941
+ if full_param_name in initialized_parameters:
942
+ # Already initialized
943
+ continue
944
+
945
+ if "weight" == param_name:
946
+ init_method(param)
947
+ elif "bias" == param_name:
948
+ param.zero_()
949
+ else:
950
+ raise ValueError(f"Who the fuck is {param_name}?")
951
+
952
+ assert full_param_name not in initialized_parameters
953
+ initialized_parameters.add(full_param_name)
954
+ elif isinstance(module, TensorParallelRowLinear):
955
+ # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96
956
+ # What it does:
957
+ # - instantiate a buffer of the `full size` in fp32
958
+ # - run init method on it
959
+ # - shard result to get only a specific shard
960
+ # Instead I'm lazy and just going to run init_method, since they are scalar independent
961
+ assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == {
962
+ name for name, _ in module.named_parameters()
963
+ }
964
+ for param_name, param in module.named_parameters():
965
+ assert isinstance(param, NanotronParameter)
966
+ if param.is_tied:
967
+ tied_info = param.get_tied_info()
968
+ full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
969
+ module_id_to_prefix=module_id_to_prefix
970
+ )
971
+ else:
972
+ full_param_name = f"{module_name}.{param_name}"
973
+
974
+ if full_param_name in initialized_parameters:
975
+ # Already initialized
976
+ continue
977
+
978
+ if "weight" == param_name:
979
+ scaled_init_method(param)
980
+ elif "bias" == param_name:
981
+ param.zero_()
982
+ else:
983
+ raise ValueError(f"Who the fuck is {param_name}?")
984
+
985
+ assert full_param_name not in initialized_parameters
986
+ initialized_parameters.add(full_param_name)
987
+ elif isinstance(module, TritonRMSNorm):
988
+ assert {"weight"} == {name for name, _ in module.named_parameters()}
989
+ for param_name, param in module.named_parameters():
990
+ assert isinstance(param, NanotronParameter)
991
+ if param.is_tied:
992
+ tied_info = param.get_tied_info()
993
+ full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
994
+ module_id_to_prefix=module_id_to_prefix
995
+ )
996
+ else:
997
+ full_param_name = f"{module_name}.{param_name}"
998
+
999
+ if full_param_name in initialized_parameters:
1000
+ # Already initialized
1001
+ continue
1002
+
1003
+ if "weight" == param_name:
1004
+ # TODO @thomasw21: Sometimes we actually want 0
1005
+ param.fill_(1)
1006
+ elif "bias" == param_name:
1007
+ param.zero_()
1008
+ else:
1009
+ raise ValueError(f"Who the fuck is {param_name}?")
1010
+
1011
+ assert full_param_name not in initialized_parameters
1012
+ initialized_parameters.add(full_param_name)
1013
+ elif isinstance(module, TensorParallelEmbedding):
1014
+ # TODO @thomasw21: Handle tied embeddings
1015
+ # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96
1016
+ # What it does:
1017
+ # - instantiate a buffer of the `full size` in fp32
1018
+ # - run init method on it
1019
+ # - shard result to get only a specific shard
1020
+ # Instead I'm lazy and just going to run init_method, since they are scalar independent
1021
+ assert {"weight"} == {name for name, _ in module.named_parameters()}
1022
+
1023
+ assert isinstance(module.weight, NanotronParameter)
1024
+ if module.weight.is_tied:
1025
+ tied_info = module.weight.get_tied_info()
1026
+ full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
1027
+ module_id_to_prefix=module_id_to_prefix
1028
+ )
1029
+ else:
1030
+ full_param_name = f"{module_name}.weight"
1031
+
1032
+ if full_param_name in initialized_parameters:
1033
+ # Already initialized
1034
+ continue
1035
+
1036
+ init_method(module.weight)
1037
+ assert full_param_name not in initialized_parameters
1038
+ initialized_parameters.add(full_param_name)
1039
+
1040
+ assert initialized_parameters == {
1041
+ param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
1042
+ if param.is_tied
1043
+ else name
1044
+ for name, param in model.named_parameters()
1045
+ }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}"
1046
+
1047
+ def get_block_compute_costs(self):
1048
+ """Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
1049
+ return self.model.get_block_compute_costs()
1050
+
1051
+ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
1052
+ """Get flops per second for a given model"""
1053
+ return self.model.get_flops_per_sec(iteration_time_in_sec, sequence_length, global_batch_size)
1054
+
1055
+
1056
+ def get_flops(
1057
+ num_layers,
1058
+ hidden_size,
1059
+ num_heads,
1060
+ vocab_size,
1061
+ seq_len,
1062
+ kv_channels=None,
1063
+ ffn_hidden_size=None,
1064
+ batch_size=1,
1065
+ recompute_granularity=None,
1066
+ glu_activation=False,
1067
+ ):
1068
+ """Counts flops in an decoder-only model
1069
+ Args:
1070
+ num_layers: number of decoder layers
1071
+ hidden_size: hidden size of the model
1072
+ num_heads: number of heads in the model
1073
+ num_key_value_heads: number of key/value heads in the model
1074
+ ffn_hidden_size: hidden size of the FFN
1075
+ vocab_size: size of the vocabulary
1076
+ seq_len: sequence length of the decoder
1077
+ batch_size: batch size
1078
+ recompute_granularity: Activation recomputation method. Either None, FULL or SELECTIVE. Check Megatron-LM docs for more info.
1079
+ Returns:
1080
+ model_flops: flops in the model (should be independent of the hardware and model implementation)
1081
+ hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf
1082
+ """
1083
+ if kv_channels is None:
1084
+ assert hidden_size % num_heads == 0
1085
+ kv_channels = hidden_size // num_heads
1086
+ if ffn_hidden_size is None:
1087
+ ffn_hidden_size = 4 * hidden_size
1088
+
1089
+ # In the following we mark the reduced dimension with parentheses
1090
+ # decoder
1091
+ # self attention (MQA)
1092
+ ## q projection
1093
+ decoder_q_proj_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * kv_channels
1094
+ ## kv projection, shared across heads
1095
+ decoder_kv_proj_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * kv_channels
1096
+ ## qk logits
1097
+ decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * seq_len
1098
+ ### SWA (sliding window attention / local attention)
1099
+ # window_size = 4096
1100
+ # decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * window_size
1101
+ ## v logits
1102
+ decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (seq_len) * kv_channels
1103
+ # decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (window_size) * kv_channels
1104
+ ## attn out
1105
+ decoder_attn_out_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * hidden_size
1106
+ # FF
1107
+ ## 1st layer
1108
+ decoder_ffn_1_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size
1109
+ if glu_activation:
1110
+ # 3 matmuls instead of 2 in FFN
1111
+ # ref. https://arxiv.org/pdf/2002.05202.pdf
1112
+ # Used for example in T5 v1.1
1113
+ decoder_ffn_1_flops_fwd = 4 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size
1114
+ ## 2nd layer
1115
+ decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size
1116
+
1117
+ decoder_flops_fwd = (
1118
+ decoder_q_proj_flops_fwd
1119
+ + decoder_kv_proj_flops_fwd
1120
+ + decoder_qk_logits_flops_fwd
1121
+ + decoder_v_logits_flops_fwd
1122
+ + decoder_attn_out_flops_fwd
1123
+ + decoder_ffn_1_flops_fwd
1124
+ + decoder_ffn_2_flops_fwd
1125
+ )
1126
+
1127
+ # lm head
1128
+ lm_head_flops_fwd = 2 * batch_size * seq_len * (hidden_size) * vocab_size
1129
+
1130
+ # the bwd pass requires double the flops in case of matmuls to calculate the gradients with respect to
1131
+ # both input and weight tensors
1132
+ model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd
1133
+
1134
+ if recompute_granularity is None:
1135
+ hardware_flops = model_flops
1136
+ elif recompute_granularity is RecomputeGranularity.FULL:
1137
+ # Note: we don't recompute lm head activs
1138
+ hardware_flops = model_flops + decoder_flops_fwd # + activ recomputation
1139
+ elif recompute_granularity is RecomputeGranularity.SELECTIVE:
1140
+ # all terms with s^2 are flops that are recomputed
1141
+ # ref. appendix A: https://arxiv.org/pdf/2205.05198.pdf
1142
+ recomputed_decoder_flops = decoder_qk_logits_flops_fwd + decoder_v_logits_flops_fwd
1143
+ hardware_flops = model_flops + recomputed_decoder_flops
1144
+ else:
1145
+ raise ValueError("recompute_granularity must be one of 'full' or 'selective'")
1146
+
1147
+ return model_flops, hardware_flops
pretrained/MiniCPM-2B-dpo-bf16/checkpoint_metadata.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dp": 1,
3
+ "metas": {
4
+ "consumed_train_samples": 0,
5
+ "last_train_step": 0
6
+ },
7
+ "tp": 1,
8
+ "version": "1.2"
9
+ }
pretrained/MiniCPM-2B-dpo-bf16/config.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoints: null
2
+ data: null
3
+ general:
4
+ benchmark_csv_path: null
5
+ consumed_train_samples: null
6
+ ignore_sanity_checks: false
7
+ project: openbmb
8
+ run: MiniCPM-2B-dpo-bf16
9
+ seed: 42
10
+ step: 0
11
+ lighteval: null
12
+ logging: null
13
+ model:
14
+ ddp_bucket_cap_mb: 25
15
+ dtype: bfloat16
16
+ init_method:
17
+ std: 0.025
18
+ make_vocab_size_divisible_by: 1
19
+ model_config:
20
+ attn_pdrop: 0.0
21
+ bos_token_id: 1
22
+ dim_model_base: 256
23
+ eos_token_id: 2
24
+ hidden_act: silu
25
+ hidden_size: 2304
26
+ initializer_range: 0.1
27
+ intermediate_size: 5760
28
+ max_position_embeddings: 2048
29
+ num_attention_heads: 36
30
+ num_hidden_layers: 40
31
+ num_key_value_heads: 36
32
+ pad_token_id: null
33
+ pretraining_tp: 1
34
+ rms_norm_eps: 1.0e-05
35
+ rope_theta: 10000.0
36
+ scale_depth: 1.4
37
+ scale_emb: 12
38
+ tie_word_embeddings: true
39
+ use_cache: true
40
+ vocab_size: 122753
41
+ optimizer: null
42
+ parallelism:
43
+ dp: 1
44
+ pp: 1
45
+ pp_engine: 1f1b
46
+ recompute_granularity: SELECTIVE
47
+ tp: 1
48
+ tp_linear_async_communication: true
49
+ tp_mode: REDUCE_SCATTER
50
+ profiler: null
51
+ tokenizer:
52
+ tokenizer_max_length: null
53
+ tokenizer_name_or_path: openbmb/MiniCPM-2B-dpo-bf16
54
+ tokenizer_revision: null
55
+ tokens: null
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/0/pp_block/attn/o_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36d4080827449bba1682daf6cab4b546e2fc02bb6a5c62efff2470a7e83202f5
3
+ size 10617072
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/0/pp_block/attn/qkv_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8843d169ca3025d42ddfdb590fc51d01122c1f9212f9581d4dd319430f0e45b
3
+ size 31850848
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/0/pp_block/input_layernorm/model_weight.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:beb60636c7e802f851754225517f7f375a0faff02721fd37a2d460a2ca61458d
3
+ size 4704
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/0/pp_block/mlp/down_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5fec63ea6756d045071b457036fe09b783b34570efa1705d1d5120611db651f
3
+ size 26542320
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/0/pp_block/mlp/gate_up_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b895ee74603d6a88a024041a16266e351714207a61f966ceeb5e638b4ba1a4ed
3
+ size 53084456
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/0/pp_block/post_attention_layernorm/model_weight.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a319fc293f19f00db8aa2811b1a730a38a5966459493151020240c7adcb5dfd6
3
+ size 4704
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/1/pp_block/attn/o_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:495d95a04eea9c36ec15b8c1add6e51a52052ee0fd90b446683d7a412d01b526
3
+ size 10617072
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/1/pp_block/attn/qkv_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:908e028114e8dc887ca2e80245678ea7ab08662b8462bdf9111d5e342f9de729
3
+ size 31850848
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/1/pp_block/input_layernorm/model_weight.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6a6bc147e3a4d94a72c4a48c86b7b4517662141923380b80e216897e0c647bf
3
+ size 4704
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/1/pp_block/mlp/down_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88805e4d760f1e78e2d42e7ff18c1ffa61051571bf8b9eefc6fdd9f3843770e4
3
+ size 26542320
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/1/pp_block/mlp/gate_up_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16ff202e99dfe71bb7619cc3580a9f59c788868f8f6fc58c6c694f5eae562c89
3
+ size 53084456
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/1/pp_block/post_attention_layernorm/model_weight.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f240f0046868ee13ee354feb943d5c382358774d8e1533dd53e16040d9618124
3
+ size 4704
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/10/pp_block/attn/o_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56ab2210f9e6e6018e84b6e96ad6bdc525e46396049ff66caa0285fb7a170120
3
+ size 10617072
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/10/pp_block/attn/qkv_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8736369fe6db6ca62bdd30014a6c920c6d6bb570c85b79d5ce43bff75366a499
3
+ size 31850848
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/10/pp_block/input_layernorm/model_weight.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e37863cb49b75e41eff480d6fb94a3240cbc884dc2e76215d36f8d6207fba5ee
3
+ size 4704
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/10/pp_block/mlp/down_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b7d3522e6a49963803f5e4248ff76e2a9d51a044c76cfac156bc581b794f492
3
+ size 26542320
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/10/pp_block/mlp/gate_up_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1fdd9830af039908d5a402c6edeccb2527d5cdddfbe020dd3e5adc4d33e0371
3
+ size 53084456
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/10/pp_block/post_attention_layernorm/model_weight.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:639ee973e90fb837781436fbf443facd7ed1f73f070041ffd71c14196d06e2ce
3
+ size 4704
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/11/pp_block/attn/o_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70f9c571fb7c57aa4a57d51fe4c5d05fe21a34e32a2f4896049c1d9bed6ea32b
3
+ size 10617072
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/11/pp_block/attn/qkv_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f4a20b46bd3098c0f1efdd665bba8f99adc2363257db31a8b7a0e4e73a540bd
3
+ size 31850848
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/11/pp_block/input_layernorm/model_weight.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c572d06613b5e508d138b5af0180b5b3a09c78abc7b06a07039fec736487d1e9
3
+ size 4704
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/11/pp_block/mlp/down_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d021e4f9a713730b67f2e8df642639628ebd5674eb960c39575c62e3d92212f
3
+ size 26542320
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/11/pp_block/mlp/gate_up_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a290b7bf65372eb56ae8036fbe048ee0d94a39e01debf2c07600d2228bd6cb0c
3
+ size 53084456
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/11/pp_block/post_attention_layernorm/model_weight.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03a6490c81420efe127f1f8dc68c7f8734a9b936ca901e2f5e2485d90d178469
3
+ size 4704
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/12/pp_block/attn/o_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1d3c77097833aa4ea96466ae247220e1f2a4ef6957c65f3257ad8b1da84ef70
3
+ size 10617072
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/12/pp_block/attn/qkv_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ca582e7a768fa2766bf655a6acb54ee6d8c3ec1fe63a7b8bc5ebb53846828f2
3
+ size 31850848
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/12/pp_block/input_layernorm/model_weight.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b66f9fc608826efaa57dff502e1f0ba1bea00996208f1e8e875f58459d82c65
3
+ size 4704
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/12/pp_block/mlp/down_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36035e96236f5c5f39b19d5f53af38da02d05317849852bc5f2dd0eca315148a
3
+ size 26542320
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/12/pp_block/mlp/gate_up_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1c6047d8d338b808b94f73626c1a0ef9ac33a4c8cf964041ba8c9db3d8f9629
3
+ size 53084456
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/12/pp_block/post_attention_layernorm/model_weight.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f3c926d53d3eb31a3c6aea0390794e16fd9ac99d4a99038a48de9ce750895ed4
3
+ size 4704
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/13/pp_block/attn/o_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f1b53cb9c0ea5fe7080e68a0ad6e988aaf958315d3b58276d9ec7e2bd239e67
3
+ size 10617072
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/13/pp_block/attn/qkv_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08eca1997fce9a5815f81299a83848066b309a2d031f88e5012878236ae4650c
3
+ size 31850848
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/13/pp_block/input_layernorm/model_weight.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab3f62ae9b040d8aa3e1b9d3a8b81d89ad731c5d9bb32757a43b116fcfd79ca0
3
+ size 4704
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/13/pp_block/mlp/down_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb415ebe5cbfc55a84d7657c74cf1634a9f534a5980242e59484b9c2a041d896
3
+ size 26542320
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/13/pp_block/mlp/gate_up_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4baf79283bdd79707aabf64311ee95bb4149435bb4469aa0e43aed65caa8c667
3
+ size 53084456
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/13/pp_block/post_attention_layernorm/model_weight.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5543638dcaccb01fa706781c9dc67b820df41a7209c25494dc82b6db8e63b1f3
3
+ size 4704
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/14/pp_block/attn/o_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0883579e29c4b28ee7f4af72be5d77c2f045fca9757ca06f6d0a8c75de6abd0f
3
+ size 10617072
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/14/pp_block/attn/qkv_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c436a6adf43608478642ab6efbf750101b4b3fc4d19425b649443309bde8417a
3
+ size 31850848
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/14/pp_block/input_layernorm/model_weight.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43b8fe2b4b03bbe4ddbd0dc572b4cfac0286624eb2865d96f9afdb75c727643f
3
+ size 4704
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/14/pp_block/mlp/down_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6194cb03e3b7c3260e8a34e632f7736177a5ee15a46d184d90e6e94d5a6802e
3
+ size 26542320
pretrained/MiniCPM-2B-dpo-bf16/model/model/decoder/14/pp_block/mlp/gate_up_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec68dd15fe1f1643fec13bb3cb19d8139beed19fdf48075c097baf16c5f23ba3
3
+ size 53084456