| 
							 | 
						"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import logging | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						import signal | 
					
					
						
						| 
							 | 
						import sys | 
					
					
						
						| 
							 | 
						from dataclasses import dataclass | 
					
					
						
						| 
							 | 
						from pathlib import Path | 
					
					
						
						| 
							 | 
						from typing import Optional | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						from datasets import Dataset | 
					
					
						
						| 
							 | 
						from optimum.bettertransformer import BetterTransformer | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from axolotl.common.cli import TrainerCliArgs | 
					
					
						
						| 
							 | 
						from axolotl.logging_config import configure_logging | 
					
					
						
						| 
							 | 
						from axolotl.utils.dict import DictDefault | 
					
					
						
						| 
							 | 
						from axolotl.utils.models import load_model, load_tokenizer | 
					
					
						
						| 
							 | 
						from axolotl.utils.trainer import setup_trainer | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | 
					
					
						
						| 
							 | 
						src_dir = os.path.join(project_root, "src") | 
					
					
						
						| 
							 | 
						sys.path.insert(0, src_dir) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						configure_logging() | 
					
					
						
						| 
							 | 
						LOG = logging.getLogger("axolotl.train") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						@dataclass | 
					
					
						
						| 
							 | 
						class TrainDatasetMeta: | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    dataclass to capture the dataset specific options for training | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    train_dataset: Dataset | 
					
					
						
						| 
							 | 
						    eval_dataset: Optional[Dataset] = None | 
					
					
						
						| 
							 | 
						    total_num_steps: Optional[int] = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def train( | 
					
					
						
						| 
							 | 
						    *, | 
					
					
						
						| 
							 | 
						    cfg: DictDefault, | 
					
					
						
						| 
							 | 
						    cli_args: TrainerCliArgs, | 
					
					
						
						| 
							 | 
						    dataset_meta: TrainDatasetMeta, | 
					
					
						
						| 
							 | 
						): | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") | 
					
					
						
						| 
							 | 
						    tokenizer = load_tokenizer(cfg) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    train_dataset = dataset_meta.train_dataset | 
					
					
						
						| 
							 | 
						    eval_dataset = dataset_meta.eval_dataset | 
					
					
						
						| 
							 | 
						    total_num_steps = dataset_meta.total_num_steps | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    LOG.info("loading model and (optionally) peft_config...") | 
					
					
						
						| 
							 | 
						    model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    safe_serialization = cfg.save_safetensors is True | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: | 
					
					
						
						| 
							 | 
						        possible_checkpoints = [ | 
					
					
						
						| 
							 | 
						            str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") | 
					
					
						
						| 
							 | 
						        ] | 
					
					
						
						| 
							 | 
						        if len(possible_checkpoints) > 0: | 
					
					
						
						| 
							 | 
						            sorted_paths = sorted( | 
					
					
						
						| 
							 | 
						                possible_checkpoints, | 
					
					
						
						| 
							 | 
						                key=lambda path: int(path.split("-")[-1]), | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            cfg.resume_from_checkpoint = sorted_paths[-1] | 
					
					
						
						| 
							 | 
						            LOG.info( | 
					
					
						
						| 
							 | 
						                f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						    resume_from_checkpoint = cfg.resume_from_checkpoint | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    trainer = setup_trainer( | 
					
					
						
						| 
							 | 
						        cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    model.config.use_cache = False | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if peft_config: | 
					
					
						
						| 
							 | 
						        LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") | 
					
					
						
						| 
							 | 
						        peft_config.save_pretrained(cfg.output_dir) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if not Path(cfg.output_dir).is_dir(): | 
					
					
						
						| 
							 | 
						        os.makedirs(cfg.output_dir, exist_ok=True) | 
					
					
						
						| 
							 | 
						    tokenizer.save_pretrained(str(Path(cfg.output_dir))) | 
					
					
						
						| 
							 | 
						    model.config.save_pretrained(str(Path(cfg.output_dir))) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if cfg.local_rank == 0: | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        def terminate_handler(_, __, model): | 
					
					
						
						| 
							 | 
						            if cfg.flash_optimum: | 
					
					
						
						| 
							 | 
						                model = BetterTransformer.reverse(model) | 
					
					
						
						| 
							 | 
						            model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) | 
					
					
						
						| 
							 | 
						            sys.exit(0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        signal.signal( | 
					
					
						
						| 
							 | 
						            signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model) | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    LOG.info("Starting trainer...") | 
					
					
						
						| 
							 | 
						    if cfg.group_by_length: | 
					
					
						
						| 
							 | 
						        LOG.info("hang tight... sorting dataset for group_by_length") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if cfg.flash_optimum: | 
					
					
						
						| 
							 | 
						        with torch.backends.cuda.sdp_kernel( | 
					
					
						
						| 
							 | 
						            enable_flash=True, enable_math=True, enable_mem_efficient=True | 
					
					
						
						| 
							 | 
						        ): | 
					
					
						
						| 
							 | 
						            trainer.train(resume_from_checkpoint=resume_from_checkpoint) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        trainer.train(resume_from_checkpoint=resume_from_checkpoint) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if trainer.is_fsdp_enabled: | 
					
					
						
						| 
							 | 
						        trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") | 
					
					
						
						| 
							 | 
						        LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if cfg.relora_steps: | 
					
					
						
						| 
							 | 
						        if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): | 
					
					
						
						| 
							 | 
						            model = model.merge_and_unload() | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            return model, tokenizer | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if cfg.fsdp: | 
					
					
						
						| 
							 | 
						        trainer.save_model(cfg.output_dir) | 
					
					
						
						| 
							 | 
						    elif cfg.local_rank == 0: | 
					
					
						
						| 
							 | 
						        if cfg.flash_optimum: | 
					
					
						
						| 
							 | 
						            model = BetterTransformer.reverse(model) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return model, tokenizer | 
					
					
						
						| 
							 | 
						
 |