import torch import torch.nn as nn import gpt_config as config from head import Head class MultiHeadAttention(nn.Module): """ multiple heads of self-attention in parallel """ def __init__(self, num_heads, head_size): super().__init__() self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)]) self.proj = nn.Linear(head_size * num_heads, config.n_embd) self.dropout = nn.Dropout(config.dropout) def forward(self, x): out = torch.cat([h(x) for h in self.heads], dim=-1) out = self.dropout(self.proj(out)) return out