eubinecto commited on
Commit
dd810eb
1 Parent(s): e3c7b5a

[#2] Eval script added. It still needs testing

Browse files
config.yaml CHANGED
@@ -12,9 +12,10 @@ train:
12
  upload:
13
  idioms:
14
  ver: d-1-2
15
- description: the set of idioms in the traning set of literal2idiomatic:d-1-2
16
  literal2idiomatic:
17
  ver: d-1-2
18
- description: PIE data split into train & test set (80 / 20 split)
 
19
  train_ratio: 0.8
20
  seed: 104
 
12
  upload:
13
  idioms:
14
  ver: d-1-2
15
+ description: the set of idioms in the traning set of literal2idiomatic:d-1-2.
16
  literal2idiomatic:
17
  ver: d-1-2
18
+ description: PIE data split into train & test set (80 / 20 split). There is no validation set, because I don't intend to
19
+ do hyperparameter tuning on this set.
20
  train_ratio: 0.8
21
  seed: 104
explore/explore_torchmetrics_bleu.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from torchmetrics import BLEUScore
3
+ from transformers import BartTokenizer
4
+
5
+
6
+ pairs = [
7
+ ("I knew you could do it", "I knew you could do it"),
8
+ ("I knew you could do it", "you knew you could do it")
9
+ ]
10
+
11
+
12
+ def main():
13
+ tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
14
+ metric = BLEUScore()
15
+ preds = tokenizer([pred for pred, _ in pairs])['input_ids']
16
+ targets = tokenizer([target for _, target in pairs])['input_ids']
17
+ print(preds)
18
+ print(targets)
19
+ print(metric(preds, targets))
20
+ # arghhh, so bleu score does not support tensors...
21
+ """
22
+ AttributeError: 'int' object has no attribute 'split'
23
+ """
24
+ # let's just go for the accuracies then.
25
+
26
+
27
+ if __name__ == '__main__':
28
+ main()
idiomify/metrics.py DELETED
@@ -1,4 +0,0 @@
1
- """
2
- you may want to include bleu score.
3
- and more metrics for paraphrasing.
4
- """
 
 
 
 
 
idiomify/models.py CHANGED
@@ -7,7 +7,7 @@ from torch.nn import functional as F
7
  import pytorch_lightning as pl
8
  from transformers import BartForConditionalGeneration, BartTokenizer
9
  from idiomify.builders import SourcesBuilder
10
-
11
 
12
  class Idiomifier(pl.LightningModule): # noqa
13
  """
@@ -15,8 +15,11 @@ class Idiomifier(pl.LightningModule): # noqa
15
  """
16
  def __init__(self, bart: BartForConditionalGeneration, lr: float, bos_token_id: int, pad_token_id: int): # noqa
17
  super().__init__()
18
- self.bart = bart
19
  self.save_hyperparameters(ignore=["bart"])
 
 
 
 
20
 
21
  def forward(self, srcs: torch.Tensor, tgts_r: torch.Tensor) -> torch.Tensor:
22
  """
@@ -40,13 +43,27 @@ class Idiomifier(pl.LightningModule): # noqa
40
  logits = self.forward(srcs, tgts_r).transpose(1, 2) # ... -> (N, L, |V|) -> (N, |V|, L)
41
  loss = F.cross_entropy(logits, tgts, ignore_index=self.hparams['pad_token_id'])\
42
  .sum() # (N, L, |V|), (N, L) -> (N,) -> (1,)
 
43
  return {
44
  "loss": loss
45
  }
46
 
47
- def on_train_batch_end(self, outputs: dict, *args, **kwargs):
48
  self.log("Train/Loss", outputs['loss'])
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def configure_optimizers(self) -> torch.optim.Optimizer:
51
  """
52
  Instantiates and returns the optimizer to be used for this model
 
7
  import pytorch_lightning as pl
8
  from transformers import BartForConditionalGeneration, BartTokenizer
9
  from idiomify.builders import SourcesBuilder
10
+ from torchmetrics import Accuracy
11
 
12
  class Idiomifier(pl.LightningModule): # noqa
13
  """
 
15
  """
16
  def __init__(self, bart: BartForConditionalGeneration, lr: float, bos_token_id: int, pad_token_id: int): # noqa
17
  super().__init__()
 
18
  self.save_hyperparameters(ignore=["bart"])
19
+ self.bart = bart
20
+ # metrics (using accuracies as of right now)
21
+ self.acc_train = Accuracy(ignore_index=pad_token_id)
22
+ self.acc_test = Accuracy(ignore_index=pad_token_id)
23
 
24
  def forward(self, srcs: torch.Tensor, tgts_r: torch.Tensor) -> torch.Tensor:
25
  """
 
43
  logits = self.forward(srcs, tgts_r).transpose(1, 2) # ... -> (N, L, |V|) -> (N, |V|, L)
44
  loss = F.cross_entropy(logits, tgts, ignore_index=self.hparams['pad_token_id'])\
45
  .sum() # (N, L, |V|), (N, L) -> (N,) -> (1,)
46
+ self.acc_train.update(logits.detach(), target=tgts.detach())
47
  return {
48
  "loss": loss
49
  }
50
 
51
+ def on_train_batch_end(self, outputs: dict, **kwargs):
52
  self.log("Train/Loss", outputs['loss'])
53
 
54
+ def on_train_epoch_end(self, *args, **kwargs) -> None:
55
+ self.log("Train/Accuracy", self.acc_train.compute())
56
+ self.acc_train.reset()
57
+
58
+ def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], **kwargs):
59
+ srcs, tgts_r, tgts = batch # (N, 2, L_s), (N, 2, L_t), (N, 2, L_t)
60
+ logits = self.forward(srcs, tgts_r).transpose(1, 2) # ... -> (N, L, |V|) -> (N, |V|, L)
61
+ self.acc_test.update(logits.detach(), target=tgts.detach())
62
+
63
+ def on_test_end(self):
64
+ self.log("Test/Accuracy", self.acc_test.compute())
65
+ self.acc_test.reset()
66
+
67
  def configure_optimizers(self) -> torch.optim.Optimizer:
68
  """
69
  Instantiates and returns the optimizer to be used for this model
main_eval.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import os
4
+ import wandb
5
+ import pytorch_lightning as pl
6
+ from pytorch_lightning.loggers import WandbLogger
7
+ from transformers import BartTokenizer
8
+ from idiomify.data import IdiomifyDataModule
9
+ from idiomify.fetchers import fetch_config, fetch_idiomifier
10
+ from paths import ROOT_DIR
11
+
12
+
13
+ def main():
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument("--num_workers", type=int, default=os.cpu_count())
16
+ args = parser.parse_args()
17
+ config = fetch_config()['train']
18
+ config.update(vars(args))
19
+ # prepare the model
20
+ tokenizer = BartTokenizer.from_pretrained(config['bart'])
21
+ # prepare the datamodule
22
+ with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
23
+ model = fetch_idiomifier(config['ver'], run)
24
+ datamodule = IdiomifyDataModule(config, tokenizer, run)
25
+ logger = WandbLogger(log_model=False)
26
+ trainer = pl.Trainer(fast_dev_run=config['fast_dev_run'],
27
+ gpus=torch.cuda.device_count(),
28
+ default_root_dir=str(ROOT_DIR),
29
+ logger=logger)
30
+ trainer.test(model, datamodule)
31
+
32
+
33
+ if __name__ == '__main__':
34
+ main()