import torch class PrefixEncoder(torch.nn.Module): r''' The torch.nn model to encode the prefix Input shape: (batch-size, prefix-length) Output shape: (batch-size, prefix-length, 2*layers*hidden) ''' def __init__(self, config): super().__init__() self.prefix_projection = config.prefix_projection if self.prefix_projection: # Use a two-layer MLP to encode the prefix self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size) self.trans = torch.nn.Sequential( torch.nn.Linear(config.hidden_size, config.prefix_hidden_size), torch.nn.Tanh(), torch.nn.Linear(config.prefix_hidden_size, config.num_hidden_layers * 2 * config.hidden_size) ) else: self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size) def forward(self, prefix: torch.Tensor): device = next(self.embedding.parameters()).device if self.prefix_projection: prefix_tokens = self.embedding(prefix.to(device)) past_key_values = self.trans(prefix_tokens) else: past_key_values = self.embedding(prefix.to(device)) return past_key_values