Napuh commited on
Commit
85b0be2
·
unverified ·
1 Parent(s): 8fe0e63

Warn users to login to HuggingFace (#645)

Browse files

* added warning if user is not logged in HF

* updated doc to suggest logging in to HF

README.md CHANGED
@@ -124,6 +124,11 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
124
  pip3 install packaging
125
  pip3 install -e '.[flash-attn,deepspeed]'
126
  ```
 
 
 
 
 
127
 
128
  - LambdaLabs
129
  <details>
 
124
  pip3 install packaging
125
  pip3 install -e '.[flash-attn,deepspeed]'
126
  ```
127
+ 4. (Optional) Login to Huggingface to use gated models/datasets.
128
+ ```bash
129
+ huggingface-cli login
130
+ ```
131
+ Get the token at huggingface.co/settings/tokens
132
 
133
  - LambdaLabs
134
  <details>
scripts/finetune.py CHANGED
@@ -7,6 +7,7 @@ import transformers
7
 
8
  from axolotl.cli import (
9
  check_accelerate_default_config,
 
10
  do_inference,
11
  do_merge_lora,
12
  load_cfg,
@@ -31,6 +32,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
31
  )
32
  parsed_cfg = load_cfg(config, **kwargs)
33
  check_accelerate_default_config()
 
34
  parser = transformers.HfArgumentParser((TrainerCliArgs))
35
  parsed_cli_args, _ = parser.parse_args_into_dataclasses(
36
  return_remaining_strings=True
 
7
 
8
  from axolotl.cli import (
9
  check_accelerate_default_config,
10
+ check_user_token,
11
  do_inference,
12
  do_merge_lora,
13
  load_cfg,
 
32
  )
33
  parsed_cfg = load_cfg(config, **kwargs)
34
  check_accelerate_default_config()
35
+ check_user_token()
36
  parser = transformers.HfArgumentParser((TrainerCliArgs))
37
  parsed_cli_args, _ = parser.parse_args_into_dataclasses(
38
  return_remaining_strings=True
src/axolotl/cli/__init__.py CHANGED
@@ -14,6 +14,8 @@ import yaml
14
  # add src to the pythonpath so we don't need to pip install this
15
  from accelerate.commands.config import config_args
16
  from art import text2art
 
 
17
  from transformers import GenerationConfig, TextStreamer
18
 
19
  from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
@@ -247,3 +249,16 @@ def check_accelerate_default_config():
247
  LOG.warning(
248
  f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
249
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # add src to the pythonpath so we don't need to pip install this
15
  from accelerate.commands.config import config_args
16
  from art import text2art
17
+ from huggingface_hub import HfApi
18
+ from huggingface_hub.utils import LocalTokenNotFoundError
19
  from transformers import GenerationConfig, TextStreamer
20
 
21
  from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
 
249
  LOG.warning(
250
  f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
251
  )
252
+
253
+
254
+ def check_user_token():
255
+ # Verify if token is valid
256
+ api = HfApi()
257
+ try:
258
+ user_info = api.whoami()
259
+ return bool(user_info)
260
+ except LocalTokenNotFoundError:
261
+ LOG.warning(
262
+ "Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
263
+ )
264
+ return False
src/axolotl/cli/train.py CHANGED
@@ -8,6 +8,7 @@ import transformers
8
 
9
  from axolotl.cli import (
10
  check_accelerate_default_config,
 
11
  load_cfg,
12
  load_datasets,
13
  print_axolotl_text_art,
@@ -21,6 +22,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
21
  print_axolotl_text_art()
22
  parsed_cfg = load_cfg(config, **kwargs)
23
  check_accelerate_default_config()
 
24
  parser = transformers.HfArgumentParser((TrainerCliArgs))
25
  parsed_cli_args, _ = parser.parse_args_into_dataclasses(
26
  return_remaining_strings=True
 
8
 
9
  from axolotl.cli import (
10
  check_accelerate_default_config,
11
+ check_user_token,
12
  load_cfg,
13
  load_datasets,
14
  print_axolotl_text_art,
 
22
  print_axolotl_text_art()
23
  parsed_cfg = load_cfg(config, **kwargs)
24
  check_accelerate_default_config()
25
+ check_user_token()
26
  parser = transformers.HfArgumentParser((TrainerCliArgs))
27
  parsed_cli_args, _ = parser.parse_args_into_dataclasses(
28
  return_remaining_strings=True