jupyter lab fixes (#1139) [skip ci]
Browse files* add a basic notebook for lab users in the root
* update notebook and fix cors for jupyter
* cell is code
* fix eval batch size check
* remove intro notebook
- docker/Dockerfile-cloud +1 -1
- scripts/cloud-entrypoint.sh +1 -1
- src/axolotl/cli/train.py +13 -7
- src/axolotl/core/trainer_builder.py +4 -3
- src/axolotl/utils/bench.py +2 -1
- src/axolotl/utils/models.py +6 -2
docker/Dockerfile-cloud
CHANGED
|
@@ -12,7 +12,7 @@ EXPOSE 22
|
|
| 12 |
|
| 13 |
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
|
| 14 |
|
| 15 |
-
RUN pip install jupyterlab notebook && \
|
| 16 |
jupyter lab clean
|
| 17 |
RUN apt install --yes --no-install-recommends openssh-server tmux && \
|
| 18 |
mkdir -p ~/.ssh && \
|
|
|
|
| 12 |
|
| 13 |
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
|
| 14 |
|
| 15 |
+
RUN pip install jupyterlab notebook ipywidgets && \
|
| 16 |
jupyter lab clean
|
| 17 |
RUN apt install --yes --no-install-recommends openssh-server tmux && \
|
| 18 |
mkdir -p ~/.ssh && \
|
scripts/cloud-entrypoint.sh
CHANGED
|
@@ -33,7 +33,7 @@ fi
|
|
| 33 |
|
| 34 |
if [ "$JUPYTER_DISABLE" != "1" ]; then
|
| 35 |
# Run Jupyter Lab in the background
|
| 36 |
-
jupyter lab --allow-root --
|
| 37 |
fi
|
| 38 |
|
| 39 |
# Execute the passed arguments (CMD)
|
|
|
|
| 33 |
|
| 34 |
if [ "$JUPYTER_DISABLE" != "1" ]; then
|
| 35 |
# Run Jupyter Lab in the background
|
| 36 |
+
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace &
|
| 37 |
fi
|
| 38 |
|
| 39 |
# Execute the passed arguments (CMD)
|
src/axolotl/cli/train.py
CHANGED
|
@@ -3,9 +3,11 @@ CLI to run training on a model
|
|
| 3 |
"""
|
| 4 |
import logging
|
| 5 |
from pathlib import Path
|
|
|
|
| 6 |
|
| 7 |
import fire
|
| 8 |
import transformers
|
|
|
|
| 9 |
|
| 10 |
from axolotl.cli import (
|
| 11 |
check_accelerate_default_config,
|
|
@@ -24,19 +26,23 @@ LOG = logging.getLogger("axolotl.cli.train")
|
|
| 24 |
def do_cli(config: Path = Path("examples/"), **kwargs):
|
| 25 |
# pylint: disable=duplicate-code
|
| 26 |
parsed_cfg = load_cfg(config, **kwargs)
|
| 27 |
-
print_axolotl_text_art()
|
| 28 |
-
check_accelerate_default_config()
|
| 29 |
-
check_user_token()
|
| 30 |
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
| 31 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
| 32 |
return_remaining_strings=True
|
| 33 |
)
|
|
|
|
|
|
|
| 34 |
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
else:
|
| 38 |
-
dataset_meta = load_datasets(cfg=
|
| 39 |
-
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
if __name__ == "__main__":
|
|
|
|
| 3 |
"""
|
| 4 |
import logging
|
| 5 |
from pathlib import Path
|
| 6 |
+
from typing import Tuple
|
| 7 |
|
| 8 |
import fire
|
| 9 |
import transformers
|
| 10 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
| 11 |
|
| 12 |
from axolotl.cli import (
|
| 13 |
check_accelerate_default_config,
|
|
|
|
| 26 |
def do_cli(config: Path = Path("examples/"), **kwargs):
|
| 27 |
# pylint: disable=duplicate-code
|
| 28 |
parsed_cfg = load_cfg(config, **kwargs)
|
|
|
|
|
|
|
|
|
|
| 29 |
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
| 30 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
| 31 |
return_remaining_strings=True
|
| 32 |
)
|
| 33 |
+
return do_train(parsed_cfg, parsed_cli_args)
|
| 34 |
+
|
| 35 |
|
| 36 |
+
def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
| 37 |
+
print_axolotl_text_art()
|
| 38 |
+
check_accelerate_default_config()
|
| 39 |
+
check_user_token()
|
| 40 |
+
if cfg.rl:
|
| 41 |
+
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
| 42 |
else:
|
| 43 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
| 44 |
+
|
| 45 |
+
return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 46 |
|
| 47 |
|
| 48 |
if __name__ == "__main__":
|
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -746,9 +746,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 746 |
training_arguments_kwargs[
|
| 747 |
"per_device_train_batch_size"
|
| 748 |
] = self.cfg.micro_batch_size
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
|
|
|
| 752 |
training_arguments_kwargs[
|
| 753 |
"gradient_accumulation_steps"
|
| 754 |
] = self.cfg.gradient_accumulation_steps
|
|
|
|
| 746 |
training_arguments_kwargs[
|
| 747 |
"per_device_train_batch_size"
|
| 748 |
] = self.cfg.micro_batch_size
|
| 749 |
+
if self.cfg.eval_batch_size:
|
| 750 |
+
training_arguments_kwargs[
|
| 751 |
+
"per_device_eval_batch_size"
|
| 752 |
+
] = self.cfg.eval_batch_size
|
| 753 |
training_arguments_kwargs[
|
| 754 |
"gradient_accumulation_steps"
|
| 755 |
] = self.cfg.gradient_accumulation_steps
|
src/axolotl/utils/bench.py
CHANGED
|
@@ -20,7 +20,8 @@ def check_cuda_device(default_value):
|
|
| 20 |
device = kwargs.get("device", args[0] if args else None)
|
| 21 |
|
| 22 |
if (
|
| 23 |
-
|
|
|
|
| 24 |
or device == "auto"
|
| 25 |
or torch.device(device).type == "cpu"
|
| 26 |
):
|
|
|
|
| 20 |
device = kwargs.get("device", args[0] if args else None)
|
| 21 |
|
| 22 |
if (
|
| 23 |
+
device is None
|
| 24 |
+
or not torch.cuda.is_available()
|
| 25 |
or device == "auto"
|
| 26 |
or torch.device(device).type == "cpu"
|
| 27 |
):
|
src/axolotl/utils/models.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
import logging
|
| 3 |
import math
|
| 4 |
import os
|
| 5 |
-
from typing import Any, Optional, Tuple, Union # noqa: F401
|
| 6 |
|
| 7 |
import addict
|
| 8 |
import bitsandbytes as bnb
|
|
@@ -348,7 +348,11 @@ def load_model(
|
|
| 348 |
LOG.info("patching _expand_mask")
|
| 349 |
hijack_expand_mask()
|
| 350 |
|
| 351 |
-
model_kwargs = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
|
| 353 |
max_memory = cfg.max_memory
|
| 354 |
device_map = cfg.device_map
|
|
|
|
| 2 |
import logging
|
| 3 |
import math
|
| 4 |
import os
|
| 5 |
+
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
| 6 |
|
| 7 |
import addict
|
| 8 |
import bitsandbytes as bnb
|
|
|
|
| 348 |
LOG.info("patching _expand_mask")
|
| 349 |
hijack_expand_mask()
|
| 350 |
|
| 351 |
+
model_kwargs: Dict[str, Any] = {}
|
| 352 |
+
|
| 353 |
+
if cfg.model_kwargs:
|
| 354 |
+
for key, val in model_kwargs.items():
|
| 355 |
+
model_kwargs[key] = val
|
| 356 |
|
| 357 |
max_memory = cfg.max_memory
|
| 358 |
device_map = cfg.device_map
|