Context-Aware Prompt Tuning (CPT)
Context-aware Prompt Tuning: Advancing In-Context Learning with Adversarial Methods (CPT) combines In-Context Learning (ICL) with Prompt Tuning (PT) and adversarial optimization to improve few-shot learning by refining context embeddings. CPT optimizes only context tokens, which minimizes overfitting and enhances performance on classification tasks.
The abstract from the paper is:
Traditional fine-tuning is effective but computationally intensive, as it requires updating billions of parameters. CPT, inspired by ICL, PT, and adversarial attacks, refines context embeddings in a parameter-efficient manner. By optimizing context tokens and applying a controlled gradient descent, CPT achieves superior accuracy across various few-shot classification tasks, showing significant improvement over existing methods such as LoRA, PT, and ICL.
CPTConfig
class peft.CPTConfig
< source >( peft_type: Union = None auto_mapping: Optional = None base_model_name_or_path: Optional = None revision: Optional = None task_type: Union = None inference_mode: bool = False num_virtual_tokens: int = None token_dim: int = None num_transformer_submodules: Optional = None num_attention_heads: Optional = None num_layers: Optional = None cpt_token_ids: Optional = None cpt_mask: Optional = None cpt_tokens_type_mask: Optional = None opt_weighted_loss_type: Optional = 'none' opt_loss_decay_factor: Optional = 1.0 opt_projection_epsilon: Optional = 0.1 opt_projection_format_epsilon: Optional = 0.1 tokenizer_name_or_path: Optional = None )
CPT Configuration class extending PeftConfig for Context-aware Prompt Tuning (CPT).
This class introduces additional parameters required for CPT, such as:
- Token type masks
- Prompt tuning initialization
- Loss weighting
- Projection settings
For more details, see the paper: https://arxiv.org/abs/2410.17222
CPTEmbedding
CPTEmbedding is a custom embedding layer designed for Context-aware Prompt Tuning (CPT) in PEFT. It initializes embeddings, applies prompt-specific projections, and computes loss using label masks.
calculate_loss
< source >( base_model_output labels cpt_type_mask config ) → ModelOutput
Parameters
- base_model_output (ModelOutput) — Output from the base model containing logits.
- labels (torch.Tensor) — Ground-truth labels for the input tokens.
- cpt_type_mask (torch.Tensor) — Token type mask used for filtering valid loss terms.
- config (Namespace) — Configuration object containing loss-related hyperparameters.
Returns
ModelOutput
The base model output with computed loss.
Computes the loss for CPT models with optional exponential decay.
forward
< source >( indices ) → torch.Tensor
Computes the prompt embeddings and applies delta adjustments.
Applies epsilon-based projection to the delta embeddings to control their norm.
Sets up a backward hook to selectively update token gradients based on the CPT token type mask.