nreimers's picture
upload
8f180fd
raw
history blame
4.2 kB
"""
This file runs Masked Language Model. You provide a training file. Each line is interpreted as a sentence / paragraph.
Optionally, you can also provide a dev file.
The fine-tuned model is stored in the output/model_name folder.
python train_mlm.py model_name data/train_sentences.txt [data/dev_sentences.txt]
"""
from transformers import AutoModelForMaskedLM, AutoTokenizer
from transformers import DataCollatorForLanguageModeling, DataCollatorForWholeWordMask
from transformers import Trainer, TrainingArguments
import sys
import gzip
from datetime import datetime
import wandb
wandb.init(project="bert-word2vec")
model_name = "nicoladecao/msmarco-word2vec256000-distilbert-base-uncased"
per_device_train_batch_size = 16
save_steps = 5000
eval_steps = 1000
num_train_epochs = 3
use_fp16 = True #Set to True, if your GPU supports FP16 operations
max_length = 250 #Max length for a text input
do_whole_word_mask = True #If set to true, whole words are masked
mlm_prob = 15 #Probability that a word is replaced by a [MASK] token
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
## Freeze embedding layer
model.distilbert.embeddings.requires_grad = False
output_dir = "output/{}-{}".format(model_name.replace("/", "_"), datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
print("Save checkpoints to:", output_dir)
##### Load our training datasets
train_sentences = []
train_path = 'data/train.txt'
with gzip.open(train_path, 'rt', encoding='utf8') if train_path.endswith('.gz') else open(train_path, 'r', encoding='utf8') as fIn:
for line in fIn:
line = line.strip()
if len(line) >= 10:
train_sentences.append(line)
print("Train sentences:", len(train_sentences))
dev_sentences = []
dev_path = 'data/dev.txt'
with gzip.open(dev_path, 'rt', encoding='utf8') if dev_path.endswith('.gz') else open(dev_path, 'r', encoding='utf8') as fIn:
for line in fIn:
line = line.strip()
if len(line) >= 10:
dev_sentences.append(line)
print("Dev sentences:", len(dev_sentences))
#A dataset wrapper, that tokenizes our data on-the-fly
class TokenizedSentencesDataset:
def __init__(self, sentences, tokenizer, max_length, cache_tokenization=False):
self.tokenizer = tokenizer
self.sentences = sentences
self.max_length = max_length
self.cache_tokenization = cache_tokenization
def __getitem__(self, item):
if not self.cache_tokenization:
return self.tokenizer(self.sentences[item], add_special_tokens=True, truncation=True, max_length=self.max_length, return_special_tokens_mask=True)
if isinstance(self.sentences[item], str):
self.sentences[item] = self.tokenizer(self.sentences[item], add_special_tokens=True, truncation=True, max_length=self.max_length, return_special_tokens_mask=True)
return self.sentences[item]
def __len__(self):
return len(self.sentences)
train_dataset = TokenizedSentencesDataset(train_sentences, tokenizer, max_length)
dev_dataset = TokenizedSentencesDataset(dev_sentences, tokenizer, max_length, cache_tokenization=True) if len(dev_sentences) > 0 else None
##### Training arguments
if do_whole_word_mask:
data_collator = DataCollatorForWholeWordMask(tokenizer=tokenizer, mlm=True, mlm_probability=mlm_prob)
else:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=mlm_prob)
training_args = TrainingArguments(
output_dir=output_dir,
overwrite_output_dir=True,
num_train_epochs=num_train_epochs,
evaluation_strategy="steps" if dev_dataset is not None else "no",
per_device_train_batch_size=per_device_train_batch_size,
eval_steps=eval_steps,
save_steps=save_steps,
save_total_limit=1,
prediction_loss_only=True,
fp16=use_fp16
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=dev_dataset
)
trainer.train()
print("Save model to:", output_dir)
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print("Training done")