chatlm / models.py
DHRUV SHEKHAWAT
Upload 2 files
1dd09ef
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Embeddings(nn.Module):
"""
Implements embeddings of the words and adds their positional encodings.
"""
def __init__(self, vocab_size, d_model, max_len = 50):
super(Embeddings, self).__init__()
self.d_model = d_model
self.dropout = nn.Dropout(0.1)
self.embed = nn.Embedding(vocab_size, d_model)
self.pe = self.create_positinal_encoding(max_len, self.d_model)
self.dropout = nn.Dropout(0.1)
def create_positinal_encoding(self, max_len, d_model):
pe = torch.zeros(max_len, d_model).to(device)
for pos in range(max_len): # for each position of the word
for i in range(0, d_model, 2): # for each dimension of the each position
pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
pe = pe.unsqueeze(0) # include the batch size
return pe
def forward(self, encoded_words):
embedding = self.embed(encoded_words) * math.sqrt(self.d_model)
embedding += self.pe[:, :embedding.size(1)] # pe will automatically be expanded with the same batch size as encoded_words
embedding = self.dropout(embedding)
return embedding
class MultiHeadAttention(nn.Module):
def __init__(self, heads, d_model):
super(MultiHeadAttention, self).__init__()
assert d_model % heads == 0
self.d_k = d_model // heads
self.heads = heads
self.dropout = nn.Dropout(0.1)
self.query = nn.Linear(d_model, d_model)
self.key = nn.Linear(d_model, d_model)
self.value = nn.Linear(d_model, d_model)
self.concat = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask):
"""
query, key, value of shape: (batch_size, max_len, 512)
mask of shape: (batch_size, 1, 1, max_words)
"""
# (batch_size, max_len, 512)
query = self.query(query)
key = self.key(key)
value = self.value(value)
# (batch_size, max_len, 512) --> (batch_size, max_len, h, d_k) --> (batch_size, h, max_len, d_k)
query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
# (batch_size, h, max_len, d_k) matmul (batch_size, h, d_k, max_len) --> (batch_size, h, max_len, max_len)
scores = torch.matmul(query, key.permute(0,1,3,2)) / math.sqrt(query.size(-1))
scores = scores.masked_fill(mask == 0, -1e9) # (batch_size, h, max_len, max_len)
weights = F.softmax(scores, dim = -1) # (batch_size, h, max_len, max_len)
weights = self.dropout(weights)
# (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k)
context = torch.matmul(weights, value)
# (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, h * d_k)
context = context.permute(0,2,1,3).contiguous().view(context.shape[0], -1, self.heads * self.d_k)
# (batch_size, max_len, h * d_k)
interacted = self.concat(context)
return interacted
class FeedForward(nn.Module):
def __init__(self, d_model, middle_dim = 2048):
super(FeedForward, self).__init__()
self.fc1 = nn.Linear(d_model, middle_dim)
self.fc2 = nn.Linear(middle_dim, d_model)
self.dropout = nn.Dropout(0.1)
def forward(self, x):
out = F.relu(self.fc1(x))
out = self.fc2(self.dropout(out))
return out
class EncoderLayer(nn.Module):
def __init__(self, d_model, heads):
super(EncoderLayer, self).__init__()
self.layernorm = nn.LayerNorm(d_model)
self.self_multihead = MultiHeadAttention(heads, d_model)
self.feed_forward = FeedForward(d_model)
self.dropout = nn.Dropout(0.1)
def forward(self, embeddings, mask):
interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
interacted = self.layernorm(interacted + embeddings)
feed_forward_out = self.dropout(self.feed_forward(interacted))
encoded = self.layernorm(feed_forward_out + interacted)
return encoded
class DecoderLayer(nn.Module):
def __init__(self, d_model, heads):
super(DecoderLayer, self).__init__()
self.layernorm = nn.LayerNorm(d_model)
self.self_multihead = MultiHeadAttention(heads, d_model)
self.src_multihead = MultiHeadAttention(heads, d_model)
self.feed_forward = FeedForward(d_model)
self.dropout = nn.Dropout(0.1)
def forward(self, embeddings, encoded, src_mask, target_mask):
query = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, target_mask))
query = self.layernorm(query + embeddings)
interacted = self.dropout(self.src_multihead(query, encoded, encoded, src_mask))
interacted = self.layernorm(interacted + query)
feed_forward_out = self.dropout(self.feed_forward(interacted))
decoded = self.layernorm(feed_forward_out + interacted)
return decoded
class Transformer(nn.Module):
def __init__(self, d_model, heads, num_layers, word_map):
super(Transformer, self).__init__()
self.d_model = d_model
self.vocab_size = len(word_map)
self.embed = Embeddings(self.vocab_size, d_model)
self.encoder = nn.ModuleList([EncoderLayer(d_model, heads) for _ in range(num_layers)])
self.decoder = nn.ModuleList([DecoderLayer(d_model, heads) for _ in range(num_layers)])
self.logit = nn.Linear(d_model, self.vocab_size)
def encode(self, src_words, src_mask):
src_embeddings = self.embed(src_words)
for layer in self.encoder:
src_embeddings = layer(src_embeddings, src_mask)
return src_embeddings
def decode(self, target_words, target_mask, src_embeddings, src_mask):
tgt_embeddings = self.embed(target_words)
for layer in self.decoder:
tgt_embeddings = layer(tgt_embeddings, src_embeddings, src_mask, target_mask)
return tgt_embeddings
def forward(self, src_words, src_mask, target_words, target_mask):
encoded = self.encode(src_words, src_mask)
decoded = self.decode(target_words, target_mask, encoded, src_mask)
out = F.log_softmax(self.logit(decoded), dim = 2)
return out