TRL documentation

Iterative Trainer

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.8.6).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Iterative Trainer

Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.

Usage

To get started quickly, instantiate an instance a model, and a tokenizer.


model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

trainer = IterativeSFTTrainer(
    model,
    tokenizer
)

You have the choice to either provide a list of strings or a list of tensors to the step function.

Using a list of tensors as input:


inputs = {
    "input_ids": input_ids,
    "attention_mask": attention_mask
}

trainer.step(**inputs)

Using a list of strings as input:


inputs = {
    "texts": texts
}

trainer.step(**inputs)

For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels.

IterativeTrainer

class trl.IterativeSFTTrainer

< >

( model: Optional = None args: Optional = None tokenizer: Optional = None optimizers: Tuple = (None, None) data_collator: Optional = None eval_dataset: Union = None max_length: Optional = None truncation_mode: Optional = 'keep_end' preprocess_logits_for_metrics: Optional = None compute_metrics: Optional = None optimize_device_cache: Optional = False )

Parameters

  • **model** (PreTrainedModel) — Model to be optimized, either an ‘AutoModelForCausalLM’ or an ‘AutoModelForSeq2SeqLM’. — Check the documentation of PreTrainedModel for more details.
  • **args** (transformers.TrainingArguments) — — The arguments to use for training.
  • **tokenizer** (PreTrainedTokenizerBase) — Tokenizer to be used for encoding the — data. Check the documentation of transformers.PreTrainedTokenizer and transformers.PreTrainedTokenizerFast for more details.
  • **optimizers** (Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]) — — The optimizer and scheduler to use for training.
  • **data_collator** (Union[DataCollatorForLanguageModeling, DataCollatorForSeq2Seq], optional) — Data collator to be used for training and — passed along the dataloader.
  • **eval_dataset** (datasets.Dataset) — The dataset to use for evaluation.
  • **max_length** (int, defaults to None) — — The maximum length of the input.
  • **truncation_mode** (str, defaults to keep_end) — — The truncation mode to use, either keep_end or keep_start.
  • **preprocess_logits_for_metrics** (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) — — The function to use to preprocess the logits before computing the metrics.
  • **compute_metrics** (Callable[[EvalPrediction], Dict], optional) — — The function to use to compute the metrics. Must take a EvalPrediction and return a dictionary string to metric values.
  • **optimize_device_cache * (bool, optional*, defaults to False) — Optimize CUDA cache for slightly more memory-efficient training. —

The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization.

step

< >

( input_ids: Optional = None attention_mask: Optional = None labels: Optional = None texts: Optional = None texts_labels: Optional = None ) dict[str, Any]

Parameters

  • input_ids (Listtorch.LongTensor) — List of tensors containing the input_ids (if not provided, text will be used)
  • attention_mask (Listtorch.LongTensor, , optional) — List of tensors containing the attention_mask
  • labels (Listtorch.FloatTensor, optional) — List of tensors containing the labels (if set to None, will default to input_ids)
  • texts (Liststr, optional) — List of strings containing the text input (if not provided, input_ids will directly be used)
  • texts_labels (Liststr, optional) — List of strings containing the text labels (if set to None, will default to text)

Returns

dict[str, Any]

A summary of the training statistics

Run an optimisation step given a list of input_ids, attention_mask, and labels or a list of text and text_labels.

< > Update on GitHub