File size: 6,203 Bytes
a2364f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from transformers import PreTrainedModel
import torch
import torch.nn as nn
from torch.nn import functional as F
from .configuration_medts import MedTSConfig
class FeedFoward(nn.Module):
""" a simple linear layer followed by a non-linearity """
def __init__(self, n_embd, dropout):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Head(nn.Module):
""" one head of self-attention """
def __init__(self, head_size, n_embd, block_size):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
def forward(self, x):
# input of size (batch, time-step, channels)
# output of size (batch, time-step, head size)
B,T,C = x.shape
k = self.key(x) # (B,T,hs)
q = self.query(x) # (B,T,hs)
# compute attention scores ("affinities")
wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
wei = F.softmax(wei, dim=-1) # (B, T, T)
# perform the weighted aggregation of the values
v = self.value(x) # (B,T,hs)
out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
return out
class MultiHeadAttention(nn.Module):
""" multiple heads of self-attention in parallel """
def __init__(self, num_heads, head_size, n_embd, dropout, block_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size, n_embd, block_size) for _ in range(num_heads)])
self.proj = nn.Linear(head_size * num_heads, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
class Block(nn.Module):
""" Transformer block: communication followed by computation """
def __init__(self, n_embd, n_head, dropout, block_size):
# n_embd: embedding dimension, n_head: the number of heads we'd like
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_head, head_size, n_embd, dropout, block_size)
self.ffwd = FeedFoward(n_embd, dropout)
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
class PatientsTimeSeriesModel(nn.Module):
def __init__(self, vocab_size, n_embd, block_size, device, n_layer, n_head, dropout):
'''
args:
- vocab_size: int, the number of unique tokens in the vocabulary, i.e. the number of unique tests results
- n_embd: int, the dimension of the embedding, i.e. the number of tests results (same as vocab_size)
- block_size: int, the length of the context
'''
super().__init__()
# each token directly reads off the logits for the next token from a lookup table
self.position_embedding_table = nn.Embedding(block_size, vocab_size)
# self.sa =Head(n_embd, n_embd, block_size)
self.blocks = nn.Sequential(*[Block(n_embd, n_head, dropout, block_size) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd) # final layer norm
self.lm_prefix = nn.Linear(vocab_size, n_embd) # linear layer to project the tokens to the embedding dimension
self.lm_head = nn.Linear(n_embd, vocab_size) # linear layer to project the embeddings to the vocabulary size
self.device = device
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, tok_emb, targets=None):
# tok_emb and targets are both (B,T,C) tensors
# where B is the batch size, T is the number of time steps and C is the number of tests results
B, T, C = tok_emb.shape
pos_emb = self.position_embedding_table(torch.arange(T, device=self.device)) # (T,Vocab_size)
x = tok_emb + pos_emb # (B,T,Vocab_size)
x = self.lm_prefix(x) # (B,T,C)
x = self.blocks(x) # (B,T,C)
x = self.ln_f(x) # (B,T,C)
logits = self.lm_head(x) # (B,T,vocab_size)
if targets is None:
return {"logits": logits}
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T, C)
# TODO: Add padding mask to the loss computation
# loss = F.mse_loss(logits, targets)
loss = self.mse_loss(logits, targets, reduction="mean")
return {"logits": logits, "loss": loss}
def mse_loss(self, out, target, reduction):
mask = (target == 0)
loss = (out[~mask]-target[~mask])**2
if reduction == "mean":
return loss.mean()
elif reduction == "None":
return loss
class MedTSModel(PreTrainedModel):
config_class = MedTSConfig
def __init__(self, config):
super().__init__(config)
self.model = PatientsTimeSeriesModel(
vocab_size=config.vocab_size,
n_embd=config.n_embd,
block_size=config.block_size,
device= 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu',
n_layer=config.n_layer,
n_head=config.n_head,
dropout=config.dropout
)
def forward(self, tensor, targets=None):
return self.model(tensor, targets) |