PEFT documentation

Prompt tuning

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Prompt tuning

Prompt tuning adds task-specific prompts to the input, and these prompt parameters are updated independently of the pretrained model parameters which are frozen.

The abstract from the paper is:

In this work, we explore “prompt tuning”, a simple yet effective mechanism for learning “soft prompts” to condition frozen language models to perform specific downstream tasks. Unlike the discrete text prompts used by GPT-3, soft prompts are learned through backpropagation and can be tuned to incorporate signal from any number of labeled examples. Our end-to-end learned approach outperforms GPT-3’s “few-shot” learning by a large margin. More remarkably, through ablations on model size using T5, we show that prompt tuning becomes more competitive with scale: as models exceed billions of parameters, our method “closes the gap” and matches the strong performance of model tuning (where all model weights are tuned). This finding is especially relevant in that large models are costly to share and serve, and the ability to reuse one frozen model for multiple downstream tasks can ease this burden. Our method can be seen as a simplification of the recently proposed “prefix tuning” of Li and Liang (2021), and we provide a comparison to this and other similar approaches. Finally, we show that conditioning a frozen model with soft prompts confers benefits in robustness to domain transfer, as compared to full model tuning.

PromptTuningConfig

class peft.PromptTuningConfig

< >

( task_type: typing.Union[str, peft.utils.peft_types.TaskType, NoneType] = None peft_type: typing.Union[str, peft.utils.peft_types.PeftType, NoneType] = None auto_mapping: typing.Optional[dict] = None base_model_name_or_path: typing.Optional[str] = None revision: typing.Optional[str] = None inference_mode: bool = False num_virtual_tokens: int = None token_dim: int = None num_transformer_submodules: typing.Optional[int] = None num_attention_heads: typing.Optional[int] = None num_layers: typing.Optional[int] = None prompt_tuning_init: typing.Union[peft.tuners.prompt_tuning.config.PromptTuningInit, str] = <PromptTuningInit.RANDOM: 'RANDOM'> prompt_tuning_init_text: typing.Optional[str] = None tokenizer_name_or_path: typing.Optional[str] = None tokenizer_kwargs: typing.Optional[dict] = None )

Parameters

  • prompt_tuning_init (Union[PromptTuningInit, str]) — The initialization of the prompt embedding.
  • prompt_tuning_init_text (str, optional) — The text to initialize the prompt embedding. Only used if prompt_tuning_init is TEXT.
  • tokenizer_name_or_path (str, optional) — The name or path of the tokenizer. Only used if prompt_tuning_init is TEXT.
  • tokenizer_kwargs (dict, optional) — The keyword arguments to pass to AutoTokenizer.from_pretrained. Only used if prompt_tuning_init is TEXT.

This is the configuration class to store the configuration of a PromptEmbedding.

PromptEmbedding

class peft.PromptEmbedding

< >

( config word_embeddings )

Parameters

  • config (PromptTuningConfig) — The configuration of the prompt embedding.
  • word_embeddings (torch.nn.Module) — The word embeddings of the base transformer model.

The model to encode virtual tokens into prompt embeddings.

Attributes:

  • embedding (torch.nn.Embedding) — The embedding layer of the prompt embedding.

Example:

>>> from peft import PromptEmbedding, PromptTuningConfig

>>> config = PromptTuningConfig(
...     peft_type="PROMPT_TUNING",
...     task_type="SEQ_2_SEQ_LM",
...     num_virtual_tokens=20,
...     token_dim=768,
...     num_transformer_submodules=1,
...     num_attention_heads=12,
...     num_layers=12,
...     prompt_tuning_init="TEXT",
...     prompt_tuning_init_text="Predict if sentiment of this review is positive, negative or neutral",
...     tokenizer_name_or_path="t5-base",
... )

>>> # t5_model.shared is the word embeddings of the base model
>>> prompt_embedding = PromptEmbedding(config, t5_model.shared)

Input Shape: (batch_size, total_virtual_tokens)

Output Shape: (batch_size, total_virtual_tokens, token_dim)

< > Update on GitHub