# .KTO Example

https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py

In [2]:
from dataclasses import dataclass

from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format

In [3]:
# Define and parse arguments.
@dataclass
class ScriptArguments:
    """
    The arguments for the KTO training script.
    """

    dataset_name: str = "trl-lib/kto-mix-14k"


# Initialize the arguments directly
script_args = ScriptArguments(
    dataset_name="trl-lib/kto-mix-14k"
)

training_args = KTOConfig(
    output_dir="kto-aligned-model",
    num_train_epochs=1,
    per_device_train_batch_size=16,
    learning_rate=5e-7,
    lr_scheduler_type="cosine",
    gradient_accumulation_steps=1,
    logging_steps=10,
    eval_steps=500,
    warmup_ratio=0.1,
    bf16=True,
    logging_first_step=True
)

model_args = ModelConfig(
    model_name_or_path="trl-lib/qwen1.5-1.8b-sft",
    # any additional model-specific arguments
)

- @dataclass makes it easier to create classes that only contain data, making your argument definitions compact, easier to read, and automatically initialized without the need to write a custom __init__ method.
- @dataclass is used here to define a structure for the arguments that you are going to pass to the training script:
- You define a simple data structure (ScriptArguments) with a list of variables (e.g., dataset_name).
- You can quickly create instances of this structure (script_args = ScriptArguments(...)) without manually writing the initializer.


In [4]:
# Load a pretrained model
model = AutoModelForCausalLM.from_pretrained(
    model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
ref_model = AutoModelForCausalLM.from_pretrained(
    model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)

# load a tokenaizer
tokenizer = AutoTokenizer.from_pretrained(
    model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# If we are aligning a base model, we use ChatML as the default template
if tokenizer.chat_template is None:
    model, tokenizer = setup_chat_format(model, tokenizer)

config.json:   0%|          | 0.00/702 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.67G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/117 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.17k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/80.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/419 [00:00<?, ?B/s]

In [5]:
# Load the dataset
dataset = load_dataset(script_args.dataset_name)

# If needed, reformat a DPO-formatted dataset (prompt, chosen, rejected) to a KTO-format (prompt, completion, label)
dataset = maybe_unpair_preference_dataset(dataset, num_proc=training_args.dataset_num_proc)

README.md:   0%|          | 0.00/814 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/16.3M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/1.81M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/13500 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1500 [00:00<?, ? examples/s]

In [6]:
# Apply chat template
def format_dataset(example):
    example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
    example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
    return example

In [7]:
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
    dataset = dataset.map(format_dataset, num_proc=training_args.dataset_num_proc)


Map:   0%|          | 0/13500 [00:00<?, ? examples/s]

Map:   0%|          | 0/1500 [00:00<?, ? examples/s]

In [8]:
# Initialize the KTO trainer
trainer = KTOTrainer(
    model,
    ref_model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=tokenizer,
    peft_config=get_peft_config(model_args),
)

# Train and push the model to the Hub
trainer.train()

# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
    trainer.push_to_hub()



Tokenizing train dataset:   0%|          | 0/13500 [00:00<?, ? examples/s]

Processing tokenized train dataset:   0%|          | 0/13500 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/1500 [00:00<?, ? examples/s]

Processing tokenized eval dataset:   0%|          | 0/1500 [00:00<?, ? examples/s]

Extracting KL train dataset:   0%|          | 0/13500 [00:00<?, ? examples/s]

Processing tokenized train KL dataset:   0%|          | 0/13500 [00:00<?, ? examples/s]

Extracting eval KL dataset:   0%|          | 0/1500 [00:00<?, ? examples/s]

Processing tokenized eval KL dataset:   0%|          | 0/1500 [00:00<?, ? examples/s]

  0%|          | 0/844 [00:00<?, ?it/s]

RuntimeError: MPS backend out of memory (MPS allocated: 17.37 GB, other allocations: 664.64 MB, max allowed: 18.13 GB). Tried to allocate 172.34 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).