simpx commited on
Commit
c2f214d
·
verified ·
1 Parent(s): 0c5ac87

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +144 -0
model.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from transformers import PreTrainedModel, PretrainedConfig
6
+ import os
7
+
8
+ # 超参数
9
+ batch_size = 64 # 一批包含的文本序列个数
10
+ block_size = 256 # 一个文本序列包含的字符个数
11
+ n_embed = 384 # embedding维度
12
+ n_head = 6
13
+ n_layer = 6
14
+ dropout = 0.2
15
+
16
+ # 准备词汇表
17
+ current_dir = os.path.dirname(os.path.abspath(__file__))
18
+ with open(os.path.join(current_dir, 'input.txt'), 'r', encoding='utf-8') as f:
19
+ text = f.read()
20
+
21
+ chars = sorted(list(set(text)))
22
+ vocab_size = len(chars)
23
+
24
+ # decode、encode函数,在序号和字符间转换
25
+ stoi = { ch:i for i,ch in enumerate(chars) }
26
+ itos = { i:ch for i,ch in enumerate(chars) }
27
+ encode = lambda s: [stoi[c] for c in s]
28
+ decode = lambda l: ''.join([itos[i] for i in l])
29
+
30
+ class NoobConfig(PretrainedConfig):
31
+ model_type = "Noob"
32
+ vocab_size = vocab_size
33
+ n_positions = block_size
34
+ n_embd = n_embed
35
+ n_layer = n_layer
36
+ n_head = n_head
37
+
38
+ class Head(nn.Module):
39
+ """ one head of self-attention """
40
+ def __init__(self, head_size):
41
+ super().__init__()
42
+ self.key = nn.Linear(n_embed, head_size, bias=False)
43
+ self.query = nn.Linear(n_embed, head_size, bias=False)
44
+ self.value = nn.Linear(n_embed, head_size, bias=False)
45
+ self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
46
+ self.dropout = nn.Dropout(dropout)
47
+
48
+ def forward(self, x):
49
+ B, T, C = x.shape
50
+ k = self.key(x)
51
+ q = self.query(x)
52
+ wei = q @ k.transpose(-2, -1) * C**-0.5
53
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
54
+ wei = F.softmax(wei, dim=-1)
55
+ wei = self.dropout(wei)
56
+ v = self.value(x)
57
+ out = wei @ v
58
+ return out
59
+
60
+ class MultiHeadAttention(nn.Module):
61
+ def __init__(self, num_heads, head_size):
62
+ super().__init__()
63
+ self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
64
+ self.proj = nn.Linear(n_embed, n_embed)
65
+ self.dropout = nn.Dropout(dropout)
66
+
67
+ def forward(self, x):
68
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
69
+ out = self.proj(out)
70
+ out = self.dropout(out)
71
+ return out
72
+
73
+ class FeedFoward(nn.Module):
74
+ def __init__(self, n_embed):
75
+ super().__init__()
76
+ self.net = nn.Sequential(
77
+ nn.Linear(n_embed, 4 * n_embed),
78
+ nn.ReLU(),
79
+ nn.Linear(4 * n_embed, n_embed),
80
+ nn.Dropout(dropout),
81
+ )
82
+
83
+ def forward(self, x):
84
+ return self.net(x)
85
+
86
+ class Block(nn.Module):
87
+ """ transformer block: communication followed by computation """
88
+ def __init__(self, n_embed, n_head):
89
+ super().__init__()
90
+ head_size = n_embed // n_head
91
+ self.sa = MultiHeadAttention(n_head, head_size)
92
+ self.ffwd = FeedFoward(n_embed)
93
+ self.ln1 = nn.LayerNorm(n_embed)
94
+ self.ln2 = nn.LayerNorm(n_embed)
95
+
96
+ def forward(self, x):
97
+ x = x + self.sa(self.ln1(x))
98
+ x = x + self.ffwd(self.ln2(x))
99
+ return x
100
+
101
+ class Noob(PreTrainedModel):
102
+ config_class = NoobConfig
103
+
104
+ def __init__(self, config):
105
+ super().__init__(config)
106
+ self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embd)
107
+ self.position_embedding_table = nn.Embedding(config.n_positions, config.n_embd)
108
+ self.blocks = nn.Sequential(*[Block(config.n_embd, config.n_head) for _ in range(config.n_layer)])
109
+ self.ln_final = nn.LayerNorm(config.n_embd)
110
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
111
+
112
+ def forward(self, idx, targets=None):
113
+ B, T = idx.shape
114
+ tok_emb = self.token_embedding_table(idx)
115
+ pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
116
+ x = tok_emb + pos_emb
117
+ x = self.blocks(x)
118
+ x = self.ln_final(x)
119
+ logits = self.lm_head(x)
120
+
121
+ if targets is None:
122
+ loss = None
123
+ else:
124
+ B, T, C = logits.shape
125
+ logits = logits.view(B*T, C)
126
+ targets = targets.view(B*T)
127
+ loss = F.cross_entropy(logits, targets)
128
+
129
+ return logits, loss
130
+
131
+ def generate(self, idx, max_new_tokens):
132
+ for _ in range(max_new_tokens):
133
+ idx_cond = idx[:, -block_size:]
134
+ logits, _ = self(idx_cond)
135
+ logits = logits[:, -1, :]
136
+ probs = F.softmax(logits, dim=-1)
137
+ idx_next = torch.multinomial(probs, num_samples=1)
138
+ idx = torch.cat((idx, idx_next), dim=1)
139
+ return idx
140
+
141
+ def save_pretrained(self, save_directory, **kwargs):
142
+ super().save_pretrained(save_directory, **kwargs)
143
+ with open(f"{save_directory}/vocab.json", "w") as f:
144
+ json.dump(stoi, f)