zetavg commited on
Commit
98dba8d
1 Parent(s): f3676ed

support hf_access_token

Browse files
README.md CHANGED
@@ -78,11 +78,12 @@ setup: |
78
  echo "Pre-downloading base models so that you won't have to wait for long once the app is ready..."
79
  python llm_tuner/download_base_model.py --base_model_names='decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j'
80
 
81
- # Start the app. `wandb_api_key` and `wandb_project_name` are optional.
82
  run: |
83
  conda activate llm-tuner
84
  python llm_tuner/app.py \
85
  --data_dir='/data' \
 
86
  --wandb_api_key="$([ -f /data/secrets/wandb_api_key.txt ] && cat /data/secrets/wandb_api_key.txt | tr -d '\n')" \
87
  --wandb_project_name='llm-tuner' \
88
  --timezone='Atlantic/Reykjavik' \
 
78
  echo "Pre-downloading base models so that you won't have to wait for long once the app is ready..."
79
  python llm_tuner/download_base_model.py --base_model_names='decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j'
80
 
81
+ # Start the app. `hf_access_token`, `wandb_api_key` and `wandb_project_name` are optional.
82
  run: |
83
  conda activate llm-tuner
84
  python llm_tuner/app.py \
85
  --data_dir='/data' \
86
+ --hf_access_token="$([ -f /data/secrets/hf_access_token.txt ] && cat /data/secrets/hf_access_token.txt | tr -d '\n')" \
87
  --wandb_api_key="$([ -f /data/secrets/wandb_api_key.txt ] && cat /data/secrets/wandb_api_key.txt | tr -d '\n')" \
88
  --wandb_project_name='llm-tuner' \
89
  --timezone='Atlantic/Reykjavik' \
app.py CHANGED
@@ -29,6 +29,7 @@ def main(
29
  ui_dev_mode: Union[bool, None] = None,
30
  wandb_api_key: Union[str, None] = None,
31
  wandb_project: Union[str, None] = None,
 
32
  timezone: Union[str, None] = None,
33
  config: Union[str, None] = None,
34
  ):
@@ -45,6 +46,8 @@ def main(
45
 
46
  :param wandb_api_key: The API key for Weights & Biases. Setting either this or `wandb_project` will enable Weights & Biases.
47
  :param wandb_project: The default project name for Weights & Biases. Setting either this or `wandb_api_key` will enable Weights & Biases.
 
 
48
  '''
49
 
50
  config_from_file = read_yaml_config(config_path=config)
 
29
  ui_dev_mode: Union[bool, None] = None,
30
  wandb_api_key: Union[str, None] = None,
31
  wandb_project: Union[str, None] = None,
32
+ hf_access_token: Union[str, None] = None,
33
  timezone: Union[str, None] = None,
34
  config: Union[str, None] = None,
35
  ):
 
46
 
47
  :param wandb_api_key: The API key for Weights & Biases. Setting either this or `wandb_project` will enable Weights & Biases.
48
  :param wandb_project: The default project name for Weights & Biases. Setting either this or `wandb_api_key` will enable Weights & Biases.
49
+
50
+ :param hf_access_token: Provide an access token to load private models form Hugging Face Hub. An access token can be created at https://huggingface.co/settings/tokens.
51
  '''
52
 
53
  config_from_file = read_yaml_config(config_path=config)
llama_lora/config.py CHANGED
@@ -8,19 +8,25 @@ class Config:
8
  Stores the application configuration. This is a singleton class.
9
  """
10
 
 
11
  data_dir: str = ""
12
- load_8bit: bool = False
13
 
 
14
  default_base_model_name: str = ""
15
  base_model_choices: Union[List[str], str] = []
16
-
17
  trust_remote_code: bool = False
18
 
 
19
  timezone: Any = pytz.UTC
20
 
 
21
  auth_username: Union[str, None] = None
22
  auth_password: Union[str, None] = None
23
 
 
 
 
24
  # WandB
25
  enable_wandb: Union[bool, None] = None
26
  wandb_api_key: Union[str, None] = None
 
8
  Stores the application configuration. This is a singleton class.
9
  """
10
 
11
+ # Where data is stored
12
  data_dir: str = ""
 
13
 
14
+ # Model Related
15
  default_base_model_name: str = ""
16
  base_model_choices: Union[List[str], str] = []
17
+ load_8bit: bool = False
18
  trust_remote_code: bool = False
19
 
20
+ # Application Settings
21
  timezone: Any = pytz.UTC
22
 
23
+ # Authentication
24
  auth_username: Union[str, None] = None
25
  auth_password: Union[str, None] = None
26
 
27
+ # Hugging Face
28
+ hf_access_token: Union[str, None] = None
29
+
30
  # WandB
31
  enable_wandb: Union[bool, None] = None
32
  wandb_api_key: Union[str, None] = None
llama_lora/lib/finetune.py CHANGED
@@ -71,6 +71,7 @@ def train(
71
  wandb_watch: str = "false", # options: false | gradients | all
72
  wandb_log_model: str = "true", # options: false | true
73
  additional_wandb_config: Union[dict, None] = None,
 
74
  status_message_callback: Any = None,
75
  params_info_callback: Any = None,
76
  ):
@@ -88,9 +89,11 @@ def train(
88
  additional_training_arguments = None
89
  if isinstance(additional_training_arguments, str):
90
  try:
91
- additional_training_arguments = json.loads(additional_training_arguments)
 
92
  except Exception as e:
93
- raise ValueError(f"Could not parse additional_training_arguments: {e}")
 
94
 
95
  if isinstance(additional_lora_config, str):
96
  additional_lora_config = additional_lora_config.strip()
@@ -183,11 +186,13 @@ def train(
183
 
184
  if status_message_callback:
185
  if isinstance(base_model, str):
186
- cb_result = status_message_callback(f"Preparing model '{base_model}' for training...")
 
187
  if cb_result:
188
  return
189
  else:
190
- cb_result = status_message_callback("Preparing model for training...")
 
191
  if cb_result:
192
  return
193
 
@@ -201,6 +206,7 @@ def train(
201
  torch_dtype=torch.float16,
202
  llm_int8_skip_modules=lora_modules_to_save,
203
  device_map=device_map,
 
204
  )
205
  if re.match("[^/]+/llama", model_name):
206
  print(f"Setting special tokens for LLaMA model {model_name}...")
@@ -213,11 +219,14 @@ def train(
213
  if isinstance(tokenizer, str):
214
  tokenizer_name = tokenizer
215
  try:
216
- tokenizer = AutoTokenizer.from_pretrained(tokenizer)
 
 
217
  except Exception as e:
218
  if 'LLaMATokenizer' in str(e):
219
  tokenizer = LlamaTokenizer.from_pretrained(
220
  tokenizer_name,
 
221
  )
222
  else:
223
  raise e
@@ -243,7 +252,8 @@ def train(
243
  f"Got error while running prepare_model_for_int8_training(model), maybe the model has already be prepared. Original error: {e}.")
244
 
245
  if status_message_callback:
246
- cb_result = status_message_callback("Preparing PEFT model for training...")
 
247
  if cb_result:
248
  return
249
 
@@ -299,7 +309,8 @@ def train(
299
  wandb.config.update({"model": {"all_params": all_params, "trainable_params": trainable_params,
300
  "trainable%": 100 * trainable_params / all_params}})
301
  if params_info_callback:
302
- cb_result = params_info_callback(all_params=all_params, trainable_params=trainable_params)
 
303
  if cb_result:
304
  return
305
 
 
71
  wandb_watch: str = "false", # options: false | gradients | all
72
  wandb_log_model: str = "true", # options: false | true
73
  additional_wandb_config: Union[dict, None] = None,
74
+ hf_access_token: Union[str, None] = None,
75
  status_message_callback: Any = None,
76
  params_info_callback: Any = None,
77
  ):
 
89
  additional_training_arguments = None
90
  if isinstance(additional_training_arguments, str):
91
  try:
92
+ additional_training_arguments = json.loads(
93
+ additional_training_arguments)
94
  except Exception as e:
95
+ raise ValueError(
96
+ f"Could not parse additional_training_arguments: {e}")
97
 
98
  if isinstance(additional_lora_config, str):
99
  additional_lora_config = additional_lora_config.strip()
 
186
 
187
  if status_message_callback:
188
  if isinstance(base_model, str):
189
+ cb_result = status_message_callback(
190
+ f"Preparing model '{base_model}' for training...")
191
  if cb_result:
192
  return
193
  else:
194
+ cb_result = status_message_callback(
195
+ "Preparing model for training...")
196
  if cb_result:
197
  return
198
 
 
206
  torch_dtype=torch.float16,
207
  llm_int8_skip_modules=lora_modules_to_save,
208
  device_map=device_map,
209
+ use_auth_token=hf_access_token
210
  )
211
  if re.match("[^/]+/llama", model_name):
212
  print(f"Setting special tokens for LLaMA model {model_name}...")
 
219
  if isinstance(tokenizer, str):
220
  tokenizer_name = tokenizer
221
  try:
222
+ tokenizer = AutoTokenizer.from_pretrained(
223
+ tokenizer, use_auth_token=hf_access_token
224
+ )
225
  except Exception as e:
226
  if 'LLaMATokenizer' in str(e):
227
  tokenizer = LlamaTokenizer.from_pretrained(
228
  tokenizer_name,
229
+ use_auth_token=hf_access_token
230
  )
231
  else:
232
  raise e
 
252
  f"Got error while running prepare_model_for_int8_training(model), maybe the model has already be prepared. Original error: {e}.")
253
 
254
  if status_message_callback:
255
+ cb_result = status_message_callback(
256
+ "Preparing PEFT model for training...")
257
  if cb_result:
258
  return
259
 
 
309
  wandb.config.update({"model": {"all_params": all_params, "trainable_params": trainable_params,
310
  "trainable%": 100 * trainable_params / all_params}})
311
  if params_info_callback:
312
+ cb_result = params_info_callback(
313
+ all_params=all_params, trainable_params=trainable_params)
314
  if cb_result:
315
  return
316
 
llama_lora/models.py CHANGED
@@ -47,7 +47,11 @@ def get_new_base_model(base_model_name):
47
  while True:
48
  try:
49
  model = _get_model_from_pretrained(
50
- model_class, base_model_name, from_tf=from_tf, force_download=force_download)
 
 
 
 
51
  break
52
  except Exception as e:
53
  if 'from_tf' in str(e):
@@ -83,7 +87,9 @@ def get_new_base_model(base_model_name):
83
  return model
84
 
85
 
86
- def _get_model_from_pretrained(model_class, model_name, from_tf=False, force_download=False):
 
 
87
  torch = get_torch()
88
  device = get_device()
89
 
@@ -97,7 +103,8 @@ def _get_model_from_pretrained(model_class, model_name, from_tf=False, force_dow
97
  device_map={'': 0},
98
  from_tf=from_tf,
99
  force_download=force_download,
100
- trust_remote_code=Config.trust_remote_code
 
101
  )
102
  elif device == "mps":
103
  return model_class.from_pretrained(
@@ -106,7 +113,8 @@ def _get_model_from_pretrained(model_class, model_name, from_tf=False, force_dow
106
  torch_dtype=torch.float16,
107
  from_tf=from_tf,
108
  force_download=force_download,
109
- trust_remote_code=Config.trust_remote_code
 
110
  )
111
  else:
112
  return model_class.from_pretrained(
@@ -115,7 +123,8 @@ def _get_model_from_pretrained(model_class, model_name, from_tf=False, force_dow
115
  low_cpu_mem_usage=True,
116
  from_tf=from_tf,
117
  force_download=force_download,
118
- trust_remote_code=Config.trust_remote_code
 
119
  )
120
 
121
 
@@ -133,13 +142,15 @@ def get_tokenizer(base_model_name):
133
  try:
134
  tokenizer = AutoTokenizer.from_pretrained(
135
  base_model_name,
136
- trust_remote_code=Config.trust_remote_code
 
137
  )
138
  except Exception as e:
139
  if 'LLaMATokenizer' in str(e):
140
  tokenizer = LlamaTokenizer.from_pretrained(
141
  base_model_name,
142
- trust_remote_code=Config.trust_remote_code
 
143
  )
144
  else:
145
  raise e
@@ -210,6 +221,7 @@ def get_model(
210
  torch_dtype=torch.float16,
211
  # ? https://github.com/tloen/alpaca-lora/issues/21
212
  device_map={'': 0},
 
213
  )
214
  elif device == "mps":
215
  model = PeftModel.from_pretrained(
@@ -217,12 +229,14 @@ def get_model(
217
  peft_model_name_or_path,
218
  device_map={"": device},
219
  torch_dtype=torch.float16,
 
220
  )
221
  else:
222
  model = PeftModel.from_pretrained(
223
  model,
224
  peft_model_name_or_path,
225
  device_map={"": device},
 
226
  )
227
 
228
  if re.match("[^/]+/llama", base_model_name):
 
47
  while True:
48
  try:
49
  model = _get_model_from_pretrained(
50
+ model_class,
51
+ base_model_name,
52
+ from_tf=from_tf,
53
+ force_download=force_download
54
+ )
55
  break
56
  except Exception as e:
57
  if 'from_tf' in str(e):
 
87
  return model
88
 
89
 
90
+ def _get_model_from_pretrained(
91
+ model_class, model_name,
92
+ from_tf=False, force_download=False):
93
  torch = get_torch()
94
  device = get_device()
95
 
 
103
  device_map={'': 0},
104
  from_tf=from_tf,
105
  force_download=force_download,
106
+ trust_remote_code=Config.trust_remote_code,
107
+ use_auth_token=Config.hf_access_token
108
  )
109
  elif device == "mps":
110
  return model_class.from_pretrained(
 
113
  torch_dtype=torch.float16,
114
  from_tf=from_tf,
115
  force_download=force_download,
116
+ trust_remote_code=Config.trust_remote_code,
117
+ use_auth_token=Config.hf_access_token
118
  )
119
  else:
120
  return model_class.from_pretrained(
 
123
  low_cpu_mem_usage=True,
124
  from_tf=from_tf,
125
  force_download=force_download,
126
+ trust_remote_code=Config.trust_remote_code,
127
+ use_auth_token=Config.hf_access_token
128
  )
129
 
130
 
 
142
  try:
143
  tokenizer = AutoTokenizer.from_pretrained(
144
  base_model_name,
145
+ trust_remote_code=Config.trust_remote_code,
146
+ use_auth_token=Config.hf_access_token
147
  )
148
  except Exception as e:
149
  if 'LLaMATokenizer' in str(e):
150
  tokenizer = LlamaTokenizer.from_pretrained(
151
  base_model_name,
152
+ trust_remote_code=Config.trust_remote_code,
153
+ use_auth_token=Config.hf_access_token
154
  )
155
  else:
156
  raise e
 
221
  torch_dtype=torch.float16,
222
  # ? https://github.com/tloen/alpaca-lora/issues/21
223
  device_map={'': 0},
224
+ use_auth_token=Config.hf_access_token
225
  )
226
  elif device == "mps":
227
  model = PeftModel.from_pretrained(
 
229
  peft_model_name_or_path,
230
  device_map={"": device},
231
  torch_dtype=torch.float16,
232
+ use_auth_token=Config.hf_access_token
233
  )
234
  else:
235
  model = PeftModel.from_pretrained(
236
  model,
237
  peft_model_name_or_path,
238
  device_map={"": device},
239
+ use_auth_token=Config.hf_access_token
240
  )
241
 
242
  if re.match("[^/]+/llama", base_model_name):