Create train.py
Browse files- detector/train.py +305 -0
detector/train.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Training code for the detector model"""
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import subprocess
|
6 |
+
import sys
|
7 |
+
from itertools import count
|
8 |
+
from multiprocessing import Process
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.distributed as dist
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn.parallel import DistributedDataParallel
|
14 |
+
from torch.optim import Adam
|
15 |
+
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
16 |
+
from tqdm import tqdm
|
17 |
+
from transformers import *
|
18 |
+
|
19 |
+
from .dataset import Corpus, EncodedDataset
|
20 |
+
from .download import download
|
21 |
+
from .utils import summary, distributed
|
22 |
+
|
23 |
+
|
24 |
+
def setup_distributed(port=29500):
|
25 |
+
if not dist.is_available() or not torch.cuda.is_available() or torch.cuda.device_count() <= 1:
|
26 |
+
return 0, 1
|
27 |
+
|
28 |
+
if 'MPIR_CVAR_CH3_INTERFACE_HOSTNAME' in os.environ:
|
29 |
+
from mpi4py import MPI
|
30 |
+
mpi_rank = MPI.COMM_WORLD.Get_rank()
|
31 |
+
mpi_size = MPI.COMM_WORLD.Get_size()
|
32 |
+
|
33 |
+
os.environ["MASTER_ADDR"] = '127.0.0.1'
|
34 |
+
os.environ["MASTER_PORT"] = str(port)
|
35 |
+
|
36 |
+
dist.init_process_group(backend="nccl", world_size=mpi_size, rank=mpi_rank)
|
37 |
+
return mpi_rank, mpi_size
|
38 |
+
|
39 |
+
dist.init_process_group(backend="nccl", init_method="env://")
|
40 |
+
return dist.get_rank(), dist.get_world_size()
|
41 |
+
|
42 |
+
|
43 |
+
def load_datasets(data_dir, real_dataset, fake_dataset, tokenizer, batch_size,
|
44 |
+
max_sequence_length, random_sequence_length, epoch_size=None, token_dropout=None, seed=None):
|
45 |
+
if fake_dataset == 'TWO':
|
46 |
+
download(real_dataset, 'xl-1542M', 'xl-1542M-nucleus', data_dir=data_dir)
|
47 |
+
elif fake_dataset == 'THREE':
|
48 |
+
download(real_dataset, 'xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus', data_dir=data_dir)
|
49 |
+
else:
|
50 |
+
download(real_dataset, fake_dataset, data_dir=data_dir)
|
51 |
+
|
52 |
+
real_corpus = Corpus(real_dataset, data_dir=data_dir)
|
53 |
+
|
54 |
+
if fake_dataset == "TWO":
|
55 |
+
real_train, real_valid = real_corpus.train * 2, real_corpus.valid * 2
|
56 |
+
fake_corpora = [Corpus(name, data_dir=data_dir) for name in ['xl-1542M', 'xl-1542M-nucleus']]
|
57 |
+
fake_train = sum([corpus.train for corpus in fake_corpora], [])
|
58 |
+
fake_valid = sum([corpus.valid for corpus in fake_corpora], [])
|
59 |
+
elif fake_dataset == "THREE":
|
60 |
+
real_train, real_valid = real_corpus.train * 3, real_corpus.valid * 3
|
61 |
+
fake_corpora = [Corpus(name, data_dir=data_dir) for name in
|
62 |
+
['xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus']]
|
63 |
+
fake_train = sum([corpus.train for corpus in fake_corpora], [])
|
64 |
+
fake_valid = sum([corpus.valid for corpus in fake_corpora], [])
|
65 |
+
else:
|
66 |
+
fake_corpus = Corpus(fake_dataset, data_dir=data_dir)
|
67 |
+
|
68 |
+
real_train, real_valid = real_corpus.train, real_corpus.valid
|
69 |
+
fake_train, fake_valid = fake_corpus.train, fake_corpus.valid
|
70 |
+
|
71 |
+
Sampler = DistributedSampler if distributed() and dist.get_world_size() > 1 else RandomSampler
|
72 |
+
|
73 |
+
min_sequence_length = 10 if random_sequence_length else None
|
74 |
+
train_dataset = EncodedDataset(real_train, fake_train, tokenizer, max_sequence_length, min_sequence_length,
|
75 |
+
epoch_size, token_dropout, seed)
|
76 |
+
train_loader = DataLoader(train_dataset, batch_size, sampler=Sampler(train_dataset), num_workers=0)
|
77 |
+
|
78 |
+
validation_dataset = EncodedDataset(real_valid, fake_valid, tokenizer)
|
79 |
+
validation_loader = DataLoader(validation_dataset, batch_size=1, sampler=Sampler(validation_dataset))
|
80 |
+
|
81 |
+
return train_loader, validation_loader
|
82 |
+
|
83 |
+
|
84 |
+
def accuracy_sum(logits, labels):
|
85 |
+
if list(logits.shape) == list(labels.shape) + [2]:
|
86 |
+
# 2-d outputs
|
87 |
+
classification = (logits[..., 0] < logits[..., 1]).long().flatten()
|
88 |
+
else:
|
89 |
+
classification = (logits > 0).long().flatten()
|
90 |
+
assert classification.shape == labels.shape
|
91 |
+
return (classification == labels).float().sum().item()
|
92 |
+
|
93 |
+
|
94 |
+
def train(model: nn.Module, optimizer, device: str, loader: DataLoader, desc='Train'):
|
95 |
+
model.train()
|
96 |
+
|
97 |
+
train_accuracy = 0
|
98 |
+
train_epoch_size = 0
|
99 |
+
train_loss = 0
|
100 |
+
|
101 |
+
with tqdm(loader, desc=desc, disable=distributed() and dist.get_rank() > 0) as loop:
|
102 |
+
for texts, masks, labels in loop:
|
103 |
+
|
104 |
+
texts, masks, labels = texts.to(device), masks.to(device), labels.to(device)
|
105 |
+
batch_size = texts.shape[0]
|
106 |
+
|
107 |
+
optimizer.zero_grad()
|
108 |
+
loss, logits = model(texts, attention_mask=masks, labels=labels)
|
109 |
+
loss.backward()
|
110 |
+
optimizer.step()
|
111 |
+
|
112 |
+
batch_accuracy = accuracy_sum(logits, labels)
|
113 |
+
train_accuracy += batch_accuracy
|
114 |
+
train_epoch_size += batch_size
|
115 |
+
train_loss += loss.item() * batch_size
|
116 |
+
|
117 |
+
loop.set_postfix(loss=loss.item(), acc=train_accuracy / train_epoch_size)
|
118 |
+
|
119 |
+
return {
|
120 |
+
"train/accuracy": train_accuracy,
|
121 |
+
"train/epoch_size": train_epoch_size,
|
122 |
+
"train/loss": train_loss
|
123 |
+
}
|
124 |
+
|
125 |
+
|
126 |
+
def validate(model: nn.Module, device: str, loader: DataLoader, votes=1, desc='Validation'):
|
127 |
+
model.eval()
|
128 |
+
|
129 |
+
validation_accuracy = 0
|
130 |
+
validation_epoch_size = 0
|
131 |
+
validation_loss = 0
|
132 |
+
|
133 |
+
records = [record for v in range(votes) for record in tqdm(loader, desc=f'Preloading data ... {v}',
|
134 |
+
disable=dist.is_available() and dist.get_rank() > 0)]
|
135 |
+
records = [[records[v * len(loader) + i] for v in range(votes)] for i in range(len(loader))]
|
136 |
+
|
137 |
+
with tqdm(records, desc=desc, disable=distributed() and dist.get_rank() > 0) as loop, torch.no_grad():
|
138 |
+
for example in loop:
|
139 |
+
losses = []
|
140 |
+
logit_votes = []
|
141 |
+
|
142 |
+
for texts, masks, labels in example:
|
143 |
+
texts, masks, labels = texts.to(device), masks.to(device), labels.to(device)
|
144 |
+
batch_size = texts.shape[0]
|
145 |
+
|
146 |
+
loss, logits = model(texts, attention_mask=masks, labels=labels)
|
147 |
+
losses.append(loss)
|
148 |
+
logit_votes.append(logits)
|
149 |
+
|
150 |
+
loss = torch.stack(losses).mean(dim=0)
|
151 |
+
logits = torch.stack(logit_votes).mean(dim=0)
|
152 |
+
|
153 |
+
batch_accuracy = accuracy_sum(logits, labels)
|
154 |
+
validation_accuracy += batch_accuracy
|
155 |
+
validation_epoch_size += batch_size
|
156 |
+
validation_loss += loss.item() * batch_size
|
157 |
+
|
158 |
+
loop.set_postfix(loss=loss.item(), acc=validation_accuracy / validation_epoch_size)
|
159 |
+
|
160 |
+
return {
|
161 |
+
"validation/accuracy": validation_accuracy,
|
162 |
+
"validation/epoch_size": validation_epoch_size,
|
163 |
+
"validation/loss": validation_loss
|
164 |
+
}
|
165 |
+
|
166 |
+
|
167 |
+
def _all_reduce_dict(d, device):
|
168 |
+
# wrap in tensor and use reduce to gpu0 tensor
|
169 |
+
output_d = {}
|
170 |
+
for (key, value) in sorted(d.items()):
|
171 |
+
tensor_input = torch.tensor([[value]]).to(device)
|
172 |
+
torch.distributed.all_reduce(tensor_input)
|
173 |
+
output_d[key] = tensor_input.item()
|
174 |
+
return output_d
|
175 |
+
|
176 |
+
|
177 |
+
def run(max_epochs=None,
|
178 |
+
device=None,
|
179 |
+
batch_size=24,
|
180 |
+
max_sequence_length=128,
|
181 |
+
random_sequence_length=False,
|
182 |
+
epoch_size=None,
|
183 |
+
seed=None,
|
184 |
+
data_dir='data',
|
185 |
+
real_dataset='webtext',
|
186 |
+
fake_dataset='xl-1542M-nucleus',
|
187 |
+
token_dropout=None,
|
188 |
+
large=False,
|
189 |
+
learning_rate=2e-5,
|
190 |
+
weight_decay=0,
|
191 |
+
**kwargs):
|
192 |
+
args = locals()
|
193 |
+
rank, world_size = setup_distributed()
|
194 |
+
|
195 |
+
if device is None:
|
196 |
+
device = f'cuda:{rank}' if torch.cuda.is_available() else 'cpu'
|
197 |
+
|
198 |
+
print('rank:', rank, 'world_size:', world_size, 'device:', device)
|
199 |
+
|
200 |
+
import torch.distributed as dist
|
201 |
+
if distributed() and rank > 0:
|
202 |
+
dist.barrier()
|
203 |
+
|
204 |
+
model_name = 'roberta-large' if large else 'roberta-base'
|
205 |
+
tokenization_utils.logger.setLevel('ERROR')
|
206 |
+
tokenizer = RobertaTokenizer.from_pretrained(model_name)
|
207 |
+
model = RobertaForSequenceClassification.from_pretrained(model_name).to(device)
|
208 |
+
|
209 |
+
if rank == 0:
|
210 |
+
summary(model)
|
211 |
+
if distributed():
|
212 |
+
dist.barrier()
|
213 |
+
|
214 |
+
if world_size > 1:
|
215 |
+
model = DistributedDataParallel(model, [rank], output_device=rank, find_unused_parameters=True)
|
216 |
+
|
217 |
+
train_loader, validation_loader = load_datasets(data_dir, real_dataset, fake_dataset, tokenizer, batch_size,
|
218 |
+
max_sequence_length, random_sequence_length, epoch_size,
|
219 |
+
token_dropout, seed)
|
220 |
+
|
221 |
+
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
222 |
+
epoch_loop = count(1) if max_epochs is None else range(1, max_epochs + 1)
|
223 |
+
|
224 |
+
logdir = os.environ.get("OPENAI_LOGDIR", "logs")
|
225 |
+
os.makedirs(logdir, exist_ok=True)
|
226 |
+
|
227 |
+
from torch.utils.tensorboard import SummaryWriter
|
228 |
+
writer = SummaryWriter(logdir) if rank == 0 else None
|
229 |
+
best_validation_accuracy = 0
|
230 |
+
|
231 |
+
for epoch in epoch_loop:
|
232 |
+
if world_size > 1:
|
233 |
+
train_loader.sampler.set_epoch(epoch)
|
234 |
+
validation_loader.sampler.set_epoch(epoch)
|
235 |
+
|
236 |
+
train_metrics = train(model, optimizer, device, train_loader, f'Epoch {epoch}')
|
237 |
+
validation_metrics = validate(model, device, validation_loader)
|
238 |
+
|
239 |
+
combined_metrics = _all_reduce_dict({**validation_metrics, **train_metrics}, device)
|
240 |
+
|
241 |
+
combined_metrics["train/accuracy"] /= combined_metrics["train/epoch_size"]
|
242 |
+
combined_metrics["train/loss"] /= combined_metrics["train/epoch_size"]
|
243 |
+
combined_metrics["validation/accuracy"] /= combined_metrics["validation/epoch_size"]
|
244 |
+
combined_metrics["validation/loss"] /= combined_metrics["validation/epoch_size"]
|
245 |
+
|
246 |
+
if rank == 0:
|
247 |
+
for key, value in combined_metrics.items():
|
248 |
+
writer.add_scalar(key, value, global_step=epoch)
|
249 |
+
|
250 |
+
if combined_metrics["validation/accuracy"] > best_validation_accuracy:
|
251 |
+
best_validation_accuracy = combined_metrics["validation/accuracy"]
|
252 |
+
|
253 |
+
model_to_save = model.module if hasattr(model, 'module') else model
|
254 |
+
torch.save(dict(
|
255 |
+
epoch=epoch,
|
256 |
+
model_state_dict=model_to_save.state_dict(),
|
257 |
+
optimizer_state_dict=optimizer.state_dict(),
|
258 |
+
args=args
|
259 |
+
),
|
260 |
+
os.path.join(logdir, "best-model.pt")
|
261 |
+
)
|
262 |
+
|
263 |
+
|
264 |
+
if __name__ == '__main__':
|
265 |
+
parser = argparse.ArgumentParser()
|
266 |
+
|
267 |
+
parser.add_argument('--max-epochs', type=int, default=None)
|
268 |
+
parser.add_argument('--device', type=str, default=None)
|
269 |
+
parser.add_argument('--batch-size', type=int, default=24)
|
270 |
+
parser.add_argument('--max-sequence-length', type=int, default=128)
|
271 |
+
parser.add_argument('--random-sequence-length', action='store_true')
|
272 |
+
parser.add_argument('--epoch-size', type=int, default=None)
|
273 |
+
parser.add_argument('--seed', type=int, default=None)
|
274 |
+
parser.add_argument('--data-dir', type=str, default='data')
|
275 |
+
parser.add_argument('--real-dataset', type=str, default='webtext')
|
276 |
+
parser.add_argument('--fake-dataset', type=str, default='xl-1542M-k40')
|
277 |
+
parser.add_argument('--token-dropout', type=float, default=None)
|
278 |
+
|
279 |
+
parser.add_argument('--large', action='store_true', help='use the roberta-large model instead of roberta-base')
|
280 |
+
parser.add_argument('--learning-rate', type=float, default=2e-5)
|
281 |
+
parser.add_argument('--weight-decay', type=float, default=0)
|
282 |
+
args = parser.parse_args()
|
283 |
+
|
284 |
+
nproc = int(subprocess.check_output([sys.executable, '-c', "import torch;"
|
285 |
+
"print(torch.cuda.device_count() if torch.cuda.is_available() else 1)"]))
|
286 |
+
if nproc > 1:
|
287 |
+
print(f'Launching {nproc} processes ...', file=sys.stderr)
|
288 |
+
|
289 |
+
os.environ["MASTER_ADDR"] = '127.0.0.1'
|
290 |
+
os.environ["MASTER_PORT"] = str(29500)
|
291 |
+
os.environ['WORLD_SIZE'] = str(nproc)
|
292 |
+
os.environ['OMP_NUM_THREAD'] = str(1)
|
293 |
+
subprocesses = []
|
294 |
+
|
295 |
+
for i in range(nproc):
|
296 |
+
os.environ['RANK'] = str(i)
|
297 |
+
os.environ['LOCAL_RANK'] = str(i)
|
298 |
+
process = Process(target=run, kwargs=vars(args))
|
299 |
+
process.start()
|
300 |
+
subprocesses.append(process)
|
301 |
+
|
302 |
+
for process in subprocesses:
|
303 |
+
process.join()
|
304 |
+
else:
|
305 |
+
run(**vars(args))
|