nreimers commited on
Commit
81804c7
1 Parent(s): 7b02ded
1_Pooling/config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 768,
3
+ "pooling_mode_cls_token": false,
4
+ "pooling_mode_mean_tokens": true,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false
7
+ }
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ pipeline_tag: sentence-similarity
3
+ tags:
4
+ - sentence-transformers
5
+ - feature-extraction
6
+ - sentence-similarity
7
+ ---
8
+
9
+ # stackoverflow_mpnet-base
10
+
11
+ This is a microsoft/mpnet-base model trained on 18,562,443 (title, body) pairs from StackOverflow. See data_config.json and train_script.py in this respository how the model was trained and which datasets have been used.
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "microsoft/mpnet-base",
3
+ "architectures": [
4
+ "MPNetForMaskedLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "eos_token_id": 2,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 514,
16
+ "model_type": "mpnet",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 1,
20
+ "relative_attention_num_buckets": 32,
21
+ "transformers_version": "4.8.2",
22
+ "vocab_size": 30527
23
+ }
config_sentence_transformers.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "2.0.0",
4
+ "transformers": "4.6.1",
5
+ "pytorch": "1.8.1"
6
+ }
7
+ }
data_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "name": "stackexchange_title_body/stackoverflow.com-Posts.jsonl.gz",
4
+ "lines": 18562443,
5
+ "weight": 100
6
+ }
7
+ ]
modules.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ },
14
+ {
15
+ "idx": 2,
16
+ "name": "2",
17
+ "path": "2_Normalize",
18
+ "type": "sentence_transformers.models.Normalize"
19
+ }
20
+ ]
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7dcdca5dfbca9becac708f520fa11b36e7c615bec56e033f1566eb8745e2d9d3
3
+ size 438011953
sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ {
2
+ "max_seq_length": 512,
3
+ "do_lower_case": false
4
+ }
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "[UNK]", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
1
+ {"do_lower_case": true, "bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "[UNK]", "pad_token": "<pad>", "mask_token": "<mask>", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "microsoft/mpnet-base", "tokenizer_class": "MPNetTokenizer"}
train_script.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train script for a single file
3
+
4
+ Need to set the TPU address first:
5
+ export XRT_TPU_CONFIG="localservice;0;localhost:51011"
6
+ """
7
+
8
+ import torch.multiprocessing as mp
9
+ import threading
10
+ import time
11
+ import random
12
+ import sys
13
+ import argparse
14
+ import gzip
15
+ import json
16
+ import logging
17
+ import tqdm
18
+ import torch
19
+ from torch import nn
20
+ from torch.utils.data import DataLoader
21
+ import torch
22
+ import torch_xla
23
+ import torch_xla.core
24
+ import torch_xla.core.functions
25
+ import torch_xla.core.xla_model as xm
26
+ import torch_xla.distributed.xla_multiprocessing as xmp
27
+ import torch_xla.distributed.parallel_loader as pl
28
+ import os
29
+ from shutil import copyfile
30
+
31
+
32
+ from transformers import (
33
+ AdamW,
34
+ AutoModel,
35
+ AutoTokenizer,
36
+ get_linear_schedule_with_warmup,
37
+ set_seed,
38
+ )
39
+
40
+ class AutoModelForSentenceEmbedding(nn.Module):
41
+ def __init__(self, model_name, tokenizer, normalize=True):
42
+ super(AutoModelForSentenceEmbedding, self).__init__()
43
+
44
+ self.model = AutoModel.from_pretrained(model_name)
45
+ self.normalize = normalize
46
+ self.tokenizer = tokenizer
47
+
48
+ def forward(self, **kwargs):
49
+ model_output = self.model(**kwargs)
50
+ embeddings = self.mean_pooling(model_output, kwargs['attention_mask'])
51
+ if self.normalize:
52
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
53
+
54
+ return embeddings
55
+
56
+ def mean_pooling(self, model_output, attention_mask):
57
+ token_embeddings = model_output[0] # First element of model_output contains all token embeddings
58
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
59
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
60
+
61
+ def save_pretrained(self, output_path):
62
+ if xm.is_master_ordinal():
63
+ self.tokenizer.save_pretrained(output_path)
64
+ self.model.config.save_pretrained(output_path)
65
+
66
+ xm.save(self.model.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
67
+
68
+
69
+
70
+
71
+ def train_function(index, args, queues, dataset_indices):
72
+ dataset_rnd = random.Random(index % args.data_word_size) #Defines which dataset to use in every step
73
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
74
+ model = AutoModelForSentenceEmbedding(args.model, tokenizer)
75
+
76
+
77
+ ### Train Loop
78
+ device = xm.xla_device()
79
+ model = model.to(device)
80
+
81
+ # Instantiate optimizer
82
+ optimizer = AdamW(params=model.parameters(), lr=2e-5, correct_bias=True)
83
+
84
+ lr_scheduler = get_linear_schedule_with_warmup(
85
+ optimizer=optimizer,
86
+ num_warmup_steps=500,
87
+ num_training_steps=args.steps,
88
+ )
89
+
90
+ # Now we train the model
91
+ cross_entropy_loss = nn.CrossEntropyLoss()
92
+ max_grad_norm = 1
93
+
94
+ model.train()
95
+
96
+ for global_step in tqdm.trange(args.steps, disable=not xm.is_master_ordinal()):
97
+
98
+ #### Get the batch data
99
+ dataset_idx = dataset_rnd.choice(dataset_indices)
100
+ text1 = []
101
+ text2 = []
102
+ for _ in range(args.batch_size):
103
+ example = queues[dataset_idx].get()
104
+ text1.append(example[0])
105
+ text2.append(example[1])
106
+
107
+ #print(index, f"dataset {dataset_idx}", text1[0:3])
108
+
109
+ text1 = tokenizer(text1, return_tensors="pt", max_length=128, truncation=True, padding="max_length")
110
+ text2 = tokenizer(text2, return_tensors="pt", max_length=128, truncation=True, padding="max_length")
111
+
112
+ ### Compute embeddings
113
+ #print(index, "compute embeddings")
114
+ embeddings_a = model(**text1.to(device))
115
+ embeddings_b = model(**text2.to(device))
116
+
117
+
118
+ ### Gather all embedings
119
+ embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
120
+ embeddings_b = torch_xla.core.functions.all_gather(embeddings_b)
121
+
122
+ ### Compute similarity scores
123
+ scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
124
+
125
+ ### Compute cross-entropy loss
126
+ labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
127
+
128
+ ## One-way loss
129
+ #loss = cross_entropy_loss(scores, labels)
130
+
131
+ ## Symmetric loss as in CLIP
132
+ loss = (cross_entropy_loss(scores, labels) + cross_entropy_loss(scores.transpose(0, 1), labels)) / 2
133
+
134
+ # Backward pass
135
+ optimizer.zero_grad()
136
+ loss.backward()
137
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
138
+
139
+ xm.optimizer_step(optimizer, barrier=True)
140
+ lr_scheduler.step()
141
+
142
+ #Save model
143
+ if (global_step+1) % args.save_steps == 0:
144
+ output_path = os.path.join(args.output, str(global_step+1))
145
+ xm.master_print("save model: "+output_path)
146
+ model.save_pretrained(output_path)
147
+
148
+
149
+ output_path = os.path.join(args.output)
150
+ xm.master_print("save model final: "+ output_path)
151
+ model.save_pretrained(output_path)
152
+
153
+
154
+
155
+ def load_data(path, queue):
156
+ dataset = []
157
+
158
+ with gzip.open(path, "rt") as fIn:
159
+ for line in fIn:
160
+ data = json.loads(line)
161
+ if isinstance(data, dict):
162
+ data = data['texts']
163
+
164
+ #Only use two columns
165
+ dataset.append(data[0:2])
166
+ queue.put(data[0:2])
167
+
168
+ # Data loaded. Now stream to the queue
169
+ # Shuffle for each epoch
170
+ while True:
171
+ random.shuffle(dataset)
172
+ for data in dataset:
173
+ queue.put(data)
174
+
175
+
176
+ if __name__ == "__main__":
177
+ parser = argparse.ArgumentParser()
178
+ parser.add_argument('--model', default='nreimers/MiniLM-L6-H384-uncased')
179
+ parser.add_argument('--steps', type=int, default=2000)
180
+ parser.add_argument('--save_steps', type=int, default=10000)
181
+ parser.add_argument('--batch_size', type=int, default=32)
182
+ parser.add_argument('--nprocs', type=int, default=8)
183
+ parser.add_argument('--data_word_size', type=int, default=2, help="How many different dataset should be included in every train step. Cannot be larger than nprocs")
184
+ parser.add_argument('--scale', type=float, default=20)
185
+ parser.add_argument('--data_folder', default="/data", help="Folder with your dataset files")
186
+ parser.add_argument('data_config', help="A data_config.json file")
187
+ parser.add_argument('output')
188
+ args = parser.parse_args()
189
+
190
+ logging.info("Output: "+args.output)
191
+ if os.path.exists(args.output):
192
+ print("Output folder already exists. Exit!")
193
+ exit()
194
+
195
+ # Write train script to output path
196
+ os.makedirs(args.output, exist_ok=True)
197
+
198
+ data_config_path = os.path.join(args.output, 'data_config.json')
199
+ copyfile(args.data_config, data_config_path)
200
+
201
+ train_script_path = os.path.join(args.output, 'train_script.py')
202
+ copyfile(__file__, train_script_path)
203
+ with open(train_script_path, 'a') as fOut:
204
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
205
+
206
+
207
+
208
+ #Load data config
209
+ with open(args.data_config) as fIn:
210
+ data_config = json.load(fIn)
211
+
212
+ threads = []
213
+ queues = []
214
+ dataset_indices = []
215
+ for data in data_config:
216
+ data_idx = len(queues)
217
+ queue = mp.Queue(maxsize=args.nprocs*args.batch_size)
218
+ th = threading.Thread(target=load_data, daemon=True, args=(os.path.join(os.path.expanduser(args.data_folder), data['name']), queue))
219
+ th.start()
220
+ threads.append(th)
221
+ queues.append(queue)
222
+ dataset_indices.extend([data_idx]*data['weight'])
223
+
224
+
225
+ print("Start processes:", args.nprocs)
226
+ xmp.spawn(train_function, args=(args, queues, dataset_indices), nprocs=args.nprocs, start_method='fork')
227
+ print("Training done")
228
+ print("It might be that not all processes exit automatically. In that case you must manually kill this process.")
229
+ print("With 'pkill python' you can kill all remaining python processes")
230
+ exit()
231
+
232
+
233
+
234
+ # Script was called via:
235
+ #python train_many_data_files.py --steps 100000 --batch_size 64 --model microsoft/mpnet-base train_data_configs/stackoverflow.json output/stackoverflow_mpnet-base
vocab.txt ADDED
The diff for this file is too large to render. See raw diff