eubinecto commited on
Commit
d8d4c8d
·
1 Parent(s): 25e310b

[#1] Alpha implemented

Browse files
config.yaml CHANGED
@@ -1,46 +1,8 @@
1
  alpha:
2
- eng2eng:
3
- bert: bert-base-uncased
4
- desc:
5
- seed: 410
6
- idioms_ver: c
7
- idiom2def_ver: c
8
- k: 11
9
- lr: 0.00001
10
- max_epochs: 10
11
- batch_size: 64
12
- shuffle: true
13
- kor2eng:
14
- bert: bert-base-multilingual-uncased
15
- desc:
16
- seed: 410
17
- idioms_ver: c
18
- idiom2def_ver: d
19
- k: 11
20
- lr: 0.00001
21
- max_epochs: 20
22
- batch_size: 64
23
- num_workers: 4
24
- shuffle: true
25
- gamma:
26
- eng2eng:
27
- bert: bert-base-uncased
28
- seed: 410
29
- idioms_ver: c
30
- idiom2def_ver: c
31
- k: 11
32
- lr: 0.00001
33
- max_epochs: 50
34
- batch_size: 64
35
- shuffle: true
36
- kor2eng:
37
- bert: bert-base-multilingual-uncased
38
- seed: 410
39
- idioms_ver: c
40
- idiom2def_ver: d
41
- k: 11
42
- lr: 0.00001
43
- max_epochs: 50
44
- batch_size: 64
45
- num_workers: 4
46
  shuffle: true
 
1
  alpha:
2
+ overfit:
3
+ bart: facebook/bart-base
4
+ lr: 0.0001
5
+ literal2idiomatic_ver: pie_v0
6
+ max_epochs: 100
7
+ batch_size: 100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  shuffle: true
explore/explore_bart_logits_shape.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BartTokenizer, BartForConditionalGeneration
2
+
3
+ from data import IdiomifyDataModule
4
+
5
+
6
+ CONFIG = {
7
+ "literal2idiomatic_ver": "pie_v0",
8
+ "batch_size": 20,
9
+ "num_workers": 4,
10
+ "shuffle": True
11
+ }
12
+
13
+
14
+ def main():
15
+ tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
16
+ bart = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
17
+ datamodule = IdiomifyDataModule(CONFIG, tokenizer)
18
+ datamodule.prepare_data()
19
+ datamodule.setup()
20
+ for batch in datamodule.train_dataloader():
21
+ srcs, tgts_r, tgts = batch
22
+ input_ids, attention_mask = srcs[:, 0], srcs[:, 1] # noqa
23
+ decoder_input_ids, decoder_attention_mask = tgts_r[:, 0], tgts_r[:, 1]
24
+ outputs = bart(input_ids=input_ids,
25
+ attention_mask=attention_mask,
26
+ decoder_input_ids=decoder_input_ids,
27
+ decoder_attention_mask=decoder_attention_mask)
28
+ logits = outputs[0]
29
+ print(logits.shape)
30
+ """
31
+ torch.Size([20, 47, 50265])
32
+ (N, L, |V|)
33
+ """
34
+
35
+ break
36
+
37
+
38
+ if __name__ == '__main__':
39
+ main()
explore/explore_idiomifydatamodule.py CHANGED
@@ -1,5 +1,5 @@
1
  from transformers import BartTokenizer
2
- from idiomify.datamodules import IdiomifyDataModule
3
 
4
 
5
  CONFIG = {
 
1
  from transformers import BartTokenizer
2
+ from idiomify.data import IdiomifyDataModule
3
 
4
 
5
  CONFIG = {
idiomify/builders.py CHANGED
@@ -81,8 +81,7 @@ class TargetsBuilder(TensorBuilder):
81
  idiomatic + self.tokenizer.eos_token # no bos, but ends with eos
82
  for _, idiomatic in literal2idiomatic
83
  ], return_tensors="pt", add_special_tokens=False, padding=True, truncation=True)
84
- tgts = torch.stack([encodings['input_ids'],
85
- encodings['attention_mask']], dim=1) # (N, 2, L)
86
- return tgts
87
 
88
 
 
81
  idiomatic + self.tokenizer.eos_token # no bos, but ends with eos
82
  for _, idiomatic in literal2idiomatic
83
  ], return_tensors="pt", add_special_tokens=False, padding=True, truncation=True)
84
+ tgts = encodings['input_ids']
85
+ return tgts # (N, L)
 
86
 
87
 
idiomify/{datamodules.py → data.py} RENAMED
@@ -2,6 +2,8 @@ import torch
2
  from typing import Tuple, Optional, List
3
  from torch.utils.data import Dataset, DataLoader
4
  from pytorch_lightning import LightningDataModule
 
 
5
  from idiomify.fetchers import fetch_literal2idiomatic
6
  from idiomify.builders import SourcesBuilder, TargetsBuilder, TargetsRightShiftedBuilder
7
  from transformers import BartTokenizer
@@ -12,9 +14,9 @@ class IdiomifyDataset(Dataset):
12
  srcs: torch.Tensor,
13
  tgts_r: torch.Tensor,
14
  tgts: torch.Tensor):
15
- self.srcs = srcs
16
- self.tgts_r = tgts_r
17
- self.tgts = tgts
18
 
19
  def __len__(self) -> int:
20
  """
@@ -47,10 +49,12 @@ class IdiomifyDataModule(LightningDataModule):
47
 
48
  def __init__(self,
49
  config: dict,
50
- tokenizer: BartTokenizer):
 
51
  super().__init__()
52
  self.config = config
53
  self.tokenizer = tokenizer
 
54
  # --- to be downloaded & built --- #
55
  self.literal2idiomatic: Optional[List[Tuple[str, str]]] = None
56
  self.dataset: Optional[IdiomifyDataset] = None
@@ -59,12 +63,9 @@ class IdiomifyDataModule(LightningDataModule):
59
  """
60
  prepare: download all data needed for this from wandb to local.
61
  """
62
- self.literal2idiomatic = fetch_literal2idiomatic(self.config['literal2idiomatic_ver'])
63
 
64
  def setup(self, stage: Optional[str] = None):
65
- """
66
- setup the builders.
67
- """
68
  # --- set up the builders --- #
69
  # build the datasets
70
  srcs = SourcesBuilder(self.tokenizer)(self.literal2idiomatic)
 
2
  from typing import Tuple, Optional, List
3
  from torch.utils.data import Dataset, DataLoader
4
  from pytorch_lightning import LightningDataModule
5
+ from wandb.sdk.wandb_run import Run
6
+
7
  from idiomify.fetchers import fetch_literal2idiomatic
8
  from idiomify.builders import SourcesBuilder, TargetsBuilder, TargetsRightShiftedBuilder
9
  from transformers import BartTokenizer
 
14
  srcs: torch.Tensor,
15
  tgts_r: torch.Tensor,
16
  tgts: torch.Tensor):
17
+ self.srcs = srcs # (N, 2, L)
18
+ self.tgts_r = tgts_r # (N, 2, L)
19
+ self.tgts = tgts # (N, L)
20
 
21
  def __len__(self) -> int:
22
  """
 
49
 
50
  def __init__(self,
51
  config: dict,
52
+ tokenizer: BartTokenizer,
53
+ run: Run = None):
54
  super().__init__()
55
  self.config = config
56
  self.tokenizer = tokenizer
57
+ self.run = run
58
  # --- to be downloaded & built --- #
59
  self.literal2idiomatic: Optional[List[Tuple[str, str]]] = None
60
  self.dataset: Optional[IdiomifyDataset] = None
 
63
  """
64
  prepare: download all data needed for this from wandb to local.
65
  """
66
+ self.literal2idiomatic = fetch_literal2idiomatic(self.config['literal2idiomatic_ver'], self.run)
67
 
68
  def setup(self, stage: Optional[str] = None):
 
 
 
69
  # --- set up the builders --- #
70
  # build the datasets
71
  srcs = SourcesBuilder(self.tokenizer)(self.literal2idiomatic)
idiomify/fetchers.py CHANGED
@@ -62,7 +62,7 @@ def fetch_idioms(ver: str, run: Run = None) -> List[str]:
62
  # if run object is given, we track the lineage of the data.
63
  # if not, we get the dataset via wandb Api.
64
  if run:
65
- artifact = run.use_artifact("idioms", type="dataset", aliases=ver)
66
  else:
67
  artifact = wandb.Api().artifact(f"eubinecto/idiomify/idioms:{ver}", type="dataset")
68
  artifact_dir = artifact.download(root=idioms_dir(ver))
@@ -75,7 +75,7 @@ def fetch_literal2idiomatic(ver: str, run: Run = None) -> List[Tuple[str, str]]:
75
  # if run object is given, we track the lineage of the data.
76
  # if not, we get the dataset via wandb Api.
77
  if run:
78
- artifact = run.use_artifact("literal2idiom", type="dataset", aliases=ver)
79
  else:
80
  artifact = wandb.Api().artifact(f"eubinecto/idiomify/literal2idiomatic:{ver}", type="dataset")
81
  artifact_dir = artifact.download(root=literal2idiomatic(ver))
 
62
  # if run object is given, we track the lineage of the data.
63
  # if not, we get the dataset via wandb Api.
64
  if run:
65
+ artifact = run.use_artifact(f"idioms:{ver}", type="dataset")
66
  else:
67
  artifact = wandb.Api().artifact(f"eubinecto/idiomify/idioms:{ver}", type="dataset")
68
  artifact_dir = artifact.download(root=idioms_dir(ver))
 
75
  # if run object is given, we track the lineage of the data.
76
  # if not, we get the dataset via wandb Api.
77
  if run:
78
+ artifact = run.use_artifact(f"literal2idiomatic:{ver}", type="dataset")
79
  else:
80
  artifact = wandb.Api().artifact(f"eubinecto/idiomify/literal2idiomatic:{ver}", type="dataset")
81
  artifact_dir = artifact.download(root=literal2idiomatic(ver))
idiomify/metrics.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ you may want to include bleu score.
3
+ and more metrics for paraphrasing.
4
+ """
idiomify/models.py CHANGED
@@ -1,56 +1,56 @@
1
  """
2
  The reverse dictionary models below are based off of: https://github.com/yhcc/BertForRD/blob/master/mono/model/bert.py
3
  """
4
- from typing import Tuple, List, Optional
5
  import torch
6
  from torch.nn import functional as F
7
  import pytorch_lightning as pl
8
- from transformers import BertForMaskedLM
9
 
10
 
11
- class Idiomifier(pl.LightningModule):
12
  """
13
- @eubinecto
14
- The superclass of all the reverse-dictionaries. This class houses any methods that are required by
15
- whatever reverse-dictionaries we define.
16
  """
17
- # passing them to avoid warnings --- #
18
- def train_dataloader(self):
19
- pass
 
20
 
21
- def test_dataloader(self):
22
- pass
23
-
24
- def val_dataloader(self):
25
- pass
26
-
27
- def predict_dataloader(self):
28
- pass
29
-
30
- def __init__(self, mlm: BertForMaskedLM, idiom2subwords: torch.Tensor, k: int, lr: float): # noqa
31
- """
32
- :param mlm: a bert model for masked language modeling
33
- :param idiom2subwords: (|W|, K)
34
- :return: (N, K, |V|); (num samples, k, the size of the vocabulary of subwords)
35
  """
36
- pass
37
-
38
- def forward(self, X: torch.Tensor) -> torch.Tensor:
39
- """
40
- given a batch, forward returns a batch of hidden vectors
41
- :param X: (N, 3, L). input_ids, token_type_ids, and what was the last one...?
42
- :return: (N, L, H)
43
  """
44
- pass
45
-
46
- def step(self):
47
- pass
48
-
49
- def predict(self):
50
- pass
51
-
52
- def training_step(self):
53
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  def configure_optimizers(self) -> torch.optim.Optimizer:
56
  """
@@ -59,21 +59,3 @@ class Idiomifier(pl.LightningModule):
59
  """
60
  # The authors used Adam, so we might as well use it as well.
61
  return torch.optim.AdamW(self.parameters(), lr=self.hparams['lr'])
62
-
63
- @classmethod
64
- def name(cls) -> str:
65
- return cls.__name__.lower()
66
-
67
-
68
- class Alpha(Idiomifier):
69
- """
70
- @eubinecto
71
- The first prototype.
72
- S_wisdom = S_wisdom_literal
73
- trained on: wisdom2def only.
74
- """
75
-
76
- def S_wisdom(self, H_all: torch.Tensor) -> torch.Tensor:
77
- H_k = self.H_k(H_all) # (N, L, H) -> (N, K, H)
78
- S_wisdom = self.S_wisdom_literal(H_k) # (N, K, H) -> (N, |W|)
79
- return S_wisdom
 
1
  """
2
  The reverse dictionary models below are based off of: https://github.com/yhcc/BertForRD/blob/master/mono/model/bert.py
3
  """
4
+ from typing import Tuple
5
  import torch
6
  from torch.nn import functional as F
7
  import pytorch_lightning as pl
8
+ from transformers import BartForConditionalGeneration
9
 
10
 
11
+ class Alpha(pl.LightningModule): # noqa
12
  """
13
+ the baseline.
 
 
14
  """
15
+ def __init__(self, bart: BartForConditionalGeneration, lr: float, bos_token_id: int, pad_token_id: int): # noqa
16
+ super().__init__()
17
+ self.bart = bart
18
+ self.save_hyperparameters(ignore=["bart"])
19
 
20
+ def forward(self, srcs: torch.Tensor, tgts_r: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  """
22
+ as for using bart for CG, refer to:
23
+ https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartForQuestionAnswering.forward
24
+ param srcs: (N, 2, L_s)
25
+ param tgts_r: (N, 2, L_t)
26
+ return: (N, L, |V|)
 
 
27
  """
28
+ input_ids, attention_mask = srcs[:, 0], srcs[:, 1]
29
+ decoder_input_ids, decoder_attention_mask = tgts_r[:, 0], tgts_r[:, 1]
30
+ outputs = self.bart(input_ids=input_ids,
31
+ attention_mask=attention_mask,
32
+ decoder_input_ids=decoder_input_ids,
33
+ decoder_attention_mask=decoder_attention_mask)
34
+ logits = outputs[0] # (N, L, |V|)
35
+ return logits
36
+
37
+ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> dict:
38
+ srcs, tgts_r, tgts = batch # (N, 2, L_s), (N, 2, L_t), (N, 2, L_t)
39
+ logits = self.forward(srcs, tgts_r) # -> (N, L, |V|)
40
+ logits = logits.transpose(1, 2) # (N, L, |V|) -> (N, |V|, L)
41
+ loss = F.cross_entropy(logits, tgts, ignore_index=self.hparams['pad_token_id'])\
42
+ .sum() # (N, L, |V|), (N, L) -> (N,) -> (1,)
43
+ return {
44
+ "loss": loss
45
+ }
46
+
47
+ def predict(self, srcs: torch.Tensor) -> torch.Tensor:
48
+ pred_ids = self.bart.generate(
49
+ inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
50
+ attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
51
+ decoder_start_token_id=self.hparams['bos_token_id'],
52
+ )
53
+ return pred_ids # (N, L)
54
 
55
  def configure_optimizers(self) -> torch.optim.Optimizer:
56
  """
 
59
  """
60
  # The authors used Adam, so we might as well use it as well.
61
  return torch.optim.AdamW(self.parameters(), lr=self.hparams['lr'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main_train.py CHANGED
@@ -3,20 +3,19 @@ import torch.cuda
3
  import wandb
4
  import argparse
5
  import pytorch_lightning as pl
6
- from pytorch_lightning.loggers import WandbLogger
7
  from termcolor import colored
8
- from transformers import BertForMaskedLM, BertTokenizer
9
- from idiomify.datamodules import IdiomifyDataModule
10
- from idiomify.fetchers import fetch_config, fetch_idioms
11
- from idiomify.models import Alpha, Gamma
 
12
  from idiomify.paths import ROOT_DIR
13
- from idiomify import tensors as T
14
 
15
 
16
  def main():
17
  parser = argparse.ArgumentParser()
18
  parser.add_argument("--model", type=str, default="alpha")
19
- parser.add_argument("--ver", type=str, default="eng2eng")
20
  parser.add_argument("--num_workers", type=int, default=os.cpu_count())
21
  parser.add_argument("--log_every_n_steps", type=int, default=1)
22
  parser.add_argument("--fast_dev_run", action="store_true", default=False)
@@ -27,22 +26,17 @@ def main():
27
  if not config['upload']:
28
  print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
29
 
30
- # prepare arguments
31
- mlm = BertForMaskedLM.from_pretrained(config['bert'])
32
- tokenizer = BertTokenizer.from_pretrained(config['bert'])
33
- idioms = fetch_idioms(config['idioms_ver'])
34
- idiom2subwords = T.idiom2subwords(idioms, tokenizer, config['k'])
35
- # choose the model to train
36
- if config['model'] == Alpha.name():
37
- rd = Alpha(mlm, idiom2subwords, config['k'], config['lr'])
38
- elif config['model'] == Gamma.name():
39
- rd = Gamma(mlm, idiom2subwords, config['k'], config['lr'])
40
  else:
41
- raise ValueError
42
- # prepare datamodule
43
- datamodule = IdiomifyDataModule(config, tokenizer, idioms)
44
 
45
- with wandb.init(entity="eubinecto", project="idiomify-demo", config=config) as run:
 
46
  logger = WandbLogger(log_model=False)
47
  trainer = pl.Trainer(max_epochs=config['max_epochs'],
48
  fast_dev_run=config['fast_dev_run'],
@@ -52,10 +46,10 @@ def main():
52
  enable_checkpointing=False,
53
  logger=logger)
54
  # start training
55
- trainer.fit(model=rd, datamodule=datamodule)
56
  # upload the model to wandb only if the training is properly done #
57
  if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1:
58
- ckpt_path = ROOT_DIR / "rd.ckpt"
59
  trainer.save_checkpoint(str(ckpt_path))
60
  artifact = wandb.Artifact(name=config['model'], type="model", metadata=config)
61
  artifact.add_file(str(ckpt_path))
 
3
  import wandb
4
  import argparse
5
  import pytorch_lightning as pl
 
6
  from termcolor import colored
7
+ from pytorch_lightning.loggers import WandbLogger
8
+ from transformers import BartTokenizer, BartForConditionalGeneration
9
+ from idiomify.data import IdiomifyDataModule
10
+ from idiomify.fetchers import fetch_config
11
+ from idiomify.models import Alpha
12
  from idiomify.paths import ROOT_DIR
 
13
 
14
 
15
  def main():
16
  parser = argparse.ArgumentParser()
17
  parser.add_argument("--model", type=str, default="alpha")
18
+ parser.add_argument("--ver", type=str, default="overfit ")
19
  parser.add_argument("--num_workers", type=int, default=os.cpu_count())
20
  parser.add_argument("--log_every_n_steps", type=int, default=1)
21
  parser.add_argument("--fast_dev_run", action="store_true", default=False)
 
26
  if not config['upload']:
27
  print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
28
 
29
+ # prepare the model
30
+ bart = BartForConditionalGeneration.from_pretrained(config['bart'])
31
+ tokenizer = BartTokenizer.from_pretrained(config['bart'])
32
+ if config['model'] == "alpha":
33
+ model = Alpha(bart, config['lr'], tokenizer.bos_token_id, tokenizer.pad_token_id)
 
 
 
 
 
34
  else:
35
+ raise NotImplementedError
36
+ # prepare the datamodule
 
37
 
38
+ with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
39
+ datamodule = IdiomifyDataModule(config, tokenizer, run)
40
  logger = WandbLogger(log_model=False)
41
  trainer = pl.Trainer(max_epochs=config['max_epochs'],
42
  fast_dev_run=config['fast_dev_run'],
 
46
  enable_checkpointing=False,
47
  logger=logger)
48
  # start training
49
+ trainer.fit(model=model, datamodule=datamodule)
50
  # upload the model to wandb only if the training is properly done #
51
  if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1:
52
+ ckpt_path = ROOT_DIR / "model.ckpt"
53
  trainer.save_checkpoint(str(ckpt_path))
54
  artifact = wandb.Artifact(name=config['model'], type="model", metadata=config)
55
  artifact.add_file(str(ckpt_path))