zetavg commited on
Commit
9c78439
·
1 Parent(s): 03b5741

try to load different type of models

Browse files
Files changed (1) hide show
  1. llama_lora/models.py +60 -16
llama_lora/models.py CHANGED
@@ -5,7 +5,10 @@ import json
5
  import re
6
 
7
  import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
 
 
 
9
  from peft import PeftModel
10
 
11
  from .globals import Global
@@ -27,42 +30,83 @@ def get_new_base_model(base_model_name):
27
  Global.name_of_new_base_model_that_is_ready_to_be_used = None
28
  clear_cache()
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  device = get_device()
31
 
32
  if device == "cuda":
33
- model = AutoModelForCausalLM.from_pretrained(
34
- base_model_name,
35
  load_in_8bit=Global.load_8bit,
36
  torch_dtype=torch.float16,
37
  # device_map="auto",
38
  # ? https://github.com/tloen/alpaca-lora/issues/21
39
  device_map={'': 0},
 
 
40
  trust_remote_code=Global.trust_remote_code
41
  )
42
  elif device == "mps":
43
- model = AutoModelForCausalLM.from_pretrained(
44
- base_model_name,
45
  device_map={"": device},
46
  torch_dtype=torch.float16,
 
 
47
  trust_remote_code=Global.trust_remote_code
48
  )
49
  else:
50
- model = AutoModelForCausalLM.from_pretrained(
51
- base_model_name,
52
  device_map={"": device},
53
  low_cpu_mem_usage=True,
 
 
54
  trust_remote_code=Global.trust_remote_code
55
  )
56
 
57
- tokenizer = get_tokenizer(base_model_name)
58
-
59
- if re.match("[^/]+/llama", base_model_name):
60
- model.config.pad_token_id = tokenizer.pad_token_id = 0
61
- model.config.bos_token_id = tokenizer.bos_token_id = 1
62
- model.config.eos_token_id = tokenizer.eos_token_id = 2
63
-
64
- return model
65
-
66
 
67
  def get_tokenizer(base_model_name):
68
  if Global.ui_dev_mode:
 
5
  import re
6
 
7
  import torch
8
+ from transformers import (
9
+ AutoModelForCausalLM, AutoModel,
10
+ AutoTokenizer, LlamaTokenizer
11
+ )
12
  from peft import PeftModel
13
 
14
  from .globals import Global
 
30
  Global.name_of_new_base_model_that_is_ready_to_be_used = None
31
  clear_cache()
32
 
33
+ model_class = AutoModelForCausalLM
34
+ from_tf = False
35
+ force_download = False
36
+ has_tried_force_download = False
37
+ while True:
38
+ try:
39
+ model = _get_model_from_pretrained(
40
+ model_class, base_model_name, from_tf=from_tf, force_download=force_download)
41
+ break
42
+ except Exception as e:
43
+ if 'from_tf' in str(e):
44
+ print(
45
+ f"Got error while loading model {base_model_name} with AutoModelForCausalLM: {e}.")
46
+ print("Retrying with from_tf=True...")
47
+ from_tf = True
48
+ force_download = False
49
+ elif model_class == AutoModelForCausalLM:
50
+ print(
51
+ f"Got error while loading model {base_model_name} with AutoModelForCausalLM: {e}.")
52
+ print("Retrying with AutoModel...")
53
+ model_class = AutoModel
54
+ force_download = False
55
+ else:
56
+ if has_tried_force_download:
57
+ raise e
58
+ print(
59
+ f"Got error while loading model {base_model_name}: {e}.")
60
+ print("Retrying with force_download=True...")
61
+ model_class = AutoModelForCausalLM
62
+ from_tf = False
63
+ force_download = True
64
+ has_tried_force_download = True
65
+
66
+ tokenizer = get_tokenizer(base_model_name)
67
+
68
+ if re.match("[^/]+/llama", base_model_name):
69
+ model.config.pad_token_id = tokenizer.pad_token_id = 0
70
+ model.config.bos_token_id = tokenizer.bos_token_id = 1
71
+ model.config.eos_token_id = tokenizer.eos_token_id = 2
72
+
73
+ return model
74
+
75
+
76
+ def _get_model_from_pretrained(model_class, model_name, from_tf=False, force_download=False):
77
  device = get_device()
78
 
79
  if device == "cuda":
80
+ return model_class.from_pretrained(
81
+ model_name,
82
  load_in_8bit=Global.load_8bit,
83
  torch_dtype=torch.float16,
84
  # device_map="auto",
85
  # ? https://github.com/tloen/alpaca-lora/issues/21
86
  device_map={'': 0},
87
+ from_tf=from_tf,
88
+ force_download=force_download,
89
  trust_remote_code=Global.trust_remote_code
90
  )
91
  elif device == "mps":
92
+ return model_class.from_pretrained(
93
+ model_name,
94
  device_map={"": device},
95
  torch_dtype=torch.float16,
96
+ from_tf=from_tf,
97
+ force_download=force_download,
98
  trust_remote_code=Global.trust_remote_code
99
  )
100
  else:
101
+ return model_class.from_pretrained(
102
+ model_name,
103
  device_map={"": device},
104
  low_cpu_mem_usage=True,
105
+ from_tf=from_tf,
106
+ force_download=force_download,
107
  trust_remote_code=Global.trust_remote_code
108
  )
109
 
 
 
 
 
 
 
 
 
 
110
 
111
  def get_tokenizer(base_model_name):
112
  if Global.ui_dev_mode: