Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, embed_size, heads, dropout, forward_expansion): | |
| super().__init__() | |
| self.attention = nn.MultiheadAttention(embed_size, heads) | |
| self.norm1 = nn.LayerNorm(embed_size) | |
| self.norm2 = nn.LayerNorm(embed_size) | |
| self.feed_forward = nn.Sequential( | |
| nn.Linear(embed_size, forward_expansion * embed_size), | |
| nn.ReLU(), | |
| nn.Linear(forward_expansion * embed_size, embed_size) | |
| ) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, value, key, query): | |
| attn_output, _ = self.attention(query, key, value) | |
| x = self.dropout(self.norm1(attn_output + query)) | |
| forward = self.feed_forward(x) | |
| out = self.dropout(self.norm2(forward + x)) | |
| return out | |
| class MathTransformer(nn.Module): | |
| def __init__(self, vocab_size, embed_size=128, num_layers=2, heads=4, forward_expansion=4, max_length=512, dropout=0.1): | |
| super().__init__() | |
| self.embed_size = embed_size | |
| self.word_embedding = nn.Embedding(vocab_size, embed_size) | |
| self.position_embedding = nn.Embedding(max_length, embed_size) | |
| self.layers = nn.ModuleList([ | |
| TransformerBlock(embed_size, heads, dropout, forward_expansion) for _ in range(num_layers) | |
| ]) | |
| self.fc_out = nn.Linear(embed_size, vocab_size) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| N, seq_length = x.shape | |
| positions = torch.arange(0, seq_length).expand(N, seq_length).to(x.device) | |
| out = self.dropout(self.word_embedding(x) + self.position_embedding(positions)) | |
| for layer in self.layers: | |
| out = layer(out, out, out) | |
| out = self.fc_out(out) | |
| return out |