TRL documentation

Models

You are viewing v0.12.0 version. A newer version v0.24.0 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Models

With the AutoModelForCausalLMWithValueHead class TRL supports all decoder model architectures in transformers such as GPT-2, OPT, and GPT-Neo. In addition, with AutoModelForSeq2SeqLMWithValueHead you can use encoder-decoder architectures such as T5. TRL also requires reference models which are frozen copies of the model that is trained. With create_reference_model you can easily create a frozen copy and also share layers between the two models to save memory.

PreTrainedModelWrapper

class trl.PreTrainedModelWrapper

< >

( pretrained_model = None score_module = None supports_rm_adapter = False rm_adapter_name = None **kwargs )

Parameters

  • pretrained_model (transformers.PreTrainedModel) — The model to be wrapped.
  • parent_class (transformers.PreTrainedModel) — The parent class of the model to be wrapped.
  • supported_args (list) — The list of arguments that are supported by the wrapper class.

A wrapper class around a (transformers.PreTrainedModel) to be compatible with the (~transformers.PreTrained) class in order to keep some attributes and methods of the (~transformers.PreTrainedModel) class.

add_and_load_reward_modeling_adapter

< >

( pretrained_model adapter_model_id adapter_name = 'reward_model_adapter' token = None )

Add and load a reward modeling adapter. This method can only be used if the model is a PeftModel and if you have initialized the model with the reward_modeling_adapter_id argument, pointing to the id of the reward modeling adapter. The latest needs also to contain the score head in order to produce the reward.

compute_reward_score

< >

( input_ids attention_mask = None **kwargs )

Computes the reward score for a given input. The method has first to enable the adapter and then compute the reward score. After that the model disables the reward modeling adapter and enables the default ppo adapter again.

from_pretrained

< >

( pretrained_model_name_or_path *model_args **kwargs )

Parameters

  • pretrained_model_name_or_path (str or transformers.PreTrainedModel) — The path to the pretrained model or its name.
  • *model_args (list, optional)) — Additional positional arguments passed along to the underlying model’s from_pretrained method.
  • **kwargs (dict, optional) — Additional keyword arguments passed along to the underlying model’s from_pretrained method. We also pre-process the kwargs to extract the arguments that are specific to the transformers.PreTrainedModel class and the arguments that are specific to trl models. The kwargs also support prepare_model_for_kbit_training arguments from peft library.

Instantiates a new model from a pretrained model from transformers. The pretrained model is loaded using the from_pretrained method of the transformers.PreTrainedModel class. The arguments that are specific to the transformers.PreTrainedModel class are passed along this method and filtered out from the kwargs argument.

post_init

< >

( *args **kwargs )

Post initialization method. This method is called after the model is instantiated and loaded from a checkpoint. It can be used to perform additional operations such as loading the state_dict.

push_to_hub

< >

( *args **kwargs )

Parameters

  • *args (list, optional) — Positional arguments passed along to the underlying model’s push_to_hub method.
  • **kwargs (dict, optional) — Keyword arguments passed along to the underlying model’s push_to_hub method.

Push the pretrained model to the hub. This method is a wrapper around transformers.PreTrainedModel.push_to_hub. Please refer to the documentation of transformers.PreTrainedModel.push_to_hub for more information.

save_pretrained

< >

( *args **kwargs )

Parameters

  • *args (list, optional) — Positional arguments passed along to the underlying model’s save_pretrained method.
  • **kwargs (dict, optional) — Keyword arguments passed along to the underlying model’s save_pretrained method.

Save the pretrained model to a directory. This method is a wrapper around transformers.PreTrainedModel.save_pretrained. Please refer to the documentation of transformers.PreTrainedModel.save_pretrained for more information.

state_dict

< >

( *args **kwargs )

Return the state_dict of the pretrained model.

AutoModelForCausalLMWithValueHead

class trl.AutoModelForCausalLMWithValueHead

< >

( pretrained_model **kwargs )

An autoregressive model with a value head in addition to the language model head. This class inherits from ~trl.PreTrainedModelWrapper and wraps a transformers.PreTrainedModel class. The wrapper class supports classic functions such as from_pretrained, push_to_hub and generate. To call a method of the wrapped model, simply manipulate the pretrained_model attribute of this class.

Class attributes:

  • transformers_parent_class (transformers.PreTrainedModel) — The parent class of the wrapped model. This should be set to transformers.AutoModelForCausalLM for this class.
  • lm_head_namings (tuple) — A tuple of strings that are used to identify the language model head of the wrapped model. This is set to ("lm_head", "embed_out") for this class but can be changed for other models in the future
  • supported_args (tuple) — A tuple of strings that are used to identify the arguments that are supported by the ValueHead class. Currently, the supported args are:
    • summary_dropout_prob (float, optional, defaults to None) — The dropout probability for the ValueHead class.
    • v_head_initializer_range (float, optional, defaults to 0.2) — The initializer range for the ValueHead if a specific initialization strategy is selected.
    • v_head_init_strategy (str, optional, defaults to None) — The initialization strategy for the ValueHead. Currently, the supported strategies are:
      • None — Initializes the weights of the ValueHead with a random distribution. This is the default strategy.
      • “normal” — Initializes the weights of the ValueHead with a normal distribution.

__init__

< >

( pretrained_model **kwargs )

Parameters

  • pretrained_model (transformers.PreTrainedModel) — The model to wrap. It should be a causal language model such as GPT2. or any model mapped inside the AutoModelForCausalLM class.
  • kwargs (dict, optional) — Additional keyword arguments, that are passed to the ValueHead class.

Initializes the model.

forward

< >

( input_ids = None past_key_values = None attention_mask = None return_past_key_values = False **kwargs )

Parameters

  • input_ids (torch.LongTensor of shape (batch_size, sequence_length)) — Indices of input sequence tokens in the vocabulary.
  • past_key_values (tuple(tuple(torch.FloatTensor)), optional) — Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see past_key_values input) to speed up sequential decoding.
  • attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional) — Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]:
    • 1 for tokens that are not masked,
    • 0 for tokens that are masked.
  • return_past_key_values (bool) — A flag indicating if the computed hidden-states should be returned.
  • kwargs (dict, optional) — Additional keyword arguments, that are passed to the wrapped model.

Applies a forward pass to the wrapped model and returns the logits of the value head.

generate

< >

( *args **kwargs )

Parameters

  • *args (list, optional) — Positional arguments passed to the generate method of the wrapped model.
  • **kwargs (dict, optional) — Keyword arguments passed to the generate method of the wrapped model.

A simple wrapper around the generate method of the wrapped model. Please refer to the generate method of the wrapped model for more information about the supported arguments.

_init_weights

< >

( **kwargs )

Parameters

  • **kwargs (dict, optional) — Additional keyword arguments, that are passed to the ValueHead class. These arguments can contain the v_head_init_strategy argument as well as the v_head_initializer_range argument.

Initializes the weights of the value head. The default initialization strategy is random. Users can pass a different initialization strategy by passing the v_head_init_strategy argument when calling .from_pretrained. Supported strategies are:

  • normal: initializes the weights with a normal distribution.

AutoModelForSeq2SeqLMWithValueHead

class trl.AutoModelForSeq2SeqLMWithValueHead

< >

( pretrained_model **kwargs )

Parameters

  • pretrained_model (transformers.PreTrainedModel) — The model to wrap. It should be a causal language model such as GPT2. or any model mapped inside the AutoModelForSeq2SeqLM class. kwargs — Additional keyword arguments passed along to the ValueHead class.

A seq2seq model with a value head in addition to the language model head. This class inherits from ~trl.PreTrainedModelWrapper and wraps a transformers.PreTrainedModel class. The wrapper class supports classic functions such as from_pretrained and push_to_hub and also provides some additional functionalities such as generate.

__init__

< >

( pretrained_model **kwargs )

forward

< >

( input_ids = None past_key_values = None attention_mask = None return_past_key_values = False **kwargs )

generate

< >

( *args **kwargs )

We call generate on the wrapped model.

_init_weights

< >

( **kwargs )

We initialize the weights of the value head.

create_reference_model

trl.create_reference_model

< >

( model: PreTrainedModelWrapper num_shared_layers: Optional = None pattern: Optional = None ) PreTrainedModelWrapper

Parameters

  • model (PreTrainedModelWrapper) — The model to be copied.
  • num_shared_layers (int, optional) — The number of initial layers that are shared between both models and kept frozen.
  • pattern (str, optional) — The shared layers are selected with a string pattern (e.g. “transformer.h.{layer}” for GPT2) and if a custom pattern is necessary it can be passed here.

Returns

PreTrainedModelWrapper

Creates a static reference copy of a model. Note that model will be in .eval() mode.

< > Update on GitHub