Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,568 Bytes
d711508 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import importlib
import os
from typing import Optional
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoTokenizer,
)
from .config import PeftConfig
from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING
from .peft_model import (
PeftModel,
PeftModelForCausalLM,
PeftModelForFeatureExtraction,
PeftModelForQuestionAnswering,
PeftModelForSeq2SeqLM,
PeftModelForSequenceClassification,
PeftModelForTokenClassification,
)
from .utils.constants import TOKENIZER_CONFIG_NAME
from .utils.other import check_file_exists_on_hf_hub
class _BaseAutoPeftModel:
_target_class = None
_target_peft_class = None
def __init__(self, *args, **kwargs):
# For consistency with transformers: https://github.com/huggingface/transformers/blob/91d7df58b6537d385e90578dac40204cb550f706/src/transformers/models/auto/auto_factory.py#L400
raise EnvironmentError( # noqa: UP024
f"{self.__class__.__name__} is designed to be instantiated "
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
f"`{self.__class__.__name__}.from_config(config)` methods."
)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path,
adapter_name: str = "default",
is_trainable: bool = False,
config: Optional[PeftConfig] = None,
**kwargs,
):
r"""
A wrapper around all the preprocessing steps a user needs to perform in order to load a PEFT model. The kwargs
are passed along to `PeftConfig` that automatically takes care of filtering the kwargs of the Hub methods and
the config object init.
"""
peft_config = PeftConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
base_model_path = peft_config.base_model_name_or_path
task_type = getattr(peft_config, "task_type", None)
if cls._target_class is not None:
target_class = cls._target_class
elif cls._target_class is None and task_type is not None:
# this is only in the case where we use `AutoPeftModel`
raise ValueError(
"Cannot use `AutoPeftModel` with a task type, please use a specific class for your task type. (e.g. `AutoPeftModelForCausalLM` for `task_type='CAUSAL_LM'`)"
)
if task_type is not None:
expected_target_class = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[task_type]
if cls._target_peft_class.__name__ != expected_target_class.__name__:
raise ValueError(
f"Expected target PEFT class: {expected_target_class.__name__}, but you have asked for: {cls._target_peft_class.__name__ }"
" make sure that you are loading the correct model for your task type."
)
elif task_type is None and getattr(peft_config, "auto_mapping", None) is not None:
auto_mapping = getattr(peft_config, "auto_mapping", None)
base_model_class = auto_mapping["base_model_class"]
parent_library_name = auto_mapping["parent_library"]
parent_library = importlib.import_module(parent_library_name)
target_class = getattr(parent_library, base_model_class)
else:
raise ValueError(
"Cannot infer the auto class from the config, please make sure that you are loading the correct model for your task type."
)
base_model = target_class.from_pretrained(base_model_path, **kwargs)
tokenizer_exists = False
if os.path.exists(os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_NAME)):
tokenizer_exists = True
else:
token = kwargs.get("token", None)
if token is None:
token = kwargs.get("use_auth_token", None)
tokenizer_exists = check_file_exists_on_hf_hub(
repo_id=pretrained_model_name_or_path,
filename=TOKENIZER_CONFIG_NAME,
revision=kwargs.get("revision", None),
repo_type=kwargs.get("repo_type", None),
token=token,
)
if tokenizer_exists:
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=kwargs.get("trust_remote_code", False)
)
base_model.resize_token_embeddings(len(tokenizer))
return cls._target_peft_class.from_pretrained(
base_model,
pretrained_model_name_or_path,
adapter_name=adapter_name,
is_trainable=is_trainable,
config=config,
**kwargs,
)
class AutoPeftModel(_BaseAutoPeftModel):
_target_class = None
_target_peft_class = PeftModel
class AutoPeftModelForCausalLM(_BaseAutoPeftModel):
_target_class = AutoModelForCausalLM
_target_peft_class = PeftModelForCausalLM
class AutoPeftModelForSeq2SeqLM(_BaseAutoPeftModel):
_target_class = AutoModelForSeq2SeqLM
_target_peft_class = PeftModelForSeq2SeqLM
class AutoPeftModelForSequenceClassification(_BaseAutoPeftModel):
_target_class = AutoModelForSequenceClassification
_target_peft_class = PeftModelForSequenceClassification
class AutoPeftModelForTokenClassification(_BaseAutoPeftModel):
_target_class = AutoModelForTokenClassification
_target_peft_class = PeftModelForTokenClassification
class AutoPeftModelForQuestionAnswering(_BaseAutoPeftModel):
_target_class = AutoModelForQuestionAnswering
_target_peft_class = PeftModelForQuestionAnswering
class AutoPeftModelForFeatureExtraction(_BaseAutoPeftModel):
_target_class = AutoModel
_target_peft_class = PeftModelForFeatureExtraction
|