pt-sk commited on
Commit
28dc58b
1 Parent(s): a969f99

Upload 7 files

Browse files
Files changed (7) hide show
  1. bert_dataset.py +115 -0
  2. bert_model.py +250 -0
  3. data.py +30 -0
  4. optimizer_schedule.py +33 -0
  5. tokenizer.py +59 -0
  6. train.ipynb +71 -0
  7. trainer.py +106 -0
bert_dataset.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import Dataset
6
+ from transformers import BertTokenizer
7
+ from data import get_data
8
+ import itertools
9
+
10
+ tokenizer = BertTokenizer.from_pretrained("bert-it-1/bert-it-vocab.txt")
11
+
12
+
13
+ class BERTDataset(Dataset):
14
+
15
+ def __init__(self, tokenizer: BertTokenizer=tokenizer, data_pair: list=get_data('datasets/movie_conversations.txt', "datasets/movie_lines.txt"), seq_len: int=128) -> None:
16
+ super().__init__()
17
+
18
+ self.tokenizer = tokenizer
19
+ self.seq_len = seq_len
20
+ self.corpus_lines = len(data_pair)
21
+ self.lines = data_pair
22
+
23
+ def __len__(self):
24
+ return self.corpus_lines
25
+
26
+ def __getitem__(self, item):
27
+
28
+ # Step 1: get random sentence pair, either negative or positive (saved as is_next_label)
29
+ t1, t2, is_next_label = self.get_sent(item)
30
+
31
+ # Step 2: replace random words in sentence with mask / random words
32
+ t1_random, t1_label = self.random_word(t1)
33
+ t2_random, t2_label = self.random_word(t2)
34
+
35
+ # Step 3: Adding CLS and SEP tokens to the start and end of sentences
36
+ # Adding PAD token for labels
37
+ t1 = [self.tokenizer.vocab['[CLS]']] + t1_random + [self.tokenizer.vocab['[SEP]']]
38
+ t2 = t2_random + [self.tokenizer.vocab['[SEP]']]
39
+ t1_label = [self.tokenizer.vocab['[PAD]']] + t1_label + [self.tokenizer.vocab['[PAD]']]
40
+ t2_label = t2_label + [self.tokenizer.vocab['[PAD]']]
41
+
42
+ # Step 4: combine sentence 1 and 2 as one input
43
+ # adding PAD tokens to make the sentence same length as seq_len
44
+ segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
45
+ bert_input = (t1 + t2)[:self.seq_len]
46
+ bert_label = (t1_label + t2_label)[:self.seq_len]
47
+ padding = [self.tokenizer.vocab['[PAD]'] for _ in range(self.seq_len - len(bert_input))]
48
+ bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding)
49
+
50
+ output = {"bert_input": bert_input,
51
+ "bert_label": bert_label,
52
+ "segment_label": segment_label,
53
+ "is_next": is_next_label}
54
+
55
+ return {key: torch.tensor(value) for key, value in output.items()}
56
+
57
+ def random_word(self, sentence):
58
+ tokens = sentence.split()
59
+ output_label = []
60
+ output = []
61
+
62
+ # 15% of the tokens would be replaced
63
+ for i, token in enumerate(tokens):
64
+ prob = random.random()
65
+
66
+ # remove cls and sep token
67
+ token_id = self.tokenizer(token)['input_ids'][1:-1]
68
+
69
+ if prob < 0.15:
70
+ prob /= 0.15
71
+
72
+ # 80% chance change token to mask token
73
+ if prob < 0.8:
74
+ for i in range(len(token_id)):
75
+ output.append(self.tokenizer.vocab['[MASK]'])
76
+
77
+ # 10% chance change token to random token
78
+ elif prob < 0.9:
79
+ for i in range(len(token_id)):
80
+ output.append(random.randrange(len(self.tokenizer.vocab)))
81
+
82
+ # 10% chance change token to current token
83
+ else:
84
+ output.append(token_id)
85
+
86
+ output_label.append(token_id)
87
+
88
+ else:
89
+ output.append(token_id)
90
+ for i in range(len(token_id)):
91
+ output_label.append(0)
92
+
93
+ # flattening
94
+ output = list(itertools.chain(*[[x] if not isinstance(x, list) else x for x in output]))
95
+ output_label = list(itertools.chain(*[[x] if not isinstance(x, list) else x for x in output_label]))
96
+ assert len(output) == len(output_label)
97
+ return output, output_label
98
+
99
+ def get_sent(self, index):
100
+ '''return random sentence pair'''
101
+ t1, t2 = self.get_corpus_line(index)
102
+
103
+ # negative or positive pair, for next sentence prediction
104
+ if random.random() > 0.5:
105
+ return t1, t2, 1
106
+ else:
107
+ return t1, self.get_random_line(), 0
108
+
109
+ def get_corpus_line(self, item):
110
+ '''return sentence pair'''
111
+ return self.lines[item][0], self.lines[item][1]
112
+
113
+ def get_random_line(self):
114
+ '''return random single sentence'''
115
+ return self.lines[random.randrange(len(self.lines))][1]
bert_model.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+
7
+ class PositionalEmbedding(torch.nn.Module):
8
+
9
+ def __init__(self, d_model, max_len=128):
10
+ super().__init__()
11
+
12
+ # Compute the positional encodings once in log space.
13
+ pe = torch.zeros(max_len, d_model).float()
14
+ pe.require_grad = False
15
+
16
+ for pos in range(max_len):
17
+ # for each dimension of the each position
18
+ for i in range(0, d_model, 2):
19
+ pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
20
+ pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
21
+
22
+ # include the batch size
23
+ self.pe = pe.unsqueeze(0)
24
+ # self.register_buffer('pe', pe)
25
+
26
+ def forward(self, x):
27
+ return self.pe
28
+
29
+ class BERTEmbedding(torch.nn.Module):
30
+ """
31
+ BERT Embedding which is consisted with under features
32
+ 1. TokenEmbedding : normal embedding matrix
33
+ 2. PositionalEmbedding : adding positional information using sin, cos
34
+ 2. SegmentEmbedding : adding sentence segment info, (sent_A:1, sent_B:2)
35
+ sum of all these features are output of BERTEmbedding
36
+ """
37
+
38
+ def __init__(self, vocab_size, embed_size, seq_len=64, dropout=0.1):
39
+ """
40
+ :param vocab_size: total vocab size
41
+ :param embed_size: embedding size of token embedding
42
+ :param dropout: dropout rate
43
+ """
44
+
45
+ super().__init__()
46
+ self.embed_size = embed_size
47
+ # (m, seq_len) --> (m, seq_len, embed_size)
48
+ # padding_idx is not updated during training, remains as fixed pad (0)
49
+ self.token = torch.nn.Embedding(vocab_size, embed_size, padding_idx=0)
50
+ self.segment = torch.nn.Embedding(3, embed_size, padding_idx=0)
51
+ self.position = PositionalEmbedding(d_model=embed_size, max_len=seq_len)
52
+ self.dropout = torch.nn.Dropout(p=dropout)
53
+
54
+ def forward(self, sequence, segment_label):
55
+ x = self.token(sequence) + self.position(sequence) + self.segment(segment_label)
56
+ return self.dropout(x)
57
+
58
+ ### attention layers
59
+ class MultiHeadedAttention(torch.nn.Module):
60
+
61
+ def __init__(self, heads, d_model, dropout=0.1):
62
+ super(MultiHeadedAttention, self).__init__()
63
+
64
+ assert d_model % heads == 0
65
+ self.d_k = d_model // heads
66
+ self.heads = heads
67
+ self.dropout = torch.nn.Dropout(dropout)
68
+
69
+ self.query = torch.nn.Linear(d_model, d_model)
70
+ self.key = torch.nn.Linear(d_model, d_model)
71
+ self.value = torch.nn.Linear(d_model, d_model)
72
+ self.output_linear = torch.nn.Linear(d_model, d_model)
73
+
74
+ def forward(self, query, key, value, mask):
75
+ """
76
+ query, key, value of shape: (batch_size, max_len, d_model)
77
+ mask of shape: (batch_size, 1, 1, max_words)
78
+ """
79
+ # (batch_size, max_len, d_model)
80
+ query = self.query(query)
81
+ key = self.key(key)
82
+ value = self.value(value)
83
+
84
+ # (batch_size, max_len, d_model) --> (batch_size, max_len, h, d_k) --> (batch_size, h, max_len, d_k)
85
+ query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
86
+ key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
87
+ value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
88
+
89
+ # (batch_size, h, max_len, d_k) matmul (batch_size, h, d_k, max_len) --> (batch_size, h, max_len, max_len)
90
+ scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / math.sqrt(query.size(-1))
91
+
92
+ # fill 0 mask with super small number so it wont affect the softmax weight
93
+ # (batch_size, h, max_len, max_len)
94
+ scores = scores.masked_fill(mask == 0, -1e9)
95
+
96
+ # (batch_size, h, max_len, max_len)
97
+ # softmax to put attention weight for all non-pad tokens
98
+ # max_len X max_len matrix of attention
99
+ weights = F.softmax(scores, dim=-1)
100
+ weights = self.dropout(weights)
101
+
102
+ # (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k)
103
+ context = torch.matmul(weights, value)
104
+
105
+ # (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, d_model)
106
+ context = context.permute(0, 2, 1, 3).contiguous().view(context.shape[0], -1, self.heads * self.d_k)
107
+
108
+ # (batch_size, max_len, d_model)
109
+ return self.output_linear(context)
110
+
111
+ class FeedForward(torch.nn.Module):
112
+ "Implements FFN equation."
113
+
114
+ def __init__(self, d_model, middle_dim=2048, dropout=0.1):
115
+ super(FeedForward, self).__init__()
116
+
117
+ self.fc1 = torch.nn.Linear(d_model, middle_dim)
118
+ self.fc2 = torch.nn.Linear(middle_dim, d_model)
119
+ self.dropout = torch.nn.Dropout(dropout)
120
+ self.activation = torch.nn.GELU()
121
+
122
+ def forward(self, x):
123
+ out = self.activation(self.fc1(x))
124
+ out = self.fc2(self.dropout(out))
125
+ return out
126
+
127
+ class EncoderLayer(torch.nn.Module):
128
+ def __init__(
129
+ self,
130
+ d_model=768,
131
+ heads=12,
132
+ feed_forward_hidden=768 * 4,
133
+ dropout=0.1
134
+ ):
135
+ super(EncoderLayer, self).__init__()
136
+ self.layernorm = torch.nn.LayerNorm(d_model)
137
+ self.self_multihead = MultiHeadedAttention(heads, d_model)
138
+ self.feed_forward = FeedForward(d_model, middle_dim=feed_forward_hidden)
139
+ self.dropout = torch.nn.Dropout(dropout)
140
+
141
+ def forward(self, embeddings, mask):
142
+ # embeddings: (batch_size, max_len, d_model)
143
+ # encoder mask: (batch_size, 1, 1, max_len)
144
+ # result: (batch_size, max_len, d_model)
145
+ interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
146
+ # residual layer
147
+ interacted = self.layernorm(interacted + embeddings)
148
+ # bottleneck
149
+ feed_forward_out = self.dropout(self.feed_forward(interacted))
150
+ encoded = self.layernorm(feed_forward_out + interacted)
151
+ return encoded
152
+
153
+
154
+ class BERT(torch.nn.Module):
155
+ """
156
+ BERT model : Bidirectional Encoder Representations from Transformers.
157
+ """
158
+
159
+ def __init__(self, vocab_size, d_model=768, n_layers=12, heads=12, dropout=0.1):
160
+ """
161
+ :param vocab_size: vocab_size of total words
162
+ :param hidden: BERT model hidden size
163
+ :param n_layers: numbers of Transformer blocks(layers)
164
+ :param attn_heads: number of attention heads
165
+ :param dropout: dropout rate
166
+ """
167
+
168
+ super().__init__()
169
+ self.d_model = d_model
170
+ self.n_layers = n_layers
171
+ self.heads = heads
172
+
173
+ # paper noted they used 4 * hidden_size for ff_network_hidden_size
174
+ self.feed_forward_hidden = d_model * 4
175
+
176
+ # embedding for BERT, sum of positional, segment, token embeddings
177
+ self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=d_model)
178
+
179
+ # multi-layers transformer blocks, deep network
180
+ self.encoder_blocks = torch.nn.ModuleList(
181
+ [EncoderLayer(d_model, heads, d_model * 4, dropout) for _ in range(n_layers)])
182
+
183
+ def forward(self, x, segment_info):
184
+ # attention masking for padded token
185
+ # (batch_size, 1, seq_len, seq_len)
186
+ mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
187
+
188
+ # embedding the indexed sequence to sequence of vectors
189
+ x = self.embedding(x, segment_info)
190
+
191
+ # running over multiple transformer blocks
192
+ for encoder in self.encoder_blocks:
193
+ x = encoder.forward(x, mask)
194
+ return x
195
+
196
+ class NextSentencePrediction(torch.nn.Module):
197
+ """
198
+ 2-class classification model : is_next, is_not_next
199
+ """
200
+
201
+ def __init__(self, hidden):
202
+ """
203
+ :param hidden: BERT model output size
204
+ """
205
+ super().__init__()
206
+ self.linear = torch.nn.Linear(hidden, 2)
207
+ self.softmax = torch.nn.LogSoftmax(dim=-1)
208
+
209
+ def forward(self, x):
210
+ # use only the first token which is the [CLS]
211
+ return self.softmax(self.linear(x[:, 0]))
212
+
213
+ class MaskedLanguageModel(torch.nn.Module):
214
+ """
215
+ predicting origin token from masked input sequence
216
+ n-class classification problem, n-class = vocab_size
217
+ """
218
+
219
+ def __init__(self, hidden, vocab_size):
220
+ """
221
+ :param hidden: output size of BERT model
222
+ :param vocab_size: total vocab size
223
+ """
224
+ super().__init__()
225
+ self.linear = torch.nn.Linear(hidden, vocab_size)
226
+ self.softmax = torch.nn.LogSoftmax(dim=-1)
227
+
228
+ def forward(self, x):
229
+ return self.softmax(self.linear(x))
230
+
231
+ class BERTLM(torch.nn.Module):
232
+ """
233
+ BERT Language Model
234
+ Next Sentence Prediction Model + Masked Language Model
235
+ """
236
+
237
+ def __init__(self, bert: BERT, vocab_size):
238
+ """
239
+ :param bert: BERT model which should be trained
240
+ :param vocab_size: total vocab size for masked_lm
241
+ """
242
+
243
+ super().__init__()
244
+ self.bert = bert
245
+ self.next_sentence = NextSentencePrediction(self.bert.d_model)
246
+ self.mask_lm = MaskedLanguageModel(self.bert.d_model, vocab_size)
247
+
248
+ def forward(self, x, segment_label):
249
+ x = self.bert(x, segment_label)
250
+ return self.next_sentence(x), self.mask_lm(x)
data.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_data(conversations: str, movie_lines: str, max_len: int=64) -> list:
2
+
3
+ with open(conversations, 'r', encoding='iso-8859-1') as c:
4
+ conv = c.readlines()
5
+ with open(movie_lines, 'r', encoding='iso-8859-1') as l:
6
+ lines = l.readlines()
7
+
8
+ ### splitting text using special lines
9
+ lines_dic = {}
10
+ for line in lines:
11
+ objects = line.split(" +++$+++ ")
12
+ lines_dic[objects[0]] = objects[-1]
13
+
14
+ ### generate question answer pairs
15
+ pairs = []
16
+ for con in conv:
17
+ ids = eval(con.split(" +++$+++ ")[-1])
18
+ for i in range(len(ids)):
19
+ qa_pairs = []
20
+
21
+ if i == len(ids) - 1:
22
+ break
23
+
24
+ first = lines_dic[ids[i]].strip()
25
+ second = lines_dic[ids[i+1]].strip()
26
+
27
+ qa_pairs.append(' '.join(first.split()[:max_len]))
28
+ qa_pairs.append(' '.join(second.split()[:max_len]))
29
+ pairs.append(qa_pairs)
30
+ return pairs
optimizer_schedule.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ class ScheduledOptim():
4
+ '''A simple wrapper class for learning rate scheduling'''
5
+
6
+ def __init__(self, optimizer, d_model, n_warmup_steps):
7
+ self._optimizer = optimizer
8
+ self.n_warmup_steps = n_warmup_steps
9
+ self.n_current_steps = 0
10
+ self.init_lr = np.power(d_model, -0.5)
11
+
12
+ def step_and_update_lr(self):
13
+ "Step with the inner optimizer"
14
+ self._update_learning_rate()
15
+ self._optimizer.step()
16
+
17
+ def zero_grad(self):
18
+ "Zero out the gradients by the inner optimizer"
19
+ self._optimizer.zero_grad()
20
+
21
+ def _get_lr_scale(self):
22
+ return np.min([
23
+ np.power(self.n_current_steps, -0.5),
24
+ np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])
25
+
26
+ def _update_learning_rate(self):
27
+ ''' Learning rate scheduling per step '''
28
+
29
+ self.n_current_steps += 1
30
+ lr = self.init_lr * self._get_lr_scale()
31
+
32
+ for param_group in self._optimizer.param_groups:
33
+ param_group['lr'] = lr
tokenizer.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from tokenizers import BertWordPieceTokenizer
4
+ from transformers import BertTokenizer
5
+ import tqdm
6
+
7
+ from data import get_data
8
+
9
+
10
+ import re
11
+ import transformers, datasets
12
+ import numpy as np
13
+ from torch.optim import Adam
14
+ import math
15
+
16
+
17
+ pairs = get_data('datasets/movie_conversations.txt', "datasets/movie_lines.txt")
18
+
19
+ # WordPiece tokenizer
20
+
21
+ ### save data as txt file
22
+ os.mkdir('data')
23
+ text_data = []
24
+ file_count = 0
25
+
26
+
27
+ for sample in tqdm.tqdm([x[0] for x in pairs]):
28
+ text_data.append(sample)
29
+
30
+ # once we hit the 10K mark, save to file
31
+ if len(text_data) == 10000:
32
+ with open(f'data/text_{file_count}.txt', 'w', encoding='utf-8') as fp:
33
+ fp.write('\n'.join(text_data))
34
+ text_data = []
35
+ file_count += 1
36
+
37
+ paths = [str(x) for x in Path('data').glob('**/*.txt')]
38
+
39
+
40
+ ### Training own tokenizer
41
+ tokenizer = BertWordPieceTokenizer(
42
+ clean_text=True,
43
+ handle_chinese_chars=False,
44
+ strip_accents=False,
45
+ lowercase=True
46
+ )
47
+
48
+
49
+ tokenizer.train(
50
+ files=paths,
51
+ min_frequency=5,
52
+ limit_alphabet=1000,
53
+ wordpieces_prefix="##",
54
+ special_tokens=["[PAD]", "[CLS]", "[SEP]", "[MASK]", "[UNK]"]
55
+ )
56
+
57
+
58
+ os.mkdir("bert-it-1")
59
+ tokenizer.save_model("bert-it-1", "bert-it")
train.ipynb ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 4,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from bert_dataset import BERTDataset\n",
10
+ "from torch.utils.data import DataLoader\n",
11
+ "from bert_model import BERT, BERTLM\n",
12
+ "from trainer import BERTTrainer\n",
13
+ "from transformers import BertTokenizer\n",
14
+ "from data import get_data\n",
15
+ "\n",
16
+ "MAX_LEN = 128\n",
17
+ "\n",
18
+ "pairs = get_data('datasets/movie_conversations.txt', \"datasets/movie_lines.txt\")\n",
19
+ "tokenizer = BertTokenizer.from_pretrained(\"bert-it-1/bert-it-vocab.txt\")\n",
20
+ "\n",
21
+ "train_data = BERTDataset()\n",
22
+ "\n",
23
+ "train_loader = DataLoader(\n",
24
+ " train_data, batch_size=32, shuffle=True, pin_memory=True)\n",
25
+ "\n",
26
+ "bert_model = BERT(\n",
27
+ " vocab_size=len(tokenizer.vocab),\n",
28
+ " d_model=768,\n",
29
+ " n_layers=2,\n",
30
+ " heads=12,\n",
31
+ " dropout=0.1\n",
32
+ ")\n",
33
+ "\n",
34
+ "bert_lm = BERTLM(bert=bert_model, vocab_size=len(tokenizer.vocab))\n",
35
+ "bert_trainer = BERTTrainer(bert_lm, train_loader, device='cpu')\n",
36
+ "epochs = 20\n",
37
+ "\n",
38
+ "for epoch in range(epochs):\n",
39
+ " bert_trainer.train(epoch)"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": []
48
+ }
49
+ ],
50
+ "metadata": {
51
+ "kernelspec": {
52
+ "display_name": "base",
53
+ "language": "python",
54
+ "name": "python3"
55
+ },
56
+ "language_info": {
57
+ "codemirror_mode": {
58
+ "name": "ipython",
59
+ "version": 3
60
+ },
61
+ "file_extension": ".py",
62
+ "mimetype": "text/x-python",
63
+ "name": "python",
64
+ "nbconvert_exporter": "python",
65
+ "pygments_lexer": "ipython3",
66
+ "version": "3.11.8"
67
+ }
68
+ },
69
+ "nbformat": 4,
70
+ "nbformat_minor": 2
71
+ }
trainer.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn. functional as F
4
+ from optimizer_schedule import ScheduledOptim
5
+ import tqdm
6
+ from torch.optim import Adam
7
+
8
+
9
+ class BERTTrainer:
10
+ def __init__(
11
+ self,
12
+ model,
13
+ train_dataloader,
14
+ test_dataloader=None,
15
+ lr= 1e-4,
16
+ weight_decay=0.01,
17
+ betas=(0.9, 0.999),
18
+ warmup_steps=10000,
19
+ log_freq=10,
20
+ device='cuda'
21
+ ):
22
+
23
+ self.device = device
24
+ self.model = model
25
+ self.train_data = train_dataloader
26
+ self.test_data = test_dataloader
27
+
28
+ # Setting the Adam optimizer with hyper-param
29
+ self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
30
+ self.optim_schedule = ScheduledOptim(
31
+ self.optim, self.model.bert.d_model, n_warmup_steps=warmup_steps
32
+ )
33
+
34
+ # Using Negative Log Likelihood Loss function for predicting the masked_token
35
+ self.criterion = torch.nn.NLLLoss(ignore_index=0)
36
+ self.log_freq = log_freq
37
+ print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
38
+
39
+ def train(self, epoch):
40
+ self.iteration(epoch, self.train_data)
41
+
42
+ def test(self, epoch):
43
+ self.iteration(epoch, self.test_data, train=False)
44
+
45
+ def iteration(self, epoch, data_loader, train=True):
46
+
47
+ avg_loss = 0.0
48
+ total_correct = 0
49
+ total_element = 0
50
+
51
+ mode = "train" if train else "test"
52
+
53
+ # progress bar
54
+ data_iter = tqdm.tqdm(
55
+ enumerate(data_loader),
56
+ desc="EP_%s:%d" % (mode, epoch),
57
+ total=len(data_loader),
58
+ bar_format="{l_bar}{r_bar}"
59
+ )
60
+
61
+ for i, data in data_iter:
62
+
63
+ # 0. batch_data will be sent into the device(GPU or cpu)
64
+ data = {key: value.to(self.device) for key, value in data.items()}
65
+
66
+ # 1. forward the next_sentence_prediction and masked_lm model
67
+ next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"])
68
+
69
+ # 2-1. NLL(negative log likelihood) loss of is_next classification result
70
+ next_loss = self.criterion(next_sent_output, data["is_next"])
71
+
72
+ # 2-2. NLLLoss of predicting masked token word
73
+ # transpose to (m, vocab_size, seq_len) vs (m, seq_len)
74
+ # criterion(mask_lm_output.view(-1, mask_lm_output.size(-1)), data["bert_label"].view(-1))
75
+ mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"])
76
+
77
+ # 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure
78
+ loss = next_loss + mask_loss
79
+
80
+ # 3. backward and optimization only in train
81
+ if train:
82
+ self.optim_schedule.zero_grad()
83
+ loss.backward()
84
+ self.optim_schedule.step_and_update_lr()
85
+
86
+ # next sentence prediction accuracy
87
+ correct = next_sent_output.argmax(dim=-1).eq(data["is_next"]).sum().item()
88
+ avg_loss += loss.item()
89
+ total_correct += correct
90
+ total_element += data["is_next"].nelement()
91
+
92
+ post_fix = {
93
+ "epoch": epoch,
94
+ "iter": i,
95
+ "avg_loss": avg_loss / (i + 1),
96
+ "avg_acc": total_correct / total_element * 100,
97
+ "loss": loss.item()
98
+ }
99
+
100
+ if i % self.log_freq == 0:
101
+ data_iter.write(str(post_fix))
102
+ print(
103
+ f"EP{epoch}, {mode}: \
104
+ avg_loss={avg_loss / len(data_iter)}, \
105
+ total_acc={total_correct * 100.0 / total_element}"
106
+ )