PEFT documentation

P-tuning

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.11.0).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

P-tuning

P-tuning adds trainable prompt embeddings to the input that is optimized by a prompt encoder to find a better prompt, eliminating the need to manually design prompts. The prompt tokens can be added anywhere in the input sequence, and p-tuning also introduces anchor tokens for improving performance.

The abstract from the paper is:

While GPTs with traditional fine-tuning fail to achieve strong results on natural language understanding (NLU), we show that GPTs can be better than or comparable to similar-sized BERTs on NLU tasks with a novel method P-tuning — which employs trainable continuous prompt embeddings. On the knowledge probing (LAMA) benchmark, the best GPT recovers 64\% (P@1) of world knowledge without any additional text provided during test time, which substantially improves the previous best by 20+ percentage points. On the SuperGlue benchmark, GPTs achieve comparable and sometimes better performance to similar-sized BERTs in supervised learning. Importantly, we find that P-tuning also improves BERTs’ performance in both few-shot and supervised settings while largely reducing the need for prompt engineering. Consequently, P-tuning outperforms the state-of-the-art approaches on the few-shot SuperGlue benchmark..

PromptEncoderConfig

class peft.PromptEncoderConfig

< >

( peft_type: Union = None auto_mapping: Optional = None base_model_name_or_path: Optional = None revision: Optional = None task_type: Union = None inference_mode: bool = False num_virtual_tokens: int = None token_dim: int = None num_transformer_submodules: Optional = None num_attention_heads: Optional = None num_layers: Optional = None encoder_reparameterization_type: Union = <PromptEncoderReparameterizationType.MLP: 'MLP'> encoder_hidden_size: int = None encoder_num_layers: int = 2 encoder_dropout: float = 0.0 )

Parameters

  • encoder_reparameterization_type (Union[PromptEncoderReparameterizationType, str]) — The type of reparameterization to use.
  • encoder_hidden_size (int) — The hidden size of the prompt encoder.
  • encoder_num_layers (int) — The number of layers of the prompt encoder.
  • encoder_dropout (float) — The dropout probability of the prompt encoder.

This is the configuration class to store the configuration of a PromptEncoder.

PromptEncoder

class peft.PromptEncoder

< >

( config )

Parameters

The prompt encoder network that is used to generate the virtual token embeddings for p-tuning.

Example:

>>> from peft import PromptEncoder, PromptEncoderConfig

>>> config = PromptEncoderConfig(
...     peft_type="P_TUNING",
...     task_type="SEQ_2_SEQ_LM",
...     num_virtual_tokens=20,
...     token_dim=768,
...     num_transformer_submodules=1,
...     num_attention_heads=12,
...     num_layers=12,
...     encoder_reparameterization_type="MLP",
...     encoder_hidden_size=768,
... )

>>> prompt_encoder = PromptEncoder(config)

Attributes:

  • embedding (torch.nn.Embedding) — The embedding layer of the prompt encoder.
  • mlp_head (torch.nn.Sequential) — The MLP head of the prompt encoder if inference_mode=False.
  • lstm_head (torch.nn.LSTM) — The LSTM head of the prompt encoder if inference_mode=False and encoder_reparameterization_type="LSTM".
  • token_dim (int) — The hidden embedding dimension of the base transformer model.
  • input_size (int) — The input size of the prompt encoder.
  • output_size (int) — The output size of the prompt encoder.
  • hidden_size (int) — The hidden size of the prompt encoder.
  • total_virtual_tokens (int): The total number of virtual tokens of the prompt encoder.
  • encoder_type (Union[PromptEncoderReparameterizationType, str]): The encoder type of the prompt encoder.

Input shape: (batch_size, total_virtual_tokens)

Output shape: (batch_size, total_virtual_tokens, token_dim)

< > Update on GitHub