# Config for multi-device LoRA finetuning in lora_finetune_distributed.py # using a Phi3 mini (3.8B) model # # This config assumes that you've run the following command before launching # this run: # tune download microsoft/Phi-3-mini-4k-instruct --output-dir /tmp/Phi-3-mini-4k-instruct --hf-token --ignore-patterns "" # # To launch on 2 devices, run the following command from root: # tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config phi3/mini_lora # # You can add specific overrides through the command line. For example # to override the checkpointer directory while launching training # you can run: # tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config phi3/mini_lora checkpointer.checkpoint_dir= # # This config works best when the model is being fine-tuned on 2+ GPUs. # For single device LoRA finetuning please use mini_lora_single_device.yaml # or mini_qlora_single_device.yaml # Model Arguments model: _component_: torchtune.models.phi3.lora_phi3_mini lora_attn_modules: ['q_proj', 'v_proj'] apply_lora_to_mlp: False apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 tokenizer: _component_: torchtune.models.phi3.phi3_mini_tokenizer path: ./phi3/tokenizer.model checkpointer: _component_: torchtune.utils.FullModelHFCheckpointer checkpoint_dir: ./phi3 checkpoint_files: [ model-00001-of-00002.safetensors, model-00002-of-00002.safetensors ] output_dir: lora-phi3-math model_type: PHI3_MINI resume_from_checkpoint: False # Dataset and Sampler dataset: _component_: torchtune.datasets.instruct_dataset source: TIGER-Lab/MATH-plus template: AlpacaInstructTemplate train_on_input: True packed: False max_seq_len: 4096 split: train seed: 123 shuffle: True batch_size: 2 # Optimizer and Scheduler optimizer: _component_: torch.optim.AdamW weight_decay: 0.01 lr: 3e-4 lr_scheduler: _component_: torchtune.modules.get_cosine_schedule_with_warmup num_warmup_steps: 100 loss: _component_: torch.nn.CrossEntropyLoss # Training epochs: 1 max_steps_per_epoch: 2000 gradient_accumulation_steps: 16 # Logging output_dir: lora-phi3-math metric_logger: _component_: torchtune.utils.metric_logging.WandBLogger project: lora-phi3-math log_every_n_steps: 1 log_peak_memory_stats: False # Environment device: cuda dtype: bf16 enable_activation_checkpointing: False