Spaces:
Sleeping
Sleeping
| # Copyright 2023-present the HuggingFace Inc. team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import copy | |
| import json | |
| import os | |
| import pickle | |
| import tempfile | |
| import warnings | |
| import pytest | |
| from peft import ( | |
| AdaLoraConfig, | |
| AdaptionPromptConfig, | |
| BOFTConfig, | |
| BoneConfig, | |
| C3AConfig, | |
| FourierFTConfig, | |
| HRAConfig, | |
| IA3Config, | |
| LNTuningConfig, | |
| LoHaConfig, | |
| LoKrConfig, | |
| LoraConfig, | |
| MissConfig, | |
| MultitaskPromptTuningConfig, | |
| OFTConfig, | |
| PeftConfig, | |
| PeftType, | |
| PolyConfig, | |
| PrefixTuningConfig, | |
| PromptEncoder, | |
| PromptEncoderConfig, | |
| PromptTuningConfig, | |
| RoadConfig, | |
| ShiraConfig, | |
| TaskType, | |
| TrainableTokensConfig, | |
| VBLoRAConfig, | |
| VeraConfig, | |
| XLoraConfig, | |
| ) | |
| PEFT_MODELS_TO_TEST = [("peft-internal-testing/tiny-opt-lora-revision", "test")] | |
| # Config classes and their mandatory parameters | |
| ALL_CONFIG_CLASSES = ( | |
| (AdaLoraConfig, {"total_step": 1}), | |
| (AdaptionPromptConfig, {}), | |
| (BOFTConfig, {}), | |
| (BoneConfig, {}), | |
| (C3AConfig, {}), | |
| (FourierFTConfig, {}), | |
| (HRAConfig, {}), | |
| (IA3Config, {}), | |
| (LNTuningConfig, {}), | |
| (LoHaConfig, {}), | |
| (LoKrConfig, {}), | |
| (LoraConfig, {}), | |
| (MissConfig, {}), | |
| (MultitaskPromptTuningConfig, {}), | |
| (PolyConfig, {}), | |
| (PrefixTuningConfig, {}), | |
| (PromptEncoderConfig, {}), | |
| (PromptTuningConfig, {}), | |
| (RoadConfig, {}), | |
| (ShiraConfig, {}), | |
| (TrainableTokensConfig, {}), | |
| (VeraConfig, {}), | |
| (VBLoRAConfig, {}), | |
| (XLoraConfig, {"hidden_size": 32, "adapters": {}}), | |
| ) | |
| class TestPeftConfig: | |
| def test_methods(self, config_class, mandatory_kwargs): | |
| r""" | |
| Test if all configs have the expected methods. Here we test | |
| - to_dict | |
| - save_pretrained | |
| - from_pretrained | |
| - from_json_file | |
| """ | |
| # test if all configs have the expected methods | |
| config = config_class(**mandatory_kwargs) | |
| assert hasattr(config, "to_dict") | |
| assert hasattr(config, "save_pretrained") | |
| assert hasattr(config, "from_pretrained") | |
| assert hasattr(config, "from_json_file") | |
| def test_valid_task_type(self, config_class, mandatory_kwargs, valid_task_type): | |
| r""" | |
| Test if all configs work correctly for all valid task types | |
| """ | |
| config_class(task_type=valid_task_type, **mandatory_kwargs) | |
| def test_invalid_task_type(self, config_class, mandatory_kwargs): | |
| r""" | |
| Test if all configs correctly raise the defined error message for invalid task types. | |
| """ | |
| invalid_task_type = "invalid-task-type" | |
| with pytest.raises( | |
| ValueError, | |
| match=f"Invalid task type: '{invalid_task_type}'. Must be one of the following task types: {', '.join(TaskType)}.", | |
| ): | |
| config_class(task_type=invalid_task_type, **mandatory_kwargs) | |
| def test_from_peft_type(self): | |
| r""" | |
| Test if the config is correctly loaded using: | |
| - from_peft_type | |
| """ | |
| from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING | |
| for peft_type in PeftType: | |
| expected_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type] | |
| mandatory_config_kwargs = {} | |
| if expected_cls == AdaLoraConfig: | |
| mandatory_config_kwargs = {"total_step": 1} | |
| config = PeftConfig.from_peft_type(peft_type=peft_type, **mandatory_config_kwargs) | |
| assert type(config) is expected_cls | |
| def test_from_pretrained(self, config_class, mandatory_kwargs): | |
| r""" | |
| Test if the config is correctly loaded using: | |
| - from_pretrained | |
| """ | |
| for model_name, revision in PEFT_MODELS_TO_TEST: | |
| # Test we can load config from delta | |
| config_class.from_pretrained(model_name, revision=revision) | |
| def test_save_pretrained(self, config_class, mandatory_kwargs): | |
| r""" | |
| Test if the config is correctly saved and loaded using | |
| - save_pretrained | |
| """ | |
| config = config_class(**mandatory_kwargs) | |
| with tempfile.TemporaryDirectory() as tmp_dirname: | |
| config.save_pretrained(tmp_dirname) | |
| config_from_pretrained = config_class.from_pretrained(tmp_dirname) | |
| assert config.to_dict() == config_from_pretrained.to_dict() | |
| def test_from_json_file(self, config_class, mandatory_kwargs): | |
| config = config_class(**mandatory_kwargs) | |
| with tempfile.TemporaryDirectory() as tmp_dirname: | |
| config.save_pretrained(tmp_dirname) | |
| config_path = os.path.join(tmp_dirname, "adapter_config.json") | |
| config_from_json = config_class.from_json_file(config_path) | |
| assert config.to_dict() == config_from_json | |
| # Also test with a runtime_config entry -- they should be ignored, even if they | |
| # were accidentally saved to disk | |
| config_from_json["runtime_config"] = {"ephemeral_gpu_offload": True} | |
| json.dump(config_from_json, open(config_path, "w")) | |
| config_from_json = config_class.from_json_file(config_path) | |
| assert config.to_dict() == config_from_json | |
| def test_to_dict(self, config_class, mandatory_kwargs): | |
| r""" | |
| Test if the config can be correctly converted to a dict using: | |
| - to_dict | |
| """ | |
| config = config_class(**mandatory_kwargs) | |
| assert isinstance(config.to_dict(), dict) | |
| def test_from_pretrained_cache_dir(self, config_class, mandatory_kwargs): | |
| r""" | |
| Test if the config is correctly loaded with extra kwargs | |
| """ | |
| with tempfile.TemporaryDirectory() as tmp_dirname: | |
| for model_name, revision in PEFT_MODELS_TO_TEST: | |
| # Test we can load config from delta | |
| config_class.from_pretrained(model_name, revision=revision, cache_dir=tmp_dirname) | |
| def test_from_pretrained_cache_dir_remote(self): | |
| r""" | |
| Test if the config is correctly loaded with a checkpoint from the hub | |
| """ | |
| with tempfile.TemporaryDirectory() as tmp_dirname: | |
| PeftConfig.from_pretrained("ybelkada/test-st-lora", cache_dir=tmp_dirname) | |
| assert "models--ybelkada--test-st-lora" in os.listdir(tmp_dirname) | |
| def test_save_pretrained_with_runtime_config(self, config_class, mandatory_kwargs): | |
| r""" | |
| Test if the config correctly removes runtime config when saving | |
| """ | |
| with tempfile.TemporaryDirectory() as tmp_dirname: | |
| for model_name, revision in PEFT_MODELS_TO_TEST: | |
| cfg = config_class.from_pretrained(model_name, revision=revision) | |
| # NOTE: cfg is always a LoraConfig here, because the configuration of the loaded model was a LoRA. | |
| # Hence we can expect a runtime_config to exist regardless of config_class. | |
| cfg.runtime_config.ephemeral_gpu_offload = True | |
| cfg.save_pretrained(tmp_dirname) | |
| cfg = config_class.from_pretrained(tmp_dirname) | |
| assert not cfg.runtime_config.ephemeral_gpu_offload | |
| def test_set_attributes(self, config_class, mandatory_kwargs): | |
| # manually set attributes and check if they are correctly written | |
| config = config_class(peft_type="test", **mandatory_kwargs) | |
| # save pretrained | |
| with tempfile.TemporaryDirectory() as tmp_dirname: | |
| config.save_pretrained(tmp_dirname) | |
| config_from_pretrained = config_class.from_pretrained(tmp_dirname) | |
| assert config.to_dict() == config_from_pretrained.to_dict() | |
| def test_config_copy(self, config_class, mandatory_kwargs): | |
| # see https://github.com/huggingface/peft/issues/424 | |
| config = config_class(**mandatory_kwargs) | |
| copied = copy.copy(config) | |
| assert config.to_dict() == copied.to_dict() | |
| def test_config_deepcopy(self, config_class, mandatory_kwargs): | |
| # see https://github.com/huggingface/peft/issues/424 | |
| config = config_class(**mandatory_kwargs) | |
| copied = copy.deepcopy(config) | |
| assert config.to_dict() == copied.to_dict() | |
| def test_config_pickle_roundtrip(self, config_class, mandatory_kwargs): | |
| # see https://github.com/huggingface/peft/issues/424 | |
| config = config_class(**mandatory_kwargs) | |
| copied = pickle.loads(pickle.dumps(config)) | |
| assert config.to_dict() == copied.to_dict() | |
| def test_prompt_encoder_warning_num_layers(self): | |
| # This test checks that if a prompt encoder config is created with an argument that is ignored, there should be | |
| # warning. However, there should be no warning if the default value is used. | |
| kwargs = { | |
| "num_virtual_tokens": 20, | |
| "num_transformer_submodules": 1, | |
| "token_dim": 768, | |
| "encoder_hidden_size": 768, | |
| } | |
| # there should be no warning with just default argument for encoder_num_layer | |
| config = PromptEncoderConfig(**kwargs) | |
| with warnings.catch_warnings(): | |
| PromptEncoder(config) | |
| # when changing encoder_num_layer, there should be a warning for MLP since that value is not used | |
| config = PromptEncoderConfig(encoder_num_layers=123, **kwargs) | |
| with pytest.warns(UserWarning) as record: | |
| PromptEncoder(config) | |
| expected_msg = "for MLP, the argument `encoder_num_layers` is ignored. Exactly 2 MLP layers are used." | |
| assert str(record.list[0].message) == expected_msg | |
| def test_save_pretrained_with_target_modules(self, config_class): | |
| # See #1041, #1045 | |
| config = config_class(target_modules=["a", "list"]) | |
| with tempfile.TemporaryDirectory() as tmp_dirname: | |
| config.save_pretrained(tmp_dirname) | |
| config_from_pretrained = config_class.from_pretrained(tmp_dirname) | |
| assert config.to_dict() == config_from_pretrained.to_dict() | |
| # explicit test that target_modules should be converted to set | |
| assert isinstance(config_from_pretrained.target_modules, set) | |
| def test_regex_with_layer_indexing_lora(self): | |
| # This test checks that an error is raised if `target_modules` is a regex expression and `layers_to_transform` or | |
| # `layers_pattern` are not None | |
| invalid_config1 = {"target_modules": ".*foo", "layers_to_transform": [0]} | |
| invalid_config2 = {"target_modules": ".*foo", "layers_pattern": ["bar"]} | |
| valid_config = {"target_modules": ["foo"], "layers_pattern": ["bar"], "layers_to_transform": [0]} | |
| with pytest.raises(ValueError, match="`layers_to_transform` cannot be used when `target_modules` is a str."): | |
| LoraConfig(**invalid_config1) | |
| with pytest.raises(ValueError, match="`layers_pattern` cannot be used when `target_modules` is a str."): | |
| LoraConfig(**invalid_config2) | |
| # should run without errors | |
| LoraConfig(**valid_config) | |
| def test_ia3_is_feedforward_subset_invalid_config(self): | |
| # This test checks that the IA3 config raises a value error if the feedforward_modules argument | |
| # is not a subset of the target_modules argument | |
| # an example invalid config | |
| invalid_config = {"target_modules": ["k", "v"], "feedforward_modules": ["q"]} | |
| with pytest.raises(ValueError, match="^`feedforward_modules` should be a subset of `target_modules`$"): | |
| IA3Config(**invalid_config) | |
| def test_ia3_is_feedforward_subset_valid_config(self): | |
| # This test checks that the IA3 config is created without errors with valid arguments. | |
| # feedforward_modules should be a subset of target_modules if both are lists | |
| # an example valid config with regex expressions. | |
| valid_config_regex_exp = { | |
| "target_modules": ".*.(SelfAttention|EncDecAttention|DenseReluDense).*(q|v|wo)$", | |
| "feedforward_modules": ".*.DenseReluDense.wo$", | |
| } | |
| # an example valid config with module lists. | |
| valid_config_list = {"target_modules": ["k", "v", "wo"], "feedforward_modules": ["wo"]} | |
| # should run without errors | |
| IA3Config(**valid_config_regex_exp) | |
| IA3Config(**valid_config_list) | |
| def test_adalora_config_r_warning(self): | |
| # This test checks that a warning is raised when r is set other than default in AdaLoraConfig | |
| # No warning should be raised when initializing AdaLoraConfig with default values. | |
| kwargs = {"peft_type": "ADALORA", "task_type": "SEQ_2_SEQ_LM", "init_r": 12, "lora_alpha": 32, "total_step": 1} | |
| # Test that no warning is raised with default initialization | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("error") | |
| try: | |
| AdaLoraConfig(**kwargs) | |
| except Warning: | |
| pytest.fail("AdaLoraConfig raised a warning with default initialization.") | |
| # Test that a warning is raised when r != 8 in AdaLoraConfig | |
| with pytest.warns(UserWarning, match="Note that `r` is not used in AdaLora and will be ignored."): | |
| AdaLoraConfig(r=10, total_step=1) | |
| def test_adalora_config_correct_timing_still_works(self): | |
| pass | |
| def test_adalora_config_valid_timing_works(self, timing_kwargs): | |
| # Make sure that passing correct timing values is not prevented by faulty config checks. | |
| AdaLoraConfig(**timing_kwargs) # does not raise | |
| def test_adalora_config_invalid_total_step_raises(self): | |
| with pytest.raises(ValueError) as e: | |
| AdaLoraConfig(total_step=None) | |
| assert "AdaLoRA does not work when `total_step` is None, supply a value > 0." in str(e) | |
| def test_adalora_config_timing_bounds_error(self, timing_kwargs): | |
| # Check if the user supplied timing values that will certainly fail because it breaks | |
| # AdaLoRA assumptions. | |
| with pytest.raises(ValueError) as e: | |
| AdaLoraConfig(**timing_kwargs) | |
| assert "The supplied schedule values don't allow for a budgeting phase" in str(e) | |
| def test_from_pretrained_forward_compatible(self, config_class, mandatory_kwargs, tmp_path, recwarn): | |
| """ | |
| Make it possible to load configs that contain unknown keys by ignoring them. | |
| The idea is to make PEFT configs forward-compatible with future versions of the library. | |
| """ | |
| config = config_class(**mandatory_kwargs) | |
| config.save_pretrained(tmp_path) | |
| # add a spurious key to the config | |
| with open(tmp_path / "adapter_config.json") as f: | |
| config_dict = json.load(f) | |
| config_dict["foobar"] = "baz" | |
| config_dict["spam"] = 123 | |
| with open(tmp_path / "adapter_config.json", "w") as f: | |
| json.dump(config_dict, f) | |
| msg = f"Unexpected keyword arguments ['foobar', 'spam'] for class {config_class.__name__}, these are ignored." | |
| config_from_pretrained = config_class.from_pretrained(tmp_path) | |
| expected_num_warnings = 1 | |
| # TODO: remove once Bone is removed in v0.19.0 | |
| if config_class == BoneConfig: | |
| expected_num_warnings = 2 # Bone has 1 more warning about it being deprecated | |
| assert len(recwarn) == expected_num_warnings | |
| assert recwarn.list[-1].message.args[0].startswith(msg) | |
| assert "foo" not in config_from_pretrained.to_dict() | |
| assert "spam" not in config_from_pretrained.to_dict() | |
| assert config.to_dict() == config_from_pretrained.to_dict() | |
| assert isinstance(config_from_pretrained, config_class) | |
| def test_from_pretrained_forward_compatible_load_from_peft_config( | |
| self, config_class, mandatory_kwargs, tmp_path, recwarn | |
| ): | |
| """Exact same test as before, but instead of using LoraConfig.from_pretrained, AdaLoraconfig.from_pretrained, | |
| etc. use PeftConfig.from_pretrained. This covers a previously existing bug where only the known arguments from | |
| PeftConfig would be used instead of the more specific config (which is known thanks to the peft_type | |
| attribute). | |
| """ | |
| config = config_class(**mandatory_kwargs) | |
| config.save_pretrained(tmp_path) | |
| # add a spurious key to the config | |
| with open(tmp_path / "adapter_config.json") as f: | |
| config_dict = json.load(f) | |
| config_dict["foobar"] = "baz" | |
| config_dict["spam"] = 123 | |
| with open(tmp_path / "adapter_config.json", "w") as f: | |
| json.dump(config_dict, f) | |
| msg = f"Unexpected keyword arguments ['foobar', 'spam'] for class {config_class.__name__}, these are ignored." | |
| config_from_pretrained = PeftConfig.from_pretrained(tmp_path) # <== use PeftConfig here | |
| expected_num_warnings = 1 | |
| # TODO: remove once Bone is removed in v0.19.0 | |
| if config_class == BoneConfig: | |
| expected_num_warnings = 2 # Bone has 1 more warning about it being deprecated | |
| assert len(recwarn) == expected_num_warnings | |
| assert recwarn.list[-1].message.args[0].startswith(msg) | |
| assert "foo" not in config_from_pretrained.to_dict() | |
| assert "spam" not in config_from_pretrained.to_dict() | |
| assert config.to_dict() == config_from_pretrained.to_dict() | |
| assert isinstance(config_from_pretrained, config_class) | |
| def test_from_pretrained_sanity_check(self, config_class, mandatory_kwargs, tmp_path): | |
| """Following up on the previous test about forward compatibility, we *don't* want any random json to be accepted as | |
| a PEFT config. There should be a minimum set of required keys. | |
| """ | |
| non_peft_json = {"foo": "bar", "baz": 123} | |
| with open(tmp_path / "adapter_config.json", "w") as f: | |
| json.dump(non_peft_json, f) | |
| msg = f"The {config_class.__name__} config that is trying to be loaded is missing required keys: {{'peft_type'}}." | |
| with pytest.raises(TypeError, match=msg): | |
| config_class.from_pretrained(tmp_path) | |
| def test_lora_config_layers_to_transform_validation(self): | |
| """Test that specifying layers_pattern without layers_to_transform raises an error""" | |
| with pytest.raises( | |
| ValueError, match="When `layers_pattern` is specified, `layers_to_transform` must also be specified." | |
| ): | |
| LoraConfig(r=8, lora_alpha=16, target_modules=["query", "value"], layers_pattern="model.layers") | |
| # Test that specifying both layers_to_transform and layers_pattern works fine | |
| config = LoraConfig( | |
| r=8, | |
| lora_alpha=16, | |
| target_modules=["query", "value"], | |
| layers_to_transform=[0, 1, 2], | |
| layers_pattern="model.layers", | |
| ) | |
| assert config.layers_to_transform == [0, 1, 2] | |
| assert config.layers_pattern == "model.layers" | |
| # Test that not specifying either works fine | |
| config = LoraConfig( | |
| r=8, | |
| lora_alpha=16, | |
| target_modules=["query", "value"], | |
| ) | |
| assert config.layers_to_transform is None | |
| assert config.layers_pattern is None | |
| def test_peft_version_is_stored(self, version, config_class, mandatory_kwargs, monkeypatch, tmp_path): | |
| # Check that the PEFT version is automatically stored in/restored from the config file. | |
| from peft import config | |
| monkeypatch.setattr(config, "__version__", version) | |
| peft_config = config_class(**mandatory_kwargs) | |
| assert peft_config.peft_version == version | |
| peft_config.save_pretrained(tmp_path) | |
| with open(tmp_path / "adapter_config.json") as f: | |
| config_dict = json.load(f) | |
| assert config_dict["peft_version"] == version | |
| # ensure that the version from the config is being loaded, not just the current version | |
| monkeypatch.setattr(config, "__version__", "0.1.another-version") | |
| # load from config | |
| config_loaded = PeftConfig.from_pretrained(tmp_path) | |
| assert config_loaded.peft_version == version | |
| # load from json | |
| config_path = tmp_path / "adapter_config.json" | |
| config_json = PeftConfig.from_json_file(str(config_path)) | |
| assert config_json["peft_version"] == version | |
| def test_peft_version_is_dev_version(self, config_class, mandatory_kwargs, monkeypatch, tmp_path): | |
| # When a dev version of PEFT is installed, the actual state of PEFT is ambiguous. Therefore, try to determine | |
| # the commit hash too and store it as part of the version string. | |
| from peft import config | |
| version = "0.15.0.dev7" | |
| monkeypatch.setattr(config, "__version__", version) | |
| def fake_commit_hash(pkg_name): | |
| return "abcdef012345" | |
| monkeypatch.setattr(config, "_get_commit_hash", fake_commit_hash) | |
| peft_config = config_class(**mandatory_kwargs) | |
| expected_version = f"{version}@{fake_commit_hash('peft')}" | |
| assert peft_config.peft_version == expected_version | |
| peft_config.save_pretrained(tmp_path) | |
| config_loaded = PeftConfig.from_pretrained(tmp_path) | |
| assert config_loaded.peft_version == expected_version | |
| def test_peft_version_is_dev_version_but_commit_hash_cannot_be_determined( | |
| self, config_class, mandatory_kwargs, monkeypatch, tmp_path | |
| ): | |
| # There can be cases where PEFT is using a dev version but the commit hash cannot be determined. In this case, | |
| # just store the dev version string. | |
| from peft import config | |
| version = "0.15.0.dev7" | |
| monkeypatch.setattr(config, "__version__", version) | |
| def fake_commit_hash(pkg_name): | |
| return None | |
| monkeypatch.setattr(config, "_get_commit_hash", fake_commit_hash) | |
| peft_config = config_class(**mandatory_kwargs) | |
| assert peft_config.peft_version == version + "@UNKNOWN" | |
| peft_config.save_pretrained(tmp_path) | |
| config_loaded = PeftConfig.from_pretrained(tmp_path) | |
| assert config_loaded.peft_version == version + "@UNKNOWN" | |
| def test_peft_version_warn_when_commit_hash_errors(self, config_class, mandatory_kwargs, monkeypatch, tmp_path): | |
| # We try to get the PEFT commit hash if a dev version is installed. But in case there is any kind of error | |
| # there, we don't want user code to break. Instead, the code should run and a version without commit hash should | |
| # be recorded. In addition, there should be a warning. | |
| from peft import config | |
| version = "0.15.0.dev7" | |
| monkeypatch.setattr(config, "__version__", version) | |
| def fake_commit_hash_raises(pkg_name): | |
| raise Exception("Error for testing purpose") | |
| monkeypatch.setattr(config, "_get_commit_hash", fake_commit_hash_raises) | |
| msg = "A dev version of PEFT is used but there was an error while trying to determine the commit hash" | |
| with pytest.warns(UserWarning, match=msg): | |
| peft_config = config_class(**mandatory_kwargs) | |
| assert peft_config.peft_version == version + "@UNKNOWN" | |