Math_gpt / model.py
MalikAyaanAhmed1123's picture
Create model.py
be7af3d verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class TransformerBlock(nn.Module):
def __init__(self, embed_size, heads, dropout, forward_expansion):
super().__init__()
self.attention = nn.MultiheadAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion * embed_size),
nn.ReLU(),
nn.Linear(forward_expansion * embed_size, embed_size)
)
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query):
attn_output, _ = self.attention(query, key, value)
x = self.dropout(self.norm1(attn_output + query))
forward = self.feed_forward(x)
out = self.dropout(self.norm2(forward + x))
return out
class MathTransformer(nn.Module):
def __init__(self, vocab_size, embed_size=128, num_layers=2, heads=4, forward_expansion=4, max_length=512, dropout=0.1):
super().__init__()
self.embed_size = embed_size
self.word_embedding = nn.Embedding(vocab_size, embed_size)
self.position_embedding = nn.Embedding(max_length, embed_size)
self.layers = nn.ModuleList([
TransformerBlock(embed_size, heads, dropout, forward_expansion) for _ in range(num_layers)
])
self.fc_out = nn.Linear(embed_size, vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
N, seq_length = x.shape
positions = torch.arange(0, seq_length).expand(N, seq_length).to(x.device)
out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
for layer in self.layers:
out = layer(out, out, out)
out = self.fc_out(out)
return out