|  | """Module for models and model loading""" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import logging | 
					
						
						|  | import math | 
					
						
						|  | import os | 
					
						
						|  | from pathlib import Path | 
					
						
						|  | from typing import TYPE_CHECKING, Optional, Tuple | 
					
						
						|  |  | 
					
						
						|  | import bitsandbytes as bnb | 
					
						
						|  | import torch | 
					
						
						|  | import transformers | 
					
						
						|  | from transformers import PreTrainedModel | 
					
						
						|  | from transformers import ( | 
					
						
						|  | AutoConfig, | 
					
						
						|  | AutoModelForCausalLM, | 
					
						
						|  | AutoTokenizer, | 
					
						
						|  | BitsAndBytesConfig, | 
					
						
						|  | LlamaConfig, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | from transformers import LlamaForCausalLM | 
					
						
						|  | except ImportError: | 
					
						
						|  | logging.warning( | 
					
						
						|  | "This version of transformers does not support Llama. Consider upgrading." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN | 
					
						
						|  |  | 
					
						
						|  | if TYPE_CHECKING: | 
					
						
						|  | from peft import PeftConfig | 
					
						
						|  |  | 
					
						
						|  | from axolotl.utils.dict import DictDefault | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_tokenizer( | 
					
						
						|  | tokenizer_config, | 
					
						
						|  | tokenizer_type, | 
					
						
						|  | cfg, | 
					
						
						|  | ): | 
					
						
						|  | if tokenizer_type: | 
					
						
						|  | tokenizer = getattr(transformers, tokenizer_type).from_pretrained( | 
					
						
						|  | tokenizer_config, | 
					
						
						|  | trust_remote_code=cfg.trust_remote_code or False, | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | tokenizer = AutoTokenizer.from_pretrained( | 
					
						
						|  | tokenizer_config, | 
					
						
						|  | trust_remote_code=cfg.trust_remote_code or False, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") | 
					
						
						|  | logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") | 
					
						
						|  | logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") | 
					
						
						|  | logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") | 
					
						
						|  |  | 
					
						
						|  | if tokenizer.__class__.__name__ in [ | 
					
						
						|  | "LlamaTokenizer", | 
					
						
						|  | "LlamaTokenizerFast", | 
					
						
						|  | ]: | 
					
						
						|  | tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN | 
					
						
						|  |  | 
					
						
						|  | if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": | 
					
						
						|  | tokenizer.add_special_tokens({"pad_token": "[PAD]"}) | 
					
						
						|  | os.environ["TOKENIZERS_PARALLELISM"] = "false" | 
					
						
						|  |  | 
					
						
						|  | if cfg.special_tokens: | 
					
						
						|  | for k, val in cfg.special_tokens.items(): | 
					
						
						|  | tokenizer.add_special_tokens({k: val}) | 
					
						
						|  | if cfg.tokens: | 
					
						
						|  | tokenizer.add_tokens(list(cfg.tokens)) | 
					
						
						|  |  | 
					
						
						|  | return tokenizer | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_model( | 
					
						
						|  | base_model, | 
					
						
						|  | base_model_config, | 
					
						
						|  | model_type, | 
					
						
						|  | tokenizer, | 
					
						
						|  | cfg, | 
					
						
						|  | adapter="lora", | 
					
						
						|  | inference=False, | 
					
						
						|  | ): | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | Load a model from a base model and a model type. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | load_in_8bit = cfg.load_in_8bit | 
					
						
						|  | is_llama_derived_model = "llama" in base_model or ( | 
					
						
						|  | cfg.model_type and "llama" in cfg.model_type.lower() | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if is_llama_derived_model and cfg.flash_attention: | 
					
						
						|  | if cfg.device not in ["mps", "cpu"] and inference is False: | 
					
						
						|  | from axolotl.flash_attn import replace_llama_attn_with_flash_attn | 
					
						
						|  |  | 
					
						
						|  | logging.info("patching with flash attention") | 
					
						
						|  | replace_llama_attn_with_flash_attn() | 
					
						
						|  | elif is_llama_derived_model and cfg.xformers_attention: | 
					
						
						|  | from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import ( | 
					
						
						|  | hijack_llama_attention, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | logging.info("patching with xformers attention") | 
					
						
						|  | hijack_llama_attention() | 
					
						
						|  |  | 
					
						
						|  | if cfg.bf16: | 
					
						
						|  | torch_dtype = torch.bfloat16 | 
					
						
						|  | elif cfg.load_in_8bit or cfg.fp16: | 
					
						
						|  | torch_dtype = torch.float16 | 
					
						
						|  | else: | 
					
						
						|  | torch_dtype = torch.float32 | 
					
						
						|  | try: | 
					
						
						|  | if cfg.gptq: | 
					
						
						|  | from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import ( | 
					
						
						|  | replace_peft_model_with_int4_lora_model, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | replace_peft_model_with_int4_lora_model() | 
					
						
						|  | from peft import prepare_model_for_int8_training | 
					
						
						|  | except Exception as err: | 
					
						
						|  | logging.exception(err) | 
					
						
						|  | raise err | 
					
						
						|  |  | 
					
						
						|  | model_kwargs = {} | 
					
						
						|  | if cfg.adapter == "qlora" and cfg.load_in_4bit: | 
					
						
						|  | model_kwargs["quantization_config"] = BitsAndBytesConfig( | 
					
						
						|  | load_in_4bit=True, | 
					
						
						|  | llm_int8_threshold=6.0, | 
					
						
						|  | llm_int8_has_fp16_weight=False, | 
					
						
						|  | bnb_4bit_compute_dtype=torch_dtype, | 
					
						
						|  | bnb_4bit_use_double_quant=True, | 
					
						
						|  | bnb_4bit_quant_type="nf4", | 
					
						
						|  | ) | 
					
						
						|  | try: | 
					
						
						|  | if cfg.gptq and is_llama_derived_model: | 
					
						
						|  | from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram | 
					
						
						|  | from huggingface_hub import snapshot_download | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | snapshot_download_kwargs = {} | 
					
						
						|  | if cfg.base_model_ignore_patterns: | 
					
						
						|  | snapshot_download_kwargs[ | 
					
						
						|  | "ignore_patterns" | 
					
						
						|  | ] = cfg.base_model_ignore_patterns | 
					
						
						|  | cache_model_path = Path( | 
					
						
						|  | snapshot_download(base_model, **snapshot_download_kwargs) | 
					
						
						|  | ) | 
					
						
						|  | files = ( | 
					
						
						|  | list(cache_model_path.glob("*.pt")) | 
					
						
						|  | + list(cache_model_path.glob("*.safetensors")) | 
					
						
						|  | + list(cache_model_path.glob("*.bin")) | 
					
						
						|  | ) | 
					
						
						|  | if len(files) > 0: | 
					
						
						|  | model_path = str(files[0]) | 
					
						
						|  | else: | 
					
						
						|  | logging.warning( | 
					
						
						|  | "unable to find a cached model file, this will likely fail..." | 
					
						
						|  | ) | 
					
						
						|  | model_path = str(cache_model_path) | 
					
						
						|  | except Exception: | 
					
						
						|  | model_path = cfg.base_model | 
					
						
						|  | model, _ = load_llama_model_4bit_low_ram( | 
					
						
						|  | base_model_config if base_model_config else base_model, | 
					
						
						|  | model_path, | 
					
						
						|  | device_map=cfg.device_map, | 
					
						
						|  | half=cfg.fp16, | 
					
						
						|  | groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1, | 
					
						
						|  | is_v1_model=cfg.gptq_model_v1 | 
					
						
						|  | if cfg.gptq_model_v1 is not None | 
					
						
						|  | else True, | 
					
						
						|  | ) | 
					
						
						|  | load_in_8bit = False | 
					
						
						|  | elif is_llama_derived_model and "LlamaForCausalLM" in globals(): | 
					
						
						|  | config = LlamaConfig.from_pretrained(base_model_config) | 
					
						
						|  | model = LlamaForCausalLM.from_pretrained( | 
					
						
						|  | base_model, | 
					
						
						|  | config=config, | 
					
						
						|  | load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, | 
					
						
						|  | load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, | 
					
						
						|  | torch_dtype=torch_dtype, | 
					
						
						|  | device_map="auto" if cfg.world_size == 1 else cfg.device_map, | 
					
						
						|  | **model_kwargs, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | elif model_type: | 
					
						
						|  | model = getattr(transformers, model_type).from_pretrained( | 
					
						
						|  | base_model, | 
					
						
						|  | load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, | 
					
						
						|  | load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, | 
					
						
						|  | torch_dtype=torch_dtype, | 
					
						
						|  | device_map=cfg.device_map, | 
					
						
						|  | trust_remote_code=cfg.trust_remote_code or False, | 
					
						
						|  | **model_kwargs, | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | config = AutoConfig.from_pretrained( | 
					
						
						|  | base_model, | 
					
						
						|  | trust_remote_code=cfg.trust_remote_code or False, | 
					
						
						|  | ) | 
					
						
						|  | model = AutoModelForCausalLM.from_pretrained( | 
					
						
						|  | base_model, | 
					
						
						|  | config=config, | 
					
						
						|  | load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, | 
					
						
						|  | load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, | 
					
						
						|  | torch_dtype=torch_dtype, | 
					
						
						|  | device_map=cfg.device_map, | 
					
						
						|  | trust_remote_code=cfg.trust_remote_code or False, | 
					
						
						|  | **model_kwargs, | 
					
						
						|  | ) | 
					
						
						|  | except Exception as err: | 
					
						
						|  | logging.error( | 
					
						
						|  | "Exception raised attempting to load model, retrying with AutoModelForCausalLM" | 
					
						
						|  | ) | 
					
						
						|  | logging.exception(err) | 
					
						
						|  | model = AutoModelForCausalLM.from_pretrained( | 
					
						
						|  | base_model, | 
					
						
						|  | load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, | 
					
						
						|  | torch_dtype=torch_dtype, | 
					
						
						|  | device_map=cfg.device_map, | 
					
						
						|  | trust_remote_code=cfg.trust_remote_code or False, | 
					
						
						|  | **model_kwargs, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | embeddings_len = math.ceil(len(tokenizer) / 32) * 32 | 
					
						
						|  | model.resize_token_embeddings(embeddings_len) | 
					
						
						|  |  | 
					
						
						|  | if not cfg.gptq and ( | 
					
						
						|  | (cfg.adapter == "lora" and load_in_8bit) | 
					
						
						|  | or (cfg.adapter == "qlora" and cfg.load_in_4bit) | 
					
						
						|  | ): | 
					
						
						|  | logging.info("converting PEFT model w/ prepare_model_for_int8_training") | 
					
						
						|  | model = prepare_model_for_int8_training(model) | 
					
						
						|  |  | 
					
						
						|  | model, lora_config = load_adapter(model, cfg, adapter) | 
					
						
						|  |  | 
					
						
						|  | if cfg.ddp and not load_in_8bit: | 
					
						
						|  | model.to(f"cuda:{cfg.local_rank}") | 
					
						
						|  |  | 
					
						
						|  | if cfg.gptq: | 
					
						
						|  |  | 
					
						
						|  | logging.info("Fitting 4bit scales and zeros to half") | 
					
						
						|  | for _, module in model.named_modules(): | 
					
						
						|  | if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str( | 
					
						
						|  | type(module) | 
					
						
						|  | ): | 
					
						
						|  | if hasattr(module, "is_v1_model") and module.is_v1_model: | 
					
						
						|  | module.zeros = module.zeros.half() | 
					
						
						|  | module.scales = module.scales.half() | 
					
						
						|  | module.bias = module.bias.half() | 
					
						
						|  |  | 
					
						
						|  | if ( | 
					
						
						|  | torch.cuda.device_count() > 1 | 
					
						
						|  | and int(os.getenv("WORLD_SIZE", "1")) > 1 | 
					
						
						|  | and (cfg.gptq or cfg.load_in_4bit) | 
					
						
						|  | ): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | setattr(model, "is_parallelizable", True) | 
					
						
						|  | setattr(model, "model_parallel", True) | 
					
						
						|  |  | 
					
						
						|  | requires_grad = [] | 
					
						
						|  | for name, param in model.named_parameters(recurse=True): | 
					
						
						|  | if param.requires_grad: | 
					
						
						|  | requires_grad.append(f"{name}: {param.requires_grad}") | 
					
						
						|  | if len(requires_grad) == 0: | 
					
						
						|  | logging.warning("there are no parameters that require gradient updates") | 
					
						
						|  | model.config.use_cache = False | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return model, lora_config | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_adapter(model, cfg, adapter): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if adapter is None: | 
					
						
						|  | return model, None | 
					
						
						|  | if adapter in ["lora", "qlora"]: | 
					
						
						|  | return load_lora(model, cfg) | 
					
						
						|  | if adapter == "llama-adapter": | 
					
						
						|  | return load_llama_adapter(model, cfg) | 
					
						
						|  |  | 
					
						
						|  | raise NotImplementedError(f"{adapter} peft adapter not available") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_llama_adapter(model, cfg): | 
					
						
						|  |  | 
					
						
						|  | from peft import AdaptionPromptConfig, PeftModel, get_peft_model | 
					
						
						|  |  | 
					
						
						|  | peft_config = AdaptionPromptConfig( | 
					
						
						|  | adapter_layers=cfg.peft_adapter.layers, | 
					
						
						|  | adapter_len=cfg.peft_adapter.len, | 
					
						
						|  | task_type="CAUSAL_LM", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if cfg.lora_model_dir: | 
					
						
						|  | logging.info("Loading pretained LORA") | 
					
						
						|  | model = PeftModel.from_pretrained( | 
					
						
						|  | model, | 
					
						
						|  | cfg.lora_model_dir, | 
					
						
						|  | device_map=cfg.device_map, | 
					
						
						|  | torch_dtype=torch.float16, | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | model = get_peft_model(model, peft_config) | 
					
						
						|  |  | 
					
						
						|  | model.print_trainable_parameters() | 
					
						
						|  |  | 
					
						
						|  | return model, peft_config | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def find_all_linear_names(bits, model): | 
					
						
						|  | cls = ( | 
					
						
						|  | bnb.nn.Linear4bit | 
					
						
						|  | if bits == 4 | 
					
						
						|  | else (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear) | 
					
						
						|  | ) | 
					
						
						|  | lora_module_names = set() | 
					
						
						|  | for name, module in model.named_modules(): | 
					
						
						|  | if isinstance(module, cls): | 
					
						
						|  | names = name.split(".") | 
					
						
						|  | lora_module_names.add(names[0] if len(names) == 1 else names[-1]) | 
					
						
						|  |  | 
					
						
						|  | if "lm_head" in lora_module_names: | 
					
						
						|  | lora_module_names.remove("lm_head") | 
					
						
						|  |  | 
					
						
						|  | return list(lora_module_names) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_lora(model, cfg): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from peft import LoraConfig, PeftModel, get_peft_model | 
					
						
						|  |  | 
					
						
						|  | lora_target_modules = list(cfg.lora_target_modules or []) | 
					
						
						|  |  | 
					
						
						|  | if cfg.lora_target_linear: | 
					
						
						|  | bits = None | 
					
						
						|  | if cfg.load_in_4bit: | 
					
						
						|  | bits = 4 | 
					
						
						|  | elif cfg.load_in_8bit: | 
					
						
						|  | bits = 8 | 
					
						
						|  |  | 
					
						
						|  | linear_names = find_all_linear_names(bits, model) | 
					
						
						|  | logging.info(f"found linear modules: {repr(linear_names)}") | 
					
						
						|  | lora_target_modules = list(set(lora_target_modules + linear_names)) | 
					
						
						|  |  | 
					
						
						|  | lora_config = LoraConfig( | 
					
						
						|  | r=cfg.lora_r, | 
					
						
						|  | lora_alpha=cfg.lora_alpha, | 
					
						
						|  | target_modules=lora_target_modules, | 
					
						
						|  | lora_dropout=cfg.lora_dropout, | 
					
						
						|  | fan_in_fan_out=cfg.lora_fan_in_fan_out, | 
					
						
						|  | modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, | 
					
						
						|  | bias="none", | 
					
						
						|  | task_type="CAUSAL_LM", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if cfg.lora_model_dir: | 
					
						
						|  | model = PeftModel.from_pretrained( | 
					
						
						|  | model, | 
					
						
						|  | cfg.lora_model_dir, | 
					
						
						|  | device_map=cfg.device_map, | 
					
						
						|  |  | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | model = get_peft_model(model, lora_config) | 
					
						
						|  |  | 
					
						
						|  | model.print_trainable_parameters() | 
					
						
						|  |  | 
					
						
						|  | return model, lora_config | 
					
						
						|  |  |