IE101TW / models /basic_modules /prefix_encoder.py
DeepLearning101's picture
Upload 6 files
437e42f
raw
history blame
1.46 kB
import torch
# from transformers.models.bart.modeling_bart import BartForConditionalGeneration
# from transformers.models.bert.modeling_bert import BertForSequenceClassification
# model = BartForConditionalGeneration(None)
class PrefixEncoder(torch.nn.Module):
r"""
The torch.nn model to encode the prefix
Input shape: (batch-size, prefix-length)
Output shape: (batch-size, prefix-length, 2*layers*hidden)
"""
def __init__(self, config):
super().__init__()
self.prefix_projection = config.prefix_projection
if self.prefix_projection:
# Use a two-layer MLP to encode the prefix
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
self.trans = torch.nn.Sequential(
torch.nn.Linear(config.hidden_size, config.prefix_hidden_size),
torch.nn.Tanh(),
torch.nn.Linear(config.prefix_hidden_size, config.num_hidden_layers * 2 * config.hidden_size)
)
else:
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size)
def forward(self, prefix: torch.Tensor):
if self.prefix_projection:
prefix_tokens = self.embedding(prefix) # [pre_seq_len, hidden_dim]
past_key_values = self.trans(prefix_tokens)
else:
past_key_values = self.embedding(prefix)
return past_key_values