sooks commited on
Commit
270ff3e
·
1 Parent(s): 3118b47

Create train.py

Browse files
Files changed (1) hide show
  1. detector/train.py +305 -0
detector/train.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training code for the detector model"""
2
+
3
+ import argparse
4
+ import os
5
+ import subprocess
6
+ import sys
7
+ from itertools import count
8
+ from multiprocessing import Process
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+ from torch import nn
13
+ from torch.nn.parallel import DistributedDataParallel
14
+ from torch.optim import Adam
15
+ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
16
+ from tqdm import tqdm
17
+ from transformers import *
18
+
19
+ from .dataset import Corpus, EncodedDataset
20
+ from .download import download
21
+ from .utils import summary, distributed
22
+
23
+
24
+ def setup_distributed(port=29500):
25
+ if not dist.is_available() or not torch.cuda.is_available() or torch.cuda.device_count() <= 1:
26
+ return 0, 1
27
+
28
+ if 'MPIR_CVAR_CH3_INTERFACE_HOSTNAME' in os.environ:
29
+ from mpi4py import MPI
30
+ mpi_rank = MPI.COMM_WORLD.Get_rank()
31
+ mpi_size = MPI.COMM_WORLD.Get_size()
32
+
33
+ os.environ["MASTER_ADDR"] = '127.0.0.1'
34
+ os.environ["MASTER_PORT"] = str(port)
35
+
36
+ dist.init_process_group(backend="nccl", world_size=mpi_size, rank=mpi_rank)
37
+ return mpi_rank, mpi_size
38
+
39
+ dist.init_process_group(backend="nccl", init_method="env://")
40
+ return dist.get_rank(), dist.get_world_size()
41
+
42
+
43
+ def load_datasets(data_dir, real_dataset, fake_dataset, tokenizer, batch_size,
44
+ max_sequence_length, random_sequence_length, epoch_size=None, token_dropout=None, seed=None):
45
+ if fake_dataset == 'TWO':
46
+ download(real_dataset, 'xl-1542M', 'xl-1542M-nucleus', data_dir=data_dir)
47
+ elif fake_dataset == 'THREE':
48
+ download(real_dataset, 'xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus', data_dir=data_dir)
49
+ else:
50
+ download(real_dataset, fake_dataset, data_dir=data_dir)
51
+
52
+ real_corpus = Corpus(real_dataset, data_dir=data_dir)
53
+
54
+ if fake_dataset == "TWO":
55
+ real_train, real_valid = real_corpus.train * 2, real_corpus.valid * 2
56
+ fake_corpora = [Corpus(name, data_dir=data_dir) for name in ['xl-1542M', 'xl-1542M-nucleus']]
57
+ fake_train = sum([corpus.train for corpus in fake_corpora], [])
58
+ fake_valid = sum([corpus.valid for corpus in fake_corpora], [])
59
+ elif fake_dataset == "THREE":
60
+ real_train, real_valid = real_corpus.train * 3, real_corpus.valid * 3
61
+ fake_corpora = [Corpus(name, data_dir=data_dir) for name in
62
+ ['xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus']]
63
+ fake_train = sum([corpus.train for corpus in fake_corpora], [])
64
+ fake_valid = sum([corpus.valid for corpus in fake_corpora], [])
65
+ else:
66
+ fake_corpus = Corpus(fake_dataset, data_dir=data_dir)
67
+
68
+ real_train, real_valid = real_corpus.train, real_corpus.valid
69
+ fake_train, fake_valid = fake_corpus.train, fake_corpus.valid
70
+
71
+ Sampler = DistributedSampler if distributed() and dist.get_world_size() > 1 else RandomSampler
72
+
73
+ min_sequence_length = 10 if random_sequence_length else None
74
+ train_dataset = EncodedDataset(real_train, fake_train, tokenizer, max_sequence_length, min_sequence_length,
75
+ epoch_size, token_dropout, seed)
76
+ train_loader = DataLoader(train_dataset, batch_size, sampler=Sampler(train_dataset), num_workers=0)
77
+
78
+ validation_dataset = EncodedDataset(real_valid, fake_valid, tokenizer)
79
+ validation_loader = DataLoader(validation_dataset, batch_size=1, sampler=Sampler(validation_dataset))
80
+
81
+ return train_loader, validation_loader
82
+
83
+
84
+ def accuracy_sum(logits, labels):
85
+ if list(logits.shape) == list(labels.shape) + [2]:
86
+ # 2-d outputs
87
+ classification = (logits[..., 0] < logits[..., 1]).long().flatten()
88
+ else:
89
+ classification = (logits > 0).long().flatten()
90
+ assert classification.shape == labels.shape
91
+ return (classification == labels).float().sum().item()
92
+
93
+
94
+ def train(model: nn.Module, optimizer, device: str, loader: DataLoader, desc='Train'):
95
+ model.train()
96
+
97
+ train_accuracy = 0
98
+ train_epoch_size = 0
99
+ train_loss = 0
100
+
101
+ with tqdm(loader, desc=desc, disable=distributed() and dist.get_rank() > 0) as loop:
102
+ for texts, masks, labels in loop:
103
+
104
+ texts, masks, labels = texts.to(device), masks.to(device), labels.to(device)
105
+ batch_size = texts.shape[0]
106
+
107
+ optimizer.zero_grad()
108
+ loss, logits = model(texts, attention_mask=masks, labels=labels)
109
+ loss.backward()
110
+ optimizer.step()
111
+
112
+ batch_accuracy = accuracy_sum(logits, labels)
113
+ train_accuracy += batch_accuracy
114
+ train_epoch_size += batch_size
115
+ train_loss += loss.item() * batch_size
116
+
117
+ loop.set_postfix(loss=loss.item(), acc=train_accuracy / train_epoch_size)
118
+
119
+ return {
120
+ "train/accuracy": train_accuracy,
121
+ "train/epoch_size": train_epoch_size,
122
+ "train/loss": train_loss
123
+ }
124
+
125
+
126
+ def validate(model: nn.Module, device: str, loader: DataLoader, votes=1, desc='Validation'):
127
+ model.eval()
128
+
129
+ validation_accuracy = 0
130
+ validation_epoch_size = 0
131
+ validation_loss = 0
132
+
133
+ records = [record for v in range(votes) for record in tqdm(loader, desc=f'Preloading data ... {v}',
134
+ disable=dist.is_available() and dist.get_rank() > 0)]
135
+ records = [[records[v * len(loader) + i] for v in range(votes)] for i in range(len(loader))]
136
+
137
+ with tqdm(records, desc=desc, disable=distributed() and dist.get_rank() > 0) as loop, torch.no_grad():
138
+ for example in loop:
139
+ losses = []
140
+ logit_votes = []
141
+
142
+ for texts, masks, labels in example:
143
+ texts, masks, labels = texts.to(device), masks.to(device), labels.to(device)
144
+ batch_size = texts.shape[0]
145
+
146
+ loss, logits = model(texts, attention_mask=masks, labels=labels)
147
+ losses.append(loss)
148
+ logit_votes.append(logits)
149
+
150
+ loss = torch.stack(losses).mean(dim=0)
151
+ logits = torch.stack(logit_votes).mean(dim=0)
152
+
153
+ batch_accuracy = accuracy_sum(logits, labels)
154
+ validation_accuracy += batch_accuracy
155
+ validation_epoch_size += batch_size
156
+ validation_loss += loss.item() * batch_size
157
+
158
+ loop.set_postfix(loss=loss.item(), acc=validation_accuracy / validation_epoch_size)
159
+
160
+ return {
161
+ "validation/accuracy": validation_accuracy,
162
+ "validation/epoch_size": validation_epoch_size,
163
+ "validation/loss": validation_loss
164
+ }
165
+
166
+
167
+ def _all_reduce_dict(d, device):
168
+ # wrap in tensor and use reduce to gpu0 tensor
169
+ output_d = {}
170
+ for (key, value) in sorted(d.items()):
171
+ tensor_input = torch.tensor([[value]]).to(device)
172
+ torch.distributed.all_reduce(tensor_input)
173
+ output_d[key] = tensor_input.item()
174
+ return output_d
175
+
176
+
177
+ def run(max_epochs=None,
178
+ device=None,
179
+ batch_size=24,
180
+ max_sequence_length=128,
181
+ random_sequence_length=False,
182
+ epoch_size=None,
183
+ seed=None,
184
+ data_dir='data',
185
+ real_dataset='webtext',
186
+ fake_dataset='xl-1542M-nucleus',
187
+ token_dropout=None,
188
+ large=False,
189
+ learning_rate=2e-5,
190
+ weight_decay=0,
191
+ **kwargs):
192
+ args = locals()
193
+ rank, world_size = setup_distributed()
194
+
195
+ if device is None:
196
+ device = f'cuda:{rank}' if torch.cuda.is_available() else 'cpu'
197
+
198
+ print('rank:', rank, 'world_size:', world_size, 'device:', device)
199
+
200
+ import torch.distributed as dist
201
+ if distributed() and rank > 0:
202
+ dist.barrier()
203
+
204
+ model_name = 'roberta-large' if large else 'roberta-base'
205
+ tokenization_utils.logger.setLevel('ERROR')
206
+ tokenizer = RobertaTokenizer.from_pretrained(model_name)
207
+ model = RobertaForSequenceClassification.from_pretrained(model_name).to(device)
208
+
209
+ if rank == 0:
210
+ summary(model)
211
+ if distributed():
212
+ dist.barrier()
213
+
214
+ if world_size > 1:
215
+ model = DistributedDataParallel(model, [rank], output_device=rank, find_unused_parameters=True)
216
+
217
+ train_loader, validation_loader = load_datasets(data_dir, real_dataset, fake_dataset, tokenizer, batch_size,
218
+ max_sequence_length, random_sequence_length, epoch_size,
219
+ token_dropout, seed)
220
+
221
+ optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
222
+ epoch_loop = count(1) if max_epochs is None else range(1, max_epochs + 1)
223
+
224
+ logdir = os.environ.get("OPENAI_LOGDIR", "logs")
225
+ os.makedirs(logdir, exist_ok=True)
226
+
227
+ from torch.utils.tensorboard import SummaryWriter
228
+ writer = SummaryWriter(logdir) if rank == 0 else None
229
+ best_validation_accuracy = 0
230
+
231
+ for epoch in epoch_loop:
232
+ if world_size > 1:
233
+ train_loader.sampler.set_epoch(epoch)
234
+ validation_loader.sampler.set_epoch(epoch)
235
+
236
+ train_metrics = train(model, optimizer, device, train_loader, f'Epoch {epoch}')
237
+ validation_metrics = validate(model, device, validation_loader)
238
+
239
+ combined_metrics = _all_reduce_dict({**validation_metrics, **train_metrics}, device)
240
+
241
+ combined_metrics["train/accuracy"] /= combined_metrics["train/epoch_size"]
242
+ combined_metrics["train/loss"] /= combined_metrics["train/epoch_size"]
243
+ combined_metrics["validation/accuracy"] /= combined_metrics["validation/epoch_size"]
244
+ combined_metrics["validation/loss"] /= combined_metrics["validation/epoch_size"]
245
+
246
+ if rank == 0:
247
+ for key, value in combined_metrics.items():
248
+ writer.add_scalar(key, value, global_step=epoch)
249
+
250
+ if combined_metrics["validation/accuracy"] > best_validation_accuracy:
251
+ best_validation_accuracy = combined_metrics["validation/accuracy"]
252
+
253
+ model_to_save = model.module if hasattr(model, 'module') else model
254
+ torch.save(dict(
255
+ epoch=epoch,
256
+ model_state_dict=model_to_save.state_dict(),
257
+ optimizer_state_dict=optimizer.state_dict(),
258
+ args=args
259
+ ),
260
+ os.path.join(logdir, "best-model.pt")
261
+ )
262
+
263
+
264
+ if __name__ == '__main__':
265
+ parser = argparse.ArgumentParser()
266
+
267
+ parser.add_argument('--max-epochs', type=int, default=None)
268
+ parser.add_argument('--device', type=str, default=None)
269
+ parser.add_argument('--batch-size', type=int, default=24)
270
+ parser.add_argument('--max-sequence-length', type=int, default=128)
271
+ parser.add_argument('--random-sequence-length', action='store_true')
272
+ parser.add_argument('--epoch-size', type=int, default=None)
273
+ parser.add_argument('--seed', type=int, default=None)
274
+ parser.add_argument('--data-dir', type=str, default='data')
275
+ parser.add_argument('--real-dataset', type=str, default='webtext')
276
+ parser.add_argument('--fake-dataset', type=str, default='xl-1542M-k40')
277
+ parser.add_argument('--token-dropout', type=float, default=None)
278
+
279
+ parser.add_argument('--large', action='store_true', help='use the roberta-large model instead of roberta-base')
280
+ parser.add_argument('--learning-rate', type=float, default=2e-5)
281
+ parser.add_argument('--weight-decay', type=float, default=0)
282
+ args = parser.parse_args()
283
+
284
+ nproc = int(subprocess.check_output([sys.executable, '-c', "import torch;"
285
+ "print(torch.cuda.device_count() if torch.cuda.is_available() else 1)"]))
286
+ if nproc > 1:
287
+ print(f'Launching {nproc} processes ...', file=sys.stderr)
288
+
289
+ os.environ["MASTER_ADDR"] = '127.0.0.1'
290
+ os.environ["MASTER_PORT"] = str(29500)
291
+ os.environ['WORLD_SIZE'] = str(nproc)
292
+ os.environ['OMP_NUM_THREAD'] = str(1)
293
+ subprocesses = []
294
+
295
+ for i in range(nproc):
296
+ os.environ['RANK'] = str(i)
297
+ os.environ['LOCAL_RANK'] = str(i)
298
+ process = Process(target=run, kwargs=vars(args))
299
+ process.start()
300
+ subprocesses.append(process)
301
+
302
+ for process in subprocesses:
303
+ process.join()
304
+ else:
305
+ run(**vars(args))