FZH1996 commited on
Commit
e7d695a
1 Parent(s): fe45bc3

update fed-lora

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. examples/NLG/eval/GenerationEval/bleurt +1 -0
  2. examples/NLG/eval/GenerationEval/metrics/bleurt +1 -0
  3. examples/NLG/eval/e2e/metrics/__pycache__/__init__.cpython-37.pyc +0 -0
  4. examples/NLG/eval/e2e/metrics/__pycache__/pymteval.cpython-37.pyc +0 -0
  5. examples/NLG/eval/e2e/pycocoevalcap/__pycache__/__init__.cpython-37.pyc +0 -0
  6. examples/NLG/eval/e2e/pycocoevalcap/__pycache__/eval.cpython-37.pyc +0 -0
  7. examples/NLG/eval/e2e/pycocoevalcap/bleu/__pycache__/__init__.cpython-37.pyc +0 -0
  8. examples/NLG/eval/e2e/pycocoevalcap/bleu/__pycache__/bleu.cpython-37.pyc +0 -0
  9. examples/NLG/eval/e2e/pycocoevalcap/bleu/__pycache__/bleu_scorer.cpython-37.pyc +0 -0
  10. examples/NLG/eval/e2e/pycocoevalcap/cider/__pycache__/__init__.cpython-37.pyc +0 -0
  11. examples/NLG/eval/e2e/pycocoevalcap/cider/__pycache__/cider.cpython-37.pyc +0 -0
  12. examples/NLG/eval/e2e/pycocoevalcap/cider/__pycache__/cider_scorer.cpython-37.pyc +0 -0
  13. examples/NLG/eval/e2e/pycocoevalcap/meteor/__pycache__/__init__.cpython-37.pyc +0 -0
  14. examples/NLG/eval/e2e/pycocoevalcap/meteor/__pycache__/meteor.cpython-37.pyc +0 -0
  15. examples/NLG/eval/e2e/pycocoevalcap/rouge/__pycache__/__init__.cpython-37.pyc +0 -0
  16. examples/NLG/eval/e2e/pycocoevalcap/rouge/__pycache__/rouge.cpython-37.pyc +0 -0
  17. examples/NLG/eval/e2e/pycocoevalcap/tokenizer/__pycache__/__init__.cpython-37.pyc +0 -0
  18. examples/NLG/eval/e2e/pycocoevalcap/tokenizer/__pycache__/ptbtokenizer.cpython-37.pyc +0 -0
  19. examples/NLG/eval/e2e/pycocotools/__pycache__/__init__.cpython-36.pyc +0 -0
  20. examples/NLG/eval/e2e/pycocotools/__pycache__/__init__.cpython-37.pyc +0 -0
  21. examples/NLG/eval/e2e/pycocotools/__pycache__/coco.cpython-36.pyc +0 -0
  22. examples/NLG/eval/e2e/pycocotools/__pycache__/coco.cpython-37.pyc +0 -0
  23. examples/NLG/src/.DS_Store +0 -0
  24. examples/NLG/src/__pycache__/data_utils.cpython-310.pyc +0 -0
  25. examples/NLG/src/__pycache__/data_utils.cpython-36.pyc +0 -0
  26. examples/NLG/src/__pycache__/data_utils.cpython-37.pyc +0 -0
  27. examples/NLG/src/__pycache__/encoder.cpython-37.pyc +0 -0
  28. examples/NLG/src/__pycache__/exp_utils.cpython-310.pyc +0 -0
  29. examples/NLG/src/__pycache__/exp_utils.cpython-37.pyc +0 -0
  30. examples/NLG/src/__pycache__/gpu.cpython-310.pyc +0 -0
  31. examples/NLG/src/__pycache__/gpu.cpython-36.pyc +0 -0
  32. examples/NLG/src/__pycache__/gpu.cpython-37.pyc +0 -0
  33. examples/NLG/src/__pycache__/model.cpython-310.pyc +0 -0
  34. examples/NLG/src/__pycache__/model.cpython-36.pyc +0 -0
  35. examples/NLG/src/__pycache__/model.cpython-37.pyc +0 -0
  36. examples/NLG/src/__pycache__/optimizer.cpython-36.pyc +0 -0
  37. examples/NLG/src/__pycache__/optimizer.cpython-37.pyc +0 -0
  38. examples/NLG/src/data_utils.py +282 -0
  39. examples/NLG/src/encoder.py +132 -0
  40. examples/NLG/src/exp_utils.py +46 -0
  41. examples/NLG/src/format_converting_dart.py +43 -0
  42. examples/NLG/src/format_converting_e2e.py +20 -0
  43. examples/NLG/src/format_converting_webnlg.py +68 -0
  44. examples/NLG/src/gpt2_beam.py +419 -0
  45. examples/NLG/src/gpt2_decode.py +187 -0
  46. examples/NLG/src/gpt2_encode.py +70 -0
  47. examples/NLG/src/gpt2_ft.py +385 -0
  48. examples/NLG/src/gpu.py +129 -0
  49. examples/NLG/src/model.log +698 -0
  50. examples/NLG/src/model.py +460 -0
examples/NLG/eval/GenerationEval/bleurt ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit cebe7e6f996b40910cfaa520a63db47807e3bf5c
examples/NLG/eval/GenerationEval/metrics/bleurt ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit cebe7e6f996b40910cfaa520a63db47807e3bf5c
examples/NLG/eval/e2e/metrics/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (169 Bytes). View file
 
examples/NLG/eval/e2e/metrics/__pycache__/pymteval.cpython-37.pyc ADDED
Binary file (12.9 kB). View file
 
examples/NLG/eval/e2e/pycocoevalcap/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (195 Bytes). View file
 
examples/NLG/eval/e2e/pycocoevalcap/__pycache__/eval.cpython-37.pyc ADDED
Binary file (2.57 kB). View file
 
examples/NLG/eval/e2e/pycocoevalcap/bleu/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (200 Bytes). View file
 
examples/NLG/eval/e2e/pycocoevalcap/bleu/__pycache__/bleu.cpython-37.pyc ADDED
Binary file (1.24 kB). View file
 
examples/NLG/eval/e2e/pycocoevalcap/bleu/__pycache__/bleu_scorer.cpython-37.pyc ADDED
Binary file (8.07 kB). View file
 
examples/NLG/eval/e2e/pycocoevalcap/cider/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (201 Bytes). View file
 
examples/NLG/eval/e2e/pycocoevalcap/cider/__pycache__/cider.cpython-37.pyc ADDED
Binary file (1.67 kB). View file
 
examples/NLG/eval/e2e/pycocoevalcap/cider/__pycache__/cider_scorer.cpython-37.pyc ADDED
Binary file (7.85 kB). View file
 
examples/NLG/eval/e2e/pycocoevalcap/meteor/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (202 Bytes). View file
 
examples/NLG/eval/e2e/pycocoevalcap/meteor/__pycache__/meteor.cpython-37.pyc ADDED
Binary file (2.75 kB). View file
 
examples/NLG/eval/e2e/pycocoevalcap/rouge/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (203 Bytes). View file
 
examples/NLG/eval/e2e/pycocoevalcap/rouge/__pycache__/rouge.cpython-37.pyc ADDED
Binary file (3.75 kB). View file
 
examples/NLG/eval/e2e/pycocoevalcap/tokenizer/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (205 Bytes). View file
 
examples/NLG/eval/e2e/pycocoevalcap/tokenizer/__pycache__/ptbtokenizer.cpython-37.pyc ADDED
Binary file (2.18 kB). View file
 
examples/NLG/eval/e2e/pycocotools/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (189 Bytes). View file
 
examples/NLG/eval/e2e/pycocotools/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (193 Bytes). View file
 
examples/NLG/eval/e2e/pycocotools/__pycache__/coco.cpython-36.pyc ADDED
Binary file (13.4 kB). View file
 
examples/NLG/eval/e2e/pycocotools/__pycache__/coco.cpython-37.pyc ADDED
Binary file (13.4 kB). View file
 
examples/NLG/src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
examples/NLG/src/__pycache__/data_utils.cpython-310.pyc ADDED
Binary file (8.49 kB). View file
 
examples/NLG/src/__pycache__/data_utils.cpython-36.pyc ADDED
Binary file (8.58 kB). View file
 
examples/NLG/src/__pycache__/data_utils.cpython-37.pyc ADDED
Binary file (8.58 kB). View file
 
examples/NLG/src/__pycache__/encoder.cpython-37.pyc ADDED
Binary file (5.1 kB). View file
 
examples/NLG/src/__pycache__/exp_utils.cpython-310.pyc ADDED
Binary file (1.49 kB). View file
 
examples/NLG/src/__pycache__/exp_utils.cpython-37.pyc ADDED
Binary file (1.44 kB). View file
 
examples/NLG/src/__pycache__/gpu.cpython-310.pyc ADDED
Binary file (3.58 kB). View file
 
examples/NLG/src/__pycache__/gpu.cpython-36.pyc ADDED
Binary file (3.53 kB). View file
 
examples/NLG/src/__pycache__/gpu.cpython-37.pyc ADDED
Binary file (3.54 kB). View file
 
examples/NLG/src/__pycache__/model.cpython-310.pyc ADDED
Binary file (13.3 kB). View file
 
examples/NLG/src/__pycache__/model.cpython-36.pyc ADDED
Binary file (13.7 kB). View file
 
examples/NLG/src/__pycache__/model.cpython-37.pyc ADDED
Binary file (13.5 kB). View file
 
examples/NLG/src/__pycache__/optimizer.cpython-36.pyc ADDED
Binary file (11.4 kB). View file
 
examples/NLG/src/__pycache__/optimizer.cpython-37.pyc ADDED
Binary file (11.4 kB). View file
 
examples/NLG/src/data_utils.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # ------------------------------------------------------------------------------------------
5
+ import os, sys
6
+ import glob
7
+ import random
8
+ from collections import Counter, OrderedDict
9
+ import numpy as np
10
+ import torch
11
+ import json
12
+
13
+ import torch
14
+ from torch.utils.data import Dataset
15
+ from torch.utils.data import DataLoader
16
+
17
+
18
+ class LMOrderedIterator(object):
19
+ def __init__(self, data, bsz, bptt, eval_len=None, device='cpu', world_size=1, rank=0):
20
+ """
21
+ data -- LongTensor -- the LongTensor is strictly ordered
22
+ """
23
+ self.data = data
24
+ self.bsz = bsz
25
+ self.world_size = world_size
26
+ self.rank = rank
27
+ self.bptt = bptt # tgt_len
28
+ # existing len.
29
+ self.eval_len = bptt if eval_len is None else eval_len
30
+
31
+ self.device = device
32
+
33
+ self.global_bsz = bsz * world_size
34
+ # Work out how cleanly we can divide the dataset into bsz parts.
35
+ self.n_step = len(data) // self.global_bsz # bsz
36
+
37
+ self.split_data = torch.tensor(
38
+ data[rank * self.n_step * bsz : (rank + 1) * self.n_step * bsz],
39
+ dtype=torch.long, device=self.device
40
+ ) # data.view(-1)
41
+
42
+ self.split_data = self.split_data.view(bsz, -1)
43
+
44
+ def __iter__(self):
45
+ return self.get_fixlen_iter()
46
+
47
+ def get_batch(self, i, bptt, eval_len):
48
+ beg_idx = i
49
+ end_idx = i + bptt # seq_len
50
+
51
+ # batch_size, lengh;
52
+ _input = self.split_data[:, beg_idx : end_idx].contiguous()
53
+ _target = self.split_data[:, beg_idx+1 : end_idx+1].contiguous()
54
+
55
+ _msk = torch.cat(
56
+ [
57
+ torch.zeros(bptt-eval_len, dtype=torch.float, device=self.device),
58
+ torch.ones(eval_len, dtype=torch.float, device=self.device)
59
+ ]
60
+ )
61
+ _msk = _msk.unsqueeze(0).expand_as(_input) # .unsqueeze(-1) # length, 1;
62
+ return _input, _target, _msk
63
+
64
+ def get_fixlen_iter(self, start=0):
65
+ self.data_len = self.split_data.size(1)
66
+ _eval_cursor = 0
67
+ for i in range(start, self.data_len - 1, self.eval_len):
68
+ bptt = min(self.bptt, self.data_len - i - 1)
69
+ _end_idx = i + bptt
70
+ yield self.get_batch(i, bptt, _end_idx - _eval_cursor)
71
+ _eval_cursor = _end_idx
72
+
73
+
74
+ class Corpus(object):
75
+ def __init__(self, path):
76
+ self.path = path
77
+ self.num_words = 0
78
+ self.tokens = []
79
+ with open(self.path, "r") as reader:
80
+ for line in reader:
81
+ items = json.loads(line.strip())
82
+ book = items['book']
83
+ tokens = items['tokens']
84
+ num_words = items['num_words']
85
+
86
+ self.num_words += num_words
87
+ self.tokens.extend(tokens)
88
+
89
+
90
+ class BinLMOrderedIterator(object):
91
+ def __init__(self, corpus, bsz, bptt, eval_len=None, device='cpu', world_size=1, rank=0):
92
+ """
93
+ data -- LongTensor -- the LongTensor is strictly ordered
94
+ """
95
+ self.corpus = corpus
96
+ self.bsz = bsz
97
+ self.world_size = world_size
98
+ self.rank = rank
99
+ self.bptt = bptt # tgt_len
100
+ # existing len.
101
+ self.eval_len = bptt if eval_len is None else eval_len
102
+ self.device = device
103
+ self.global_bsz = bsz * world_size
104
+ # Work out how cleanly we can divide the dataset into bsz parts.
105
+ self.n_step = corpus.length // self.global_bsz # bsz
106
+
107
+ self.offset = [(rank * bsz + _b) * self.n_step for _b in range(bsz)]
108
+
109
+ def __iter__(self):
110
+ return self.get_fixlen_iter()
111
+
112
+ def get_batch(self, i, bptt, eval_len):
113
+ # batch_size, lengh;
114
+ _inputs = []
115
+ _targets = []
116
+ for _b in range(0, self.bsz):
117
+ _input = self.corpus.get_tokens(self.offset[_b] + i, bptt)
118
+ _target = self.corpus.get_tokens(self.offset[_b] + i + 1, bptt)
119
+
120
+ _inputs.append(_input)
121
+ _targets.append(_target)
122
+
123
+ _input = torch.tensor(_inputs, dtype=torch.int64, device=self.device).contiguous()
124
+ _target = torch.tensor(_targets, dtype=torch.int64, device=self.device).contiguous()
125
+
126
+ _msk = torch.cat(
127
+ [
128
+ torch.zeros(bptt-eval_len, dtype=torch.float, device=self.device),
129
+ torch.ones(eval_len, dtype=torch.float, device=self.device)
130
+ ]
131
+ )
132
+ _msk = _msk.unsqueeze(0).expand_as(_input) # .unsqueeze(-1) # length, 1;
133
+ return _input, _target, _msk
134
+
135
+ def get_fixlen_iter(self, start=0):
136
+ #self.data_len = self.split_data.size(1)
137
+ _eval_cursor = 0
138
+ for i in range(start, self.n_step - 1, self.eval_len):
139
+ bptt = min(self.bptt, self.n_step - i - 1)
140
+ _end_idx = i + bptt
141
+ yield self.get_batch(i, bptt, _end_idx - _eval_cursor)
142
+ _eval_cursor = _end_idx
143
+
144
+
145
+ class BinCorpus(object):
146
+ def __init__(self, path):
147
+ self.path = path
148
+
149
+ self.book_token_span = []
150
+ self.book_token_span.append(0)
151
+ tokens_sum = 0
152
+ self.num_words = 0
153
+
154
+ with open(path+'.info', 'r') as info_reader:
155
+ for line in info_reader:
156
+ items = json.loads(line.strip())
157
+ book = items['book']
158
+ num_tokens = items['num_subtokens']
159
+ num_words = items['num_words']
160
+
161
+ tokens_sum += num_tokens
162
+ self.book_token_span.append(tokens_sum)
163
+ self.num_words += num_words
164
+
165
+ self.length = self.book_token_span[-1]
166
+ self.bin_reader = open(path+'.bin', 'rb')
167
+
168
+ def get_tokens(self, offset, count):
169
+ INT64_SIZE = 8
170
+ self.bin_reader.seek(offset * INT64_SIZE)
171
+ x = np.fromfile(self.bin_reader, count=count, dtype=np.int)
172
+ return x
173
+
174
+
175
+ def get_lm_corpus(data):
176
+ print('Producing dataset {}...'.format(data))
177
+ corpus = Corpus(data)
178
+ return corpus
179
+
180
+
181
+ def padding_tokens(tokens, max_seq_length, pad_token, direct, max_context_length=0):
182
+
183
+ if max_context_length == 0:
184
+ max_context_length = max_seq_length
185
+
186
+ if len(tokens) > max_context_length:
187
+ if direct > 0:
188
+ pad_tokens = tokens[:max_context_length]
189
+ else:
190
+ pad_tokens = tokens[-max_context_length:]
191
+ else:
192
+ pad_tokens = tokens
193
+ token_len = len(pad_tokens)
194
+ pad_tokens = pad_tokens + [pad_token for _ in range(max_seq_length - token_len)]
195
+ return pad_tokens, token_len
196
+
197
+
198
+ class FT_Dataset(Dataset):
199
+ def __init__(self, ft_file, batch_size, max_seq_length,
200
+ max_eval_length=0, joint_lm=False, prefix_len=0, infix_len=0,
201
+ prefix_cursor=1000000, infix_cursor=2000000):
202
+ self.ft_file = ft_file
203
+ self.ft_samples = self.read_ft_file(ft_file)
204
+ self.batch_size = batch_size
205
+ self.num_examples = len(self.ft_samples)
206
+ self.max_seq_length = max_seq_length
207
+ self.max_eval_length = max_eval_length
208
+ self.rng = random.Random(911)
209
+ self.joint_lm = joint_lm
210
+
211
+ self.num_batches = int((self.num_examples + self.batch_size - 1) / self.batch_size)
212
+
213
+ self.prefix_len = prefix_len
214
+ self.infix_len = infix_len
215
+ self.prefix_cursor = prefix_cursor
216
+ self.infix_cursor = infix_cursor
217
+
218
+ def __len__(self):
219
+ return self.num_batches * self.batch_size
220
+
221
+ def __getitem__(self, item):
222
+ if(item >= self.num_examples):
223
+ item = self.rng.randint(0, self.num_examples - 1)
224
+
225
+ example = self.ft_samples[item]
226
+ context = example[0]
227
+ completion = example[1]
228
+
229
+ pretokens = [i + self.prefix_cursor for i in range(0, self.prefix_len)]
230
+ intokens = [i + self.infix_cursor for i in range(0, self.infix_len)]
231
+
232
+ conditions = pretokens + context + intokens
233
+ _input, _input_len = padding_tokens(conditions + completion, self.max_seq_length, 0, 1)
234
+
235
+ pad_targets = [0 for i in range(0, self.prefix_len)] + context + [0 for i in range(0, self.infix_len)] + completion
236
+ _target, _ = padding_tokens(pad_targets[1:], self.max_seq_length, 0, 1)
237
+
238
+ if not self.joint_lm:
239
+ _msk = [0.0] * (len(conditions) - 1) + [1.0] * (_input_len - len(conditions))
240
+ else:
241
+ _msk = [1.0] * (_input_len - 1)
242
+
243
+ _msk, _ = padding_tokens(_msk, self.max_seq_length, 0.0, 1)
244
+
245
+ output = {}
246
+ output["id"] = torch.tensor(item, dtype=torch.long)
247
+
248
+ _query, _query_len = padding_tokens(
249
+ conditions, self.max_seq_length, 0, -1,
250
+ max_context_length = self.max_seq_length - self.max_eval_length
251
+ )
252
+ output["query"] = torch.tensor(_query, dtype=torch.long)
253
+ output["query_len"] = torch.tensor(_query_len, dtype=torch.long)
254
+
255
+ output["input"] = torch.tensor(_input, dtype=torch.long)
256
+ output["target"] = torch.tensor(_target, dtype=torch.long)
257
+
258
+ output["mask"] = torch.tensor(_msk, dtype=torch.float)
259
+ return output
260
+
261
+ def read_ft_file(self, ft_file):
262
+ ft_samples = []
263
+ with open(ft_file, 'r') as reader:
264
+ for line in reader:
265
+ items = json.loads(line.strip())
266
+ context = items['context']
267
+ completion = items['completion']
268
+ ft_samples.append([context, completion])
269
+ return ft_samples
270
+
271
+ def get_item_list(self, start, interval):
272
+ start = min(start, self.num_examples-1)
273
+ start = max(0,start)
274
+ if(start + interval >= self.num_examples):
275
+ end = self.num_examples
276
+ else:
277
+ end = start + interval
278
+ samples = []
279
+ for index in range(start, end):
280
+ output = self.__getitem__(index)
281
+ samples.append(output)
282
+ return samples
examples/NLG/src/encoder.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # ------------------------------------------------------------------------------------------
5
+ import os
6
+ import json
7
+ import regex as re
8
+ from functools import lru_cache
9
+
10
+
11
+ @lru_cache()
12
+ def bytes_to_unicode():
13
+ """
14
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
15
+ The reversible bpe codes work on unicode strings.
16
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
17
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
18
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
19
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
20
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
21
+ """
22
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
23
+ cs = bs[:]
24
+ n = 0
25
+ for b in range(2**8):
26
+ if b not in bs:
27
+ bs.append(b)
28
+ cs.append(2**8+n)
29
+ n += 1
30
+ cs = [chr(n) for n in cs]
31
+ return dict(zip(bs, cs))
32
+
33
+
34
+ def get_pairs(word):
35
+ """Return set of symbol pairs in a word.
36
+ Word is represented as tuple of symbols (symbols being variable-length strings).
37
+ """
38
+ pairs = set()
39
+ prev_char = word[0]
40
+ for char in word[1:]:
41
+ pairs.add((prev_char, char))
42
+ prev_char = char
43
+ return pairs
44
+
45
+
46
+ class Encoder:
47
+
48
+ def __init__(self, encoder, bpe_merges, errors='replace'):
49
+ self.encoder = encoder
50
+ self.decoder = {v:k for k,v in self.encoder.items()}
51
+ self.errors = errors # how to handle errors in decoding
52
+ self.byte_encoder = bytes_to_unicode()
53
+ self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
54
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
55
+ self.cache = {}
56
+ # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
57
+ try:
58
+ import regex as re
59
+ self.re = re
60
+ except ImportError:
61
+ raise ImportError('Please install regex with: pip install regex')
62
+
63
+
64
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
65
+
66
+ def bpe(self, token):
67
+ if token in self.cache:
68
+ return self.cache[token]
69
+ word = tuple(token)
70
+ pairs = get_pairs(word)
71
+
72
+ if not pairs:
73
+ return token
74
+
75
+ while True:
76
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
77
+ if bigram not in self.bpe_ranks:
78
+ break
79
+ first, second = bigram
80
+ new_word = []
81
+ i = 0
82
+ while i < len(word):
83
+ try:
84
+ j = word.index(first, i)
85
+ new_word.extend(word[i:j])
86
+ i = j
87
+ except:
88
+ new_word.extend(word[i:])
89
+ break
90
+
91
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
92
+ new_word.append(first+second)
93
+ i += 2
94
+ else:
95
+ new_word.append(word[i])
96
+ i += 1
97
+ new_word = tuple(new_word)
98
+ word = new_word
99
+ if len(word) == 1:
100
+ break
101
+ else:
102
+ pairs = get_pairs(word)
103
+ word = ' '.join(word)
104
+ self.cache[token] = word
105
+ return word
106
+
107
+ def encode(self, text):
108
+ bpe_tokens = []
109
+ tokens = []
110
+ for token in re.findall(self.pat, text):
111
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
112
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
113
+ if token:
114
+ tokens.append(token)
115
+ return bpe_tokens, tokens
116
+
117
+ def decode(self, tokens):
118
+ text = ''.join([self.decoder[token] for token in tokens])
119
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
120
+ return text
121
+
122
+
123
+ def get_encoder(models_dir):
124
+ with open(os.path.join(models_dir, 'encoder.json'), 'r') as f:
125
+ encoder = json.load(f)
126
+ with open(os.path.join(models_dir, 'vocab.bpe'), 'r', encoding="utf-8") as f:
127
+ bpe_data = f.read()
128
+ bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
129
+ return Encoder(
130
+ encoder=encoder,
131
+ bpe_merges=bpe_merges,
132
+ )
examples/NLG/src/exp_utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # ------------------------------------------------------------------------------------------
5
+ import functools
6
+ import os, shutil
7
+ import numpy as np
8
+
9
+ import torch
10
+
11
+
12
+ def logging(s, log_path, print_=True, log_=True):
13
+ if print_:
14
+ print(s)
15
+ if log_:
16
+ with open(log_path, 'a+') as f_log:
17
+ f_log.write(s + '\n')
18
+
19
+
20
+ def get_logger(log_path, **kwargs):
21
+ return functools.partial(logging, log_path=log_path, **kwargs)
22
+
23
+
24
+ def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
25
+ if debug:
26
+ print('Debug Mode : no experiment dir created')
27
+ return functools.partial(logging, log_path=None, log_=False)
28
+
29
+ if not os.path.exists(dir_path):
30
+ os.makedirs(dir_path)
31
+
32
+ print('Experiment dir : {}'.format(dir_path))
33
+ if scripts_to_save is not None:
34
+ script_path = os.path.join(dir_path, 'scripts')
35
+ if not os.path.exists(script_path):
36
+ os.makedirs(script_path)
37
+ for script in scripts_to_save:
38
+ dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script))
39
+ shutil.copyfile(script, dst_file)
40
+
41
+ return get_logger(log_path=os.path.join(dir_path, 'log.txt'))
42
+
43
+
44
+ def save_checkpoint(model, optimizer, path, epoch):
45
+ torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch)))
46
+ torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer_{}.pt'.format(epoch)))
examples/NLG/src/format_converting_dart.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # ------------------------------------------------------------------------------------------
5
+ import sys
6
+ import io
7
+ import json
8
+
9
+
10
+ with open(sys.argv[1], 'r', encoding='utf8') as reader, \
11
+ open(sys.argv[2], 'w', encoding='utf8') as writer :
12
+ lines_dict = json.load(reader)
13
+
14
+ full_rela_lst = []
15
+ full_src_lst = []
16
+ full_tgt_lst = []
17
+ unique_src = 0
18
+
19
+ for example in lines_dict:
20
+ rela_lst = []
21
+ temp_triples = ''
22
+ for i, tripleset in enumerate(example['tripleset']):
23
+ subj, rela, obj = tripleset
24
+ rela = rela.lower()
25
+ rela_lst.append(rela)
26
+ if i > 0:
27
+ temp_triples += ' | '
28
+ temp_triples += '{} : {} : {}'.format(subj, rela, obj)
29
+
30
+ unique_src += 1
31
+
32
+ for sent in example['annotations']:
33
+ full_tgt_lst.append(sent['text'])
34
+ full_src_lst.append(temp_triples)
35
+ full_rela_lst.append(rela_lst)
36
+
37
+ print('unique source is', unique_src)
38
+
39
+ for src, tgt in zip(full_src_lst, full_tgt_lst):
40
+ x = {}
41
+ x['context'] = src # context #+ '||'
42
+ x['completion'] = tgt #completion
43
+ writer.write(json.dumps(x)+'\n')
examples/NLG/src/format_converting_e2e.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # ------------------------------------------------------------------------------------------
5
+ import sys
6
+ import io
7
+ import json
8
+
9
+
10
+ with open(sys.argv[1], 'r', encoding='utf8') as reader, \
11
+ open(sys.argv[2], 'w', encoding='utf8') as writer :
12
+ for line in reader:
13
+ items = line.strip().split('||')
14
+ context = items[0]
15
+ completion = items[1].strip('\n')
16
+ x = {}
17
+ x['context'] = context #+ '||'
18
+ x['completion'] = completion
19
+ writer.write(json.dumps(x)+'\n')
20
+
examples/NLG/src/format_converting_webnlg.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # ------------------------------------------------------------------------------------------
5
+ import sys
6
+ import io
7
+ import json
8
+
9
+
10
+ with open(sys.argv[1], 'r', encoding='utf8') as reader, \
11
+ open(sys.argv[2], 'w', encoding='utf8') as writer :
12
+ lines_dict = json.load(reader)
13
+
14
+ full_rela_lst = []
15
+ full_src_lst = []
16
+ full_tgt_lst = []
17
+ full_cate_lst = []
18
+
19
+ seen = [
20
+ 'Airport',
21
+ 'Astronaut',
22
+ 'Building',
23
+ 'City',
24
+ 'ComicsCharacter',
25
+ 'Food',
26
+ 'Monument',
27
+ 'SportsTeam',
28
+ 'University',
29
+ 'WrittenWork'
30
+ ]
31
+
32
+ cate_dict = {}
33
+ for i, example in enumerate(lines_dict['entries']):
34
+ sents = example[str(i+1)]['lexicalisations']
35
+ triples = example[str(i + 1)]['modifiedtripleset']
36
+ cate = example[str(i + 1)]['category']
37
+
38
+ if not cate in cate_dict:
39
+ cate_dict[cate] = 0
40
+ cate_dict[cate] += 1
41
+
42
+ rela_lst = []
43
+ temp_triples = ''
44
+ for i, tripleset in enumerate(triples):
45
+ subj, rela, obj = tripleset['subject'], tripleset['property'], tripleset['object']
46
+ rela_lst.append(rela)
47
+ if i > 0:
48
+ temp_triples += ' | '
49
+ temp_triples += '{} : {} : {}'.format(subj, rela, obj)
50
+
51
+ for sent in sents:
52
+ if sent["comment"] == 'good':
53
+ full_tgt_lst.append(sent['lex'])
54
+ full_src_lst.append(temp_triples)
55
+ full_rela_lst.append(rela_lst)
56
+ full_cate_lst.append(cate)
57
+
58
+ for cate in cate_dict:
59
+ print('cate', cate, cate_dict[cate])
60
+
61
+ #edited_sents = []
62
+ for src, tgt, cate in zip(full_src_lst, full_tgt_lst, full_cate_lst):
63
+ x = {}
64
+ x['context'] = src # context #+ '||'
65
+ x['completion'] = tgt #completion
66
+ x['cate'] = cate in seen
67
+ writer.write(json.dumps(x)+'\n')
68
+
examples/NLG/src/gpt2_beam.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # ------------------------------------------------------------------------------------------
5
+
6
+ # python -m torch.distributed.launch --nproc_per_node=1 src/gpt2_beam.py \
7
+ # --data ./data/e2e/test.jsonl \
8
+ # --batch_size 1 \
9
+ # --seq_len 512 \
10
+ # --eval_len 64 \
11
+ # --model_card gpt2.md \
12
+ # --platform local \
13
+ # --beam 10 \
14
+ # --length_penalty 0.8 \
15
+ # --no_repeat_ngram_size 4 \
16
+ # --repetition_penalty 1.0 \
17
+ # --eos_token_id 628 \
18
+ # --lora_dim 4 \
19
+ # --lora_alpha 32 \
20
+ # --work_dir ./trained_models/GPT2_M/e2e \
21
+ # --output_file predict.26290.jsonl \
22
+ # --init_checkpoint ./trained_models/GPT2_M/e2e/model.26290.pt
23
+
24
+
25
+ import argparse
26
+ import time
27
+ import math
28
+ import os, sys
29
+ import json
30
+ import itertools
31
+ from typing import Callable, Dict, Iterable, List, Optional, Tuple
32
+
33
+ import torch
34
+ from torch import Tensor, device, dtype, nn
35
+ from torch.nn import CrossEntropyLoss
36
+ from torch.nn import functional as F
37
+ from torch.utils.data import DataLoader
38
+ import torch.nn.functional as F
39
+ torch.set_printoptions(threshold=100000)
40
+
41
+ import numpy as np
42
+
43
+ from gpu import (
44
+ add_gpu_params,
45
+ parse_gpu,
46
+ distributed_opt,
47
+ distributed_gather,
48
+ distributed_sync,
49
+ cleanup
50
+ )
51
+
52
+ from exp_utils import create_exp_dir
53
+
54
+ from data_utils import FT_Dataset
55
+ from model import GPT2Config, GPT2LMModel
56
+
57
+
58
+ parser = argparse.ArgumentParser(description='PyTorch GPT2 beam decoding')
59
+
60
+ add_gpu_params(parser)
61
+
62
+ parser.add_argument('--data', type=str, default='../data/wikitext-103',
63
+ help='location of the data corpus')
64
+
65
+ parser.add_argument('--batch_size', type=int, default=10,
66
+ help='batch size')
67
+
68
+ parser.add_argument('--seq_len', type=int, default=512,
69
+ help='number of tokens to predict')
70
+
71
+ parser.add_argument('--eval_len', type=int, default=256,
72
+ help='evaluation length')
73
+
74
+ parser.add_argument('--min_length', type=int, default=0,
75
+ help='minimum generation length')
76
+
77
+ parser.add_argument('--model_card', default='gpt2.sm', choices=['gpt2.sm', 'gpt2.md', 'gpt2.lg'],
78
+ help='model names')
79
+
80
+ parser.add_argument('--init_checkpoint', default=None, type=str, help='initial checkpoint')
81
+
82
+ parser.add_argument('--lora_dim', type=int, default=0, help='lora attn dimension')
83
+
84
+ parser.add_argument('--lora_alpha', type=int, default=128, help='lora attn alpha')
85
+
86
+ parser.add_argument('--work_dir', type=str, default=os.getenv('PT_OUTPUT_DIR', 'gpt2_model'),
87
+ help='working folder')
88
+
89
+ parser.add_argument('--beam', type=int, default=1, help='beam search size')
90
+
91
+ parser.add_argument('--length_penalty', type=float, default=1.0, help='length penalty')
92
+
93
+ parser.add_argument('--no_repeat_ngram_size', type=int, default=4, help='no_repeat_ngram_size')
94
+
95
+ parser.add_argument('--repetition_penalty', type=float, default=1.0, help='repetition_penalty')
96
+
97
+ parser.add_argument('--eos_token_id', action='append', type=int, default=[50256],
98
+ help='eos token id')
99
+
100
+ parser.add_argument('--output_file', type=str, default='beam_prediction.jsonl',
101
+ help='output file name')
102
+
103
+
104
+ def print_args(args):
105
+ if args.rank == 0:
106
+ print('=' * 100)
107
+ for k, v in args.__dict__.items():
108
+ print(' - {} : {}'.format(k, v))
109
+ print('=' * 100)
110
+
111
+
112
+ def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
113
+ return tuple(layer_past.index_select(1, beam_idx).contiguous().detach() for layer_past in past)
114
+
115
+
116
+ def _calc_banned_ngram_tokens(
117
+ prev_input_ids: Tensor,
118
+ num_hypos: int,
119
+ no_repeat_ngram_size: int,
120
+ cur_len: int
121
+ ) -> None:
122
+ """Copied from fairseq for no_repeat_ngram in beam_search"""
123
+ if cur_len + 1 < no_repeat_ngram_size:
124
+ # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
125
+ return [[] for _ in range(num_hypos)]
126
+
127
+ generated_ngrams = [{} for _ in range(num_hypos)]
128
+ for idx in range(num_hypos):
129
+ gen_tokens = prev_input_ids[idx].tolist()
130
+ generated_ngram = generated_ngrams[idx]
131
+ for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
132
+ prev_ngram_tuple = tuple(ngram[:-1])
133
+ generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
134
+
135
+ def _get_generated_ngrams(hypo_idx):
136
+ # Before decoding the next token, prevent decoding of ngrams that have already appeared
137
+ start_idx = cur_len + 1 - no_repeat_ngram_size
138
+ ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
139
+ return generated_ngrams[hypo_idx].get(ngram_idx, [])
140
+
141
+ banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
142
+ return banned_tokens
143
+
144
+
145
+ def _enforce_repetition_penalty_(
146
+ lprobs,
147
+ batch_size,
148
+ num_beams,
149
+ prev_output_tokens,
150
+ repetition_penalty
151
+ ):
152
+ """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
153
+
154
+ for i in range(batch_size * num_beams):
155
+ print('prev_output_tokens.shape', prev_output_tokens.shape)
156
+ print('prev_output_tokens[i].shape', prev_output_tokens[i].shape)
157
+
158
+ for previous_token in set(prev_output_tokens[i].tolist()):
159
+ # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
160
+ if lprobs[i, previous_token] < 0:
161
+ lprobs[i, previous_token] *= repetition_penalty
162
+ else:
163
+ lprobs[i, previous_token] /= repetition_penalty
164
+
165
+ def _postprocess_next_token_scores(
166
+ scores,
167
+ history,
168
+ cur_len,
169
+ batch_size,
170
+ num_beams,
171
+ repetition_penalty=1.0,
172
+ no_repeat_ngram_size=4,
173
+ bad_words_ids=None,
174
+ min_length=0,
175
+ max_length=100,
176
+ eos_token_id=None,
177
+ ):
178
+ # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
179
+ if repetition_penalty != 1.0 and history is not None:
180
+ _enforce_repetition_penalty_(scores, batch_size, num_beams, history, repetition_penalty)
181
+
182
+ # score: batch_size * beam, vocab
183
+ # set eos token prob to zero if min_length is not reached
184
+ if eos_token_id is not None and cur_len < min_length:
185
+ for eos in eos_token_id:
186
+ scores[:, eos] = -float("inf")
187
+
188
+ if no_repeat_ngram_size > 0 and history is not None:
189
+ # calculate a list of banned tokens to prevent repetitively generating the same ngrams
190
+ num_batch_hypotheses = batch_size * num_beams
191
+ # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
192
+ banned_batch_tokens = _calc_banned_ngram_tokens(
193
+ history, num_batch_hypotheses, no_repeat_ngram_size, cur_len
194
+ )
195
+
196
+ for i, banned_tokens in enumerate(banned_batch_tokens):
197
+ scores[i, banned_tokens] = -float("inf")
198
+
199
+ return scores
200
+
201
+
202
+ def _add_beam_candidate(
203
+ best_score,
204
+ best_sequence,
205
+ batch_size,
206
+ num_beams,
207
+ beam_scores,
208
+ history,
209
+ eos_token_id=None
210
+ ):
211
+ last_tokens = history[:, -1]
212
+ for _i in range(batch_size * num_beams):
213
+ if eos_token_id is None or last_tokens[_i] in eos_token_id:
214
+ cur_len = history.shape[-1]
215
+ _score = beam_scores.view(-1)[_i] / cur_len ** args.length_penalty
216
+
217
+ batch_id = _i // num_beams
218
+
219
+ if not batch_id in best_score or best_score[batch_id] < _score:
220
+ best_score[batch_id] = _score
221
+ best_sequence[batch_id][:cur_len] = history[_i]
222
+
223
+ beam_scores.view(-1)[_i] = -float("inf")
224
+
225
+
226
+ def beam(model, data_iter, args):
227
+ model.eval()
228
+ total_loss = 0.
229
+ start_time = time.time()
230
+
231
+ all_predictions = {}
232
+ with torch.no_grad():
233
+ for idx, data in enumerate(data_iter):
234
+ data = {key: value for key, value in data.items()}
235
+
236
+ _id = data['id'].to(args.device)
237
+ _query = data['query'].to(args.device)
238
+ _query_len = data['query_len'].to(args.device)
239
+
240
+ ## local adaptation start.
241
+
242
+ ## local adaptation end.
243
+
244
+
245
+ output = None
246
+ score = None
247
+
248
+ batch_size = _id.size(0)
249
+ num_beams = args.beam
250
+ length_penalty = args.length_penalty
251
+
252
+ _batch = torch.arange(0, _id.size(0), device=args.device, dtype=torch.long)
253
+
254
+ past = None
255
+ len_past = None
256
+
257
+ _query = _query.repeat(1, num_beams).view(batch_size * num_beams, -1)
258
+ _query_len = _query_len.unsqueeze(-1).repeat(1, num_beams).view(-1)
259
+
260
+ _bbatch = _batch.unsqueeze(-1).repeat(1, num_beams).view(-1)
261
+
262
+ # scores for each sentence in the beam
263
+ beam_scores = torch.zeros(
264
+ (batch_size, num_beams), dtype=torch.float, device=_query.device
265
+ )
266
+
267
+ best_sequence = torch.zeros(
268
+ (batch_size, args.eval_len), dtype=torch.long, device=_query.device
269
+ )
270
+ best_score = {}
271
+
272
+ history = None
273
+ with torch.no_grad():
274
+ for i in range(0, args.eval_len):
275
+ if i == 0:
276
+ logits, past = model(_query)
277
+ logits = logits[_bbatch, (_query_len-1).long(), :] # batch_size * beam, vocab
278
+ else:
279
+ #print('token_id.shape', token_id.shape, token_id)
280
+ #print('past.shape', past[0].shape)
281
+ #print('len_past.shape', len_past.shape, len_past)
282
+
283
+ logits, past = model(token_id, past=past, len_past=len_past)
284
+ logits = logits[:, -1, :] # batch_size * beam, vocab
285
+
286
+ logits = _postprocess_next_token_scores(
287
+ logits,
288
+ history,
289
+ i,
290
+ batch_size,
291
+ num_beams,
292
+ repetition_penalty=args.repetition_penalty,
293
+ no_repeat_ngram_size=args.no_repeat_ngram_size,
294
+ min_length=args.min_length,
295
+ eos_token_id=args.eos_token_id,
296
+ )
297
+
298
+ softmax_probs = F.softmax(logits, dim=-1)
299
+ ##_prob, _w_idx = torch.topk(softmax_probs, num_beams) # batch_size, beam
300
+
301
+ vocab_size = softmax_probs.shape[-1]
302
+
303
+
304
+ _logprob = torch.log(softmax_probs) # batch_size * beam, vocab
305
+ if i == 0:
306
+ next_scores = _logprob.view(batch_size, num_beams, -1)[:, 0, :] # batch_size, vocab
307
+
308
+ else:
309
+ next_scores = beam_scores.unsqueeze(-1) + _logprob.view(batch_size, num_beams, -1)
310
+ next_scores = next_scores.view(batch_size, -1) # batch_size, beam * vocab
311
+
312
+ next_scores, next_tokens = torch.topk(
313
+ next_scores, num_beams, dim=1, largest=True, sorted=True
314
+ ) # batch_size, num_beams
315
+
316
+ beam_id = (next_tokens // vocab_size).view(-1) # batch_size * num_beams
317
+ token_id = (next_tokens % vocab_size).view(-1).unsqueeze(-1) # batch_size, num_beams
318
+
319
+ beam_idx = beam_id.view(batch_size, num_beams) + (_batch * num_beams).unsqueeze(-1)
320
+ past = _reorder_cache(past, beam_idx.view(-1))
321
+ beam_scores = next_scores # batch_size, num_beams
322
+ len_past = (_query_len + i).long()
323
+
324
+ if history is None:
325
+ history = token_id.detach()
326
+ else:
327
+ history = torch.cat((history[beam_idx.view(-1)], token_id.detach()), dim=1).detach()
328
+
329
+ _add_beam_candidate(
330
+ best_score, best_sequence, batch_size, num_beams, beam_scores, history,
331
+ eos_token_id=args.eos_token_id
332
+ )
333
+
334
+ _add_beam_candidate(
335
+ best_score, best_sequence, batch_size, num_beams, beam_scores, history
336
+ )
337
+
338
+
339
+ with torch.no_grad():
340
+ _id = distributed_gather(args, _id)
341
+ output = distributed_gather(args, best_sequence)
342
+ #score = distributed_gather(args, score)
343
+ distributed_sync(args)
344
+
345
+ if args.rank == 0:
346
+ _id = _id.view(-1).cpu()
347
+ output = output.view(-1, output.shape[-1]).cpu()
348
+ #score = score.view(-1, score.shape[-1]).cpu()
349
+
350
+ for _b in range(0, _id.shape[-1]):
351
+ _i = int(_id[_b].item())
352
+ all_predictions[_i] = {}
353
+ all_predictions[_i]['id'] = _i
354
+ all_predictions[_i]['predict'] = output[_b].tolist()
355
+ #all_predictions[_i]['score'] = score[_b].tolist()
356
+
357
+ if idx % 10 == 0:
358
+ print('inference samples', idx)
359
+ # pred_file = os.path.join(args.work_dir, args.output_file)
360
+ # print('saving prediction file', pred_file)
361
+ # with open(pred_file, 'w') as writer:
362
+ # for _i in all_predictions:
363
+ # writer.write(json.dumps(all_predictions[_i]) + '\n')
364
+
365
+ if args.rank == 0:
366
+ pred_file = os.path.join(args.work_dir, args.output_file)
367
+ print('saving prediction file', pred_file)
368
+ with open(pred_file, 'w') as writer:
369
+ for _i in all_predictions:
370
+ writer.write(json.dumps(all_predictions[_i]) + '\n')
371
+
372
+
373
+ if __name__ == '__main__':
374
+ args = parser.parse_args()
375
+ parse_gpu(args)
376
+ print_args(args)
377
+
378
+ if args.rank == 0:
379
+ args.logging = create_exp_dir(args.work_dir)
380
+
381
+ valid_data = FT_Dataset(
382
+ args.data, args.batch_size, args.seq_len, args.eval_len,
383
+ )
384
+ valid_data = valid_data.get_item_list(0, 1000)
385
+ valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_data)
386
+ valid_loader = DataLoader(
387
+ valid_data, batch_size=args.batch_size, num_workers=0, shuffle=False,
388
+ pin_memory=False, drop_last=False, sampler=valid_sampler
389
+ )
390
+
391
+ if args.model_card == 'gpt2.sm':
392
+ config = GPT2Config(
393
+ n_embd=768, n_layer=12, n_head=12,
394
+ lora_attn_dim=args.lora_dim, lora_attn_alpha=args.lora_alpha,
395
+ )
396
+ elif args.model_card == 'gpt2.md':
397
+ config = GPT2Config(
398
+ n_embd=1024, n_layer=24, n_head=16,
399
+ lora_attn_dim=args.lora_dim, lora_attn_alpha=args.lora_alpha,
400
+ )
401
+ elif args.model_card == 'gpt2.lg':
402
+ config = GPT2Config(
403
+ n_embd=1280, n_layer=36, n_head=20,
404
+ lora_attn_dim=args.lora_dim, lora_attn_alpha=args.lora_alpha,
405
+ )
406
+
407
+ lm_net = GPT2LMModel(config)
408
+ if args.init_checkpoint is not None:
409
+ print('loading model pretrained weight.')
410
+ cp = torch.load(args.init_checkpoint, map_location=torch.device('cpu'))
411
+ lm_net.load_weight(cp)
412
+ lm_net = lm_net.cuda()
413
+ print(lm_net.transformer.h[0].mlp)
414
+
415
+ print('model sampling ...')
416
+ beam(lm_net, valid_loader, args)
417
+ distributed_sync(args)
418
+ print('cleanup dist ...')
419
+ cleanup(args)
examples/NLG/src/gpt2_decode.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # ------------------------------------------------------------------------------------------
5
+
6
+ # python -m torch.distributed.launch --nproc_per_node=1 src/gpt2_beam.py \
7
+ # --data ./data/e2e/test.jsonl \
8
+ # --batch_size 1 \
9
+ # --seq_len 512 \
10
+ # --eval_len 64 \
11
+ # --model_card gpt2.md \
12
+ # --platform local \
13
+ # --beam 10 \
14
+ # --length_penalty 0.8 \
15
+ # --no_repeat_ngram_size 4 \
16
+ # --repetition_penalty 1.0 \
17
+ # --eos_token_id 628 \
18
+ # --lora_dim 4 \
19
+ # --lora_alpha 32 \
20
+ # --work_dir ./trained_models/GPT2_M/e2e \
21
+ # --output_file predict.26290.jsonl \
22
+ # --init_checkpoint ./trained_models/GPT2_M/e2e/model.26290.pt
23
+
24
+
25
+ import json
26
+ import numpy as np
27
+ import argparse
28
+ import os
29
+ import sys
30
+ import re
31
+ import json
32
+
33
+ import torch
34
+ import torch.nn as nn
35
+ import torch.nn.parallel
36
+ import torch.backends.cudnn as cudnn
37
+ import torch.optim as optim
38
+ import torch.utils.data
39
+
40
+ import encoder
41
+
42
+
43
+ parser = argparse.ArgumentParser()
44
+
45
+ parser.add_argument('--vocab', type=str, default=None, help='vocab path')
46
+
47
+ parser.add_argument('--sample_file', default=None, type=str, help='ft sample file')
48
+ parser.add_argument('--input_file', default=None, type=str, help='ft input file')
49
+
50
+ parser.add_argument('--output_ref_file', default=None, type=str, help='output reference file')
51
+ parser.add_argument('--output_pred_file', default=None, type=str, help='output predicion file')
52
+
53
+ parser.add_argument('--ref_unique_file', default=None, type=str, help='reference unique id file')
54
+
55
+ parser.add_argument('--ref_type', default='e2e', choices=['e2e', 'webnlg', 'dart'],
56
+ help='e2e style reference type; webnlg style reference type.')
57
+ parser.add_argument('--ref_num', default=4, type=int, help='number of references.')
58
+
59
+
60
+ parser.add_argument('--tokenize', action='store_true', help='')
61
+ parser.add_argument('--lower', action='store_true', help='')
62
+
63
+ parser.add_argument('--filter', default='all', choices=['all', 'seen', 'unseen'],
64
+ help='for webnlg only, filter categories that are seen during training, unseen, or all')
65
+
66
+ args = parser.parse_args()
67
+
68
+
69
+ def stardard_tokenize(sent):
70
+ sent = ' '.join(re.split('(\W)', sent))
71
+ sent = sent.split()
72
+ sent = ' '.join(sent)
73
+ return sent
74
+
75
+
76
+ def post_process(sent, is_tokenize, is_lower):
77
+ if is_lower:
78
+ sent = sent.lower()
79
+ if is_tokenize:
80
+ sent = stardard_tokenize(sent)
81
+
82
+ return sent
83
+
84
+
85
+ if __name__ == "__main__":
86
+ enc = encoder.get_encoder(args.vocab)
87
+
88
+ ref_unique = None
89
+
90
+ if args.ref_unique_file is not None:
91
+ print('reading ref_unique_file.')
92
+ ref_unique = []
93
+ uniques = {}
94
+ with open(args.ref_unique_file, 'r') as ref_unique_reader:
95
+ for line in ref_unique_reader:
96
+ _id = int(line.strip())
97
+ ref_unique.append(_id)
98
+ uniques[_id] = 1
99
+ print('len refer dict', len(ref_unique), 'unique', len(uniques))
100
+
101
+ with open(args.sample_file, 'r') as sample_reader, \
102
+ open(args.input_file, 'r', encoding='utf8') as input_reader, \
103
+ open(args.output_pred_file, 'w', encoding='utf8') as pred_writer:
104
+
105
+ refer_dict = {}
106
+ context_list = []
107
+ line_id = 0
108
+ for line in input_reader:
109
+ items = json.loads(line.strip())
110
+ context = items['context']
111
+ completion = items['completion']
112
+
113
+ context_list.append(context)
114
+
115
+ keep = False
116
+
117
+ if args.filter == 'all':
118
+ keep = True
119
+ if args.filter == 'seen' and items['cate']:
120
+ keep = True
121
+ if args.filter == 'unseen' and not items['cate']:
122
+ keep = True
123
+
124
+ if ref_unique is None:
125
+ _key = context
126
+ else:
127
+ _key = ref_unique[line_id]
128
+
129
+ if keep:
130
+ if not _key in refer_dict:
131
+ refer_dict[_key] = {}
132
+ refer_dict[_key]['references'] = []
133
+ refer_dict[_key]['references'].append(completion.split('<|endoftext|>')[0].split('\n\n')[0].strip())
134
+
135
+ line_id += 1
136
+ if line_id==1000:
137
+ break
138
+
139
+ print('unique refer dict', len(refer_dict))
140
+
141
+ for line in sample_reader:
142
+ items = json.loads(line.strip())
143
+ _id = items['id']
144
+ _pred_tokens = items['predict']
145
+
146
+ if ref_unique is None:
147
+ _key = context_list[_id]
148
+ else:
149
+ _key = ref_unique[_id]
150
+
151
+ #assert _key in refer_dict
152
+ # if _key in refer_dict:
153
+ if not _key in refer_dict:
154
+ refer_dict[_key] = {}
155
+ refer_dict[_key]['sample'] = []
156
+ refer_dict[_key]['sample'] = enc.decode(_pred_tokens).split('<|endoftext|>')[0].split('\n\n')[0].strip()
157
+
158
+ references = [refer_dict[s]['references'] for s in refer_dict]
159
+ hypothesis = [refer_dict[s]['sample'] for s in refer_dict]
160
+
161
+ if args.ref_type == 'e2e':
162
+ with open(args.output_ref_file, 'w', encoding='utf8') as ref_writer:
163
+ for ref, hyp in zip(references, hypothesis):
164
+ for r in ref:
165
+ ref_writer.write(post_process(r, args.tokenize, args.lower) + '\n')
166
+ ref_writer.write('\n')
167
+ pred_writer.write(post_process(hyp, args.tokenize, args.lower) + '\n')
168
+
169
+ elif args.ref_type in ['webnlg', 'dart']:
170
+ if not os.path.exists(args.output_ref_file):
171
+ os.makedirs(args.output_ref_file)
172
+
173
+ reference_writers = [
174
+ open(os.path.join(args.output_ref_file, f'reference{fid}'), 'w', encoding='utf8')
175
+ for fid in range(0, args.ref_num)
176
+ ]
177
+
178
+ for ref, hyp in zip(references, hypothesis):
179
+ for fid in range(0, args.ref_num):
180
+ if len(ref) > fid:
181
+ reference_writers[fid].write(post_process(ref[fid], args.tokenize, args.lower) + '\n')
182
+ else:
183
+ reference_writers[fid].write(post_process(ref[0], args.tokenize, args.lower) + '\n')
184
+ pred_writer.write(post_process(hyp, args.tokenize, args.lower) + '\n')
185
+
186
+ for writer in reference_writers:
187
+ writer.close()
examples/NLG/src/gpt2_encode.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # ------------------------------------------------------------------------------------------
5
+ import json
6
+ import numpy as np
7
+
8
+ import encoder
9
+
10
+ import argparse
11
+ import os
12
+ import random
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.parallel
16
+ import torch.backends.cudnn as cudnn
17
+ import torch.optim as optim
18
+ import torch.utils.data
19
+
20
+ import numpy
21
+ import io
22
+ import sys
23
+ import threading
24
+ import math
25
+ import random
26
+
27
+ import json
28
+ import collections
29
+ from collections import Counter
30
+ from collections import OrderedDict
31
+ from progress.bar import Bar as Bar
32
+
33
+
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument('--input', default=None, type=str, help='ft input file')
36
+ parser.add_argument('--vocab', type=str, default=None, help='vocab path')
37
+ parser.add_argument('--output', default=None, type=str, help='ft output file')
38
+ parser.add_argument('--add_bos', action='store_true', help='')
39
+ parser.add_argument('--add_eos', action='store_true', help='')
40
+ args = parser.parse_args()
41
+
42
+
43
+ if __name__ == "__main__":
44
+ enc = encoder.get_encoder(args.vocab)
45
+
46
+ writer = open(args.output, 'w')
47
+
48
+ with open(args.input, 'r') as reader:
49
+ line_idx = 0
50
+ for line in reader:
51
+ items = json.loads(line.strip())
52
+ context = items['context']
53
+ completion = items['completion']
54
+
55
+ bos = 50256
56
+ eos = 50256
57
+ context_bpes, _ = enc.encode(context)
58
+ context_bpes += [bos] if args.add_bos else []
59
+
60
+ completion_bpes, _ = enc.encode(' ' + completion)
61
+ completion_bpes += [eos] if args.add_eos else []
62
+
63
+ ft_json = {}
64
+ ft_json['context'] = context_bpes
65
+ ft_json['completion'] = completion_bpes
66
+ writer.write(json.dumps(ft_json)+'\n')
67
+
68
+ line_idx += 1
69
+
70
+ writer.close()
examples/NLG/src/gpt2_ft.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # ------------------------------------------------------------------------------------------
5
+ import argparse
6
+ import time
7
+ import math
8
+ import os, sys
9
+ import numpy as np
10
+ import itertools
11
+
12
+ import torch
13
+ import random
14
+ from torch.utils.data import DataLoader
15
+ torch.set_printoptions(threshold=100000)
16
+
17
+ from gpu import (
18
+ add_gpu_params,
19
+ parse_gpu,
20
+ distributed_opt,
21
+ distributed_gather,
22
+ distributed_sync,
23
+ cleanup
24
+ )
25
+ from optimizer import (
26
+ create_adam_optimizer,
27
+ create_optimizer_scheduler,
28
+ add_optimizer_params,
29
+ create_adam_optimizer_from_args
30
+ )
31
+
32
+ from data_utils import FT_Dataset
33
+ from model import GPT2Config, GPT2LMModel
34
+ from exp_utils import create_exp_dir
35
+
36
+ import loralib as lora
37
+
38
+ parser = argparse.ArgumentParser(description='PyTorch GPT2 ft script')
39
+
40
+ add_gpu_params(parser)
41
+ add_optimizer_params(parser)
42
+
43
+ parser.add_argument('--train_data', required=True, help='location of training data corpus')
44
+
45
+ parser.add_argument('--valid_data', required=True, help='location of validation data corpus')
46
+
47
+ parser.add_argument('--train_batch_size', type=int, default=8, help='training batch size')
48
+
49
+ parser.add_argument('--valid_batch_size', type=int, default=4, help='validation batch size')
50
+
51
+ parser.add_argument('--grad_acc', type=int, default=1, help='gradient accumulation steps')
52
+
53
+ parser.add_argument('--clip', type=float, default=0.0, help='gradient clip')
54
+
55
+ parser.add_argument('--seq_len', type=int, default=512, help='number of tokens to predict.')
56
+
57
+ parser.add_argument('--model_card', default='gpt2.md', choices=['gpt2.sm', 'gpt2.md', 'gpt2.lg'],
58
+ help='model names')
59
+
60
+ parser.add_argument('--init_checkpoint', default=None, help='pretrained checkpoint path')
61
+
62
+ parser.add_argument('--fp16', action='store_true', help='train model with fp16')
63
+
64
+ parser.add_argument('--log_interval', type=int, default=100, help='log interval')
65
+
66
+ parser.add_argument('--eval_interval', type=int, default=2000, help='eval interval')
67
+
68
+ parser.add_argument('--save_interval', type=int, default=500, help='save interval')
69
+
70
+ parser.add_argument('--work_dir', type=str, default=os.getenv('PT_OUTPUT_DIR', 'gpt2_model'),
71
+ help='working folder.')
72
+
73
+ parser.add_argument('--lora_dim', type=int, default=0, help='lora attn dimension')
74
+
75
+ parser.add_argument('--lora_alpha', type=int, default=128, help='lora attn alpha')
76
+
77
+ parser.add_argument('--obj', default='clm', choices=['jlm', 'clm'],
78
+ help='language model training objective')
79
+
80
+ parser.add_argument('--lora_dropout', default=0.0, type=float,
81
+ help='dropout probability for lora layers')
82
+
83
+ parser.add_argument('--label_smooth', default=0.0, type=float, help='label smoothing')
84
+
85
+ parser.add_argument('--roll_interval', type=int, default=-1, help='rolling interval')
86
+
87
+ parser.add_argument('--roll_lr', type=float, default=0.00001, help='rolling learning rate')
88
+
89
+ parser.add_argument('--roll_step', type=int, default=100, help='rolling step')
90
+
91
+ parser.add_argument('--eval_epoch', type=int, default=1, help='eval per number of epochs')
92
+
93
+ # influence model, calculate the influence score between two samples.
94
+ def print_args(args):
95
+ if args.rank == 0:
96
+ print('=' * 100)
97
+ for k, v in args.__dict__.items():
98
+ print(f' - {k} : {v}')
99
+ print('=' * 100)
100
+
101
+
102
+ class AverageMeter(object):
103
+ """Computes and stores the average and current value
104
+ Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
105
+ """
106
+ def __init__(self):
107
+ self.reset()
108
+
109
+ def reset(self):
110
+ self.val = 0
111
+ self.avg = 0
112
+ self.sum = 0
113
+ self.count = 0
114
+
115
+ def update(self, val, n=1):
116
+ self.val = val
117
+ self.sum += val * n
118
+ self.count += n
119
+ self.avg = self.sum / self.count
120
+
121
+
122
+ def optimizer_step(_loss, _optimizer, _model, _schedule, args, is_update=True):
123
+ if args.fp16:
124
+ with amp.scale_loss(_loss, _optimizer) as _scaled_loss:
125
+ _scaled_loss.backward()
126
+ else:
127
+ _loss.backward()
128
+
129
+ # for name, param in _model.named_parameters():
130
+ # if param.requires_grad and param.grad is not None:
131
+ # print(f"Parameter name: {name}")
132
+ # print(f"Gradient value: {param.grad}")
133
+
134
+ if is_update:
135
+ if args.clip > 0:
136
+ if args.fp16:
137
+ torch.nn.utils.clip_grad_norm_(amp.master_params(_optimizer), args.clip)
138
+ else:
139
+ torch.nn.utils.clip_grad_norm_(_model.parameters(), args.clip)
140
+
141
+ _optimizer.step()
142
+ _optimizer.zero_grad()
143
+
144
+ if _schedule is not None:
145
+ _schedule.step()
146
+
147
+ # print(f"query[0].lora_B = {_model.module.transformer.h[0].attn.c_attn.lora_B}")
148
+
149
+
150
+ def evaluate(model, valid_loader, args):
151
+ model.eval()
152
+ total_loss = 0.
153
+ start_time = time.time()
154
+
155
+ avg_lm_loss = AverageMeter()
156
+
157
+ with torch.no_grad():
158
+ for idx, data in enumerate(valid_loader):
159
+ data = {key: value for key, value in data.items()}
160
+
161
+ _input = data['input'].to(args.device)
162
+ _target = data['target'].to(args.device)
163
+ _msk = data['mask'].to(args.device)
164
+
165
+ _lm_logits, _loss = model(_input, lm_labels=_target, lm_mask=_msk)
166
+ loss = _loss.mean()
167
+ # print(f"logits={_lm_logits}, _loss={_loss}")
168
+
169
+ avg_lm_loss.update(loss.item())
170
+
171
+ if idx % 100 == 0:
172
+ print('eval samples:', idx, 'loss:', loss.float())
173
+
174
+ total_time = time.time() - start_time
175
+ print('average loss', avg_lm_loss.avg)
176
+ return avg_lm_loss.avg, math.exp(avg_lm_loss.avg)
177
+
178
+
179
+ def train_validate(
180
+ model,
181
+ optimizer,
182
+ scheduler,
183
+ train_loader,
184
+ valid_loader,
185
+ args,
186
+ train_step=0,
187
+ epoch=0
188
+ ):
189
+ model.train()
190
+ avg_lm_loss = AverageMeter()
191
+ print('start to train the model................', epoch)
192
+ log_start_time = time.time()
193
+ best_val_ppl = None
194
+
195
+ # train_loader.sampler.set_epoch(epoch)
196
+
197
+ for idx, data in enumerate(train_loader):
198
+ data = {key: value for key, value in data.items()}
199
+
200
+ _input = data['input'].to(args.device)
201
+ _target = data['target'].to(args.device)
202
+ _msk = data['mask'].to(args.device)
203
+
204
+ _lm_logits, _lm_loss = model(
205
+ _input, lm_labels=_target, lm_mask=_msk, label_smooth=args.label_smooth
206
+ )
207
+ # print(_input[0])
208
+
209
+ _lm_loss = _lm_loss.mean()
210
+
211
+ train_step += 1
212
+ is_update = True if train_step % args.grad_acc == 0 else False
213
+ avg_lm_loss.update(_lm_loss.item())
214
+ optimizer_step(
215
+ _lm_loss/(args.grad_acc), optimizer, model, scheduler, args, is_update=is_update
216
+ )
217
+
218
+ if train_step % args.log_interval == 0:
219
+ print(f"_lm_loss = {_lm_loss}")
220
+ print(f"layer[0].lora_A = {model.module.transformer.h[0].attn.c_attn.lora_A[0,:100]}")
221
+ elapsed = time.time() - log_start_time
222
+ lr = optimizer.param_groups[0]['lr']
223
+ log_str = f'| epoch {epoch:3d} step {train_step:>8d} | { idx + 1:>6d} batches | ' \
224
+ f'lr {lr:.3g} | ms/batch {elapsed * 1000 / args.log_interval:5.2f} | ' \
225
+ f'loss {avg_lm_loss.val:5.2f} | avg loss {avg_lm_loss.avg:5.2f} | ' \
226
+ f'ppl {math.exp(avg_lm_loss.avg):5.2f}'
227
+
228
+ if args.rank == 0:
229
+ print(log_str)
230
+ log_start_time = time.time()
231
+ avg_lm_loss.reset()
232
+
233
+ if train_step % args.save_interval == 0:
234
+ if args.rank == 0:
235
+ model_path = os.path.join(args.work_dir, f'model.{train_step}.pt')
236
+ print('saving checkpoint', model_path)
237
+ torch.save({'model_state_dict': lora.lora_state_dict(model)}, model_path)
238
+ distributed_sync(args)
239
+
240
+ # evaluation interval
241
+ if train_step % args.eval_interval == 0:
242
+ eval_start_time = time.time()
243
+
244
+ valid_loss, valid_ppl = evaluate(model, valid_loader, args)
245
+
246
+ if best_val_ppl is None or valid_ppl < best_val_ppl:
247
+ best_val_ppl = valid_ppl
248
+
249
+ log_str = f'| Eval {train_step // args.eval_interval:3d} at step {train_step:>8d} | ' \
250
+ f'time: {time.time() - eval_start_time:5.2f}s | valid loss {valid_loss:5.2f} | ' \
251
+ f'valid ppl {valid_ppl:5.2f} | best ppl {best_val_ppl:5.2f} '
252
+
253
+ if args.rank == 0:
254
+ print('-' * 100)
255
+ print(log_str)
256
+ print('-' * 100)
257
+
258
+ model.train()
259
+ distributed_sync(args)
260
+
261
+ if train_step == args.max_step:
262
+ break
263
+
264
+ if args.rank == 0:
265
+ model_path = os.path.join(args.work_dir, f'model.{train_step}.pt')
266
+ print('saving checkpoint', model_path)
267
+ torch.save({'model_state_dict': model.state_dict()}, model_path)
268
+ distributed_sync(args)
269
+ return train_step
270
+
271
+
272
+ if __name__ == '__main__':
273
+ args = parser.parse_args()
274
+ parse_gpu(args)
275
+ print_args(args)
276
+
277
+ if args.fp16:
278
+ try:
279
+ from apex import amp
280
+ except Exception as e:
281
+ warnings.warn('Could not import amp, apex may not be installed')
282
+
283
+ torch.manual_seed(args.random_seed)
284
+ random.seed(args.random_seed)
285
+
286
+ if args.rank == 0:
287
+ args.logging = create_exp_dir(args.work_dir)
288
+
289
+ train_data = FT_Dataset(
290
+ args.train_data, args.train_batch_size, args.seq_len,
291
+ joint_lm=args.obj=='jlm'
292
+ )
293
+
294
+ valid_data = FT_Dataset(
295
+ args.valid_data, args.valid_batch_size, args.seq_len,
296
+ )
297
+
298
+ train_loader = DataLoader(
299
+ train_data, batch_size=args.train_batch_size, num_workers=0,
300
+ shuffle=False, pin_memory=False, drop_last=True,
301
+ # sampler=torch.utils.data.distributed.DistributedSampler(train_data, seed=args.random_seed)
302
+ )
303
+
304
+ valid_loader = DataLoader(
305
+ valid_data, batch_size=args.valid_batch_size, num_workers=0,
306
+ shuffle=False, pin_memory=False, drop_last=False,
307
+ # sampler=torch.utils.data.distributed.DistributedSampler(valid_data, seed=args.random_seed)
308
+ )
309
+ print(f"train_loader={len(train_loader)}, train_data={len(train_data)}")
310
+ print(f"valid_loader={len(valid_loader)}, valid_data={len(valid_data)}")
311
+
312
+ if args.model_card == 'gpt2.sm':
313
+ config = GPT2Config(
314
+ n_embd=768, n_layer=12, n_head=12,
315
+ lora_attn_dim=args.lora_dim,
316
+ lora_attn_alpha=args.lora_alpha,
317
+ lora_dropout=args.lora_dropout,
318
+ )
319
+ elif args.model_card == 'gpt2.md':
320
+ config = GPT2Config(
321
+ n_embd=1024, n_layer=24, n_head=16,
322
+ lora_attn_dim=args.lora_dim,
323
+ lora_attn_alpha=args.lora_alpha,
324
+ lora_dropout=args.lora_dropout,
325
+ )
326
+ elif args.model_card == 'gpt2.lg':
327
+ config = GPT2Config(
328
+ n_embd=1280, n_layer=36, n_head=20,
329
+ lora_attn_dim=args.lora_dim,
330
+ lora_attn_alpha=args.lora_alpha,
331
+ lora_dropout=args.lora_dropout,
332
+ )
333
+
334
+ lm_net = GPT2LMModel(config)
335
+ if args.init_checkpoint is not None:
336
+ print('loading model pretrained weight.')
337
+ lm_net.load_weight(torch.load(args.init_checkpoint))
338
+
339
+ lm_net = lm_net.cuda()
340
+
341
+ if args.lora_dim > 0:
342
+ lora.mark_only_lora_as_trainable(lm_net)
343
+
344
+ print(lm_net)
345
+ print(lm_net.transformer.h[0].attn.c_attn.weight.shape)
346
+ print(lm_net.transformer.h[0].attn.c_attn.lora_A.shape)
347
+ print(lm_net.transformer.h[0].attn.c_attn.lora_B.shape)
348
+ config_dict = vars(config)
349
+ for param, value in config_dict.items():
350
+ print(f"{param}: {value}")
351
+ print(args)
352
+ optimizer = create_adam_optimizer_from_args(lm_net, args)
353
+ print("optimizer: " + str(optimizer))
354
+
355
+ if args.max_step is None:
356
+ args.max_step = (args.max_epoch * train_data.num_batches + args.world_size - 1) // args.world_size
357
+ print('set max_step:', args.max_step)
358
+ print('train_data.num_batches:', train_data.num_batches)
359
+
360
+ scheduler = create_optimizer_scheduler(optimizer, args)
361
+ if args.fp16:
362
+ lm_net, optimizer = amp.initialize(lm_net, optimizer, opt_level="O1")
363
+ lm_net, optimizer = distributed_opt(args, lm_net, optimizer, grad_acc=args.grad_acc)
364
+
365
+ try:
366
+ train_step = 0
367
+ for epoch in itertools.count(start=1):
368
+ train_step = train_validate(
369
+ lm_net, optimizer, scheduler, train_loader, valid_loader, args,
370
+ train_step=train_step, epoch=epoch
371
+ )
372
+
373
+ if train_step >= args.max_step or (args.max_epoch is not None and epoch >= args.max_epoch):
374
+ if args.rank == 0:
375
+ print('-' * 100)
376
+ print('End of training')
377
+ break
378
+ except KeyboardInterrupt:
379
+ if args.rank == 0:
380
+ print('-' * 100)
381
+ print('Exiting from training early')
382
+
383
+ distributed_sync(args)
384
+ print('cleanup dist ...')
385
+ cleanup(args)
examples/NLG/src/gpu.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # ------------------------------------------------------------------------------------------
5
+ import argparse
6
+ import time
7
+ import math
8
+ import os, sys
9
+ import itertools
10
+
11
+ import numpy as np
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.optim as optim
16
+ import torch.distributed as dist
17
+
18
+
19
+ gpu_offset = 4 # 0
20
+
21
+ def add_gpu_params(parser: argparse.ArgumentParser):
22
+ parser.add_argument("--platform", default='k8s', type=str, help='platform cloud')
23
+ parser.add_argument("--local_rank", default=0, type=int, help='local rank')
24
+ parser.add_argument("--rank", default=0, type=int, help='rank')
25
+ parser.add_argument("--device", default=0, type=int, help='device')
26
+ parser.add_argument("--world_size", default=0, type=int, help='world size')
27
+ parser.add_argument("--random_seed", default=10, type=int, help='random seed')
28
+
29
+
30
+ def distributed_opt(args, model, opt, grad_acc=1):
31
+ if args.platform == 'azure':
32
+ args.hvd.broadcast_parameters(model.state_dict(), root_rank=0)
33
+ opt = args.hvd.DistributedOptimizer(
34
+ opt, named_parameters=model.named_parameters(), backward_passes_per_step=grad_acc
35
+ )
36
+ elif args.platform == 'philly' or args.platform == 'k8s' or args.platform == 'local':
37
+ model = torch.nn.parallel.DistributedDataParallel(
38
+ model, device_ids=[args.local_rank+gpu_offset], output_device=args.local_rank+gpu_offset, # change
39
+ find_unused_parameters=False, broadcast_buffers=False
40
+ )
41
+ return model, opt
42
+
43
+
44
+ def distributed_gather(args, tensor):
45
+ g_y = [torch.zeros_like(tensor) for _ in range(args.world_size)]
46
+ torch.distributed.all_gather(g_y, tensor, async_op=False)
47
+ return torch.stack(g_y)
48
+
49
+
50
+ def distributed_sync(args):
51
+ if args.platform == 'azure':
52
+ args.hvd.allreduce(torch.tensor(0), name='barrier')
53
+ else:
54
+ args.dist.barrier()
55
+
56
+
57
+ def parse_gpu(args):
58
+ torch.manual_seed(args.random_seed)
59
+
60
+ if args.platform == 'local':
61
+ dist.init_process_group(backend='nccl')
62
+ local_rank = torch.distributed.get_rank()
63
+ torch.cuda.set_device(local_rank+gpu_offset) # change
64
+ device = torch.device('cuda', local_rank+gpu_offset) # change
65
+ args.rank = local_rank
66
+ args.device = device
67
+ args.world_size = torch.distributed.get_world_size()
68
+ args.dist = dist
69
+
70
+ elif args.platform == 'azure':
71
+ import horovod.torch as hvd
72
+ hvd.init()
73
+ print('azure hvd rank', hvd.rank(), 'local rank', hvd.local_rank())
74
+ local_rank = hvd.local_rank()
75
+ torch.cuda.set_device(local_rank)
76
+ device = torch.device('cuda', local_rank)
77
+ rank = hvd.rank()
78
+ world_size = hvd.size()
79
+
80
+ args.local_rank = local_rank
81
+ args.rank = rank
82
+ args.device = device
83
+ args.world_size = world_size
84
+ args.hvd = hvd
85
+
86
+ elif args.platform == 'philly':
87
+ local_rank = args.local_rank
88
+ torch.cuda.set_device(local_rank)
89
+ dist.init_process_group(backend='nccl')
90
+ rank = dist.get_rank()
91
+ world_size = torch.distributed.get_world_size()
92
+ device = torch.device('cuda', local_rank)
93
+
94
+ args.rank = rank
95
+ args.device = device
96
+ args.world_size = world_size
97
+ args.dist = dist
98
+ elif args.platform == 'k8s':
99
+ master_uri = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
100
+ local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
101
+ args.local_rank = local_rank
102
+ world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
103
+ world_rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
104
+ rank = world_rank
105
+ torch.cuda.set_device(local_rank)
106
+
107
+ dist.init_process_group(
108
+ backend='nccl',
109
+ init_method=master_uri,
110
+ world_size=world_size,
111
+ rank=world_rank,
112
+ )
113
+ device = torch.device("cuda", local_rank)
114
+ args.rank = rank
115
+ args.device = device
116
+ args.world_size = world_size
117
+ args.dist = dist
118
+ print(
119
+ 'myrank:', args.rank,
120
+ 'local_rank:', args.local_rank,
121
+ 'device_count:', torch.cuda.device_count(),
122
+ 'world_size:', args.world_size,
123
+ 'device:', device
124
+ )
125
+
126
+
127
+ def cleanup(args):
128
+ if args.platform == 'k8s' or args.platform == 'philly':
129
+ args.dist.destroy_process_group()
examples/NLG/src/model.log ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ myrank: 0 local_rank: 0 device_count: 8 world_size: 1 device: cuda:4
2
+ ====================================================================================================
3
+ - platform : local
4
+ - local_rank : 0
5
+ - rank : 0
6
+ - device : cuda:4
7
+ - world_size : 1
8
+ - random_seed : 110
9
+ - lr : 0.0002
10
+ - weight_decay : 0.01
11
+ - correct_bias : True
12
+ - adam_epislon : 1e-06
13
+ - no_decay_bias : False
14
+ - adam_beta1 : 0.9
15
+ - adam_beta2 : 0.999
16
+ - scheduler : linear
17
+ - max_step : None
18
+ - max_epoch : 5
19
+ - warmup_step : 500
20
+ - i_steps : 0
21
+ - i_lrs : 0.00025
22
+ - train_data : ./data/e2e/train.jsonl
23
+ - valid_data : ./data/e2e/valid.jsonl
24
+ - train_batch_size : 8
25
+ - valid_batch_size : 4
26
+ - grad_acc : 1
27
+ - clip : 0.0
28
+ - seq_len : 512
29
+ - model_card : gpt2.md
30
+ - init_checkpoint : ./pretrained_checkpoints/gpt2-medium-pytorch_model.bin
31
+ - fp16 : False
32
+ - log_interval : 100
33
+ - eval_interval : 2000
34
+ - save_interval : 1000
35
+ - work_dir : ./trained_models/GPT2_M/e2e
36
+ - lora_dim : 4
37
+ - lora_alpha : 32
38
+ - obj : clm
39
+ - lora_dropout : 0.1
40
+ - label_smooth : 0.1
41
+ - roll_interval : -1
42
+ - roll_lr : 1e-05
43
+ - roll_step : 100
44
+ - eval_epoch : 1
45
+ - dist : <module 'torch.distributed' from '/home/inc/miniconda3/envs/fedadp-new/lib/python3.7/site-packages/torch/distributed/__init__.py'>
46
+ ====================================================================================================
47
+ Experiment dir : ./trained_models/GPT2_M/e2e
48
+ train_loader=5258, train_data=42064
49
+ valid_loader=1168, valid_data=4672
50
+ scaling = 8.0
51
+ loading model pretrained weight.
52
+ GPT2LMModel(
53
+ (transformer): GPT2Model(
54
+ (wte): Embedding(50257, 1024)
55
+ (wpe): Embedding(1024, 1024)
56
+ (h): ModuleList(
57
+ (0): Block(
58
+ (ln_1): LayerNorm()
59
+ (attn): Attention(
60
+ (c_attn): MergedLinear(
61
+ in_features=1024, out_features=3072, bias=True
62
+ (lora_dropout): Dropout(p=0.1, inplace=False)
63
+ )
64
+ (c_proj): Conv1D()
65
+ )
66
+ (ln_2): LayerNorm()
67
+ (mlp): MLP(
68
+ (c_fc): Conv1D()
69
+ (c_proj): Conv1D()
70
+ )
71
+ )
72
+ (1): Block(
73
+ (ln_1): LayerNorm()
74
+ (attn): Attention(
75
+ (c_attn): MergedLinear(
76
+ in_features=1024, out_features=3072, bias=True
77
+ (lora_dropout): Dropout(p=0.1, inplace=False)
78
+ )
79
+ (c_proj): Conv1D()
80
+ )
81
+ (ln_2): LayerNorm()
82
+ (mlp): MLP(
83
+ (c_fc): Conv1D()
84
+ (c_proj): Conv1D()
85
+ )
86
+ )
87
+ (2): Block(
88
+ (ln_1): LayerNorm()
89
+ (attn): Attention(
90
+ (c_attn): MergedLinear(
91
+ in_features=1024, out_features=3072, bias=True
92
+ (lora_dropout): Dropout(p=0.1, inplace=False)
93
+ )
94
+ (c_proj): Conv1D()
95
+ )
96
+ (ln_2): LayerNorm()
97
+ (mlp): MLP(
98
+ (c_fc): Conv1D()
99
+ (c_proj): Conv1D()
100
+ )
101
+ )
102
+ (3): Block(
103
+ (ln_1): LayerNorm()
104
+ (attn): Attention(
105
+ (c_attn): MergedLinear(
106
+ in_features=1024, out_features=3072, bias=True
107
+ (lora_dropout): Dropout(p=0.1, inplace=False)
108
+ )
109
+ (c_proj): Conv1D()
110
+ )
111
+ (ln_2): LayerNorm()
112
+ (mlp): MLP(
113
+ (c_fc): Conv1D()
114
+ (c_proj): Conv1D()
115
+ )
116
+ )
117
+ (4): Block(
118
+ (ln_1): LayerNorm()
119
+ (attn): Attention(
120
+ (c_attn): MergedLinear(
121
+ in_features=1024, out_features=3072, bias=True
122
+ (lora_dropout): Dropout(p=0.1, inplace=False)
123
+ )
124
+ (c_proj): Conv1D()
125
+ )
126
+ (ln_2): LayerNorm()
127
+ (mlp): MLP(
128
+ (c_fc): Conv1D()
129
+ (c_proj): Conv1D()
130
+ )
131
+ )
132
+ (5): Block(
133
+ (ln_1): LayerNorm()
134
+ (attn): Attention(
135
+ (c_attn): MergedLinear(
136
+ in_features=1024, out_features=3072, bias=True
137
+ (lora_dropout): Dropout(p=0.1, inplace=False)
138
+ )
139
+ (c_proj): Conv1D()
140
+ )
141
+ (ln_2): LayerNorm()
142
+ (mlp): MLP(
143
+ (c_fc): Conv1D()
144
+ (c_proj): Conv1D()
145
+ )
146
+ )
147
+ (6): Block(
148
+ (ln_1): LayerNorm()
149
+ (attn): Attention(
150
+ (c_attn): MergedLinear(
151
+ in_features=1024, out_features=3072, bias=True
152
+ (lora_dropout): Dropout(p=0.1, inplace=False)
153
+ )
154
+ (c_proj): Conv1D()
155
+ )
156
+ (ln_2): LayerNorm()
157
+ (mlp): MLP(
158
+ (c_fc): Conv1D()
159
+ (c_proj): Conv1D()
160
+ )
161
+ )
162
+ (7): Block(
163
+ (ln_1): LayerNorm()
164
+ (attn): Attention(
165
+ (c_attn): MergedLinear(
166
+ in_features=1024, out_features=3072, bias=True
167
+ (lora_dropout): Dropout(p=0.1, inplace=False)
168
+ )
169
+ (c_proj): Conv1D()
170
+ )
171
+ (ln_2): LayerNorm()
172
+ (mlp): MLP(
173
+ (c_fc): Conv1D()
174
+ (c_proj): Conv1D()
175
+ )
176
+ )
177
+ (8): Block(
178
+ (ln_1): LayerNorm()
179
+ (attn): Attention(
180
+ (c_attn): MergedLinear(
181
+ in_features=1024, out_features=3072, bias=True
182
+ (lora_dropout): Dropout(p=0.1, inplace=False)
183
+ )
184
+ (c_proj): Conv1D()
185
+ )
186
+ (ln_2): LayerNorm()
187
+ (mlp): MLP(
188
+ (c_fc): Conv1D()
189
+ (c_proj): Conv1D()
190
+ )
191
+ )
192
+ (9): Block(
193
+ (ln_1): LayerNorm()
194
+ (attn): Attention(
195
+ (c_attn): MergedLinear(
196
+ in_features=1024, out_features=3072, bias=True
197
+ (lora_dropout): Dropout(p=0.1, inplace=False)
198
+ )
199
+ (c_proj): Conv1D()
200
+ )
201
+ (ln_2): LayerNorm()
202
+ (mlp): MLP(
203
+ (c_fc): Conv1D()
204
+ (c_proj): Conv1D()
205
+ )
206
+ )
207
+ (10): Block(
208
+ (ln_1): LayerNorm()
209
+ (attn): Attention(
210
+ (c_attn): MergedLinear(
211
+ in_features=1024, out_features=3072, bias=True
212
+ (lora_dropout): Dropout(p=0.1, inplace=False)
213
+ )
214
+ (c_proj): Conv1D()
215
+ )
216
+ (ln_2): LayerNorm()
217
+ (mlp): MLP(
218
+ (c_fc): Conv1D()
219
+ (c_proj): Conv1D()
220
+ )
221
+ )
222
+ (11): Block(
223
+ (ln_1): LayerNorm()
224
+ (attn): Attention(
225
+ (c_attn): MergedLinear(
226
+ in_features=1024, out_features=3072, bias=True
227
+ (lora_dropout): Dropout(p=0.1, inplace=False)
228
+ )
229
+ (c_proj): Conv1D()
230
+ )
231
+ (ln_2): LayerNorm()
232
+ (mlp): MLP(
233
+ (c_fc): Conv1D()
234
+ (c_proj): Conv1D()
235
+ )
236
+ )
237
+ (12): Block(
238
+ (ln_1): LayerNorm()
239
+ (attn): Attention(
240
+ (c_attn): MergedLinear(
241
+ in_features=1024, out_features=3072, bias=True
242
+ (lora_dropout): Dropout(p=0.1, inplace=False)
243
+ )
244
+ (c_proj): Conv1D()
245
+ )
246
+ (ln_2): LayerNorm()
247
+ (mlp): MLP(
248
+ (c_fc): Conv1D()
249
+ (c_proj): Conv1D()
250
+ )
251
+ )
252
+ (13): Block(
253
+ (ln_1): LayerNorm()
254
+ (attn): Attention(
255
+ (c_attn): MergedLinear(
256
+ in_features=1024, out_features=3072, bias=True
257
+ (lora_dropout): Dropout(p=0.1, inplace=False)
258
+ )
259
+ (c_proj): Conv1D()
260
+ )
261
+ (ln_2): LayerNorm()
262
+ (mlp): MLP(
263
+ (c_fc): Conv1D()
264
+ (c_proj): Conv1D()
265
+ )
266
+ )
267
+ (14): Block(
268
+ (ln_1): LayerNorm()
269
+ (attn): Attention(
270
+ (c_attn): MergedLinear(
271
+ in_features=1024, out_features=3072, bias=True
272
+ (lora_dropout): Dropout(p=0.1, inplace=False)
273
+ )
274
+ (c_proj): Conv1D()
275
+ )
276
+ (ln_2): LayerNorm()
277
+ (mlp): MLP(
278
+ (c_fc): Conv1D()
279
+ (c_proj): Conv1D()
280
+ )
281
+ )
282
+ (15): Block(
283
+ (ln_1): LayerNorm()
284
+ (attn): Attention(
285
+ (c_attn): MergedLinear(
286
+ in_features=1024, out_features=3072, bias=True
287
+ (lora_dropout): Dropout(p=0.1, inplace=False)
288
+ )
289
+ (c_proj): Conv1D()
290
+ )
291
+ (ln_2): LayerNorm()
292
+ (mlp): MLP(
293
+ (c_fc): Conv1D()
294
+ (c_proj): Conv1D()
295
+ )
296
+ )
297
+ (16): Block(
298
+ (ln_1): LayerNorm()
299
+ (attn): Attention(
300
+ (c_attn): MergedLinear(
301
+ in_features=1024, out_features=3072, bias=True
302
+ (lora_dropout): Dropout(p=0.1, inplace=False)
303
+ )
304
+ (c_proj): Conv1D()
305
+ )
306
+ (ln_2): LayerNorm()
307
+ (mlp): MLP(
308
+ (c_fc): Conv1D()
309
+ (c_proj): Conv1D()
310
+ )
311
+ )
312
+ (17): Block(
313
+ (ln_1): LayerNorm()
314
+ (attn): Attention(
315
+ (c_attn): MergedLinear(
316
+ in_features=1024, out_features=3072, bias=True
317
+ (lora_dropout): Dropout(p=0.1, inplace=False)
318
+ )
319
+ (c_proj): Conv1D()
320
+ )
321
+ (ln_2): LayerNorm()
322
+ (mlp): MLP(
323
+ (c_fc): Conv1D()
324
+ (c_proj): Conv1D()
325
+ )
326
+ )
327
+ (18): Block(
328
+ (ln_1): LayerNorm()
329
+ (attn): Attention(
330
+ (c_attn): MergedLinear(
331
+ in_features=1024, out_features=3072, bias=True
332
+ (lora_dropout): Dropout(p=0.1, inplace=False)
333
+ )
334
+ (c_proj): Conv1D()
335
+ )
336
+ (ln_2): LayerNorm()
337
+ (mlp): MLP(
338
+ (c_fc): Conv1D()
339
+ (c_proj): Conv1D()
340
+ )
341
+ )
342
+ (19): Block(
343
+ (ln_1): LayerNorm()
344
+ (attn): Attention(
345
+ (c_attn): MergedLinear(
346
+ in_features=1024, out_features=3072, bias=True
347
+ (lora_dropout): Dropout(p=0.1, inplace=False)
348
+ )
349
+ (c_proj): Conv1D()
350
+ )
351
+ (ln_2): LayerNorm()
352
+ (mlp): MLP(
353
+ (c_fc): Conv1D()
354
+ (c_proj): Conv1D()
355
+ )
356
+ )
357
+ (20): Block(
358
+ (ln_1): LayerNorm()
359
+ (attn): Attention(
360
+ (c_attn): MergedLinear(
361
+ in_features=1024, out_features=3072, bias=True
362
+ (lora_dropout): Dropout(p=0.1, inplace=False)
363
+ )
364
+ (c_proj): Conv1D()
365
+ )
366
+ (ln_2): LayerNorm()
367
+ (mlp): MLP(
368
+ (c_fc): Conv1D()
369
+ (c_proj): Conv1D()
370
+ )
371
+ )
372
+ (21): Block(
373
+ (ln_1): LayerNorm()
374
+ (attn): Attention(
375
+ (c_attn): MergedLinear(
376
+ in_features=1024, out_features=3072, bias=True
377
+ (lora_dropout): Dropout(p=0.1, inplace=False)
378
+ )
379
+ (c_proj): Conv1D()
380
+ )
381
+ (ln_2): LayerNorm()
382
+ (mlp): MLP(
383
+ (c_fc): Conv1D()
384
+ (c_proj): Conv1D()
385
+ )
386
+ )
387
+ (22): Block(
388
+ (ln_1): LayerNorm()
389
+ (attn): Attention(
390
+ (c_attn): MergedLinear(
391
+ in_features=1024, out_features=3072, bias=True
392
+ (lora_dropout): Dropout(p=0.1, inplace=False)
393
+ )
394
+ (c_proj): Conv1D()
395
+ )
396
+ (ln_2): LayerNorm()
397
+ (mlp): MLP(
398
+ (c_fc): Conv1D()
399
+ (c_proj): Conv1D()
400
+ )
401
+ )
402
+ (23): Block(
403
+ (ln_1): LayerNorm()
404
+ (attn): Attention(
405
+ (c_attn): MergedLinear(
406
+ in_features=1024, out_features=3072, bias=True
407
+ (lora_dropout): Dropout(p=0.1, inplace=False)
408
+ )
409
+ (c_proj): Conv1D()
410
+ )
411
+ (ln_2): LayerNorm()
412
+ (mlp): MLP(
413
+ (c_fc): Conv1D()
414
+ (c_proj): Conv1D()
415
+ )
416
+ )
417
+ )
418
+ (ln_f): LayerNorm()
419
+ )
420
+ (lm_head): GPT2LMHead(
421
+ (decoder): Linear(in_features=1024, out_features=50257, bias=False)
422
+ )
423
+ )
424
+ vocab_size: 50257
425
+ n_ctx: 1024
426
+ n_positions: 1024
427
+ n_embd: 1024
428
+ n_layer: 24
429
+ n_head: 16
430
+ layer_norm_epsilon: 1e-05
431
+ initializer_range: 0.02
432
+ lora_attn_dim: 4
433
+ lora_attn_alpha: 32
434
+ lora_dropout: 0.1
435
+ lora_r_dropout: 0.0
436
+ fix_dropout: 0.0
437
+ Namespace(adam_beta1=0.9, adam_beta2=0.999, adam_epislon=1e-06, clip=0.0, correct_bias=True, device=device(type='cuda', index=4), dist=<module 'torch.distributed' from '/home/inc/miniconda3/envs/fedadp-new/lib/python3.7/site-packages/torch/distributed/__init__.py'>, eval_epoch=1, eval_interval=2000, fp16=False, grad_acc=1, i_lrs='0.00025', i_steps='0', init_checkpoint='./pretrained_checkpoints/gpt2-medium-pytorch_model.bin', label_smooth=0.1, local_rank=0, log_interval=100, logging=functools.partial(<function logging at 0x7f90cac2ae60>, log_path='./trained_models/GPT2_M/e2e/log.txt'), lora_alpha=32, lora_dim=4, lora_dropout=0.1, lr=0.0002, max_epoch=5, max_step=None, model_card='gpt2.md', no_decay_bias=False, obj='clm', platform='local', random_seed=110, rank=0, roll_interval=-1, roll_lr=1e-05, roll_step=100, save_interval=1000, scheduler='linear', seq_len=512, train_batch_size=8, train_data='./data/e2e/train.jsonl', valid_batch_size=4, valid_data='./data/e2e/valid.jsonl', warmup_step=500, weight_decay=0.01, work_dir='./trained_models/GPT2_M/e2e', world_size=1)
438
+ optimizer: AdamW (
439
+ Parameter Group 0
440
+ betas: (0.9, 0.999)
441
+ correct_bias: True
442
+ eps: 1e-06
443
+ lr: 0.0002
444
+ weight_decay: 0.01
445
+ )
446
+ set max_step: 26290
447
+ train_data.num_batches: 5258
448
+ start to train the model................ 1
449
+ /home/inc/Documents/fzh/python/LoRA-main/examples/NLG/src/optimizer.py:117: UserWarning: This overload of addcdiv_ is deprecated:
450
+ addcdiv_(Number value, Tensor tensor1, Tensor tensor2)
451
+ Consider using one of the following signatures instead:
452
+ addcdiv_(Tensor tensor1, Tensor tensor2, *, Number value) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:1050.)
453
+ p.data.addcdiv_(-step_size, exp_avg, denom)
454
+
455
+
456
+ | epoch 1 step 100 | 100 batches | lr 4e-05 | ms/batch 612.69 | loss 5.06 | avg loss 5.52 | ppl 250.72
457
+ | epoch 1 step 200 | 200 batches | lr 8e-05 | ms/batch 608.52 | loss 3.21 | avg loss 3.70 | ppl 40.58
458
+ | epoch 1 step 300 | 300 batches | lr 0.00012 | ms/batch 609.77 | loss 2.98 | avg loss 3.08 | ppl 21.74
459
+ | epoch 1 step 400 | 400 batches | lr 0.00016 | ms/batch 610.18 | loss 3.11 | avg loss 2.98 | ppl 19.63
460
+ | epoch 1 step 500 | 500 batches | lr 0.0002 | ms/batch 610.03 | loss 2.84 | avg loss 2.89 | ppl 18.03
461
+ | epoch 1 step 600 | 600 batches | lr 0.000199 | ms/batch 608.84 | loss 2.77 | avg loss 2.83 | ppl 16.93
462
+ | epoch 1 step 700 | 700 batches | lr 0.000198 | ms/batch 611.37 | loss 2.88 | avg loss 2.80 | ppl 16.37
463
+ | epoch 1 step 800 | 800 batches | lr 0.000198 | ms/batch 611.10 | loss 2.48 | avg loss 2.76 | ppl 15.76
464
+ | epoch 1 step 900 | 900 batches | lr 0.000197 | ms/batch 610.61 | loss 2.50 | avg loss 2.75 | ppl 15.59
465
+ | epoch 1 step 1000 | 1000 batches | lr 0.000196 | ms/batch 610.44 | loss 3.19 | avg loss 2.77 | ppl 15.95
466
+ saving checkpoint ./trained_models/GPT2_M/e2e/model.1000.pt
467
+ | epoch 1 step 1100 | 1100 batches | lr 0.000195 | ms/batch 612.14 | loss 2.76 | avg loss 2.73 | ppl 15.41
468
+ | epoch 1 step 1200 | 1200 batches | lr 0.000195 | ms/batch 608.16 | loss 3.02 | avg loss 2.76 | ppl 15.84
469
+ | epoch 1 step 1300 | 1300 batches | lr 0.000194 | ms/batch 610.06 | loss 2.55 | avg loss 2.75 | ppl 15.62
470
+ | epoch 1 step 1400 | 1400 batches | lr 0.000193 | ms/batch 609.24 | loss 2.35 | avg loss 2.70 | ppl 14.93
471
+ | epoch 1 step 1500 | 1500 batches | lr 0.000192 | ms/batch 607.91 | loss 2.53 | avg loss 2.72 | ppl 15.24
472
+ | epoch 1 step 1600 | 1600 batches | lr 0.000191 | ms/batch 608.62 | loss 2.53 | avg loss 2.67 | ppl 14.50
473
+ | epoch 1 step 1700 | 1700 batches | lr 0.000191 | ms/batch 608.92 | loss 2.66 | avg loss 2.71 | ppl 14.99
474
+ | epoch 1 step 1800 | 1800 batches | lr 0.00019 | ms/batch 608.44 | loss 2.55 | avg loss 2.69 | ppl 14.75
475
+ | epoch 1 step 1900 | 1900 batches | lr 0.000189 | ms/batch 609.27 | loss 2.43 | avg loss 2.66 | ppl 14.31
476
+ | epoch 1 step 2000 | 2000 batches | lr 0.000188 | ms/batch 607.05 | loss 2.71 | avg loss 2.66 | ppl 14.36
477
+ saving checkpoint ./trained_models/GPT2_M/e2e/model.2000.pt
478
+ /home/inc/miniconda3/envs/fedadp-new/lib/python3.7/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.
479
+ warnings.warn(warning.format(ret))
480
+ eval samples: 0 loss: tensor(1.1374, device='cuda:4')
481
+ eval samples: 100 loss: tensor(1.0985, device='cuda:4')
482
+ eval samples: 200 loss: tensor(1.2215, device='cuda:4')
483
+ eval samples: 300 loss: tensor(1.2918, device='cuda:4')
484
+ eval samples: 400 loss: tensor(1.6716, device='cuda:4')
485
+ eval samples: 500 loss: tensor(1.9854, device='cuda:4')
486
+ eval samples: 600 loss: tensor(1.2216, device='cuda:4')
487
+ eval samples: 700 loss: tensor(1.0347, device='cuda:4')
488
+ eval samples: 800 loss: tensor(1.5289, device='cuda:4')
489
+ eval samples: 900 loss: tensor(1.5743, device='cuda:4')
490
+ eval samples: 1000 loss: tensor(1.3339, device='cuda:4')
491
+ eval samples: 1100 loss: tensor(1.3198, device='cuda:4')
492
+ average loss 1.3344345796496084
493
+ ----------------------------------------------------------------------------------------------------
494
+ | Eval 1 at step 2000 | time: 137.89s | valid loss 1.33 | valid ppl 3.80 | best ppl 3.80
495
+ ----------------------------------------------------------------------------------------------------
496
+ | epoch 1 step 2100 | 2100 batches | lr 0.000188 | ms/batch 1988.14 | loss 2.64 | avg loss 2.68 | ppl 14.57
497
+ | epoch 1 step 2200 | 2200 batches | lr 0.000187 | ms/batch 608.77 | loss 2.45 | avg loss 2.66 | ppl 14.34
498
+ | epoch 1 step 2300 | 2300 batches | lr 0.000186 | ms/batch 610.52 | loss 2.60 | avg loss 2.67 | ppl 14.38
499
+ | epoch 1 step 2400 | 2400 batches | lr 0.000185 | ms/batch 608.14 | loss 2.70 | avg loss 2.67 | ppl 14.49
500
+ | epoch 1 step 2500 | 2500 batches | lr 0.000184 | ms/batch 607.87 | loss 2.52 | avg loss 2.64 | ppl 14.05
501
+ | epoch 1 step 2600 | 2600 batches | lr 0.000184 | ms/batch 608.44 | loss 2.54 | avg loss 2.70 | ppl 14.85
502
+ | epoch 1 step 2700 | 2700 batches | lr 0.000183 | ms/batch 608.49 | loss 2.87 | avg loss 2.69 | ppl 14.72
503
+ | epoch 1 step 2800 | 2800 batches | lr 0.000182 | ms/batch 608.82 | loss 2.44 | avg loss 2.66 | ppl 14.26
504
+ | epoch 1 step 2900 | 2900 batches | lr 0.000181 | ms/batch 609.19 | loss 2.69 | avg loss 2.68 | ppl 14.52
505
+ | epoch 1 step 3000 | 3000 batches | lr 0.000181 | ms/batch 609.05 | loss 2.73 | avg loss 2.64 | ppl 13.99
506
+ saving checkpoint ./trained_models/GPT2_M/e2e/model.3000.pt
507
+ | epoch 1 step 3100 | 3100 batches | lr 0.00018 | ms/batch 609.17 | loss 2.63 | avg loss 2.64 | ppl 14.04
508
+ | epoch 1 step 3200 | 3200 batches | lr 0.000179 | ms/batch 609.50 | loss 2.57 | avg loss 2.66 | ppl 14.28
509
+ | epoch 1 step 3300 | 3300 batches | lr 0.000178 | ms/batch 607.31 | loss 2.47 | avg loss 2.62 | ppl 13.76
510
+ | epoch 1 step 3400 | 3400 batches | lr 0.000178 | ms/batch 604.83 | loss 2.54 | avg loss 2.60 | ppl 13.49
511
+ | epoch 1 step 3500 | 3500 batches | lr 0.000177 | ms/batch 607.92 | loss 2.62 | avg loss 2.63 | ppl 13.90
512
+ | epoch 1 step 3600 | 3600 batches | lr 0.000176 | ms/batch 608.49 | loss 2.41 | avg loss 2.62 | ppl 13.78
513
+ | epoch 1 step 3700 | 3700 batches | lr 0.000175 | ms/batch 605.91 | loss 2.58 | avg loss 2.59 | ppl 13.36
514
+ | epoch 1 step 3800 | 3800 batches | lr 0.000174 | ms/batch 607.54 | loss 2.46 | avg loss 2.64 | ppl 13.97
515
+ | epoch 1 step 3900 | 3900 batches | lr 0.000174 | ms/batch 610.01 | loss 2.68 | avg loss 2.66 | ppl 14.24
516
+ | epoch 1 step 4000 | 4000 batches | lr 0.000173 | ms/batch 607.98 | loss 2.78 | avg loss 2.64 | ppl 14.04
517
+ saving checkpoint ./trained_models/GPT2_M/e2e/model.4000.pt
518
+ eval samples: 0 loss: tensor(1.1133, device='cuda:4')
519
+ eval samples: 100 loss: tensor(1.0210, device='cuda:4')
520
+ eval samples: 200 loss: tensor(1.1742, device='cuda:4')
521
+ eval samples: 300 loss: tensor(1.2072, device='cuda:4')
522
+ eval samples: 400 loss: tensor(1.6256, device='cuda:4')
523
+ eval samples: 500 loss: tensor(1.9378, device='cuda:4')
524
+ eval samples: 600 loss: tensor(1.0971, device='cuda:4')
525
+ eval samples: 700 loss: tensor(1.0210, device='cuda:4')
526
+ eval samples: 800 loss: tensor(1.4538, device='cuda:4')
527
+ eval samples: 900 loss: tensor(1.5298, device='cuda:4')
528
+ eval samples: 1000 loss: tensor(1.2354, device='cuda:4')
529
+ eval samples: 1100 loss: tensor(1.2567, device='cuda:4')
530
+ average loss 1.2714025441506138
531
+ ----------------------------------------------------------------------------------------------------
532
+ | Eval 2 at step 4000 | time: 138.19s | valid loss 1.27 | valid ppl 3.57 | best ppl 3.57
533
+ ----------------------------------------------------------------------------------------------------
534
+ | epoch 1 step 4100 | 4100 batches | lr 0.000172 | ms/batch 1990.32 | loss 2.81 | avg loss 2.62 | ppl 13.78
535
+ | epoch 1 step 4200 | 4200 batches | lr 0.000171 | ms/batch 608.76 | loss 3.11 | avg loss 2.61 | ppl 13.57
536
+ | epoch 1 step 4300 | 4300 batches | lr 0.000171 | ms/batch 610.45 | loss 2.46 | avg loss 2.61 | ppl 13.63
537
+ | epoch 1 step 4400 | 4400 batches | lr 0.00017 | ms/batch 610.84 | loss 2.96 | avg loss 2.62 | ppl 13.74
538
+ | epoch 1 step 4500 | 4500 batches | lr 0.000169 | ms/batch 611.36 | loss 2.78 | avg loss 2.61 | ppl 13.58
539
+ | epoch 1 step 4600 | 4600 batches | lr 0.000168 | ms/batch 612.08 | loss 2.81 | avg loss 2.57 | ppl 13.07
540
+ | epoch 1 step 4700 | 4700 batches | lr 0.000167 | ms/batch 615.36 | loss 2.90 | avg loss 2.63 | ppl 13.91
541
+ | epoch 1 step 4800 | 4800 batches | lr 0.000167 | ms/batch 611.17 | loss 2.99 | avg loss 2.61 | ppl 13.55
542
+ | epoch 1 step 4900 | 4900 batches | lr 0.000166 | ms/batch 608.81 | loss 2.73 | avg loss 2.60 | ppl 13.47
543
+ | epoch 1 step 5000 | 5000 batches | lr 0.000165 | ms/batch 609.73 | loss 2.50 | avg loss 2.58 | ppl 13.26
544
+ saving checkpoint ./trained_models/GPT2_M/e2e/model.5000.pt
545
+ | epoch 1 step 5100 | 5100 batches | lr 0.000164 | ms/batch 609.36 | loss 2.27 | avg loss 2.59 | ppl 13.33
546
+ | epoch 1 step 5200 | 5200 batches | lr 0.000164 | ms/batch 611.66 | loss 2.39 | avg loss 2.62 | ppl 13.78
547
+ saving checkpoint ./trained_models/GPT2_M/e2e/model.5258.pt
548
+ start to train the model................ 2
549
+ | epoch 2 step 5300 | 42 batches | lr 0.000163 | ms/batch 256.06 | loss 2.41 | avg loss 2.61 | ppl 13.53
550
+ | epoch 2 step 5400 | 142 batches | lr 0.000162 | ms/batch 609.01 | loss 2.63 | avg loss 2.61 | ppl 13.58
551
+ | epoch 2 step 5500 | 242 batches | lr 0.000161 | ms/batch 612.10 | loss 2.45 | avg loss 2.59 | ppl 13.30
552
+ | epoch 2 step 5600 | 342 batches | lr 0.00016 | ms/batch 611.07 | loss 2.67 | avg loss 2.59 | ppl 13.27
553
+ | epoch 2 step 5700 | 442 batches | lr 0.00016 | ms/batch 611.19 | loss 2.52 | avg loss 2.64 | ppl 13.95
554
+ | epoch 2 step 5800 | 542 batches | lr 0.000159 | ms/batch 611.61 | loss 2.87 | avg loss 2.57 | ppl 13.10
555
+ | epoch 2 step 5900 | 642 batches | lr 0.000158 | ms/batch 612.67 | loss 3.17 | avg loss 2.58 | ppl 13.25
556
+ | epoch 2 step 6000 | 742 batches | lr 0.000157 | ms/batch 610.88 | loss 2.45 | avg loss 2.59 | ppl 13.32
557
+ saving checkpoint ./trained_models/GPT2_M/e2e/model.6000.pt
558
+ eval samples: 0 loss: tensor(1.0454, device='cuda:4')
559
+ eval samples: 100 loss: tensor(0.9909, device='cuda:4')
560
+ eval samples: 200 loss: tensor(1.1352, device='cuda:4')
561
+ eval samples: 300 loss: tensor(1.1335, device='cuda:4')
562
+ eval samples: 400 loss: tensor(1.5766, device='cuda:4')
563
+ eval samples: 500 loss: tensor(2.0034, device='cuda:4')
564
+ eval samples: 600 loss: tensor(1.1043, device='cuda:4')
565
+ eval samples: 700 loss: tensor(0.9965, device='cuda:4')
566
+ eval samples: 800 loss: tensor(1.4912, device='cuda:4')
567
+ eval samples: 900 loss: tensor(1.5128, device='cuda:4')
568
+ eval samples: 1000 loss: tensor(1.1385, device='cuda:4')
569
+ eval samples: 1100 loss: tensor(1.2201, device='cuda:4')
570
+ average loss 1.239899498908079
571
+ ----------------------------------------------------------------------------------------------------
572
+ | Eval 3 at step 6000 | time: 138.83s | valid loss 1.24 | valid ppl 3.46 | best ppl 3.46
573
+ ----------------------------------------------------------------------------------------------------
574
+ | epoch 2 step 6100 | 842 batches | lr 0.000157 | ms/batch 1999.78 | loss 2.55 | avg loss 2.61 | ppl 13.54
575
+ | epoch 2 step 6200 | 942 batches | lr 0.000156 | ms/batch 612.01 | loss 2.72 | avg loss 2.60 | ppl 13.48
576
+ | epoch 2 step 6300 | 1042 batches | lr 0.000155 | ms/batch 611.75 | loss 2.61 | avg loss 2.58 | ppl 13.26
577
+ | epoch 2 step 6400 | 1142 batches | lr 0.000154 | ms/batch 612.29 | loss 2.48 | avg loss 2.58 | ppl 13.15
578
+ | epoch 2 step 6500 | 1242 batches | lr 0.000153 | ms/batch 613.03 | loss 2.90 | avg loss 2.62 | ppl 13.67
579
+ | epoch 2 step 6600 | 1342 batches | lr 0.000153 | ms/batch 611.04 | loss 3.07 | avg loss 2.58 | ppl 13.16
580
+ | epoch 2 step 6700 | 1442 batches | lr 0.000152 | ms/batch 611.17 | loss 2.79 | avg loss 2.56 | ppl 12.96
581
+ | epoch 2 step 6800 | 1542 batches | lr 0.000151 | ms/batch 614.47 | loss 2.50 | avg loss 2.56 | ppl 12.95
582
+ | epoch 2 step 6900 | 1642 batches | lr 0.00015 | ms/batch 610.47 | loss 2.71 | avg loss 2.56 | ppl 12.99
583
+ | epoch 2 step 7000 | 1742 batches | lr 0.00015 | ms/batch 608.59 | loss 2.56 | avg loss 2.59 | ppl 13.37
584
+ saving checkpoint ./trained_models/GPT2_M/e2e/model.7000.pt
585
+ | epoch 2 step 7100 | 1842 batches | lr 0.000149 | ms/batch 610.96 | loss 2.32 | avg loss 2.57 | ppl 13.01
586
+ | epoch 2 step 7200 | 1942 batches | lr 0.000148 | ms/batch 610.97 | loss 2.41 | avg loss 2.53 | ppl 12.50
587
+ | epoch 2 step 7300 | 2042 batches | lr 0.000147 | ms/batch 611.57 | loss 2.48 | avg loss 2.57 | ppl 13.10
588
+ | epoch 2 step 7400 | 2142 batches | lr 0.000146 | ms/batch 610.40 | loss 2.39 | avg loss 2.56 | ppl 12.89
589
+ | epoch 2 step 7500 | 2242 batches | lr 0.000146 | ms/batch 610.66 | loss 2.63 | avg loss 2.57 | ppl 13.04
590
+ | epoch 2 step 7600 | 2342 batches | lr 0.000145 | ms/batch 610.52 | loss 2.63 | avg loss 2.58 | ppl 13.26
591
+ | epoch 2 step 7700 | 2442 batches | lr 0.000144 | ms/batch 608.69 | loss 2.22 | avg loss 2.54 | ppl 12.73
592
+ | epoch 2 step 7800 | 2542 batches | lr 0.000143 | ms/batch 609.99 | loss 2.35 | avg loss 2.57 | ppl 13.07
593
+ | epoch 2 step 7900 | 2642 batches | lr 0.000143 | ms/batch 609.05 | loss 2.72 | avg loss 2.60 | ppl 13.47
594
+ | epoch 2 step 8000 | 2742 batches | lr 0.000142 | ms/batch 609.02 | loss 2.57 | avg loss 2.59 | ppl 13.30
595
+ saving checkpoint ./trained_models/GPT2_M/e2e/model.8000.pt
596
+ eval samples: 0 loss: tensor(1.0535, device='cuda:4')
597
+ eval samples: 100 loss: tensor(0.9691, device='cuda:4')
598
+ eval samples: 200 loss: tensor(1.1137, device='cuda:4')
599
+ eval samples: 300 loss: tensor(1.1214, device='cuda:4')
600
+ eval samples: 400 loss: tensor(1.5688, device='cuda:4')
601
+ eval samples: 500 loss: tensor(1.9425, device='cuda:4')
602
+ eval samples: 600 loss: tensor(1.0476, device='cuda:4')
603
+ eval samples: 700 loss: tensor(0.9898, device='cuda:4')
604
+ eval samples: 800 loss: tensor(1.4776, device='cuda:4')
605
+ eval samples: 900 loss: tensor(1.5046, device='cuda:4')
606
+ eval samples: 1000 loss: tensor(1.1689, device='cuda:4')
607
+ eval samples: 1100 loss: tensor(1.1641, device='cuda:4')
608
+ average loss 1.2270236368456933
609
+ ----------------------------------------------------------------------------------------------------
610
+ | Eval 4 at step 8000 | time: 138.04s | valid loss 1.23 | valid ppl 3.41 | best ppl 3.41
611
+ ----------------------------------------------------------------------------------------------------
612
+ | epoch 2 step 8100 | 2842 batches | lr 0.000141 | ms/batch 1991.53 | loss 2.46 | avg loss 2.56 | ppl 12.98
613
+ | epoch 2 step 8200 | 2942 batches | lr 0.00014 | ms/batch 609.84 | loss 2.50 | avg loss 2.60 | ppl 13.49
614
+ | epoch 2 step 8300 | 3042 batches | lr 0.00014 | ms/batch 610.87 | loss 2.47 | avg loss 2.54 | ppl 12.72
615
+ | epoch 2 step 8400 | 3142 batches | lr 0.000139 | ms/batch 610.92 | loss 2.41 | avg loss 2.57 | ppl 13.03
616
+ | epoch 2 step 8500 | 3242 batches | lr 0.000138 | ms/batch 611.04 | loss 2.81 | avg loss 2.56 | ppl 12.89
617
+ | epoch 2 step 8600 | 3342 batches | lr 0.000137 | ms/batch 612.82 | loss 2.40 | avg loss 2.55 | ppl 12.87
618
+ | epoch 2 step 8700 | 3442 batches | lr 0.000136 | ms/batch 611.25 | loss 2.47 | avg loss 2.52 | ppl 12.43
619
+ | epoch 2 step 8800 | 3542 batches | lr 0.000136 | ms/batch 611.59 | loss 2.57 | avg loss 2.55 | ppl 12.86
620
+ | epoch 2 step 8900 | 3642 batches | lr 0.000135 | ms/batch 611.43 | loss 2.33 | avg loss 2.54 | ppl 12.62
621
+ | epoch 2 step 9000 | 3742 batches | lr 0.000134 | ms/batch 610.78 | loss 2.96 | avg loss 2.55 | ppl 12.78
622
+ saving checkpoint ./trained_models/GPT2_M/e2e/model.9000.pt
623
+ | epoch 2 step 9100 | 3842 batches | lr 0.000133 | ms/batch 608.39 | loss 2.67 | avg loss 2.55 | ppl 12.81
624
+ | epoch 2 step 9200 | 3942 batches | lr 0.000133 | ms/batch 611.72 | loss 2.65 | avg loss 2.58 | ppl 13.17
625
+ | epoch 2 step 9300 | 4042 batches | lr 0.000132 | ms/batch 611.24 | loss 2.60 | avg loss 2.58 | ppl 13.15
626
+ | epoch 2 step 9400 | 4142 batches | lr 0.000131 | ms/batch 613.45 | loss 2.58 | avg loss 2.56 | ppl 12.95
627
+ | epoch 2 step 9500 | 4242 batches | lr 0.00013 | ms/batch 611.51 | loss 2.40 | avg loss 2.54 | ppl 12.71
628
+ | epoch 2 step 9600 | 4342 batches | lr 0.000129 | ms/batch 613.03 | loss 2.62 | avg loss 2.53 | ppl 12.55
629
+ | epoch 2 step 9700 | 4442 batches | lr 0.000129 | ms/batch 612.45 | loss 2.26 | avg loss 2.54 | ppl 12.74
630
+ | epoch 2 step 9800 | 4542 batches | lr 0.000128 | ms/batch 610.95 | loss 2.78 | avg loss 2.55 | ppl 12.82
631
+ | epoch 2 step 9900 | 4642 batches | lr 0.000127 | ms/batch 608.32 | loss 2.61 | avg loss 2.52 | ppl 12.37
632
+ | epoch 2 step 10000 | 4742 batches | lr 0.000126 | ms/batch 610.72 | loss 2.45 | avg loss 2.54 | ppl 12.73
633
+ saving checkpoint ./trained_models/GPT2_M/e2e/model.10000.pt
634
+ eval samples: 0 loss: tensor(1.0123, device='cuda:4')
635
+ eval samples: 100 loss: tensor(1.0022, device='cuda:4')
636
+ eval samples: 200 loss: tensor(1.0972, device='cuda:4')
637
+ eval samples: 300 loss: tensor(1.1317, device='cuda:4')
638
+ eval samples: 400 loss: tensor(1.5788, device='cuda:4')
639
+ eval samples: 500 loss: tensor(1.9430, device='cuda:4')
640
+ eval samples: 600 loss: tensor(1.0426, device='cuda:4')
641
+ eval samples: 700 loss: tensor(0.9720, device='cuda:4')
642
+ eval samples: 800 loss: tensor(1.4556, device='cuda:4')
643
+ eval samples: 900 loss: tensor(1.4790, device='cuda:4')
644
+ eval samples: 1000 loss: tensor(1.1323, device='cuda:4')
645
+ eval samples: 1100 loss: tensor(1.1691, device='cuda:4')
646
+ average loss 1.2222425683006033
647
+ ----------------------------------------------------------------------------------------------------
648
+ | Eval 5 at step 10000 | time: 139.05s | valid loss 1.22 | valid ppl 3.39 | best ppl 3.39
649
+ ----------------------------------------------------------------------------------------------------
650
+ | epoch 2 step 10100 | 4842 batches | lr 0.000126 | ms/batch 2003.85 | loss 2.46 | avg loss 2.55 | ppl 12.79
651
+ | epoch 2 step 10200 | 4942 batches | lr 0.000125 | ms/batch 609.56 | loss 2.62 | avg loss 2.56 | ppl 12.88
652
+ | epoch 2 step 10300 | 5042 batches | lr 0.000124 | ms/batch 610.36 | loss 2.85 | avg loss 2.51 | ppl 12.28
653
+ | epoch 2 step 10400 | 5142 batches | lr 0.000123 | ms/batch 610.63 | loss 2.40 | avg loss 2.57 | ppl 13.05
654
+ | epoch 2 step 10500 | 5242 batches | lr 0.000122 | ms/batch 613.64 | loss 2.43 | avg loss 2.52 | ppl 12.45
655
+ saving checkpoint ./trained_models/GPT2_M/e2e/model.10516.pt
656
+ start to train the model................ 3
657
+ | epoch 3 step 10600 | 84 batches | lr 0.000122 | ms/batch 510.61 | loss 2.63 | avg loss 2.53 | ppl 12.61
658
+ | epoch 3 step 10700 | 184 batches | lr 0.000121 | ms/batch 613.48 | loss 2.67 | avg loss 2.56 | ppl 13.00
659
+ | epoch 3 step 10800 | 284 batches | lr 0.00012 | ms/batch 608.43 | loss 2.48 | avg loss 2.52 | ppl 12.39
660
+ | epoch 3 step 10900 | 384 batches | lr 0.000119 | ms/batch 611.59 | loss 2.69 | avg loss 2.56 | ppl 12.91
661
+
662
+
663
+
664
+
665
+
666
+ Running MS-COCO evaluator...
667
+ creating index...
668
+ index created!
669
+ Loading and preparing results...
670
+ DONE (t=0.00s)
671
+ creating index...
672
+ index created!
673
+ tokenization...
674
+ PTBTokenizer tokenized 22530 tokens at 184928.37 tokens per second.
675
+ PTBTokenizer tokenized 2122 tokens at 21442.98 tokens per second.
676
+ setting up scorers...
677
+ computing METEOR score...
678
+ METEOR: 0.485
679
+ computing Rouge score...
680
+ ROUGE_L: 0.761
681
+ computing CIDEr score...
682
+ CIDEr: 3.314
683
+ Running Py-MTEval metrics...
684
+ SCORES:
685
+ ==============
686
+ BLEU: 0.7401
687
+ NIST: 8.6766
688
+ METEOR: 0.4851
689
+ ROUGE_L: 0.7614
690
+ CIDEr: 3.3144
691
+ === lora.Linear, model.5258.pt ===
692
+
693
+ BLEU: 0.7905
694
+ NIST: 9.1684
695
+ METEOR: 0.5016
696
+ ROUGE_L: 0.7865
697
+ CIDEr: 3.4686
698
+ === lora.MergedLinear, model.26290.pt ===
examples/NLG/src/model.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # ------------------------------------------------------------------------------------------
5
+ import logging
6
+ import math
7
+ import os
8
+ from collections import OrderedDict
9
+ import copy
10
+ import math
11
+
12
+ import torch
13
+ from torch import nn
14
+ from torch.nn import CrossEntropyLoss, MSELoss
15
+ import torch.nn.functional as F
16
+ from torch.optim import Optimizer
17
+ from torch.optim.lr_scheduler import LambdaLR
18
+ from torch.nn.parameter import Parameter
19
+
20
+ import loralib as lora
21
+
22
+
23
+ def gelu(x):
24
+ return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
25
+
26
+
27
+ def gelu_fast(x):
28
+ return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
29
+
30
+
31
+ def gelu_new(x):
32
+ """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
33
+ Also see https://arxiv.org/abs/1606.08415
34
+ """
35
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
36
+
37
+
38
+ def swish(x):
39
+ return x * torch.sigmoid(x)
40
+
41
+
42
+ def _gelu_python(x):
43
+ """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
44
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
45
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
46
+ This is now written in C in torch.nn.functional
47
+ Also see https://arxiv.org/abs/1606.08415
48
+ """
49
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
50
+
51
+
52
+ class LayerNorm(nn.Module):
53
+ def __init__(self, hidden_size, eps=1e-12):
54
+ """Construct a layernorm module in the TF style (epsilon inside the square root)."""
55
+ super(LayerNorm, self).__init__()
56
+ self.weight = nn.Parameter(torch.ones(hidden_size))
57
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
58
+ self.variance_epsilon = eps
59
+
60
+ def forward(self, x):
61
+ u = x.mean(-1, keepdim=True)
62
+ s = (x - u).pow(2).mean(-1, keepdim=True)
63
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
64
+ return self.weight * x + self.bias
65
+
66
+
67
+ class Conv1D(nn.Module):
68
+ def __init__(self, nf, nx):
69
+ super(Conv1D, self).__init__()
70
+ self.nf = nf
71
+ w = torch.empty(nx, nf)
72
+ nn.init.normal_(w, std=0.02)
73
+ self.weight = Parameter(w)
74
+ self.bias = Parameter(torch.zeros(nf))
75
+
76
+ def forward(self, x):
77
+ size_out = x.size()[:-1] + (self.nf,)
78
+ x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
79
+ x = x.view(*size_out)
80
+ return x
81
+
82
+
83
+ class Attention(nn.Module):
84
+ def __init__(self, nx, n_ctx, config, scale=False):
85
+ super(Attention, self).__init__()
86
+ n_state = nx # in Attention: n_state=768 (nx=n_embd)
87
+ # [switch nx => n_state from Block to Attention to keep identical to TF implem]
88
+
89
+ assert n_state % config.n_head == 0
90
+ self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
91
+ self.n_head = config.n_head
92
+ self.split_size = n_state
93
+ self.scale = scale
94
+ self.c_attn = Conv1D(n_state * 3, nx)
95
+ self.c_attn = lora.MergedLinear(
96
+ nx, n_state * 3,
97
+ r=config.lora_attn_dim,
98
+ lora_alpha=config.lora_attn_alpha,
99
+ lora_dropout=config.lora_dropout,
100
+ enable_lora=[True, False, True],
101
+ fan_in_fan_out=True,
102
+ merge_weights=False
103
+ )
104
+ # self.c_attn = lora.Linear(
105
+ # nx, n_state * 3,
106
+ # r=config.lora_attn_dim,
107
+ # lora_alpha=config.lora_attn_alpha,
108
+ # lora_dropout=config.lora_dropout,
109
+ # fan_in_fan_out=True,
110
+ # merge_weights=False
111
+ # )
112
+ print(f"scaling = {config.lora_attn_alpha / config.lora_attn_dim}")
113
+ self.c_proj = Conv1D(n_state, nx)
114
+
115
+ self.config = config
116
+
117
+ def _attn(self, q, k, v, len_kv=None):
118
+ w = torch.matmul(q, k)
119
+ if self.scale:
120
+ w = w / math.sqrt(v.size(-1))
121
+ nd, ns = w.size(-2), w.size(-1)
122
+ b = self.bias[:, :, ns-nd:ns, :ns]
123
+ w = w * b - 1e10 * (1 - b)
124
+
125
+ # q : (batch, head, q_seq_length, head_features)
126
+ # k : (batch, head, head_features, kv_seq_length)
127
+ # w : (batch, head, q_seq_length, kv_seq_length)
128
+ # v : (batch, head, kv_seq_length, head_features)
129
+ if len_kv is not None:
130
+ _len = torch.arange(k.size(-1), device=k.device)
131
+ _input_msk = _len[None, :] >= (len_kv)[:, None]
132
+ w = w.masked_fill(_input_msk.unsqueeze(1).unsqueeze(2), -1.0e10)
133
+
134
+ w = nn.Softmax(dim=-1)(w)
135
+ return torch.matmul(w, v)
136
+
137
+ def merge_heads(self, x):
138
+ x = x.permute(0, 2, 1, 3).contiguous()
139
+ new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
140
+ return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
141
+
142
+ def split_heads(self, x, k=False):
143
+ new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
144
+ x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
145
+ if k:
146
+ return x.permute(0, 2, 3, 1).contiguous() # (batch, head, head_features, seq_length)
147
+ else:
148
+ return x.permute(0, 2, 1, 3).contiguous() # (batch, head, seq_length, head_features)
149
+
150
+ def forward(self, x, history=None, layer_past=None, len_past=None):
151
+ hidden_states = x
152
+
153
+ x = self.c_attn(x)
154
+ query, key, value = x.split(self.split_size, dim=2)
155
+
156
+ query = self.split_heads(query)
157
+ key = self.split_heads(key, k=True)
158
+ value = self.split_heads(value)
159
+
160
+ #_input_msk = None
161
+
162
+ len_kv = None
163
+
164
+ if layer_past is not None:
165
+ # key : (batch, head, head_features, seq_length)
166
+ # value : (batch, head, seq_length, head_features)
167
+ # layer_past, key : (batch, head, seq_length, head_features)
168
+ if len_past is None:
169
+ past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
170
+ key = torch.cat((past_key, key), dim=-1)
171
+ value = torch.cat((past_value, value), dim=-2)
172
+ else:
173
+ key_seq = key.shape[-1]
174
+ assert key_seq == 1
175
+
176
+ _batch = torch.arange(0, key.shape[0], dtype=torch.long, device=key.device)
177
+
178
+ past_key, past_value = layer_past[0], layer_past[1]
179
+
180
+ past_key[_batch,:,len_past,:] = key.squeeze(-1)
181
+ past_value[_batch,:,len_past,:] = value.squeeze(-2)
182
+
183
+ key = past_key.transpose(-2, -1)
184
+ value = past_value
185
+
186
+ len_kv = len_past + 1
187
+
188
+ present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
189
+ a = self._attn(query, key, value, len_kv = len_kv)
190
+ a = self.merge_heads(a)
191
+ a = self.c_proj(a)
192
+ # logging.info(f"attention forward: {a[0,0,:100]}, present: {present[0,0,0,:]}")
193
+ return a, present
194
+
195
+
196
+ class MLP(nn.Module):
197
+ def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
198
+ super(MLP, self).__init__()
199
+ nx = config.n_embd
200
+ self.c_fc = Conv1D(n_state, nx)
201
+ self.c_proj = Conv1D(nx, n_state)
202
+ self.act = gelu
203
+
204
+ def forward(self, x):
205
+ h = self.act(self.c_fc(x))
206
+ h2 = self.c_proj(h)
207
+ return h2
208
+
209
+
210
+ class Block(nn.Module):
211
+ def __init__(self, n_ctx, config, scale=False):
212
+ super(Block, self).__init__()
213
+ nx = config.n_embd
214
+ self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
215
+ self.attn = Attention(nx, n_ctx, config, scale)
216
+ self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
217
+ self.mlp = MLP(4 * nx, config)
218
+
219
+ def forward(self, x, layer_past=None, len_past=None):
220
+ a, present = self.attn(self.ln_1(x), layer_past=layer_past, len_past=len_past)
221
+ x = x + a
222
+ m = self.mlp(self.ln_2(x))
223
+ x = x + m
224
+ return x, present
225
+
226
+
227
+ class GPT2Model(nn.Module):
228
+ def __init__(self, config):
229
+ super(GPT2Model, self).__init__()
230
+ self.n_layer = config.n_layer
231
+ self.n_embd = config.n_embd
232
+ self.n_vocab = config.vocab_size
233
+
234
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
235
+ self.wpe = nn.Embedding(config.n_positions, config.n_embd)
236
+ block = Block(config.n_ctx, config, scale=True)
237
+ self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
238
+ self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
239
+
240
+ self.config = config
241
+
242
+
243
+ def forward(
244
+ self,
245
+ input_ids,
246
+ position_ids=None,
247
+ token_type_ids=None,
248
+ past=None,
249
+ len_past=None
250
+ ):
251
+ if past is None:
252
+ past_length = 0
253
+ past = [None] * len(self.h)
254
+ elif len_past is None:
255
+ # equal size for past. []
256
+ past_length = past[0][0].size(-2)
257
+
258
+ if position_ids is None and len_past is None:
259
+ position_ids = torch.arange(
260
+ past_length, input_ids.size(-1) + past_length,
261
+ dtype=torch.long, device=input_ids.device
262
+ )
263
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
264
+ elif len_past is not None:
265
+ position_ids = (len_past).unsqueeze(1) #.long()
266
+
267
+ input_shape = input_ids.size()
268
+ input_ids = input_ids.view(-1, input_ids.size(-1))
269
+ position_ids = position_ids.view(-1, position_ids.size(-1))
270
+
271
+ inputs_embeds = self.wte(input_ids)
272
+
273
+ position_embeds = self.wpe(position_ids)
274
+
275
+ if token_type_ids is not None:
276
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
277
+ token_type_embeds = self.wte(token_type_ids)
278
+ else:
279
+ token_type_embeds = 0
280
+ hidden_states = inputs_embeds + position_embeds + token_type_embeds
281
+ presents = []
282
+ for block, layer_past in zip(self.h, past):
283
+ hidden_states, present = block(hidden_states, layer_past = layer_past, len_past=len_past)
284
+ presents.append(present)
285
+ hidden_states = self.ln_f(hidden_states)
286
+ output_shape = input_shape + (hidden_states.size(-1),)
287
+ return hidden_states.view(*output_shape), presents
288
+
289
+
290
+ class GPT2LMHead(nn.Module):
291
+ def __init__(self, model_embeddings_weights, config):
292
+ super(GPT2LMHead, self).__init__()
293
+ self.n_embd = config.n_embd
294
+ self.set_embeddings_weights(model_embeddings_weights)
295
+
296
+ def set_embeddings_weights(self, model_embeddings_weights):
297
+ embed_shape = model_embeddings_weights.shape
298
+ self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
299
+ self.decoder.weight = model_embeddings_weights # Tied weights
300
+
301
+ def forward(self, hidden_state):
302
+ # Truncated Language modeling logits (we remove the last token)
303
+ # h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
304
+ lm_logits = self.decoder(hidden_state)
305
+ return lm_logits
306
+
307
+
308
+ class GPT2Config(object):
309
+ def __init__(
310
+ self,
311
+ vocab_size_or_config_json_file=50257,
312
+ n_positions=1024,
313
+ n_ctx=1024,
314
+ n_embd=768,
315
+ n_layer=12,
316
+ n_head=12,
317
+ layer_norm_epsilon=1e-5,
318
+ initializer_range=0.02,
319
+ lora_attn_dim=0,
320
+ lora_attn_alpha=128,
321
+ lora_dropout=0.0,
322
+ lora_r_dropout=0.0,
323
+ fix_dropout=0.0,
324
+ ):
325
+ self.vocab_size = vocab_size_or_config_json_file
326
+ self.n_ctx = n_ctx
327
+ self.n_positions = n_positions
328
+ self.n_embd = n_embd
329
+ self.n_layer = n_layer
330
+ self.n_head = n_head
331
+ self.layer_norm_epsilon = layer_norm_epsilon
332
+ self.initializer_range = initializer_range
333
+ self.lora_attn_dim = lora_attn_dim
334
+ self.lora_attn_alpha = lora_attn_alpha
335
+ self.lora_dropout = lora_dropout
336
+ self.lora_r_dropout = lora_r_dropout
337
+
338
+ self.fix_dropout = fix_dropout
339
+
340
+
341
+ class GPT2LMModel(nn.Module):
342
+ def __init__(self, config):
343
+ super(GPT2LMModel, self).__init__()
344
+ self.transformer = GPT2Model(config)
345
+ self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
346
+ self.apply(self._init_weights)
347
+
348
+ def set_tied(self):
349
+ """ Make sure we are sharing the embeddings"""
350
+ self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
351
+
352
+ def forward(
353
+ self,
354
+ input_ids,
355
+ lm_labels=None,
356
+ lm_mask=None,
357
+ past=None,
358
+ len_past=None,
359
+ label_smooth=0.0,
360
+ is_report_accuracy=False
361
+ ):
362
+ _batch, _len = input_ids.shape
363
+ hidden_states, presents = self.transformer(input_ids, past=past, len_past=len_past)
364
+
365
+ # batch, seq, vocab
366
+ lm_logits = self.lm_head(hidden_states)
367
+
368
+ if lm_labels is not None:
369
+
370
+ if is_report_accuracy:
371
+ _pred_token = torch.argmax(lm_logits, dim=-1)
372
+ _hit = (_pred_token == lm_labels) * lm_mask
373
+
374
+ _t1_acc = torch.zeros(_batch, dtype=torch.float, device=input_ids.device)
375
+ _all_acc = torch.zeros(_batch, dtype=torch.float, device=input_ids.device)
376
+
377
+ for _b in range(0, _batch):
378
+ for _i in range(0, _len):
379
+ if lm_mask[_b, _i] >= 1.0:
380
+ if _hit[_b, _i] > 0:
381
+ _t1_acc[_b] = 1.0
382
+ break
383
+
384
+ _is_succ = True
385
+ for _i in range(0, _len):
386
+ if lm_mask[_b, _i] >= 1.0:
387
+ if _hit[_b, _i] <= 0:
388
+ _is_succ = False
389
+ break
390
+
391
+ if _is_succ:
392
+ _all_acc[_b] = 1.0
393
+
394
+ #_t1_acc = _t1_acc * 1.0 / _batch
395
+ #_all_acc = _all_acc * 1.0 / _batch
396
+
397
+ if label_smooth > 0.0001:
398
+ logprobs = torch.nn.functional.log_softmax(lm_logits.view(-1, lm_logits.size(-1)), dim=-1)
399
+ nll_loss = -logprobs.gather(dim=-1, index=lm_labels.view(-1).unsqueeze(1))
400
+ nll_loss = nll_loss.squeeze(1)
401
+ smooth_loss = -logprobs.mean(dim=-1)
402
+ loss = (1.0 - label_smooth) * nll_loss + label_smooth * smooth_loss
403
+ loss = loss.view(_batch, _len)
404
+ else:
405
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-1, reduce=False)
406
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)).view(_batch, _len)
407
+
408
+ if lm_mask is None:
409
+ lm_mask = torch.ones(loss.shape, dtype=loss.dtype, device=loss.device)
410
+ loss = loss * lm_mask
411
+
412
+ loss = loss.sum() / (lm_mask.sum() + 0.0001)
413
+
414
+ if is_report_accuracy:
415
+ return lm_logits, loss, _t1_acc, _all_acc
416
+ else:
417
+ return lm_logits, loss
418
+ return lm_logits, presents
419
+
420
+ def _init_weights(self, module):
421
+ if isinstance(module, (nn.Linear, nn.Embedding)):
422
+ module.weight.data.normal_(mean=0.0, std=0.02)
423
+ elif isinstance(module, nn.LayerNorm):
424
+ module.bias.data.zero_()
425
+ module.weight.data.fill_(1.0)
426
+ if isinstance(module, nn.Linear) and module.bias is not None:
427
+ module.bias.data.zero_()
428
+
429
+ def load_weight(self, state_dict):
430
+ if 'model_state_dict' in state_dict:
431
+ state_dict = state_dict['model_state_dict']
432
+
433
+ state_dict_tmp = copy.deepcopy(state_dict)
434
+ old_keys = []
435
+ new_keys = []
436
+ for key in state_dict_tmp:
437
+ new_key = None
438
+ if key.endswith(".g"):
439
+ new_key = key[:-2] + ".weight"
440
+ elif key.endswith(".b"):
441
+ new_key = key[:-2] + ".bias"
442
+ elif key.endswith(".w"):
443
+ new_key = key[:-2] + ".weight"
444
+
445
+ if key.startswith("module.transformer."):
446
+ new_key = key[len("module.transformer."):]
447
+
448
+ if new_key:
449
+ old_keys.append(key)
450
+ new_keys.append(new_key)
451
+
452
+ for old_key, new_key in zip(old_keys, new_keys):
453
+ state_dict[new_key] = state_dict.pop(old_key)
454
+
455
+ for n, p in self.transformer.named_parameters():
456
+ if n not in state_dict:
457
+ state_dict[n] = p
458
+
459
+ self.transformer.load_state_dict(state_dict, strict=False)
460
+ self.set_tied()