Napuh
commited on
Warn users to login to HuggingFace (#645)
Browse files* added warning if user is not logged in HF
* updated doc to suggest logging in to HF
- README.md +5 -0
- scripts/finetune.py +2 -0
- src/axolotl/cli/__init__.py +15 -0
- src/axolotl/cli/train.py +2 -0
README.md
CHANGED
@@ -124,6 +124,11 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|
124 |
pip3 install packaging
|
125 |
pip3 install -e '.[flash-attn,deepspeed]'
|
126 |
```
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
- LambdaLabs
|
129 |
<details>
|
|
|
124 |
pip3 install packaging
|
125 |
pip3 install -e '.[flash-attn,deepspeed]'
|
126 |
```
|
127 |
+
4. (Optional) Login to Huggingface to use gated models/datasets.
|
128 |
+
```bash
|
129 |
+
huggingface-cli login
|
130 |
+
```
|
131 |
+
Get the token at huggingface.co/settings/tokens
|
132 |
|
133 |
- LambdaLabs
|
134 |
<details>
|
scripts/finetune.py
CHANGED
@@ -7,6 +7,7 @@ import transformers
|
|
7 |
|
8 |
from axolotl.cli import (
|
9 |
check_accelerate_default_config,
|
|
|
10 |
do_inference,
|
11 |
do_merge_lora,
|
12 |
load_cfg,
|
@@ -31,6 +32,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|
31 |
)
|
32 |
parsed_cfg = load_cfg(config, **kwargs)
|
33 |
check_accelerate_default_config()
|
|
|
34 |
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
35 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
36 |
return_remaining_strings=True
|
|
|
7 |
|
8 |
from axolotl.cli import (
|
9 |
check_accelerate_default_config,
|
10 |
+
check_user_token,
|
11 |
do_inference,
|
12 |
do_merge_lora,
|
13 |
load_cfg,
|
|
|
32 |
)
|
33 |
parsed_cfg = load_cfg(config, **kwargs)
|
34 |
check_accelerate_default_config()
|
35 |
+
check_user_token()
|
36 |
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
37 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
38 |
return_remaining_strings=True
|
src/axolotl/cli/__init__.py
CHANGED
@@ -14,6 +14,8 @@ import yaml
|
|
14 |
# add src to the pythonpath so we don't need to pip install this
|
15 |
from accelerate.commands.config import config_args
|
16 |
from art import text2art
|
|
|
|
|
17 |
from transformers import GenerationConfig, TextStreamer
|
18 |
|
19 |
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
@@ -247,3 +249,16 @@ def check_accelerate_default_config():
|
|
247 |
LOG.warning(
|
248 |
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
|
249 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
# add src to the pythonpath so we don't need to pip install this
|
15 |
from accelerate.commands.config import config_args
|
16 |
from art import text2art
|
17 |
+
from huggingface_hub import HfApi
|
18 |
+
from huggingface_hub.utils import LocalTokenNotFoundError
|
19 |
from transformers import GenerationConfig, TextStreamer
|
20 |
|
21 |
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
|
|
249 |
LOG.warning(
|
250 |
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
|
251 |
)
|
252 |
+
|
253 |
+
|
254 |
+
def check_user_token():
|
255 |
+
# Verify if token is valid
|
256 |
+
api = HfApi()
|
257 |
+
try:
|
258 |
+
user_info = api.whoami()
|
259 |
+
return bool(user_info)
|
260 |
+
except LocalTokenNotFoundError:
|
261 |
+
LOG.warning(
|
262 |
+
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
263 |
+
)
|
264 |
+
return False
|
src/axolotl/cli/train.py
CHANGED
@@ -8,6 +8,7 @@ import transformers
|
|
8 |
|
9 |
from axolotl.cli import (
|
10 |
check_accelerate_default_config,
|
|
|
11 |
load_cfg,
|
12 |
load_datasets,
|
13 |
print_axolotl_text_art,
|
@@ -21,6 +22,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|
21 |
print_axolotl_text_art()
|
22 |
parsed_cfg = load_cfg(config, **kwargs)
|
23 |
check_accelerate_default_config()
|
|
|
24 |
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
25 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
26 |
return_remaining_strings=True
|
|
|
8 |
|
9 |
from axolotl.cli import (
|
10 |
check_accelerate_default_config,
|
11 |
+
check_user_token,
|
12 |
load_cfg,
|
13 |
load_datasets,
|
14 |
print_axolotl_text_art,
|
|
|
22 |
print_axolotl_text_art()
|
23 |
parsed_cfg = load_cfg(config, **kwargs)
|
24 |
check_accelerate_default_config()
|
25 |
+
check_user_token()
|
26 |
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
27 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
28 |
return_remaining_strings=True
|