|
import torch.nn as nn |
|
|
|
class EvoDecoderModel(nn.Module): |
|
def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1): |
|
super(EvoDecoderModel, self).__init__() |
|
self.embedding = nn.Embedding(vocab_size, d_model) |
|
decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout) |
|
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers) |
|
self.output_layer = nn.Linear(d_model, vocab_size) |
|
|
|
def forward(self, tgt, memory): |
|
embedded = self.embedding(tgt) |
|
output = self.transformer_decoder(embedded, memory) |
|
return self.output_layer(output) |
|
|