zetavg commited on
Commit
03b5741
1 Parent(s): 37f2c31

support --trust_remote_code, resolves #6

Browse files
Files changed (3) hide show
  1. app.py +3 -0
  2. llama_lora/globals.py +2 -0
  3. llama_lora/models.py +24 -8
app.py CHANGED
@@ -15,6 +15,7 @@ def main(
15
  base_model: str = "",
16
  data_dir: str = "",
17
  base_model_choices: str = "",
 
18
  # Allows to listen on all interfaces by providing '0.0.0.0'.
19
  server_name: str = "127.0.0.1",
20
  share: bool = False,
@@ -60,6 +61,8 @@ def main(
60
  if base_model not in Global.base_model_choices:
61
  Global.base_model_choices = [base_model] + Global.base_model_choices
62
 
 
 
63
  Global.data_dir = os.path.abspath(data_dir)
64
  Global.load_8bit = load_8bit
65
 
 
15
  base_model: str = "",
16
  data_dir: str = "",
17
  base_model_choices: str = "",
18
+ trust_remote_code: bool = False,
19
  # Allows to listen on all interfaces by providing '0.0.0.0'.
20
  server_name: str = "127.0.0.1",
21
  share: bool = False,
 
61
  if base_model not in Global.base_model_choices:
62
  Global.base_model_choices = [base_model] + Global.base_model_choices
63
 
64
+ Global.trust_remote_code = trust_remote_code
65
+
66
  Global.data_dir = os.path.abspath(data_dir)
67
  Global.load_8bit = load_8bit
68
 
llama_lora/globals.py CHANGED
@@ -20,6 +20,8 @@ class Global:
20
  base_model_name: str = ""
21
  base_model_choices: List[str] = []
22
 
 
 
23
  # Functions
24
  train_fn: Any = train
25
 
 
20
  base_model_name: str = ""
21
  base_model_choices: List[str] = []
22
 
23
+ trust_remote_code = False
24
+
25
  # Functions
26
  train_fn: Any = train
27
 
llama_lora/models.py CHANGED
@@ -37,16 +37,21 @@ def get_new_base_model(base_model_name):
37
  # device_map="auto",
38
  # ? https://github.com/tloen/alpaca-lora/issues/21
39
  device_map={'': 0},
 
40
  )
41
  elif device == "mps":
42
  model = AutoModelForCausalLM.from_pretrained(
43
  base_model_name,
44
  device_map={"": device},
45
  torch_dtype=torch.float16,
 
46
  )
47
  else:
48
  model = AutoModelForCausalLM.from_pretrained(
49
- base_model_name, device_map={"": device}, low_cpu_mem_usage=True
 
 
 
50
  )
51
 
52
  tokenizer = get_tokenizer(base_model_name)
@@ -68,10 +73,16 @@ def get_tokenizer(base_model_name):
68
  return loaded_tokenizer
69
 
70
  try:
71
- tokenizer = AutoTokenizer.from_pretrained(base_model_name)
 
 
 
72
  except Exception as e:
73
  if 'LLaMATokenizer' in str(e):
74
- tokenizer = LlamaTokenizer.from_pretrained(base_model_name)
 
 
 
75
  else:
76
  raise e
77
 
@@ -100,13 +111,15 @@ def get_model(
100
  peft_model_name_or_path = peft_model_name
101
 
102
  if peft_model_name:
103
- lora_models_directory_path = os.path.join(Global.data_dir, "lora_models")
 
104
  possible_lora_model_path = os.path.join(
105
  lora_models_directory_path, peft_model_name)
106
  if os.path.isdir(possible_lora_model_path):
107
  peft_model_name_or_path = possible_lora_model_path
108
 
109
- possible_model_info_json_path = os.path.join(possible_lora_model_path, "info.json")
 
110
  if os.path.isfile(possible_model_info_json_path):
111
  try:
112
  with open(possible_model_info_json_path, "r") as file:
@@ -115,7 +128,8 @@ def get_model(
115
  if possible_hf_model_name and json_data.get("load_from_hf"):
116
  peft_model_name_or_path = possible_hf_model_name
117
  except Exception as e:
118
- raise ValueError("Error reading model info from {possible_model_info_json_path}: {e}")
 
119
 
120
  Global.loaded_models.prepare_to_set()
121
  clear_cache()
@@ -148,7 +162,8 @@ def get_model(
148
  )
149
 
150
  if re.match("[^/]+/llama", base_model_name):
151
- model.config.pad_token_id = get_tokenizer(base_model_name).pad_token_id = 0
 
152
  model.config.bos_token_id = 1
153
  model.config.eos_token_id = 2
154
 
@@ -166,7 +181,8 @@ def get_model(
166
 
167
 
168
  def prepare_base_model(base_model_name=Global.default_base_model_name):
169
- Global.new_base_model_that_is_ready_to_be_used = get_new_base_model(base_model_name)
 
170
  Global.name_of_new_base_model_that_is_ready_to_be_used = base_model_name
171
 
172
 
 
37
  # device_map="auto",
38
  # ? https://github.com/tloen/alpaca-lora/issues/21
39
  device_map={'': 0},
40
+ trust_remote_code=Global.trust_remote_code
41
  )
42
  elif device == "mps":
43
  model = AutoModelForCausalLM.from_pretrained(
44
  base_model_name,
45
  device_map={"": device},
46
  torch_dtype=torch.float16,
47
+ trust_remote_code=Global.trust_remote_code
48
  )
49
  else:
50
  model = AutoModelForCausalLM.from_pretrained(
51
+ base_model_name,
52
+ device_map={"": device},
53
+ low_cpu_mem_usage=True,
54
+ trust_remote_code=Global.trust_remote_code
55
  )
56
 
57
  tokenizer = get_tokenizer(base_model_name)
 
73
  return loaded_tokenizer
74
 
75
  try:
76
+ tokenizer = AutoTokenizer.from_pretrained(
77
+ base_model_name,
78
+ trust_remote_code=Global.trust_remote_code
79
+ )
80
  except Exception as e:
81
  if 'LLaMATokenizer' in str(e):
82
+ tokenizer = LlamaTokenizer.from_pretrained(
83
+ base_model_name,
84
+ trust_remote_code=Global.trust_remote_code
85
+ )
86
  else:
87
  raise e
88
 
 
111
  peft_model_name_or_path = peft_model_name
112
 
113
  if peft_model_name:
114
+ lora_models_directory_path = os.path.join(
115
+ Global.data_dir, "lora_models")
116
  possible_lora_model_path = os.path.join(
117
  lora_models_directory_path, peft_model_name)
118
  if os.path.isdir(possible_lora_model_path):
119
  peft_model_name_or_path = possible_lora_model_path
120
 
121
+ possible_model_info_json_path = os.path.join(
122
+ possible_lora_model_path, "info.json")
123
  if os.path.isfile(possible_model_info_json_path):
124
  try:
125
  with open(possible_model_info_json_path, "r") as file:
 
128
  if possible_hf_model_name and json_data.get("load_from_hf"):
129
  peft_model_name_or_path = possible_hf_model_name
130
  except Exception as e:
131
+ raise ValueError(
132
+ "Error reading model info from {possible_model_info_json_path}: {e}")
133
 
134
  Global.loaded_models.prepare_to_set()
135
  clear_cache()
 
162
  )
163
 
164
  if re.match("[^/]+/llama", base_model_name):
165
+ model.config.pad_token_id = get_tokenizer(
166
+ base_model_name).pad_token_id = 0
167
  model.config.bos_token_id = 1
168
  model.config.eos_token_id = 2
169
 
 
181
 
182
 
183
  def prepare_base_model(base_model_name=Global.default_base_model_name):
184
+ Global.new_base_model_that_is_ready_to_be_used = get_new_base_model(
185
+ base_model_name)
186
  Global.name_of_new_base_model_that_is_ready_to_be_used = base_model_name
187
 
188