|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from types import MethodType |
|
from typing import TYPE_CHECKING, Any, Dict |
|
|
|
import torch |
|
from peft import PeftModel |
|
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available |
|
from transformers.integrations import is_deepspeed_zero3_enabled |
|
from transformers.modeling_utils import is_fsdp_enabled |
|
|
|
from ..extras.logging import get_logger |
|
from ..extras.misc import infer_optim_dtype |
|
from .model_utils.attention import configure_attn_implementation, print_attn_implementation |
|
from .model_utils.checkpointing import prepare_model_for_training |
|
from .model_utils.embedding import resize_embedding_layer |
|
from .model_utils.longlora import configure_longlora |
|
from .model_utils.moe import add_z3_leaf_module, configure_moe |
|
from .model_utils.quantization import configure_quantization |
|
from .model_utils.rope import configure_rope |
|
from .model_utils.valuehead import prepare_valuehead_model |
|
|
|
if TYPE_CHECKING: |
|
from transformers import PretrainedConfig, PreTrainedTokenizer |
|
from trl import AutoModelForCausalLMWithValueHead |
|
|
|
from ..hparams import ModelArguments |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None: |
|
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): |
|
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) |
|
|
|
|
|
def patch_config( |
|
config: "PretrainedConfig", |
|
tokenizer: "PreTrainedTokenizer", |
|
model_args: "ModelArguments", |
|
init_kwargs: Dict[str, Any], |
|
is_trainable: bool, |
|
) -> None: |
|
if model_args.compute_dtype is None: |
|
if model_args.infer_dtype != "auto" and not is_trainable: |
|
model_args.compute_dtype = getattr(torch, model_args.infer_dtype) |
|
else: |
|
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) |
|
|
|
if is_torch_npu_available(): |
|
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"] |
|
torch.npu.set_compile_mode(jit_compile=use_jit_compile) |
|
|
|
configure_attn_implementation(config, model_args, is_trainable) |
|
configure_rope(config, model_args, is_trainable) |
|
configure_longlora(config, model_args, is_trainable) |
|
configure_quantization(config, tokenizer, model_args, init_kwargs) |
|
configure_moe(config, model_args, is_trainable) |
|
|
|
|
|
|
|
if model_args.use_cache and not is_trainable: |
|
setattr(config, "use_cache", True) |
|
logger.info("Using KV cache for faster generation.") |
|
|
|
if getattr(config, "model_type", None) == "qwen": |
|
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2") |
|
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: |
|
setattr(config, dtype_name, model_args.compute_dtype == dtype) |
|
|
|
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2": |
|
setattr(config, "use_cache", False) |
|
|
|
|
|
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled()) |
|
|
|
|
|
|
|
|
|
if (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()) or model_args.quantization_bit is not None: |
|
init_kwargs["torch_dtype"] = model_args.compute_dtype |
|
|
|
if init_kwargs["low_cpu_mem_usage"]: |
|
if "device_map" not in init_kwargs and model_args.device_map: |
|
init_kwargs["device_map"] = model_args.device_map |
|
|
|
if init_kwargs.get("device_map", None) == "auto": |
|
init_kwargs["offload_folder"] = model_args.offload_folder |
|
|
|
|
|
def patch_model( |
|
model: "PreTrainedModel", |
|
tokenizer: "PreTrainedTokenizer", |
|
model_args: "ModelArguments", |
|
is_trainable: bool, |
|
add_valuehead: bool, |
|
) -> None: |
|
gen_config = model.generation_config |
|
if not gen_config.do_sample and ( |
|
(gen_config.temperature is not None and gen_config.temperature != 1.0) |
|
or (gen_config.top_p is not None and gen_config.top_p != 1.0) |
|
or (gen_config.typical_p is not None and gen_config.typical_p != 1.0) |
|
): |
|
gen_config.do_sample = True |
|
|
|
if "GenerationMixin" not in str(model.generate.__func__): |
|
model.generate = MethodType(PreTrainedModel.generate, model) |
|
|
|
if add_valuehead: |
|
prepare_valuehead_model(model) |
|
|
|
if model_args.resize_vocab: |
|
resize_embedding_layer(model, tokenizer) |
|
|
|
if model_args.visual_inputs: |
|
autocast_projector_dtype(model, model_args) |
|
|
|
if is_trainable: |
|
prepare_model_for_training(model, model_args) |
|
add_z3_leaf_module(model) |
|
|
|
if not model_args.use_unsloth: |
|
print_attn_implementation(model.config) |
|
|
|
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: |
|
def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None: |
|
if isinstance(self.pretrained_model, PreTrainedModel): |
|
self.pretrained_model.tie_weights() |
|
|
|
def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: |
|
if isinstance(self.pretrained_model, PreTrainedModel): |
|
return self.pretrained_model.get_input_embeddings() |
|
|
|
def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: |
|
if isinstance(self.pretrained_model, PreTrainedModel): |
|
return self.pretrained_model.get_output_embeddings() |
|
|
|
def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None: |
|
if isinstance(self.pretrained_model, PeftModel): |
|
self.pretrained_model.create_or_update_model_card(output_dir) |
|
|
|
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] |
|
setattr(model, "_keys_to_ignore_on_save", ignore_modules) |
|
setattr(model, "tie_weights", MethodType(tie_weights, model)) |
|
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model)) |
|
setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model)) |
|
setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model)) |
|
|