|
import os |
|
from dataclasses import replace |
|
|
|
import jax |
|
import wandb |
|
from bigbird_flax import Args, DataCollator, FlaxBigBirdForNaturalQuestions, Trainer, build_tx, train_step, val_step |
|
from datasets import load_dataset |
|
from flax import jax_utils |
|
|
|
from transformers import BigBirdTokenizerFast |
|
|
|
|
|
if __name__ == "__main__": |
|
print("#################### AVAILABLE DEVICES ####################") |
|
print(jax.devices()) |
|
print("###########################################################") |
|
|
|
|
|
args = Args() |
|
logger = wandb.init(project="bigbird-natural-questions", config=args.__dict__) |
|
wandb_args = dict(logger.config) |
|
del wandb_args["batch_size"] |
|
args = replace(args, **wandb_args) |
|
base_dir = args.base_dir + "-" + wandb.run.id |
|
args = replace(args, base_dir=base_dir) |
|
print(args) |
|
|
|
tr_dataset = load_dataset("json", data_files=args.tr_data_path)["train"] |
|
val_dataset = load_dataset("json", data_files=args.val_data_path)["train"] |
|
|
|
|
|
indices = range(len(tr_dataset) - len(tr_dataset) % args.batch_size) |
|
tr_dataset = tr_dataset.shuffle().select(indices) |
|
indices = range(len(val_dataset) - len(val_dataset) % args.batch_size) |
|
val_dataset = val_dataset.shuffle().select(indices) |
|
|
|
if os.environ.get("TRAIN_ON_SMALL", "false") == "true": |
|
tr_dataset = tr_dataset.shuffle().select(range(80000)) |
|
val_dataset = val_dataset.shuffle().select(range(8000)) |
|
|
|
print(tr_dataset) |
|
print(val_dataset) |
|
|
|
model = FlaxBigBirdForNaturalQuestions.from_pretrained( |
|
args.model_id, block_size=args.block_size, num_random_blocks=args.num_random_blocks |
|
) |
|
tokenizer = BigBirdTokenizerFast.from_pretrained(args.model_id) |
|
data_collator = DataCollator(pad_id=tokenizer.pad_token_id, max_length=4096) |
|
|
|
tx_args = { |
|
"lr": args.lr, |
|
"init_lr": args.init_lr, |
|
"warmup_steps": args.warmup_steps, |
|
"num_train_steps": args.max_epochs * (len(tr_dataset) // args.batch_size), |
|
"weight_decay": args.weight_decay, |
|
} |
|
tx, lr = build_tx(**tx_args) |
|
|
|
trainer = Trainer( |
|
args=args, |
|
data_collator=data_collator, |
|
model_save_fn=model.save_pretrained, |
|
train_step_fn=train_step, |
|
val_step_fn=val_step, |
|
logger=logger, |
|
scheduler_fn=lr, |
|
) |
|
|
|
ckpt_dir = None |
|
state = trainer.create_state(model, tx, num_train_steps=tx_args["num_train_steps"], ckpt_dir=ckpt_dir) |
|
try: |
|
trainer.train(state, tr_dataset, val_dataset) |
|
except KeyboardInterrupt: |
|
print("Oooops; TRAINING STOPPED UNFORTUNATELY") |
|
|
|
print("SAVING WEIGHTS IN `final-weights`") |
|
params = jax_utils.unreplicate(state.params) |
|
model.save_pretrained(os.path.join(args.base_dir, "final-weights"), params=params) |
|
|