Jayfeather1024 commited on
Commit
3022639
1 Parent(s): 50de6a9
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<pad>": 32000
3
+ }
arguments.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name_or_path": "huggyllama/llama-7b",
3
+ "max_length": 512,
4
+ "trust_remote_code": true,
5
+ "train_datasets": [
6
+ [
7
+ "alpaca",
8
+ {
9
+ "proportion": 1.0
10
+ }
11
+ ]
12
+ ],
13
+ "eval_datasets": null,
14
+ "epochs": 3,
15
+ "per_device_train_batch_size": 4,
16
+ "per_device_eval_batch_size": 4,
17
+ "gradient_accumulation_steps": 16,
18
+ "gradient_checkpointing": true,
19
+ "lr": 2e-05,
20
+ "lr_scheduler_type": "cosine",
21
+ "lr_warmup_ratio": 0.03,
22
+ "weight_decay": 0.0,
23
+ "seed": 42,
24
+ "fp16": false,
25
+ "bf16": true,
26
+ "tf32": true,
27
+ "eval_strategy": "epoch",
28
+ "eval_interval": 1000000,
29
+ "need_eval": false,
30
+ "eval_split_ratio": null,
31
+ "output_dir": "/data/jiongxiao_wang/rlhf_attack/safe-rlhf/output/sft",
32
+ "log_type": "wandb",
33
+ "log_dir": "/data/jiongxiao_wang/rlhf_attack/safe-rlhf/output/sft",
34
+ "log_project": "Safe-RLHF-SFT",
35
+ "log_run_name": "sft-2023-12-31-20-07-40",
36
+ "save_16bit": false,
37
+ "save_interval": 1000000,
38
+ "local_rank": 0,
39
+ "zero_stage": 3,
40
+ "deepspeed": false,
41
+ "deepspeed_config": null,
42
+ "deepscale": false,
43
+ "deepscale_config": null,
44
+ "deepspeed_mpi": false,
45
+ "global_rank": 0,
46
+ "device": {
47
+ "type": "torch.device",
48
+ "repr": "device(type='cuda', index=0)"
49
+ },
50
+ "num_update_steps_per_epoch": 204,
51
+ "total_training_steps": 612
52
+ }
arguments.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:558c472797170401090f0c1a08e8d1c8d31bcad35438ef6134aceb8a269cc318
3
+ size 1019
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "huggyllama/llama-7b",
3
+ "architectures": [
4
+ "LlamaForCausalLM"
5
+ ],
6
+ "bos_token_id": 1,
7
+ "eos_token_id": 2,
8
+ "hidden_act": "silu",
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 11008,
12
+ "max_position_embeddings": 2048,
13
+ "max_sequence_length": 2048,
14
+ "model_type": "llama",
15
+ "num_attention_heads": 32,
16
+ "num_hidden_layers": 32,
17
+ "num_key_value_heads": 32,
18
+ "pad_token_id": 32000,
19
+ "pretraining_tp": 1,
20
+ "rms_norm_eps": 1e-06,
21
+ "rope_scaling": null,
22
+ "tie_word_embeddings": false,
23
+ "torch_dtype": "float16",
24
+ "transformers_version": "4.31.0",
25
+ "use_cache": true,
26
+ "vocab_size": 32001
27
+ }
environ.txt ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ADDR2LINE=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-addr2line
2
+ AR=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-ar
3
+ AS=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-as
4
+ BASH_FUNC__spack_shell_wrapper()=() { for var in LD_LIBRARY_PATH DYLD_LIBRARY_PATH DYLD_FALLBACK_LIBRARY_PATH;
5
+ do
6
+ eval "if [ -n \"\${${var}-}\" ]; then export SPACK_$var=\${${var}}; fi";
7
+ done;
8
+ if [ -n "${ZSH_VERSION:-}" ]; then
9
+ emulate -L sh;
10
+ fi;
11
+ _sp_flags="";
12
+ while [ ! -z ${1+x} ] && [ "${1#-}" != "${1}" ]; do
13
+ _sp_flags="$_sp_flags $1";
14
+ shift;
15
+ done;
16
+ if [ -n "$_sp_flags" ] && [ "${_sp_flags#*h}" != "${_sp_flags}" ] || [ "${_sp_flags#*V}" != "${_sp_flags}" ]; then
17
+ command spack $_sp_flags "$@";
18
+ return;
19
+ fi;
20
+ _sp_subcommand="";
21
+ if [ ! -z ${1+x} ]; then
22
+ _sp_subcommand="$1";
23
+ shift;
24
+ fi;
25
+ case $_sp_subcommand in
26
+ "cd")
27
+ _sp_arg="";
28
+ if [ -n "$1" ]; then
29
+ _sp_arg="$1";
30
+ shift;
31
+ fi;
32
+ if [ "$_sp_arg" = "-h" ] || [ "$_sp_arg" = "--help" ]; then
33
+ command spack cd -h;
34
+ else
35
+ LOC="$(spack location $_sp_arg "$@")";
36
+ if [ -d "$LOC" ]; then
37
+ cd "$LOC";
38
+ else
39
+ return 1;
40
+ fi;
41
+ fi;
42
+ return
43
+ ;;
44
+ "env")
45
+ _sp_arg="";
46
+ if [ -n "$1" ]; then
47
+ _sp_arg="$1";
48
+ shift;
49
+ fi;
50
+ if [ "$_sp_arg" = "-h" ] || [ "$_sp_arg" = "--help" ]; then
51
+ command spack env -h;
52
+ else
53
+ case $_sp_arg in
54
+ activate)
55
+ _a=" $@";
56
+ if [ -z ${1+x} ] || [ "${_a#* --sh}" != "$_a" ] || [ "${_a#* --csh}" != "$_a" ] || [ "${_a#* -h}" != "$_a" ] || [ "${_a#* --help}" != "$_a" ]; then
57
+ command spack env activate "$@";
58
+ else
59
+ stdout="$(command spack $_sp_flags env activate --sh "$@")" || return;
60
+ eval "$stdout";
61
+ fi
62
+ ;;
63
+ deactivate)
64
+ _a=" $@";
65
+ if [ "${_a#* --sh}" != "$_a" ] || [ "${_a#* --csh}" != "$_a" ]; then
66
+ command spack env deactivate "$@";
67
+ else
68
+ if [ -n "$*" ]; then
69
+ command spack env deactivate -h;
70
+ else
71
+ stdout="$(command spack $_sp_flags env deactivate --sh)" || return;
72
+ eval "$stdout";
73
+ fi;
74
+ fi
75
+ ;;
76
+ *)
77
+ command spack env $_sp_arg "$@"
78
+ ;;
79
+ esac;
80
+ fi;
81
+ return
82
+ ;;
83
+ "load" | "unload")
84
+ _a=" $@";
85
+ if [ "${_a#* --sh}" != "$_a" ] || [ "${_a#* --csh}" != "$_a" ] || [ "${_a#* -h}" != "$_a" ] || [ "${_a#* --list}" != "$_a" ] || [ "${_a#* --help}" != "$_a" ]; then
86
+ command spack $_sp_flags $_sp_subcommand "$@";
87
+ else
88
+ stdout="$(command spack $_sp_flags $_sp_subcommand --sh "$@")" || return;
89
+ eval "$stdout";
90
+ fi
91
+ ;;
92
+ *)
93
+ command spack $_sp_flags $_sp_subcommand "$@"
94
+ ;;
95
+ esac
96
+ }
97
+ BASH_FUNC_module()=() { eval `/usr/bin/modulecmd bash $*`
98
+ }
99
+ BASH_FUNC_spack()=() { : this is a shell function from: /nfs/cluster/spack/share/spack/setup-env.sh;
100
+ : the real spack script is here: /nfs/cluster/spack/bin/spack;
101
+ _spack_shell_wrapper "$@";
102
+ return $?
103
+ }
104
+ BUILD=x86_64-conda-linux-gnu
105
+ CC=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-cc
106
+ CC_FOR_BUILD=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-cc
107
+ CFLAGS=-march=nocona -mtune=haswell -ftree-vectorize -fPIC -fstack-protector-strong -fno-plt -O2 -ffunction-sections -pipe -isystem /data/jiongxiao_wang/anaconda3/envs/safe-rlhf/include
108
+ CMAKE_ARGS=-DCMAKE_AR=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-ar -DCMAKE_CXX_COMPILER_AR=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-gcc-ar -DCMAKE_C_COMPILER_AR=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-gcc-ar -DCMAKE_RANLIB=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-ranlib -DCMAKE_CXX_COMPILER_RANLIB=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-gcc-ranlib -DCMAKE_C_COMPILER_RANLIB=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-gcc-ranlib -DCMAKE_LINKER=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-ld -DCMAKE_STRIP=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-strip -DCMAKE_BUILD_TYPE=Release
109
+ CMAKE_PREFIX_PATH=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf:/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/x86_64-conda-linux-gnu/sysroot/usr
110
+ COLORTERM=truecolor
111
+ CONDA_BUILD_SYSROOT=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/x86_64-conda-linux-gnu/sysroot
112
+ CONDA_DEFAULT_ENV=safe-rlhf
113
+ CONDA_EXE=/data/jiongxiao_wang/anaconda3/bin/conda
114
+ CONDA_PREFIX=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf
115
+ CONDA_PREFIX_1=/data/jiongxiao_wang/anaconda3
116
+ CONDA_PROMPT_MODIFIER=(safe-rlhf)
117
+ CONDA_PYTHON_EXE=/data/jiongxiao_wang/anaconda3/bin/python
118
+ CONDA_SHLVL=2
119
+ CONDA_TOOLCHAIN_BUILD=x86_64-conda-linux-gnu
120
+ CONDA_TOOLCHAIN_HOST=x86_64-conda-linux-gnu
121
+ CPP=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-cpp
122
+ CPPFLAGS=-DNDEBUG -D_FORTIFY_SOURCE=2 -O2 -isystem /data/jiongxiao_wang/anaconda3/envs/safe-rlhf/include
123
+ CROSS_RANK=0
124
+ CROSS_SIZE=1
125
+ CUDA_MODULE_LOADING=LAZY
126
+ CUDA_VISIBLE_DEVICES=0,1,2,3
127
+ CXX=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-c++
128
+ CXXFILT=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-c++filt
129
+ CXXFLAGS=-fvisibility-inlines-hidden -fmessage-length=0 -march=nocona -mtune=haswell -ftree-vectorize -fPIC -fstack-protector-strong -fno-plt -O2 -ffunction-sections -pipe -isystem /data/jiongxiao_wang/anaconda3/envs/safe-rlhf/include
130
+ CXX_FOR_BUILD=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-c++
131
+ DEBUG_CFLAGS=-march=nocona -mtune=haswell -ftree-vectorize -fPIC -fstack-protector-all -fno-plt -Og -g -Wall -Wextra -fvar-tracking-assignments -ffunction-sections -pipe -isystem /data/jiongxiao_wang/anaconda3/envs/safe-rlhf/include
132
+ DEBUG_CPPFLAGS=-D_DEBUG -D_FORTIFY_SOURCE=2 -Og -isystem /data/jiongxiao_wang/anaconda3/envs/safe-rlhf/include
133
+ DEBUG_CXXFLAGS=-fvisibility-inlines-hidden -fmessage-length=0 -march=nocona -mtune=haswell -ftree-vectorize -fPIC -fstack-protector-all -fno-plt -Og -g -Wall -Wextra -fvar-tracking-assignments -ffunction-sections -pipe -isystem /data/jiongxiao_wang/anaconda3/envs/safe-rlhf/include
134
+ ELFEDIT=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-elfedit
135
+ ENVIRONMENT=BATCH
136
+ GCC=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-gcc
137
+ GCC_AR=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-gcc-ar
138
+ GCC_NM=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-gcc-nm
139
+ GCC_RANLIB=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-gcc-ranlib
140
+ GPROF=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-gprof
141
+ GPU_DEVICE_ORDINAL=0,1,2,3
142
+ GXX=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-g++
143
+ HOME=/data/jiongxiao_wang
144
+ HOST=x86_64-conda-linux-gnu
145
+ HOSTNAME=compute-permanent-node-153
146
+ LANG=en_US.UTF-8
147
+ LD=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-ld
148
+ LDFLAGS=-Wl,-O2 -Wl,--sort-common -Wl,--as-needed -Wl,-z,relro -Wl,-z,now -Wl,--disable-new-dtags -Wl,--gc-sections -Wl,--allow-shlib-undefined -Wl,-rpath,/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/lib -Wl,-rpath-link,/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/lib -L/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/lib
149
+ LD_GOLD=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-ld.gold
150
+ LD_LIBRARY_PATH=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/lib:/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/lib:
151
+ LESSOPEN=||/usr/bin/lesspipe.sh %s
152
+ LOADEDMODULES=
153
+ LOCAL_RANK=0
154
+ LOCAL_SIZE=4
155
+ LOGLEVEL=WARNING
156
+ LOGNAME=jiongxiao_wang
157
+ LS_COLORS=rs=0:di=38;5;27:ln=38;5;51:mh=44;38;5;15:pi=40;38;5;11:so=38;5;13:do=38;5;5:bd=48;5;232;38;5;11:cd=48;5;232;38;5;3:or=48;5;232;38;5;9:mi=05;48;5;232;38;5;15:su=48;5;196;38;5;15:sg=48;5;11;38;5;16:ca=48;5;196;38;5;226:tw=48;5;10;38;5;16:ow=48;5;10;38;5;21:st=48;5;21;38;5;15:ex=38;5;34:*.tar=38;5;9:*.tgz=38;5;9:*.arc=38;5;9:*.arj=38;5;9:*.taz=38;5;9:*.lha=38;5;9:*.lz4=38;5;9:*.lzh=38;5;9:*.lzma=38;5;9:*.tlz=38;5;9:*.txz=38;5;9:*.tzo=38;5;9:*.t7z=38;5;9:*.zip=38;5;9:*.z=38;5;9:*.Z=38;5;9:*.dz=38;5;9:*.gz=38;5;9:*.lrz=38;5;9:*.lz=38;5;9:*.lzo=38;5;9:*.xz=38;5;9:*.bz2=38;5;9:*.bz=38;5;9:*.tbz=38;5;9:*.tbz2=38;5;9:*.tz=38;5;9:*.deb=38;5;9:*.rpm=38;5;9:*.jar=38;5;9:*.war=38;5;9:*.ear=38;5;9:*.sar=38;5;9:*.rar=38;5;9:*.alz=38;5;9:*.ace=38;5;9:*.zoo=38;5;9:*.cpio=38;5;9:*.7z=38;5;9:*.rz=38;5;9:*.cab=38;5;9:*.jpg=38;5;13:*.jpeg=38;5;13:*.gif=38;5;13:*.bmp=38;5;13:*.pbm=38;5;13:*.pgm=38;5;13:*.ppm=38;5;13:*.tga=38;5;13:*.xbm=38;5;13:*.xpm=38;5;13:*.tif=38;5;13:*.tiff=38;5;13:*.png=38;5;13:*.svg=38;5;13:*.svgz=38;5;13:*.mng=38;5;13:*.pcx=38;5;13:*.mov=38;5;13:*.mpg=38;5;13:*.mpeg=38;5;13:*.m2v=38;5;13:*.mkv=38;5;13:*.webm=38;5;13:*.ogm=38;5;13:*.mp4=38;5;13:*.m4v=38;5;13:*.mp4v=38;5;13:*.vob=38;5;13:*.qt=38;5;13:*.nuv=38;5;13:*.wmv=38;5;13:*.asf=38;5;13:*.rm=38;5;13:*.rmvb=38;5;13:*.flc=38;5;13:*.avi=38;5;13:*.fli=38;5;13:*.flv=38;5;13:*.gl=38;5;13:*.dl=38;5;13:*.xcf=38;5;13:*.xwd=38;5;13:*.yuv=38;5;13:*.cgm=38;5;13:*.emf=38;5;13:*.axv=38;5;13:*.anx=38;5;13:*.ogv=38;5;13:*.ogx=38;5;13:*.aac=38;5;45:*.au=38;5;45:*.flac=38;5;45:*.mid=38;5;45:*.midi=38;5;45:*.mka=38;5;45:*.mp3=38;5;45:*.mpc=38;5;45:*.ogg=38;5;45:*.ra=38;5;45:*.wav=38;5;45:*.axa=38;5;45:*.oga=38;5;45:*.spx=38;5;45:*.xspf=38;5;45:
158
+ MAIL=/var/mail/jiongxiao_wang
159
+ MASTER_ADDR=127.0.0.1
160
+ MASTER_PORT=56337
161
+ MESON_ARGS=--buildtype release
162
+ MODULEPATH=/usr/share/Modules/modulefiles:/etc/modulefiles
163
+ MODULESHOME=/usr/share/Modules
164
+ NIX_CONF_DIR=/nix
165
+ NM=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-nm
166
+ OBJCOPY=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-objcopy
167
+ OBJDUMP=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-objdump
168
+ PATH=/nix/var/nix/profiles/default/bin:/data/jiongxiao_wang/.nix-profile/bin:/data/jiongxiao_wang/.vscode-server/bin/899d46d82c4c95423fb7e10e68eba52050e30ba3/bin:/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin:/data/jiongxiao_wang/anaconda3/condabin:/nix/var/nix/profiles/default/bin:/data/jiongxiao_wang/.nix-profile/bin:/nfs/cluster/spack/bin:/usr/local/bin:/usr/bin:/var/lib/snapd/snap/bin
169
+ PWD=/data/jiongxiao_wang/rlhf_attack/safe-rlhf
170
+ PYTHONHASHSEED=42
171
+ PYTHONPATH=/data/jiongxiao_wang/rlhf_attack/safe-rlhf
172
+ RANK=0
173
+ RANLIB=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-ranlib
174
+ READELF=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-readelf
175
+ ROCR_VISIBLE_DEVICES=0,1,2,3
176
+ SHELL=/bin/bash
177
+ SHLVL=6
178
+ SIZE=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-size
179
+ SLURMD_NODENAME=compute-permanent-node-153
180
+ SLURM_CLUSTER_NAME=cluster
181
+ SLURM_CONF=/var/spool/slurmd/conf-cache/slurm.conf
182
+ SLURM_CPUS_ON_NODE=8
183
+ SLURM_GPUS_ON_NODE=4
184
+ SLURM_GTIDS=0
185
+ SLURM_JOBID=1031281
186
+ SLURM_JOB_ACCOUNT=chaowei_xiao
187
+ SLURM_JOB_CPUS_PER_NODE=8
188
+ SLURM_JOB_END_TIME=1704226008
189
+ SLURM_JOB_GID=10043
190
+ SLURM_JOB_GPUS=2,3,5,6
191
+ SLURM_JOB_ID=1031281
192
+ SLURM_JOB_NAME=rlhf
193
+ SLURM_JOB_NODELIST=compute-permanent-node-153
194
+ SLURM_JOB_NUM_NODES=1
195
+ SLURM_JOB_PARTITION=compute
196
+ SLURM_JOB_QOS=default_qos
197
+ SLURM_JOB_START_TIME=1704053208
198
+ SLURM_JOB_UID=10193
199
+ SLURM_JOB_USER=jiongxiao_wang
200
+ SLURM_LOCALID=0
201
+ SLURM_MEM_PER_NODE=40960
202
+ SLURM_NNODES=1
203
+ SLURM_NODEID=0
204
+ SLURM_NODELIST=compute-permanent-node-153
205
+ SLURM_NODE_ALIASES=(null)
206
+ SLURM_NPROCS=1
207
+ SLURM_NTASKS=1
208
+ SLURM_PRIO_PROCESS=0
209
+ SLURM_PROCID=0
210
+ SLURM_SUBMIT_DIR=/data/jiongxiao_wang/rlhf_attack/safe-rlhf
211
+ SLURM_SUBMIT_HOST=watch-tower-login
212
+ SLURM_TASKS_PER_NODE=1
213
+ SLURM_TASK_PID=189456
214
+ SLURM_TOPOLOGY_ADDR=watch-tower.watch-tower:f102d2c503fbef087183246a.compute-permanent-node-153
215
+ SLURM_TOPOLOGY_ADDR_PATTERN=switch.switch.node
216
+ SLURM_WORKING_CLUSTER=cluster:watch-tower-bastion:6817:9984:109
217
+ SPACK_PYTHON=/usr/bin/python3
218
+ SPACK_ROOT=/nfs/cluster/spack
219
+ SSH_CLIENT=73.208.16.93 60006 22
220
+ SSH_CONNECTION=73.208.16.93 60006 172.16.0.238 22
221
+ STRINGS=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-strings
222
+ STRIP=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/x86_64-conda-linux-gnu-strip
223
+ TERM=xterm-256color
224
+ TERM_PROGRAM=vscode
225
+ TERM_PROGRAM_VERSION=1.63.2
226
+ TF2_BEHAVIOR=1
227
+ TF_CPP_MIN_LOG_LEVEL=1
228
+ TMPDIR=/tmp
229
+ TPU_ML_PLATFORM=Tensorflow
230
+ USER=jiongxiao_wang
231
+ VSCODE_IPC_HOOK_CLI=/run/user/10193/vscode-ipc-dc9d7ab9-84f5-43c9-9d9b-fbbb0ce088b6.sock
232
+ WANDB_API_KEY=f6021dca133c93e80a7dae4620bd335d4d08cac6
233
+ WANDB_SERVICE=2-189882-tcp-localhost-56894
234
+ WORLD_SIZE=4
235
+ XDG_DATA_DIRS=/usr/local/share:/usr/share:/var/lib/snapd/desktop
236
+ XDG_RUNTIME_DIR=/run/user/10193
237
+ XDG_SESSION_ID=604022
238
+ ZE_AFFINITY_MASK=0,1,2,3
239
+ _=/data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/deepspeed
240
+ _CE_CONDA=
241
+ _CE_M=
242
+ _CONDA_PYTHON_SYSCONFIGDATA_NAME=_sysconfigdata_x86_64_conda_cos6_linux_gnu
243
+ build_alias=x86_64-conda-linux-gnu
244
+ host_alias=x86_64-conda-linux-gnu
latest ADDED
@@ -0,0 +1 @@
 
 
1
+ global_step609
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f72e5f4678e96c70fadc38a9c9aa8fbacef8613f1487e8dd78bcb84bc9fa2e1
3
+ size 26953811021
script.sh ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ #
3
+ # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # ==============================================================================
17
+ export WANDB_API_KEY="f6021dca133c93e80a7dae4620bd335d4d08cac6"
18
+
19
+ if [ -z "${BASH_VERSION}" ]; then
20
+ echo "Please use bash to run this script." >&2
21
+ exit 1
22
+ fi
23
+
24
+ set -x
25
+
26
+ SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)"
27
+ ROOT_DIR="$(dirname "${SCRIPT_DIR}")"
28
+ export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}"
29
+ export LOGLEVEL="${LOGLEVEL:-WARNING}"
30
+
31
+ MODEL_NAME_OR_PATH="huggyllama/llama-7b"
32
+ OUTPUT_DIR="${ROOT_DIR}/output/sft"
33
+ ZERO_STAGE=3
34
+ while [[ "$#" -gt 0 ]]; do
35
+ arg="$1"
36
+ shift
37
+ case "${arg}" in
38
+ --model_name_or_path)
39
+ MODEL_NAME_OR_PATH="$1"
40
+ shift
41
+ ;;
42
+ --model_name_or_path=*)
43
+ MODEL_NAME_OR_PATH="${arg#*=}"
44
+ ;;
45
+ --output_dir)
46
+ OUTPUT_DIR="$1"
47
+ shift
48
+ ;;
49
+ --output_dir=*)
50
+ OUTPUT_DIR="${arg#*=}"
51
+ ;;
52
+ --zero_stage)
53
+ ZERO_STAGE="$1"
54
+ shift
55
+ ;;
56
+ --zero_stage=*)
57
+ ZERO_STAGE="${arg#*=}"
58
+ ;;
59
+ *)
60
+ echo "Unknown parameter passed: '${arg}'" >&2
61
+ exit 1
62
+ ;;
63
+ esac
64
+ done
65
+
66
+ mkdir -p "${OUTPUT_DIR}"
67
+ OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)"
68
+ if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then
69
+ echo '*' >"${OUTPUT_DIR}/.gitignore"
70
+ fi
71
+
72
+ cp -f "$0" "${OUTPUT_DIR}/script.sh"
73
+
74
+ if [[ -z "${WANDB_API_KEY}" ]]; then
75
+ export WANDB_MODE="offline"
76
+ fi
77
+
78
+ MASTER_PORT_START=10000
79
+ MASTER_PORT_END=65535
80
+ MASTER_PORT="$(
81
+ comm -23 \
82
+ <(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \
83
+ <(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) |
84
+ shuf | head -n 1
85
+ )"
86
+
87
+ exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2)
88
+
89
+ deepspeed --num_nodes=1 --num_gpus=4 \
90
+ --master_port "${MASTER_PORT}" \
91
+ --module safe_rlhf.finetune \
92
+ --train_datasets alpaca \
93
+ --model_name_or_path "${MODEL_NAME_OR_PATH}" \
94
+ --max_length 512 \
95
+ --trust_remote_code True \
96
+ --epochs 3 \
97
+ --per_device_train_batch_size 4 \
98
+ --per_device_eval_batch_size 4 \
99
+ --gradient_accumulation_steps 16 \
100
+ --gradient_checkpointing \
101
+ --learning_rate 2e-5 \
102
+ --lr_scheduler_type cosine \
103
+ --lr_warmup_ratio 0.03 \
104
+ --weight_decay 0.0 \
105
+ --seed 42 \
106
+ --output_dir "${OUTPUT_DIR}" \
107
+ --log_type wandb \
108
+ --log_project Safe-RLHF-SFT \
109
+ --zero_stage "${ZERO_STAGE}" \
110
+ --bf16 True \
111
+ --tf32 True
special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<pad>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
stdout.log ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [2023-12-31 20:06:51,176] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
2
+ [2023-12-31 20:06:55,387] [WARNING] [runner.py:202:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
3
+ Detected CUDA_VISIBLE_DEVICES=0,1,2,3 but ignoring it because one or several of --include/--exclude/--num_gpus/--num_nodes cl args were used. If you want to use CUDA_VISIBLE_DEVICES don't pass any of these arguments to deepspeed.
4
+ [2023-12-31 20:06:55,387] [INFO] [runner.py:571:main] cmd = /data/jiongxiao_wang/anaconda3/envs/safe-rlhf/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMCwgMSwgMiwgM119 --master_addr=127.0.0.1 --master_port=56337 --module --enable_each_rank_log=None safe_rlhf.finetune --train_datasets alpaca --model_name_or_path huggyllama/llama-7b --max_length 512 --trust_remote_code True --epochs 3 --per_device_train_batch_size 4 --per_device_eval_batch_size 4 --gradient_accumulation_steps 16 --gradient_checkpointing --learning_rate 2e-5 --lr_scheduler_type cosine --lr_warmup_ratio 0.03 --weight_decay 0.0 --seed 42 --output_dir /data/jiongxiao_wang/rlhf_attack/safe-rlhf/output/sft --log_type wandb --log_project Safe-RLHF-SFT --zero_stage 3 --bf16 True --tf32 True
5
+ [2023-12-31 20:06:57,487] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
6
+ [2023-12-31 20:07:00,478] [INFO] [launch.py:145:main] WORLD INFO DICT: {'localhost': [0, 1, 2, 3]}
7
+ [2023-12-31 20:07:00,478] [INFO] [launch.py:151:main] nnodes=1, num_local_procs=4, node_rank=0
8
+ [2023-12-31 20:07:00,478] [INFO] [launch.py:162:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0, 1, 2, 3]})
9
+ [2023-12-31 20:07:00,478] [INFO] [launch.py:163:main] dist_world_size=4
10
+ [2023-12-31 20:07:00,478] [INFO] [launch.py:165:main] Setting CUDA_VISIBLE_DEVICES=0,1,2,3
11
+ [2023-12-31 20:07:02,815] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
12
+ [2023-12-31 20:07:02,856] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
13
+ [2023-12-31 20:07:03,011] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
14
+ [2023-12-31 20:07:03,040] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
15
+ [2023-12-31 20:07:10,639] [INFO] [comm.py:637:init_distributed] cdb=None
16
+ [2023-12-31 20:07:10,640] [INFO] [comm.py:637:init_distributed] cdb=None
17
+ [2023-12-31 20:07:10,640] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
18
+ [2023-12-31 20:07:10,670] [INFO] [comm.py:637:init_distributed] cdb=None
19
+ [2023-12-31 20:07:10,675] [INFO] [comm.py:637:init_distributed] cdb=None
20
+ Set logger level to WARNING.
21
+ ninja: no work to do.
22
+ Time to load fused_adam op: 0.14865803718566895 seconds
23
+ Time to load fused_adam op: 0.2057504653930664 seconds
24
+ Time to load fused_adam op: 0.20213913917541504 seconds
25
+ Time to load fused_adam op: 0.2022261619567871 seconds
26
+ Parameter Offload: Total persistent parameters: 266240 in 65 params
27
+ ***** Running training *****
28
+ Saving model to "/data/jiongxiao_wang/rlhf_attack/safe-rlhf/output/sft" ...
29
+ Saving DeepSpeed Checkpoints...
30
+ Converting DeepSpeed Checkpoints to Hugging Face format...
31
+ [2023-12-31 21:51:42,560] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
32
+ Processing zero checkpoint './global_step609'
33
+ Detected checkpoint of type zero stage 3, world_size: 4
34
+ Parsing checkpoint created by deepspeed==0.12.6
35
+ Reconstructed Trainable fp32 state dict with 291 params 6738423808 elements
36
+ Saving fp32 state dict to pytorch_model.bin
37
+ Model saved!
38
+ [2023-12-31 21:52:50,198] [INFO] [launch.py:347:main] Process 189883 exits successfully.
39
+ [2023-12-31 21:52:50,198] [INFO] [launch.py:347:main] Process 189885 exits successfully.
40
+ [2023-12-31 21:52:50,198] [INFO] [launch.py:347:main] Process 189884 exits successfully.
41
+ [2023-12-31 21:52:58,206] [INFO] [launch.py:347:main] Process 189882 exits successfully.
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
tokenizer_config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<s>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "clean_up_tokenization_spaces": false,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "</s>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "legacy": true,
22
+ "model_max_length": 512,
23
+ "pad_token": null,
24
+ "padding_side": "right",
25
+ "sp_model_kwargs": {},
26
+ "tokenizer_class": "LlamaTokenizer",
27
+ "unk_token": {
28
+ "__type": "AddedToken",
29
+ "content": "<unk>",
30
+ "lstrip": false,
31
+ "normalized": true,
32
+ "rstrip": false,
33
+ "single_word": false
34
+ }
35
+ }
zero_to_fp32.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example: python zero_to_fp32.py . pytorch_model.bin
14
+
15
+ import argparse
16
+ import torch
17
+ import glob
18
+ import math
19
+ import os
20
+ import re
21
+ from collections import OrderedDict
22
+ from dataclasses import dataclass
23
+
24
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
25
+ # DeepSpeed data structures it has to be available in the current python environment.
26
+ from deepspeed.utils import logger
27
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
28
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
29
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
30
+
31
+
32
+ @dataclass
33
+ class zero_model_state:
34
+ buffers: dict()
35
+ param_shapes: dict()
36
+ shared_params: list
37
+ ds_version: int
38
+ frozen_param_shapes: dict()
39
+ frozen_param_fragments: dict()
40
+
41
+
42
+ debug = 0
43
+
44
+ # load to cpu
45
+ device = torch.device('cpu')
46
+
47
+
48
+ def atoi(text):
49
+ return int(text) if text.isdigit() else text
50
+
51
+
52
+ def natural_keys(text):
53
+ '''
54
+ alist.sort(key=natural_keys) sorts in human order
55
+ http://nedbatchelder.com/blog/200712/human_sorting.html
56
+ (See Toothy's implementation in the comments)
57
+ '''
58
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
59
+
60
+
61
+ def get_model_state_file(checkpoint_dir, zero_stage):
62
+ if not os.path.isdir(checkpoint_dir):
63
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
64
+
65
+ # there should be only one file
66
+ if zero_stage <= 2:
67
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
68
+ elif zero_stage == 3:
69
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
70
+
71
+ if not os.path.exists(file):
72
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
73
+
74
+ return file
75
+
76
+
77
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
78
+ # XXX: need to test that this simple glob rule works for multi-node setup too
79
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
80
+
81
+ if len(ckpt_files) == 0:
82
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
83
+
84
+ return ckpt_files
85
+
86
+
87
+ def get_optim_files(checkpoint_dir):
88
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
89
+
90
+
91
+ def get_model_state_files(checkpoint_dir):
92
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
93
+
94
+
95
+ def parse_model_states(files):
96
+ zero_model_states = []
97
+ for file in files:
98
+ state_dict = torch.load(file, map_location=device)
99
+
100
+ if BUFFER_NAMES not in state_dict:
101
+ raise ValueError(f"{file} is not a model state checkpoint")
102
+ buffer_names = state_dict[BUFFER_NAMES]
103
+ if debug:
104
+ print("Found buffers:", buffer_names)
105
+
106
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
107
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
108
+ param_shapes = state_dict[PARAM_SHAPES]
109
+
110
+ # collect parameters that are included in param_shapes
111
+ param_names = []
112
+ for s in param_shapes:
113
+ for name in s.keys():
114
+ param_names.append(name)
115
+
116
+ # update with frozen parameters
117
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
118
+ if frozen_param_shapes is not None:
119
+ if debug:
120
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
121
+ param_names += list(frozen_param_shapes.keys())
122
+
123
+ # handle shared params
124
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
125
+
126
+ ds_version = state_dict.get(DS_VERSION, None)
127
+
128
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
129
+
130
+ z_model_state = zero_model_state(buffers=buffers,
131
+ param_shapes=param_shapes,
132
+ shared_params=shared_params,
133
+ ds_version=ds_version,
134
+ frozen_param_shapes=frozen_param_shapes,
135
+ frozen_param_fragments=frozen_param_fragments)
136
+ zero_model_states.append(z_model_state)
137
+
138
+ return zero_model_states
139
+
140
+
141
+ def parse_optim_states(files, ds_checkpoint_dir):
142
+
143
+ total_files = len(files)
144
+ state_dicts = []
145
+ for f in files:
146
+ state_dict = torch.load(f, map_location=device)
147
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
148
+ # and also handle the case where it was already removed by another helper script
149
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
150
+ state_dicts.append(state_dict)
151
+
152
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
153
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
154
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
155
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
156
+
157
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
158
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
159
+ # use the max of the partition_count to get the dp world_size.
160
+
161
+ if type(world_size) is list:
162
+ world_size = max(world_size)
163
+
164
+ if world_size != total_files:
165
+ raise ValueError(
166
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
167
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
168
+ )
169
+
170
+ # the groups are named differently in each stage
171
+ if zero_stage <= 2:
172
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
173
+ elif zero_stage == 3:
174
+ fp32_groups_key = FP32_FLAT_GROUPS
175
+ else:
176
+ raise ValueError(f"unknown zero stage {zero_stage}")
177
+
178
+ if zero_stage <= 2:
179
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
180
+ elif zero_stage == 3:
181
+ # if there is more than one param group, there will be multiple flattened tensors - one
182
+ # flattened tensor per group - for simplicity merge them into a single tensor
183
+ #
184
+ # XXX: could make the script more memory efficient for when there are multiple groups - it
185
+ # will require matching the sub-lists of param_shapes for each param group flattened tensor
186
+
187
+ fp32_flat_groups = [
188
+ torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
189
+ ]
190
+
191
+ return zero_stage, world_size, fp32_flat_groups
192
+
193
+
194
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
195
+ """
196
+ Returns fp32 state_dict reconstructed from ds checkpoint
197
+
198
+ Args:
199
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
200
+
201
+ """
202
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
203
+
204
+ optim_files = get_optim_files(ds_checkpoint_dir)
205
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
206
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
207
+
208
+ model_files = get_model_state_files(ds_checkpoint_dir)
209
+
210
+ zero_model_states = parse_model_states(model_files)
211
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
212
+
213
+ if zero_stage <= 2:
214
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
215
+ elif zero_stage == 3:
216
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)
217
+
218
+
219
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
220
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
221
+ return
222
+
223
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
224
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
225
+
226
+ if debug:
227
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
228
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
229
+
230
+ wanted_params = len(frozen_param_shapes)
231
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
232
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
233
+ print(f'Frozen params: Have {avail_numel} numels to process.')
234
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
235
+
236
+ total_params = 0
237
+ total_numel = 0
238
+ for name, shape in frozen_param_shapes.items():
239
+ total_params += 1
240
+ unpartitioned_numel = shape.numel()
241
+ total_numel += unpartitioned_numel
242
+
243
+ state_dict[name] = frozen_param_fragments[name]
244
+
245
+ if debug:
246
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
247
+
248
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
249
+
250
+
251
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
252
+ param_shapes = zero_model_states[0].param_shapes
253
+
254
+ # Reconstruction protocol:
255
+ #
256
+ # XXX: document this
257
+
258
+ if debug:
259
+ for i in range(world_size):
260
+ for j in range(len(fp32_flat_groups[0])):
261
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
262
+
263
+ # XXX: memory usage doubles here (zero2)
264
+ num_param_groups = len(fp32_flat_groups[0])
265
+ merged_single_partition_of_fp32_groups = []
266
+ for i in range(num_param_groups):
267
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
268
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
269
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
270
+ avail_numel = sum(
271
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
272
+
273
+ if debug:
274
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
275
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
276
+ # not asserting if there is a mismatch due to possible padding
277
+ print(f"Have {avail_numel} numels to process.")
278
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
279
+
280
+ # params
281
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
282
+ # out-of-core computing solution
283
+ total_numel = 0
284
+ total_params = 0
285
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
286
+ offset = 0
287
+ avail_numel = full_single_fp32_vector.numel()
288
+ for name, shape in shapes.items():
289
+
290
+ unpartitioned_numel = shape.numel()
291
+ total_numel += unpartitioned_numel
292
+ total_params += 1
293
+
294
+ if debug:
295
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
296
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
297
+ offset += unpartitioned_numel
298
+
299
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
300
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
301
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
302
+ # live optimizer object, so we are checking that the numbers are within the right range
303
+ align_to = 2 * world_size
304
+
305
+ def zero2_align(x):
306
+ return align_to * math.ceil(x / align_to)
307
+
308
+ if debug:
309
+ print(f"original offset={offset}, avail_numel={avail_numel}")
310
+
311
+ offset = zero2_align(offset)
312
+ avail_numel = zero2_align(avail_numel)
313
+
314
+ if debug:
315
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
316
+
317
+ # Sanity check
318
+ if offset != avail_numel:
319
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
320
+
321
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
322
+
323
+
324
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states):
325
+ state_dict = OrderedDict()
326
+
327
+ # buffers
328
+ buffers = zero_model_states[0].buffers
329
+ state_dict.update(buffers)
330
+ if debug:
331
+ print(f"added {len(buffers)} buffers")
332
+
333
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
334
+
335
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
336
+
337
+ # recover shared parameters
338
+ for pair in zero_model_states[0].shared_params:
339
+ if pair[1] in state_dict:
340
+ state_dict[pair[0]] = state_dict[pair[1]]
341
+
342
+ return state_dict
343
+
344
+
345
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
346
+ remainder = unpartitioned_numel % world_size
347
+ padding_numel = (world_size - remainder) if remainder else 0
348
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
349
+ return partitioned_numel, padding_numel
350
+
351
+
352
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
353
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
354
+ return
355
+
356
+ if debug:
357
+ for i in range(world_size):
358
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
359
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
360
+
361
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
362
+ wanted_params = len(frozen_param_shapes)
363
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
364
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
365
+ print(f'Frozen params: Have {avail_numel} numels to process.')
366
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
367
+
368
+ total_params = 0
369
+ total_numel = 0
370
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
371
+ total_params += 1
372
+ unpartitioned_numel = shape.numel()
373
+ total_numel += unpartitioned_numel
374
+
375
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
376
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
377
+
378
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
379
+
380
+ if debug:
381
+ print(
382
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
383
+ )
384
+
385
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
386
+
387
+
388
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
389
+ param_shapes = zero_model_states[0].param_shapes
390
+ avail_numel = fp32_flat_groups[0].numel() * world_size
391
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
392
+ # param, re-consolidating each param, while dealing with padding if any
393
+
394
+ # merge list of dicts, preserving order
395
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
396
+
397
+ if debug:
398
+ for i in range(world_size):
399
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
400
+
401
+ wanted_params = len(param_shapes)
402
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
403
+ # not asserting if there is a mismatch due to possible padding
404
+ avail_numel = fp32_flat_groups[0].numel() * world_size
405
+ print(f"Trainable params: Have {avail_numel} numels to process.")
406
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
407
+
408
+ # params
409
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
410
+ # out-of-core computing solution
411
+ offset = 0
412
+ total_numel = 0
413
+ total_params = 0
414
+ for name, shape in param_shapes.items():
415
+
416
+ unpartitioned_numel = shape.numel()
417
+ total_numel += unpartitioned_numel
418
+ total_params += 1
419
+
420
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
421
+
422
+ if debug:
423
+ print(
424
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
425
+ )
426
+
427
+ # XXX: memory usage doubles here
428
+ state_dict[name] = torch.cat(
429
+ tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
430
+ 0).narrow(0, 0, unpartitioned_numel).view(shape)
431
+ offset += partitioned_numel
432
+
433
+ offset *= world_size
434
+
435
+ # Sanity check
436
+ if offset != avail_numel:
437
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
438
+
439
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
440
+
441
+
442
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states):
443
+ state_dict = OrderedDict()
444
+
445
+ # buffers
446
+ buffers = zero_model_states[0].buffers
447
+ state_dict.update(buffers)
448
+ if debug:
449
+ print(f"added {len(buffers)} buffers")
450
+
451
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
452
+
453
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
454
+
455
+ # recover shared parameters
456
+ for pair in zero_model_states[0].shared_params:
457
+ if pair[1] in state_dict:
458
+ state_dict[pair[0]] = state_dict[pair[1]]
459
+
460
+ return state_dict
461
+
462
+
463
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
464
+ """
465
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
466
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
467
+ via a model hub.
468
+
469
+ Args:
470
+ - ``checkpoint_dir``: path to the desired checkpoint folder
471
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
472
+
473
+ Returns:
474
+ - pytorch ``state_dict``
475
+
476
+ Note: this approach may not work if your application doesn't have sufficient free CPU memory and
477
+ you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
478
+ the checkpoint.
479
+
480
+ A typical usage might be ::
481
+
482
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
483
+ # do the training and checkpoint saving
484
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
485
+ model = model.cpu() # move to cpu
486
+ model.load_state_dict(state_dict)
487
+ # submit to model hub or save the model to share with others
488
+
489
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
490
+ application. i.e. you will need to re-initialize the deepspeed engine, since
491
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
492
+
493
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
494
+
495
+ """
496
+ if tag is None:
497
+ latest_path = os.path.join(checkpoint_dir, 'latest')
498
+ if os.path.isfile(latest_path):
499
+ with open(latest_path, 'r') as fd:
500
+ tag = fd.read().strip()
501
+ else:
502
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
503
+
504
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
505
+
506
+ if not os.path.isdir(ds_checkpoint_dir):
507
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
508
+
509
+ return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
510
+
511
+
512
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
513
+ """
514
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
515
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
516
+
517
+ Args:
518
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
519
+ - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
520
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
521
+ """
522
+
523
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
524
+ print(f"Saving fp32 state dict to {output_file}")
525
+ torch.save(state_dict, output_file)
526
+
527
+
528
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
529
+ """
530
+ 1. Put the provided model to cpu
531
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
532
+ 3. Load it into the provided model
533
+
534
+ Args:
535
+ - ``model``: the model object to update
536
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
537
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
538
+
539
+ Returns:
540
+ - ``model`: modified model
541
+
542
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
543
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
544
+ conveniently placed for you in the checkpoint folder.
545
+
546
+ A typical usage might be ::
547
+
548
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
549
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
550
+ # submit to model hub or save the model to share with others
551
+
552
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
553
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
554
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
555
+
556
+ """
557
+ logger.info(f"Extracting fp32 weights")
558
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
559
+
560
+ logger.info(f"Overwriting model with fp32 weights")
561
+ model = model.cpu()
562
+ model.load_state_dict(state_dict, strict=False)
563
+
564
+ return model
565
+
566
+
567
+ if __name__ == "__main__":
568
+
569
+ parser = argparse.ArgumentParser()
570
+ parser.add_argument("checkpoint_dir",
571
+ type=str,
572
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
573
+ parser.add_argument(
574
+ "output_file",
575
+ type=str,
576
+ help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
577
+ parser.add_argument("-t",
578
+ "--tag",
579
+ type=str,
580
+ default=None,
581
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
582
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
583
+ args = parser.parse_args()
584
+
585
+ debug = args.debug
586
+
587
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag)