Models

With the AutoModelForCausalLMWithValueHead class TRL supports all decoder model architectures in transformers such as GPT-2, OPT, and GPT-Neo.

PreTrainedModelWrapper

class trl.PreTrainedModelWrapper

< >

( pretrained_model = None **kwargs )

A wrapper class around a (transformers.PreTrainedModel) to be compatible with the (~transformers.PreTrained) class in order to keep some attributes and methods of the (~transformers.PreTrainedModel) class.

from_pretrained

< >

( pretrained_model_name_or_path *model_args **kwargs )

Parameters

  • pretrained_model_name_or_path (str or transformers.PreTrainedModel) — The path to the pretrained model or its name.
  • *model_args (list, optional)) — Additional positional arguments passed along to the underlying model’s from_pretrained method.
  • **kwargs (dict, optional) — Additional keyword arguments passed along to the underlying model’s from_pretrained method. We also pre-process the kwargs to extract the arguments that are specific to the transformers.PreTrainedModel class and the arguments that are specific to trl models.

Instantiates a new model from a pretrained model from transformers. The pretrained model is loaded using the from_pretrained method of the transformers.PreTrainedModel class. The arguments that are specific to the transformers.PreTrainedModel class are passed along this method and filtered out from the kwargs argument.

post_init

< >

( *args **kwargs )

Post initialization method. This method is called after the model is instantiated and loaded from a checkpoint. It can be used to perform additional operations such as loading the state_dict.

push_to_hub

< >

( *args **kwargs )

Parameters

  • *args (list, optional) — Positional arguments passed along to the underlying model’s push_to_hub method.
  • **kwargs (dict, optional) — Keyword arguments passed along to the underlying model’s push_to_hub method.

Push the pretrained model to the hub. This method is a wrapper around transformers.PreTrainedModel.push_to_hub. Please refer to the documentation of transformers.PreTrainedModel.push_to_hub for more information.

save_pretrained

< >

( *args **kwargs )

Parameters

  • *args (list, optional) — Positional arguments passed along to the underlying model’s save_pretrained method.
  • **kwargs (dict, optional) — Keyword arguments passed along to the underlying model’s save_pretrained method.

Save the pretrained model to a directory. This method is a wrapper around transformers.PreTrainedModel.save_pretrained. Please refer to the documentation of transformers.PreTrainedModel.save_pretrained for more information.

state_dict

< >

( *args **kwargs )

Return the state_dict of the pretrained model.

AutoModelForCausalLMWithValueHead

class trl.AutoModelForCausalLMWithValueHead

< >

( pretrained_model **kwargs )

An autoregressive model with a value head in addition to the language model head. This class inherits from ~trl.PreTrainedModelWrapper and wraps a transformers.PreTrainedModel class. The wrapper class supports classic functions such as from_pretrained, push_to_hub and generate. To call a method of the wrapped model, simply manipulate the pretrained_model attribute of this class.

Class attributes:

__init__

< >

( pretrained_model **kwargs )

Parameters

  • pretrained_model (transformers.PreTrainedModel) — The model to wrap. It should be a causal language model such as GPT2. or any model mapped inside the AutoModelForCausalLM class.
  • kwargs (dict, optional) — Additional keyword arguments, that are passed to the ValueHead class.

Initializes the model.

forward

< >

( input_ids = None past_key_values = None attention_mask = None **kwargs )

Parameters

  • input_ids (torch.LongTensor of shape (batch_size, sequence_length)) — Indices of input sequence tokens in the vocabulary.
  • past_key_values (tuple(tuple(torch.FloatTensor)), optional) — Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see past_key_values input) to speed up sequential decoding.
  • attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional) — Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]:
    • 1 for tokens that are not masked,
    • 0 for tokens that are masked.
  • kwargs (dict, optional) — Additional keyword arguments, that are passed to the wrapped model.

Applies a forward pass to the wrapped model and returns the logits of the value head.

generate

< >

( *args **kwargs )

Parameters

  • *args (list, optional) — Positional arguments passed to the generate method of the wrapped model.
  • **kwargs (dict, optional) — Keyword arguments passed to the generate method of the wrapped model.

A simple wrapper around the generate method of the wrapped model. Please refer to the generate method of the wrapped model for more information about the supported arguments.

_init_weights

< >

( **kwargs )

Parameters

  • **kwargs (dict, optional) — Additional keyword arguments, that are passed to the ValueHead class. These arguments can contain the v_head_init_strategy argument as well as the v_head_initializer_range argument.

Initializes the weights of the value head. The default initialization strategy is random. Users can pass a different initialization strategy by passing the v_head_init_strategy argument when calling .from_pretrained. Supported strategies are: