|
import torch |
|
from transformers import AutoTokenizer, PretrainedConfig, T5Config, PreTrainedModel, T5ForConditionalGeneration, \ |
|
AutoModelForSeq2SeqLM |
|
|
|
from typing import Optional, List, Callable, Mapping, Any, Union |
|
import os |
|
|
|
from .configuration_sip_finetune import SIPFinetuningModelConfig |
|
|
|
|
|
class SIPFinetuningModel(PreTrainedModel): |
|
config_class = SIPFinetuningModelConfig |
|
|
|
def __init__(self, config: SIPFinetuningModelConfig): |
|
super().__init__(config) |
|
|
|
self.model = T5ForConditionalGeneration(config) |
|
|
|
|
|
self.register_buffer("prefix_init_tensor", torch.zeros(config.num_precomputed_examples, config.prefix_max_init_length, config.d_model)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.prefix_embedding = torch.nn.Parameter(torch.nan + torch.zeros((1, self.config.prefix_length, self.config.d_model))) |
|
self.prefix_has_been_initialized = False |
|
|
|
def _initialize_prefix(self): |
|
prefix_init_tensor = self.prefix_init_tensor |
|
if self.config.random_selection: |
|
|
|
prefix_init_tensor = prefix_init_tensor[torch.randperm(prefix_init_tensor.shape[0]), :, :] |
|
|
|
prefix_init_tensor = prefix_init_tensor[:self.config.num_examples, :self.config.prefix_length, |
|
:] |
|
self.prefix_embedding.data.copy_(prefix_init_tensor.mean(dim=0, keepdims=True)) |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], |
|
*model_args, |
|
**kwargs, |
|
): |
|
model = super(SIPFinetuningModel, cls).from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
|
if torch.all(model.prefix_embedding.isnan()): |
|
model._initialize_prefix() |
|
return model |
|
|
|
|
|
def prepare_input(self, kwargs): |
|
""" |
|
Prepends the prefix to the given input. |
|
:param kwargs: |
|
:return: |
|
""" |
|
input_ids = kwargs["input_ids"] |
|
|
|
embedded_inputs = self.model.get_input_embeddings()(input_ids) |
|
|
|
batch_size = input_ids.shape[0] |
|
|
|
prefix = torch.repeat_interleave(self.prefix_embedding, batch_size, 0) |
|
|
|
kwargs = dict(kwargs) |
|
|
|
embedded_inputs = torch.cat([prefix, embedded_inputs], dim=1) |
|
|
|
del kwargs["input_ids"] |
|
kwargs["inputs_embeds"] = embedded_inputs |
|
|
|
if "attention_mask" in kwargs: |
|
ones = torch.ones((batch_size, self.config.prefix_length), device=embedded_inputs.device, dtype=kwargs["attention_mask"].dtype) |
|
input_mask = torch.cat([ones, kwargs["attention_mask"]], dim=1) |
|
kwargs["attention_mask"] = input_mask |
|
|
|
return kwargs |
|
|
|
def forward(self, **kwargs): |
|
return self.model(**self.prepare_input(kwargs)) |
|
|
|
def generate(self, **kwargs): |
|
return self.model.generate(**self.prepare_input(kwargs)) |
|
|
|
|
|
def get_optimizer(self, optimizer: Callable[..., torch.optim.Optimizer], prefix_lr:float = 1.0, **kwargs) -> torch.optim.Optimizer: |
|
""" |
|
Return an optimizer that uses a different learning rate (typically higher) for the prefix than for the rest of the model. |
|
""" |
|
|
|
prefix_params = [] |
|
other_params = [] |
|
for name, param in self.named_parameters(): |
|
if name == "prefix_embedding": |
|
prefix_params.append(param) |
|
else: |
|
other_params.append(param) |
|
return optimizer(params=[{"params": prefix_params, "lr": prefix_lr}, {"params": other_params}], **kwargs) |
|
|
|
|