hamel commited on
Commit
9bca7db
1 Parent(s): 91cf4ee

add support for https remote yamls (#1277)

Browse files
.mypy.ini CHANGED
@@ -32,6 +32,9 @@ ignore_missing_imports = True
32
  [mypy-bitsandbytes]
33
  ignore_missing_imports = True
34
 
 
 
 
35
  [mypy-datasets]
36
  ignore_missing_imports = True
37
 
 
32
  [mypy-bitsandbytes]
33
  ignore_missing_imports = True
34
 
35
+ [mypy-requests]
36
+ ignore_missing_imports = True
37
+
38
  [mypy-datasets]
39
  ignore_missing_imports = True
40
 
README.md CHANGED
@@ -121,6 +121,10 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
121
  # gradio
122
  accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
123
  --lora_model_dir="./lora-out" --gradio
 
 
 
 
124
  ```
125
 
126
  ## Installation
@@ -988,6 +992,9 @@ Run
988
  accelerate launch -m axolotl.cli.train your_config.yml
989
  ```
990
 
 
 
 
991
  #### Preprocess dataset
992
 
993
  You can optionally pre-tokenize dataset with the following before finetuning.
 
121
  # gradio
122
  accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
123
  --lora_model_dir="./lora-out" --gradio
124
+
125
+ # remote yaml files - the yaml config can be hosted on a public URL
126
+ # Note: the yaml config must directly link to the **raw** yaml
127
+ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/examples/openllama-3b/lora.yml
128
  ```
129
 
130
  ## Installation
 
992
  accelerate launch -m axolotl.cli.train your_config.yml
993
  ```
994
 
995
+ > [!TIP]
996
+ > You can also reference a config file that is hosted on a public URL, for example `accelerate launch -m axolotl.cli.train https://yourdomain.com/your_config.yml`
997
+
998
  #### Preprocess dataset
999
 
1000
  You can optionally pre-tokenize dataset with the following before finetuning.
requirements-dev.txt CHANGED
@@ -1,3 +1,4 @@
1
  pre-commit
2
  black
3
  mypy
 
 
1
  pre-commit
2
  black
3
  mypy
4
+ types-requests
requirements.txt CHANGED
@@ -9,6 +9,7 @@ deepspeed>=0.13.1
9
  addict
10
  fire
11
  PyYAML>=6.0
 
12
  datasets>=2.15.0
13
  flash-attn==2.3.3
14
  sentencepiece
 
9
  addict
10
  fire
11
  PyYAML>=6.0
12
+ requests
13
  datasets>=2.15.0
14
  flash-attn==2.3.3
15
  sentencepiece
src/axolotl/cli/__init__.py CHANGED
@@ -1,16 +1,20 @@
1
  """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
2
 
3
  import importlib
 
4
  import logging
5
  import math
6
  import os
7
  import random
8
  import sys
 
9
  from pathlib import Path
10
  from threading import Thread
11
  from typing import Any, Dict, List, Optional, Union
 
12
 
13
  import gradio as gr
 
14
  import torch
15
  import yaml
16
 
@@ -59,6 +63,52 @@ def print_axolotl_text_art(suffix=None):
59
  print(ascii_art)
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def get_multi_line_input() -> Optional[str]:
63
  print("Give me an instruction (Ctrl + D to submit): ")
64
  instruction = ""
@@ -270,9 +320,10 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
270
  return not any(el in list2 for el in list1)
271
 
272
 
273
- def load_cfg(config: Path = Path("examples/"), **kwargs):
 
274
  if Path(config).is_dir():
275
- config = choose_config(config)
276
 
277
  # load the config from the yaml file
278
  with open(config, encoding="utf-8") as file:
 
1
  """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
2
 
3
  import importlib
4
+ import json
5
  import logging
6
  import math
7
  import os
8
  import random
9
  import sys
10
+ import tempfile
11
  from pathlib import Path
12
  from threading import Thread
13
  from typing import Any, Dict, List, Optional, Union
14
+ from urllib.parse import urlparse
15
 
16
  import gradio as gr
17
+ import requests
18
  import torch
19
  import yaml
20
 
 
63
  print(ascii_art)
64
 
65
 
66
+ def check_remote_config(config: Union[str, Path]):
67
+ # Check if the config is a valid HTTPS URL to a .yml or .yaml file
68
+ if not (isinstance(config, str) and config.startswith("https://")):
69
+ return config # Return the original value if it's not a valid URL
70
+
71
+ filename = os.path.basename(urlparse(config).path)
72
+ temp_dir = tempfile.mkdtemp()
73
+
74
+ try:
75
+ response = requests.get(config, timeout=30)
76
+ response.raise_for_status() # Check for HTTP errors
77
+
78
+ content = response.content
79
+ try:
80
+ # Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML
81
+ json.loads(content)
82
+ # Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link
83
+ LOG.warning(
84
+ f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended."
85
+ )
86
+ except json.JSONDecodeError:
87
+ # If it's not valid JSON, verify it's valid YAML
88
+ try:
89
+ yaml.safe_load(content)
90
+ except yaml.YAMLError as err:
91
+ raise ValueError(
92
+ f"Failed to parse the content at {config} as YAML: {err}"
93
+ ) from err
94
+
95
+ # Write the content to a file if it's valid YAML (or JSON treated as YAML)
96
+ output_path = Path(temp_dir) / filename
97
+ with open(output_path, "wb") as file:
98
+ file.write(content)
99
+ LOG.info(
100
+ f"Using the following config obtained from {config}:\n\n{content.decode('utf-8')}\n"
101
+ )
102
+ return output_path
103
+
104
+ except requests.RequestException as err:
105
+ # This catches all requests-related exceptions including HTTPError
106
+ raise RuntimeError(f"Failed to download {config}: {err}") from err
107
+ except Exception as err:
108
+ # Catch-all for any other exceptions
109
+ raise err
110
+
111
+
112
  def get_multi_line_input() -> Optional[str]:
113
  print("Give me an instruction (Ctrl + D to submit): ")
114
  instruction = ""
 
320
  return not any(el in list2 for el in list1)
321
 
322
 
323
+ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
324
+ config = check_remote_config(config)
325
  if Path(config).is_dir():
326
+ config = choose_config(Path(config))
327
 
328
  # load the config from the yaml file
329
  with open(config, encoding="utf-8") as file:
src/axolotl/cli/preprocess.py CHANGED
@@ -3,6 +3,7 @@ CLI to run training on a model
3
  """
4
  import logging
5
  from pathlib import Path
 
6
 
7
  import fire
8
  import transformers
@@ -23,7 +24,7 @@ from axolotl.prompt_strategies.sharegpt import register_chatml_template
23
  LOG = logging.getLogger("axolotl.cli.preprocess")
24
 
25
 
26
- def do_cli(config: Path = Path("examples/"), **kwargs):
27
  # pylint: disable=duplicate-code
28
  print_axolotl_text_art()
29
  parsed_cfg = load_cfg(config, **kwargs)
 
3
  """
4
  import logging
5
  from pathlib import Path
6
+ from typing import Union
7
 
8
  import fire
9
  import transformers
 
24
  LOG = logging.getLogger("axolotl.cli.preprocess")
25
 
26
 
27
+ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
28
  # pylint: disable=duplicate-code
29
  print_axolotl_text_art()
30
  parsed_cfg = load_cfg(config, **kwargs)
src/axolotl/cli/shard.py CHANGED
@@ -3,6 +3,7 @@ CLI to shard a trained model into 10GiB chunks
3
  """
4
  import logging
5
  from pathlib import Path
 
6
 
7
  import fire
8
  import transformers
@@ -25,7 +26,7 @@ def shard(
25
  model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
26
 
27
 
28
- def do_cli(config: Path = Path("examples/"), **kwargs):
29
  # pylint: disable=duplicate-code
30
  print_axolotl_text_art()
31
  parsed_cfg = load_cfg(config, **kwargs)
 
3
  """
4
  import logging
5
  from pathlib import Path
6
+ from typing import Union
7
 
8
  import fire
9
  import transformers
 
26
  model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
27
 
28
 
29
+ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
30
  # pylint: disable=duplicate-code
31
  print_axolotl_text_art()
32
  parsed_cfg = load_cfg(config, **kwargs)
src/axolotl/cli/train.py CHANGED
@@ -3,7 +3,7 @@ CLI to run training on a model
3
  """
4
  import logging
5
  from pathlib import Path
6
- from typing import Tuple
7
 
8
  import fire
9
  from transformers.hf_argparser import HfArgumentParser
@@ -25,7 +25,7 @@ from axolotl.train import train
25
  LOG = logging.getLogger("axolotl.cli.train")
26
 
27
 
28
- def do_cli(config: Path = Path("examples/"), **kwargs):
29
  # pylint: disable=duplicate-code
30
  parsed_cfg = load_cfg(config, **kwargs)
31
  parser = HfArgumentParser((TrainerCliArgs))
 
3
  """
4
  import logging
5
  from pathlib import Path
6
+ from typing import Tuple, Union
7
 
8
  import fire
9
  from transformers.hf_argparser import HfArgumentParser
 
25
  LOG = logging.getLogger("axolotl.cli.train")
26
 
27
 
28
+ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
29
  # pylint: disable=duplicate-code
30
  parsed_cfg = load_cfg(config, **kwargs)
31
  parser = HfArgumentParser((TrainerCliArgs))