nreimers commited on
Commit
77b49f5
·
1 Parent(s): 6dd993a
Files changed (1) hide show
  1. train_mlm.py +129 -0
train_mlm.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file runs Masked Language Model. You provide a training file. Each line is interpreted as a sentence / paragraph.
3
+ Optionally, you can also provide a dev file.
4
+
5
+ The fine-tuned model is stored in the output/model_name folder.
6
+
7
+ python train_mlm.py model_name data/train_sentences.txt [data/dev_sentences.txt]
8
+ """
9
+
10
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
11
+ from transformers import DataCollatorForLanguageModeling, DataCollatorForWholeWordMask
12
+ from transformers import Trainer, TrainingArguments
13
+ import sys
14
+ import gzip
15
+ from datetime import datetime
16
+ import wandb
17
+ from shutil import copyfile
18
+
19
+ wandb.init(project="bert-word2vec")
20
+
21
+ model_name = "nicoladecao/msmarco-word2vec256000-distilbert-base-uncased"
22
+ per_device_train_batch_size = 16
23
+ save_steps = 5000
24
+ eval_steps = 1000
25
+ num_train_epochs = 3
26
+ use_fp16 = True #Set to True, if your GPU supports FP16 operations
27
+ max_length = 250 #Max length for a text input
28
+ do_whole_word_mask = False #If set to true, whole words are masked
29
+ mlm_prob = 15 #Probability that a word is replaced by a [MASK] token
30
+
31
+ model = AutoModelForMaskedLM.from_pretrained(model_name)
32
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
33
+
34
+ ## Freeze embedding layer
35
+ #model.distilbert.embeddings.requires_grad = False
36
+ model.distilbert.embeddings.word_embeddings.requires_grad_(False)
37
+
38
+ output_dir = "output-mlm/{}-{}".format(model_name.replace("/", "_"), datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
39
+ print("Save checkpoints to:", output_dir)
40
+
41
+ # Write self to path
42
+ os.makedirs(model_save_path, exist_ok=True)
43
+
44
+ train_script_path = os.path.join(model_save_path, 'train_script.py')
45
+ copyfile(__file__, train_script_path)
46
+ with open(train_script_path, 'a') as fOut:
47
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
48
+
49
+ ##### Load our training datasets
50
+
51
+ train_sentences = []
52
+ train_path = 'data/train.txt'
53
+ with gzip.open(train_path, 'rt', encoding='utf8') if train_path.endswith('.gz') else open(train_path, 'r', encoding='utf8') as fIn:
54
+ for line in fIn:
55
+ line = line.strip()
56
+ if len(line) >= 10:
57
+ train_sentences.append(line)
58
+
59
+ print("Train sentences:", len(train_sentences))
60
+
61
+ dev_sentences = []
62
+
63
+ dev_path = 'data/dev.txt'
64
+ with gzip.open(dev_path, 'rt', encoding='utf8') if dev_path.endswith('.gz') else open(dev_path, 'r', encoding='utf8') as fIn:
65
+ for line in fIn:
66
+ line = line.strip()
67
+ if len(line) >= 10:
68
+ dev_sentences.append(line)
69
+
70
+ print("Dev sentences:", len(dev_sentences))
71
+
72
+ #A dataset wrapper, that tokenizes our data on-the-fly
73
+ class TokenizedSentencesDataset:
74
+ def __init__(self, sentences, tokenizer, max_length, cache_tokenization=False):
75
+ self.tokenizer = tokenizer
76
+ self.sentences = sentences
77
+ self.max_length = max_length
78
+ self.cache_tokenization = cache_tokenization
79
+
80
+ def __getitem__(self, item):
81
+ if not self.cache_tokenization:
82
+ return self.tokenizer(self.sentences[item], add_special_tokens=True, truncation=True, max_length=self.max_length, return_special_tokens_mask=True)
83
+
84
+ if isinstance(self.sentences[item], str):
85
+ self.sentences[item] = self.tokenizer(self.sentences[item], add_special_tokens=True, truncation=True, max_length=self.max_length, return_special_tokens_mask=True)
86
+ return self.sentences[item]
87
+
88
+ def __len__(self):
89
+ return len(self.sentences)
90
+
91
+ train_dataset = TokenizedSentencesDataset(train_sentences, tokenizer, max_length)
92
+ dev_dataset = TokenizedSentencesDataset(dev_sentences, tokenizer, max_length, cache_tokenization=True) if len(dev_sentences) > 0 else None
93
+
94
+
95
+ ##### Training arguments
96
+
97
+ if do_whole_word_mask:
98
+ data_collator = DataCollatorForWholeWordMask(tokenizer=tokenizer, mlm=True, mlm_probability=mlm_prob)
99
+ else:
100
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=mlm_prob)
101
+
102
+ training_args = TrainingArguments(
103
+ output_dir=output_dir,
104
+ overwrite_output_dir=True,
105
+ num_train_epochs=num_train_epochs,
106
+ evaluation_strategy="steps" if dev_dataset is not None else "no",
107
+ per_device_train_batch_size=per_device_train_batch_size,
108
+ eval_steps=eval_steps,
109
+ save_steps=save_steps,
110
+ save_total_limit=1,
111
+ prediction_loss_only=True,
112
+ fp16=use_fp16
113
+ )
114
+
115
+ trainer = Trainer(
116
+ model=model,
117
+ args=training_args,
118
+ data_collator=data_collator,
119
+ train_dataset=train_dataset,
120
+ eval_dataset=dev_dataset
121
+ )
122
+
123
+ trainer.train()
124
+
125
+ print("Save model to:", output_dir)
126
+ model.save_pretrained(output_dir)
127
+ tokenizer.save_pretrained(output_dir)
128
+
129
+ print("Training done")