winglian commited on
Commit
0402d19
1 Parent(s): b2430ce

make sure to cleanup tmp output_dir for e2e tests

Browse files
tests/e2e/test_fused_llama.py CHANGED
@@ -4,7 +4,6 @@ E2E tests for lora llama
4
 
5
  import logging
6
  import os
7
- import tempfile
8
  import unittest
9
  from pathlib import Path
10
 
@@ -15,6 +14,7 @@ from axolotl.common.cli import TrainerCliArgs
15
  from axolotl.train import train
16
  from axolotl.utils.config import normalize_config
17
  from axolotl.utils.dict import DictDefault
 
18
 
19
  LOG = logging.getLogger("axolotl.tests.e2e")
20
  os.environ["WANDB_DISABLED"] = "true"
@@ -25,9 +25,9 @@ class TestFusedLlama(unittest.TestCase):
25
  Test case for Llama models using Fused layers
26
  """
27
 
28
- def test_fft_packing(self):
 
29
  # pylint: disable=duplicate-code
30
- output_dir = tempfile.mkdtemp()
31
  cfg = DictDefault(
32
  {
33
  "base_model": "JackFram/llama-68m",
 
4
 
5
  import logging
6
  import os
 
7
  import unittest
8
  from pathlib import Path
9
 
 
14
  from axolotl.train import train
15
  from axolotl.utils.config import normalize_config
16
  from axolotl.utils.dict import DictDefault
17
+ from tests.utils import with_temp_dir
18
 
19
  LOG = logging.getLogger("axolotl.tests.e2e")
20
  os.environ["WANDB_DISABLED"] = "true"
 
25
  Test case for Llama models using Fused layers
26
  """
27
 
28
+ @with_temp_dir
29
+ def test_fft_packing(self, output_dir):
30
  # pylint: disable=duplicate-code
 
31
  cfg = DictDefault(
32
  {
33
  "base_model": "JackFram/llama-68m",
tests/e2e/test_lora_llama.py CHANGED
@@ -4,7 +4,6 @@ E2E tests for lora llama
4
 
5
  import logging
6
  import os
7
- import tempfile
8
  import unittest
9
  from pathlib import Path
10
 
@@ -13,6 +12,7 @@ from axolotl.common.cli import TrainerCliArgs
13
  from axolotl.train import train
14
  from axolotl.utils.config import normalize_config
15
  from axolotl.utils.dict import DictDefault
 
16
 
17
  LOG = logging.getLogger("axolotl.tests.e2e")
18
  os.environ["WANDB_DISABLED"] = "true"
@@ -23,9 +23,9 @@ class TestLoraLlama(unittest.TestCase):
23
  Test case for Llama models using LoRA
24
  """
25
 
26
- def test_lora(self):
 
27
  # pylint: disable=duplicate-code
28
- output_dir = tempfile.mkdtemp()
29
  cfg = DictDefault(
30
  {
31
  "base_model": "JackFram/llama-68m",
@@ -65,9 +65,9 @@ class TestLoraLlama(unittest.TestCase):
65
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
66
  assert (Path(output_dir) / "adapter_model.bin").exists()
67
 
68
- def test_lora_packing(self):
 
69
  # pylint: disable=duplicate-code
70
- output_dir = tempfile.mkdtemp()
71
  cfg = DictDefault(
72
  {
73
  "base_model": "JackFram/llama-68m",
@@ -109,9 +109,9 @@ class TestLoraLlama(unittest.TestCase):
109
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
110
  assert (Path(output_dir) / "adapter_model.bin").exists()
111
 
112
- def test_lora_gptq(self):
 
113
  # pylint: disable=duplicate-code
114
- output_dir = tempfile.mkdtemp()
115
  cfg = DictDefault(
116
  {
117
  "base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
 
4
 
5
  import logging
6
  import os
 
7
  import unittest
8
  from pathlib import Path
9
 
 
12
  from axolotl.train import train
13
  from axolotl.utils.config import normalize_config
14
  from axolotl.utils.dict import DictDefault
15
+ from tests.utils import with_temp_dir
16
 
17
  LOG = logging.getLogger("axolotl.tests.e2e")
18
  os.environ["WANDB_DISABLED"] = "true"
 
23
  Test case for Llama models using LoRA
24
  """
25
 
26
+ @with_temp_dir
27
+ def test_lora(self, output_dir):
28
  # pylint: disable=duplicate-code
 
29
  cfg = DictDefault(
30
  {
31
  "base_model": "JackFram/llama-68m",
 
65
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
66
  assert (Path(output_dir) / "adapter_model.bin").exists()
67
 
68
+ @with_temp_dir
69
+ def test_lora_packing(self, output_dir):
70
  # pylint: disable=duplicate-code
 
71
  cfg = DictDefault(
72
  {
73
  "base_model": "JackFram/llama-68m",
 
109
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
110
  assert (Path(output_dir) / "adapter_model.bin").exists()
111
 
112
+ @with_temp_dir
113
+ def test_lora_gptq(self, output_dir):
114
  # pylint: disable=duplicate-code
 
115
  cfg = DictDefault(
116
  {
117
  "base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
tests/e2e/test_mistral.py CHANGED
@@ -4,7 +4,6 @@ E2E tests for lora llama
4
 
5
  import logging
6
  import os
7
- import tempfile
8
  import unittest
9
  from pathlib import Path
10
 
@@ -15,6 +14,7 @@ from axolotl.common.cli import TrainerCliArgs
15
  from axolotl.train import train
16
  from axolotl.utils.config import normalize_config
17
  from axolotl.utils.dict import DictDefault
 
18
 
19
  LOG = logging.getLogger("axolotl.tests.e2e")
20
  os.environ["WANDB_DISABLED"] = "true"
@@ -25,9 +25,9 @@ class TestMistral(unittest.TestCase):
25
  Test case for Llama models using LoRA
26
  """
27
 
28
- def test_lora(self):
 
29
  # pylint: disable=duplicate-code
30
- output_dir = tempfile.mkdtemp()
31
  cfg = DictDefault(
32
  {
33
  "base_model": "openaccess-ai-collective/tiny-mistral",
@@ -70,9 +70,9 @@ class TestMistral(unittest.TestCase):
70
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
71
  assert (Path(output_dir) / "adapter_model.bin").exists()
72
 
73
- def test_ft(self):
 
74
  # pylint: disable=duplicate-code
75
- output_dir = tempfile.mkdtemp()
76
  cfg = DictDefault(
77
  {
78
  "base_model": "openaccess-ai-collective/tiny-mistral",
 
4
 
5
  import logging
6
  import os
 
7
  import unittest
8
  from pathlib import Path
9
 
 
14
  from axolotl.train import train
15
  from axolotl.utils.config import normalize_config
16
  from axolotl.utils.dict import DictDefault
17
+ from tests.utils import with_temp_dir
18
 
19
  LOG = logging.getLogger("axolotl.tests.e2e")
20
  os.environ["WANDB_DISABLED"] = "true"
 
25
  Test case for Llama models using LoRA
26
  """
27
 
28
+ @with_temp_dir
29
+ def test_lora(self, output_dir):
30
  # pylint: disable=duplicate-code
 
31
  cfg = DictDefault(
32
  {
33
  "base_model": "openaccess-ai-collective/tiny-mistral",
 
70
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
71
  assert (Path(output_dir) / "adapter_model.bin").exists()
72
 
73
+ @with_temp_dir
74
+ def test_ft(self, output_dir):
75
  # pylint: disable=duplicate-code
 
76
  cfg = DictDefault(
77
  {
78
  "base_model": "openaccess-ai-collective/tiny-mistral",
tests/e2e/test_mistral_samplepack.py CHANGED
@@ -4,7 +4,6 @@ E2E tests for lora llama
4
 
5
  import logging
6
  import os
7
- import tempfile
8
  import unittest
9
  from pathlib import Path
10
 
@@ -15,6 +14,7 @@ from axolotl.common.cli import TrainerCliArgs
15
  from axolotl.train import train
16
  from axolotl.utils.config import normalize_config
17
  from axolotl.utils.dict import DictDefault
 
18
 
19
  LOG = logging.getLogger("axolotl.tests.e2e")
20
  os.environ["WANDB_DISABLED"] = "true"
@@ -25,9 +25,9 @@ class TestMistral(unittest.TestCase):
25
  Test case for Llama models using LoRA
26
  """
27
 
28
- def test_lora_packing(self):
 
29
  # pylint: disable=duplicate-code
30
- output_dir = tempfile.mkdtemp()
31
  cfg = DictDefault(
32
  {
33
  "base_model": "openaccess-ai-collective/tiny-mistral",
@@ -71,9 +71,9 @@ class TestMistral(unittest.TestCase):
71
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
72
  assert (Path(output_dir) / "adapter_model.bin").exists()
73
 
74
- def test_ft_packing(self):
 
75
  # pylint: disable=duplicate-code
76
- output_dir = tempfile.mkdtemp()
77
  cfg = DictDefault(
78
  {
79
  "base_model": "openaccess-ai-collective/tiny-mistral",
 
4
 
5
  import logging
6
  import os
 
7
  import unittest
8
  from pathlib import Path
9
 
 
14
  from axolotl.train import train
15
  from axolotl.utils.config import normalize_config
16
  from axolotl.utils.dict import DictDefault
17
+ from tests.utils import with_temp_dir
18
 
19
  LOG = logging.getLogger("axolotl.tests.e2e")
20
  os.environ["WANDB_DISABLED"] = "true"
 
25
  Test case for Llama models using LoRA
26
  """
27
 
28
+ @with_temp_dir
29
+ def test_lora_packing(self, output_dir):
30
  # pylint: disable=duplicate-code
 
31
  cfg = DictDefault(
32
  {
33
  "base_model": "openaccess-ai-collective/tiny-mistral",
 
71
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
72
  assert (Path(output_dir) / "adapter_model.bin").exists()
73
 
74
+ @with_temp_dir
75
+ def test_ft_packing(self, output_dir):
76
  # pylint: disable=duplicate-code
 
77
  cfg = DictDefault(
78
  {
79
  "base_model": "openaccess-ai-collective/tiny-mistral",
tests/e2e/test_phi.py CHANGED
@@ -4,14 +4,15 @@ E2E tests for lora llama
4
 
5
  import logging
6
  import os
7
- import tempfile
8
  import unittest
 
9
 
10
  from axolotl.cli import load_datasets
11
  from axolotl.common.cli import TrainerCliArgs
12
  from axolotl.train import train
13
  from axolotl.utils.config import normalize_config
14
  from axolotl.utils.dict import DictDefault
 
15
 
16
  LOG = logging.getLogger("axolotl.tests.e2e")
17
  os.environ["WANDB_DISABLED"] = "true"
@@ -22,7 +23,8 @@ class TestPhi(unittest.TestCase):
22
  Test case for Llama models using LoRA
23
  """
24
 
25
- def test_ft(self):
 
26
  # pylint: disable=duplicate-code
27
  cfg = DictDefault(
28
  {
@@ -52,7 +54,7 @@ class TestPhi(unittest.TestCase):
52
  "num_epochs": 1,
53
  "micro_batch_size": 1,
54
  "gradient_accumulation_steps": 1,
55
- "output_dir": tempfile.mkdtemp(),
56
  "learning_rate": 0.00001,
57
  "optimizer": "adamw_bnb_8bit",
58
  "lr_scheduler": "cosine",
@@ -64,8 +66,10 @@ class TestPhi(unittest.TestCase):
64
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
65
 
66
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
 
67
 
68
- def test_ft_packed(self):
 
69
  # pylint: disable=duplicate-code
70
  cfg = DictDefault(
71
  {
@@ -95,7 +99,7 @@ class TestPhi(unittest.TestCase):
95
  "num_epochs": 1,
96
  "micro_batch_size": 1,
97
  "gradient_accumulation_steps": 1,
98
- "output_dir": tempfile.mkdtemp(),
99
  "learning_rate": 0.00001,
100
  "optimizer": "adamw_bnb_8bit",
101
  "lr_scheduler": "cosine",
@@ -107,3 +111,4 @@ class TestPhi(unittest.TestCase):
107
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
108
 
109
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
 
 
4
 
5
  import logging
6
  import os
 
7
  import unittest
8
+ from pathlib import Path
9
 
10
  from axolotl.cli import load_datasets
11
  from axolotl.common.cli import TrainerCliArgs
12
  from axolotl.train import train
13
  from axolotl.utils.config import normalize_config
14
  from axolotl.utils.dict import DictDefault
15
+ from tests.utils import with_temp_dir
16
 
17
  LOG = logging.getLogger("axolotl.tests.e2e")
18
  os.environ["WANDB_DISABLED"] = "true"
 
23
  Test case for Llama models using LoRA
24
  """
25
 
26
+ @with_temp_dir
27
+ def test_ft(self, output_dir):
28
  # pylint: disable=duplicate-code
29
  cfg = DictDefault(
30
  {
 
54
  "num_epochs": 1,
55
  "micro_batch_size": 1,
56
  "gradient_accumulation_steps": 1,
57
+ "output_dir": output_dir,
58
  "learning_rate": 0.00001,
59
  "optimizer": "adamw_bnb_8bit",
60
  "lr_scheduler": "cosine",
 
66
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
67
 
68
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
69
+ assert (Path(output_dir) / "pytorch_model.bin").exists()
70
 
71
+ @with_temp_dir
72
+ def test_ft_packed(self, output_dir):
73
  # pylint: disable=duplicate-code
74
  cfg = DictDefault(
75
  {
 
99
  "num_epochs": 1,
100
  "micro_batch_size": 1,
101
  "gradient_accumulation_steps": 1,
102
+ "output_dir": output_dir,
103
  "learning_rate": 0.00001,
104
  "optimizer": "adamw_bnb_8bit",
105
  "lr_scheduler": "cosine",
 
111
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
112
 
113
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
114
+ assert (Path(output_dir) / "pytorch_model.bin").exists()
tests/utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ helper utils for tests
3
+ """
4
+
5
+ import shutil
6
+ import tempfile
7
+ from functools import wraps
8
+
9
+
10
+ def with_temp_dir(test_func):
11
+ @wraps(test_func)
12
+ def wrapper(*args, **kwargs):
13
+ # Create a temporary directory
14
+ temp_dir = tempfile.mkdtemp()
15
+ try:
16
+ # Pass the temporary directory to the test function
17
+ test_func(temp_dir, *args, **kwargs)
18
+ finally:
19
+ # Clean up the directory after the test
20
+ shutil.rmtree(temp_dir)
21
+
22
+ return wrapper