Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from nets.decoder import Decoder | |
| from nets.projections import Projections | |
| from nets.encoder import Encoder | |
| class Model(nn.Module): | |
| def __init__(self, input_size, embedding_size, | |
| decoder_input_size, | |
| num_heads=8, num_layers=4, ff_hidden=250, *args, **kwargs): | |
| super().__init__() | |
| self.embedding_size = embedding_size | |
| # ----------- Encoder ----------- | |
| self.encoder = Encoder( | |
| n_heads=num_heads, | |
| embed_dim=embedding_size, | |
| n_layers=num_layers, | |
| feed_forward_hidden=ff_hidden, | |
| node_dim=input_size | |
| ) | |
| # ----------- Decoder ----------- | |
| self.decoder = Decoder( | |
| decoder_input_size=decoder_input_size, | |
| embedding_size=embedding_size, | |
| num_heads=num_heads | |
| ) | |
| # ----------- Attention Projections ----------- | |
| self.projections = Projections( | |
| n_heads=num_heads, | |
| embed_dim=embedding_size | |
| ) | |
| # ----------- Fleet Attention Encoder (Optional) ----------- | |
| self.fleet_attention = Encoder( | |
| n_heads=num_heads, | |
| embed_dim=embedding_size, | |
| n_layers=1, | |
| feed_forward_hidden=ff_hidden, | |
| node_dim=embedding_size + 1 | |
| ) | |