make sure to cleanup tmp output_dir for e2e tests
Browse files- tests/e2e/test_fused_llama.py +3 -3
- tests/e2e/test_lora_llama.py +7 -7
- tests/e2e/test_mistral.py +5 -5
- tests/e2e/test_mistral_samplepack.py +5 -5
- tests/e2e/test_phi.py +10 -5
- tests/utils.py +22 -0
tests/e2e/test_fused_llama.py
CHANGED
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
|
4 |
|
5 |
import logging
|
6 |
import os
|
7 |
-
import tempfile
|
8 |
import unittest
|
9 |
from pathlib import Path
|
10 |
|
@@ -15,6 +14,7 @@ from axolotl.common.cli import TrainerCliArgs
|
|
15 |
from axolotl.train import train
|
16 |
from axolotl.utils.config import normalize_config
|
17 |
from axolotl.utils.dict import DictDefault
|
|
|
18 |
|
19 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
20 |
os.environ["WANDB_DISABLED"] = "true"
|
@@ -25,9 +25,9 @@ class TestFusedLlama(unittest.TestCase):
|
|
25 |
Test case for Llama models using Fused layers
|
26 |
"""
|
27 |
|
28 |
-
|
|
|
29 |
# pylint: disable=duplicate-code
|
30 |
-
output_dir = tempfile.mkdtemp()
|
31 |
cfg = DictDefault(
|
32 |
{
|
33 |
"base_model": "JackFram/llama-68m",
|
|
|
4 |
|
5 |
import logging
|
6 |
import os
|
|
|
7 |
import unittest
|
8 |
from pathlib import Path
|
9 |
|
|
|
14 |
from axolotl.train import train
|
15 |
from axolotl.utils.config import normalize_config
|
16 |
from axolotl.utils.dict import DictDefault
|
17 |
+
from tests.utils import with_temp_dir
|
18 |
|
19 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
20 |
os.environ["WANDB_DISABLED"] = "true"
|
|
|
25 |
Test case for Llama models using Fused layers
|
26 |
"""
|
27 |
|
28 |
+
@with_temp_dir
|
29 |
+
def test_fft_packing(self, output_dir):
|
30 |
# pylint: disable=duplicate-code
|
|
|
31 |
cfg = DictDefault(
|
32 |
{
|
33 |
"base_model": "JackFram/llama-68m",
|
tests/e2e/test_lora_llama.py
CHANGED
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
|
4 |
|
5 |
import logging
|
6 |
import os
|
7 |
-
import tempfile
|
8 |
import unittest
|
9 |
from pathlib import Path
|
10 |
|
@@ -13,6 +12,7 @@ from axolotl.common.cli import TrainerCliArgs
|
|
13 |
from axolotl.train import train
|
14 |
from axolotl.utils.config import normalize_config
|
15 |
from axolotl.utils.dict import DictDefault
|
|
|
16 |
|
17 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
18 |
os.environ["WANDB_DISABLED"] = "true"
|
@@ -23,9 +23,9 @@ class TestLoraLlama(unittest.TestCase):
|
|
23 |
Test case for Llama models using LoRA
|
24 |
"""
|
25 |
|
26 |
-
|
|
|
27 |
# pylint: disable=duplicate-code
|
28 |
-
output_dir = tempfile.mkdtemp()
|
29 |
cfg = DictDefault(
|
30 |
{
|
31 |
"base_model": "JackFram/llama-68m",
|
@@ -65,9 +65,9 @@ class TestLoraLlama(unittest.TestCase):
|
|
65 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
66 |
assert (Path(output_dir) / "adapter_model.bin").exists()
|
67 |
|
68 |
-
|
|
|
69 |
# pylint: disable=duplicate-code
|
70 |
-
output_dir = tempfile.mkdtemp()
|
71 |
cfg = DictDefault(
|
72 |
{
|
73 |
"base_model": "JackFram/llama-68m",
|
@@ -109,9 +109,9 @@ class TestLoraLlama(unittest.TestCase):
|
|
109 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
110 |
assert (Path(output_dir) / "adapter_model.bin").exists()
|
111 |
|
112 |
-
|
|
|
113 |
# pylint: disable=duplicate-code
|
114 |
-
output_dir = tempfile.mkdtemp()
|
115 |
cfg = DictDefault(
|
116 |
{
|
117 |
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
|
|
|
4 |
|
5 |
import logging
|
6 |
import os
|
|
|
7 |
import unittest
|
8 |
from pathlib import Path
|
9 |
|
|
|
12 |
from axolotl.train import train
|
13 |
from axolotl.utils.config import normalize_config
|
14 |
from axolotl.utils.dict import DictDefault
|
15 |
+
from tests.utils import with_temp_dir
|
16 |
|
17 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
18 |
os.environ["WANDB_DISABLED"] = "true"
|
|
|
23 |
Test case for Llama models using LoRA
|
24 |
"""
|
25 |
|
26 |
+
@with_temp_dir
|
27 |
+
def test_lora(self, output_dir):
|
28 |
# pylint: disable=duplicate-code
|
|
|
29 |
cfg = DictDefault(
|
30 |
{
|
31 |
"base_model": "JackFram/llama-68m",
|
|
|
65 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
66 |
assert (Path(output_dir) / "adapter_model.bin").exists()
|
67 |
|
68 |
+
@with_temp_dir
|
69 |
+
def test_lora_packing(self, output_dir):
|
70 |
# pylint: disable=duplicate-code
|
|
|
71 |
cfg = DictDefault(
|
72 |
{
|
73 |
"base_model": "JackFram/llama-68m",
|
|
|
109 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
110 |
assert (Path(output_dir) / "adapter_model.bin").exists()
|
111 |
|
112 |
+
@with_temp_dir
|
113 |
+
def test_lora_gptq(self, output_dir):
|
114 |
# pylint: disable=duplicate-code
|
|
|
115 |
cfg = DictDefault(
|
116 |
{
|
117 |
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
|
tests/e2e/test_mistral.py
CHANGED
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
|
4 |
|
5 |
import logging
|
6 |
import os
|
7 |
-
import tempfile
|
8 |
import unittest
|
9 |
from pathlib import Path
|
10 |
|
@@ -15,6 +14,7 @@ from axolotl.common.cli import TrainerCliArgs
|
|
15 |
from axolotl.train import train
|
16 |
from axolotl.utils.config import normalize_config
|
17 |
from axolotl.utils.dict import DictDefault
|
|
|
18 |
|
19 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
20 |
os.environ["WANDB_DISABLED"] = "true"
|
@@ -25,9 +25,9 @@ class TestMistral(unittest.TestCase):
|
|
25 |
Test case for Llama models using LoRA
|
26 |
"""
|
27 |
|
28 |
-
|
|
|
29 |
# pylint: disable=duplicate-code
|
30 |
-
output_dir = tempfile.mkdtemp()
|
31 |
cfg = DictDefault(
|
32 |
{
|
33 |
"base_model": "openaccess-ai-collective/tiny-mistral",
|
@@ -70,9 +70,9 @@ class TestMistral(unittest.TestCase):
|
|
70 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
71 |
assert (Path(output_dir) / "adapter_model.bin").exists()
|
72 |
|
73 |
-
|
|
|
74 |
# pylint: disable=duplicate-code
|
75 |
-
output_dir = tempfile.mkdtemp()
|
76 |
cfg = DictDefault(
|
77 |
{
|
78 |
"base_model": "openaccess-ai-collective/tiny-mistral",
|
|
|
4 |
|
5 |
import logging
|
6 |
import os
|
|
|
7 |
import unittest
|
8 |
from pathlib import Path
|
9 |
|
|
|
14 |
from axolotl.train import train
|
15 |
from axolotl.utils.config import normalize_config
|
16 |
from axolotl.utils.dict import DictDefault
|
17 |
+
from tests.utils import with_temp_dir
|
18 |
|
19 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
20 |
os.environ["WANDB_DISABLED"] = "true"
|
|
|
25 |
Test case for Llama models using LoRA
|
26 |
"""
|
27 |
|
28 |
+
@with_temp_dir
|
29 |
+
def test_lora(self, output_dir):
|
30 |
# pylint: disable=duplicate-code
|
|
|
31 |
cfg = DictDefault(
|
32 |
{
|
33 |
"base_model": "openaccess-ai-collective/tiny-mistral",
|
|
|
70 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
71 |
assert (Path(output_dir) / "adapter_model.bin").exists()
|
72 |
|
73 |
+
@with_temp_dir
|
74 |
+
def test_ft(self, output_dir):
|
75 |
# pylint: disable=duplicate-code
|
|
|
76 |
cfg = DictDefault(
|
77 |
{
|
78 |
"base_model": "openaccess-ai-collective/tiny-mistral",
|
tests/e2e/test_mistral_samplepack.py
CHANGED
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
|
4 |
|
5 |
import logging
|
6 |
import os
|
7 |
-
import tempfile
|
8 |
import unittest
|
9 |
from pathlib import Path
|
10 |
|
@@ -15,6 +14,7 @@ from axolotl.common.cli import TrainerCliArgs
|
|
15 |
from axolotl.train import train
|
16 |
from axolotl.utils.config import normalize_config
|
17 |
from axolotl.utils.dict import DictDefault
|
|
|
18 |
|
19 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
20 |
os.environ["WANDB_DISABLED"] = "true"
|
@@ -25,9 +25,9 @@ class TestMistral(unittest.TestCase):
|
|
25 |
Test case for Llama models using LoRA
|
26 |
"""
|
27 |
|
28 |
-
|
|
|
29 |
# pylint: disable=duplicate-code
|
30 |
-
output_dir = tempfile.mkdtemp()
|
31 |
cfg = DictDefault(
|
32 |
{
|
33 |
"base_model": "openaccess-ai-collective/tiny-mistral",
|
@@ -71,9 +71,9 @@ class TestMistral(unittest.TestCase):
|
|
71 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
72 |
assert (Path(output_dir) / "adapter_model.bin").exists()
|
73 |
|
74 |
-
|
|
|
75 |
# pylint: disable=duplicate-code
|
76 |
-
output_dir = tempfile.mkdtemp()
|
77 |
cfg = DictDefault(
|
78 |
{
|
79 |
"base_model": "openaccess-ai-collective/tiny-mistral",
|
|
|
4 |
|
5 |
import logging
|
6 |
import os
|
|
|
7 |
import unittest
|
8 |
from pathlib import Path
|
9 |
|
|
|
14 |
from axolotl.train import train
|
15 |
from axolotl.utils.config import normalize_config
|
16 |
from axolotl.utils.dict import DictDefault
|
17 |
+
from tests.utils import with_temp_dir
|
18 |
|
19 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
20 |
os.environ["WANDB_DISABLED"] = "true"
|
|
|
25 |
Test case for Llama models using LoRA
|
26 |
"""
|
27 |
|
28 |
+
@with_temp_dir
|
29 |
+
def test_lora_packing(self, output_dir):
|
30 |
# pylint: disable=duplicate-code
|
|
|
31 |
cfg = DictDefault(
|
32 |
{
|
33 |
"base_model": "openaccess-ai-collective/tiny-mistral",
|
|
|
71 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
72 |
assert (Path(output_dir) / "adapter_model.bin").exists()
|
73 |
|
74 |
+
@with_temp_dir
|
75 |
+
def test_ft_packing(self, output_dir):
|
76 |
# pylint: disable=duplicate-code
|
|
|
77 |
cfg = DictDefault(
|
78 |
{
|
79 |
"base_model": "openaccess-ai-collective/tiny-mistral",
|
tests/e2e/test_phi.py
CHANGED
@@ -4,14 +4,15 @@ E2E tests for lora llama
|
|
4 |
|
5 |
import logging
|
6 |
import os
|
7 |
-
import tempfile
|
8 |
import unittest
|
|
|
9 |
|
10 |
from axolotl.cli import load_datasets
|
11 |
from axolotl.common.cli import TrainerCliArgs
|
12 |
from axolotl.train import train
|
13 |
from axolotl.utils.config import normalize_config
|
14 |
from axolotl.utils.dict import DictDefault
|
|
|
15 |
|
16 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
17 |
os.environ["WANDB_DISABLED"] = "true"
|
@@ -22,7 +23,8 @@ class TestPhi(unittest.TestCase):
|
|
22 |
Test case for Llama models using LoRA
|
23 |
"""
|
24 |
|
25 |
-
|
|
|
26 |
# pylint: disable=duplicate-code
|
27 |
cfg = DictDefault(
|
28 |
{
|
@@ -52,7 +54,7 @@ class TestPhi(unittest.TestCase):
|
|
52 |
"num_epochs": 1,
|
53 |
"micro_batch_size": 1,
|
54 |
"gradient_accumulation_steps": 1,
|
55 |
-
"output_dir":
|
56 |
"learning_rate": 0.00001,
|
57 |
"optimizer": "adamw_bnb_8bit",
|
58 |
"lr_scheduler": "cosine",
|
@@ -64,8 +66,10 @@ class TestPhi(unittest.TestCase):
|
|
64 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
65 |
|
66 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
|
67 |
|
68 |
-
|
|
|
69 |
# pylint: disable=duplicate-code
|
70 |
cfg = DictDefault(
|
71 |
{
|
@@ -95,7 +99,7 @@ class TestPhi(unittest.TestCase):
|
|
95 |
"num_epochs": 1,
|
96 |
"micro_batch_size": 1,
|
97 |
"gradient_accumulation_steps": 1,
|
98 |
-
"output_dir":
|
99 |
"learning_rate": 0.00001,
|
100 |
"optimizer": "adamw_bnb_8bit",
|
101 |
"lr_scheduler": "cosine",
|
@@ -107,3 +111,4 @@ class TestPhi(unittest.TestCase):
|
|
107 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
108 |
|
109 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
|
|
4 |
|
5 |
import logging
|
6 |
import os
|
|
|
7 |
import unittest
|
8 |
+
from pathlib import Path
|
9 |
|
10 |
from axolotl.cli import load_datasets
|
11 |
from axolotl.common.cli import TrainerCliArgs
|
12 |
from axolotl.train import train
|
13 |
from axolotl.utils.config import normalize_config
|
14 |
from axolotl.utils.dict import DictDefault
|
15 |
+
from tests.utils import with_temp_dir
|
16 |
|
17 |
LOG = logging.getLogger("axolotl.tests.e2e")
|
18 |
os.environ["WANDB_DISABLED"] = "true"
|
|
|
23 |
Test case for Llama models using LoRA
|
24 |
"""
|
25 |
|
26 |
+
@with_temp_dir
|
27 |
+
def test_ft(self, output_dir):
|
28 |
# pylint: disable=duplicate-code
|
29 |
cfg = DictDefault(
|
30 |
{
|
|
|
54 |
"num_epochs": 1,
|
55 |
"micro_batch_size": 1,
|
56 |
"gradient_accumulation_steps": 1,
|
57 |
+
"output_dir": output_dir,
|
58 |
"learning_rate": 0.00001,
|
59 |
"optimizer": "adamw_bnb_8bit",
|
60 |
"lr_scheduler": "cosine",
|
|
|
66 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
67 |
|
68 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
69 |
+
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
70 |
|
71 |
+
@with_temp_dir
|
72 |
+
def test_ft_packed(self, output_dir):
|
73 |
# pylint: disable=duplicate-code
|
74 |
cfg = DictDefault(
|
75 |
{
|
|
|
99 |
"num_epochs": 1,
|
100 |
"micro_batch_size": 1,
|
101 |
"gradient_accumulation_steps": 1,
|
102 |
+
"output_dir": output_dir,
|
103 |
"learning_rate": 0.00001,
|
104 |
"optimizer": "adamw_bnb_8bit",
|
105 |
"lr_scheduler": "cosine",
|
|
|
111 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
112 |
|
113 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
114 |
+
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
tests/utils.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
helper utils for tests
|
3 |
+
"""
|
4 |
+
|
5 |
+
import shutil
|
6 |
+
import tempfile
|
7 |
+
from functools import wraps
|
8 |
+
|
9 |
+
|
10 |
+
def with_temp_dir(test_func):
|
11 |
+
@wraps(test_func)
|
12 |
+
def wrapper(*args, **kwargs):
|
13 |
+
# Create a temporary directory
|
14 |
+
temp_dir = tempfile.mkdtemp()
|
15 |
+
try:
|
16 |
+
# Pass the temporary directory to the test function
|
17 |
+
test_func(temp_dir, *args, **kwargs)
|
18 |
+
finally:
|
19 |
+
# Clean up the directory after the test
|
20 |
+
shutil.rmtree(temp_dir)
|
21 |
+
|
22 |
+
return wrapper
|