PEFT documentation
P-tuning
P-tuning
P-tuning adds trainable prompt embeddings to the input that is optimized by a prompt encoder to find a better prompt, eliminating the need to manually design prompts. The prompt tokens can be added anywhere in the input sequence, and p-tuning also introduces anchor tokens for improving performance.
The abstract from the paper is:
While GPTs with traditional fine-tuning fail to achieve strong results on natural language understanding (NLU), we show that GPTs can be better than or comparable to similar-sized BERTs on NLU tasks with a novel method P-tuning — which employs trainable continuous prompt embeddings. On the knowledge probing (LAMA) benchmark, the best GPT recovers 64\% (P@1) of world knowledge without any additional text provided during test time, which substantially improves the previous best by 20+ percentage points. On the SuperGlue benchmark, GPTs achieve comparable and sometimes better performance to similar-sized BERTs in supervised learning. Importantly, we find that P-tuning also improves BERTs’ performance in both few-shot and supervised settings while largely reducing the need for prompt engineering. Consequently, P-tuning outperforms the state-of-the-art approaches on the few-shot SuperGlue benchmark..
PromptEncoderConfig
class peft.PromptEncoderConfig
< source >( task_type: typing.Union[str, peft.utils.peft_types.TaskType, NoneType] = None peft_type: typing.Union[str, peft.utils.peft_types.PeftType, NoneType] = None auto_mapping: typing.Optional[dict] = None base_model_name_or_path: typing.Optional[str] = None revision: typing.Optional[str] = None inference_mode: bool = False num_virtual_tokens: int = None token_dim: int = None num_transformer_submodules: typing.Optional[int] = None num_attention_heads: typing.Optional[int] = None num_layers: typing.Optional[int] = None encoder_reparameterization_type: typing.Union[str, peft.tuners.p_tuning.config.PromptEncoderReparameterizationType] = <PromptEncoderReparameterizationType.MLP: 'MLP'> encoder_hidden_size: int = None encoder_num_layers: int = 2 encoder_dropout: float = 0.0 )
Parameters
-  encoder_reparameterization_type (Union[PromptEncoderReparameterizationType,str]) — The type of reparameterization to use.
-  encoder_hidden_size (int) — The hidden size of the prompt encoder.
-  encoder_num_layers (int) — The number of layers of the prompt encoder.
-  encoder_dropout (float) — The dropout probability of the prompt encoder.
This is the configuration class to store the configuration of a PromptEncoder.
PromptEncoder
class peft.PromptEncoder
< source >( config )
Parameters
- config (PromptEncoderConfig) — The configuration of the prompt encoder.
The prompt encoder network that is used to generate the virtual token embeddings for p-tuning.
Example:
>>> from peft import PromptEncoder, PromptEncoderConfig
>>> config = PromptEncoderConfig(
...     peft_type="P_TUNING",
...     task_type="SEQ_2_SEQ_LM",
...     num_virtual_tokens=20,
...     token_dim=768,
...     num_transformer_submodules=1,
...     num_attention_heads=12,
...     num_layers=12,
...     encoder_reparameterization_type="MLP",
...     encoder_hidden_size=768,
... )
>>> prompt_encoder = PromptEncoder(config)Attributes:
- embedding (torch.nn.Embedding) — The embedding layer of the prompt encoder.
- mlp_head (torch.nn.Sequential) — The MLP head of the prompt encoder ifinference_mode=False.
- lstm_head (torch.nn.LSTM) — The LSTM head of the prompt encoder ifinference_mode=Falseandencoder_reparameterization_type="LSTM".
- token_dim (int) — The hidden embedding dimension of the base transformer model.
- input_size (int) — The input size of the prompt encoder.
- output_size (int) — The output size of the prompt encoder.
- hidden_size (int) — The hidden size of the prompt encoder.
- total_virtual_tokens (int): The total number of virtual tokens of the prompt encoder.
- encoder_type (Union[PromptEncoderReparameterizationType,str]): The encoder type of the prompt encoder.
Input shape: (batch_size, total_virtual_tokens)
Output shape: (batch_size, total_virtual_tokens, token_dim)