|
|
|
"""Module for testing the validation module""" |
|
|
|
import logging |
|
import os |
|
import warnings |
|
from typing import Optional |
|
|
|
import pytest |
|
from pydantic import ValidationError |
|
|
|
from axolotl.utils.config import validate_config |
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities |
|
from axolotl.utils.dict import DictDefault |
|
from axolotl.utils.models import check_model_config |
|
from axolotl.utils.wandb_ import setup_wandb_env_vars |
|
|
|
warnings.filterwarnings("error") |
|
|
|
|
|
@pytest.fixture(name="minimal_cfg") |
|
def fixture_cfg(): |
|
return DictDefault( |
|
{ |
|
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", |
|
"learning_rate": 0.000001, |
|
"datasets": [ |
|
{ |
|
"path": "mhenrichsen/alpaca_2k_test", |
|
"type": "alpaca", |
|
} |
|
], |
|
"micro_batch_size": 1, |
|
"gradient_accumulation_steps": 1, |
|
} |
|
) |
|
|
|
|
|
class BaseValidation: |
|
""" |
|
Base validation module to setup the log capture |
|
""" |
|
|
|
_caplog: Optional[pytest.LogCaptureFixture] = None |
|
|
|
@pytest.fixture(autouse=True) |
|
def inject_fixtures(self, caplog): |
|
self._caplog = caplog |
|
|
|
|
|
|
|
class TestValidation(BaseValidation): |
|
""" |
|
Test the validation module |
|
""" |
|
|
|
def test_defaults(self, minimal_cfg): |
|
test_cfg = DictDefault( |
|
{ |
|
"weight_decay": None, |
|
} |
|
| minimal_cfg |
|
) |
|
cfg = validate_config(test_cfg) |
|
|
|
assert cfg.train_on_inputs is False |
|
assert cfg.weight_decay is None |
|
|
|
def test_datasets_min_length(self): |
|
cfg = DictDefault( |
|
{ |
|
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", |
|
"learning_rate": 0.000001, |
|
"datasets": [], |
|
"micro_batch_size": 1, |
|
"gradient_accumulation_steps": 1, |
|
} |
|
) |
|
|
|
with pytest.raises( |
|
ValidationError, |
|
match=r".*List should have at least 1 item after validation*", |
|
): |
|
validate_config(cfg) |
|
|
|
def test_datasets_min_length_empty(self): |
|
cfg = DictDefault( |
|
{ |
|
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", |
|
"learning_rate": 0.000001, |
|
"micro_batch_size": 1, |
|
"gradient_accumulation_steps": 1, |
|
} |
|
) |
|
|
|
with pytest.raises( |
|
ValueError, match=r".*either datasets or pretraining_dataset is required*" |
|
): |
|
validate_config(cfg) |
|
|
|
def test_pretrain_dataset_min_length(self): |
|
cfg = DictDefault( |
|
{ |
|
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", |
|
"learning_rate": 0.000001, |
|
"pretraining_dataset": [], |
|
"micro_batch_size": 1, |
|
"gradient_accumulation_steps": 1, |
|
"max_steps": 100, |
|
} |
|
) |
|
|
|
with pytest.raises( |
|
ValidationError, |
|
match=r".*List should have at least 1 item after validation*", |
|
): |
|
validate_config(cfg) |
|
|
|
def test_valid_pretrain_dataset(self): |
|
cfg = DictDefault( |
|
{ |
|
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", |
|
"learning_rate": 0.000001, |
|
"pretraining_dataset": [ |
|
{ |
|
"path": "mhenrichsen/alpaca_2k_test", |
|
"type": "alpaca", |
|
} |
|
], |
|
"micro_batch_size": 1, |
|
"gradient_accumulation_steps": 1, |
|
"max_steps": 100, |
|
} |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
def test_valid_sft_dataset(self): |
|
cfg = DictDefault( |
|
{ |
|
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", |
|
"learning_rate": 0.000001, |
|
"datasets": [ |
|
{ |
|
"path": "mhenrichsen/alpaca_2k_test", |
|
"type": "alpaca", |
|
} |
|
], |
|
"micro_batch_size": 1, |
|
"gradient_accumulation_steps": 1, |
|
} |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
def test_batch_size_unused_warning(self): |
|
cfg = DictDefault( |
|
{ |
|
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", |
|
"learning_rate": 0.000001, |
|
"datasets": [ |
|
{ |
|
"path": "mhenrichsen/alpaca_2k_test", |
|
"type": "alpaca", |
|
} |
|
], |
|
"micro_batch_size": 4, |
|
"batch_size": 32, |
|
} |
|
) |
|
|
|
with self._caplog.at_level(logging.WARNING): |
|
validate_config(cfg) |
|
assert "batch_size is not recommended" in self._caplog.records[0].message |
|
|
|
def test_batch_size_more_params(self): |
|
cfg = DictDefault( |
|
{ |
|
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", |
|
"learning_rate": 0.000001, |
|
"datasets": [ |
|
{ |
|
"path": "mhenrichsen/alpaca_2k_test", |
|
"type": "alpaca", |
|
} |
|
], |
|
"batch_size": 32, |
|
} |
|
) |
|
|
|
with pytest.raises(ValueError, match=r".*At least two of*"): |
|
validate_config(cfg) |
|
|
|
def test_lr_as_float(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"learning_rate": "5e-5", |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
new_cfg = validate_config(cfg) |
|
|
|
assert new_cfg.learning_rate == 0.00005 |
|
|
|
def test_model_config_remap(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"model_config": {"model_type": "mistral"}, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
new_cfg = validate_config(cfg) |
|
assert new_cfg.overrides_of_model_config["model_type"] == "mistral" |
|
|
|
def test_model_type_remap(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"model_type": "AutoModelForCausalLM", |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
new_cfg = validate_config(cfg) |
|
assert new_cfg.type_of_model == "AutoModelForCausalLM" |
|
|
|
def test_model_revision_remap(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"model_revision": "main", |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
new_cfg = validate_config(cfg) |
|
assert new_cfg.revision_of_model == "main" |
|
|
|
def test_qlora(self, minimal_cfg): |
|
base_cfg = ( |
|
DictDefault( |
|
{ |
|
"adapter": "qlora", |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"load_in_8bit": True, |
|
} |
|
) |
|
| base_cfg |
|
) |
|
|
|
with pytest.raises(ValueError, match=r".*8bit.*"): |
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"gptq": True, |
|
} |
|
) |
|
| base_cfg |
|
) |
|
|
|
with pytest.raises(ValueError, match=r".*gptq.*"): |
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"load_in_4bit": False, |
|
} |
|
) |
|
| base_cfg |
|
) |
|
|
|
with pytest.raises(ValueError, match=r".*4bit.*"): |
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"load_in_4bit": True, |
|
} |
|
) |
|
| base_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
def test_qlora_merge(self, minimal_cfg): |
|
base_cfg = ( |
|
DictDefault( |
|
{ |
|
"adapter": "qlora", |
|
"merge_lora": True, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"load_in_8bit": True, |
|
} |
|
) |
|
| base_cfg |
|
) |
|
|
|
with pytest.raises(ValueError, match=r".*8bit.*"): |
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"gptq": True, |
|
} |
|
) |
|
| base_cfg |
|
) |
|
|
|
with pytest.raises(ValueError, match=r".*gptq.*"): |
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"load_in_4bit": True, |
|
} |
|
) |
|
| base_cfg |
|
) |
|
|
|
with pytest.raises(ValueError, match=r".*4bit.*"): |
|
validate_config(cfg) |
|
|
|
def test_hf_use_auth_token(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"push_dataset_to_hub": "namespace/repo", |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with pytest.raises(ValueError, match=r".*hf_use_auth_token.*"): |
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"push_dataset_to_hub": "namespace/repo", |
|
"hf_use_auth_token": True, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
validate_config(cfg) |
|
|
|
def test_gradient_accumulations_or_batch_size(self): |
|
cfg = DictDefault( |
|
{ |
|
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", |
|
"learning_rate": 0.000001, |
|
"datasets": [ |
|
{ |
|
"path": "mhenrichsen/alpaca_2k_test", |
|
"type": "alpaca", |
|
} |
|
], |
|
"gradient_accumulation_steps": 1, |
|
"batch_size": 1, |
|
} |
|
) |
|
|
|
with pytest.raises( |
|
ValueError, match=r".*gradient_accumulation_steps or batch_size.*" |
|
): |
|
validate_config(cfg) |
|
|
|
def test_falcon_fsdp(self, minimal_cfg): |
|
regex_exp = r".*FSDP is not supported for falcon models.*" |
|
|
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"base_model": "tiiuae/falcon-7b", |
|
"fsdp": ["full_shard", "auto_wrap"], |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with pytest.raises(ValueError, match=regex_exp): |
|
validate_config(cfg) |
|
|
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"base_model": "Falcon-7b", |
|
"fsdp": ["full_shard", "auto_wrap"], |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with pytest.raises(ValueError, match=regex_exp): |
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"base_model": "tiiuae/falcon-7b", |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
def test_mpt_gradient_checkpointing(self, minimal_cfg): |
|
regex_exp = r".*gradient_checkpointing is not supported for MPT models*" |
|
|
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"base_model": "mosaicml/mpt-7b", |
|
"gradient_checkpointing": True, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with pytest.raises(ValueError, match=regex_exp): |
|
validate_config(cfg) |
|
|
|
def test_flash_optimum(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"flash_optimum": True, |
|
"adapter": "lora", |
|
"bf16": False, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with self._caplog.at_level(logging.WARNING): |
|
validate_config(cfg) |
|
assert any( |
|
"BetterTransformers probably doesn't work with PEFT adapters" |
|
in record.message |
|
for record in self._caplog.records |
|
) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"flash_optimum": True, |
|
"bf16": False, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with self._caplog.at_level(logging.WARNING): |
|
validate_config(cfg) |
|
assert any( |
|
"probably set bfloat16 or float16" in record.message |
|
for record in self._caplog.records |
|
) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"flash_optimum": True, |
|
"fp16": True, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
regex_exp = r".*AMP is not supported.*" |
|
|
|
with pytest.raises(ValueError, match=regex_exp): |
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"flash_optimum": True, |
|
"bf16": True, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
regex_exp = r".*AMP is not supported.*" |
|
|
|
with pytest.raises(ValueError, match=regex_exp): |
|
validate_config(cfg) |
|
|
|
def test_adamw_hyperparams(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"optimizer": None, |
|
"adam_epsilon": 0.0001, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with self._caplog.at_level(logging.WARNING): |
|
validate_config(cfg) |
|
assert any( |
|
"adamw hyperparameters found, but no adamw optimizer set" |
|
in record.message |
|
for record in self._caplog.records |
|
) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"optimizer": "adafactor", |
|
"adam_beta1": 0.0001, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with self._caplog.at_level(logging.WARNING): |
|
validate_config(cfg) |
|
assert any( |
|
"adamw hyperparameters found, but no adamw optimizer set" |
|
in record.message |
|
for record in self._caplog.records |
|
) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"optimizer": "adamw_bnb_8bit", |
|
"adam_beta1": 0.9, |
|
"adam_beta2": 0.99, |
|
"adam_epsilon": 0.0001, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"optimizer": "adafactor", |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
def test_deprecated_packing(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"max_packed_sequence_len": 1024, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
with pytest.raises( |
|
DeprecationWarning, |
|
match=r"`max_packed_sequence_len` is no longer supported", |
|
): |
|
validate_config(cfg) |
|
|
|
def test_packing(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"sample_packing": True, |
|
"pad_to_sequence_len": None, |
|
"flash_attention": True, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
with self._caplog.at_level(logging.WARNING): |
|
validate_config(cfg) |
|
assert any( |
|
"`pad_to_sequence_len: true` is recommended when using sample_packing" |
|
in record.message |
|
for record in self._caplog.records |
|
) |
|
|
|
def test_merge_lora_no_bf16_fail(self, minimal_cfg): |
|
""" |
|
This is assumed to be run on a CPU machine, so bf16 is not supported. |
|
""" |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"bf16": True, |
|
"capabilities": {"bf16": False}, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with pytest.raises(ValueError, match=r".*AMP is not supported on this GPU*"): |
|
AxolotlConfigWCapabilities(**cfg.to_dict()) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"bf16": True, |
|
"merge_lora": True, |
|
"capabilities": {"bf16": False}, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
def test_sharegpt_deprecation(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault( |
|
{"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]} |
|
) |
|
| minimal_cfg |
|
) |
|
with self._caplog.at_level(logging.WARNING): |
|
new_cfg = validate_config(cfg) |
|
assert any( |
|
"`type: sharegpt:chat` will soon be deprecated." in record.message |
|
for record in self._caplog.records |
|
) |
|
assert new_cfg.datasets[0].type == "sharegpt" |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"datasets": [ |
|
{"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"} |
|
] |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
with self._caplog.at_level(logging.WARNING): |
|
new_cfg = validate_config(cfg) |
|
assert any( |
|
"`type: sharegpt_simple` will soon be deprecated." in record.message |
|
for record in self._caplog.records |
|
) |
|
assert new_cfg.datasets[0].type == "sharegpt:load_role" |
|
|
|
def test_no_conflict_save_strategy(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"save_strategy": "epoch", |
|
"save_steps": 10, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with pytest.raises( |
|
ValueError, match=r".*save_strategy and save_steps mismatch.*" |
|
): |
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"save_strategy": "no", |
|
"save_steps": 10, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with pytest.raises( |
|
ValueError, match=r".*save_strategy and save_steps mismatch.*" |
|
): |
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"save_strategy": "steps", |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"save_strategy": "steps", |
|
"save_steps": 10, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"save_steps": 10, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"save_strategy": "no", |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
def test_no_conflict_eval_strategy(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"evaluation_strategy": "epoch", |
|
"eval_steps": 10, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with pytest.raises( |
|
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*" |
|
): |
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"evaluation_strategy": "no", |
|
"eval_steps": 10, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with pytest.raises( |
|
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*" |
|
): |
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"evaluation_strategy": "steps", |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"evaluation_strategy": "steps", |
|
"eval_steps": 10, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"eval_steps": 10, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"evaluation_strategy": "no", |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"evaluation_strategy": "epoch", |
|
"val_set_size": 0, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with pytest.raises( |
|
ValueError, |
|
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*", |
|
): |
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"eval_steps": 10, |
|
"val_set_size": 0, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with pytest.raises( |
|
ValueError, |
|
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*", |
|
): |
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"val_set_size": 0, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"eval_steps": 10, |
|
"val_set_size": 0.01, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"evaluation_strategy": "epoch", |
|
"val_set_size": 0.01, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
def test_eval_table_size_conflict_eval_packing(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"sample_packing": True, |
|
"eval_table_size": 100, |
|
"flash_attention": True, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with pytest.raises( |
|
ValueError, match=r".*Please set 'eval_sample_packing' to false.*" |
|
): |
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"sample_packing": True, |
|
"eval_sample_packing": False, |
|
"flash_attention": True, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"sample_packing": False, |
|
"eval_table_size": 100, |
|
"flash_attention": True, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"sample_packing": True, |
|
"eval_table_size": 100, |
|
"eval_sample_packing": False, |
|
"flash_attention": True, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
def test_load_in_x_bit_without_adapter(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"load_in_4bit": True, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with pytest.raises( |
|
ValueError, |
|
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*", |
|
): |
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"load_in_8bit": True, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with pytest.raises( |
|
ValueError, |
|
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*", |
|
): |
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"load_in_4bit": True, |
|
"adapter": "qlora", |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"load_in_8bit": True, |
|
"adapter": "lora", |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
def test_warmup_step_no_conflict(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"warmup_steps": 10, |
|
"warmup_ratio": 0.1, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with pytest.raises( |
|
ValueError, |
|
match=r".*warmup_steps and warmup_ratio are mutually exclusive*", |
|
): |
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"warmup_steps": 10, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"warmup_ratio": 0.1, |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
def test_unfrozen_parameters_w_peft_layers_to_transform(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"adapter": "lora", |
|
"unfrozen_parameters": [ |
|
"model.layers.2[0-9]+.block_sparse_moe.gate.*" |
|
], |
|
"peft_layers_to_transform": [0, 1], |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with pytest.raises( |
|
ValueError, |
|
match=r".*can have unexpected behavior*", |
|
): |
|
validate_config(cfg) |
|
|
|
def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg): |
|
cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg |
|
|
|
with self._caplog.at_level(logging.WARNING): |
|
validate_config(cfg) |
|
assert len(self._caplog.records) == 1 |
|
|
|
def test_hub_model_id_save_value_warns_random_value(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg |
|
) |
|
|
|
with self._caplog.at_level(logging.WARNING): |
|
validate_config(cfg) |
|
assert len(self._caplog.records) == 1 |
|
|
|
def test_hub_model_id_save_value_steps(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault({"hub_model_id": "test", "save_strategy": "steps"}) |
|
| minimal_cfg |
|
) |
|
|
|
with self._caplog.at_level(logging.WARNING): |
|
validate_config(cfg) |
|
assert len(self._caplog.records) == 0 |
|
|
|
def test_hub_model_id_save_value_epochs(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault({"hub_model_id": "test", "save_strategy": "epoch"}) |
|
| minimal_cfg |
|
) |
|
|
|
with self._caplog.at_level(logging.WARNING): |
|
validate_config(cfg) |
|
assert len(self._caplog.records) == 0 |
|
|
|
def test_hub_model_id_save_value_none(self, minimal_cfg): |
|
cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg |
|
|
|
with self._caplog.at_level(logging.WARNING): |
|
validate_config(cfg) |
|
assert len(self._caplog.records) == 0 |
|
|
|
def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg): |
|
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg |
|
|
|
with self._caplog.at_level(logging.WARNING): |
|
validate_config(cfg) |
|
assert len(self._caplog.records) == 0 |
|
|
|
def test_dpo_beta_deprecation(self, minimal_cfg): |
|
cfg = DictDefault({"dpo_beta": 0.2}) | minimal_cfg |
|
|
|
with self._caplog.at_level(logging.WARNING): |
|
new_cfg = validate_config(cfg) |
|
assert new_cfg["rl_beta"] == 0.2 |
|
assert new_cfg["dpo_beta"] is None |
|
assert len(self._caplog.records) == 1 |
|
|
|
|
|
class TestValidationCheckModelConfig(BaseValidation): |
|
""" |
|
Test the validation for the config when the model config is available |
|
""" |
|
|
|
def test_llama_add_tokens_adapter(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault( |
|
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]} |
|
) |
|
| minimal_cfg |
|
) |
|
model_config = DictDefault({"model_type": "llama"}) |
|
|
|
with pytest.raises( |
|
ValueError, |
|
match=r".*`lora_modules_to_save` not properly set when adding new tokens*", |
|
): |
|
check_model_config(cfg, model_config) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"adapter": "qlora", |
|
"load_in_4bit": True, |
|
"tokens": ["<|imstart|>"], |
|
"lora_modules_to_save": ["embed_tokens"], |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with pytest.raises( |
|
ValueError, |
|
match=r".*`lora_modules_to_save` not properly set when adding new tokens*", |
|
): |
|
check_model_config(cfg, model_config) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"adapter": "qlora", |
|
"load_in_4bit": True, |
|
"tokens": ["<|imstart|>"], |
|
"lora_modules_to_save": ["embed_tokens", "lm_head"], |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
check_model_config(cfg, model_config) |
|
|
|
def test_phi_add_tokens_adapter(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault( |
|
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]} |
|
) |
|
| minimal_cfg |
|
) |
|
model_config = DictDefault({"model_type": "phi"}) |
|
|
|
with pytest.raises( |
|
ValueError, |
|
match=r".*`lora_modules_to_save` not properly set when adding new tokens*", |
|
): |
|
check_model_config(cfg, model_config) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"adapter": "qlora", |
|
"load_in_4bit": True, |
|
"tokens": ["<|imstart|>"], |
|
"lora_modules_to_save": ["embd.wte", "lm_head.linear"], |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with pytest.raises( |
|
ValueError, |
|
match=r".*`lora_modules_to_save` not properly set when adding new tokens*", |
|
): |
|
check_model_config(cfg, model_config) |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"adapter": "qlora", |
|
"load_in_4bit": True, |
|
"tokens": ["<|imstart|>"], |
|
"lora_modules_to_save": ["embed_tokens", "lm_head"], |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
check_model_config(cfg, model_config) |
|
|
|
|
|
class TestValidationWandb(BaseValidation): |
|
""" |
|
Validation test for wandb |
|
""" |
|
|
|
def test_wandb_set_run_id_to_name(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"wandb_run_id": "foo", |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
with self._caplog.at_level(logging.WARNING): |
|
new_cfg = validate_config(cfg) |
|
assert any( |
|
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead." |
|
in record.message |
|
for record in self._caplog.records |
|
) |
|
|
|
assert new_cfg.wandb_name == "foo" and new_cfg.wandb_run_id == "foo" |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"wandb_name": "foo", |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
new_cfg = validate_config(cfg) |
|
|
|
assert new_cfg.wandb_name == "foo" and new_cfg.wandb_run_id is None |
|
|
|
def test_wandb_sets_env(self, minimal_cfg): |
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"wandb_project": "foo", |
|
"wandb_name": "bar", |
|
"wandb_run_id": "bat", |
|
"wandb_entity": "baz", |
|
"wandb_mode": "online", |
|
"wandb_watch": "false", |
|
"wandb_log_model": "checkpoint", |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
new_cfg = validate_config(cfg) |
|
|
|
setup_wandb_env_vars(new_cfg) |
|
|
|
assert os.environ.get("WANDB_PROJECT", "") == "foo" |
|
assert os.environ.get("WANDB_NAME", "") == "bar" |
|
assert os.environ.get("WANDB_RUN_ID", "") == "bat" |
|
assert os.environ.get("WANDB_ENTITY", "") == "baz" |
|
assert os.environ.get("WANDB_MODE", "") == "online" |
|
assert os.environ.get("WANDB_WATCH", "") == "false" |
|
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint" |
|
assert os.environ.get("WANDB_DISABLED", "") != "true" |
|
|
|
os.environ.pop("WANDB_PROJECT", None) |
|
os.environ.pop("WANDB_NAME", None) |
|
os.environ.pop("WANDB_RUN_ID", None) |
|
os.environ.pop("WANDB_ENTITY", None) |
|
os.environ.pop("WANDB_MODE", None) |
|
os.environ.pop("WANDB_WATCH", None) |
|
os.environ.pop("WANDB_LOG_MODEL", None) |
|
os.environ.pop("WANDB_DISABLED", None) |
|
|
|
def test_wandb_set_disabled(self, minimal_cfg): |
|
cfg = DictDefault({}) | minimal_cfg |
|
|
|
new_cfg = validate_config(cfg) |
|
|
|
setup_wandb_env_vars(new_cfg) |
|
|
|
assert os.environ.get("WANDB_DISABLED", "") == "true" |
|
|
|
cfg = ( |
|
DictDefault( |
|
{ |
|
"wandb_project": "foo", |
|
} |
|
) |
|
| minimal_cfg |
|
) |
|
|
|
new_cfg = validate_config(cfg) |
|
|
|
setup_wandb_env_vars(new_cfg) |
|
|
|
assert os.environ.get("WANDB_DISABLED", "") != "true" |
|
|
|
os.environ.pop("WANDB_PROJECT", None) |
|
os.environ.pop("WANDB_DISABLED", None) |
|
|