eubinecto commited on
Commit
25e310b
1 Parent(s): 3be6142

[#1] IdiomifyDataModule implemented (srcs, tgts_r, tgts)

Browse files
explore/{explore_fetch_literal2idiom.py → explore_fetch_literal2idiomatic.py} RENAMED
@@ -1,8 +1,8 @@
1
- from idiomify.fetchers import fetch_literal2idiom
2
 
3
 
4
  def main():
5
- for src, tgt in fetch_literal2idiom("pie_v0"):
6
  print(src, "->", tgt)
7
 
8
 
 
1
+ from idiomify.fetchers import fetch_literal2idiomatic
2
 
3
 
4
  def main():
5
+ for src, tgt in fetch_literal2idiomatic("pie_v0"):
6
  print(src, "->", tgt)
7
 
8
 
explore/explore_idiomifydatamodule.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BartTokenizer
2
+ from idiomify.datamodules import IdiomifyDataModule
3
+
4
+
5
+ CONFIG = {
6
+ "literal2idiomatic_ver": "pie_v0",
7
+ "batch_size": 20,
8
+ "num_workers": 4,
9
+ "shuffle": True
10
+ }
11
+
12
+
13
+ def main():
14
+ tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
15
+ datamodule = IdiomifyDataModule(CONFIG, tokenizer)
16
+ datamodule.prepare_data()
17
+ datamodule.setup()
18
+ for batch in datamodule.train_dataloader():
19
+ srcs, tgts_r, tgts = batch
20
+ print(srcs.shape)
21
+ print(tgts_r.shape)
22
+ print(tgts.shape)
23
+
24
+
25
+ if __name__ == '__main__':
26
+ main()
explore/explore_src_builder.py CHANGED
@@ -1,5 +1,5 @@
1
  from transformers import BartTokenizer
2
- from idiomify.builders import SRCBuilder
3
 
4
  BATCH = [
5
  ("I could die at any moment", "I could kick the bucket at any moment"),
@@ -9,7 +9,7 @@ BATCH = [
9
 
10
  def main():
11
  tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
12
- builder = SRCBuilder(tokenizer)
13
  src = builder(BATCH)
14
  print(src)
15
 
 
1
  from transformers import BartTokenizer
2
+ from idiomify.builders import SourcesBuilder
3
 
4
  BATCH = [
5
  ("I could die at any moment", "I could kick the bucket at any moment"),
 
9
 
10
  def main():
11
  tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
12
+ builder = SourcesBuilder(tokenizer)
13
  src = builder(BATCH)
14
  print(src)
15
 
explore/explore_tgt_builder.py CHANGED
@@ -1,5 +1,5 @@
1
  from transformers import BartTokenizer
2
- from idiomify.builders import TGTBuilder
3
 
4
  BATCH = [
5
  ("I could die at any moment", "I could kick the bucket at any moment"),
@@ -9,7 +9,7 @@ BATCH = [
9
 
10
  def main():
11
  tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
12
- builder = TGTBuilder(tokenizer)
13
  tgt_r, tgt = builder(BATCH)
14
  print(tgt_r)
15
  print(tgt)
 
1
  from transformers import BartTokenizer
2
+ from idiomify.builders import TargetsBuilder
3
 
4
  BATCH = [
5
  ("I could die at any moment", "I could kick the bucket at any moment"),
 
9
 
10
  def main():
11
  tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
12
+ builder = TargetsBuilder(tokenizer)
13
  tgt_r, tgt = builder(BATCH)
14
  print(tgt_r)
15
  print(tgt)
idiomify/builders.py CHANGED
@@ -4,12 +4,12 @@ builders must accept device as one of the parameters.
4
  """
5
  import torch
6
  from typing import List, Tuple
7
- from transformers import BertTokenizer
8
 
9
 
10
  class TensorBuilder:
11
 
12
- def __init__(self, tokenizer: BertTokenizer):
13
  self.tokenizer = tokenizer
14
 
15
  def __call__(self, *args, **kwargs) -> torch.Tensor:
@@ -45,7 +45,7 @@ class Idiom2SubwordsBuilder(TensorBuilder):
45
  return input_ids
46
 
47
 
48
- class SRCBuilder(TensorBuilder):
49
  """
50
  to be used for both training and inference
51
  """
@@ -60,24 +60,29 @@ class SRCBuilder(TensorBuilder):
60
  return src # (N, 2, L)
61
 
62
 
63
- class TGTBuilder(TensorBuilder):
64
- """
65
- This is to be used only for training. As for inference, we don't need this.
66
- """
67
- def __call__(self, literal2idiomatic: List[Tuple[str, str]]) -> Tuple[torch.Tensor, torch.Tensor]:
68
- encodings_r = self.tokenizer([
69
  self.tokenizer.bos_token + idiomatic # starts with bos, but does not end with eos (right-shifted)
70
  for _, idiomatic in literal2idiomatic
71
  ], return_tensors="pt", add_special_tokens=False, padding=True, truncation=True)
 
 
 
 
 
 
 
 
 
 
72
  encodings = self.tokenizer([
73
  idiomatic + self.tokenizer.eos_token # no bos, but ends with eos
74
  for _, idiomatic in literal2idiomatic
75
  ], return_tensors="pt", add_special_tokens=False, padding=True, truncation=True)
76
- tgt_r = torch.stack([encodings_r['input_ids'],
77
- encodings_r['attention_mask']], dim=1) # (N, 2, L)
78
- tgt = torch.stack([encodings['input_ids'],
79
- encodings['attention_mask']], dim=1) # (N, 2, L)
80
- return tgt_r, tgt
81
-
82
 
83
 
 
4
  """
5
  import torch
6
  from typing import List, Tuple
7
+ from transformers import BertTokenizer, BartTokenizer
8
 
9
 
10
  class TensorBuilder:
11
 
12
+ def __init__(self, tokenizer: BartTokenizer):
13
  self.tokenizer = tokenizer
14
 
15
  def __call__(self, *args, **kwargs) -> torch.Tensor:
 
45
  return input_ids
46
 
47
 
48
+ class SourcesBuilder(TensorBuilder):
49
  """
50
  to be used for both training and inference
51
  """
 
60
  return src # (N, 2, L)
61
 
62
 
63
+ class TargetsRightShiftedBuilder(TensorBuilder):
64
+
65
+ def __call__(self, literal2idiomatic: List[Tuple[str, str]]) -> torch.Tensor:
66
+ encodings = self.tokenizer([
 
 
67
  self.tokenizer.bos_token + idiomatic # starts with bos, but does not end with eos (right-shifted)
68
  for _, idiomatic in literal2idiomatic
69
  ], return_tensors="pt", add_special_tokens=False, padding=True, truncation=True)
70
+ tgts_r = torch.stack([encodings['input_ids'],
71
+ encodings['attention_mask']], dim=1) # (N, 2, L)
72
+ return tgts_r
73
+
74
+
75
+ class TargetsBuilder(TensorBuilder):
76
+ """
77
+ This is to be used only for training. As for inference, we don't need this.
78
+ """
79
+ def __call__(self, literal2idiomatic: List[Tuple[str, str]]) -> torch.Tensor:
80
  encodings = self.tokenizer([
81
  idiomatic + self.tokenizer.eos_token # no bos, but ends with eos
82
  for _, idiomatic in literal2idiomatic
83
  ], return_tensors="pt", add_special_tokens=False, padding=True, truncation=True)
84
+ tgts = torch.stack([encodings['input_ids'],
85
+ encodings['attention_mask']], dim=1) # (N, 2, L)
86
+ return tgts
 
 
 
87
 
88
 
idiomify/datamodules.py CHANGED
@@ -2,35 +2,38 @@ import torch
2
  from typing import Tuple, Optional, List
3
  from torch.utils.data import Dataset, DataLoader
4
  from pytorch_lightning import LightningDataModule
5
- from idiomify.fetchers import fetch_idiom2def, fetch_epie
6
- from idiomify.builders import Idiom2DefBuilder, Idiom2ContextBuilder, LabelsBuilder
7
- from transformers import BertTokenizer
8
 
9
 
10
  class IdiomifyDataset(Dataset):
11
  def __init__(self,
12
- X: torch.Tensor,
13
- y: torch.Tensor):
14
- self.X = X
15
- self.y = y
 
 
16
 
17
  def __len__(self) -> int:
18
  """
19
  Returning the size of the dataset
20
  :return:
21
  """
22
- return self.y.shape[0]
 
23
 
24
- def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.LongTensor]:
25
  """
26
  Returns features & the label
27
  :param idx:
28
  :return:
29
  """
30
- return self.X[idx], self.y[idx]
31
 
32
 
33
- class Idiom2DefDataModule(LightningDataModule):
34
 
35
  # boilerplate - just ignore these
36
  def test_dataloader(self):
@@ -44,21 +47,19 @@ class Idiom2DefDataModule(LightningDataModule):
44
 
45
  def __init__(self,
46
  config: dict,
47
- tokenizer: BertTokenizer,
48
- idioms: List[str]):
49
  super().__init__()
50
  self.config = config
51
  self.tokenizer = tokenizer
52
- self.idioms = idioms
53
  # --- to be downloaded & built --- #
54
- self.idiom2def: Optional[List[Tuple[str, str]]] = None
55
  self.dataset: Optional[IdiomifyDataset] = None
56
 
57
  def prepare_data(self):
58
  """
59
  prepare: download all data needed for this from wandb to local.
60
  """
61
- self.idiom2def = fetch_idiom2def(self.config['idiom2def_ver'])
62
 
63
  def setup(self, stage: Optional[str] = None):
64
  """
@@ -66,50 +67,11 @@ class Idiom2DefDataModule(LightningDataModule):
66
  """
67
  # --- set up the builders --- #
68
  # build the datasets
69
- X = Idiom2DefBuilder(self.tokenizer)(self.idiom2def, self.config['k'])
70
- y = LabelsBuilder(self.tokenizer)(self.idiom2def, self.idioms)
71
- self.dataset = IdiomifyDataset(X, y)
 
72
 
73
  def train_dataloader(self) -> DataLoader:
74
  return DataLoader(self.dataset, batch_size=self.config['batch_size'],
75
  shuffle=self.config['shuffle'], num_workers=self.config['num_workers'])
76
-
77
-
78
- class Idiom2ContextsDataModule(LightningDataModule):
79
-
80
- # boilerplate - just ignore these
81
- def test_dataloader(self):
82
- pass
83
-
84
- def val_dataloader(self):
85
- pass
86
-
87
- def predict_dataloader(self):
88
- pass
89
-
90
- def __init__(self, config: dict, tokenizer: BertTokenizer, idioms: List[str]):
91
- super().__init__()
92
- self.config = config
93
- self.tokenizer = tokenizer
94
- self.idioms = idioms
95
- self.idiom2context: Optional[List[Tuple[str, str]]] = None
96
- self.dataset: Optional[IdiomifyDataset] = None
97
-
98
- def prepare_data(self):
99
- """
100
- prepare: download all data needed for this from wandb to local.
101
- """
102
- self.idiom2context = [
103
- (idiom, context)
104
- for idiom, _, context in fetch_epie()
105
- ]
106
-
107
- def setup(self, stage: Optional[str] = None):
108
- # build the datasets
109
- X = Idiom2ContextBuilder(self.tokenizer)(self.idiom2context)
110
- y = LabelsBuilder(self.tokenizer)(self.idiom2context, self.idioms)
111
- self.dataset = IdiomifyDataset(X, y)
112
-
113
- def train_dataloader(self):
114
- return DataLoader(self.dataset, batch_size=self.config['batch_size'],
115
- shuffle=self.config['shuffle'], num_workers=self.config['num_workers'])
 
2
  from typing import Tuple, Optional, List
3
  from torch.utils.data import Dataset, DataLoader
4
  from pytorch_lightning import LightningDataModule
5
+ from idiomify.fetchers import fetch_literal2idiomatic
6
+ from idiomify.builders import SourcesBuilder, TargetsBuilder, TargetsRightShiftedBuilder
7
+ from transformers import BartTokenizer
8
 
9
 
10
  class IdiomifyDataset(Dataset):
11
  def __init__(self,
12
+ srcs: torch.Tensor,
13
+ tgts_r: torch.Tensor,
14
+ tgts: torch.Tensor):
15
+ self.srcs = srcs
16
+ self.tgts_r = tgts_r
17
+ self.tgts = tgts
18
 
19
  def __len__(self) -> int:
20
  """
21
  Returning the size of the dataset
22
  :return:
23
  """
24
+ assert self.srcs.shape[0] == self.tgts_r.shape[0] == self.tgts.shape[0]
25
+ return self.srcs.shape[0]
26
 
27
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
28
  """
29
  Returns features & the label
30
  :param idx:
31
  :return:
32
  """
33
+ return self.srcs[idx], self.tgts_r[idx], self.tgts[idx]
34
 
35
 
36
+ class IdiomifyDataModule(LightningDataModule):
37
 
38
  # boilerplate - just ignore these
39
  def test_dataloader(self):
 
47
 
48
  def __init__(self,
49
  config: dict,
50
+ tokenizer: BartTokenizer):
 
51
  super().__init__()
52
  self.config = config
53
  self.tokenizer = tokenizer
 
54
  # --- to be downloaded & built --- #
55
+ self.literal2idiomatic: Optional[List[Tuple[str, str]]] = None
56
  self.dataset: Optional[IdiomifyDataset] = None
57
 
58
  def prepare_data(self):
59
  """
60
  prepare: download all data needed for this from wandb to local.
61
  """
62
+ self.literal2idiomatic = fetch_literal2idiomatic(self.config['literal2idiomatic_ver'])
63
 
64
  def setup(self, stage: Optional[str] = None):
65
  """
 
67
  """
68
  # --- set up the builders --- #
69
  # build the datasets
70
+ srcs = SourcesBuilder(self.tokenizer)(self.literal2idiomatic)
71
+ tgts_r = TargetsRightShiftedBuilder(self.tokenizer)(self.literal2idiomatic)
72
+ tgts = TargetsBuilder(self.tokenizer)(self.literal2idiomatic)
73
+ self.dataset = IdiomifyDataset(srcs, tgts_r, tgts)
74
 
75
  def train_dataloader(self) -> DataLoader:
76
  return DataLoader(self.dataset, batch_size=self.config['batch_size'],
77
  shuffle=self.config['shuffle'], num_workers=self.config['num_workers'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
idiomify/fetchers.py CHANGED
@@ -5,10 +5,7 @@ import wandb
5
  import requests
6
  from typing import Tuple, List
7
  from wandb.sdk.wandb_run import Run
8
- from transformers import AutoModelForMaskedLM, AutoConfig, BertTokenizer
9
- from idiomify.builders import Idiom2SubwordsBuilder
10
- from idiomify.models import Alpha, RD
11
- from idiomify.paths import CONFIG_YAML, idioms_dir, alpha_dir, literal2idiom
12
  from idiomify.urls import (
13
  EPIE_IMMUTABLE_IDIOMS_URL,
14
  EPIE_IMMUTABLE_IDIOMS_CONTEXTS_URL,
@@ -74,40 +71,20 @@ def fetch_idioms(ver: str, run: Run = None) -> List[str]:
74
  return [line.strip() for line in fh]
75
 
76
 
77
- def fetch_literal2idiom(ver: str, run: Run = None) -> List[Tuple[str, str]]:
78
  # if run object is given, we track the lineage of the data.
79
  # if not, we get the dataset via wandb Api.
80
  if run:
81
  artifact = run.use_artifact("literal2idiom", type="dataset", aliases=ver)
82
  else:
83
- artifact = wandb.Api().artifact(f"eubinecto/idiomify/literal2idiom:{ver}", type="dataset")
84
- artifact_dir = artifact.download(root=literal2idiom(ver))
85
  tsv_path = path.join(artifact_dir, "all.tsv")
86
  with open(tsv_path, 'r') as fh:
87
  reader = csv.reader(fh, delimiter="\t")
88
  return [(row[0], row[1]) for row in reader]
89
 
90
 
91
- def fetch_rd(model: str, ver: str) -> RD:
92
- artifact = wandb.Api().artifact(f"eubinecto/idiomify-demo/{model}:{ver}", type="model")
93
- config = artifact.metadata
94
- artifact_path = alpha_dir(ver)
95
- artifact.download(root=str(artifact_path))
96
- mlm = AutoModelForMaskedLM.from_config(AutoConfig.from_pretrained(config['bert']))
97
- ckpt_path = artifact_path / "rd.ckpt"
98
- idioms = fetch_idioms(config['idioms_ver'])
99
- tokenizer = BertTokenizer.from_pretrained(config['bert'])
100
- idiom2subwords = Idiom2SubwordsBuilder(tokenizer)(idioms, config['k'])
101
- # if model == Alpha.name():
102
- # rd = Alpha.load_from_checkpoint(str(ckpt_path), mlm=mlm, idiom2subwords=idiom2subwords)
103
- # elif model == Gamma.name():
104
- # rd = Gamma.load_from_checkpoint(str(ckpt_path), mlm=mlm, idiom2subwords=idiom2subwords)
105
- # else:
106
- # raise ValueError
107
- rd = ...
108
- return rd
109
-
110
-
111
  def fetch_config() -> dict:
112
  with open(str(CONFIG_YAML), 'r', encoding="utf-8") as fh:
113
  return yaml.safe_load(fh)
 
5
  import requests
6
  from typing import Tuple, List
7
  from wandb.sdk.wandb_run import Run
8
+ from idiomify.paths import CONFIG_YAML, idioms_dir, literal2idiomatic
 
 
 
9
  from idiomify.urls import (
10
  EPIE_IMMUTABLE_IDIOMS_URL,
11
  EPIE_IMMUTABLE_IDIOMS_CONTEXTS_URL,
 
71
  return [line.strip() for line in fh]
72
 
73
 
74
+ def fetch_literal2idiomatic(ver: str, run: Run = None) -> List[Tuple[str, str]]:
75
  # if run object is given, we track the lineage of the data.
76
  # if not, we get the dataset via wandb Api.
77
  if run:
78
  artifact = run.use_artifact("literal2idiom", type="dataset", aliases=ver)
79
  else:
80
+ artifact = wandb.Api().artifact(f"eubinecto/idiomify/literal2idiomatic:{ver}", type="dataset")
81
+ artifact_dir = artifact.download(root=literal2idiomatic(ver))
82
  tsv_path = path.join(artifact_dir, "all.tsv")
83
  with open(tsv_path, 'r') as fh:
84
  reader = csv.reader(fh, delimiter="\t")
85
  return [(row[0], row[1]) for row in reader]
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def fetch_config() -> dict:
89
  with open(str(CONFIG_YAML), 'r', encoding="utf-8") as fh:
90
  return yaml.safe_load(fh)
idiomify/paths.py CHANGED
@@ -9,8 +9,8 @@ def idioms_dir(ver: str) -> Path:
9
  return ARTIFACTS_DIR / f"idioms_{ver}"
10
 
11
 
12
- def literal2idiom(ver: str) -> Path:
13
- return ARTIFACTS_DIR / f"literal2idiom_{ver}"
14
 
15
 
16
  def alpha_dir(ver: str) -> Path:
 
9
  return ARTIFACTS_DIR / f"idioms_{ver}"
10
 
11
 
12
+ def literal2idiomatic(ver: str) -> Path:
13
+ return ARTIFACTS_DIR / f"literal2idiomatic_{ver}"
14
 
15
 
16
  def alpha_dir(ver: str) -> Path:
main_train.py CHANGED
@@ -6,7 +6,7 @@ import pytorch_lightning as pl
6
  from pytorch_lightning.loggers import WandbLogger
7
  from termcolor import colored
8
  from transformers import BertForMaskedLM, BertTokenizer
9
- from idiomify.datamodules import Idiom2DefDataModule
10
  from idiomify.fetchers import fetch_config, fetch_idioms
11
  from idiomify.models import Alpha, Gamma
12
  from idiomify.paths import ROOT_DIR
@@ -40,7 +40,7 @@ def main():
40
  else:
41
  raise ValueError
42
  # prepare datamodule
43
- datamodule = Idiom2DefDataModule(config, tokenizer, idioms)
44
 
45
  with wandb.init(entity="eubinecto", project="idiomify-demo", config=config) as run:
46
  logger = WandbLogger(log_model=False)
 
6
  from pytorch_lightning.loggers import WandbLogger
7
  from termcolor import colored
8
  from transformers import BertForMaskedLM, BertTokenizer
9
+ from idiomify.datamodules import IdiomifyDataModule
10
  from idiomify.fetchers import fetch_config, fetch_idioms
11
  from idiomify.models import Alpha, Gamma
12
  from idiomify.paths import ROOT_DIR
 
40
  else:
41
  raise ValueError
42
  # prepare datamodule
43
+ datamodule = IdiomifyDataModule(config, tokenizer, idioms)
44
 
45
  with wandb.init(entity="eubinecto", project="idiomify-demo", config=config) as run:
46
  logger = WandbLogger(log_model=False)
main_upload_literal2idiom.py → main_upload_literal2idiomatic.py RENAMED
@@ -31,7 +31,7 @@ def main():
31
  raise NotImplementedError
32
 
33
  with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
34
- artifact = wandb.Artifact(name="literal2idiom", type="dataset")
35
  tsv_path = ROOT_DIR / "all.tsv"
36
  with open(tsv_path, 'w') as fh:
37
  writer = csv.writer(fh, delimiter="\t")
 
31
  raise NotImplementedError
32
 
33
  with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
34
+ artifact = wandb.Artifact(name="literal2idiomatic", type="dataset")
35
  tsv_path = ROOT_DIR / "all.tsv"
36
  with open(tsv_path, 'w') as fh:
37
  writer = csv.writer(fh, delimiter="\t")