|
"""Module for models and model loading""" |
|
import logging |
|
import math |
|
import os |
|
from typing import Optional, Tuple |
|
|
|
import bitsandbytes as bnb |
|
import torch |
|
import transformers |
|
from optimum.bettertransformer import BetterTransformer |
|
from peft import PeftConfig, prepare_model_for_kbit_training |
|
from peft.tuners.lora import QuantLinear |
|
from transformers import ( |
|
AddedToken, |
|
AutoConfig, |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
BitsAndBytesConfig, |
|
GPTQConfig, |
|
LlamaConfig, |
|
PreTrainedModel, |
|
PreTrainedTokenizerBase, |
|
) |
|
|
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN |
|
from axolotl.utils.bench import log_gpu_memory_usage |
|
from axolotl.utils.dict import DictDefault |
|
|
|
LOG = logging.getLogger("axolotl") |
|
|
|
|
|
def load_model_config(cfg): |
|
model_config_name = cfg.base_model_config or cfg.base_model |
|
trust_remote_code = cfg.trust_remote_code is True |
|
return AutoConfig.from_pretrained( |
|
model_config_name, trust_remote_code=trust_remote_code |
|
) |
|
|
|
|
|
def load_tokenizer(cfg): |
|
tokenizer_kwargs = {} |
|
use_fast = True |
|
|
|
if cfg.tokenizer_use_fast is not None: |
|
use_fast = cfg.tokenizer_use_fast |
|
if cfg.tokenizer_legacy is not None: |
|
|
|
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy |
|
|
|
tokenizer_cls = AutoTokenizer |
|
if cfg.tokenizer_type: |
|
tokenizer_cls = getattr(transformers, cfg.tokenizer_type) |
|
|
|
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config |
|
tokenizer = tokenizer_cls.from_pretrained( |
|
tokenizer_config, |
|
trust_remote_code=cfg.trust_remote_code or False, |
|
use_fast=use_fast, |
|
**tokenizer_kwargs, |
|
) |
|
|
|
if ( |
|
tokenizer.__class__.__name__ |
|
in [ |
|
"LlamaTokenizer", |
|
"LlamaTokenizerFast", |
|
"CodeLlamaTokenizer", |
|
] |
|
and hasattr(tokenizer, "pad_token") |
|
and not tokenizer.pad_token |
|
): |
|
|
|
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN |
|
|
|
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": |
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing: |
|
tokenizer.padding_side = "left" |
|
|
|
if cfg.special_tokens: |
|
for k, val in cfg.special_tokens.items(): |
|
tokenizer.add_special_tokens( |
|
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)} |
|
) |
|
if cfg.tokens: |
|
tokenizer.add_tokens( |
|
[ |
|
AddedToken(token, rstrip=False, lstrip=False, normalized=False) |
|
for token in cfg.tokens |
|
] |
|
) |
|
|
|
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") |
|
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") |
|
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") |
|
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") |
|
|
|
return tokenizer |
|
|
|
|
|
def load_model( |
|
cfg: DictDefault, |
|
tokenizer: PreTrainedTokenizerBase, |
|
inference: bool = False, |
|
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: |
|
""" |
|
Load a model for a given configuration and tokenizer. |
|
""" |
|
base_model = cfg.base_model |
|
base_model_config = cfg.base_model_config |
|
model_type = cfg.model_type |
|
model_config = load_model_config(cfg) |
|
|
|
|
|
load_in_8bit = cfg.load_in_8bit |
|
|
|
if hasattr(model_config, "model_type") and model_config.model_type == "btlm": |
|
if cfg.flash_attention: |
|
from axolotl.monkeypatch.btlm_attn_hijack_flash import ( |
|
replace_btlm_attn_with_flash_attn, |
|
) |
|
|
|
replace_btlm_attn_with_flash_attn(cfg.base_model) |
|
|
|
if ( |
|
hasattr(model_config, "model_type") |
|
and model_config.model_type == "stablelm_epoch" |
|
): |
|
if cfg.flash_attention and cfg.sample_packing: |
|
from axolotl.monkeypatch.stablelm_attn_hijack_flash import ( |
|
replace_stablelm_attn_with_flash_attn, |
|
) |
|
|
|
replace_stablelm_attn_with_flash_attn(cfg.base_model) |
|
|
|
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing: |
|
if cfg.device not in ["mps", "cpu"] and not inference: |
|
from axolotl.monkeypatch.llama_attn_hijack_flash import ( |
|
replace_llama_attn_with_flash_attn, |
|
) |
|
|
|
LOG.info("patching with flash attention for sample packing") |
|
replace_llama_attn_with_flash_attn( |
|
packed=cfg.sample_packing, |
|
cross_entropy=cfg.flash_attn_cross_entropy, |
|
rms_norm=cfg.flash_attn_rms_norm, |
|
) |
|
elif cfg.is_llama_derived_model and cfg.xformers_attention: |
|
from axolotl.monkeypatch.llama_attn_hijack_xformers import ( |
|
hijack_llama_attention, |
|
) |
|
|
|
LOG.info("patching with xformers attention") |
|
hijack_llama_attention() |
|
elif cfg.is_llama_derived_model and cfg.sdp_attention: |
|
from axolotl.monkeypatch.llama_attn_hijack_sdp import hijack_llama_sdp_attention |
|
|
|
LOG.info("patching with sdp attention") |
|
hijack_llama_sdp_attention() |
|
elif cfg.is_llama_derived_model and cfg.landmark_attention: |
|
from axolotl.monkeypatch.llama_landmark_attn import ( |
|
MEM_TOKEN, |
|
patch_llama_with_landmark_attn, |
|
) |
|
|
|
LOG.info("patching with landmark attention") |
|
patch_llama_with_landmark_attn() |
|
|
|
|
|
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]}) |
|
|
|
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing: |
|
from axolotl.monkeypatch.mistral_attn_hijack_flash import ( |
|
replace_mistral_attn_with_flash_attn, |
|
) |
|
|
|
LOG.info("patching with flash attention") |
|
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) |
|
|
|
if cfg.is_llama_derived_model and cfg.xpos_rope: |
|
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import ( |
|
replace_llama_rope_with_xpos_rope, |
|
) |
|
|
|
LOG.info("patching with xpos rope") |
|
replace_llama_rope_with_xpos_rope() |
|
|
|
if ( |
|
cfg.is_llama_derived_model |
|
and (cfg.max_packed_sequence_len or cfg.sample_packing) |
|
and not inference |
|
): |
|
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask |
|
|
|
LOG.info("patching _expand_mask") |
|
hijack_expand_mask() |
|
|
|
model_kwargs = {} |
|
|
|
model_kwargs["device_map"] = cfg.device_map |
|
model_kwargs["torch_dtype"] = cfg.torch_dtype |
|
|
|
if cfg.model_revision: |
|
model_kwargs["revision"] = cfg.model_revision |
|
if cfg.gptq: |
|
if not hasattr(model_config, "quantization_config"): |
|
LOG.warning("model config does not contain quantization_config information") |
|
else: |
|
if cfg.gptq_disable_exllama is not None: |
|
model_config.quantization_config[ |
|
"disable_exllama" |
|
] = cfg.gptq_disable_exllama |
|
model_kwargs["quantization_config"] = GPTQConfig( |
|
**model_config.quantization_config |
|
) |
|
if cfg.adapter == "qlora" and cfg.load_in_4bit: |
|
model_kwargs["quantization_config"] = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
llm_int8_threshold=6.0, |
|
llm_int8_has_fp16_weight=False, |
|
bnb_4bit_compute_dtype=cfg.torch_dtype, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
) |
|
|
|
if cfg.flash_attention and not cfg.sample_packing: |
|
if ( |
|
cfg.is_llama_derived_model |
|
or cfg.is_falcon_derived_model |
|
or cfg.is_mistral_derived_model |
|
): |
|
model_kwargs["use_flash_attention_2"] = True |
|
|
|
try: |
|
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: |
|
from transformers import LlamaForCausalLM |
|
|
|
config_kwargs = {} |
|
if cfg.rope_scaling: |
|
config_kwargs["rope_scaling"] = cfg.rope_scaling |
|
config = LlamaConfig.from_pretrained( |
|
base_model_config, |
|
**config_kwargs, |
|
) |
|
model = LlamaForCausalLM.from_pretrained( |
|
base_model, |
|
config=config, |
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, |
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, |
|
**model_kwargs, |
|
) |
|
|
|
if cfg.flash_attention and not inference: |
|
from axolotl.monkeypatch.llama_attn_hijack_flash import ( |
|
replace_llama_mlp_with_swiglu, |
|
replace_llama_qkv_with_fused, |
|
) |
|
|
|
if cfg.flash_attn_fuse_mlp: |
|
LOG.info("patching with SwiGLU") |
|
replace_llama_mlp_with_swiglu(model) |
|
|
|
if cfg.flash_attn_fuse_qkv: |
|
LOG.info("patching with fused QKV") |
|
replace_llama_qkv_with_fused(model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif model_type == "MixFormerSequentialForCausalLM": |
|
from axolotl.models.phi import MixFormerSequentialForCausalLM |
|
|
|
model = MixFormerSequentialForCausalLM.from_pretrained( |
|
base_model, |
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, |
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, |
|
**model_kwargs, |
|
) |
|
elif model_type and not cfg.trust_remote_code: |
|
if cfg.gptq: |
|
model = AutoModelForCausalLM.from_pretrained( |
|
base_model, |
|
trust_remote_code=cfg.trust_remote_code or False, |
|
**model_kwargs, |
|
) |
|
else: |
|
model = getattr(transformers, model_type).from_pretrained( |
|
base_model, |
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, |
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, |
|
trust_remote_code=cfg.trust_remote_code or False, |
|
**model_kwargs, |
|
) |
|
else: |
|
config = AutoConfig.from_pretrained( |
|
base_model, |
|
trust_remote_code=cfg.trust_remote_code or False, |
|
) |
|
|
|
|
|
if ( |
|
hasattr(config, "max_seq_len") |
|
and config.max_seq_len |
|
and cfg.sequence_len > config.max_seq_len |
|
): |
|
config.max_seq_len = cfg.sequence_len |
|
LOG.warning(f"increasing context length to {cfg.sequence_len}") |
|
elif ( |
|
hasattr(config, "max_sequence_length") |
|
and config.max_sequence_length |
|
and cfg.sequence_len > config.max_sequence_length |
|
): |
|
config.max_sequence_length = cfg.sequence_len |
|
LOG.warning(f"increasing context length to {cfg.sequence_len}") |
|
if cfg.gptq: |
|
model = AutoModelForCausalLM.from_pretrained( |
|
base_model, |
|
config=config, |
|
trust_remote_code=cfg.trust_remote_code or False, |
|
**model_kwargs, |
|
) |
|
else: |
|
model = AutoModelForCausalLM.from_pretrained( |
|
base_model, |
|
config=config, |
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, |
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, |
|
trust_remote_code=cfg.trust_remote_code or False, |
|
**model_kwargs, |
|
) |
|
except Exception as err: |
|
LOG.error( |
|
"Exception raised attempting to load model, retrying with AutoModelForCausalLM" |
|
) |
|
LOG.exception(err) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
base_model, |
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, |
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, |
|
trust_remote_code=cfg.trust_remote_code or False, |
|
**model_kwargs, |
|
) |
|
|
|
embeddings_len = ( |
|
math.ceil(len(tokenizer) / 32) * 32 |
|
if cfg.resize_token_embeddings_to_32x |
|
else len(tokenizer) |
|
) |
|
if model.get_input_embeddings().num_embeddings < embeddings_len: |
|
model.resize_token_embeddings(embeddings_len) |
|
else: |
|
model.tie_weights() |
|
|
|
if ( |
|
hasattr(model.config, "max_position_embeddings") |
|
and model.config.max_position_embeddings |
|
and cfg.sequence_len > model.config.max_position_embeddings |
|
): |
|
LOG.warning( |
|
f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {cfg.sequence_len}" |
|
) |
|
model.config.max_position_embeddings = cfg.sequence_len |
|
|
|
if ( |
|
hasattr(model.config, "bos_token_id") |
|
and model.config.bos_token_id |
|
and model.config.bos_token_id != tokenizer.bos_token_id |
|
): |
|
model.config.bos_token_id = tokenizer.bos_token_id |
|
|
|
if ( |
|
hasattr(model.config, "eos_token_id") |
|
and model.config.eos_token_id |
|
and model.config.eos_token_id != tokenizer.eos_token_id |
|
): |
|
model.config.eos_token_id = tokenizer.eos_token_id |
|
|
|
if model.device.type == "cuda": |
|
log_gpu_memory_usage(LOG, "after model load", model.device) |
|
|
|
|
|
for name, module in model.named_modules(): |
|
if "norm" in name: |
|
module.to(torch.float32) |
|
if model_config.model_type == "btlm": |
|
|
|
continue |
|
if "lm_head" in name or "embed_tokens" in name: |
|
if hasattr(module, "weight"): |
|
module.to(torch.float32) |
|
|
|
needs_fa2_dtype = cfg.adapter or cfg.fsdp |
|
if (cfg.adapter == "lora" and load_in_8bit) or ( |
|
cfg.adapter == "qlora" and cfg.load_in_4bit |
|
): |
|
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") |
|
if cfg.gradient_checkpointing: |
|
model.gradient_checkpointing_enable() |
|
model = prepare_model_for_kbit_training( |
|
model, use_gradient_checkpointing=cfg.gradient_checkpointing |
|
) |
|
needs_fa2_dtype = True |
|
|
|
|
|
|
|
if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model): |
|
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype) |
|
for name, module in model.named_modules(): |
|
if "norm" in name: |
|
module.to(cfg.torch_dtype) |
|
if "lm_head" in name or "embed_tokens" in name: |
|
if hasattr(module, "weight"): |
|
module.to(cfg.torch_dtype) |
|
|
|
model, lora_config = load_adapter(model, cfg, cfg.adapter) |
|
|
|
if cfg.ddp and not load_in_8bit: |
|
model.to(f"cuda:{cfg.local_rank}") |
|
|
|
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: |
|
setattr(model, "is_parallelizable", True) |
|
setattr(model, "model_parallel", True) |
|
|
|
requires_grad = [] |
|
for name, param in model.named_parameters(recurse=True): |
|
if param.requires_grad: |
|
requires_grad.append(f"{name}: {param.requires_grad}") |
|
if len(requires_grad) == 0: |
|
LOG.warning("there are no parameters that require gradient updates") |
|
model.config.use_cache = False |
|
|
|
if cfg.flash_optimum: |
|
model = BetterTransformer.transform(model) |
|
|
|
if cfg.adapter is not None: |
|
log_gpu_memory_usage(LOG, "after adapters", model.device) |
|
|
|
|
|
return model, lora_config |
|
|
|
|
|
def load_adapter(model, cfg, adapter, inference=False): |
|
|
|
|
|
if adapter is None: |
|
return model, None |
|
if hasattr(model, "enable_input_require_grads"): |
|
model.enable_input_require_grads() |
|
if adapter in ["lora", "qlora"]: |
|
return load_lora(model, cfg, inference=inference) |
|
if adapter == "llama-adapter": |
|
return load_llama_adapter(model, cfg) |
|
|
|
raise NotImplementedError(f"{adapter} peft adapter not available") |
|
|
|
|
|
def load_llama_adapter(model, cfg): |
|
|
|
from peft import AdaptionPromptConfig, PeftModel, get_peft_model |
|
|
|
peft_config = AdaptionPromptConfig( |
|
adapter_layers=cfg.peft_adapter.layers, |
|
adapter_len=cfg.peft_adapter.len, |
|
task_type="CAUSAL_LM", |
|
) |
|
|
|
if cfg.lora_model_dir: |
|
LOG.debug("Loading pretained PEFT - llama_adapter") |
|
model = PeftModel.from_pretrained( |
|
model, |
|
cfg.lora_model_dir, |
|
torch_dtype=torch.float16, |
|
) |
|
else: |
|
model = get_peft_model(model, peft_config) |
|
|
|
model.print_trainable_parameters() |
|
|
|
return model, peft_config |
|
|
|
|
|
def find_all_linear_names(model): |
|
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear) |
|
lora_module_names = set() |
|
for name, module in model.named_modules(): |
|
if ( |
|
isinstance(module, cls) |
|
or "Linear" in module.__class__.__name__ |
|
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",) |
|
): |
|
names = name.split(".") |
|
lora_module_names.add(names[0] if len(names) == 1 else names[-1]) |
|
|
|
if "lm_head" in lora_module_names: |
|
lora_module_names.remove("lm_head") |
|
|
|
return list(lora_module_names) |
|
|
|
|
|
def load_lora(model, cfg, inference=False): |
|
|
|
|
|
from peft import LoraConfig, PeftModel, get_peft_model |
|
|
|
lora_target_modules = list(cfg.lora_target_modules or []) |
|
|
|
if cfg.lora_target_linear: |
|
linear_names = find_all_linear_names(model) |
|
LOG.info(f"found linear modules: {repr(linear_names)}") |
|
lora_target_modules = list(set(lora_target_modules + linear_names)) |
|
|
|
lora_config = LoraConfig( |
|
r=cfg.lora_r, |
|
lora_alpha=cfg.lora_alpha, |
|
target_modules=lora_target_modules, |
|
lora_dropout=cfg.lora_dropout, |
|
fan_in_fan_out=cfg.lora_fan_in_fan_out, |
|
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
) |
|
|
|
if cfg.lora_model_dir: |
|
LOG.debug("Loading pretained PEFT - LoRA") |
|
model = PeftModel.from_pretrained( |
|
model, |
|
cfg.lora_model_dir, |
|
is_trainable=(not inference), |
|
) |
|
else: |
|
model = get_peft_model(model, lora_config) |
|
|
|
model.print_trainable_parameters() |
|
|
|
return model, lora_config |
|
|