gagan001 commited on
Commit
9e82102
·
1 Parent(s): 33aa743

Added new architecture supporting GQA

Browse files
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
app.py CHANGED
@@ -4,7 +4,7 @@ from my_gpt import my_gpt
4
  from tokenizer.tokenizer import BPE
5
 
6
  ##Load model
7
- model = my_gpt.load_pretrained("model/model_1000_cpu.bin")
8
  # model.to(torch.device("cpu"))
9
  # model.save_pretrained("model/model_1000_cpu.bin")
10
  # exit()
@@ -21,7 +21,7 @@ def generate(input_text):
21
  iface = gr.Interface(fn=generate,
22
  inputs="text",
23
  outputs="text",
24
- title="GPT - 1000 steps",
25
- description="""This model is trained for 1000 steps only. It is not
26
- able to generate perfect sentences/words. However, it has learnt a gist of the English language""")
27
  iface.launch()
 
4
  from tokenizer.tokenizer import BPE
5
 
6
  ##Load model
7
+ model = my_gpt.load_pretrained("model/model_1000_.bin")
8
  # model.to(torch.device("cpu"))
9
  # model.save_pretrained("model/model_1000_cpu.bin")
10
  # exit()
 
21
  iface = gr.Interface(fn=generate,
22
  inputs="text",
23
  outputs="text",
24
+ title="NoobGPT - 1000 steps",
25
+ description="""This 13M param model is trained for 1000steps only and has seen only 1M tokens. It is not
26
+ able to generate perfect sentences/words but has acquired a rudimentary understanding of the English language""")
27
  iface.launch()
model/.DS_Store CHANGED
Binary files a/model/.DS_Store and b/model/.DS_Store differ
 
model/{model_1000_cpu.bin → model_1000_.bin} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f88eff25b6947e11a832f96e2bc914c6818989045539c327438c3e490b184cc9
3
- size 56951293
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5877a72287e65e61deab89188115afa2eb7dade01cbde49c3103fa40b468a1c8
3
+ size 56607625
my_gpt.py CHANGED
@@ -4,13 +4,14 @@ from torch.nn import functional as F
4
  import json
5
  import logging
6
 
7
-
8
- block_size = 256
9
  vocab_size = 500
10
  n_embed = 384
11
  dropout = 0.2
12
  n_head = 6
13
  n_layer = 6
 
 
14
 
15
  class Head(nn.Module):
16
  def __init__(self, head_size=16):
@@ -40,18 +41,102 @@ class Head(nn.Module):
40
 
41
  return out
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  class MultiHeadAttention(nn.Module):
44
- def __init__(self,num_heads, head_size) :
45
  super().__init__()
 
 
 
 
 
 
 
46
 
47
- self.heads = nn.ModuleList(Head(head_size=head_size) for _ in range(num_heads))
48
- self.proj = nn.Linear(head_size * num_heads, n_embed)
 
 
 
 
 
 
 
 
 
49
  self.dropout = nn.Dropout(dropout)
50
 
51
- def forward(self, x):
52
- out = torch.cat([h(x) for h in self.heads], dim=-1)
53
- out = self.dropout(self.proj(out))
54
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  class FeedForward(nn.Module):
57
  def __init__(self,n_embed) -> None:
@@ -68,26 +153,33 @@ class FeedForward(nn.Module):
68
  return x
69
 
70
  class decoder_block(nn.Module):
71
- def __init__(self, n_embed, n_heads):
72
  super().__init__()
73
- self.sa = MultiHeadAttention(n_heads,n_embed//n_heads)
 
 
74
  self.ln1 = nn.LayerNorm(n_embed)
75
  self.ln2 = nn.LayerNorm(n_embed)
76
  self.ffwd = FeedForward(n_embed)
 
 
 
 
77
 
78
  def forward(self, x):
79
- x = x + self.sa(self.ln1(x))
80
  x = x + self.ffwd(self.ln2(x))
81
  return x
82
 
83
 
84
 
85
  class my_gpt(nn.Module):
86
- def __init__(self, block_size = 256):
87
  super().__init__()
 
88
  self.block_size = block_size ##context window size
89
  self.token_embed = nn.Embedding(vocab_size, n_embed)
90
- self.pos_embed = nn.Embedding(vocab_size, n_embed)
91
  self.lm_head = nn.Linear(n_embed, vocab_size)
92
  self.sa_head = Head(vocab_size)
93
  self.d_blocks = nn.Sequential(*[decoder_block(n_embed=n_embed,n_heads=n_head) for _ in range(n_layer)])
@@ -103,19 +195,20 @@ class my_gpt(nn.Module):
103
  elif isinstance(module, nn.Embedding):
104
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
105
 
106
- def forward(self, idx, targets = None):
107
  """
108
  Args:
109
- idx: int(B,T) Token ids
110
  targets :
111
 
112
  Returns:
113
  logits
114
  """
115
  # print("idx ", idx)
116
- B, T = idx.shape ##
117
- tok_emd = self.token_embed(idx) ##(B,T,C)
118
- pos_emd = self.pos_embed(idx)
 
119
 
120
 
121
  x = tok_emd + pos_emd
@@ -154,6 +247,7 @@ class my_gpt(nn.Module):
154
  for _ in range(max_new_tokens):
155
  ##Take only last allowed number of tokens
156
  idx_tokens = context[:, -self.block_size:]
 
157
 
158
  ##generate the next token
159
  logits, loss = self(idx_tokens)
 
4
  import json
5
  import logging
6
 
7
+ block_size = 128
 
8
  vocab_size = 500
9
  n_embed = 384
10
  dropout = 0.2
11
  n_head = 6
12
  n_layer = 6
13
+ kv_heads = 3
14
+ max_position_embeddings = 128
15
 
16
  class Head(nn.Module):
17
  def __init__(self, head_size=16):
 
41
 
42
  return out
43
 
44
+
45
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
46
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
47
+ """
48
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
49
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
50
+ """
51
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
52
+ if n_rep == 1:
53
+ return hidden_states
54
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
55
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
56
+
57
+
58
  class MultiHeadAttention(nn.Module):
59
+ def __init__(self,num_heads, head_dim) :
60
  super().__init__()
61
+ assert num_heads%kv_heads == 0
62
+ self.n_embed = n_embed
63
+ self.num_attn_heads = num_heads
64
+ self.head_dim = head_dim
65
+ self.kv_heads = kv_heads
66
+ # self.kv_out_proj = head_dim * self.kv_heads #Doubt
67
+ self.num_kv_groups = self.num_attn_heads // self.kv_heads
68
 
69
+
70
+ self.heads = nn.ModuleList(Head(head_size=head_dim) for _ in range(num_heads))
71
+ ##Only self attention
72
+
73
+ #For num_attn_heads number of heads
74
+ self.Wq = nn.Linear(self.n_embed, self.num_attn_heads*self.head_dim)
75
+ #For kv_heads number of heads
76
+ self.Wk = nn.Linear(self.n_embed, self.kv_heads * self.head_dim)
77
+ self.Wv = nn.Linear(self.n_embed, self.kv_heads * self.head_dim)
78
+
79
+ self.o_proj = nn.Linear(self.head_dim * self.num_attn_heads, self.n_embed)
80
  self.dropout = nn.Dropout(dropout)
81
 
82
+ # self.attention_mask = torch.zeros((bsz, self.num_attn_heads, qlen, qlen))
83
+ # self.attention_mask[:, :, :, qlen:] = float('-inf') # Mask out positions beyond the key sequence length
84
+
85
+
86
+ def forward(self, x, attn_mask= None):
87
+ """
88
+ Parameters:
89
+ x (bsz, qlen, embed) : input
90
+ """
91
+ # out = torch.cat([h(x) for h in self.heads], dim=-1)
92
+ # attn_output = self.dropout(self.o_proj(out))
93
+
94
+ # ################ Experiment
95
+
96
+
97
+ bsz, qlen, embed = x.size()
98
+ # print("input size", x.size())
99
+
100
+ q = self.Wq(x) ##(B,T,head_dim * num_heads)
101
+ k = self.Wk(x) ##(B,T,head_dim * kv_heads)
102
+ v = self.Wv(x) ##(B,T,head_dim * kv_heads)
103
+
104
+
105
+
106
+ q = q.view(bsz, qlen, self.num_attn_heads, self.head_dim).transpose(2,1) ##(B,T,head_dim * num_heads)
107
+ k = k.view(bsz, qlen, self.kv_heads, self.head_dim).transpose(2,1) ##(B,T,head_dim * kv_heads)
108
+ v = v.view(bsz, qlen, self.kv_heads, self.head_dim).transpose(2,1) ##(B,T,head_dim * kv_heads)
109
+
110
+ # print("k-shape before ",k.shape)
111
+ k = repeat_kv(k, self.num_kv_groups) ##(B, n_kvheads * nrep, qlen, head_dim)
112
+ v = repeat_kv(v, self.num_kv_groups)
113
+
114
+ attn_scores = q @ k.transpose(-1,-2)/torch.sqrt(torch.tensor(self.n_embed)) ##(B, T, block_size)
115
+
116
+ ################
117
+ # print("Q-shape ", q.shape)
118
+ # print("k-shape ",k.shape)
119
+ # print(k.shape[-2])
120
+ # print(attn_scores.shape)
121
+
122
+ if attn_mask is not None:
123
+ # causal_mask = attn_mask[:, :, :, : k.shape[-2]]
124
+ # attn_scores = attn_scores + causal_mask
125
+ attn_scores = attn_scores.masked_fill(
126
+ attn_mask[None, None, :qlen, :qlen]==0 , float("-inf")
127
+ )
128
+
129
+
130
+ attn_scores = F.softmax(attn_scores, dim=-1)
131
+ attn_scores = F.dropout(attn_scores) ##Why this dropout is required??
132
+
133
+ attn_output = torch.matmul(attn_scores, v) ##(B, n_heads, qlen, hidden_size)
134
+ attn_output = attn_output.transpose(1,2).contiguous()
135
+ attn_output = attn_output.view(bsz, qlen, self.n_embed)
136
+
137
+ attn_output = self.o_proj(attn_output)
138
+ return attn_output
139
+
140
 
141
  class FeedForward(nn.Module):
142
  def __init__(self,n_embed) -> None:
 
153
  return x
154
 
155
  class decoder_block(nn.Module):
156
+ def __init__(self, n_embed, n_heads, attn_mask=None):
157
  super().__init__()
158
+ # Assume 0 for allowed positions and -inf for masked positions
159
+
160
+ self.sa = MultiHeadAttention(n_heads,n_embed//n_head)
161
  self.ln1 = nn.LayerNorm(n_embed)
162
  self.ln2 = nn.LayerNorm(n_embed)
163
  self.ffwd = FeedForward(n_embed)
164
+ # self.causal_mask = torch.tril(torch.ones(block_size,block_size))
165
+ self.register_buffer('causal_mask',torch.tril(torch.ones(block_size,block_size)))
166
+
167
+
168
 
169
  def forward(self, x):
170
+ x = x + self.sa(self.ln1(x), attn_mask = self.causal_mask)
171
  x = x + self.ffwd(self.ln2(x))
172
  return x
173
 
174
 
175
 
176
  class my_gpt(nn.Module):
177
+ def __init__(self, device='cpu', block_size = 128):
178
  super().__init__()
179
+ self.device = device
180
  self.block_size = block_size ##context window size
181
  self.token_embed = nn.Embedding(vocab_size, n_embed)
182
+ self.pos_embed = nn.Embedding(max_position_embeddings, n_embed)
183
  self.lm_head = nn.Linear(n_embed, vocab_size)
184
  self.sa_head = Head(vocab_size)
185
  self.d_blocks = nn.Sequential(*[decoder_block(n_embed=n_embed,n_heads=n_head) for _ in range(n_layer)])
 
195
  elif isinstance(module, nn.Embedding):
196
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
197
 
198
+ def forward(self, x, targets = None):
199
  """
200
  Args:
201
+ x: int(B,T) Token ids
202
  targets :
203
 
204
  Returns:
205
  logits
206
  """
207
  # print("idx ", idx)
208
+ B, T = x.size() ##
209
+ tok_emd = self.token_embed(x) ##(B,T,C)
210
+ position_ids = torch.arange(T, device = self.device )
211
+ pos_emd = self.pos_embed(position_ids)
212
 
213
 
214
  x = tok_emd + pos_emd
 
247
  for _ in range(max_new_tokens):
248
  ##Take only last allowed number of tokens
249
  idx_tokens = context[:, -self.block_size:]
250
+ # print(f"idx tokens {idx_tokens.shape}")
251
 
252
  ##generate the next token
253
  logits, loss = self(idx_tokens)