PBusienei commited on
Commit
0b249e3
1 Parent(s): e0734d9

added train script

Browse files
Files changed (1) hide show
  1. train_script.py +361 -0
train_script.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train script for a single file
3
+
4
+ Need to set the TPU address first:
5
+ export XRT_TPU_CONFIG="localservice;0;localhost:51011"
6
+ """
7
+
8
+ import torch.multiprocessing as mp
9
+ import threading
10
+ import time
11
+ import random
12
+ import sys
13
+ import argparse
14
+ import gzip
15
+ import json
16
+ import logging
17
+ import tqdm
18
+ import torch
19
+ from torch import nn
20
+ from torch.utils.data import DataLoader
21
+ import torch
22
+ import torch_xla
23
+ import torch_xla.core
24
+ import torch_xla.core.functions
25
+ import torch_xla.core.xla_model as xm
26
+ import torch_xla.distributed.xla_multiprocessing as xmp
27
+ import torch_xla.distributed.parallel_loader as pl
28
+ import os
29
+ from shutil import copyfile
30
+
31
+
32
+ from transformers import (
33
+ AdamW,
34
+ AutoModel,
35
+ AutoTokenizer,
36
+ get_linear_schedule_with_warmup,
37
+ set_seed,
38
+ )
39
+
40
+ class AutoModelForSentenceEmbedding(nn.Module):
41
+ def __init__(self, model_name, tokenizer, args):
42
+ super(AutoModelForSentenceEmbedding, self).__init__()
43
+
44
+ assert args.pooling in ['mean', 'cls']
45
+
46
+ self.model = AutoModel.from_pretrained(model_name)
47
+ self.normalize = not args.no_normalize
48
+ self.tokenizer = tokenizer
49
+ self.pooling = args.pooling
50
+
51
+ def forward(self, **kwargs):
52
+ model_output = self.model(**kwargs)
53
+ if self.pooling == 'mean':
54
+ embeddings = self.mean_pooling(model_output, kwargs['attention_mask'])
55
+ elif self.pooling == 'cls':
56
+ embeddings = self.cls_pooling(model_output, kwargs['attention_mask'])
57
+
58
+ if self.normalize:
59
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
60
+
61
+ return embeddings
62
+
63
+ def mean_pooling(self, model_output, attention_mask):
64
+ token_embeddings = model_output[0] # First element of model_output contains all token embeddings
65
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
66
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
67
+
68
+ def cls_pooling(self, model_output, attention_mask):
69
+ return model_output[0][:,0]
70
+
71
+ def save_pretrained(self, output_path):
72
+ if xm.is_master_ordinal():
73
+ self.tokenizer.save_pretrained(output_path)
74
+ self.model.config.save_pretrained(output_path)
75
+
76
+ xm.save(self.model.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
77
+
78
+
79
+
80
+
81
+ def train_function(index, args, queue):
82
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
83
+ model = AutoModelForSentenceEmbedding(args.model, tokenizer, args)
84
+
85
+
86
+ ### Train Loop
87
+ device = xm.xla_device()
88
+ model = model.to(device)
89
+
90
+ # Instantiate optimizer
91
+ optimizer = AdamW(params=model.parameters(), lr=2e-5, correct_bias=True)
92
+
93
+ lr_scheduler = get_linear_schedule_with_warmup(
94
+ optimizer=optimizer,
95
+ num_warmup_steps=500,
96
+ num_training_steps=args.steps,
97
+ )
98
+
99
+ # Now we train the model
100
+ cross_entropy_loss = nn.CrossEntropyLoss()
101
+ max_grad_norm = 1
102
+
103
+ model.train()
104
+
105
+ for global_step in tqdm.trange(args.steps, disable=not xm.is_master_ordinal()):
106
+ #### Get the batch data
107
+ batch = queue.get()
108
+ #print(index, "batch {}x{}".format(len(batch), ",".join([str(len(b)) for b in batch])))
109
+
110
+
111
+ if len(batch[0]) == 2: #(anchor, positive)
112
+ text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length_a, truncation=True, padding="max_length")
113
+ text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length_b, truncation=True, padding="max_length")
114
+
115
+ ### Compute embeddings
116
+ embeddings_a = model(**text1.to(device))
117
+ embeddings_b = model(**text2.to(device))
118
+
119
+ ### Gather all embedings
120
+ embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
121
+ embeddings_b = torch_xla.core.functions.all_gather(embeddings_b)
122
+
123
+ ### Compute similarity scores 512 x 512
124
+ scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
125
+
126
+ ### Compute cross-entropy loss
127
+ labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
128
+
129
+ ## Symmetric loss as in CLIP
130
+ loss = (cross_entropy_loss(scores, labels) + cross_entropy_loss(scores.transpose(0, 1), labels)) / 2
131
+
132
+ else: #(anchor, positive, negative)
133
+ text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length_a, truncation=True, padding="max_length")
134
+ text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length_b, truncation=True, padding="max_length")
135
+ text3 = tokenizer([b[2] for b in batch], return_tensors="pt", max_length=args.max_length_b, truncation=True, padding="max_length")
136
+
137
+ embeddings_a = model(**text1.to(device))
138
+ embeddings_b1 = model(**text2.to(device))
139
+ embeddings_b2 = model(**text3.to(device))
140
+
141
+ embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
142
+ embeddings_b1 = torch_xla.core.functions.all_gather(embeddings_b1)
143
+ embeddings_b2 = torch_xla.core.functions.all_gather(embeddings_b2)
144
+
145
+ embeddings_b = torch.cat([embeddings_b1, embeddings_b2])
146
+
147
+ ### Compute similarity scores 512 x 1024
148
+ scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
149
+
150
+ ### Compute cross-entropy loss
151
+ labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
152
+
153
+ ## One-way loss
154
+ loss = cross_entropy_loss(scores, labels)
155
+
156
+
157
+ # Backward pass
158
+ optimizer.zero_grad()
159
+ loss.backward()
160
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
161
+
162
+ xm.optimizer_step(optimizer, barrier=True)
163
+ lr_scheduler.step()
164
+
165
+
166
+ #Save model
167
+ if (global_step+1) % args.save_steps == 0:
168
+ output_path = os.path.join(args.output, str(global_step+1))
169
+ xm.master_print("save model: "+output_path)
170
+ model.save_pretrained(output_path)
171
+
172
+
173
+ output_path = os.path.join(args.output, "final")
174
+ xm.master_print("save model final: "+ output_path)
175
+ model.save_pretrained(output_path)
176
+
177
+
178
+ def produce_data(args, queue, filepaths, dataset_indices):
179
+ global_batch_size = args.batch_size*args.nprocs #Global batch size
180
+ num_same_dataset = int(args.nprocs / args.datasets_per_batch)
181
+ print("producer", "global_batch_size", global_batch_size)
182
+ print("producer", "num_same_dataset", num_same_dataset)
183
+
184
+ datasets = []
185
+ for filepath in filepaths:
186
+ if "reddit_" in filepath: #Special dataset class for Reddit files
187
+ data_obj = RedditDataset(filepath)
188
+ else:
189
+ data_obj = Dataset(filepath)
190
+ datasets.append(iter(data_obj))
191
+
192
+ # Store if dataset is in a 2 col or 3 col format
193
+ num_cols = {idx: len(next(dataset)) for idx, dataset in enumerate(datasets)}
194
+
195
+ while True:
196
+ texts_in_batch = set()
197
+ batch_format = None #2 vs 3 col format for this batch
198
+
199
+ #Add data from several sub datasets
200
+ for _ in range(args.datasets_per_batch):
201
+ valid_dataset = False #Check that datasets have the same 2/3 col format
202
+ while not valid_dataset:
203
+ data_idx = random.choice(dataset_indices)
204
+ if batch_format is None:
205
+ batch_format = num_cols[data_idx]
206
+ valid_dataset = True
207
+ else: #Check that this dataset has the same format
208
+ valid_dataset = (batch_format == num_cols[data_idx])
209
+
210
+ #Get data from this dataset
211
+ dataset = datasets[data_idx]
212
+ local_batch_size = args.batch_size
213
+ if batch_format == 3 and args.batch_size_triplets is not None:
214
+ local_batch_size = args.batch_size_triplets
215
+
216
+ for _ in range(num_same_dataset):
217
+ for _ in range(args.nprocs):
218
+ batch_device = [] #A batch for one device
219
+ while len(batch_device) < local_batch_size:
220
+ sample = next(dataset)
221
+ in_batch = False
222
+ for text in sample:
223
+ if text in texts_in_batch:
224
+ in_batch = True
225
+ break
226
+
227
+ if not in_batch:
228
+ for text in sample:
229
+ texts_in_batch.add(text)
230
+ batch_device.append(sample)
231
+
232
+ queue.put(batch_device)
233
+
234
+
235
+ class RedditDataset:
236
+ """
237
+ A class that handles the reddit data files
238
+ """
239
+ def __init__(self, filepath):
240
+ self.filepath = filepath
241
+
242
+ def __iter__(self):
243
+ while True:
244
+ with gzip.open(self.filepath, "rt") as fIn:
245
+ for line in fIn:
246
+ data = json.loads(line)
247
+
248
+ if "response" in data and "context" in data:
249
+ yield [data["response"], data["context"]]
250
+
251
+ class Dataset:
252
+ """
253
+ A class that handles one dataset
254
+ """
255
+ def __init__(self, filepath):
256
+ self.filepath = filepath
257
+
258
+ def __iter__(self):
259
+ max_dataset_size = 20*1000*1000 #Cache small datasets in memory
260
+ dataset = []
261
+ data_format = None
262
+
263
+ while dataset is None or len(dataset) == 0:
264
+ with gzip.open(self.filepath, "rt") as fIn:
265
+ for line in fIn:
266
+ data = json.loads(line)
267
+ if isinstance(data, dict):
268
+ data = data['texts']
269
+
270
+ if data_format is None:
271
+ data_format = len(data)
272
+
273
+ #Ensure that all entries are of the same 2/3 col format
274
+ assert len(data) == data_format
275
+
276
+ if dataset is not None:
277
+ dataset.append(data)
278
+ if len(dataset) >= max_dataset_size:
279
+ dataset = None
280
+
281
+ yield data
282
+
283
+ # Data loaded. Now stream to the queue
284
+ # Shuffle for each epoch
285
+ while True:
286
+ random.shuffle(dataset)
287
+ for data in dataset:
288
+ yield data
289
+
290
+
291
+
292
+ if __name__ == "__main__":
293
+ parser = argparse.ArgumentParser()
294
+ parser.add_argument('--model', default='nreimers/MiniLM-L6-H384-uncased')
295
+ parser.add_argument('--steps', type=int, default=2000)
296
+ parser.add_argument('--save_steps', type=int, default=10000)
297
+ parser.add_argument('--batch_size', type=int, default=64)
298
+ parser.add_argument('--batch_size_triplets', type=int, default=None)
299
+ parser.add_argument('--max_length_a', type=int, default=128)
300
+ parser.add_argument('--max_length_b', type=int, default=128)
301
+ parser.add_argument('--nprocs', type=int, default=8)
302
+ parser.add_argument('--datasets_per_batch', type=int, default=2, help="Number of datasets per batch")
303
+ parser.add_argument('--scale', type=float, default=20, help="Use 20 for cossim, and 1 when you work with unnormalized embeddings with dot product")
304
+ parser.add_argument('--no_normalize', action="store_true", default=False, help="If set: Embeddings are not normalized")
305
+ parser.add_argument('--pooling', default='mean')
306
+ parser.add_argument('--data_folder', default="/data", help="Folder with your dataset files")
307
+ parser.add_argument('data_config', help="A data_config.json file")
308
+ parser.add_argument('output')
309
+ args = parser.parse_args()
310
+
311
+ # Ensure num proc is devisible by datasets_per_batch
312
+ assert (args.nprocs % args.datasets_per_batch) == 0
313
+
314
+
315
+ logging.info("Output: "+args.output)
316
+ if os.path.exists(args.output):
317
+ print("Output folder already exists.")
318
+ input("Continue?")
319
+
320
+ # Write train script to output path
321
+ os.makedirs(args.output, exist_ok=True)
322
+
323
+ data_config_path = os.path.join(args.output, 'data_config.json')
324
+ copyfile(args.data_config, data_config_path)
325
+
326
+ train_script_path = os.path.join(args.output, 'train_script.py')
327
+ copyfile(__file__, train_script_path)
328
+ with open(train_script_path, 'a') as fOut:
329
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
330
+
331
+
332
+
333
+ #Load data config
334
+ with open(args.data_config) as fIn:
335
+ data_config = json.load(fIn)
336
+
337
+ queue = mp.Queue(maxsize=100*args.nprocs)
338
+
339
+ filepaths = []
340
+ dataset_indices = []
341
+ for idx, data in enumerate(data_config):
342
+ filepaths.append(os.path.join(os.path.expanduser(args.data_folder), data['name']))
343
+ dataset_indices.extend([idx]*data['weight'])
344
+
345
+ # Start producer
346
+ p = mp.Process(target=produce_data, args=(args, queue, filepaths, dataset_indices))
347
+ p.start()
348
+
349
+ # Run training
350
+ print("Start processes:", args.nprocs)
351
+ xmp.spawn(train_function, args=(args, queue), nprocs=args.nprocs, start_method='fork')
352
+ print("Training done")
353
+ print("It might be that not all processes exit automatically. In that case you must manually kill this process.")
354
+ print("With 'pkill python' you can kill all remaining python processes")
355
+ p.kill()
356
+ exit()
357
+
358
+
359
+
360
+ # Script was called via:
361
+ #python train_many_data_files_v2.py --steps 200000 --batch_size 128 --model nreimers/MiniLM-L6-H384-uncased --max_length_a 64 --max_length_b 250 train_data_configs/multi-qa_v1.json output/multi-qa_v1-MiniLM-L6-mean_cos