EvoConvo / evo_model.py
HemanM's picture
Update evo_model.py
1070f67 verified
raw
history blame
688 Bytes
import torch.nn as nn
class EvoDecoderModel(nn.Module):
def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1):
super(EvoDecoderModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
self.output_layer = nn.Linear(d_model, vocab_size)
def forward(self, tgt, memory):
embedded = self.embedding(tgt)
output = self.transformer_decoder(embedded, memory)
return self.output_layer(output)