winglian commited on
Commit
45f77dd
·
1 Parent(s): 949a27b

bettter handling of llama model import

Browse files
Files changed (1) hide show
  1. scripts/finetune.py +19 -9
scripts/finetune.py CHANGED
@@ -19,7 +19,7 @@ from peft import (
19
  get_peft_model_state_dict, PeftModel,
20
  )
21
  from torch import nn
22
- from transformers import AutoModelForCausalLM, AutoTokenizer
23
 
24
  # add src to the pythonpath so we don't need to pip install this
25
  from transformers.trainer_pt_utils import get_parameter_names
@@ -53,16 +53,23 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
53
  raise NotImplementedError(f"{adapter} peft adapter not available")
54
  if "llama" in base_model:
55
  from axolotl.flash_attn import replace_llama_attn_with_flash_attn
56
-
57
  replace_llama_attn_with_flash_attn()
58
 
59
  try:
60
- model = getattr(transformers, model_type).from_pretrained(
61
- base_model,
62
- load_in_8bit=cfg.load_in_8bit,
63
- torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
64
- device_map=cfg.device_map,
65
- )
 
 
 
 
 
 
 
 
66
  except:
67
  model = AutoModelForCausalLM.from_pretrained(
68
  base_model,
@@ -72,7 +79,10 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
72
  )
73
 
74
  try:
75
- tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
 
 
 
76
  except:
77
  tokenizer = AutoTokenizer.from_pretrained(base_model)
78
 
 
19
  get_peft_model_state_dict, PeftModel,
20
  )
21
  from torch import nn
22
+ from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer
23
 
24
  # add src to the pythonpath so we don't need to pip install this
25
  from transformers.trainer_pt_utils import get_parameter_names
 
53
  raise NotImplementedError(f"{adapter} peft adapter not available")
54
  if "llama" in base_model:
55
  from axolotl.flash_attn import replace_llama_attn_with_flash_attn
 
56
  replace_llama_attn_with_flash_attn()
57
 
58
  try:
59
+ if "llama" in base_model:
60
+ model = LlamaForCausalLM.from_pretrained(
61
+ base_model,
62
+ load_in_8bit=cfg.load_in_8bit,
63
+ torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
64
+ device_map=cfg.device_map,
65
+ )
66
+ else:
67
+ model = getattr(transformers, model_type).from_pretrained(
68
+ base_model,
69
+ load_in_8bit=cfg.load_in_8bit,
70
+ torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
71
+ device_map=cfg.device_map,
72
+ )
73
  except:
74
  model = AutoModelForCausalLM.from_pretrained(
75
  base_model,
 
79
  )
80
 
81
  try:
82
+ if "llama" in base_model:
83
+ tokenizer = LlamaTokenizer.from_pretrained(model)
84
+ else:
85
+ tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
86
  except:
87
  tokenizer = AutoTokenizer.from_pretrained(base_model)
88