[#1] Alpha implemented
Browse files- config.yaml +6 -44
- explore/explore_bart_logits_shape.py +39 -0
- explore/explore_idiomifydatamodule.py +1 -1
- idiomify/builders.py +2 -3
- idiomify/{datamodules.py → data.py} +9 -8
- idiomify/fetchers.py +2 -2
- idiomify/metrics.py +4 -0
- idiomify/models.py +40 -58
- main_train.py +17 -23
config.yaml
CHANGED
@@ -1,46 +1,8 @@
|
|
1 |
alpha:
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
k: 11
|
9 |
-
lr: 0.00001
|
10 |
-
max_epochs: 10
|
11 |
-
batch_size: 64
|
12 |
-
shuffle: true
|
13 |
-
kor2eng:
|
14 |
-
bert: bert-base-multilingual-uncased
|
15 |
-
desc:
|
16 |
-
seed: 410
|
17 |
-
idioms_ver: c
|
18 |
-
idiom2def_ver: d
|
19 |
-
k: 11
|
20 |
-
lr: 0.00001
|
21 |
-
max_epochs: 20
|
22 |
-
batch_size: 64
|
23 |
-
num_workers: 4
|
24 |
-
shuffle: true
|
25 |
-
gamma:
|
26 |
-
eng2eng:
|
27 |
-
bert: bert-base-uncased
|
28 |
-
seed: 410
|
29 |
-
idioms_ver: c
|
30 |
-
idiom2def_ver: c
|
31 |
-
k: 11
|
32 |
-
lr: 0.00001
|
33 |
-
max_epochs: 50
|
34 |
-
batch_size: 64
|
35 |
-
shuffle: true
|
36 |
-
kor2eng:
|
37 |
-
bert: bert-base-multilingual-uncased
|
38 |
-
seed: 410
|
39 |
-
idioms_ver: c
|
40 |
-
idiom2def_ver: d
|
41 |
-
k: 11
|
42 |
-
lr: 0.00001
|
43 |
-
max_epochs: 50
|
44 |
-
batch_size: 64
|
45 |
-
num_workers: 4
|
46 |
shuffle: true
|
|
|
1 |
alpha:
|
2 |
+
overfit:
|
3 |
+
bart: facebook/bart-base
|
4 |
+
lr: 0.0001
|
5 |
+
literal2idiomatic_ver: pie_v0
|
6 |
+
max_epochs: 100
|
7 |
+
batch_size: 100
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
shuffle: true
|
explore/explore_bart_logits_shape.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BartTokenizer, BartForConditionalGeneration
|
2 |
+
|
3 |
+
from data import IdiomifyDataModule
|
4 |
+
|
5 |
+
|
6 |
+
CONFIG = {
|
7 |
+
"literal2idiomatic_ver": "pie_v0",
|
8 |
+
"batch_size": 20,
|
9 |
+
"num_workers": 4,
|
10 |
+
"shuffle": True
|
11 |
+
}
|
12 |
+
|
13 |
+
|
14 |
+
def main():
|
15 |
+
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
|
16 |
+
bart = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
|
17 |
+
datamodule = IdiomifyDataModule(CONFIG, tokenizer)
|
18 |
+
datamodule.prepare_data()
|
19 |
+
datamodule.setup()
|
20 |
+
for batch in datamodule.train_dataloader():
|
21 |
+
srcs, tgts_r, tgts = batch
|
22 |
+
input_ids, attention_mask = srcs[:, 0], srcs[:, 1] # noqa
|
23 |
+
decoder_input_ids, decoder_attention_mask = tgts_r[:, 0], tgts_r[:, 1]
|
24 |
+
outputs = bart(input_ids=input_ids,
|
25 |
+
attention_mask=attention_mask,
|
26 |
+
decoder_input_ids=decoder_input_ids,
|
27 |
+
decoder_attention_mask=decoder_attention_mask)
|
28 |
+
logits = outputs[0]
|
29 |
+
print(logits.shape)
|
30 |
+
"""
|
31 |
+
torch.Size([20, 47, 50265])
|
32 |
+
(N, L, |V|)
|
33 |
+
"""
|
34 |
+
|
35 |
+
break
|
36 |
+
|
37 |
+
|
38 |
+
if __name__ == '__main__':
|
39 |
+
main()
|
explore/explore_idiomifydatamodule.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from transformers import BartTokenizer
|
2 |
-
from idiomify.
|
3 |
|
4 |
|
5 |
CONFIG = {
|
|
|
1 |
from transformers import BartTokenizer
|
2 |
+
from idiomify.data import IdiomifyDataModule
|
3 |
|
4 |
|
5 |
CONFIG = {
|
idiomify/builders.py
CHANGED
@@ -81,8 +81,7 @@ class TargetsBuilder(TensorBuilder):
|
|
81 |
idiomatic + self.tokenizer.eos_token # no bos, but ends with eos
|
82 |
for _, idiomatic in literal2idiomatic
|
83 |
], return_tensors="pt", add_special_tokens=False, padding=True, truncation=True)
|
84 |
-
tgts =
|
85 |
-
|
86 |
-
return tgts
|
87 |
|
88 |
|
|
|
81 |
idiomatic + self.tokenizer.eos_token # no bos, but ends with eos
|
82 |
for _, idiomatic in literal2idiomatic
|
83 |
], return_tensors="pt", add_special_tokens=False, padding=True, truncation=True)
|
84 |
+
tgts = encodings['input_ids']
|
85 |
+
return tgts # (N, L)
|
|
|
86 |
|
87 |
|
idiomify/{datamodules.py → data.py}
RENAMED
@@ -2,6 +2,8 @@ import torch
|
|
2 |
from typing import Tuple, Optional, List
|
3 |
from torch.utils.data import Dataset, DataLoader
|
4 |
from pytorch_lightning import LightningDataModule
|
|
|
|
|
5 |
from idiomify.fetchers import fetch_literal2idiomatic
|
6 |
from idiomify.builders import SourcesBuilder, TargetsBuilder, TargetsRightShiftedBuilder
|
7 |
from transformers import BartTokenizer
|
@@ -12,9 +14,9 @@ class IdiomifyDataset(Dataset):
|
|
12 |
srcs: torch.Tensor,
|
13 |
tgts_r: torch.Tensor,
|
14 |
tgts: torch.Tensor):
|
15 |
-
self.srcs = srcs
|
16 |
-
self.tgts_r = tgts_r
|
17 |
-
self.tgts = tgts
|
18 |
|
19 |
def __len__(self) -> int:
|
20 |
"""
|
@@ -47,10 +49,12 @@ class IdiomifyDataModule(LightningDataModule):
|
|
47 |
|
48 |
def __init__(self,
|
49 |
config: dict,
|
50 |
-
tokenizer: BartTokenizer
|
|
|
51 |
super().__init__()
|
52 |
self.config = config
|
53 |
self.tokenizer = tokenizer
|
|
|
54 |
# --- to be downloaded & built --- #
|
55 |
self.literal2idiomatic: Optional[List[Tuple[str, str]]] = None
|
56 |
self.dataset: Optional[IdiomifyDataset] = None
|
@@ -59,12 +63,9 @@ class IdiomifyDataModule(LightningDataModule):
|
|
59 |
"""
|
60 |
prepare: download all data needed for this from wandb to local.
|
61 |
"""
|
62 |
-
self.literal2idiomatic = fetch_literal2idiomatic(self.config['literal2idiomatic_ver'])
|
63 |
|
64 |
def setup(self, stage: Optional[str] = None):
|
65 |
-
"""
|
66 |
-
setup the builders.
|
67 |
-
"""
|
68 |
# --- set up the builders --- #
|
69 |
# build the datasets
|
70 |
srcs = SourcesBuilder(self.tokenizer)(self.literal2idiomatic)
|
|
|
2 |
from typing import Tuple, Optional, List
|
3 |
from torch.utils.data import Dataset, DataLoader
|
4 |
from pytorch_lightning import LightningDataModule
|
5 |
+
from wandb.sdk.wandb_run import Run
|
6 |
+
|
7 |
from idiomify.fetchers import fetch_literal2idiomatic
|
8 |
from idiomify.builders import SourcesBuilder, TargetsBuilder, TargetsRightShiftedBuilder
|
9 |
from transformers import BartTokenizer
|
|
|
14 |
srcs: torch.Tensor,
|
15 |
tgts_r: torch.Tensor,
|
16 |
tgts: torch.Tensor):
|
17 |
+
self.srcs = srcs # (N, 2, L)
|
18 |
+
self.tgts_r = tgts_r # (N, 2, L)
|
19 |
+
self.tgts = tgts # (N, L)
|
20 |
|
21 |
def __len__(self) -> int:
|
22 |
"""
|
|
|
49 |
|
50 |
def __init__(self,
|
51 |
config: dict,
|
52 |
+
tokenizer: BartTokenizer,
|
53 |
+
run: Run = None):
|
54 |
super().__init__()
|
55 |
self.config = config
|
56 |
self.tokenizer = tokenizer
|
57 |
+
self.run = run
|
58 |
# --- to be downloaded & built --- #
|
59 |
self.literal2idiomatic: Optional[List[Tuple[str, str]]] = None
|
60 |
self.dataset: Optional[IdiomifyDataset] = None
|
|
|
63 |
"""
|
64 |
prepare: download all data needed for this from wandb to local.
|
65 |
"""
|
66 |
+
self.literal2idiomatic = fetch_literal2idiomatic(self.config['literal2idiomatic_ver'], self.run)
|
67 |
|
68 |
def setup(self, stage: Optional[str] = None):
|
|
|
|
|
|
|
69 |
# --- set up the builders --- #
|
70 |
# build the datasets
|
71 |
srcs = SourcesBuilder(self.tokenizer)(self.literal2idiomatic)
|
idiomify/fetchers.py
CHANGED
@@ -62,7 +62,7 @@ def fetch_idioms(ver: str, run: Run = None) -> List[str]:
|
|
62 |
# if run object is given, we track the lineage of the data.
|
63 |
# if not, we get the dataset via wandb Api.
|
64 |
if run:
|
65 |
-
artifact = run.use_artifact("idioms", type="dataset"
|
66 |
else:
|
67 |
artifact = wandb.Api().artifact(f"eubinecto/idiomify/idioms:{ver}", type="dataset")
|
68 |
artifact_dir = artifact.download(root=idioms_dir(ver))
|
@@ -75,7 +75,7 @@ def fetch_literal2idiomatic(ver: str, run: Run = None) -> List[Tuple[str, str]]:
|
|
75 |
# if run object is given, we track the lineage of the data.
|
76 |
# if not, we get the dataset via wandb Api.
|
77 |
if run:
|
78 |
-
artifact = run.use_artifact("
|
79 |
else:
|
80 |
artifact = wandb.Api().artifact(f"eubinecto/idiomify/literal2idiomatic:{ver}", type="dataset")
|
81 |
artifact_dir = artifact.download(root=literal2idiomatic(ver))
|
|
|
62 |
# if run object is given, we track the lineage of the data.
|
63 |
# if not, we get the dataset via wandb Api.
|
64 |
if run:
|
65 |
+
artifact = run.use_artifact(f"idioms:{ver}", type="dataset")
|
66 |
else:
|
67 |
artifact = wandb.Api().artifact(f"eubinecto/idiomify/idioms:{ver}", type="dataset")
|
68 |
artifact_dir = artifact.download(root=idioms_dir(ver))
|
|
|
75 |
# if run object is given, we track the lineage of the data.
|
76 |
# if not, we get the dataset via wandb Api.
|
77 |
if run:
|
78 |
+
artifact = run.use_artifact(f"literal2idiomatic:{ver}", type="dataset")
|
79 |
else:
|
80 |
artifact = wandb.Api().artifact(f"eubinecto/idiomify/literal2idiomatic:{ver}", type="dataset")
|
81 |
artifact_dir = artifact.download(root=literal2idiomatic(ver))
|
idiomify/metrics.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
you may want to include bleu score.
|
3 |
+
and more metrics for paraphrasing.
|
4 |
+
"""
|
idiomify/models.py
CHANGED
@@ -1,56 +1,56 @@
|
|
1 |
"""
|
2 |
The reverse dictionary models below are based off of: https://github.com/yhcc/BertForRD/blob/master/mono/model/bert.py
|
3 |
"""
|
4 |
-
from typing import Tuple
|
5 |
import torch
|
6 |
from torch.nn import functional as F
|
7 |
import pytorch_lightning as pl
|
8 |
-
from transformers import
|
9 |
|
10 |
|
11 |
-
class
|
12 |
"""
|
13 |
-
|
14 |
-
The superclass of all the reverse-dictionaries. This class houses any methods that are required by
|
15 |
-
whatever reverse-dictionaries we define.
|
16 |
"""
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
20 |
|
21 |
-
def
|
22 |
-
pass
|
23 |
-
|
24 |
-
def val_dataloader(self):
|
25 |
-
pass
|
26 |
-
|
27 |
-
def predict_dataloader(self):
|
28 |
-
pass
|
29 |
-
|
30 |
-
def __init__(self, mlm: BertForMaskedLM, idiom2subwords: torch.Tensor, k: int, lr: float): # noqa
|
31 |
-
"""
|
32 |
-
:param mlm: a bert model for masked language modeling
|
33 |
-
:param idiom2subwords: (|W|, K)
|
34 |
-
:return: (N, K, |V|); (num samples, k, the size of the vocabulary of subwords)
|
35 |
"""
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
:param X: (N, 3, L). input_ids, token_type_ids, and what was the last one...?
|
42 |
-
:return: (N, L, H)
|
43 |
"""
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
def configure_optimizers(self) -> torch.optim.Optimizer:
|
56 |
"""
|
@@ -59,21 +59,3 @@ class Idiomifier(pl.LightningModule):
|
|
59 |
"""
|
60 |
# The authors used Adam, so we might as well use it as well.
|
61 |
return torch.optim.AdamW(self.parameters(), lr=self.hparams['lr'])
|
62 |
-
|
63 |
-
@classmethod
|
64 |
-
def name(cls) -> str:
|
65 |
-
return cls.__name__.lower()
|
66 |
-
|
67 |
-
|
68 |
-
class Alpha(Idiomifier):
|
69 |
-
"""
|
70 |
-
@eubinecto
|
71 |
-
The first prototype.
|
72 |
-
S_wisdom = S_wisdom_literal
|
73 |
-
trained on: wisdom2def only.
|
74 |
-
"""
|
75 |
-
|
76 |
-
def S_wisdom(self, H_all: torch.Tensor) -> torch.Tensor:
|
77 |
-
H_k = self.H_k(H_all) # (N, L, H) -> (N, K, H)
|
78 |
-
S_wisdom = self.S_wisdom_literal(H_k) # (N, K, H) -> (N, |W|)
|
79 |
-
return S_wisdom
|
|
|
1 |
"""
|
2 |
The reverse dictionary models below are based off of: https://github.com/yhcc/BertForRD/blob/master/mono/model/bert.py
|
3 |
"""
|
4 |
+
from typing import Tuple
|
5 |
import torch
|
6 |
from torch.nn import functional as F
|
7 |
import pytorch_lightning as pl
|
8 |
+
from transformers import BartForConditionalGeneration
|
9 |
|
10 |
|
11 |
+
class Alpha(pl.LightningModule): # noqa
|
12 |
"""
|
13 |
+
the baseline.
|
|
|
|
|
14 |
"""
|
15 |
+
def __init__(self, bart: BartForConditionalGeneration, lr: float, bos_token_id: int, pad_token_id: int): # noqa
|
16 |
+
super().__init__()
|
17 |
+
self.bart = bart
|
18 |
+
self.save_hyperparameters(ignore=["bart"])
|
19 |
|
20 |
+
def forward(self, srcs: torch.Tensor, tgts_r: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
"""
|
22 |
+
as for using bart for CG, refer to:
|
23 |
+
https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartForQuestionAnswering.forward
|
24 |
+
param srcs: (N, 2, L_s)
|
25 |
+
param tgts_r: (N, 2, L_t)
|
26 |
+
return: (N, L, |V|)
|
|
|
|
|
27 |
"""
|
28 |
+
input_ids, attention_mask = srcs[:, 0], srcs[:, 1]
|
29 |
+
decoder_input_ids, decoder_attention_mask = tgts_r[:, 0], tgts_r[:, 1]
|
30 |
+
outputs = self.bart(input_ids=input_ids,
|
31 |
+
attention_mask=attention_mask,
|
32 |
+
decoder_input_ids=decoder_input_ids,
|
33 |
+
decoder_attention_mask=decoder_attention_mask)
|
34 |
+
logits = outputs[0] # (N, L, |V|)
|
35 |
+
return logits
|
36 |
+
|
37 |
+
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> dict:
|
38 |
+
srcs, tgts_r, tgts = batch # (N, 2, L_s), (N, 2, L_t), (N, 2, L_t)
|
39 |
+
logits = self.forward(srcs, tgts_r) # -> (N, L, |V|)
|
40 |
+
logits = logits.transpose(1, 2) # (N, L, |V|) -> (N, |V|, L)
|
41 |
+
loss = F.cross_entropy(logits, tgts, ignore_index=self.hparams['pad_token_id'])\
|
42 |
+
.sum() # (N, L, |V|), (N, L) -> (N,) -> (1,)
|
43 |
+
return {
|
44 |
+
"loss": loss
|
45 |
+
}
|
46 |
+
|
47 |
+
def predict(self, srcs: torch.Tensor) -> torch.Tensor:
|
48 |
+
pred_ids = self.bart.generate(
|
49 |
+
inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
|
50 |
+
attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
|
51 |
+
decoder_start_token_id=self.hparams['bos_token_id'],
|
52 |
+
)
|
53 |
+
return pred_ids # (N, L)
|
54 |
|
55 |
def configure_optimizers(self) -> torch.optim.Optimizer:
|
56 |
"""
|
|
|
59 |
"""
|
60 |
# The authors used Adam, so we might as well use it as well.
|
61 |
return torch.optim.AdamW(self.parameters(), lr=self.hparams['lr'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main_train.py
CHANGED
@@ -3,20 +3,19 @@ import torch.cuda
|
|
3 |
import wandb
|
4 |
import argparse
|
5 |
import pytorch_lightning as pl
|
6 |
-
from pytorch_lightning.loggers import WandbLogger
|
7 |
from termcolor import colored
|
8 |
-
from
|
9 |
-
from
|
10 |
-
from idiomify.
|
11 |
-
from idiomify.
|
|
|
12 |
from idiomify.paths import ROOT_DIR
|
13 |
-
from idiomify import tensors as T
|
14 |
|
15 |
|
16 |
def main():
|
17 |
parser = argparse.ArgumentParser()
|
18 |
parser.add_argument("--model", type=str, default="alpha")
|
19 |
-
parser.add_argument("--ver", type=str, default="
|
20 |
parser.add_argument("--num_workers", type=int, default=os.cpu_count())
|
21 |
parser.add_argument("--log_every_n_steps", type=int, default=1)
|
22 |
parser.add_argument("--fast_dev_run", action="store_true", default=False)
|
@@ -27,22 +26,17 @@ def main():
|
|
27 |
if not config['upload']:
|
28 |
print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
|
29 |
|
30 |
-
# prepare
|
31 |
-
|
32 |
-
tokenizer =
|
33 |
-
|
34 |
-
|
35 |
-
# choose the model to train
|
36 |
-
if config['model'] == Alpha.name():
|
37 |
-
rd = Alpha(mlm, idiom2subwords, config['k'], config['lr'])
|
38 |
-
elif config['model'] == Gamma.name():
|
39 |
-
rd = Gamma(mlm, idiom2subwords, config['k'], config['lr'])
|
40 |
else:
|
41 |
-
raise
|
42 |
-
# prepare datamodule
|
43 |
-
datamodule = IdiomifyDataModule(config, tokenizer, idioms)
|
44 |
|
45 |
-
with wandb.init(entity="eubinecto", project="idiomify
|
|
|
46 |
logger = WandbLogger(log_model=False)
|
47 |
trainer = pl.Trainer(max_epochs=config['max_epochs'],
|
48 |
fast_dev_run=config['fast_dev_run'],
|
@@ -52,10 +46,10 @@ def main():
|
|
52 |
enable_checkpointing=False,
|
53 |
logger=logger)
|
54 |
# start training
|
55 |
-
trainer.fit(model=
|
56 |
# upload the model to wandb only if the training is properly done #
|
57 |
if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1:
|
58 |
-
ckpt_path = ROOT_DIR / "
|
59 |
trainer.save_checkpoint(str(ckpt_path))
|
60 |
artifact = wandb.Artifact(name=config['model'], type="model", metadata=config)
|
61 |
artifact.add_file(str(ckpt_path))
|
|
|
3 |
import wandb
|
4 |
import argparse
|
5 |
import pytorch_lightning as pl
|
|
|
6 |
from termcolor import colored
|
7 |
+
from pytorch_lightning.loggers import WandbLogger
|
8 |
+
from transformers import BartTokenizer, BartForConditionalGeneration
|
9 |
+
from idiomify.data import IdiomifyDataModule
|
10 |
+
from idiomify.fetchers import fetch_config
|
11 |
+
from idiomify.models import Alpha
|
12 |
from idiomify.paths import ROOT_DIR
|
|
|
13 |
|
14 |
|
15 |
def main():
|
16 |
parser = argparse.ArgumentParser()
|
17 |
parser.add_argument("--model", type=str, default="alpha")
|
18 |
+
parser.add_argument("--ver", type=str, default="overfit ")
|
19 |
parser.add_argument("--num_workers", type=int, default=os.cpu_count())
|
20 |
parser.add_argument("--log_every_n_steps", type=int, default=1)
|
21 |
parser.add_argument("--fast_dev_run", action="store_true", default=False)
|
|
|
26 |
if not config['upload']:
|
27 |
print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
|
28 |
|
29 |
+
# prepare the model
|
30 |
+
bart = BartForConditionalGeneration.from_pretrained(config['bart'])
|
31 |
+
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
32 |
+
if config['model'] == "alpha":
|
33 |
+
model = Alpha(bart, config['lr'], tokenizer.bos_token_id, tokenizer.pad_token_id)
|
|
|
|
|
|
|
|
|
|
|
34 |
else:
|
35 |
+
raise NotImplementedError
|
36 |
+
# prepare the datamodule
|
|
|
37 |
|
38 |
+
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
39 |
+
datamodule = IdiomifyDataModule(config, tokenizer, run)
|
40 |
logger = WandbLogger(log_model=False)
|
41 |
trainer = pl.Trainer(max_epochs=config['max_epochs'],
|
42 |
fast_dev_run=config['fast_dev_run'],
|
|
|
46 |
enable_checkpointing=False,
|
47 |
logger=logger)
|
48 |
# start training
|
49 |
+
trainer.fit(model=model, datamodule=datamodule)
|
50 |
# upload the model to wandb only if the training is properly done #
|
51 |
if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1:
|
52 |
+
ckpt_path = ROOT_DIR / "model.ckpt"
|
53 |
trainer.save_checkpoint(str(ckpt_path))
|
54 |
artifact = wandb.Artifact(name=config['model'], type="model", metadata=config)
|
55 |
artifact.add_file(str(ckpt_path))
|