upload
Browse files- README.md +6 -0
- config.json +24 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +1 -0
- tokenizer.json +0 -0
- tokenizer_config.json +1 -0
- train_script.py +398 -0
- train_steps.log +100 -0
README.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DistilBERT with 256k token embeddings
|
2 |
+
|
3 |
+
This model was initialized with a word2vec token embedding matrix with 256k entries, but these token embeddings were updated during MLM. The word2vec was trained on 100GB data from C4, MSMARCO, News, Wikipedia, S2ORC, for 3 epochs.
|
4 |
+
|
5 |
+
Then the model was trained on this dataset with MLM for 250k steps (batch size 64). The token embeddings were updated during MLM.
|
6 |
+
|
config.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "train-w2v-model/c4_msmarco_news_s2orc_wiki/distilbert-256k/",
|
3 |
+
"activation": "gelu",
|
4 |
+
"architectures": [
|
5 |
+
"DistilBertForMaskedLM"
|
6 |
+
],
|
7 |
+
"attention_dropout": 0.1,
|
8 |
+
"dim": 768,
|
9 |
+
"dropout": 0.1,
|
10 |
+
"hidden_dim": 3072,
|
11 |
+
"initializer_range": 0.02,
|
12 |
+
"max_position_embeddings": 512,
|
13 |
+
"model_type": "distilbert",
|
14 |
+
"n_heads": 12,
|
15 |
+
"n_layers": 6,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"qa_dropout": 0.1,
|
18 |
+
"seq_classif_dropout": 0.2,
|
19 |
+
"sinusoidal_pos_embds": false,
|
20 |
+
"tie_weights_": true,
|
21 |
+
"torch_dtype": "float32",
|
22 |
+
"transformers_version": "4.17.0",
|
23 |
+
"vocab_size": 256000
|
24 |
+
}
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3cfeeeaf234f887a7982d25adc2cd95c0bf0b6c2c22077413a445de3eef5b2b0
|
3 |
+
size 961553391
|
special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"model_max_length": 512, "unk_token": "[UNK]", "cls_token": "[CLS]", "sep_token": "[SEP]", "pad_token": "[PAD]", "mask_token": "[MASK]", "model_input_names": ["input_ids", "attention_mask"], "special_tokens_map_file": "c4_msmarco_news_s2orc_wiki/tokenizer-256k/special_tokens_map.json", "name_or_path": "train-w2v-model/c4_msmarco_news_s2orc_wiki/distilbert-256k/", "tokenizer_class": "PreTrainedTokenizerFast"}
|
train_script.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import argparse
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
from datetime import datetime
|
7 |
+
import datasets
|
8 |
+
import torch
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
from tqdm.auto import tqdm
|
11 |
+
import sys
|
12 |
+
import transformers
|
13 |
+
from accelerate import Accelerator, DistributedType
|
14 |
+
from shutil import copyfile
|
15 |
+
import wandb
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
from transformers import (
|
19 |
+
MODEL_MAPPING,
|
20 |
+
AutoModelForMaskedLM,
|
21 |
+
AutoTokenizer,
|
22 |
+
DataCollatorForLanguageModeling,
|
23 |
+
SchedulerType,
|
24 |
+
get_scheduler
|
25 |
+
)
|
26 |
+
from transformers.utils.versions import require_version
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
class TrainDataset(torch.utils.data.IterableDataset):
|
31 |
+
def __init__(self, filepath, tokenizer, max_length, batch_size, train_samples):
|
32 |
+
self.tokenizer = tokenizer
|
33 |
+
self.fIn = open(filepath)
|
34 |
+
self.max_length = max_length
|
35 |
+
self.batch_size = batch_size
|
36 |
+
self.train_samples = train_samples
|
37 |
+
|
38 |
+
def __iter__(self):
|
39 |
+
batch = []
|
40 |
+
for sent in self.fIn:
|
41 |
+
batch.append(sent.strip()[0:1000])
|
42 |
+
|
43 |
+
if len(batch) >= self.batch_size:
|
44 |
+
#Use multi process tokenization
|
45 |
+
encoded = self.tokenizer(batch, add_special_tokens=True, truncation=True, max_length=self.max_length, return_special_tokens_mask=True, padding=True)
|
46 |
+
#print(len(encoded['input_ids'][0]))
|
47 |
+
for idx in range(len(batch)):
|
48 |
+
single_sample = {key: encoded[key][idx] for key in encoded}
|
49 |
+
yield single_sample
|
50 |
+
|
51 |
+
batch = []
|
52 |
+
|
53 |
+
def __len__(self):
|
54 |
+
return self.train_samples
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
## Dev dataset
|
61 |
+
class DevDataset(torch.utils.data.Dataset):
|
62 |
+
def __init__(self, filepath, tokenizer, max_length):
|
63 |
+
self.tokenizer = tokenizer
|
64 |
+
self.max_length = max_length
|
65 |
+
with open(filepath) as fIn:
|
66 |
+
sentences = [sent.strip() for sent in fIn]
|
67 |
+
|
68 |
+
self.num_sentences = len(sentences)
|
69 |
+
self.tokenized = self.tokenizer(sentences, add_special_tokens=True, truncation=True, max_length=self.max_length, return_special_tokens_mask=True)
|
70 |
+
|
71 |
+
def __getitem__(self, idx):
|
72 |
+
return {key: self.tokenized[key][idx] for key in self.tokenized}
|
73 |
+
|
74 |
+
def __len__(self):
|
75 |
+
return self.num_sentences
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
logger = logging.getLogger(__name__)
|
80 |
+
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
81 |
+
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
82 |
+
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
83 |
+
|
84 |
+
|
85 |
+
def parse_args():
|
86 |
+
parser = argparse.ArgumentParser(description="Finetune a transformers model on a Masked Language Modeling task")
|
87 |
+
parser.add_argument(
|
88 |
+
"--dataset_config_name",
|
89 |
+
type=str,
|
90 |
+
default=None,
|
91 |
+
help="The configuration name of the dataset to use (via the datasets library).",
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--train_file", type=str, default=None, help="A text file data (1 text per line).."
|
95 |
+
)
|
96 |
+
parser.add_argument(
|
97 |
+
"--dev_file", type=str, default=None, help="A text file data (1 text per line)."
|
98 |
+
)
|
99 |
+
parser.add_argument(
|
100 |
+
"--model_name",
|
101 |
+
default="nicoladecao/msmarco-word2vec256000-distilbert-base-uncased",
|
102 |
+
type=str,
|
103 |
+
help="Path to pretrained model or model identifier from huggingface.co/models."
|
104 |
+
)
|
105 |
+
parser.add_argument(
|
106 |
+
"--per_device_batch_size",
|
107 |
+
type=int,
|
108 |
+
default=16,
|
109 |
+
help="Batch size (per device) for the training dataloader.",
|
110 |
+
)
|
111 |
+
parser.add_argument(
|
112 |
+
"--learning_rate",
|
113 |
+
type=float,
|
114 |
+
default=5e-5,
|
115 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
116 |
+
)
|
117 |
+
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.")
|
118 |
+
parser.add_argument("--num_train_epochs", type=int, default=1, help="Total number of training epochs to perform.")
|
119 |
+
parser.add_argument(
|
120 |
+
"--max_train_steps",
|
121 |
+
type=int,
|
122 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
123 |
+
)
|
124 |
+
parser.add_argument(
|
125 |
+
"--gradient_accumulation_steps",
|
126 |
+
type=int,
|
127 |
+
default=1,
|
128 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--lr_scheduler_type",
|
132 |
+
type=SchedulerType,
|
133 |
+
default="linear",
|
134 |
+
help="The scheduler type to use.",
|
135 |
+
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
|
136 |
+
)
|
137 |
+
parser.add_argument(
|
138 |
+
"--num_warmup_steps", type=int, default=1000, help="Number of steps for the warmup in the lr scheduler."
|
139 |
+
)
|
140 |
+
parser.add_argument(
|
141 |
+
"--model_type",
|
142 |
+
type=str,
|
143 |
+
default=None,
|
144 |
+
help="Model type to use if training from scratch.",
|
145 |
+
choices=MODEL_TYPES,
|
146 |
+
)
|
147 |
+
parser.add_argument(
|
148 |
+
"--max_seq_length",
|
149 |
+
type=int,
|
150 |
+
default=256,
|
151 |
+
help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated.",
|
152 |
+
)
|
153 |
+
parser.add_argument(
|
154 |
+
"--line_by_line",
|
155 |
+
type=bool,
|
156 |
+
default=True,
|
157 |
+
help="Whether distinct lines of text in the dataset are to be handled as distinct sequences.",
|
158 |
+
)
|
159 |
+
parser.add_argument(
|
160 |
+
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
|
161 |
+
)
|
162 |
+
parser.add_argument(
|
163 |
+
"--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss"
|
164 |
+
)
|
165 |
+
parser.add_argument("--mixed_precision", default="fp16")
|
166 |
+
parser.add_argument("--train_samples", required=True, type=int)
|
167 |
+
parser.add_argument("--eval_steps", default=10000, type=int)
|
168 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float)
|
169 |
+
parser.add_argument("--project", default="bert-word2vec")
|
170 |
+
parser.add_argument("--freeze_emb_layer", default=False, action='store_true')
|
171 |
+
parser.add_argument("--log_interval", default=1000, type=int)
|
172 |
+
parser.add_argument("--ckp_steps", default=50000, type=int)
|
173 |
+
|
174 |
+
args = parser.parse_args()
|
175 |
+
|
176 |
+
|
177 |
+
return args
|
178 |
+
|
179 |
+
|
180 |
+
def main():
|
181 |
+
args = parse_args()
|
182 |
+
|
183 |
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
|
184 |
+
accelerator = Accelerator(mixed_precision=args.mixed_precision)
|
185 |
+
# Make one log on every process with the configuration for debugging.
|
186 |
+
logging.basicConfig(
|
187 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
188 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
189 |
+
level=logging.INFO,
|
190 |
+
)
|
191 |
+
logger.info(accelerator.state)
|
192 |
+
|
193 |
+
# Setup logging, we only want one process per machine to log things on the screen.
|
194 |
+
# accelerator.is_local_main_process is only True for one process per machine.
|
195 |
+
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
|
196 |
+
if accelerator.is_local_main_process:
|
197 |
+
datasets.utils.logging.set_verbosity_warning()
|
198 |
+
transformers.utils.logging.set_verbosity_info()
|
199 |
+
else:
|
200 |
+
datasets.utils.logging.set_verbosity_error()
|
201 |
+
transformers.utils.logging.set_verbosity_error()
|
202 |
+
|
203 |
+
|
204 |
+
accelerator.wait_for_everyone()
|
205 |
+
|
206 |
+
|
207 |
+
#Load model
|
208 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
209 |
+
model = AutoModelForMaskedLM.from_pretrained(args.model_name)
|
210 |
+
|
211 |
+
#Freeze emb layer
|
212 |
+
if args.freeze_emb_layer:
|
213 |
+
model.distilbert.embeddings.word_embeddings.requires_grad_(False)
|
214 |
+
|
215 |
+
# Logging & Co on main process
|
216 |
+
if accelerator.is_main_process:
|
217 |
+
exp_name = f'{args.model_name.replace("/", "-")}-{"freeze_emb" if args.freeze_emb_layer else "update_emb"}-{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'
|
218 |
+
output_dir = os.path.join("output-mlm", exp_name)
|
219 |
+
wandb.init(project=args.project, name=exp_name, config=args)
|
220 |
+
|
221 |
+
os.makedirs(output_dir, exist_ok=False)
|
222 |
+
|
223 |
+
#Save tokenizer
|
224 |
+
tokenizer.save_pretrained(output_dir)
|
225 |
+
|
226 |
+
#Save train script
|
227 |
+
train_script_path = os.path.join(output_dir, 'train_script.py')
|
228 |
+
copyfile(__file__, train_script_path)
|
229 |
+
with open(train_script_path, 'a') as fOut:
|
230 |
+
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
|
231 |
+
|
232 |
+
|
233 |
+
total_batch_size = args.per_device_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
234 |
+
|
235 |
+
train_dataset = TrainDataset(args.train_file, tokenizer, args.max_seq_length, batch_size=total_batch_size, train_samples=args.train_samples)
|
236 |
+
eval_dataset = DevDataset(args.dev_file, tokenizer, args.max_seq_length)
|
237 |
+
|
238 |
+
|
239 |
+
# Data collator
|
240 |
+
# This one will take care of randomly masking the tokens.
|
241 |
+
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=args.mlm_probability)
|
242 |
+
|
243 |
+
# DataLoaders creation:
|
244 |
+
train_dataloader = DataLoader(train_dataset, collate_fn=data_collator, batch_size=args.per_device_batch_size)
|
245 |
+
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_batch_size)
|
246 |
+
|
247 |
+
# Optimizer
|
248 |
+
# Split weights in two groups, one with weight decay and the other not.
|
249 |
+
no_decay = ["bias", "LayerNorm.weight"]
|
250 |
+
optimizer_grouped_parameters = [
|
251 |
+
{
|
252 |
+
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
253 |
+
"weight_decay": args.weight_decay,
|
254 |
+
},
|
255 |
+
{
|
256 |
+
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
257 |
+
"weight_decay": 0.0,
|
258 |
+
},
|
259 |
+
]
|
260 |
+
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
261 |
+
|
262 |
+
# Prepare everything with our `accelerator`.
|
263 |
+
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, train_dataloader, eval_dataloader)
|
264 |
+
|
265 |
+
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
|
266 |
+
if accelerator.distributed_type == DistributedType.TPU:
|
267 |
+
model.tie_weights()
|
268 |
+
|
269 |
+
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
|
270 |
+
# shorter in multiprocess)
|
271 |
+
|
272 |
+
# Scheduler and math around the number of training steps.
|
273 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
274 |
+
if args.max_train_steps is None:
|
275 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
276 |
+
else:
|
277 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
278 |
+
|
279 |
+
lr_scheduler = get_scheduler(
|
280 |
+
name=args.lr_scheduler_type,
|
281 |
+
optimizer=optimizer,
|
282 |
+
num_warmup_steps=args.num_warmup_steps,
|
283 |
+
num_training_steps=args.max_train_steps,
|
284 |
+
)
|
285 |
+
|
286 |
+
|
287 |
+
# Train!
|
288 |
+
logger.info("***** Running training *****")
|
289 |
+
logger.info(f" Num examples = {args.train_samples}")
|
290 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
291 |
+
logger.info(f" Instantaneous batch size per device = {args.per_device_batch_size}")
|
292 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
293 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
294 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
295 |
+
# Only show the progress bar once on each machine.
|
296 |
+
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process, smoothing=0.05)
|
297 |
+
completed_steps = 0
|
298 |
+
train_loss_values = []
|
299 |
+
|
300 |
+
best_eval_loss = 999999
|
301 |
+
if accelerator.is_main_process:
|
302 |
+
best_ckp_dir = os.path.join(output_dir, "best")
|
303 |
+
tokenizer.save_pretrained(best_ckp_dir)
|
304 |
+
|
305 |
+
for epoch in range(args.num_train_epochs):
|
306 |
+
logger.info(f"Start epoch {epoch}")
|
307 |
+
model.train()
|
308 |
+
for step, batch in enumerate(train_dataloader):
|
309 |
+
outputs = model(**batch)
|
310 |
+
loss = outputs.loss
|
311 |
+
loss = loss / args.gradient_accumulation_steps
|
312 |
+
|
313 |
+
if accelerator.is_main_process:
|
314 |
+
train_loss_values.append(loss.cpu().item())
|
315 |
+
|
316 |
+
accelerator.backward(loss)
|
317 |
+
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
318 |
+
if step % args.gradient_accumulation_steps == 0:
|
319 |
+
optimizer.step()
|
320 |
+
lr_scheduler.step()
|
321 |
+
optimizer.zero_grad()
|
322 |
+
progress_bar.update(1)
|
323 |
+
completed_steps += 1
|
324 |
+
|
325 |
+
### Do logging
|
326 |
+
if accelerator.is_main_process:
|
327 |
+
if completed_steps % args.log_interval == 0:
|
328 |
+
wandb.log({"train/loss": np.mean(train_loss_values)}, step=completed_steps)
|
329 |
+
train_loss_values = []
|
330 |
+
|
331 |
+
|
332 |
+
if completed_steps % args.eval_steps == 0:
|
333 |
+
model.eval()
|
334 |
+
losses = []
|
335 |
+
for step, batch in enumerate(eval_dataloader):
|
336 |
+
with torch.no_grad():
|
337 |
+
outputs = model(**batch)
|
338 |
+
|
339 |
+
loss = outputs.loss
|
340 |
+
losses.append(accelerator.gather(loss.repeat(args.per_device_batch_size)))
|
341 |
+
|
342 |
+
losses = torch.cat(losses)
|
343 |
+
losses = losses[: len(eval_dataset)]
|
344 |
+
try:
|
345 |
+
eval_loss = torch.mean(losses)
|
346 |
+
except OverflowError:
|
347 |
+
eval_loss = float("inf")
|
348 |
+
|
349 |
+
logger.info(f"step {completed_steps}: perplexity: {eval_loss}")
|
350 |
+
if accelerator.is_main_process:
|
351 |
+
wandb.log({"eval/loss": eval_loss}, step=completed_steps)
|
352 |
+
|
353 |
+
model.train()
|
354 |
+
|
355 |
+
#Save model
|
356 |
+
accelerator.wait_for_everyone()
|
357 |
+
if accelerator.is_main_process:
|
358 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
359 |
+
unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
|
360 |
+
with open(os.path.join(output_dir, "train_steps.log"), 'a') as fOut:
|
361 |
+
fOut.write(f"{completed_steps}: {eval_loss}\n")
|
362 |
+
|
363 |
+
#Save best model
|
364 |
+
if eval_loss < best_eval_loss:
|
365 |
+
best_eval_loss = eval_loss
|
366 |
+
unwrapped_model.save_pretrained(best_ckp_dir, save_function=accelerator.save)
|
367 |
+
with open(os.path.join(best_ckp_dir, "train_steps.log"), 'a') as fOut:
|
368 |
+
fOut.write(f"{completed_steps}: {eval_loss}\n")
|
369 |
+
|
370 |
+
if accelerator.is_main_process and completed_steps % args.ckp_steps == 0:
|
371 |
+
ckp_dir = os.path.join(output_dir, f"ckp-{int(completed_steps/1000)}k")
|
372 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
373 |
+
unwrapped_model.save_pretrained(ckp_dir, save_function=accelerator.save)
|
374 |
+
tokenizer.save_pretrained(ckp_dir)
|
375 |
+
with open(os.path.join(ckp_dir, "train_steps.log"), 'a') as fOut:
|
376 |
+
fOut.write(f"{completed_steps}: {eval_loss}\n")
|
377 |
+
|
378 |
+
|
379 |
+
if completed_steps >= args.max_train_steps:
|
380 |
+
break
|
381 |
+
|
382 |
+
if args.output_dir is not None:
|
383 |
+
accelerator.wait_for_everyone()
|
384 |
+
if accelerator.is_main_process:
|
385 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
386 |
+
unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
|
387 |
+
with open(os.path.join(output_dir, "train_steps.log"), 'a') as fOut:
|
388 |
+
fOut.write(f"{completed_steps}\n")
|
389 |
+
|
390 |
+
|
391 |
+
|
392 |
+
|
393 |
+
if __name__ == "__main__":
|
394 |
+
main()
|
395 |
+
|
396 |
+
|
397 |
+
# Script was called via:
|
398 |
+
#python train_mlm-iterable.py --train_file data/c4_msmarco_news_s2orc_wiki_train.txt --dev_file data/c4_msmarco_news_s2orc_wiki_dev.txt --train_samples 100000000 --model_name train-w2v-model/c4_msmarco_news_s2orc_wiki/distilbert-256k/
|
train_steps.log
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
10000: 2.9510703086853027
|
2 |
+
20000: 2.6985881328582764
|
3 |
+
30000: 2.6190617084503174
|
4 |
+
40000: 2.554388999938965
|
5 |
+
50000: 2.5151987075805664
|
6 |
+
60000: 2.4846017360687256
|
7 |
+
70000: 2.456629514694214
|
8 |
+
80000: 2.4532623291015625
|
9 |
+
90000: 2.4096949100494385
|
10 |
+
100000: 2.390449047088623
|
11 |
+
110000: 2.379936695098877
|
12 |
+
120000: 2.3351876735687256
|
13 |
+
130000: 2.360316276550293
|
14 |
+
140000: 2.343963623046875
|
15 |
+
150000: 2.343773365020752
|
16 |
+
160000: 2.326107978820801
|
17 |
+
170000: 2.308845043182373
|
18 |
+
180000: 2.319910764694214
|
19 |
+
190000: 2.269522190093994
|
20 |
+
200000: 2.306370735168457
|
21 |
+
210000: 2.256661891937256
|
22 |
+
220000: 2.285386562347412
|
23 |
+
230000: 2.2685086727142334
|
24 |
+
240000: 2.2510366439819336
|
25 |
+
250000: 2.280529737472534
|
26 |
+
260000: 2.2705845832824707
|
27 |
+
270000: 2.2605068683624268
|
28 |
+
280000: 2.24812388420105
|
29 |
+
290000: 2.2429370880126953
|
30 |
+
300000: 2.249992609024048
|
31 |
+
310000: 2.245483636856079
|
32 |
+
320000: 2.2558467388153076
|
33 |
+
330000: 2.2274234294891357
|
34 |
+
340000: 2.2205779552459717
|
35 |
+
350000: 2.220641851425171
|
36 |
+
360000: 2.224605083465576
|
37 |
+
370000: 2.2072439193725586
|
38 |
+
380000: 2.207751750946045
|
39 |
+
390000: 2.187239170074463
|
40 |
+
400000: 2.1820247173309326
|
41 |
+
410000: 2.191732168197632
|
42 |
+
420000: 2.1815905570983887
|
43 |
+
430000: 2.188547372817993
|
44 |
+
440000: 2.1916229724884033
|
45 |
+
450000: 2.175816774368286
|
46 |
+
460000: 2.188457489013672
|
47 |
+
470000: 2.1955676078796387
|
48 |
+
480000: 2.1778831481933594
|
49 |
+
490000: 2.188725233078003
|
50 |
+
500000: 2.186866044998169
|
51 |
+
510000: 2.165159225463867
|
52 |
+
520000: 2.179119825363159
|
53 |
+
530000: 2.1783058643341064
|
54 |
+
540000: 2.1477503776550293
|
55 |
+
550000: 2.183104991912842
|
56 |
+
560000: 2.1740524768829346
|
57 |
+
570000: 2.1488537788391113
|
58 |
+
580000: 2.1921725273132324
|
59 |
+
590000: 2.1152608394622803
|
60 |
+
600000: 2.134315252304077
|
61 |
+
610000: 2.1524155139923096
|
62 |
+
620000: 2.1292426586151123
|
63 |
+
630000: 2.125551462173462
|
64 |
+
640000: 2.167358875274658
|
65 |
+
650000: 2.146320343017578
|
66 |
+
660000: 2.144334077835083
|
67 |
+
670000: 2.141312599182129
|
68 |
+
680000: 2.11755108833313
|
69 |
+
690000: 2.1170132160186768
|
70 |
+
700000: 2.1400225162506104
|
71 |
+
710000: 2.0839574337005615
|
72 |
+
720000: 2.1461777687072754
|
73 |
+
730000: 2.1319420337677
|
74 |
+
740000: 2.127019166946411
|
75 |
+
750000: 2.13002347946167
|
76 |
+
760000: 2.1098318099975586
|
77 |
+
770000: 2.12105131149292
|
78 |
+
780000: 2.1280393600463867
|
79 |
+
790000: 2.1270456314086914
|
80 |
+
800000: 2.092860221862793
|
81 |
+
810000: 2.1097354888916016
|
82 |
+
820000: 2.122591495513916
|
83 |
+
830000: 2.0946450233459473
|
84 |
+
840000: 2.0991997718811035
|
85 |
+
850000: 2.084550619125366
|
86 |
+
860000: 2.1101889610290527
|
87 |
+
870000: 2.0824830532073975
|
88 |
+
880000: 2.0871477127075195
|
89 |
+
890000: 2.086862802505493
|
90 |
+
900000: 2.0873560905456543
|
91 |
+
910000: 2.0829179286956787
|
92 |
+
920000: 2.1007256507873535
|
93 |
+
930000: 2.0886971950531006
|
94 |
+
940000: 2.0912179946899414
|
95 |
+
950000: 2.0809004306793213
|
96 |
+
960000: 2.0886764526367188
|
97 |
+
970000: 2.0796029567718506
|
98 |
+
980000: 2.049144983291626
|
99 |
+
990000: 2.0718026161193848
|
100 |
+
1000000: 2.082963705062866
|