File size: 4,210 Bytes
e9cb6f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
from transformers import AutoTokenizer, PretrainedConfig, T5Config, PreTrainedModel, T5ForConditionalGeneration, \
    AutoModelForSeq2SeqLM
    
from typing import Optional, List, Callable, Mapping, Any, Union
import os

from .configuration_sip_finetune import SIPFinetuningModelConfig


class SIPFinetuningModel(PreTrainedModel):
    config_class = SIPFinetuningModelConfig

    def __init__(self, config: SIPFinetuningModelConfig):
        super().__init__(config)

        self.model = T5ForConditionalGeneration(config)

        # Initialize the prefix with NaNs.
        self.register_buffer("prefix_init_tensor", torch.zeros(config.num_precomputed_examples, config.prefix_max_init_length, config.d_model))

        # There are two cases: (1) we initialize the model after SIP-pretraining, i.e. the tunable prefix is not set
        # and (2) the model has been fine-tuned on downstream data, and hence there is meaningful data in the tunable prefix

        # Initialize the prefix with NaNs. If we initialize from SIP-pretraining, this will not be overwritten by a custom version of from_pretrained
        # if we initialize after fine-tuning, the NaNs will be overwritten anyway.

        self.prefix_embedding = torch.nn.Parameter(torch.nan + torch.zeros((1, self.config.prefix_length, self.config.d_model)))
        self.prefix_has_been_initialized = False

    def _initialize_prefix(self):
        prefix_init_tensor = self.prefix_init_tensor
        if self.config.random_selection:
            # randomize selection of FSTs to average for initialization the prefix.
            prefix_init_tensor = prefix_init_tensor[torch.randperm(prefix_init_tensor.shape[0]), :, :]

        prefix_init_tensor = prefix_init_tensor[:self.config.num_examples, :self.config.prefix_length,
                             :]  # shape (num ex, prefix length, d model)
        self.prefix_embedding.data.copy_(prefix_init_tensor.mean(dim=0, keepdims=True))

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
        *model_args,
        **kwargs,
    ):
        model = super(SIPFinetuningModel, cls).from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
        if torch.all(model.prefix_embedding.isnan()):
            model._initialize_prefix()
        return model


    def prepare_input(self, kwargs):
        """
        Prepends the prefix to the given input.
        :param kwargs:
        :return:
        """
        input_ids = kwargs["input_ids"]

        embedded_inputs = self.model.get_input_embeddings()(input_ids)

        batch_size = input_ids.shape[0]

        prefix = torch.repeat_interleave(self.prefix_embedding, batch_size, 0) #shape (batch, prefix length, embed dim)

        kwargs = dict(kwargs)

        embedded_inputs = torch.cat([prefix, embedded_inputs], dim=1)  # shape (batch, prefix + seq length, embed dim)

        del kwargs["input_ids"]
        kwargs["inputs_embeds"] = embedded_inputs

        if "attention_mask" in kwargs:
            ones = torch.ones((batch_size, self.config.prefix_length), device=embedded_inputs.device, dtype=kwargs["attention_mask"].dtype)
            input_mask = torch.cat([ones, kwargs["attention_mask"]], dim=1)
            kwargs["attention_mask"] = input_mask

        return kwargs

    def forward(self, **kwargs):
        return self.model(**self.prepare_input(kwargs))

    def generate(self, **kwargs):
        return self.model.generate(**self.prepare_input(kwargs))


    def get_optimizer(self, optimizer: Callable[..., torch.optim.Optimizer], prefix_lr:float = 1.0, **kwargs) -> torch.optim.Optimizer:
        """
        Return an optimizer that uses a different learning rate (typically higher) for the prefix than for the rest of the model.
        """

        prefix_params = []
        other_params = []
        for name, param in self.named_parameters():
            if name == "prefix_embedding":
                prefix_params.append(param)
            else:
                other_params.append(param)
        return optimizer(params=[{"params": prefix_params, "lr": prefix_lr}, {"params": other_params}], **kwargs)