Spaces:
Paused
Paused
# Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import json | |
import os | |
import tempfile | |
import unittest | |
from datasets import load_dataset | |
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments | |
from transformers.testing_utils import require_peft, require_wandb | |
from transformers.trainer_utils import get_last_checkpoint | |
from transformers.utils import is_peft_available | |
from tests.testing_utils import require_comet, require_mergekit | |
from trl import BasePairwiseJudge, DPOConfig, DPOTrainer, LogCompletionsCallback, MergeModelCallback, WinRateCallback | |
from trl.mergekit_utils import MergeConfig | |
if is_peft_available(): | |
from peft import LoraConfig | |
class HalfPairwiseJudge(BasePairwiseJudge): | |
"""Naive pairwise judge that always returns [1, 0] for two prompts""" | |
def judge(self, prompts, completions, shuffle_order=True, return_scores=False): | |
# just check that the batch size is 2 | |
assert len(prompts) == 2 | |
if return_scores: | |
return [0.3, 0.9] | |
return [1, 0] | |
class TrainerWithRefModel(Trainer): | |
# This is a dummy class to test the callback. Compared to the Trainer class, it only has an additional | |
# ref_model attribute | |
def __init__(self, model, ref_model, args, train_dataset, eval_dataset, processing_class): | |
super().__init__( | |
model=model, | |
args=args, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
processing_class=processing_class, | |
) | |
self.ref_model = ref_model | |
class WinRateCallbackTester(unittest.TestCase): | |
def setUp(self): | |
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") | |
self.ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") | |
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") | |
dataset["train"] = dataset["train"].select(range(8)) | |
self.expected_winrates = [ | |
{"eval_win_rate": 0.5, "epoch": 0.0, "step": 0}, | |
{"eval_win_rate": 0.5, "epoch": 0.5, "step": 2}, | |
{"eval_win_rate": 0.5, "epoch": 1.0, "step": 4}, | |
{"eval_win_rate": 0.5, "epoch": 1.5, "step": 6}, | |
{"eval_win_rate": 0.5, "epoch": 2.0, "step": 8}, | |
{"eval_win_rate": 0.5, "epoch": 2.5, "step": 10}, | |
{"eval_win_rate": 0.5, "epoch": 3.0, "step": 12}, | |
] | |
def tokenize_function(examples): | |
out = self.tokenizer(examples["prompt"], padding="max_length", max_length=16, truncation=True) | |
out["labels"] = out["input_ids"].copy() | |
return out | |
self.dataset = dataset.map(tokenize_function, batched=True) | |
self.generation_config = GenerationConfig(max_length=32) | |
self.judge = HalfPairwiseJudge() | |
def test_basic(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = TrainingArguments( | |
output_dir=tmp_dir, | |
eval_strategy="steps", | |
eval_steps=2, # evaluate every 2 steps | |
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch | |
per_device_eval_batch_size=2, | |
report_to="none", | |
) | |
trainer = TrainerWithRefModel( | |
model=self.model, | |
ref_model=self.ref_model, | |
args=training_args, | |
train_dataset=self.dataset["train"], | |
eval_dataset=self.dataset["test"], | |
processing_class=self.tokenizer, | |
) | |
win_rate_callback = WinRateCallback( | |
judge=self.judge, trainer=trainer, generation_config=self.generation_config | |
) | |
trainer.add_callback(win_rate_callback) | |
trainer.train() | |
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] | |
self.assertListEqual(winrate_history, self.expected_winrates) | |
def test_without_ref_model(self): | |
# Same as before, but without the ref_model attribute. It should use the model attribute instead | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = TrainingArguments( | |
output_dir=tmp_dir, | |
eval_strategy="steps", | |
eval_steps=2, # evaluate every 2 steps | |
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch | |
per_device_eval_batch_size=2, | |
report_to="none", | |
) | |
trainer = Trainer( | |
model=self.model, | |
args=training_args, | |
train_dataset=self.dataset["train"], | |
eval_dataset=self.dataset["test"], | |
processing_class=self.tokenizer, | |
) | |
win_rate_callback = WinRateCallback( | |
judge=self.judge, trainer=trainer, generation_config=self.generation_config | |
) | |
trainer.add_callback(win_rate_callback) | |
trainer.train() | |
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] | |
self.assertListEqual(winrate_history, self.expected_winrates) | |
def test_soft_judge(self): | |
"""Test that the soft judge functionality works correctly""" | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = TrainingArguments( | |
output_dir=tmp_dir, | |
eval_strategy="steps", | |
eval_steps=2, # evaluate every 2 steps | |
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch | |
per_device_eval_batch_size=2, | |
report_to="none", | |
) | |
trainer = TrainerWithRefModel( | |
model=self.model, | |
ref_model=self.ref_model, | |
args=training_args, | |
train_dataset=self.dataset["train"], | |
eval_dataset=self.dataset["test"], | |
processing_class=self.tokenizer, | |
) | |
win_rate_callback = WinRateCallback( | |
judge=self.judge, trainer=trainer, generation_config=self.generation_config, use_soft_judge=True | |
) | |
trainer.add_callback(win_rate_callback) | |
trainer.train() | |
# Expected values based on judge returning [0.3, 0.9] for each pair | |
expected_soft_winrates = [ | |
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 0.0, "step": 0}, | |
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 0.5, "step": 2}, | |
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 1.0, "step": 4}, | |
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 1.5, "step": 6}, | |
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 2.0, "step": 8}, | |
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 2.5, "step": 10}, | |
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 3.0, "step": 12}, | |
] | |
winrate_history = [ | |
{k: h[k] for k in ["eval_avg_win_prob", "eval_win_rate", "epoch", "step"]} | |
for h in trainer.state.log_history | |
if "eval_avg_win_prob" in h | |
] | |
self.assertListEqual(winrate_history, expected_soft_winrates) | |
def test_lora(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
peft_config = LoraConfig( | |
r=16, | |
lora_alpha=32, | |
lora_dropout=0.05, | |
bias="none", | |
task_type="CAUSAL_LM", | |
) | |
self.model.add_adapter(peft_config) | |
training_args = TrainingArguments( | |
output_dir=tmp_dir, | |
eval_strategy="steps", | |
eval_steps=2, # evaluate every 2 steps | |
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch | |
per_device_eval_batch_size=2, | |
report_to="none", | |
) | |
trainer = Trainer( | |
model=self.model, | |
args=training_args, | |
train_dataset=self.dataset["train"], | |
eval_dataset=self.dataset["test"], | |
processing_class=self.tokenizer, | |
) | |
win_rate_callback = WinRateCallback( | |
judge=self.judge, trainer=trainer, generation_config=self.generation_config | |
) | |
trainer.add_callback(win_rate_callback) | |
trainer.train() | |
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] | |
self.assertListEqual(winrate_history, self.expected_winrates) | |
class LogCompletionsCallbackTester(unittest.TestCase): | |
def setUp(self): | |
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") | |
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") | |
dataset["train"] = dataset["train"].select(range(8)) | |
def tokenize_function(examples): | |
out = self.tokenizer(examples["prompt"], padding="max_length", max_length=16, truncation=True) | |
out["labels"] = out["input_ids"].copy() | |
return out | |
self.dataset = dataset.map(tokenize_function, batched=True) | |
self.generation_config = GenerationConfig(max_length=32) | |
def test_basic_wandb(self): | |
import wandb | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = TrainingArguments( | |
output_dir=tmp_dir, | |
eval_strategy="steps", | |
eval_steps=2, # evaluate every 2 steps | |
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch | |
per_device_eval_batch_size=2, | |
report_to="wandb", | |
) | |
trainer = Trainer( | |
model=self.model, | |
args=training_args, | |
train_dataset=self.dataset["train"], | |
eval_dataset=self.dataset["test"], | |
processing_class=self.tokenizer, | |
) | |
completions_callback = LogCompletionsCallback(trainer, self.generation_config, num_prompts=2) | |
trainer.add_callback(completions_callback) | |
trainer.train() | |
# Get the current run | |
completions_path = wandb.run.summary.completions["path"] | |
json_path = os.path.join(wandb.run.dir, completions_path) | |
with open(json_path) as f: | |
completions = json.load(f) | |
# Check that the columns are correct | |
self.assertIn("step", completions["columns"]) | |
self.assertIn("prompt", completions["columns"]) | |
self.assertIn("completion", completions["columns"]) | |
# Check that the prompt is in the log | |
self.assertIn(self.dataset["test"][0]["prompt"], completions["data"][0]) | |
def test_basic_comet(self): | |
import comet_ml | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = TrainingArguments( | |
output_dir=tmp_dir, | |
eval_strategy="steps", | |
eval_steps=2, # evaluate every 2 steps | |
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch | |
per_device_eval_batch_size=2, | |
report_to="comet_ml", | |
) | |
trainer = Trainer( | |
model=self.model, | |
args=training_args, | |
train_dataset=self.dataset["train"], | |
eval_dataset=self.dataset["test"], | |
processing_class=self.tokenizer, | |
) | |
completions_callback = LogCompletionsCallback(trainer, self.generation_config, num_prompts=2) | |
trainer.add_callback(completions_callback) | |
trainer.train() | |
# close experiment to make sure all pending data are flushed | |
experiment = comet_ml.get_running_experiment() | |
assert experiment is not None | |
experiment.end() | |
# get experiment assets and check that all required tables was logged | |
steps = len(self.dataset["train"]) + len(self.dataset["test"]) | |
tables_logged = int(steps / 2) + 1 # +1 to include zero step | |
api_experiment = comet_ml.APIExperiment(previous_experiment=experiment.id) | |
tables = api_experiment.get_asset_list("dataframe") | |
assert tables is not None | |
assert len(tables) == tables_logged | |
assert all(table["fileName"] == "completions.csv" for table in tables) | |
class MergeModelCallbackTester(unittest.TestCase): | |
def setUp(self): | |
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") | |
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") | |
self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") | |
def test_callback(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = DPOConfig( | |
output_dir=tmp_dir, | |
num_train_epochs=1, | |
report_to="none", | |
save_strategy="steps", | |
save_steps=1, | |
) | |
config = MergeConfig() | |
merge_callback = MergeModelCallback(config) | |
trainer = DPOTrainer( | |
model=self.model, | |
args=training_args, | |
train_dataset=self.dataset, | |
processing_class=self.tokenizer, | |
callbacks=[merge_callback], | |
) | |
trainer.train() | |
last_checkpoint = get_last_checkpoint(tmp_dir) | |
merged_path = os.path.join(last_checkpoint, "merged") | |
self.assertTrue(os.path.isdir(merged_path), "Merged folder does not exist in the last checkpoint.") | |
def test_every_checkpoint(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = DPOConfig( | |
output_dir=tmp_dir, | |
num_train_epochs=1, | |
report_to="none", | |
save_strategy="steps", | |
save_steps=1, | |
) | |
config = MergeConfig() | |
merge_callback = MergeModelCallback(config, merge_at_every_checkpoint=True) | |
trainer = DPOTrainer( | |
model=self.model, | |
args=training_args, | |
train_dataset=self.dataset, | |
processing_class=self.tokenizer, | |
callbacks=[merge_callback], | |
) | |
trainer.train() | |
checkpoints = sorted( | |
[os.path.join(tmp_dir, cp) for cp in os.listdir(tmp_dir) if cp.startswith("checkpoint-")] | |
) | |
for checkpoint in checkpoints: | |
merged_path = os.path.join(checkpoint, "merged") | |
self.assertTrue( | |
os.path.isdir(merged_path), f"Merged folder does not exist in checkpoint {checkpoint}." | |
) | |