make sure to cleanup tmp output_dir for e2e tests
Browse files- tests/e2e/test_fused_llama.py +3 -3
 - tests/e2e/test_lora_llama.py +7 -7
 - tests/e2e/test_mistral.py +5 -5
 - tests/e2e/test_mistral_samplepack.py +5 -5
 - tests/e2e/test_phi.py +10 -5
 - tests/utils.py +22 -0
 
    	
        tests/e2e/test_fused_llama.py
    CHANGED
    
    | 
         @@ -4,7 +4,6 @@ E2E tests for lora llama 
     | 
|
| 4 | 
         | 
| 5 | 
         
             
            import logging
         
     | 
| 6 | 
         
             
            import os
         
     | 
| 7 | 
         
            -
            import tempfile
         
     | 
| 8 | 
         
             
            import unittest
         
     | 
| 9 | 
         
             
            from pathlib import Path
         
     | 
| 10 | 
         | 
| 
         @@ -15,6 +14,7 @@ from axolotl.common.cli import TrainerCliArgs 
     | 
|
| 15 | 
         
             
            from axolotl.train import train
         
     | 
| 16 | 
         
             
            from axolotl.utils.config import normalize_config
         
     | 
| 17 | 
         
             
            from axolotl.utils.dict import DictDefault
         
     | 
| 
         | 
|
| 18 | 
         | 
| 19 | 
         
             
            LOG = logging.getLogger("axolotl.tests.e2e")
         
     | 
| 20 | 
         
             
            os.environ["WANDB_DISABLED"] = "true"
         
     | 
| 
         @@ -25,9 +25,9 @@ class TestFusedLlama(unittest.TestCase): 
     | 
|
| 25 | 
         
             
                Test case for Llama models using Fused layers
         
     | 
| 26 | 
         
             
                """
         
     | 
| 27 | 
         | 
| 28 | 
         
            -
                 
     | 
| 
         | 
|
| 29 | 
         
             
                    # pylint: disable=duplicate-code
         
     | 
| 30 | 
         
            -
                    output_dir = tempfile.mkdtemp()
         
     | 
| 31 | 
         
             
                    cfg = DictDefault(
         
     | 
| 32 | 
         
             
                        {
         
     | 
| 33 | 
         
             
                            "base_model": "JackFram/llama-68m",
         
     | 
| 
         | 
|
| 4 | 
         | 
| 5 | 
         
             
            import logging
         
     | 
| 6 | 
         
             
            import os
         
     | 
| 
         | 
|
| 7 | 
         
             
            import unittest
         
     | 
| 8 | 
         
             
            from pathlib import Path
         
     | 
| 9 | 
         | 
| 
         | 
|
| 14 | 
         
             
            from axolotl.train import train
         
     | 
| 15 | 
         
             
            from axolotl.utils.config import normalize_config
         
     | 
| 16 | 
         
             
            from axolotl.utils.dict import DictDefault
         
     | 
| 17 | 
         
            +
            from tests.utils import with_temp_dir
         
     | 
| 18 | 
         | 
| 19 | 
         
             
            LOG = logging.getLogger("axolotl.tests.e2e")
         
     | 
| 20 | 
         
             
            os.environ["WANDB_DISABLED"] = "true"
         
     | 
| 
         | 
|
| 25 | 
         
             
                Test case for Llama models using Fused layers
         
     | 
| 26 | 
         
             
                """
         
     | 
| 27 | 
         | 
| 28 | 
         
            +
                @with_temp_dir
         
     | 
| 29 | 
         
            +
                def test_fft_packing(self, output_dir):
         
     | 
| 30 | 
         
             
                    # pylint: disable=duplicate-code
         
     | 
| 
         | 
|
| 31 | 
         
             
                    cfg = DictDefault(
         
     | 
| 32 | 
         
             
                        {
         
     | 
| 33 | 
         
             
                            "base_model": "JackFram/llama-68m",
         
     | 
    	
        tests/e2e/test_lora_llama.py
    CHANGED
    
    | 
         @@ -4,7 +4,6 @@ E2E tests for lora llama 
     | 
|
| 4 | 
         | 
| 5 | 
         
             
            import logging
         
     | 
| 6 | 
         
             
            import os
         
     | 
| 7 | 
         
            -
            import tempfile
         
     | 
| 8 | 
         
             
            import unittest
         
     | 
| 9 | 
         
             
            from pathlib import Path
         
     | 
| 10 | 
         | 
| 
         @@ -13,6 +12,7 @@ from axolotl.common.cli import TrainerCliArgs 
     | 
|
| 13 | 
         
             
            from axolotl.train import train
         
     | 
| 14 | 
         
             
            from axolotl.utils.config import normalize_config
         
     | 
| 15 | 
         
             
            from axolotl.utils.dict import DictDefault
         
     | 
| 
         | 
|
| 16 | 
         | 
| 17 | 
         
             
            LOG = logging.getLogger("axolotl.tests.e2e")
         
     | 
| 18 | 
         
             
            os.environ["WANDB_DISABLED"] = "true"
         
     | 
| 
         @@ -23,9 +23,9 @@ class TestLoraLlama(unittest.TestCase): 
     | 
|
| 23 | 
         
             
                Test case for Llama models using LoRA
         
     | 
| 24 | 
         
             
                """
         
     | 
| 25 | 
         | 
| 26 | 
         
            -
                 
     | 
| 
         | 
|
| 27 | 
         
             
                    # pylint: disable=duplicate-code
         
     | 
| 28 | 
         
            -
                    output_dir = tempfile.mkdtemp()
         
     | 
| 29 | 
         
             
                    cfg = DictDefault(
         
     | 
| 30 | 
         
             
                        {
         
     | 
| 31 | 
         
             
                            "base_model": "JackFram/llama-68m",
         
     | 
| 
         @@ -65,9 +65,9 @@ class TestLoraLlama(unittest.TestCase): 
     | 
|
| 65 | 
         
             
                    train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
         
     | 
| 66 | 
         
             
                    assert (Path(output_dir) / "adapter_model.bin").exists()
         
     | 
| 67 | 
         | 
| 68 | 
         
            -
                 
     | 
| 
         | 
|
| 69 | 
         
             
                    # pylint: disable=duplicate-code
         
     | 
| 70 | 
         
            -
                    output_dir = tempfile.mkdtemp()
         
     | 
| 71 | 
         
             
                    cfg = DictDefault(
         
     | 
| 72 | 
         
             
                        {
         
     | 
| 73 | 
         
             
                            "base_model": "JackFram/llama-68m",
         
     | 
| 
         @@ -109,9 +109,9 @@ class TestLoraLlama(unittest.TestCase): 
     | 
|
| 109 | 
         
             
                    train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
         
     | 
| 110 | 
         
             
                    assert (Path(output_dir) / "adapter_model.bin").exists()
         
     | 
| 111 | 
         | 
| 112 | 
         
            -
                 
     | 
| 
         | 
|
| 113 | 
         
             
                    # pylint: disable=duplicate-code
         
     | 
| 114 | 
         
            -
                    output_dir = tempfile.mkdtemp()
         
     | 
| 115 | 
         
             
                    cfg = DictDefault(
         
     | 
| 116 | 
         
             
                        {
         
     | 
| 117 | 
         
             
                            "base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
         
     | 
| 
         | 
|
| 4 | 
         | 
| 5 | 
         
             
            import logging
         
     | 
| 6 | 
         
             
            import os
         
     | 
| 
         | 
|
| 7 | 
         
             
            import unittest
         
     | 
| 8 | 
         
             
            from pathlib import Path
         
     | 
| 9 | 
         | 
| 
         | 
|
| 12 | 
         
             
            from axolotl.train import train
         
     | 
| 13 | 
         
             
            from axolotl.utils.config import normalize_config
         
     | 
| 14 | 
         
             
            from axolotl.utils.dict import DictDefault
         
     | 
| 15 | 
         
            +
            from tests.utils import with_temp_dir
         
     | 
| 16 | 
         | 
| 17 | 
         
             
            LOG = logging.getLogger("axolotl.tests.e2e")
         
     | 
| 18 | 
         
             
            os.environ["WANDB_DISABLED"] = "true"
         
     | 
| 
         | 
|
| 23 | 
         
             
                Test case for Llama models using LoRA
         
     | 
| 24 | 
         
             
                """
         
     | 
| 25 | 
         | 
| 26 | 
         
            +
                @with_temp_dir
         
     | 
| 27 | 
         
            +
                def test_lora(self, output_dir):
         
     | 
| 28 | 
         
             
                    # pylint: disable=duplicate-code
         
     | 
| 
         | 
|
| 29 | 
         
             
                    cfg = DictDefault(
         
     | 
| 30 | 
         
             
                        {
         
     | 
| 31 | 
         
             
                            "base_model": "JackFram/llama-68m",
         
     | 
| 
         | 
|
| 65 | 
         
             
                    train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
         
     | 
| 66 | 
         
             
                    assert (Path(output_dir) / "adapter_model.bin").exists()
         
     | 
| 67 | 
         | 
| 68 | 
         
            +
                @with_temp_dir
         
     | 
| 69 | 
         
            +
                def test_lora_packing(self, output_dir):
         
     | 
| 70 | 
         
             
                    # pylint: disable=duplicate-code
         
     | 
| 
         | 
|
| 71 | 
         
             
                    cfg = DictDefault(
         
     | 
| 72 | 
         
             
                        {
         
     | 
| 73 | 
         
             
                            "base_model": "JackFram/llama-68m",
         
     | 
| 
         | 
|
| 109 | 
         
             
                    train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
         
     | 
| 110 | 
         
             
                    assert (Path(output_dir) / "adapter_model.bin").exists()
         
     | 
| 111 | 
         | 
| 112 | 
         
            +
                @with_temp_dir
         
     | 
| 113 | 
         
            +
                def test_lora_gptq(self, output_dir):
         
     | 
| 114 | 
         
             
                    # pylint: disable=duplicate-code
         
     | 
| 
         | 
|
| 115 | 
         
             
                    cfg = DictDefault(
         
     | 
| 116 | 
         
             
                        {
         
     | 
| 117 | 
         
             
                            "base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
         
     | 
    	
        tests/e2e/test_mistral.py
    CHANGED
    
    | 
         @@ -4,7 +4,6 @@ E2E tests for lora llama 
     | 
|
| 4 | 
         | 
| 5 | 
         
             
            import logging
         
     | 
| 6 | 
         
             
            import os
         
     | 
| 7 | 
         
            -
            import tempfile
         
     | 
| 8 | 
         
             
            import unittest
         
     | 
| 9 | 
         
             
            from pathlib import Path
         
     | 
| 10 | 
         | 
| 
         @@ -15,6 +14,7 @@ from axolotl.common.cli import TrainerCliArgs 
     | 
|
| 15 | 
         
             
            from axolotl.train import train
         
     | 
| 16 | 
         
             
            from axolotl.utils.config import normalize_config
         
     | 
| 17 | 
         
             
            from axolotl.utils.dict import DictDefault
         
     | 
| 
         | 
|
| 18 | 
         | 
| 19 | 
         
             
            LOG = logging.getLogger("axolotl.tests.e2e")
         
     | 
| 20 | 
         
             
            os.environ["WANDB_DISABLED"] = "true"
         
     | 
| 
         @@ -25,9 +25,9 @@ class TestMistral(unittest.TestCase): 
     | 
|
| 25 | 
         
             
                Test case for Llama models using LoRA
         
     | 
| 26 | 
         
             
                """
         
     | 
| 27 | 
         | 
| 28 | 
         
            -
                 
     | 
| 
         | 
|
| 29 | 
         
             
                    # pylint: disable=duplicate-code
         
     | 
| 30 | 
         
            -
                    output_dir = tempfile.mkdtemp()
         
     | 
| 31 | 
         
             
                    cfg = DictDefault(
         
     | 
| 32 | 
         
             
                        {
         
     | 
| 33 | 
         
             
                            "base_model": "openaccess-ai-collective/tiny-mistral",
         
     | 
| 
         @@ -70,9 +70,9 @@ class TestMistral(unittest.TestCase): 
     | 
|
| 70 | 
         
             
                    train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
         
     | 
| 71 | 
         
             
                    assert (Path(output_dir) / "adapter_model.bin").exists()
         
     | 
| 72 | 
         | 
| 73 | 
         
            -
                 
     | 
| 
         | 
|
| 74 | 
         
             
                    # pylint: disable=duplicate-code
         
     | 
| 75 | 
         
            -
                    output_dir = tempfile.mkdtemp()
         
     | 
| 76 | 
         
             
                    cfg = DictDefault(
         
     | 
| 77 | 
         
             
                        {
         
     | 
| 78 | 
         
             
                            "base_model": "openaccess-ai-collective/tiny-mistral",
         
     | 
| 
         | 
|
| 4 | 
         | 
| 5 | 
         
             
            import logging
         
     | 
| 6 | 
         
             
            import os
         
     | 
| 
         | 
|
| 7 | 
         
             
            import unittest
         
     | 
| 8 | 
         
             
            from pathlib import Path
         
     | 
| 9 | 
         | 
| 
         | 
|
| 14 | 
         
             
            from axolotl.train import train
         
     | 
| 15 | 
         
             
            from axolotl.utils.config import normalize_config
         
     | 
| 16 | 
         
             
            from axolotl.utils.dict import DictDefault
         
     | 
| 17 | 
         
            +
            from tests.utils import with_temp_dir
         
     | 
| 18 | 
         | 
| 19 | 
         
             
            LOG = logging.getLogger("axolotl.tests.e2e")
         
     | 
| 20 | 
         
             
            os.environ["WANDB_DISABLED"] = "true"
         
     | 
| 
         | 
|
| 25 | 
         
             
                Test case for Llama models using LoRA
         
     | 
| 26 | 
         
             
                """
         
     | 
| 27 | 
         | 
| 28 | 
         
            +
                @with_temp_dir
         
     | 
| 29 | 
         
            +
                def test_lora(self, output_dir):
         
     | 
| 30 | 
         
             
                    # pylint: disable=duplicate-code
         
     | 
| 
         | 
|
| 31 | 
         
             
                    cfg = DictDefault(
         
     | 
| 32 | 
         
             
                        {
         
     | 
| 33 | 
         
             
                            "base_model": "openaccess-ai-collective/tiny-mistral",
         
     | 
| 
         | 
|
| 70 | 
         
             
                    train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
         
     | 
| 71 | 
         
             
                    assert (Path(output_dir) / "adapter_model.bin").exists()
         
     | 
| 72 | 
         | 
| 73 | 
         
            +
                @with_temp_dir
         
     | 
| 74 | 
         
            +
                def test_ft(self, output_dir):
         
     | 
| 75 | 
         
             
                    # pylint: disable=duplicate-code
         
     | 
| 
         | 
|
| 76 | 
         
             
                    cfg = DictDefault(
         
     | 
| 77 | 
         
             
                        {
         
     | 
| 78 | 
         
             
                            "base_model": "openaccess-ai-collective/tiny-mistral",
         
     | 
    	
        tests/e2e/test_mistral_samplepack.py
    CHANGED
    
    | 
         @@ -4,7 +4,6 @@ E2E tests for lora llama 
     | 
|
| 4 | 
         | 
| 5 | 
         
             
            import logging
         
     | 
| 6 | 
         
             
            import os
         
     | 
| 7 | 
         
            -
            import tempfile
         
     | 
| 8 | 
         
             
            import unittest
         
     | 
| 9 | 
         
             
            from pathlib import Path
         
     | 
| 10 | 
         | 
| 
         @@ -15,6 +14,7 @@ from axolotl.common.cli import TrainerCliArgs 
     | 
|
| 15 | 
         
             
            from axolotl.train import train
         
     | 
| 16 | 
         
             
            from axolotl.utils.config import normalize_config
         
     | 
| 17 | 
         
             
            from axolotl.utils.dict import DictDefault
         
     | 
| 
         | 
|
| 18 | 
         | 
| 19 | 
         
             
            LOG = logging.getLogger("axolotl.tests.e2e")
         
     | 
| 20 | 
         
             
            os.environ["WANDB_DISABLED"] = "true"
         
     | 
| 
         @@ -25,9 +25,9 @@ class TestMistral(unittest.TestCase): 
     | 
|
| 25 | 
         
             
                Test case for Llama models using LoRA
         
     | 
| 26 | 
         
             
                """
         
     | 
| 27 | 
         | 
| 28 | 
         
            -
                 
     | 
| 
         | 
|
| 29 | 
         
             
                    # pylint: disable=duplicate-code
         
     | 
| 30 | 
         
            -
                    output_dir = tempfile.mkdtemp()
         
     | 
| 31 | 
         
             
                    cfg = DictDefault(
         
     | 
| 32 | 
         
             
                        {
         
     | 
| 33 | 
         
             
                            "base_model": "openaccess-ai-collective/tiny-mistral",
         
     | 
| 
         @@ -71,9 +71,9 @@ class TestMistral(unittest.TestCase): 
     | 
|
| 71 | 
         
             
                    train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
         
     | 
| 72 | 
         
             
                    assert (Path(output_dir) / "adapter_model.bin").exists()
         
     | 
| 73 | 
         | 
| 74 | 
         
            -
                 
     | 
| 
         | 
|
| 75 | 
         
             
                    # pylint: disable=duplicate-code
         
     | 
| 76 | 
         
            -
                    output_dir = tempfile.mkdtemp()
         
     | 
| 77 | 
         
             
                    cfg = DictDefault(
         
     | 
| 78 | 
         
             
                        {
         
     | 
| 79 | 
         
             
                            "base_model": "openaccess-ai-collective/tiny-mistral",
         
     | 
| 
         | 
|
| 4 | 
         | 
| 5 | 
         
             
            import logging
         
     | 
| 6 | 
         
             
            import os
         
     | 
| 
         | 
|
| 7 | 
         
             
            import unittest
         
     | 
| 8 | 
         
             
            from pathlib import Path
         
     | 
| 9 | 
         | 
| 
         | 
|
| 14 | 
         
             
            from axolotl.train import train
         
     | 
| 15 | 
         
             
            from axolotl.utils.config import normalize_config
         
     | 
| 16 | 
         
             
            from axolotl.utils.dict import DictDefault
         
     | 
| 17 | 
         
            +
            from tests.utils import with_temp_dir
         
     | 
| 18 | 
         | 
| 19 | 
         
             
            LOG = logging.getLogger("axolotl.tests.e2e")
         
     | 
| 20 | 
         
             
            os.environ["WANDB_DISABLED"] = "true"
         
     | 
| 
         | 
|
| 25 | 
         
             
                Test case for Llama models using LoRA
         
     | 
| 26 | 
         
             
                """
         
     | 
| 27 | 
         | 
| 28 | 
         
            +
                @with_temp_dir
         
     | 
| 29 | 
         
            +
                def test_lora_packing(self, output_dir):
         
     | 
| 30 | 
         
             
                    # pylint: disable=duplicate-code
         
     | 
| 
         | 
|
| 31 | 
         
             
                    cfg = DictDefault(
         
     | 
| 32 | 
         
             
                        {
         
     | 
| 33 | 
         
             
                            "base_model": "openaccess-ai-collective/tiny-mistral",
         
     | 
| 
         | 
|
| 71 | 
         
             
                    train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
         
     | 
| 72 | 
         
             
                    assert (Path(output_dir) / "adapter_model.bin").exists()
         
     | 
| 73 | 
         | 
| 74 | 
         
            +
                @with_temp_dir
         
     | 
| 75 | 
         
            +
                def test_ft_packing(self, output_dir):
         
     | 
| 76 | 
         
             
                    # pylint: disable=duplicate-code
         
     | 
| 
         | 
|
| 77 | 
         
             
                    cfg = DictDefault(
         
     | 
| 78 | 
         
             
                        {
         
     | 
| 79 | 
         
             
                            "base_model": "openaccess-ai-collective/tiny-mistral",
         
     | 
    	
        tests/e2e/test_phi.py
    CHANGED
    
    | 
         @@ -4,14 +4,15 @@ E2E tests for lora llama 
     | 
|
| 4 | 
         | 
| 5 | 
         
             
            import logging
         
     | 
| 6 | 
         
             
            import os
         
     | 
| 7 | 
         
            -
            import tempfile
         
     | 
| 8 | 
         
             
            import unittest
         
     | 
| 
         | 
|
| 9 | 
         | 
| 10 | 
         
             
            from axolotl.cli import load_datasets
         
     | 
| 11 | 
         
             
            from axolotl.common.cli import TrainerCliArgs
         
     | 
| 12 | 
         
             
            from axolotl.train import train
         
     | 
| 13 | 
         
             
            from axolotl.utils.config import normalize_config
         
     | 
| 14 | 
         
             
            from axolotl.utils.dict import DictDefault
         
     | 
| 
         | 
|
| 15 | 
         | 
| 16 | 
         
             
            LOG = logging.getLogger("axolotl.tests.e2e")
         
     | 
| 17 | 
         
             
            os.environ["WANDB_DISABLED"] = "true"
         
     | 
| 
         @@ -22,7 +23,8 @@ class TestPhi(unittest.TestCase): 
     | 
|
| 22 | 
         
             
                Test case for Llama models using LoRA
         
     | 
| 23 | 
         
             
                """
         
     | 
| 24 | 
         | 
| 25 | 
         
            -
                 
     | 
| 
         | 
|
| 26 | 
         
             
                    # pylint: disable=duplicate-code
         
     | 
| 27 | 
         
             
                    cfg = DictDefault(
         
     | 
| 28 | 
         
             
                        {
         
     | 
| 
         @@ -52,7 +54,7 @@ class TestPhi(unittest.TestCase): 
     | 
|
| 52 | 
         
             
                            "num_epochs": 1,
         
     | 
| 53 | 
         
             
                            "micro_batch_size": 1,
         
     | 
| 54 | 
         
             
                            "gradient_accumulation_steps": 1,
         
     | 
| 55 | 
         
            -
                            "output_dir":  
     | 
| 56 | 
         
             
                            "learning_rate": 0.00001,
         
     | 
| 57 | 
         
             
                            "optimizer": "adamw_bnb_8bit",
         
     | 
| 58 | 
         
             
                            "lr_scheduler": "cosine",
         
     | 
| 
         @@ -64,8 +66,10 @@ class TestPhi(unittest.TestCase): 
     | 
|
| 64 | 
         
             
                    dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
         
     | 
| 65 | 
         | 
| 66 | 
         
             
                    train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
         
     | 
| 
         | 
|
| 67 | 
         | 
| 68 | 
         
            -
                 
     | 
| 
         | 
|
| 69 | 
         
             
                    # pylint: disable=duplicate-code
         
     | 
| 70 | 
         
             
                    cfg = DictDefault(
         
     | 
| 71 | 
         
             
                        {
         
     | 
| 
         @@ -95,7 +99,7 @@ class TestPhi(unittest.TestCase): 
     | 
|
| 95 | 
         
             
                            "num_epochs": 1,
         
     | 
| 96 | 
         
             
                            "micro_batch_size": 1,
         
     | 
| 97 | 
         
             
                            "gradient_accumulation_steps": 1,
         
     | 
| 98 | 
         
            -
                            "output_dir":  
     | 
| 99 | 
         
             
                            "learning_rate": 0.00001,
         
     | 
| 100 | 
         
             
                            "optimizer": "adamw_bnb_8bit",
         
     | 
| 101 | 
         
             
                            "lr_scheduler": "cosine",
         
     | 
| 
         @@ -107,3 +111,4 @@ class TestPhi(unittest.TestCase): 
     | 
|
| 107 | 
         
             
                    dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
         
     | 
| 108 | 
         | 
| 109 | 
         
             
                    train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
         
     | 
| 
         | 
| 
         | 
|
| 4 | 
         | 
| 5 | 
         
             
            import logging
         
     | 
| 6 | 
         
             
            import os
         
     | 
| 
         | 
|
| 7 | 
         
             
            import unittest
         
     | 
| 8 | 
         
            +
            from pathlib import Path
         
     | 
| 9 | 
         | 
| 10 | 
         
             
            from axolotl.cli import load_datasets
         
     | 
| 11 | 
         
             
            from axolotl.common.cli import TrainerCliArgs
         
     | 
| 12 | 
         
             
            from axolotl.train import train
         
     | 
| 13 | 
         
             
            from axolotl.utils.config import normalize_config
         
     | 
| 14 | 
         
             
            from axolotl.utils.dict import DictDefault
         
     | 
| 15 | 
         
            +
            from tests.utils import with_temp_dir
         
     | 
| 16 | 
         | 
| 17 | 
         
             
            LOG = logging.getLogger("axolotl.tests.e2e")
         
     | 
| 18 | 
         
             
            os.environ["WANDB_DISABLED"] = "true"
         
     | 
| 
         | 
|
| 23 | 
         
             
                Test case for Llama models using LoRA
         
     | 
| 24 | 
         
             
                """
         
     | 
| 25 | 
         | 
| 26 | 
         
            +
                @with_temp_dir
         
     | 
| 27 | 
         
            +
                def test_ft(self, output_dir):
         
     | 
| 28 | 
         
             
                    # pylint: disable=duplicate-code
         
     | 
| 29 | 
         
             
                    cfg = DictDefault(
         
     | 
| 30 | 
         
             
                        {
         
     | 
| 
         | 
|
| 54 | 
         
             
                            "num_epochs": 1,
         
     | 
| 55 | 
         
             
                            "micro_batch_size": 1,
         
     | 
| 56 | 
         
             
                            "gradient_accumulation_steps": 1,
         
     | 
| 57 | 
         
            +
                            "output_dir": output_dir,
         
     | 
| 58 | 
         
             
                            "learning_rate": 0.00001,
         
     | 
| 59 | 
         
             
                            "optimizer": "adamw_bnb_8bit",
         
     | 
| 60 | 
         
             
                            "lr_scheduler": "cosine",
         
     | 
| 
         | 
|
| 66 | 
         
             
                    dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
         
     | 
| 67 | 
         | 
| 68 | 
         
             
                    train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
         
     | 
| 69 | 
         
            +
                    assert (Path(output_dir) / "pytorch_model.bin").exists()
         
     | 
| 70 | 
         | 
| 71 | 
         
            +
                @with_temp_dir
         
     | 
| 72 | 
         
            +
                def test_ft_packed(self, output_dir):
         
     | 
| 73 | 
         
             
                    # pylint: disable=duplicate-code
         
     | 
| 74 | 
         
             
                    cfg = DictDefault(
         
     | 
| 75 | 
         
             
                        {
         
     | 
| 
         | 
|
| 99 | 
         
             
                            "num_epochs": 1,
         
     | 
| 100 | 
         
             
                            "micro_batch_size": 1,
         
     | 
| 101 | 
         
             
                            "gradient_accumulation_steps": 1,
         
     | 
| 102 | 
         
            +
                            "output_dir": output_dir,
         
     | 
| 103 | 
         
             
                            "learning_rate": 0.00001,
         
     | 
| 104 | 
         
             
                            "optimizer": "adamw_bnb_8bit",
         
     | 
| 105 | 
         
             
                            "lr_scheduler": "cosine",
         
     | 
| 
         | 
|
| 111 | 
         
             
                    dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
         
     | 
| 112 | 
         | 
| 113 | 
         
             
                    train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
         
     | 
| 114 | 
         
            +
                    assert (Path(output_dir) / "pytorch_model.bin").exists()
         
     | 
    	
        tests/utils.py
    ADDED
    
    | 
         @@ -0,0 +1,22 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            helper utils for tests
         
     | 
| 3 | 
         
            +
            """
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import shutil
         
     | 
| 6 | 
         
            +
            import tempfile
         
     | 
| 7 | 
         
            +
            from functools import wraps
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            def with_temp_dir(test_func):
         
     | 
| 11 | 
         
            +
                @wraps(test_func)
         
     | 
| 12 | 
         
            +
                def wrapper(*args, **kwargs):
         
     | 
| 13 | 
         
            +
                    # Create a temporary directory
         
     | 
| 14 | 
         
            +
                    temp_dir = tempfile.mkdtemp()
         
     | 
| 15 | 
         
            +
                    try:
         
     | 
| 16 | 
         
            +
                        # Pass the temporary directory to the test function
         
     | 
| 17 | 
         
            +
                        test_func(temp_dir, *args, **kwargs)
         
     | 
| 18 | 
         
            +
                    finally:
         
     | 
| 19 | 
         
            +
                        # Clean up the directory after the test
         
     | 
| 20 | 
         
            +
                        shutil.rmtree(temp_dir)
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                return wrapper
         
     |