PEFT documentation

Context-aware Prompt Tuning: Advancing In-Context Learning with Adversarial Methods

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Context-aware Prompt Tuning: Advancing In-Context Learning with Adversarial Methods

CPT combines In-Context Learning (ICL), Prompt Tuning (PT), and adversarial optimization to improve few-shot learning by refining context embeddings. CPT updates the context tokens by optimizing both the context and the training examples, encapsulating them into a novel loss design that minimizes overfitting, enables more effective optimization, and drives significant improvements in classification tasks.

The abstract from the paper is:

Large Language Models (LLMs) can perform few-shot learning using either optimization-based approaches or In-Context Learning (ICL). Optimization-based methods often suffer from overfitting, as they require updating a large number of parameters with limited data. In contrast, ICL avoids overfitting but typically underperforms compared to optimization-based methods and is highly sensitive to the selection, order, and format of demonstration examples. To overcome these challenges, we introduce Context-aware Prompt Tuning (CPT), a method inspired by ICL, Prompt Tuning (PT), and adversarial attacks. CPT builds on the ICL strategy of concatenating examples before the input, extending it by incorporating PT-like learning to refine the context embedding through iterative optimization, extracting deeper insights from the training examples. Our approach carefully modifies specific context tokens, considering the unique structure of the examples within the context. In addition to updating the context with PT-like optimization, CPT draws inspiration from adversarial attacks, adjusting the input based on the labels present in the context while preserving the inherent value of the user-provided data. To ensure robustness and stability during optimization, we employ a projected gradient descent algorithm, constraining token embeddings to remain close to their original values and safeguarding the quality of the context. Our method has demonstrated superior accuracy across multiple classification tasks using various LLM models, outperforming existing baselines and effectively addressing the overfitting challenge in few-shot learning.

Take a look at Example for a step-by-step guide on how to train a model with CPT.

CPTConfig

class peft.CPTConfig

< >

( task_type: typing.Union[str, peft.utils.peft_types.TaskType, NoneType] = None peft_type: typing.Union[str, peft.utils.peft_types.PeftType, NoneType] = None auto_mapping: typing.Optional[dict] = None base_model_name_or_path: typing.Optional[str] = None revision: typing.Optional[str] = None inference_mode: bool = False num_virtual_tokens: int = None token_dim: int = None num_transformer_submodules: typing.Optional[int] = None num_attention_heads: typing.Optional[int] = None num_layers: typing.Optional[int] = None cpt_token_ids: typing.Optional[list[int]] = None cpt_mask: typing.Optional[list[int]] = None cpt_tokens_type_mask: typing.Optional[list[int]] = None opt_weighted_loss_type: typing.Optional[typing.Literal['none', 'decay']] = 'none' opt_loss_decay_factor: typing.Optional[float] = 1.0 opt_projection_epsilon: typing.Optional[float] = 0.1 opt_projection_format_epsilon: typing.Optional[float] = 0.1 tokenizer_name_or_path: typing.Optional[str] = None )

CPT Configuration class extending PeftConfig for Context-aware Prompt Tuning (CPT).

This class introduces additional parameters required for CPT, such as:

  • Token type masks
  • Prompt tuning initialization
  • Loss weighting
  • Projection settings

For more details, see the paper: https://arxiv.org/abs/2410.17222

CPTEmbedding

class peft.CPTEmbedding

< >

( config word_embeddings )

CPTEmbedding is a custom embedding layer designed for Context-aware Prompt Tuning (CPT) in PEFT. It initializes embeddings, applies prompt-specific projections, and computes loss using label masks.

calculate_loss

< >

( base_model_output labels cpt_type_mask config ) ModelOutput

Parameters

  • base_model_output (ModelOutput) — Output from the base model containing logits.
  • labels (torch.Tensor) — Ground-truth labels for the input tokens.
  • cpt_type_mask (torch.Tensor) — Token type mask used for filtering valid loss terms.
  • config (Namespace) — Configuration object containing loss-related hyperparameters.

Returns

ModelOutput

The base model output with computed loss.

Computes the loss for CPT models with optional exponential decay.

forward

< >

( indices ) torch.Tensor

Parameters

  • indices (torch.Tensor) — Indices of the tokens to be embedded.

Returns

torch.Tensor

Sum of prompt embeddings and delta embeddings.

Computes the prompt embeddings and applies delta adjustments.

get_projection

< >

( )

Applies epsilon-based projection to the delta embeddings to control their norm.

set_updated_tokens

< >

( )

Sets up a backward hook to selectively update token gradients based on the CPT token type mask.

< > Update on GitHub