|
import os |
|
import torch.cuda |
|
import wandb |
|
import argparse |
|
import pytorch_lightning as pl |
|
from pytorch_lightning.loggers import WandbLogger |
|
from termcolor import colored |
|
from transformers import BertForMaskedLM, BertTokenizer |
|
from idiomify.datamodules import Idiom2DefDataModule |
|
from idiomify.fetchers import fetch_config, fetch_idioms |
|
from idiomify.models import Alpha, Gamma |
|
from idiomify.paths import ROOT_DIR |
|
from idiomify import tensors as T |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model", type=str, default="alpha") |
|
parser.add_argument("--ver", type=str, default="eng2eng") |
|
parser.add_argument("--num_workers", type=int, default=os.cpu_count()) |
|
parser.add_argument("--log_every_n_steps", type=int, default=1) |
|
parser.add_argument("--fast_dev_run", action="store_true", default=False) |
|
parser.add_argument("--upload", dest='upload', action='store_true', default=False) |
|
args = parser.parse_args() |
|
config = fetch_config()[args.model][args.ver] |
|
config.update(vars(args)) |
|
if not config['upload']: |
|
print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red")) |
|
|
|
|
|
mlm = BertForMaskedLM.from_pretrained(config['bert']) |
|
tokenizer = BertTokenizer.from_pretrained(config['bert']) |
|
idioms = fetch_idioms(config['idioms_ver']) |
|
idiom2subwords = T.idiom2subwords(idioms, tokenizer, config['k']) |
|
|
|
if config['model'] == Alpha.name(): |
|
rd = Alpha(mlm, idiom2subwords, config['k'], config['lr']) |
|
elif config['model'] == Gamma.name(): |
|
rd = Gamma(mlm, idiom2subwords, config['k'], config['lr']) |
|
else: |
|
raise ValueError |
|
|
|
datamodule = Idiom2DefDataModule(config, tokenizer, idioms) |
|
|
|
with wandb.init(entity="eubinecto", project="idiomify-demo", config=config) as run: |
|
logger = WandbLogger(log_model=False) |
|
trainer = pl.Trainer(max_epochs=config['max_epochs'], |
|
fast_dev_run=config['fast_dev_run'], |
|
log_every_n_steps=config['log_every_n_steps'], |
|
gpus=torch.cuda.device_count(), |
|
default_root_dir=str(ROOT_DIR), |
|
enable_checkpointing=False, |
|
logger=logger) |
|
|
|
trainer.fit(model=rd, datamodule=datamodule) |
|
|
|
if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1: |
|
ckpt_path = ROOT_DIR / "rd.ckpt" |
|
trainer.save_checkpoint(str(ckpt_path)) |
|
artifact = wandb.Artifact(name=config['model'], type="model", metadata=config) |
|
artifact.add_file(str(ckpt_path)) |
|
run.log_artifact(artifact, aliases=["latest", config['ver']]) |
|
os.remove(str(ckpt_path)) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|