With the AutoModelForCausalLMWithValueHead
class TRL supports all decoder model architectures in transformers such as GPT-2, OPT, and GPT-Neo.
A wrapper class around a (transformers.PreTrainedModel
) to be compatible with the
(~transformers.PreTrained
) class in order to keep some attributes and methods of the
(~transformers.PreTrainedModel
) class.
( pretrained_model_name_or_path *model_args **kwargs )
Parameters
str
or transformers.PreTrainedModel
) —
The path to the pretrained model or its name.
list
, optional)) —
Additional positional arguments passed along to the underlying model’s
from_pretrained
method.
dict
, optional) —
Additional keyword arguments passed along to the underlying model’s
from_pretrained
method. We also pre-process the kwargs to extract
the arguments that are specific to the transformers.PreTrainedModel
class and the arguments that are specific to trl models.
Instantiates a new model from a pretrained model from transformers
. The
pretrained model is loaded using the from_pretrained
method of the
transformers.PreTrainedModel
class. The arguments that are specific to the
transformers.PreTrainedModel
class are passed along this method and filtered
out from the kwargs
argument.
Post initialization method. This method is called after the model is instantiated and loaded from a checkpoint. It can be used to perform additional operations such as loading the state_dict.
( *args **kwargs )
Push the pretrained model to the hub. This method is a wrapper around
transformers.PreTrainedModel.push_to_hub
. Please refer to the documentation
of transformers.PreTrainedModel.push_to_hub
for more information.
( *args **kwargs )
Save the pretrained model to a directory. This method is a wrapper around
transformers.PreTrainedModel.save_pretrained
. Please refer to the documentation
of transformers.PreTrainedModel.save_pretrained
for more information.
Return the state_dict of the pretrained model.
An autoregressive model with a value head in addition to the language model head.
This class inherits from ~trl.PreTrainedModelWrapper
and wraps a
transformers.PreTrainedModel
class. The wrapper class supports classic functions
such as from_pretrained
, push_to_hub
and generate
. To call a method of the wrapped
model, simply manipulate the pretrained_model
attribute of this class.
Class attributes:
transformers.PreTrainedModel
) — The parent class of the wrapped model. This
should be set to transformers.AutoModelForCausalLM
for this class.tuple
) — A tuple of strings that are used to identify the language model head of the
wrapped model. This is set to ("lm_head", "embed_out")
for this class but can be changed for other models
in the futuretuple
) — A tuple of strings that are used to identify the arguments that are supported
by the ValueHead
class. Currently the supported args are:float
, optional
, defaults to None
) — The dropout probability for the
ValueHead
class.float
, optional
, defaults to 0.2
) — The initializer range for the
ValueHead
if a specific initialization strategy is selected.str
, optional
, defaults to None
) — The initialization strategy for the
ValueHead
. Currently supported strategies are:None
— Initializes the weights of the ValueHead
with a random distribution. This is the default
strategy.ValueHead
with a normal distribution.( pretrained_model **kwargs )
Initializes the model.
( input_ids = None past_key_values = None attention_mask = None **kwargs )
Parameters
[0, 1]
:Applies a forward pass to the wrapped model and returns the logits of the value head.
( *args **kwargs )
A simple wrapper around the generate
method of the wrapped model.
Please refer to the generate
method of the wrapped model for more information about the supported arguments.
( **kwargs )
Initializes the weights of the value head. The default initialization strategy is random.
Users can pass a different initialization strategy by passing the v_head_init_strategy
argument
when calling .from_pretrained
. Supported strategies are:
normal
: initializes the weights with a normal distribution.