saintyboy commited on
Commit
dc49c0d
1 Parent(s): cb69e53

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +150 -0
model.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from torch.utils.checkpoint import checkpoint
6
+
7
+ class RMSNorm(nn.Module):
8
+ def __init__(self, dim, eps=1e-6):
9
+ super().__init__()
10
+ self.eps = eps
11
+ self.weight = nn.Parameter(torch.ones(dim))
12
+
13
+ def forward(self, x):
14
+ mean_square = torch.mean(x ** 2, dim=-1, keepdim=True)
15
+ normalized_x = x / torch.sqrt(mean_square + self.eps)
16
+ return self.weight * normalized_x
17
+
18
+ class RotaryPositionalEmbedding(nn.Module):
19
+ def __init__(self, dim):
20
+ super().__init__()
21
+ self.dim = dim
22
+
23
+ def forward(self, x):
24
+ max_len = x.size(1)
25
+ freqs = torch.arange(0, self.dim // 2, dtype=torch.float32).to(x.device)
26
+ inv_freq = 1.0 / (10000 ** (freqs / (self.dim // 2)))
27
+ t = torch.arange(max_len, dtype=torch.float32).to(x.device)
28
+ sinusoid_inp = torch.outer(t, inv_freq)
29
+ sin_inp = sinusoid_inp.sin()
30
+ cos_inp = sinusoid_inp.cos()
31
+ emb_sin_cos = torch.stack((sin_inp, cos_inp), dim=-1).view(max_len, -1)
32
+ return x + emb_sin_cos[:max_len, :self.dim].unsqueeze(0)
33
+
34
+ def apply_rotary_emb(xq, xk, freqs_cis):
35
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
36
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
37
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
38
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
39
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
40
+ return xq_out.type_as(xq), xk_out.type_as(xk)
41
+
42
+ def reshape_for_broadcast(freqs_cis, x):
43
+ ndim = x.ndim
44
+ assert 0 <= 1 < ndim
45
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
46
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
47
+ return freqs_cis.view(*shape)
48
+
49
+ class SwiGLU(nn.Module):
50
+ def __init__(self, embed_size, expansion_factor=4):
51
+ super().__init__()
52
+ self.fc1 = nn.Linear(embed_size, expansion_factor * embed_size)
53
+ self.fc2 = nn.Linear(expansion_factor * embed_size, embed_size)
54
+ self.dropout = nn.Dropout(0.1)
55
+
56
+ def forward(self, x):
57
+ x = self.fc1(x)
58
+ x = F.silu(x) * x
59
+ x = self.dropout(x)
60
+ x = self.fc2(x)
61
+ return x
62
+
63
+ class SelfAttention(nn.Module):
64
+ def __init__(self, embed_size, heads):
65
+ super().__init__()
66
+ self.embed_size = embed_size
67
+ self.heads = heads
68
+ self.head_dim = embed_size // heads
69
+
70
+ assert embed_size % heads == 0, "Embed size must be divisible by heads"
71
+
72
+ self.values = nn.Linear(embed_size, embed_size, bias=False)
73
+ self.keys = nn.Linear(embed_size, embed_size, bias=False)
74
+ self.queries = nn.Linear(embed_size, embed_size, bias=False)
75
+ self.fc_out = nn.Linear(embed_size, embed_size)
76
+
77
+ def forward(self, values, keys, queries, mask=None):
78
+ N = queries.shape[0]
79
+ value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]
80
+
81
+ values = self.values(values).view(N, value_len, self.heads, self.head_dim).transpose(1, 2)
82
+ keys = self.keys(keys).view(N, key_len, self.heads, self.head_dim).transpose(1, 2)
83
+ queries = self.queries(queries).view(N, query_len, self.heads, self.head_dim).transpose(1, 2)
84
+
85
+ energy = torch.einsum("bthd,bshd->bhts", [queries, keys])
86
+
87
+ if mask is not None:
88
+ energy = energy.masked_fill(mask == 0, float('-inf'))
89
+
90
+ attention = torch.softmax(energy / (self.head_dim ** 0.5), dim=-1)
91
+
92
+ out = torch.einsum("bhts,bshd->bthd", [attention, values]).transpose(1, 2).reshape(N, query_len, self.embed_size)
93
+ return self.fc_out(out)
94
+
95
+ class TransformerBlock(nn.Module):
96
+ def __init__(self, embed_size, heads, expansion_factor=4, dropout=0.1, checkpoint=False):
97
+ super().__init__()
98
+ self.attention = SelfAttention(embed_size, heads)
99
+ self.feed_forward = SwiGLU(embed_size, expansion_factor)
100
+ self.norm1 = RMSNorm(embed_size)
101
+ self.norm2 = RMSNorm(embed_size)
102
+ self.rotary_pos_emb = RotaryPositionalEmbedding(embed_size)
103
+ self.checkpoint = checkpoint
104
+
105
+ def forward(self, value, mask=None):
106
+ def forward_fn(value, mask):
107
+ value = self.rotary_pos_emb(value)
108
+ attention = self.attention(value, value, value, mask)
109
+ x = self.norm1(attention + value)
110
+ forward = self.feed_forward(x)
111
+ out = self.norm2(forward + x)
112
+ return out
113
+
114
+ if self.checkpoint:
115
+ return checkpoint(forward_fn, value, mask)
116
+ else:
117
+ return forward_fn(value, mask)
118
+
119
+ class GPT(nn.Module):
120
+ def __init__(self, vocab_size, embed_size, num_layers, heads, max_length, expansion_factor=4, dropout=0.1, checkpoint=False):
121
+ super().__init__()
122
+ self.word_embedding = nn.Embedding(vocab_size, embed_size)
123
+ self.position_embedding = nn.Embedding(max_length, embed_size)
124
+
125
+ self.src_vocab_size = vocab_size
126
+
127
+ self.layers = nn.ModuleList(
128
+ [TransformerBlock(embed_size, heads, expansion_factor, dropout, checkpoint)
129
+ for _ in range(num_layers)]
130
+ )
131
+ self.norm = RMSNorm(embed_size)
132
+ self.fc_out = nn.Linear(embed_size, vocab_size)
133
+
134
+ def forward(self, x, mask=None):
135
+ positions = torch.arange(0, x.size(1)).unsqueeze(0).to(x.device)
136
+ x = self.word_embedding(x) + self.position_embedding(positions)
137
+
138
+ for layer in self.layers:
139
+ x = layer(x, mask)
140
+
141
+ x = self.norm(x)
142
+ return self.fc_out(x)
143
+
144
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
145
+ model = GPT(vocab_size=10000, embed_size=768, num_layers=20, heads=16, max_length=512, checkpoint=True)
146
+ model.to(device)
147
+
148
+ inputs = torch.randint(0, 10000, (1, 100), device=device)
149
+ outputs = model(inputs)
150
+ print(outputs.shape) # Should output: [1, 100, 10000]