Prefix tuning
Prefix tuning prefixes a series of task-specific vectors to the input sequence that can be learned while keeping the pretrained model frozen. The prefix parameters are inserted in all of the model layers.
The abstract from the paper is:
Fine-tuning is the de facto way to leverage large pretrained language models to perform downstream tasks. However, it modifies all the language model parameters and therefore necessitates storing a full copy for each task. In this paper, we propose prefix-tuning, a lightweight alternative to fine-tuning for natural language generation tasks, which keeps language model parameters frozen, but optimizes a small continuous task-specific vector (called the prefix). Prefix-tuning draws inspiration from prompting, allowing subsequent tokens to attend to this prefix as if it were “virtual tokens”. We apply prefix-tuning to GPT-2 for table-to-text generation and to BART for summarization. We find that by learning only 0.1\% of the parameters, prefix-tuning obtains comparable performance in the full data setting, outperforms fine-tuning in low-data settings, and extrapolates better to examples with topics unseen during training.
PrefixTuningConfig
class peft.PrefixTuningConfig
< source >( task_type: typing.Union[str, peft.utils.peft_types.TaskType, NoneType] = None peft_type: typing.Union[str, peft.utils.peft_types.PeftType, NoneType] = None auto_mapping: typing.Optional[dict] = None base_model_name_or_path: typing.Optional[str] = None revision: typing.Optional[str] = None inference_mode: bool = False num_virtual_tokens: int = None token_dim: int = None num_transformer_submodules: typing.Optional[int] = None num_attention_heads: typing.Optional[int] = None num_layers: typing.Optional[int] = None encoder_hidden_size: int = None prefix_projection: bool = False )
This is the configuration class to store the configuration of a PrefixEncoder.
PrefixEncoder
class peft.PrefixEncoder
< source >( config )
Parameters
- config (PrefixTuningConfig) — The configuration of the prefix encoder.
The torch.nn
model to encode the prefix.
Example:
>>> from peft import PrefixEncoder, PrefixTuningConfig
>>> config = PrefixTuningConfig(
... peft_type="PREFIX_TUNING",
... task_type="SEQ_2_SEQ_LM",
... num_virtual_tokens=20,
... token_dim=768,
... num_transformer_submodules=1,
... num_attention_heads=12,
... num_layers=12,
... encoder_hidden_size=768,
... )
>>> prefix_encoder = PrefixEncoder(config)
Attributes:
- embedding (
torch.nn.Embedding
) — The embedding layer of the prefix encoder. - transform (
torch.nn.Sequential
) — The two-layer MLP to transform the prefix embeddings ifprefix_projection
isTrue
. - prefix_projection (
bool
) — Whether to project the prefix embeddings.
Input shape: (batch_size
, num_virtual_tokens
)
Output shape: (batch_size
, num_virtual_tokens
, 2*layers*hidden
)