AutoTrain documentation

AutoTrain Configs

You are viewing v0.7.104 version. A newer version v0.8.24 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

AutoTrain Configs

AutoTrain Configs are the way to use and train models using AutoTrain locally.

Once you have installed AutoTrain Advanced, you can use the following command to train models using AutoTrain config files:

$ export HF_USERNAME=your_hugging_face_username
$ export HF_TOKEN=your_hugging_face_write_token

$ autotrain --config path/to/config.yaml

Example configurations for all tasks can be found in the configs directory of the AutoTrain Advanced GitHub repository.

Here is an example of an AutoTrain config file:

task: llm
base_model: meta-llama/Meta-Llama-3-8B-Instruct
project_name: autotrain-llama3-8b-orpo
log: tensorboard
backend: local

data:
  path: argilla/distilabel-capybara-dpo-7k-binarized
  train_split: train
  valid_split: null
  chat_template: chatml
  column_mapping:
    text_column: chosen
    rejected_text_column: rejected

params:
  trainer: orpo
  block_size: 1024
  model_max_length: 2048
  max_prompt_length: 512
  epochs: 3
  batch_size: 2
  lr: 3e-5
  peft: true
  quantization: int4
  target_modules: all-linear
  padding: right
  optimizer: adamw_torch
  scheduler: linear
  gradient_accumulation: 4
  mixed_precision: bf16

hub:
  username: ${HF_USERNAME}
  token: ${HF_TOKEN}
  push_to_hub: true

In this config, we are finetuning the meta-llama/Meta-Llama-3-8B-Instruct model on the argilla/distilabel-capybara-dpo-7k-binarized dataset using the orpo trainer for 3 epochs with a batch size of 2 and a learning rate of 3e-5. More information on the available parameters can be found in the Data Formats and Parameters section.

In case you dont want to push the model to hub, you can set push_to_hub to false in the config file. If not pushing the model to hub username and token are not required. Note: they may still be needed if you are trying to access gated models or datasets.

< > Update on GitHub