FRIENDS-GPT / feedforward.py
bala1802's picture
Upload 7 files
dabde41
import torch.nn as nn
import gpt_config as config
class FeedFoward(nn.Module):
""" a simple linear layer followed by a non-linearity """
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(config.dropout),
)
def forward(self, x):
return self.net(x)