attempt to also run e2e tests that needs gpus (#1070)
Browse files* attempt to also run e2e tests that needs gpus
* fix stray quote
* checkout specific github ref
* dockerfile for tests with proper checkout
ensure wandb is dissabled for docker pytests
clear wandb env after testing
clear wandb env after testing
make sure to provide a default val for pop
tryin skipping wandb validation tests
explicitly disable wandb in the e2e tests
explicitly report_to None to see if that fixes the docker e2e tests
split gpu from non-gpu unit tests
skip bf16 check in test for now
build docker w/o cache since it uses branch name ref
revert some changes now that caching is fixed
skip bf16 check if on gpu w support
* pytest skip for auto-gptq requirements
* skip mamba tests for now, split multipack and non packed lora llama tests
* split tests that use monkeypatches
* fix relative import for prev commit
* move other tests using monkeypatches to the correct run
- .github/workflows/tests-docker.yml +10 -2
- docker/Dockerfile-tests +40 -0
- tests/e2e/patched/__init__.py +0 -0
- tests/e2e/{test_fused_llama.py β patched/test_fused_llama.py} +1 -1
- tests/e2e/patched/test_lora_llama_multipack.py +126 -0
- tests/e2e/{test_mistral_samplepack.py β patched/test_mistral_samplepack.py} +1 -1
- tests/e2e/{test_mixtral_samplepack.py β patched/test_mixtral_samplepack.py} +1 -1
- tests/e2e/{test_model_patches.py β patched/test_model_patches.py} +1 -1
- tests/e2e/{test_resume.py β patched/test_resume.py} +2 -2
- tests/e2e/test_lora_llama.py +0 -93
- tests/e2e/test_mamba.py +5 -2
- tests/e2e/test_phi.py +10 -2
- tests/test_validation.py +17 -0
@@ -36,11 +36,19 @@ jobs:
|
|
36 |
PYTORCH_VERSION="${{ matrix.pytorch }}"
|
37 |
# Build the Docker image
|
38 |
docker build . \
|
39 |
-
--file ./docker/Dockerfile \
|
40 |
--build-arg BASE_TAG=$BASE_TAG \
|
41 |
--build-arg CUDA=$CUDA \
|
|
|
42 |
--build-arg PYTORCH_VERSION=$PYTORCH_VERSION \
|
43 |
-
--tag test-axolotl
|
|
|
44 |
- name: Unit Tests w docker image
|
45 |
run: |
|
46 |
docker run --rm test-axolotl pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
PYTORCH_VERSION="${{ matrix.pytorch }}"
|
37 |
# Build the Docker image
|
38 |
docker build . \
|
39 |
+
--file ./docker/Dockerfile-tests \
|
40 |
--build-arg BASE_TAG=$BASE_TAG \
|
41 |
--build-arg CUDA=$CUDA \
|
42 |
+
--build-arg GITHUB_REF=$GITHUB_REF \
|
43 |
--build-arg PYTORCH_VERSION=$PYTORCH_VERSION \
|
44 |
+
--tag test-axolotl \
|
45 |
+
--no-cache
|
46 |
- name: Unit Tests w docker image
|
47 |
run: |
|
48 |
docker run --rm test-axolotl pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
49 |
+
- name: GPU Unit Tests w docker image
|
50 |
+
run: |
|
51 |
+
docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm test-axolotl pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/
|
52 |
+
- name: GPU Unit Tests monkeypatched w docker image
|
53 |
+
run: |
|
54 |
+
docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm test-axolotl pytest /workspace/axolotl/tests/e2e/patched/
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ARG BASE_TAG=main-base
|
2 |
+
FROM winglian/axolotl-base:$BASE_TAG
|
3 |
+
|
4 |
+
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
5 |
+
ARG AXOLOTL_EXTRAS=""
|
6 |
+
ARG CUDA="118"
|
7 |
+
ENV BNB_CUDA_VERSION=$CUDA
|
8 |
+
ARG PYTORCH_VERSION="2.0.1"
|
9 |
+
ARG GITHUB_REF="main"
|
10 |
+
|
11 |
+
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
12 |
+
|
13 |
+
RUN apt-get update && \
|
14 |
+
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
15 |
+
|
16 |
+
WORKDIR /workspace
|
17 |
+
|
18 |
+
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
19 |
+
|
20 |
+
WORKDIR /workspace/axolotl
|
21 |
+
|
22 |
+
RUN git fetch origin +$GITHUB_REF && \
|
23 |
+
git checkout FETCH_HEAD
|
24 |
+
|
25 |
+
# If AXOLOTL_EXTRAS is set, append it in brackets
|
26 |
+
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
27 |
+
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
|
28 |
+
else \
|
29 |
+
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
|
30 |
+
fi
|
31 |
+
|
32 |
+
# So we can test the Docker image
|
33 |
+
RUN pip install pytest
|
34 |
+
|
35 |
+
# fix so that git fetch/pull from remote works
|
36 |
+
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
37 |
+
git config --get remote.origin.fetch
|
38 |
+
|
39 |
+
# helper for huggingface-login cli
|
40 |
+
RUN git config --global credential.helper store
|
File without changes
|
@@ -15,7 +15,7 @@ from axolotl.train import train
|
|
15 |
from axolotl.utils.config import normalize_config
|
16 |
from axolotl.utils.dict import DictDefault
|
17 |
|
18 |
-
from
|
19 |
|
20 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
21 |
os.environ["WANDB_DISABLED"] = "true"
|
|
|
15 |
from axolotl.utils.config import normalize_config
|
16 |
from axolotl.utils.dict import DictDefault
|
17 |
|
18 |
+
from ..utils import with_temp_dir
|
19 |
|
20 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
21 |
os.environ["WANDB_DISABLED"] = "true"
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
E2E tests for lora llama
|
3 |
+
"""
|
4 |
+
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import unittest
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import pytest
|
11 |
+
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
|
12 |
+
|
13 |
+
from axolotl.cli import load_datasets
|
14 |
+
from axolotl.common.cli import TrainerCliArgs
|
15 |
+
from axolotl.train import train
|
16 |
+
from axolotl.utils.config import normalize_config
|
17 |
+
from axolotl.utils.dict import DictDefault
|
18 |
+
|
19 |
+
from ..utils import with_temp_dir
|
20 |
+
|
21 |
+
LOG = logging.getLogger("axolotl.tests.e2e")
|
22 |
+
os.environ["WANDB_DISABLED"] = "true"
|
23 |
+
|
24 |
+
|
25 |
+
class TestLoraLlama(unittest.TestCase):
|
26 |
+
"""
|
27 |
+
Test case for Llama models using LoRA w multipack
|
28 |
+
"""
|
29 |
+
|
30 |
+
@with_temp_dir
|
31 |
+
def test_lora_packing(self, temp_dir):
|
32 |
+
# pylint: disable=duplicate-code
|
33 |
+
cfg = DictDefault(
|
34 |
+
{
|
35 |
+
"base_model": "JackFram/llama-68m",
|
36 |
+
"tokenizer_type": "LlamaTokenizer",
|
37 |
+
"sequence_len": 1024,
|
38 |
+
"sample_packing": True,
|
39 |
+
"flash_attention": True,
|
40 |
+
"load_in_8bit": True,
|
41 |
+
"adapter": "lora",
|
42 |
+
"lora_r": 32,
|
43 |
+
"lora_alpha": 64,
|
44 |
+
"lora_dropout": 0.05,
|
45 |
+
"lora_target_linear": True,
|
46 |
+
"val_set_size": 0.1,
|
47 |
+
"special_tokens": {
|
48 |
+
"unk_token": "<unk>",
|
49 |
+
"bos_token": "<s>",
|
50 |
+
"eos_token": "</s>",
|
51 |
+
},
|
52 |
+
"datasets": [
|
53 |
+
{
|
54 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
55 |
+
"type": "alpaca",
|
56 |
+
},
|
57 |
+
],
|
58 |
+
"num_epochs": 2,
|
59 |
+
"micro_batch_size": 8,
|
60 |
+
"gradient_accumulation_steps": 1,
|
61 |
+
"output_dir": temp_dir,
|
62 |
+
"learning_rate": 0.00001,
|
63 |
+
"optimizer": "adamw_torch",
|
64 |
+
"lr_scheduler": "cosine",
|
65 |
+
}
|
66 |
+
)
|
67 |
+
if is_torch_bf16_gpu_available():
|
68 |
+
cfg.bf16 = True
|
69 |
+
else:
|
70 |
+
cfg.fp16 = True
|
71 |
+
|
72 |
+
normalize_config(cfg)
|
73 |
+
cli_args = TrainerCliArgs()
|
74 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
75 |
+
|
76 |
+
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
77 |
+
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
78 |
+
|
79 |
+
@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
|
80 |
+
@with_temp_dir
|
81 |
+
def test_lora_gptq_packed(self, temp_dir):
|
82 |
+
# pylint: disable=duplicate-code
|
83 |
+
cfg = DictDefault(
|
84 |
+
{
|
85 |
+
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
|
86 |
+
"model_type": "AutoModelForCausalLM",
|
87 |
+
"tokenizer_type": "LlamaTokenizer",
|
88 |
+
"sequence_len": 1024,
|
89 |
+
"sample_packing": True,
|
90 |
+
"flash_attention": True,
|
91 |
+
"load_in_8bit": True,
|
92 |
+
"adapter": "lora",
|
93 |
+
"gptq": True,
|
94 |
+
"gptq_disable_exllama": True,
|
95 |
+
"lora_r": 32,
|
96 |
+
"lora_alpha": 64,
|
97 |
+
"lora_dropout": 0.05,
|
98 |
+
"lora_target_linear": True,
|
99 |
+
"val_set_size": 0.1,
|
100 |
+
"special_tokens": {
|
101 |
+
"unk_token": "<unk>",
|
102 |
+
"bos_token": "<s>",
|
103 |
+
"eos_token": "</s>",
|
104 |
+
},
|
105 |
+
"datasets": [
|
106 |
+
{
|
107 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
108 |
+
"type": "alpaca",
|
109 |
+
},
|
110 |
+
],
|
111 |
+
"num_epochs": 2,
|
112 |
+
"save_steps": 0.5,
|
113 |
+
"micro_batch_size": 8,
|
114 |
+
"gradient_accumulation_steps": 1,
|
115 |
+
"output_dir": temp_dir,
|
116 |
+
"learning_rate": 0.00001,
|
117 |
+
"optimizer": "adamw_torch",
|
118 |
+
"lr_scheduler": "cosine",
|
119 |
+
}
|
120 |
+
)
|
121 |
+
normalize_config(cfg)
|
122 |
+
cli_args = TrainerCliArgs()
|
123 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
124 |
+
|
125 |
+
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
126 |
+
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
@@ -15,7 +15,7 @@ from axolotl.train import train
|
|
15 |
from axolotl.utils.config import normalize_config
|
16 |
from axolotl.utils.dict import DictDefault
|
17 |
|
18 |
-
from
|
19 |
|
20 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
21 |
os.environ["WANDB_DISABLED"] = "true"
|
|
|
15 |
from axolotl.utils.config import normalize_config
|
16 |
from axolotl.utils.dict import DictDefault
|
17 |
|
18 |
+
from ..utils import with_temp_dir
|
19 |
|
20 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
21 |
os.environ["WANDB_DISABLED"] = "true"
|
@@ -15,7 +15,7 @@ from axolotl.train import train
|
|
15 |
from axolotl.utils.config import normalize_config
|
16 |
from axolotl.utils.dict import DictDefault
|
17 |
|
18 |
-
from
|
19 |
|
20 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
21 |
os.environ["WANDB_DISABLED"] = "true"
|
|
|
15 |
from axolotl.utils.config import normalize_config
|
16 |
from axolotl.utils.dict import DictDefault
|
17 |
|
18 |
+
from ..utils import with_temp_dir
|
19 |
|
20 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
21 |
os.environ["WANDB_DISABLED"] = "true"
|
@@ -9,7 +9,7 @@ from axolotl.utils.config import normalize_config
|
|
9 |
from axolotl.utils.dict import DictDefault
|
10 |
from axolotl.utils.models import load_model, load_tokenizer
|
11 |
|
12 |
-
from
|
13 |
|
14 |
|
15 |
class TestModelPatches(unittest.TestCase):
|
|
|
9 |
from axolotl.utils.dict import DictDefault
|
10 |
from axolotl.utils.models import load_model, load_tokenizer
|
11 |
|
12 |
+
from ..utils import with_temp_dir
|
13 |
|
14 |
|
15 |
class TestModelPatches(unittest.TestCase):
|
@@ -17,7 +17,7 @@ from axolotl.train import train
|
|
17 |
from axolotl.utils.config import normalize_config
|
18 |
from axolotl.utils.dict import DictDefault
|
19 |
|
20 |
-
from
|
21 |
|
22 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
23 |
os.environ["WANDB_DISABLED"] = "true"
|
@@ -29,7 +29,7 @@ class TestResumeLlama(unittest.TestCase):
|
|
29 |
"""
|
30 |
|
31 |
@with_temp_dir
|
32 |
-
def
|
33 |
# pylint: disable=duplicate-code
|
34 |
cfg = DictDefault(
|
35 |
{
|
|
|
17 |
from axolotl.utils.config import normalize_config
|
18 |
from axolotl.utils.dict import DictDefault
|
19 |
|
20 |
+
from ..utils import most_recent_subdir, with_temp_dir
|
21 |
|
22 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
23 |
os.environ["WANDB_DISABLED"] = "true"
|
|
|
29 |
"""
|
30 |
|
31 |
@with_temp_dir
|
32 |
+
def test_resume_qlora_packed(self, temp_dir):
|
33 |
# pylint: disable=duplicate-code
|
34 |
cfg = DictDefault(
|
35 |
{
|
@@ -65,96 +65,3 @@ class TestLoraLlama(unittest.TestCase):
|
|
65 |
|
66 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
67 |
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
68 |
-
|
69 |
-
@with_temp_dir
|
70 |
-
def test_lora_packing(self, temp_dir):
|
71 |
-
# pylint: disable=duplicate-code
|
72 |
-
cfg = DictDefault(
|
73 |
-
{
|
74 |
-
"base_model": "JackFram/llama-68m",
|
75 |
-
"tokenizer_type": "LlamaTokenizer",
|
76 |
-
"sequence_len": 1024,
|
77 |
-
"sample_packing": True,
|
78 |
-
"flash_attention": True,
|
79 |
-
"load_in_8bit": True,
|
80 |
-
"adapter": "lora",
|
81 |
-
"lora_r": 32,
|
82 |
-
"lora_alpha": 64,
|
83 |
-
"lora_dropout": 0.05,
|
84 |
-
"lora_target_linear": True,
|
85 |
-
"val_set_size": 0.1,
|
86 |
-
"special_tokens": {
|
87 |
-
"unk_token": "<unk>",
|
88 |
-
"bos_token": "<s>",
|
89 |
-
"eos_token": "</s>",
|
90 |
-
},
|
91 |
-
"datasets": [
|
92 |
-
{
|
93 |
-
"path": "mhenrichsen/alpaca_2k_test",
|
94 |
-
"type": "alpaca",
|
95 |
-
},
|
96 |
-
],
|
97 |
-
"num_epochs": 2,
|
98 |
-
"micro_batch_size": 8,
|
99 |
-
"gradient_accumulation_steps": 1,
|
100 |
-
"output_dir": temp_dir,
|
101 |
-
"learning_rate": 0.00001,
|
102 |
-
"optimizer": "adamw_torch",
|
103 |
-
"lr_scheduler": "cosine",
|
104 |
-
"bf16": True,
|
105 |
-
}
|
106 |
-
)
|
107 |
-
normalize_config(cfg)
|
108 |
-
cli_args = TrainerCliArgs()
|
109 |
-
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
110 |
-
|
111 |
-
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
112 |
-
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
113 |
-
|
114 |
-
@with_temp_dir
|
115 |
-
def test_lora_gptq(self, temp_dir):
|
116 |
-
# pylint: disable=duplicate-code
|
117 |
-
cfg = DictDefault(
|
118 |
-
{
|
119 |
-
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
|
120 |
-
"model_type": "AutoModelForCausalLM",
|
121 |
-
"tokenizer_type": "LlamaTokenizer",
|
122 |
-
"sequence_len": 1024,
|
123 |
-
"sample_packing": True,
|
124 |
-
"flash_attention": True,
|
125 |
-
"load_in_8bit": True,
|
126 |
-
"adapter": "lora",
|
127 |
-
"gptq": True,
|
128 |
-
"gptq_disable_exllama": True,
|
129 |
-
"lora_r": 32,
|
130 |
-
"lora_alpha": 64,
|
131 |
-
"lora_dropout": 0.05,
|
132 |
-
"lora_target_linear": True,
|
133 |
-
"val_set_size": 0.1,
|
134 |
-
"special_tokens": {
|
135 |
-
"unk_token": "<unk>",
|
136 |
-
"bos_token": "<s>",
|
137 |
-
"eos_token": "</s>",
|
138 |
-
},
|
139 |
-
"datasets": [
|
140 |
-
{
|
141 |
-
"path": "mhenrichsen/alpaca_2k_test",
|
142 |
-
"type": "alpaca",
|
143 |
-
},
|
144 |
-
],
|
145 |
-
"num_epochs": 2,
|
146 |
-
"save_steps": 0.5,
|
147 |
-
"micro_batch_size": 8,
|
148 |
-
"gradient_accumulation_steps": 1,
|
149 |
-
"output_dir": temp_dir,
|
150 |
-
"learning_rate": 0.00001,
|
151 |
-
"optimizer": "adamw_torch",
|
152 |
-
"lr_scheduler": "cosine",
|
153 |
-
}
|
154 |
-
)
|
155 |
-
normalize_config(cfg)
|
156 |
-
cli_args = TrainerCliArgs()
|
157 |
-
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
158 |
-
|
159 |
-
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
160 |
-
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
|
|
65 |
|
66 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
67 |
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -7,6 +7,8 @@ import os
|
|
7 |
import unittest
|
8 |
from pathlib import Path
|
9 |
|
|
|
|
|
10 |
from axolotl.cli import load_datasets
|
11 |
from axolotl.common.cli import TrainerCliArgs
|
12 |
from axolotl.train import train
|
@@ -19,9 +21,10 @@ LOG = logging.getLogger("axolotl.tests.e2e")
|
|
19 |
os.environ["WANDB_DISABLED"] = "true"
|
20 |
|
21 |
|
22 |
-
|
|
|
23 |
"""
|
24 |
-
Test case for
|
25 |
"""
|
26 |
|
27 |
@with_temp_dir
|
|
|
7 |
import unittest
|
8 |
from pathlib import Path
|
9 |
|
10 |
+
import pytest
|
11 |
+
|
12 |
from axolotl.cli import load_datasets
|
13 |
from axolotl.common.cli import TrainerCliArgs
|
14 |
from axolotl.train import train
|
|
|
21 |
os.environ["WANDB_DISABLED"] = "true"
|
22 |
|
23 |
|
24 |
+
@pytest.mark.skip(reason="skipping until upstreamed into transformers")
|
25 |
+
class TestMamba(unittest.TestCase):
|
26 |
"""
|
27 |
+
Test case for Mamba models
|
28 |
"""
|
29 |
|
30 |
@with_temp_dir
|
@@ -8,6 +8,7 @@ import unittest
|
|
8 |
from pathlib import Path
|
9 |
|
10 |
import pytest
|
|
|
11 |
|
12 |
from axolotl.cli import load_datasets
|
13 |
from axolotl.common.cli import TrainerCliArgs
|
@@ -59,7 +60,6 @@ class TestPhi(unittest.TestCase):
|
|
59 |
"learning_rate": 0.00001,
|
60 |
"optimizer": "paged_adamw_8bit",
|
61 |
"lr_scheduler": "cosine",
|
62 |
-
"bf16": True,
|
63 |
"flash_attention": True,
|
64 |
"max_steps": 10,
|
65 |
"save_steps": 10,
|
@@ -67,6 +67,10 @@ class TestPhi(unittest.TestCase):
|
|
67 |
"save_safetensors": True,
|
68 |
}
|
69 |
)
|
|
|
|
|
|
|
|
|
70 |
normalize_config(cfg)
|
71 |
cli_args = TrainerCliArgs()
|
72 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
@@ -110,9 +114,13 @@ class TestPhi(unittest.TestCase):
|
|
110 |
"learning_rate": 0.00001,
|
111 |
"optimizer": "adamw_bnb_8bit",
|
112 |
"lr_scheduler": "cosine",
|
113 |
-
"bf16": True,
|
114 |
}
|
115 |
)
|
|
|
|
|
|
|
|
|
|
|
116 |
normalize_config(cfg)
|
117 |
cli_args = TrainerCliArgs()
|
118 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
|
8 |
from pathlib import Path
|
9 |
|
10 |
import pytest
|
11 |
+
from transformers.utils import is_torch_bf16_gpu_available
|
12 |
|
13 |
from axolotl.cli import load_datasets
|
14 |
from axolotl.common.cli import TrainerCliArgs
|
|
|
60 |
"learning_rate": 0.00001,
|
61 |
"optimizer": "paged_adamw_8bit",
|
62 |
"lr_scheduler": "cosine",
|
|
|
63 |
"flash_attention": True,
|
64 |
"max_steps": 10,
|
65 |
"save_steps": 10,
|
|
|
67 |
"save_safetensors": True,
|
68 |
}
|
69 |
)
|
70 |
+
if is_torch_bf16_gpu_available():
|
71 |
+
cfg.bf16 = True
|
72 |
+
else:
|
73 |
+
cfg.fp16 = True
|
74 |
normalize_config(cfg)
|
75 |
cli_args = TrainerCliArgs()
|
76 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
|
114 |
"learning_rate": 0.00001,
|
115 |
"optimizer": "adamw_bnb_8bit",
|
116 |
"lr_scheduler": "cosine",
|
|
|
117 |
}
|
118 |
)
|
119 |
+
if is_torch_bf16_gpu_available():
|
120 |
+
cfg.bf16 = True
|
121 |
+
else:
|
122 |
+
cfg.fp16 = True
|
123 |
+
|
124 |
normalize_config(cfg)
|
125 |
cli_args = TrainerCliArgs()
|
126 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
@@ -6,6 +6,7 @@ import unittest
|
|
6 |
from typing import Optional
|
7 |
|
8 |
import pytest
|
|
|
9 |
|
10 |
from axolotl.utils.config import validate_config
|
11 |
from axolotl.utils.dict import DictDefault
|
@@ -354,6 +355,10 @@ class ValidationTest(unittest.TestCase):
|
|
354 |
with pytest.raises(ValueError, match=regex_exp):
|
355 |
validate_config(cfg)
|
356 |
|
|
|
|
|
|
|
|
|
357 |
def test_merge_lora_no_bf16_fail(self):
|
358 |
"""
|
359 |
This is assumed to be run on a CPU machine, so bf16 is not supported.
|
@@ -778,6 +783,15 @@ class ValidationWandbTest(ValidationTest):
|
|
778 |
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
|
779 |
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
780 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
781 |
def test_wandb_set_disabled(self):
|
782 |
cfg = DictDefault({})
|
783 |
|
@@ -798,3 +812,6 @@ class ValidationWandbTest(ValidationTest):
|
|
798 |
setup_wandb_env_vars(cfg)
|
799 |
|
800 |
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
|
|
|
|
|
|
|
6 |
from typing import Optional
|
7 |
|
8 |
import pytest
|
9 |
+
from transformers.utils import is_torch_bf16_gpu_available
|
10 |
|
11 |
from axolotl.utils.config import validate_config
|
12 |
from axolotl.utils.dict import DictDefault
|
|
|
355 |
with pytest.raises(ValueError, match=regex_exp):
|
356 |
validate_config(cfg)
|
357 |
|
358 |
+
@pytest.mark.skipif(
|
359 |
+
is_torch_bf16_gpu_available(),
|
360 |
+
reason="test should only run on gpus w/o bf16 support",
|
361 |
+
)
|
362 |
def test_merge_lora_no_bf16_fail(self):
|
363 |
"""
|
364 |
This is assumed to be run on a CPU machine, so bf16 is not supported.
|
|
|
783 |
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
|
784 |
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
785 |
|
786 |
+
os.environ.pop("WANDB_PROJECT", None)
|
787 |
+
os.environ.pop("WANDB_NAME", None)
|
788 |
+
os.environ.pop("WANDB_RUN_ID", None)
|
789 |
+
os.environ.pop("WANDB_ENTITY", None)
|
790 |
+
os.environ.pop("WANDB_MODE", None)
|
791 |
+
os.environ.pop("WANDB_WATCH", None)
|
792 |
+
os.environ.pop("WANDB_LOG_MODEL", None)
|
793 |
+
os.environ.pop("WANDB_DISABLED", None)
|
794 |
+
|
795 |
def test_wandb_set_disabled(self):
|
796 |
cfg = DictDefault({})
|
797 |
|
|
|
812 |
setup_wandb_env_vars(cfg)
|
813 |
|
814 |
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
815 |
+
|
816 |
+
os.environ.pop("WANDB_PROJECT", None)
|
817 |
+
os.environ.pop("WANDB_DISABLED", None)
|