Spaces:
Runtime error
Runtime error
File size: 7,800 Bytes
f6f64ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms
from .adapter import init_adapter
from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.valuehead import load_valuehead_params
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
class TokenizerModule(TypedDict):
tokenizer: "PreTrainedTokenizer"
processor: Optional["ProcessorMixin"]
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
r"""
Gets arguments to load config/tokenizer/model.
Note: including inplace operation of model_args.
"""
skip_check_imports()
model_args.model_name_or_path = try_download_model_from_ms(model_args)
return {
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"token": model_args.hf_hub_token,
}
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
r"""
Loads pretrained tokenizer and optionally loads processor.
Note: including inplace operation of model_args.
"""
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
try:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens,
padding_side="right",
**init_kwargs,
)
except ValueError: # try the fast one
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=True,
padding_side="right",
**init_kwargs,
)
if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False,
)
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
patch_tokenizer(tokenizer)
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, config, tokenizer, model_args)
except Exception:
processor = None
# Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
if "Processor" not in processor.__class__.__name__:
processor = None
return {"tokenizer": tokenizer, "processor": processor}
def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
r"""
Loads model config.
"""
init_kwargs = _get_init_kwargs(model_args)
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
def load_model(
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool = False,
add_valuehead: bool = False,
) -> "PreTrainedModel":
r"""
Loads pretrained model.
"""
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
model = None
lazy_load = False
if model_args.use_unsloth:
if model_args.adapter_name_or_path is not None:
lazy_load = True
elif is_trainable:
model = load_unsloth_pretrained_model(config, model_args)
if model is None and not lazy_load:
init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs)
else:
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models
load_class = AutoModelForVision2Seq
else:
load_class = AutoModelForCausalLM
if model_args.train_from_scratch:
model = load_class.from_config(config)
else:
model = load_class.from_pretrained(**init_kwargs)
if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args)
if not lazy_load:
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
register_autoclass(config, model, tokenizer)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable)
if add_valuehead:
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
patch_valuehead_model(model)
if model_args.adapter_name_or_path is not None:
vhead_path = model_args.adapter_name_or_path[-1]
else:
vhead_path = model_args.model_name_or_path
vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
if not is_trainable:
model.requires_grad_(False)
for param in model.parameters():
if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32:
param.data = param.data.to(model_args.compute_dtype)
model.eval()
else:
model.train()
trainable_params, all_param = count_parameters(model)
if is_trainable:
param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
)
else:
param_stats = "all params: {:,}".format(all_param)
logger.info(param_stats)
if model_args.print_param_status:
for name, param in model.named_parameters():
print(
"name: {}, dtype: {}, device: {}, trainable: {}".format(
name, param.dtype, param.device, param.requires_grad
)
)
return model
|