File size: 33,276 Bytes
f4e5d86 6045345 bdbca8f 6045345 eaaeefc 6045345 40a6362 ffd1043 6045345 8c2e05a 03e5907 88e17ff e0b7eea 39a208c 3355706 88e17ff 39a208c 54d2ac1 2bc1a5b 40a6362 5698943 cb9797e e303d64 f8ae59b 125cccb 0f10080 6045345 553a86b 6045345 0f10080 a581e9f 0f10080 08719b9 0f10080 a581e9f 125cccb 732851f 11d1d60 fccb542 68b227a 7fabc4d fccb542 7fabc4d 40a6362 7fabc4d 40a6362 1bc1186 a581e9f 1bc1186 37293dc 6045345 efb3b2c 732851f 2bb0b78 47d601f efb3b2c 47d601f 2bb0b78 efb3b2c 1bc1186 efb3b2c 32e6fe9 71bd062 4c37bd0 fde091c 71bd062 cb9797e 32e6fe9 669f1d0 eb480df 1115c50 25e037f 32e6fe9 25e037f 0f10080 25e037f 1ffa386 732851f 1ffa386 25e037f 1ffa386 0f10080 1ffa386 0f10080 1ffa386 08719b9 1ffa386 e0b7eea fde091c 32e6fe9 e0b7eea 32e6fe9 25e037f 10388a8 f8ae59b 98b4762 f8ae59b 32e6fe9 6045345 125cccb f243c21 125cccb f4e5d86 7181022 f4e5d86 7181022 6b9b229 6045345 6b9b229 2d60ba3 1d70f24 5698943 1d70f24 312a9fa 6045345 1d70f24 895f0a0 2bc1a5b 1d70f24 00568c1 1d70f24 6cb2310 00568c1 1d70f24 6045345 1d70f24 6910e6a b6ab8aa 6910e6a b6ab8aa 2ce5c0d 2bb0b78 eaaeefc 73f1bda eaaeefc e62d590 bdfefaf e62d590 fac2d98 f243c21 e62d590 5ea3aa3 69a2350 3355706 a94f9cb 3355706 faecff9 3355706 1987e5c 41353d2 3b4d055 41353d2 3b4d055 1d70f24 c67fb71 19a600a 7fabc4d 1d70f24 e799e08 7fabc4d 5698943 bcc78d8 7fabc4d f1f60cb bcc78d8 f1f60cb 00568c1 e62d590 6045345 e799e08 563b6d8 919727b 9190ada 1bc1186 3b4d055 9190ada 15d3a65 1d5ab84 56f9ca5 1d5ab84 56f9ca5 1d5ab84 40a6362 d69da99 3355706 1bc1186 3355706 1bc1186 3355706 94f5e41 4ac9e25 136522f 1bc1186 136522f 1bc1186 553a86b 7f09106 1bc1186 7f09106 1bc1186 553a86b 03e5907 1bc1186 03e5907 1bc1186 03e5907 f4e5d86 553a86b 1bc1186 6045345 8c2e05a 1066751 40a6362 3607882 aa3c3f9 c9a149f 40a6362 136522f 5b67ea9 c9a149f 553a86b 5b67ea9 ab5cd28 637ed09 40a6362 637ed09 40a6362 637ed09 fac2d98 e303d64 0b7ba57 78c5b19 e923e62 2d65f47 0b7ba57 e923e62 0b7ba57 98bf76e 3e3229e 54d2ac1 c67fb71 54d2ac1 3e3229e 4cb7900 3355706 4cb7900 3e3229e f319b0b 6045345 f311df9 78c5b19 3a011ea f311df9 8da1633 f311df9 78c5b19 f311df9 0b7ba57 f243c21 7523d1f 6045345 f243c21 4cb7900 6045345 964d858 cfcc549 ad2b48c 247825b 553a86b 40a6362 ad2b48c 1edc30c 00568c1 1edc30c 7b55fe6 6045345 32e6fe9 6045345 125cccb 6045345 176b888 7b5e762 125cccb 2255bb7 6045345 2255bb7 8bd7a49 8c2e05a 2255bb7 1d5ab84 8608d80 2255bb7 267b7b2 03e5907 ffd1043 440c3ab ffd1043 e799e08 ffd1043 7523d1f 6045345 8c2e05a 6045345 4c90633 9196237 267b7b2 553a86b 9196237 6045345 4cb7900 2255bb7 ffd1043 8487b97 2255bb7 2c73c81 2255bb7 4cb7900 2255bb7 6045345 7523d1f 2255bb7 8608d80 bdfefaf 2255bb7 125cccb bdfefaf 2255bb7 6045345 2255bb7 6045345 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 |
"""Module for models and model loading"""
import logging
import math
import os
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
import addict
import bitsandbytes as bnb
import torch
import transformers
from peft import (
LoftQConfig,
PeftConfig,
PeftModel,
PeftModelForCausalLM,
prepare_model_for_kbit_training,
)
from peft.tuners.lora import QuantLinear
from transformers import ( # noqa: F401
AddedToken,
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
GPTQConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
patch_for_multipack,
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.dict import DictDefault
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
LOG = logging.getLogger("axolotl")
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
quant_config_exists = (
hasattr(model_config, "quantization_config")
and model_config.quantization_config
)
quant_config_method_is_gptq = (
quant_config_exists
and "quant_method" in model_config.quantization_config
and model_config.quantization_config["quant_method"] == "gptq"
)
if cfg.gptq and not quant_config_method_is_gptq:
raise ValueError(
"model_config.quantization_config is not set or quant_method is not set to gptq. "
"Please make sure to point to a GPTQ model."
)
if not cfg.gptq and quant_config_exists:
raise ValueError(
"model_config.quantization_config is set but `gptq` flag is not. "
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
)
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
if (
cfg.adapter
and cfg.tokens
and (
not cfg.lora_modules_to_save
or not all(x in cfg.lora_modules_to_save for x in lora_modules_to_save)
)
):
lora_modules_to_save = ", ".join(map(lambda x: f"`{x}`", lora_modules_to_save))
raise ValueError(
f"`lora_modules_to_save` not properly set when adding new tokens. Please include [{lora_modules_to_save}] in `lora_modules_to_save`."
)
def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model
if not model_config_name and cfg.tokenizer_config:
model_config_name = cfg.tokenizer_config
trust_remote_code = cfg.trust_remote_code is True
config_kwargs = {}
if cfg.model_revision:
config_kwargs["revision"] = cfg.model_revision
try:
model_config = AutoConfig.from_pretrained(
model_config_name,
trust_remote_code=trust_remote_code,
**config_kwargs,
)
except ValueError as err:
if "mamba" in model_config_name:
return addict.Dict(
{
"model_type": "mamba",
}
)
raise err
if cfg.model_config:
for key, val in cfg.model_config.items():
setattr(model_config, key, val)
check_model_config(cfg, model_config)
return model_config
def load_tokenizer(cfg):
model_config = load_model_config(cfg)
tokenizer_kwargs = {}
use_fast = True # this is the default
if cfg.tokenizer_use_fast is not None:
use_fast = cfg.tokenizer_use_fast
if cfg.tokenizer_legacy is not None:
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
tokenizer_cls = AutoTokenizer
if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
tokenizer = tokenizer_cls.from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
if (
tokenizer.__class__.__name__
in [
"LlamaTokenizer",
"LlamaTokenizerFast",
"CodeLlamaTokenizer",
"CodeLlamaTokenizerFast",
]
and hasattr(tokenizer, "pad_token")
and not tokenizer.pad_token
):
# set a pad_token, but use eos_token so we don't add a new token
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Mistral's official FA implementation requires left padding
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
tokenizer.padding_side = "left"
# Qwen base only has single token, so we need to set the special tokens
if cfg.is_qwen_derived_model:
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
for attr_name in token_ids:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, tokenizer.eod_id)
token_names = ["bos_token", "eos_token", "pad_token", "unk_token"]
for attr_name in token_names:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")
additional_special_tokens = None
if cfg.special_tokens:
special_tokens = cfg.special_tokens.to_dict()
additional_special_tokens = special_tokens.pop(
"additional_special_tokens", None
)
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
for k, val in special_tokens.items():
# check if new special token is not already in tokenizer and
# is adapter training to make sure lora_modules_to_save is set
# pylint: disable=too-many-boolean-expressions
if (
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
and (len(tokenizer.encode(val, add_special_tokens=False)) > 2)
and cfg.adapter
and (
not cfg.lora_modules_to_save
or not all(
x in cfg.lora_modules_to_save for x in lora_modules_to_save
)
)
):
lora_modules_to_save = ", ".join(
[f"`{x}`" for x in lora_modules_to_save]
)
raise ValueError(
f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens."
)
tokenizer.add_special_tokens(
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
)
# If we add bos_token and eos_token, we need to update the post processor to
# handle them correctly.
# https://github.com/huggingface/transformers/pull/24132
bos_or_eos_in_special_tokens = (
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
)
if (
tokenizer.__class__.__name__
in (
"LlamaTokenizerFast",
"CodeLlamaTokenizerFast",
)
and bos_or_eos_in_special_tokens
):
tokenizer.update_post_processor()
if cfg.tokens:
tokenizer.add_tokens(
[
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
for token in cfg.tokens
]
)
# Additional special tokens are a List, and need to be treated differently than regular special
# tokens. We add them after we have called `add_tokens` in case these additional special tokens
# are new tokens.
#
# Usage:
#
# ```py
# special_tokens:
# additional_special_tokens: ["<|im_start|>", "<|im_end|>"]
# ```
if additional_special_tokens is not None:
tokenizer.add_special_tokens(
{"additional_special_tokens": additional_special_tokens}
)
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if cfg.chat_template:
chat_template_string = chat_templates(cfg.chat_template)
if cfg.default_system_message and cfg.chat_template == "chatml":
chat_template_string = chat_template_string.replace(
"You are a helpful assistant.", cfg.default_system_message
)
tokenizer.chat_template = chat_template_string
else:
LOG.info(
"No Chat template selected. Consider adding a chat template for easier inference."
)
return tokenizer
def load_model(
cfg: DictDefault,
tokenizer: PreTrainedTokenizerBase,
inference: bool = False,
reference_model: bool = False,
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
"""
Load a model for a given configuration and tokenizer.
"""
base_model = cfg.base_model
model_type = cfg.model_type
model_config = load_model_config(cfg)
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
if cfg.flash_attention:
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
replace_btlm_attn_with_flash_attn,
)
replace_btlm_attn_with_flash_attn(cfg.base_model)
if (
hasattr(model_config, "model_type")
and model_config.model_type == "stablelm_epoch"
):
if cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.stablelm_attn_hijack_flash import (
replace_stablelm_attn_with_flash_attn,
)
replace_stablelm_attn_with_flash_attn(cfg.base_model)
if cfg.sample_packing and cfg.s2_attention:
raise ValueError(
"Received `sample_packing=true` and `s2_attention=true`; however, \
shifted-sparse attention does not currently support sample packing."
)
if (
cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and cfg.flash_attention
and cfg.sample_packing
):
patch_for_multipack(cfg.model_config_type)
elif cfg.is_llama_derived_model:
# Modify all llama derived models in one block
if cfg.flash_attention:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)
if cfg.sample_packing:
if cfg.device not in ["mps", "cpu"] and not inference:
LOG.info("patching with flash attention for sample packing")
replace_llama_attn_with_flash_attn(
packed=True,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
)
elif cfg.s2_attention:
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
replace_llama_attn_with_flash_attn(
packed=False,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
use_shifted_sparse_attn=True,
)
elif cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
)
LOG.info("patching with xformers attention")
hijack_llama_attention()
elif cfg.sample_packing:
from axolotl.monkeypatch.llama_patch_multipack import (
hijack_llama_prepare_4d_mask,
)
LOG.info("patching llama _prepare_4d_causal_attention_mask*")
hijack_llama_prepare_4d_mask()
elif cfg.s2_attention:
raise NotImplementedError(
"Shifted-sparse attention not currently implemented without flash attention."
)
# Modify mistral derived models
if (
cfg.model_config_type == "mistral"
and cfg.flash_attention
and cfg.sample_packing
):
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
replace_mistral_attn_with_flash_attn,
)
LOG.info("patching mistral with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
LOG.info("patching _expand_mask")
hijack_expand_mask()
model_kwargs: Dict[str, Any] = {}
if cfg.model_kwargs:
for key, val in cfg.model_kwargs.items():
model_kwargs[key] = val
max_memory = cfg.max_memory
device_map = cfg.device_map
if cfg.gpu_memory_limit:
gpu_memory_limit = (
str(cfg.gpu_memory_limit) + "GiB"
if isinstance(cfg.gpu_memory_limit, int)
else cfg.gpu_memory_limit
)
max_memory = {}
for i in range(torch.cuda.device_count()):
max_memory[i] = gpu_memory_limit
max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything
if max_memory is not None:
# Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
from accelerate import infer_auto_device_map, init_empty_weights
with init_empty_weights():
model_canvas = AutoModelForCausalLM.from_config(model_config)
model_canvas.tie_weights()
device_map = infer_auto_device_map(
model_canvas,
max_memory=max_memory,
dtype=cfg.torch_dtype,
)
# We can discard max_memory now as we have a device map set up for us
max_memory = None
model_kwargs["device_map"] = device_map
model_kwargs["torch_dtype"] = cfg.torch_dtype
if torch.backends.mps.is_available():
model_kwargs["device_map"] = "mps:0"
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
# if cfg.rl:
# if torch.cuda.device_count() > 1:
# if reference_model:
# model_kwargs["device_map"] = "cuda:" + str(
# torch.cuda.current_device() + 1
# )
# else:
# model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device())
if is_deepspeed_zero3_enabled():
del model_kwargs["device_map"]
if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision
if cfg.gptq:
if not hasattr(model_config, "quantization_config"):
LOG.warning("model config does not contain quantization_config information")
else:
if cfg.gptq_disable_exllama is not None:
model_config.quantization_config[
"disable_exllama"
] = cfg.gptq_disable_exllama
model_kwargs["quantization_config"] = GPTQConfig(
**model_config.quantization_config
)
if cfg.adapter == "qlora" and cfg.load_in_4bit:
bnb_config = {
"load_in_4bit": True,
"llm_int8_threshold": 6.0,
"llm_int8_has_fp16_weight": False,
"bnb_4bit_compute_dtype": cfg.torch_dtype,
"bnb_4bit_use_double_quant": True,
"bnb_4bit_quant_type": "nf4",
}
if cfg.bnb_config_kwargs:
bnb_config.update(cfg.bnb_config_kwargs)
model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
if cfg.load_in_8bit and cfg.adapter is not None:
model_kwargs["load_in_8bit"] = True
if cfg.load_in_4bit and cfg.adapter is not None:
model_kwargs["load_in_4bit"] = True
# no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in model_kwargs or cfg.gptq:
if "load_in_8bit" in model_kwargs:
del model_kwargs["load_in_8bit"]
if "load_in_4bit" in model_kwargs:
del model_kwargs["load_in_4bit"]
# sample packing uses custom FA2 patch
if cfg.flash_attention:
if not cfg.sample_packing:
if cfg.s2_attention:
pass
# most other models support flash attention, we can define exceptions as they come up
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
model_kwargs["attn_implementation"] = "eager"
model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
elif cfg.sdp_attention:
model_kwargs["attn_implementation"] = "sdpa"
model_config._attn_implementation = "sdpa" # pylint: disable=protected-access
elif cfg.eager_attention:
model_kwargs["attn_implementation"] = "eager"
model_config._attn_implementation = "eager" # pylint: disable=protected-access
try:
if (
model_config.model_type == "llama"
and not cfg.trust_remote_code
and not cfg.gptq
):
from transformers import LlamaForCausalLM
model = LlamaForCausalLM.from_pretrained(
base_model,
config=model_config,
**model_kwargs,
)
if cfg.flash_attention and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_mlp_with_swiglu,
replace_llama_qkv_with_fused,
)
if cfg.flash_attn_fuse_mlp:
LOG.info("patching with SwiGLU")
replace_llama_mlp_with_swiglu(model)
if cfg.flash_attn_fuse_qkv:
LOG.info("patching with fused QKV")
replace_llama_qkv_with_fused(model)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
# This is a WIP, still an issue with the backward pass
# RuntimeError: grad can be implicitly created only for scalar outputs
# TODO: try config.sequence_parallel = False
# # https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/tests/models/test_gpt_neox.py#L12
# # https://github.com/HazyResearch/flash-attention/tree/main/training#model-components
# # add `**kwargs` to https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/flash_attn/models/gpt.py#L442
# from flash_attn.utils.pretrained import state_dict_from_pretrained
# from flash_attn.models.gpt import GPTLMHeadModel
# from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config
# from transformers import GPTNeoXConfig
# config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(base_model))
# config.use_flash_attn = True
# config.fused_bias_fc = True
# config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
# config.activation_function = "gelu_fast"
# config.fused_dropout_add_ln = True
# # config.residual_in_fp32 = True
#
# model: GPTLMHeadModel = GPTLMHeadModel.from_pretrained(
# base_model,
# config,
# dtype=torch_dtype,
# device=cfg.device,
# )
# model.train() # sets to train instead of eval mode
elif model_type == "MambaLMHeadModel":
# FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
model_kwargs["dtype"] = model_kwargs["torch_dtype"]
model_kwargs["device"] = torch.cuda.current_device()
del model_kwargs["torch_dtype"]
del model_kwargs["device_map"]
model = MambaLMHeadModel.from_pretrained(
base_model,
**model_kwargs,
)
elif model_type and not cfg.trust_remote_code:
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
model = getattr(transformers, model_type).from_pretrained(
base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts
if (
hasattr(model_config, "max_seq_len")
and model_config.max_seq_len
and cfg.sequence_len > model_config.max_seq_len
):
model_config.max_seq_len = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}")
elif (
hasattr(model_config, "max_sequence_length")
and model_config.max_sequence_length
and cfg.sequence_len > model_config.max_sequence_length
):
model_config.max_sequence_length = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}")
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
except Exception as err: # pylint: disable=broad-exception-caught
LOG.exception(err)
raise err
if isinstance(model, (PeftModel, PeftModelForCausalLM)):
model = model.merge_and_unload()
embeddings_len = (
math.ceil(len(tokenizer) / 32) * 32
if cfg.resize_token_embeddings_to_32x
else len(tokenizer)
)
if (
hasattr(model, "get_input_embeddings")
and model.get_input_embeddings().num_embeddings < embeddings_len
):
model.resize_token_embeddings(embeddings_len)
else:
model.tie_weights()
if (
hasattr(model, "config")
and hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings
and cfg.sequence_len > model.config.max_position_embeddings
):
LOG.warning(
f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {cfg.sequence_len}"
)
model.config.max_position_embeddings = cfg.sequence_len
if (
hasattr(model, "config")
and hasattr(model.config, "bos_token_id")
and model.config.bos_token_id
and model.config.bos_token_id != tokenizer.bos_token_id
):
model.config.bos_token_id = tokenizer.bos_token_id
if (
hasattr(model, "config")
and hasattr(model.config, "eos_token_id")
and model.config.eos_token_id
and model.config.eos_token_id != tokenizer.eos_token_id
):
model.config.eos_token_id = tokenizer.eos_token_id
if hasattr(model, "device") and model.device.type in ("cuda", "mps"):
log_gpu_memory_usage(LOG, "after model load", model.device)
# make sure these are fp32 per Ramesh et al. (2021)
embedding_modules = get_linear_embedding_layers(cfg.model_config_type)
if not cfg.fsdp:
# FSDP doesn't like mixed Float and BFloat16
for name, module in model.named_modules():
if "norm" in name or name.endswith(".gate"):
module.to(torch.float32)
if model_config.model_type == "btlm":
# don't upcast lm_head for btlm
continue
if any(m in name for m in embedding_modules):
if hasattr(module, "weight"):
module.to(torch.float32)
needs_fa2_dtype = cfg.adapter or cfg.fsdp
skip_prepare_model_for_kbit_training = False
if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
from deepspeed.utils import ( # pylint: disable=no-name-in-module
set_z3_leaf_modules,
)
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
skip_prepare_model_for_kbit_training = True
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
if cfg.adapter == "lora" and loftq_bits:
skip_prepare_model_for_kbit_training = True
if cfg.adapter in ["lora", "qlora"]:
if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable()
if (
cfg.load_in_8bit or cfg.load_in_4bit
) and not skip_prepare_model_for_kbit_training:
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
needs_fa2_dtype = True
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if needs_fa2_dtype or cfg.flash_attention:
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
for name, module in model.named_modules():
if "norm" in name:
module.to(cfg.torch_dtype)
if any(m in name for m in embedding_modules):
if hasattr(module, "weight"):
module.to(cfg.torch_dtype)
lora_config = None
if not reference_model or cfg.lora_model_dir:
# if we're not loading the reference model, then we're loading the model for training
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora:
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
else:
model, lora_config = load_adapter(model, cfg, cfg.adapter)
if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit):
# TODO revaldate this conditional
model.to(f"cuda:{cfg.local_rank}")
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
setattr(model, "is_parallelizable", True)
setattr(model, "model_parallel", True)
requires_grad = []
for name, param in model.named_parameters(recurse=True):
if param.requires_grad:
requires_grad.append(f"{name}: {param.requires_grad}")
if len(requires_grad) == 0:
LOG.warning("there are no parameters that require gradient updates")
if hasattr(model, "config"):
model.config.use_cache = False
if cfg.flash_optimum:
from optimum.bettertransformer import BetterTransformer
model = BetterTransformer.transform(model)
if cfg.adapter is not None:
log_gpu_memory_usage(LOG, "after adapters", model.device)
# TODO resume_from_checkpoint handling
return model, lora_config
def load_adapter(model, cfg, adapter, inference=False):
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
if adapter is None:
return model, None
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter in ["lora", "qlora"]:
return load_lora(model, cfg, inference=inference)
if adapter == "llama-adapter":
return load_llama_adapter(model, cfg)
raise NotImplementedError(f"{adapter} peft adapter not available")
def load_llama_adapter(model, cfg):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import AdaptionPromptConfig, get_peft_model
peft_config = AdaptionPromptConfig(
adapter_layers=cfg.peft_adapter.layers, # layers (L)
adapter_len=cfg.peft_adapter.len, # prompt length (K)
task_type="CAUSAL_LM",
)
if cfg.lora_model_dir:
LOG.debug("Loading pretrained PEFT - llama_adapter")
model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
torch_dtype=torch.float16,
)
else:
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
return model, peft_config
def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
lora_module_names = set()
for name, module in model.named_modules():
if (
isinstance(module, cls)
or "Linear" in module.__class__.__name__
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
embedding_modules = get_linear_embedding_layers(model.config.model_type)
output_embedding = embedding_modules[1]
if output_embedding in lora_module_names: # needed for 16-bit
lora_module_names.remove(output_embedding)
return list(lora_module_names)
def load_lora(model, cfg, inference=False, config_only=False):
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
from peft import LoraConfig, get_peft_model
lora_target_modules = list(cfg.lora_target_modules or [])
if cfg.lora_target_linear:
linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(linear_names)}")
lora_target_modules = list(set(lora_target_modules + linear_names))
lora_config_kwargs = {}
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
if loftq_bits:
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
lora_config_kwargs["init_lora_weights"] = "loftq"
lora_config = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
target_modules=lora_target_modules,
layers_to_transform=cfg.peft_layers_to_transform,
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
bias="none",
task_type="CAUSAL_LM",
**lora_config_kwargs,
)
if config_only:
return None, lora_config
if cfg.lora_model_dir:
LOG.debug("Loading pretrained PEFT - LoRA")
model_kwargs: Any = {}
if cfg.lora_on_cpu:
model_kwargs["max_memory"] = {"cpu": "256GiB"}
model_kwargs["device_map"] = {"": "cpu"}
model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
is_trainable=(not inference),
**model_kwargs,
)
else:
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
return model, lora_config
|