Yi3852 commited on
Commit
6669548
·
verified ·
1 Parent(s): f6f01fa
Files changed (2) hide show
  1. model.pt +3 -0
  2. model.py +159 -0
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52e34a24fbae8f947251a220055b37575ce6096c6be7f3fc908f9db8afa674f8
3
+ size 1084213418
model.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+
6
+ class CausalSelfAttention(nn.Module):
7
+
8
+ def __init__(self, config):
9
+ super().__init__()
10
+ assert config.n_embd % config.n_head == 0
11
+ # key, query, value projections for all heads, but in a batch
12
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
13
+ # output projection
14
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
15
+ self.c_proj.NANOGPT_SCALE_INIT = 1
16
+ # regularization
17
+ self.n_head = config.n_head
18
+ self.n_embd = config.n_embd
19
+
20
+ def forward(self, x):
21
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
22
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
23
+ # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
24
+ # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
25
+ qkv = self.c_attn(x)
26
+ q, k, v = qkv.split(self.n_embd, dim=2)
27
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
28
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
29
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
30
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # flash attention
31
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
32
+ # output projection
33
+ y = self.c_proj(y)
34
+ return y
35
+
36
+ class MLP(nn.Module):
37
+
38
+ def __init__(self, config):
39
+ super().__init__()
40
+ # different from original gpt2: mul 2x -> 5x
41
+ self.c_fc = nn.Linear(config.n_embd, 5 * config.n_embd)
42
+ self.gelu = nn.GELU(approximate='tanh')
43
+ self.c_proj = nn.Linear(5 * config.n_embd, config.n_embd)
44
+ self.c_proj.NANOGPT_SCALE_INIT = 1
45
+
46
+ def forward(self, x):
47
+ x = self.c_fc(x)
48
+ x = self.gelu(x)
49
+ x = self.c_proj(x)
50
+ return x
51
+
52
+ class Block(nn.Module):
53
+
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ self.ln_1 = nn.LayerNorm(config.n_embd)
57
+ self.attn = CausalSelfAttention(config)
58
+ self.ln_2 = nn.LayerNorm(config.n_embd)
59
+ self.mlp = MLP(config)
60
+
61
+ def forward(self, x):
62
+ x = x + self.attn(self.ln_1(x))
63
+ x = x + self.mlp(self.ln_2(x))
64
+ return x
65
+
66
+ @dataclass
67
+ class GPTConfig:
68
+ block_size: int = 1024 # max sequence length
69
+ vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
70
+ n_layer: int = 12 # number of layers
71
+ n_head: int = 12 # number of heads
72
+ n_embd: int = 768 # embedding dimension
73
+
74
+ class GPT(nn.Module):
75
+
76
+ def __init__(self, config):
77
+ super().__init__()
78
+ self.config = config
79
+
80
+ self.transformer = nn.ModuleDict(dict(
81
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
82
+ wpe = nn.Embedding(config.block_size, config.n_embd),
83
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
84
+ ln_f = nn.LayerNorm(config.n_embd),
85
+ ))
86
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
87
+
88
+ # weight sharing scheme
89
+ self.transformer.wte.weight = self.lm_head.weight
90
+
91
+ # init params
92
+ self.apply(self._init_weights)
93
+
94
+ def _init_weights(self, module):
95
+ if isinstance(module, nn.Linear):
96
+ std = 0.02
97
+ if hasattr(module, 'NANOGPT_SCALE_INIT'):
98
+ std *= (2 * self.config.n_layer) ** -0.5
99
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
100
+ if module.bias is not None:
101
+ torch.nn.init.zeros_(module.bias)
102
+ elif isinstance(module, nn.Embedding):
103
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
104
+
105
+ def forward(self, idx, targets=None):
106
+ # idx is of shape (B, T)
107
+ B, T = idx.size()
108
+ assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
109
+ # forward the token and posisition embeddings
110
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
111
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
112
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
113
+ x = tok_emb + pos_emb
114
+ # forward the blocks of the transformer
115
+ for block in self.transformer.h:
116
+ x = block(x)
117
+ # forward the final layernorm and the classifier
118
+ x = self.transformer.ln_f(x)
119
+ logits = self.lm_head(x) # (B, T, vocab_size)
120
+ loss = None
121
+ if targets is not None:
122
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
123
+ return logits, loss
124
+
125
+ def generate(self,input_ids,topk=50,max_length=100):
126
+ self.eval()
127
+ device = input_ids.device
128
+ sample_rng = torch.Generator(device=device)
129
+ xgen = input_ids
130
+ while xgen.size(1) < max_length:
131
+ # forward the model to get the logits
132
+ with torch.no_grad():
133
+ logits, _ = self(xgen) # (B, T, vocab_size)
134
+ # take the logits at the last position
135
+ logits = logits[:, -1, :] # (B, vocab_size)
136
+ # get the probabilities
137
+ probs = F.softmax(logits, dim=-1)
138
+ topk_probs, topk_indices = torch.topk(probs, topk, dim=-1)
139
+ # select a token from the top-k probabilities
140
+ # note: multinomial does not demand the input to sum to 1
141
+ ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1)
142
+ # gather the corresponding indices
143
+ xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
144
+ # append to the sequence
145
+ xgen = torch.cat((xgen, xcol), dim=1)
146
+ for i in range(xgen.size(0)):
147
+ tokens = xgen[i, :max_length].tolist()
148
+
149
+ return tokens
150
+ @classmethod
151
+ def from_pretrained(cls, path):
152
+ ckpt=torch.load(path, map_location='cpu')
153
+ model=GPT(ckpt['config'])
154
+ model.load_state_dict(ckpt['model'])
155
+ return model
156
+
157
+
158
+
159
+