eubinecto commited on
Commit
642d911
1 Parent(s): f49863b

[#2] evaluating m-1-2 works. config.yaml simplified.

Browse files
config.yaml CHANGED
@@ -1,20 +1,19 @@
1
- train:
2
  ver: m-1-2
3
  desc: just overfitting the model, but on the entire PIE dataset.
4
  bart: facebook/bart-base
5
  lr: 0.0001
6
  literal2idiomatic_ver: d-1-2
7
- max_epochs: 100
8
- batch_size: 100
9
  shuffle: true
10
 
11
- # for building & uploading datasets or others
12
- upload:
13
- idioms:
14
- ver: d-1-2
15
- description: the set of idioms in the traning set of literal2idiomatic_d-1-2.
16
- literal2idiomatic:
17
- ver: d-1-2
18
- description: PIE data split into train & test set (80 / 20 split). There is no validation set because I don't intend to do any hyperparameter tuning on this thing.
19
- train_ratio: 0.8
20
- seed: 104
 
1
+ idiomifier:
2
  ver: m-1-2
3
  desc: just overfitting the model, but on the entire PIE dataset.
4
  bart: facebook/bart-base
5
  lr: 0.0001
6
  literal2idiomatic_ver: d-1-2
7
+ max_epochs: 2
8
+ batch_size: 40
9
  shuffle: true
10
 
11
+ # for building & uploading datasets or tokenizer
12
+ idioms:
13
+ ver: d-1-2
14
+ description: the set of idioms in the traning set of literal2idiomatic_d-1-2.
15
+ literal2idiomatic:
16
+ ver: d-1-2
17
+ description: PIE data split into train & test set (80 / 20 split). There is no validation set because I don't intend to do any hyperparameter tuning on this thing.
18
+ train_ratio: 0.8
19
+ seed: 104
 
explore/explore_bart_logits_shape.py CHANGED
@@ -1,6 +1,6 @@
1
  from transformers import BartTokenizer, BartForConditionalGeneration
2
 
3
- from data import IdiomifyDataModule
4
 
5
 
6
  CONFIG = {
 
1
  from transformers import BartTokenizer, BartForConditionalGeneration
2
 
3
+ from datamodules import IdiomifyDataModule
4
 
5
 
6
  CONFIG = {
explore/explore_idiomifydatamodule.py CHANGED
@@ -1,5 +1,5 @@
1
  from transformers import BartTokenizer
2
- from idiomify.data import IdiomifyDataModule
3
 
4
 
5
  CONFIG = {
 
1
  from transformers import BartTokenizer
2
+ from idiomify.datamodules import IdiomifyDataModule
3
 
4
 
5
  CONFIG = {
idiomify/{data.py → datamodules.py} RENAMED
@@ -84,6 +84,6 @@ class IdiomifyDataModule(LightningDataModule):
84
  return DataLoader(self.train_dataset, batch_size=self.config['batch_size'],
85
  shuffle=self.config['shuffle'], num_workers=self.config['num_workers'])
86
 
87
- def test_dataloader(self):
88
  return DataLoader(self.test_dataset, batch_size=self.config['batch_size'],
89
  shuffle=False, num_workers=self.config['num_workers'])
 
84
  return DataLoader(self.train_dataset, batch_size=self.config['batch_size'],
85
  shuffle=self.config['shuffle'], num_workers=self.config['num_workers'])
86
 
87
+ def test_dataloader(self) -> DataLoader:
88
  return DataLoader(self.test_dataset, batch_size=self.config['batch_size'],
89
  shuffle=False, num_workers=self.config['num_workers'])
idiomify/fetchers.py CHANGED
@@ -53,9 +53,9 @@ def fetch_idiomifier(ver: str, run: Run = None) -> Idiomifier:
53
  The current Idiomifier then turns into a pipeline.
54
  """
55
  if run:
56
- artifact = run.use_artifact(f"seq2seq:{ver}", type="model")
57
  else:
58
- artifact = wandb.Api().artifact(f"eubinecto/idiomify/seq2seq:{ver}", type="model")
59
  config = artifact.metadata
60
  artifact_dir = artifact.download(root=seq2seq_dir(ver))
61
  ckpt_path = path.join(artifact_dir, "model.ckpt")
 
53
  The current Idiomifier then turns into a pipeline.
54
  """
55
  if run:
56
+ artifact = run.use_artifact(f"idiomifier:{ver}", type="model")
57
  else:
58
+ artifact = wandb.Api().artifact(f"eubinecto/idiomify/idiomifier:{ver}", type="model")
59
  config = artifact.metadata
60
  artifact_dir = artifact.download(root=seq2seq_dir(ver))
61
  ckpt_path = path.join(artifact_dir, "model.ckpt")
idiomify/models.py CHANGED
@@ -48,19 +48,19 @@ class Idiomifier(pl.LightningModule): # noqa
48
  "loss": loss
49
  }
50
 
51
- def on_train_batch_end(self, outputs: dict, **kwargs):
52
  self.log("Train/Loss", outputs['loss'])
53
 
54
  def on_train_epoch_end(self, *args, **kwargs) -> None:
55
  self.log("Train/Accuracy", self.acc_train.compute())
56
  self.acc_train.reset()
57
 
58
- def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], **kwargs):
59
  srcs, tgts_r, tgts = batch # (N, 2, L_s), (N, 2, L_t), (N, 2, L_t)
60
  logits = self.forward(srcs, tgts_r).transpose(1, 2) # ... -> (N, L, |V|) -> (N, |V|, L)
61
  self.acc_test.update(logits.detach(), target=tgts.detach())
62
 
63
- def on_test_end(self):
64
  self.log("Test/Accuracy", self.acc_test.compute())
65
  self.acc_test.reset()
66
 
@@ -72,21 +72,3 @@ class Idiomifier(pl.LightningModule): # noqa
72
  # The authors used Adam, so we might as well use it as well.
73
  return torch.optim.AdamW(self.parameters(), lr=self.hparams['lr'])
74
 
75
-
76
- # for inference
77
- class Pipeline:
78
-
79
- def __init__(self, model: Idiomifier, tokenizer: BartTokenizer):
80
- self.model = model
81
- self.builder = SourcesBuilder(tokenizer)
82
-
83
- def __call__(self, src: str, max_length=100) -> str:
84
- srcs = self.builder(literal2idiomatic=[(src, "")])
85
- pred_ids = self.model.bart.generate(
86
- inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
87
- attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
88
- decoder_start_token_id=self.model.hparams['bos_token_id'],
89
- max_length=max_length,
90
- ).squeeze() # -> (N, L_t) -> (L_t)
91
- tgt = self.builder.tokenizer.decode(pred_ids, skip_special_tokens=True)
92
- return tgt
 
48
  "loss": loss
49
  }
50
 
51
+ def on_train_batch_end(self, outputs: dict, *args, **kwargs):
52
  self.log("Train/Loss", outputs['loss'])
53
 
54
  def on_train_epoch_end(self, *args, **kwargs) -> None:
55
  self.log("Train/Accuracy", self.acc_train.compute())
56
  self.acc_train.reset()
57
 
58
+ def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], *args, **kwargs):
59
  srcs, tgts_r, tgts = batch # (N, 2, L_s), (N, 2, L_t), (N, 2, L_t)
60
  logits = self.forward(srcs, tgts_r).transpose(1, 2) # ... -> (N, L, |V|) -> (N, |V|, L)
61
  self.acc_test.update(logits.detach(), target=tgts.detach())
62
 
63
+ def on_test_epoch_end(self, *args, **kwargs) -> None:
64
  self.log("Test/Accuracy", self.acc_test.compute())
65
  self.acc_test.reset()
66
 
 
72
  # The authors used Adam, so we might as well use it as well.
73
  return torch.optim.AdamW(self.parameters(), lr=self.hparams['lr'])
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
idiomify/pipeline.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # for inference
3
+ from transformers import BartTokenizer
4
+
5
+ from builders import SourcesBuilder
6
+ from models import Idiomifier
7
+
8
+
9
+ class Pipeline:
10
+
11
+ def __init__(self, model: Idiomifier, tokenizer: BartTokenizer):
12
+ self.model = model
13
+ self.builder = SourcesBuilder(tokenizer)
14
+
15
+ def __call__(self, src: str, max_length=100) -> str:
16
+ srcs = self.builder(literal2idiomatic=[(src, "")])
17
+ pred_ids = self.model.bart.generate(
18
+ inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
19
+ attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
20
+ decoder_start_token_id=self.model.hparams['bos_token_id'],
21
+ max_length=max_length,
22
+ ).squeeze() # -> (N, L_t) -> (L_t)
23
+ tgt = self.builder.tokenizer.decode(pred_ids, skip_special_tokens=True)
24
+ return tgt
main_eval.py CHANGED
@@ -5,22 +5,22 @@ import wandb
5
  import pytorch_lightning as pl
6
  from pytorch_lightning.loggers import WandbLogger
7
  from transformers import BartTokenizer
8
- from idiomify.data import IdiomifyDataModule
9
  from idiomify.fetchers import fetch_config, fetch_idiomifier
10
- from paths import ROOT_DIR
11
 
12
 
13
  def main():
14
  parser = argparse.ArgumentParser()
15
  parser.add_argument("--num_workers", type=int, default=os.cpu_count())
 
16
  args = parser.parse_args()
17
- config = fetch_config()['train']
18
  config.update(vars(args))
19
- # prepare the model
20
  tokenizer = BartTokenizer.from_pretrained(config['bart'])
21
  # prepare the datamodule
22
  with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
23
- model = fetch_idiomifier(config['ver'], run)
24
  datamodule = IdiomifyDataModule(config, tokenizer, run)
25
  logger = WandbLogger(log_model=False)
26
  trainer = pl.Trainer(fast_dev_run=config['fast_dev_run'],
 
5
  import pytorch_lightning as pl
6
  from pytorch_lightning.loggers import WandbLogger
7
  from transformers import BartTokenizer
8
+ from idiomify.datamodules import IdiomifyDataModule
9
  from idiomify.fetchers import fetch_config, fetch_idiomifier
10
+ from idiomify.paths import ROOT_DIR
11
 
12
 
13
  def main():
14
  parser = argparse.ArgumentParser()
15
  parser.add_argument("--num_workers", type=int, default=os.cpu_count())
16
+ parser.add_argument("--fast_dev_run", action="store_true", default=False)
17
  args = parser.parse_args()
18
+ config = fetch_config()['idiomifier']
19
  config.update(vars(args))
 
20
  tokenizer = BartTokenizer.from_pretrained(config['bart'])
21
  # prepare the datamodule
22
  with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
23
+ model = fetch_idiomifier(config['ver'], run) # fetch a pre-trained model
24
  datamodule = IdiomifyDataModule(config, tokenizer, run)
25
  logger = WandbLogger(log_model=False)
26
  trainer = pl.Trainer(fast_dev_run=config['fast_dev_run'],
main_infer.py CHANGED
@@ -1,5 +1,5 @@
1
  import argparse
2
- from idiomify.models import Idiomifier, Pipeline
3
  from idiomify.fetchers import fetch_config, fetch_idiomifier
4
  from transformers import BartTokenizer
5
 
@@ -10,14 +10,14 @@ def main():
10
  default="If there's any good to loosing my job,"
11
  " it's that I'll now be able to go to school full-time and finish my degree earlier.")
12
  args = parser.parse_args()
13
- config = fetch_config()['infer']
14
  config.update(vars(args))
15
  model = fetch_idiomifier(config['ver'])
16
  model.eval() # this is crucial
17
  tokenizer = BartTokenizer.from_pretrained(config['bart'])
18
- idiomifier = Pipeline(model, tokenizer)
19
  src = config['src']
20
- tgt = idiomifier(src=config['src'])
21
  print(src, "\n->", tgt)
22
 
23
 
 
1
  import argparse
2
+ from idiomify.models import Pipeline
3
  from idiomify.fetchers import fetch_config, fetch_idiomifier
4
  from transformers import BartTokenizer
5
 
 
10
  default="If there's any good to loosing my job,"
11
  " it's that I'll now be able to go to school full-time and finish my degree earlier.")
12
  args = parser.parse_args()
13
+ config = fetch_config()['idiomifier']
14
  config.update(vars(args))
15
  model = fetch_idiomifier(config['ver'])
16
  model.eval() # this is crucial
17
  tokenizer = BartTokenizer.from_pretrained(config['bart'])
18
+ pipeline = Pipeline(model, tokenizer)
19
  src = config['src']
20
+ tgt = pipeline(src=config['src'])
21
  print(src, "\n->", tgt)
22
 
23
 
main_train.py CHANGED
@@ -6,7 +6,7 @@ import pytorch_lightning as pl
6
  from termcolor import colored
7
  from pytorch_lightning.loggers import WandbLogger
8
  from transformers import BartTokenizer, BartForConditionalGeneration
9
- from idiomify.data import IdiomifyDataModule
10
  from idiomify.fetchers import fetch_config
11
  from idiomify.models import Idiomifier
12
  from idiomify.paths import ROOT_DIR
@@ -19,7 +19,7 @@ def main():
19
  parser.add_argument("--fast_dev_run", action="store_true", default=False)
20
  parser.add_argument("--upload", dest='upload', action='store_true', default=False)
21
  args = parser.parse_args()
22
- config = fetch_config()['train']
23
  config.update(vars(args))
24
  if not config['upload']:
25
  print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
 
6
  from termcolor import colored
7
  from pytorch_lightning.loggers import WandbLogger
8
  from transformers import BartTokenizer, BartForConditionalGeneration
9
+ from idiomify.datamodules import IdiomifyDataModule
10
  from idiomify.fetchers import fetch_config
11
  from idiomify.models import Idiomifier
12
  from idiomify.paths import ROOT_DIR
 
19
  parser.add_argument("--fast_dev_run", action="store_true", default=False)
20
  parser.add_argument("--upload", dest='upload', action='store_true', default=False)
21
  args = parser.parse_args()
22
+ config = fetch_config()['idiomifier']
23
  config.update(vars(args))
24
  if not config['upload']:
25
  print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
main_upload_idioms.py CHANGED
@@ -9,7 +9,7 @@ from idiomify.paths import ROOT_DIR
9
 
10
 
11
  def main():
12
- config = fetch_config()['upload']['idioms']
13
  train_df, _ = fetch_literal2idiomatic(config['ver'])
14
  idioms = train_df['Idiom'].tolist()
15
  idioms = list(set(idioms))
 
9
 
10
 
11
  def main():
12
+ config = fetch_config()['idioms']
13
  train_df, _ = fetch_literal2idiomatic(config['ver'])
14
  idioms = train_df['Idiom'].tolist()
15
  idioms = list(set(idioms))
main_upload_literal2idiomatic.py CHANGED
@@ -12,7 +12,7 @@ def main():
12
 
13
  # here, we use all of them, while splitting them into train & test
14
  pie_df = fetch_pie()
15
- config = fetch_config()['upload']['literal2idiomatic']
16
  train_df, test_df = pie_df.pipe(cleanse)\
17
  .pipe(upsample, seed=config['seed'])\
18
  .pipe(stratified_split, ratio=config['train_ratio'], seed=config['seed'])
 
12
 
13
  # here, we use all of them, while splitting them into train & test
14
  pie_df = fetch_pie()
15
+ config = fetch_config()['literal2idiomatic']
16
  train_df, test_df = pie_df.pipe(cleanse)\
17
  .pipe(upsample, seed=config['seed'])\
18
  .pipe(stratified_split, ratio=config['train_ratio'], seed=config['seed'])