import torch from transformers import AutoTokenizer, PretrainedConfig, T5Config, PreTrainedModel, T5ForConditionalGeneration, \ AutoModelForSeq2SeqLM from typing import Optional, List, Callable, Mapping, Any, Union import os from .configuration_sip_finetune import SIPFinetuningModelConfig class SIPFinetuningModel(PreTrainedModel): config_class = SIPFinetuningModelConfig def __init__(self, config: SIPFinetuningModelConfig): super().__init__(config) self.model = T5ForConditionalGeneration(config) # Initialize the prefix with NaNs. self.register_buffer("prefix_init_tensor", torch.zeros(config.num_precomputed_examples, config.prefix_max_init_length, config.d_model)) # There are two cases: (1) we initialize the model after SIP-pretraining, i.e. the tunable prefix is not set # and (2) the model has been fine-tuned on downstream data, and hence there is meaningful data in the tunable prefix # Initialize the prefix with NaNs. If we initialize from SIP-pretraining, this will not be overwritten by a custom version of from_pretrained # if we initialize after fine-tuning, the NaNs will be overwritten anyway. self.prefix_embedding = torch.nn.Parameter(torch.nan + torch.zeros((1, self.config.prefix_length, self.config.d_model))) self.prefix_has_been_initialized = False def _initialize_prefix(self): prefix_init_tensor = self.prefix_init_tensor if self.config.random_selection: # randomize selection of FSTs to average for initialization the prefix. prefix_init_tensor = prefix_init_tensor[torch.randperm(prefix_init_tensor.shape[0]), :, :] prefix_init_tensor = prefix_init_tensor[:self.config.num_examples, :self.config.prefix_length, :] # shape (num ex, prefix length, d model) self.prefix_embedding.data.copy_(prefix_init_tensor.mean(dim=0, keepdims=True)) @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs, ): model = super(SIPFinetuningModel, cls).from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) if torch.all(model.prefix_embedding.isnan()): model._initialize_prefix() return model def prepare_input(self, kwargs): """ Prepends the prefix to the given input. :param kwargs: :return: """ input_ids = kwargs["input_ids"] embedded_inputs = self.model.get_input_embeddings()(input_ids) batch_size = input_ids.shape[0] prefix = torch.repeat_interleave(self.prefix_embedding, batch_size, 0) #shape (batch, prefix length, embed dim) kwargs = dict(kwargs) embedded_inputs = torch.cat([prefix, embedded_inputs], dim=1) # shape (batch, prefix + seq length, embed dim) del kwargs["input_ids"] kwargs["inputs_embeds"] = embedded_inputs if "attention_mask" in kwargs: ones = torch.ones((batch_size, self.config.prefix_length), device=embedded_inputs.device, dtype=kwargs["attention_mask"].dtype) input_mask = torch.cat([ones, kwargs["attention_mask"]], dim=1) kwargs["attention_mask"] = input_mask return kwargs def forward(self, **kwargs): return self.model(**self.prepare_input(kwargs)) def generate(self, **kwargs): return self.model.generate(**self.prepare_input(kwargs)) def get_optimizer(self, optimizer: Callable[..., torch.optim.Optimizer], prefix_lr:float = 1.0, **kwargs) -> torch.optim.Optimizer: """ Return an optimizer that uses a different learning rate (typically higher) for the prefix than for the rest of the model. """ prefix_params = [] other_params = [] for name, param in self.named_parameters(): if name == "prefix_embedding": prefix_params.append(param) else: other_params.append(param) return optimizer(params=[{"params": prefix_params, "lr": prefix_lr}, {"params": other_params}], **kwargs)