Nanobit commited on
Commit
f4e5d86
·
1 Parent(s): daf47cc

Lint models.py

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +34 -30
src/axolotl/utils/models.py CHANGED
@@ -1,13 +1,16 @@
 
 
 
1
  import logging
2
  import math
3
  import os
4
  from pathlib import Path
5
- from typing import Optional, Tuple, TYPE_CHECKING
6
 
7
  import bitsandbytes as bnb
8
  import torch
9
  import transformers
10
- from transformers import (
11
  AutoModelForCausalLM,
12
  AutoTokenizer,
13
  PreTrainedModel,
@@ -18,9 +21,8 @@ from transformers import (
18
  try:
19
  from transformers import (
20
  LlamaForCausalLM,
21
- LlamaTokenizer,
22
  )
23
- except:
24
  logging.warning(
25
  "This version of transformers does not support Llama. Consider upgrading."
26
  )
@@ -28,9 +30,9 @@ except:
28
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
29
 
30
  if TYPE_CHECKING:
31
- from peft import PeftModel, PeftConfig
32
- from axolotl.utils.dict import DictDefault
33
- from transformers import PreTrainedTokenizer
34
 
35
 
36
  def load_tokenizer(
@@ -62,8 +64,8 @@ def load_tokenizer(
62
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
63
 
64
  if cfg.special_tokens:
65
- for k, v in cfg.special_tokens.items():
66
- tokenizer.add_special_tokens({k: v})
67
  if cfg.tokens:
68
  tokenizer.add_tokens(list(cfg.tokens))
69
 
@@ -80,6 +82,9 @@ def load_model(
80
  inference=False,
81
  ):
82
  # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
 
 
 
83
 
84
  # TODO refactor as a kwarg
85
  load_in_8bit = cfg.load_in_8bit
@@ -115,9 +120,9 @@ def load_model(
115
 
116
  replace_peft_model_with_int4_lora_model()
117
  from peft import prepare_model_for_int8_training
118
- except Exception as e:
119
- logging.exception(e)
120
- raise e
121
 
122
  model_kwargs = {}
123
  if cfg.adapter == "qlora" and cfg.load_in_4bit:
@@ -155,7 +160,7 @@ def load_model(
155
  "unable to find a cached model file, this will likely fail..."
156
  )
157
  model_path = str(cache_model_path)
158
- except:
159
  model_path = cfg.base_model
160
  model, _ = load_llama_model_4bit_low_ram(
161
  base_model_config if base_model_config else base_model,
@@ -210,13 +215,13 @@ def load_model(
210
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
211
  torch_dtype=torch_dtype,
212
  device_map=cfg.device_map,
213
- trust_remote_code=True if cfg.trust_remote_code is True else False,
214
  **model_kwargs,
215
  )
216
  else:
217
  config = AutoConfig.from_pretrained(
218
  base_model,
219
- trust_remote_code=True if cfg.trust_remote_code is True else False,
220
  )
221
  model = AutoModelForCausalLM.from_pretrained(
222
  base_model,
@@ -225,30 +230,29 @@ def load_model(
225
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
226
  torch_dtype=torch_dtype,
227
  device_map=cfg.device_map,
228
- trust_remote_code=True if cfg.trust_remote_code is True else False,
229
  **model_kwargs,
230
  )
231
- except Exception as e:
232
  logging.error(
233
  "Exception raised attempting to load model, retrying with AutoModelForCausalLM"
234
  )
235
- logging.exception(e)
236
  model = AutoModelForCausalLM.from_pretrained(
237
  base_model,
238
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
239
  torch_dtype=torch_dtype,
240
  device_map=cfg.device_map,
241
- trust_remote_code=True if cfg.trust_remote_code is True else False,
242
  **model_kwargs,
243
  )
244
 
245
  embeddings_len = math.ceil(len(tokenizer) / 32) * 32
246
  model.resize_token_embeddings(embeddings_len)
247
 
248
- if (
249
- ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora")
250
- and not cfg.gptq
251
- and (load_in_8bit or cfg.load_in_4bit)
252
  ):
253
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
254
  model = prepare_model_for_int8_training(model)
@@ -261,14 +265,14 @@ def load_model(
261
  if cfg.gptq:
262
  # Scales to half
263
  logging.info("Fitting 4bit scales and zeros to half")
264
- for n, m in model.named_modules():
265
- if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str(
266
- type(m)
267
  ):
268
- if hasattr(m, "is_v1_model") and m.is_v1_model:
269
- m.zeros = m.zeros.half()
270
- m.scales = m.scales.half()
271
- m.bias = m.bias.half()
272
 
273
  if (
274
  torch.cuda.device_count() > 1
 
1
+ """Module for models and model loading"""
2
+
3
+
4
  import logging
5
  import math
6
  import os
7
  from pathlib import Path
8
+ from typing import Optional, Tuple, TYPE_CHECKING # noqa: F401
9
 
10
  import bitsandbytes as bnb
11
  import torch
12
  import transformers
13
+ from transformers import ( # noqa: F401
14
  AutoModelForCausalLM,
15
  AutoTokenizer,
16
  PreTrainedModel,
 
21
  try:
22
  from transformers import (
23
  LlamaForCausalLM,
 
24
  )
25
+ except ImportError:
26
  logging.warning(
27
  "This version of transformers does not support Llama. Consider upgrading."
28
  )
 
30
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
31
 
32
  if TYPE_CHECKING:
33
+ from peft import PeftConfig # noqa: F401
34
+ from axolotl.utils.dict import DictDefault # noqa: F401
35
+ from transformers import PreTrainedTokenizer # noqa: F401
36
 
37
 
38
  def load_tokenizer(
 
64
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
65
 
66
  if cfg.special_tokens:
67
+ for k, val in cfg.special_tokens.items():
68
+ tokenizer.add_special_tokens({k: val})
69
  if cfg.tokens:
70
  tokenizer.add_tokens(list(cfg.tokens))
71
 
 
82
  inference=False,
83
  ):
84
  # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
85
+ """
86
+ Load a model from a base model and a model type.
87
+ """
88
 
89
  # TODO refactor as a kwarg
90
  load_in_8bit = cfg.load_in_8bit
 
120
 
121
  replace_peft_model_with_int4_lora_model()
122
  from peft import prepare_model_for_int8_training
123
+ except Exception as err:
124
+ logging.exception(err)
125
+ raise err
126
 
127
  model_kwargs = {}
128
  if cfg.adapter == "qlora" and cfg.load_in_4bit:
 
160
  "unable to find a cached model file, this will likely fail..."
161
  )
162
  model_path = str(cache_model_path)
163
+ except Exception: # pylint: disable=broad-exception-caught
164
  model_path = cfg.base_model
165
  model, _ = load_llama_model_4bit_low_ram(
166
  base_model_config if base_model_config else base_model,
 
215
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
216
  torch_dtype=torch_dtype,
217
  device_map=cfg.device_map,
218
+ trust_remote_code=cfg.trust_remote_code or False,
219
  **model_kwargs,
220
  )
221
  else:
222
  config = AutoConfig.from_pretrained(
223
  base_model,
224
+ trust_remote_code=cfg.trust_remote_code or False,
225
  )
226
  model = AutoModelForCausalLM.from_pretrained(
227
  base_model,
 
230
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
231
  torch_dtype=torch_dtype,
232
  device_map=cfg.device_map,
233
+ trust_remote_code=cfg.trust_remote_code or False,
234
  **model_kwargs,
235
  )
236
+ except Exception as err: # pylint: disable=broad-exception-caught
237
  logging.error(
238
  "Exception raised attempting to load model, retrying with AutoModelForCausalLM"
239
  )
240
+ logging.exception(err)
241
  model = AutoModelForCausalLM.from_pretrained(
242
  base_model,
243
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
244
  torch_dtype=torch_dtype,
245
  device_map=cfg.device_map,
246
+ trust_remote_code=cfg.trust_remote_code or False,
247
  **model_kwargs,
248
  )
249
 
250
  embeddings_len = math.ceil(len(tokenizer) / 32) * 32
251
  model.resize_token_embeddings(embeddings_len)
252
 
253
+ if not cfg.gptq and (
254
+ (cfg.adapter == "lora" and load_in_8bit)
255
+ or (cfg.adapter == "qlora" and cfg.load_in_4bit)
 
256
  ):
257
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
258
  model = prepare_model_for_int8_training(model)
 
265
  if cfg.gptq:
266
  # Scales to half
267
  logging.info("Fitting 4bit scales and zeros to half")
268
+ for _, module in model.named_modules():
269
+ if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
270
+ type(module)
271
  ):
272
+ if hasattr(module, "is_v1_model") and module.is_v1_model:
273
+ module.zeros = module.zeros.half()
274
+ module.scales = module.scales.half()
275
+ module.bias = module.bias.half()
276
 
277
  if (
278
  torch.cuda.device_count() > 1