Problem with finetuning with Axolotl (qlora, lora, and FFT)

#1
by Undi95 - opened

Hello, I have a hard time trying to finetune this model with Axolotl.
I use 1xA100 and 125GB of RAM, and when Axolotl load it, 4 bit, 8 bit or full model, it get past the 125GB of RAM and don't even load the model on VRAM, so I get an Out Of Memory crash on every try.
Is the model have an issue? Or the config have an issue? I don't really know.

I was excited because 9B with 200k context sounded great, but it's impossible to train anything on my side.
I tried last version of axolotl, some old commit, etc... But no luck.

Maybe it's a problem on axolotl side too?
Thanks for any replies!

Edit:
6B work tho (batch size one for the test, qlora, load in 4bit, 4096 ctx)

image.png

I had the same issue doing QLoRA at 4096 context. It crashed with a generic error. Weird because I thought this model was llama compatible.

01-ai org

Hello, I have a hard time trying to finetune this model with Axolotl.
I use 1xA100 and 125GB of RAM, and when Axolotl load it, 4 bit, 8 bit or full model, it get past the 125GB of RAM and don't even load the model on VRAM, so I get an Out Of Memory crash on every try.
Is the model have an issue? Or the config have an issue? I don't really know.

I was excited because 9B with 200k context sounded great, but it's impossible to train anything on my side.
I tried last version of axolotl, some old commit, etc... But no luck.

Maybe it's a problem on axolotl side too?
Thanks for any replies!

Edit:
6B work tho (batch size one for the test, qlora, load in 4bit, 4096 ctx)

image.png

I lack experience with “Axolotl”.
If Yi-6B-200K works fine, Yi-9B-200K might need more memory. Also, the 'max_position_embeddings' for Yi-9B-200K is set at 256K by default. Maybe try adjusting it down to 200K and see if that solves the issue? https://huggingface.co/01-ai/Yi-9B-200K/blob/main/config.json#L13

01-ai org

I had the same issue doing QLoRA at 4096 context. It crashed with a generic error. Weird because I thought this model was llama compatible.

It has been confirmed to be compatible with Llama. Could you provide a more comprehensive set of crash logs?

Hi, thank you for your response. Unfortunately the log wasn't too informative. I observed an extreme amount of RAM ( not vRAM ) use as the model was loaded and then it crashed with the following after hanging for several minutes while loading the model. Unfortunately its not a helpful error.

torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: -9) localrank: 3 (pid: 13727) of binary: /root/miniconda3/envs/py3.10/bin/python3
Traceback (most recent call last):
File "/root/miniconda3/envs/py3.10/bin/accelerate", line 8, in
sys.exit(main())
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/commands/acceleratecli.py", line 47, in main
args.func(args)
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1014, in launch_command
multi_gpu_launcher(args)
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/commands/launch.py", line 672, in multi_gpu_launcher
distrib_run.run(args)
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/run.py", line 797, in run
elastic_launch(
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call
return launch_agent(self._config, self._entrypoint, list(args))
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

axolotl.cli.train FAILED

Failures:

Root Cause (first observed failure):
[0]:
time : 2024-03-16_19:34:31
host : ba783a85a837
rank : 3 (local_rank: 3)
exitcode : -9 (pid: 13727)
error_file: <N/A>
traceback : Signal 9 (SIGKILL) received by PID 13727

01-ai org

Hi, thank you for your response. Unfortunately the log wasn't too informative. I observed an extreme amount of RAM ( not vRAM ) use as the model was loaded and then it crashed with the following after hanging for several minutes while loading the model. Unfortunately its not a helpful error.

torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: -9) localrank: 3 (pid: 13727) of binary: /root/miniconda3/envs/py3.10/bin/python3
Traceback (most recent call last):
File "/root/miniconda3/envs/py3.10/bin/accelerate", line 8, in
sys.exit(main())
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/commands/acceleratecli.py", line 47, in main
args.func(args)
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1014, in launch_command
multi_gpu_launcher(args)
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/commands/launch.py", line 672, in multi_gpu_launcher
distrib_run.run(args)
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/run.py", line 797, in run
elastic_launch(
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call
return launch_agent(self._config, self._entrypoint, list(args))
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

axolotl.cli.train FAILED

Failures:

Root Cause (first observed failure):
[0]:
time : 2024-03-16_19:34:31
host : ba783a85a837
rank : 3 (local_rank: 3)
exitcode : -9 (pid: 13727)
error_file: <N/A>
traceback : Signal 9 (SIGKILL) received by PID 13727

Unfortunately, no insights were derived. Is the Yi-9B base model capable of operating effectively when implementing QLoRA with a context length of 4096?
https://huggingface.co/01-ai/Yi-9B

I had similar issues when finetuning Yi-34B-200K on 24GB of VRAM. Set max_position_embeddings (I literally edit config.json file) to something like 4096 or 32768 for the finetuning and then change it back to higher values for inference. With 16384 resulting LoRA was broken, so not every value will work - I think you need to pick something that 01.ai used internally for pre-training, they probably scaled it to bigger and bigger values during pre-training. Also, make sure to have flash_attention_2 installed and activated. Most if not all of long context performance survives doing this once you revert the value post-tuning.

Sign up or log in to comment