Spaces:
Sleeping
Sleeping
import torch | |
# from transformers.models.bart.modeling_bart import BartForConditionalGeneration | |
# from transformers.models.bert.modeling_bert import BertForSequenceClassification | |
# model = BartForConditionalGeneration(None) | |
class PrefixEncoder(torch.nn.Module): | |
r""" | |
The torch.nn model to encode the prefix | |
Input shape: (batch-size, prefix-length) | |
Output shape: (batch-size, prefix-length, 2*layers*hidden) | |
""" | |
def __init__(self, config): | |
super().__init__() | |
self.prefix_projection = config.prefix_projection | |
if self.prefix_projection: | |
# Use a two-layer MLP to encode the prefix | |
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size) | |
self.trans = torch.nn.Sequential( | |
torch.nn.Linear(config.hidden_size, config.prefix_hidden_size), | |
torch.nn.Tanh(), | |
torch.nn.Linear(config.prefix_hidden_size, config.num_hidden_layers * 2 * config.hidden_size) | |
) | |
else: | |
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size) | |
def forward(self, prefix: torch.Tensor): | |
if self.prefix_projection: | |
prefix_tokens = self.embedding(prefix) # [pre_seq_len, hidden_dim] | |
past_key_values = self.trans(prefix_tokens) | |
else: | |
past_key_values = self.embedding(prefix) | |
return past_key_values | |