eubinecto commited on
Commit
e9d1a5a
1 Parent(s): 207cddf

[#1] checkpoint before amending builders.py

Browse files
explore/explore_bart.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BartTokenizer, BartModel
2
+
3
+
4
+ def main():
5
+
6
+ tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
7
+ model = BartModel.from_pretrained('facebook/bart-large')
8
+
9
+ inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
10
+ outputs = model(**inputs)
11
+ H_all = outputs.last_hidden_state # noqa
12
+ print(H_all.shape) # (1, 8, 1024)
13
+
14
+
15
+ if __name__ == '__main__':
16
+ main()
main_upload_idiom2context.py → explore/explore_bart_for_conditional_generation.py RENAMED
@@ -1,6 +1,5 @@
1
- """
2
- Build and upload an idiom2context dataset to wandb.
3
- """
4
 
5
 
6
  def main():
@@ -8,4 +7,4 @@ def main():
8
 
9
 
10
  if __name__ == '__main__':
11
- main()
 
1
+
2
+ from transformers import BartTokenizer, BartForConditionalGeneration
 
3
 
4
 
5
  def main():
 
7
 
8
 
9
  if __name__ == '__main__':
10
+ main()
explore/explore_fetch_epie.py CHANGED
@@ -11,7 +11,7 @@ def main():
11
 
12
  # so, what do you want? you want to build an idiom-masked language modeling?
13
  for idiom, context, tag in epie:
14
- print(context)
15
 
16
  for idx, idiom in enumerate(idioms):
17
  print(idx, idiom)
 
11
 
12
  # so, what do you want? you want to build an idiom-masked language modeling?
13
  for idiom, context, tag in epie:
14
+ print(idiom, context)
15
 
16
  for idx, idiom in enumerate(idioms):
17
  print(idx, idiom)
explore/explore_fetch_epie_counts.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  from idiomify.fetchers import fetch_epie
3
 
4
 
 
 
1
  from idiomify.fetchers import fetch_epie
2
 
3
 
explore/explore_fetch_idiom2def.py DELETED
@@ -1,15 +0,0 @@
1
- from idiomify.fetchers import fetch_idiom2def
2
-
3
-
4
- def main():
5
- idiom2def = fetch_idiom2def("c")
6
- for idiom, definition in idiom2def:
7
- print(idiom, definition)
8
-
9
- df = fetch_idiom2def("d")
10
- for idiom, definition in idiom2def:
11
- print(idiom, definition)
12
-
13
-
14
- if __name__ == '__main__':
15
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
explore/explore_fetch_idioms.py CHANGED
@@ -2,7 +2,7 @@ from idiomify.fetchers import fetch_idioms
2
 
3
 
4
  def main():
5
- print(fetch_idioms("c"))
6
 
7
 
8
  if __name__ == '__main__':
 
2
 
3
 
4
  def main():
5
+ print(fetch_idioms("pie_v0"))
6
 
7
 
8
  if __name__ == '__main__':
explore/explore_fetch_literal2idiom.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from idiomify.fetchers import fetch_literal2idiom
2
+
3
+
4
+ def main():
5
+ for src, tgt in fetch_literal2idiom("pie_v0"):
6
+ print(src, "->", tgt)
7
+
8
+
9
+ if __name__ == '__main__':
10
+ main()
explore/explore_fetch_pie.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from idiomify.fetchers import fetch_pie
3
+
4
+
5
+ def main():
6
+ for idx, row in enumerate(fetch_pie()):
7
+ print(idx, row)
8
+ # the first 105 = V0.
9
+ if idx == 105:
10
+ break
11
+
12
+
13
+ if __name__ == '__main__':
14
+ main()
idiomify/builders.py CHANGED
@@ -19,6 +19,16 @@ class TensorBuilder:
19
  class Idiom2SubwordsBuilder(TensorBuilder):
20
 
21
  def __call__(self, idioms: List[str], k: int) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
22
  mask_id = self.tokenizer.mask_token_id
23
  pad_id = self.tokenizer.pad_token_id
24
  # temporarily disable single-token status of the idioms
@@ -31,38 +41,20 @@ class Idiom2SubwordsBuilder(TensorBuilder):
31
  max_length=k, # set to k
32
  return_tensors="pt")
33
  input_ids = encodings['input_ids']
34
- input_ids[input_ids == pad_id] = mask_id # replace them with masks
35
  return input_ids
36
 
37
 
38
- class Idiom2DefBuilder(TensorBuilder):
39
-
40
- def __call__(self, idiom2def: List[Tuple[str, str]], k: int) -> torch.Tensor:
41
- defs = [definition for _, definition in idiom2def]
42
- lefts = [" ".join(["[MASK]"] * k)] * len(defs)
43
- encodings = self.tokenizer(text=lefts,
44
- text_pair=defs,
45
- return_tensors="pt",
46
- add_special_tokens=True,
47
- truncation=True,
48
- padding=True,
49
- verbose=True)
50
- input_ids: torch.Tensor = encodings['input_ids']
51
- cls_id: int = self.tokenizer.cls_token_id
52
- sep_id: int = self.tokenizer.sep_token_id
53
- mask_id: int = self.tokenizer.mask_token_id
54
- wisdom_mask = torch.where(input_ids == mask_id, 1, 0)
55
- desc_mask = torch.where(((input_ids != cls_id) & (input_ids != sep_id) & (input_ids != mask_id)), 1, 0)
56
- return torch.stack([input_ids,
57
- encodings['token_type_ids'],
58
- encodings['attention_mask'],
59
- wisdom_mask,
60
- desc_mask], dim=1)
61
-
62
-
63
  class Idiom2ContextBuilder(TensorBuilder):
64
 
65
  def __call__(self, idiom2context: List[Tuple[str, str]]):
 
 
 
 
 
 
 
66
  contexts = [context for _, context in idiom2context]
67
  encodings = self.tokenizer(text=contexts,
68
  return_tensors="pt",
@@ -78,6 +70,14 @@ class Idiom2ContextBuilder(TensorBuilder):
78
  class TargetsBuilder(TensorBuilder):
79
 
80
  def __call__(self, idiom2sent: List[Tuple[str, str]], idioms: List[str]) -> torch.Tensor:
 
 
 
 
 
 
 
 
81
  return torch.LongTensor([
82
  idioms.index(idiom)
83
  for idiom, _ in idiom2sent
 
19
  class Idiom2SubwordsBuilder(TensorBuilder):
20
 
21
  def __call__(self, idioms: List[str], k: int) -> torch.Tensor:
22
+ """
23
+ 1. The function takes in a list of idioms, and a maximum length of the input sequence.
24
+ 2. It then splits the idioms into words, and pads the sequence to the maximum length.
25
+ 3. It masks the padding tokens, and returns the input ids
26
+ :param idioms: a list of idioms, each of which is a list of tokens
27
+ :type idioms: List[str]
28
+ :param k: the maximum length of the idioms
29
+ :type k: int
30
+ :return: The input_ids of the idioms, with the pad tokens replaced by the mask token.
31
+ """
32
  mask_id = self.tokenizer.mask_token_id
33
  pad_id = self.tokenizer.pad_token_id
34
  # temporarily disable single-token status of the idioms
 
41
  max_length=k, # set to k
42
  return_tensors="pt")
43
  input_ids = encodings['input_ids']
44
+ input_ids[input_ids == pad_id] = mask_id
45
  return input_ids
46
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  class Idiom2ContextBuilder(TensorBuilder):
49
 
50
  def __call__(self, idiom2context: List[Tuple[str, str]]):
51
+ """
52
+ Given a list of tuples of idiom and context,
53
+ it returns a tensor of shape (batch_size, 3, max_seq_len)
54
+ :param idiom2context: List[Tuple[str, str]], a list of tuples of idiom and context
55
+ :type idiom2context: List[Tuple[str, str]]
56
+ :return: The input_ids, token_type_ids, and attention_mask for each context.
57
+ """
58
  contexts = [context for _, context in idiom2context]
59
  encodings = self.tokenizer(text=contexts,
60
  return_tensors="pt",
 
70
  class TargetsBuilder(TensorBuilder):
71
 
72
  def __call__(self, idiom2sent: List[Tuple[str, str]], idioms: List[str]) -> torch.Tensor:
73
+ """
74
+ Given a list of idioms and a list of sentences, return a list of indices of the idioms in the sentences
75
+ :param idiom2sent: A list of tuples, where each tuple is an idiom and its corresponding sentence
76
+ :type idiom2sent: List[Tuple[str, str]]
77
+ :param idioms: A list of idioms
78
+ :type idioms: List[str]
79
+ :return: A tensor of indices of the idioms in the list of idioms.
80
+ """
81
  return torch.LongTensor([
82
  idioms.index(idiom)
83
  for idiom, _ in idiom2sent
idiomify/fetchers.py CHANGED
@@ -1,73 +1,91 @@
1
  import csv
 
2
  import yaml
3
  import wandb
4
  import requests
5
  from typing import Tuple, List
6
-
7
  from wandb.sdk.wandb_run import Run
8
-
 
9
  from idiomify.models import Alpha, RD
10
- from idiomify.paths import idiom2def_dir, CONFIG_YAML, idioms_dir, alpha_dir
11
  from idiomify.urls import (
12
  EPIE_IMMUTABLE_IDIOMS_URL,
13
  EPIE_IMMUTABLE_IDIOMS_CONTEXTS_URL,
14
  EPIE_IMMUTABLE_IDIOMS_TAGS_URL,
15
  EPIE_MUTABLE_IDIOMS_URL,
16
  EPIE_MUTABLE_IDIOMS_CONTEXTS_URL,
17
- EPIE_MUTABLE_IDIOMS_TAGS_URL
 
18
  )
19
- from idiomify.builders import Idiom2SubwordsBuilder
20
- from transformers import AutoModelForMaskedLM, AutoConfig, BertTokenizer
21
 
22
 
23
  # sources for dataset
24
- def fetch_epie() -> List[Tuple[str, str, str]]:
25
- idioms = requests.get(EPIE_IMMUTABLE_IDIOMS_URL).text \
26
- + requests.get(EPIE_MUTABLE_IDIOMS_URL).text
27
- contexts = requests.get(EPIE_IMMUTABLE_IDIOMS_CONTEXTS_URL).text \
28
- + requests.get(EPIE_MUTABLE_IDIOMS_CONTEXTS_URL).text
29
- tags = requests.get(EPIE_IMMUTABLE_IDIOMS_TAGS_URL).text \
30
- + requests.get(EPIE_MUTABLE_IDIOMS_TAGS_URL).text
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  return list(zip(idioms.strip().split("\n"),
32
  contexts.strip().split("\n"),
33
  tags.strip().split("\n")))
34
 
35
 
36
- # you should somehow get this from... wandb.
37
- def fetch_idiom2context(ver: str, run: Run = None) -> List[Tuple[str, str]]:
 
 
 
 
 
 
 
 
 
 
 
38
  """
39
- include run if you want to track the lineage
40
  """
 
 
41
  if run:
42
- pass
43
-
44
-
45
- # dataset
46
- def fetch_idiom2def(ver: str) -> List[Tuple[str, str]]:
47
- artifact = wandb.Api().artifact(f"eubinecto/idiomify-demo/idiom2def:{ver}", type="dataset")
48
- artifact_path = idiom2def_dir(ver)
49
- artifact.download(root=str(artifact_path))
50
- tsv_path = artifact_path / "all.tsv"
51
- with open(tsv_path, 'r') as fh:
52
- reader = csv.reader(fh, delimiter="\t")
53
- return [
54
- (row[0], row[1])
55
- for row in reader
56
- ]
57
 
58
 
59
- def fetch_idioms(ver: str) -> List[str]:
60
- artifact = wandb.Api().artifact(f"eubinecto/idiomify-demo/idioms:{ver}", type="dataset")
61
- artifact_path = idioms_dir(ver)
62
- artifact.download(root=str(artifact_path))
63
- tsv_path = artifact_path / "all.tsv"
 
 
 
 
64
  with open(tsv_path, 'r') as fh:
65
  reader = csv.reader(fh, delimiter="\t")
66
- next(reader)
67
- return [
68
- row[0]
69
- for row in reader
70
- ]
71
 
72
 
73
  def fetch_rd(model: str, ver: str) -> RD:
@@ -80,12 +98,13 @@ def fetch_rd(model: str, ver: str) -> RD:
80
  idioms = fetch_idioms(config['idioms_ver'])
81
  tokenizer = BertTokenizer.from_pretrained(config['bert'])
82
  idiom2subwords = Idiom2SubwordsBuilder(tokenizer)(idioms, config['k'])
83
- if model == Alpha.name():
84
- rd = Alpha.load_from_checkpoint(str(ckpt_path), mlm=mlm, idiom2subwords=idiom2subwords)
85
- elif model == Gamma.name():
86
- rd = Gamma.load_from_checkpoint(str(ckpt_path), mlm=mlm, idiom2subwords=idiom2subwords)
87
- else:
88
- raise ValueError
 
89
  return rd
90
 
91
 
 
1
  import csv
2
+ from os import path
3
  import yaml
4
  import wandb
5
  import requests
6
  from typing import Tuple, List
 
7
  from wandb.sdk.wandb_run import Run
8
+ from transformers import AutoModelForMaskedLM, AutoConfig, BertTokenizer
9
+ from idiomify.builders import Idiom2SubwordsBuilder
10
  from idiomify.models import Alpha, RD
11
+ from idiomify.paths import CONFIG_YAML, idioms_dir, alpha_dir, literal2idiom
12
  from idiomify.urls import (
13
  EPIE_IMMUTABLE_IDIOMS_URL,
14
  EPIE_IMMUTABLE_IDIOMS_CONTEXTS_URL,
15
  EPIE_IMMUTABLE_IDIOMS_TAGS_URL,
16
  EPIE_MUTABLE_IDIOMS_URL,
17
  EPIE_MUTABLE_IDIOMS_CONTEXTS_URL,
18
+ EPIE_MUTABLE_IDIOMS_TAGS_URL,
19
+ PIE_URL
20
  )
 
 
21
 
22
 
23
  # sources for dataset
24
+ def fetch_epie(ver: str) -> List[Tuple[str, str, str]]:
25
+ """
26
+ It fetches the EPIE idioms, contexts, and tags from the web
27
+ :param ver: str
28
+ :type ver: str
29
+ :return: A list of tuples. Each tuple contains three strings: an idiom, a context, and a tag.
30
+ """
31
+ if ver == "immutable":
32
+ idioms_url = EPIE_IMMUTABLE_IDIOMS_URL
33
+ contexts_url = EPIE_IMMUTABLE_IDIOMS_CONTEXTS_URL
34
+ tags_url = EPIE_IMMUTABLE_IDIOMS_TAGS_URL
35
+ elif ver == "mutable":
36
+ idioms_url = EPIE_MUTABLE_IDIOMS_URL
37
+ contexts_url = EPIE_MUTABLE_IDIOMS_CONTEXTS_URL
38
+ tags_url = EPIE_MUTABLE_IDIOMS_TAGS_URL
39
+ else:
40
+ raise ValueError
41
+ idioms = requests.get(idioms_url).text
42
+ contexts = requests.get(contexts_url).text
43
+ tags = requests.get(tags_url).text
44
  return list(zip(idioms.strip().split("\n"),
45
  contexts.strip().split("\n"),
46
  tags.strip().split("\n")))
47
 
48
 
49
+ def fetch_pie() -> list:
50
+ text = requests.get(PIE_URL).text
51
+ lines = (line for line in text.split("\n") if line)
52
+ reader = csv.reader(lines)
53
+ next(reader) # skip the header
54
+ return [
55
+ row
56
+ for row in reader
57
+ ]
58
+
59
+
60
+ # --- from wandb --- #
61
+ def fetch_idioms(ver: str, run: Run = None) -> List[str]:
62
  """
63
+ why do you need this? -> you need this to have access to the idiom embeddings.
64
  """
65
+ # if run object is given, we track the lineage of the data.
66
+ # if not, we get the dataset via wandb Api.
67
  if run:
68
+ artifact = run.use_artifact("idioms", type="dataset", aliases=ver)
69
+ else:
70
+ artifact = wandb.Api().artifact(f"eubinecto/idiomify/idioms:{ver}", type="dataset")
71
+ artifact_dir = artifact.download(root=idioms_dir(ver))
72
+ txt_path = path.join(artifact_dir, "all.txt")
73
+ with open(txt_path, 'r') as fh:
74
+ return [line.strip() for line in fh]
 
 
 
 
 
 
 
 
75
 
76
 
77
+ def fetch_literal2idiom(ver: str, run: Run = None) -> List[Tuple[str, str]]:
78
+ # if run object is given, we track the lineage of the data.
79
+ # if not, we get the dataset via wandb Api.
80
+ if run:
81
+ artifact = run.use_artifact("literal2idiom", type="dataset", aliases=ver)
82
+ else:
83
+ artifact = wandb.Api().artifact(f"eubinecto/idiomify/literal2idiom:{ver}", type="dataset")
84
+ artifact_dir = artifact.download(root=literal2idiom(ver))
85
+ tsv_path = path.join(artifact_dir, "all.tsv")
86
  with open(tsv_path, 'r') as fh:
87
  reader = csv.reader(fh, delimiter="\t")
88
+ return [(row[0], row[1]) for row in reader]
 
 
 
 
89
 
90
 
91
  def fetch_rd(model: str, ver: str) -> RD:
 
98
  idioms = fetch_idioms(config['idioms_ver'])
99
  tokenizer = BertTokenizer.from_pretrained(config['bert'])
100
  idiom2subwords = Idiom2SubwordsBuilder(tokenizer)(idioms, config['k'])
101
+ # if model == Alpha.name():
102
+ # rd = Alpha.load_from_checkpoint(str(ckpt_path), mlm=mlm, idiom2subwords=idiom2subwords)
103
+ # elif model == Gamma.name():
104
+ # rd = Gamma.load_from_checkpoint(str(ckpt_path), mlm=mlm, idiom2subwords=idiom2subwords)
105
+ # else:
106
+ # raise ValueError
107
+ rd = ...
108
  return rd
109
 
110
 
idiomify/models.py CHANGED
@@ -8,14 +8,12 @@ import pytorch_lightning as pl
8
  from transformers import BertForMaskedLM
9
 
10
 
11
- class RD(pl.LightningModule):
12
  """
13
  @eubinecto
14
  The superclass of all the reverse-dictionaries. This class houses any methods that are required by
15
  whatever reverse-dictionaries we define.
16
  """
17
-
18
- # --- boilerplate; the loaders are defined in datamodules, so we don't define them here
19
  # passing them to avoid warnings --- #
20
  def train_dataloader(self):
21
  pass
@@ -35,119 +33,24 @@ class RD(pl.LightningModule):
35
  :param idiom2subwords: (|W|, K)
36
  :return: (N, K, |V|); (num samples, k, the size of the vocabulary of subwords)
37
  """
38
- super().__init__()
39
- # -- hyper params --- #
40
- # should be saved to self.hparams
41
- # https://github.com/PyTorchLightning/pytorch-lightning/issues/4390#issue-730493746
42
- self.save_hyperparameters(ignore=["mlm", "idiom2subwords"])
43
- # -- the only neural network we need -- #
44
- self.mlm = mlm
45
- # --- to be used for getting H_k --- #
46
- self.wisdom_mask: Optional[torch.Tensor] = None # (N, L)
47
- # --- to be used for getting H_desc --- #
48
- self.desc_mask: Optional[torch.Tensor] = None # (N, L)
49
- # -- constant tensors -- #
50
- self.register_buffer("idiom2subwords", idiom2subwords) # (|W|, K)
51
 
52
  def forward(self, X: torch.Tensor) -> torch.Tensor:
53
  """
54
- :param X: (N, 4, L);
55
- (num samples, 0=input_ids/1=token_type_ids/2=attention_mask/3=wisdom_mask, the maximum length)
56
- :return: (N, L, H); (num samples, k, the size of the vocabulary of subwords)
57
- """
58
- input_ids = X[:, 0] # (N, 4, L) -> (N, L)
59
- token_type_ids = X[:, 1] # (N, 4, L) -> (N, L)
60
- attention_mask = X[:, 2] # (N, 4, L) -> (N, L)
61
- self.wisdom_mask = X[:, 3] # (N, 4, L) -> (N, L)
62
- self.desc_mask = X[:, 4] # (N, 4, L) -> (N, L)
63
- H_all = self.mlm.bert.forward(input_ids, attention_mask, token_type_ids)[0] # (N, 3, L) -> (N, L, H)
64
- return H_all
65
-
66
- def H_k(self, H_all: torch.Tensor) -> torch.Tensor:
67
- """
68
- You may want to override this. (e.g. RDGamma - the k's could be anywhere)
69
- :param H_all (N, L, H)
70
- :return H_k (N, K, H)
71
- """
72
- N, _, H = H_all.size()
73
- # refer to: wisdomify/examples/explore_masked_select.py
74
- wisdom_mask = self.wisdom_mask.unsqueeze(2).expand(H_all.shape) # (N, L) -> (N, L, 1) -> (N, L, H)
75
- H_k = torch.masked_select(H_all, wisdom_mask.bool()) # (N, L, H), (N, L, H) -> (N * K * H)
76
- H_k = H_k.reshape(N, self.hparams['k'], H) # (N * K * H) -> (N, K, H)
77
- return H_k
78
-
79
- def H_desc(self, H_all: torch.Tensor) -> torch.Tensor:
80
  """
81
- :param H_all (N, L, H)
82
- :return H_desc (N, L - (K + 3), H)
83
- """
84
- N, L, H = H_all.size()
85
- desc_mask = self.desc_mask.unsqueeze(2).expand(H_all.shape)
86
- H_desc = torch.masked_select(H_all, desc_mask.bool()) # (N, L, H), (N, L, H) -> (N * (L - (K + 3)) * H)
87
- H_desc = H_desc.reshape(N, L - (self.hparams['k'] + 3), H) # (N * (L - (K + 3)) * H) -> (N, L - (K + 3), H)
88
- return H_desc
89
-
90
- def S_wisdom_literal(self, H_k: torch.Tensor) -> torch.Tensor:
91
- """
92
- To be used for both RDAlpha & RDBeta
93
- :param H_k: (N, K, H)
94
- :return: S_wisdom_literal (N, |W|)
95
- """
96
- S_vocab = self.mlm.cls(H_k) # bmm; (N, K, H) * (H, |V|) -> (N, K, |V|)
97
- indices = self.idiom2subwords.T.repeat(S_vocab.shape[0], 1, 1) # (|W|, K) -> (N, K, |W|)
98
- S_wisdom_literal = S_vocab.gather(dim=-1, index=indices) # (N, K, |V|) -> (N, K, |W|)
99
- S_wisdom_literal = S_wisdom_literal.sum(dim=1) # (N, K, |W|) -> (N, |W|)
100
- return S_wisdom_literal
101
-
102
- def S_wisdom(self, H_all: torch.Tensor) -> torch.Tensor:
103
- """
104
- :param H_all: (N, L, H)
105
- :return S_wisdom: (N, |W|)
106
- """
107
- raise NotImplementedError("An RD class must implement S_wisdom")
108
-
109
- def P_wisdom(self, X: torch.Tensor) -> torch.Tensor:
110
- """
111
- :param X: (N, 3, L)
112
- :return P_wisdom: (N, |W|), normalized over dim 1.
113
- """
114
- H_all = self.forward(X) # (N, 3, L) -> (N, L, H)
115
- S_wisdom = self.S_wisdom(H_all) # (N, L, H) -> (N, W)
116
- P_wisdom = F.softmax(S_wisdom, dim=1) # (N, W) -> (N, W)
117
- return P_wisdom
118
-
119
- def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> dict:
120
- X, y = batch
121
- H_all = self.forward(X) # (N, 3, L) -> (N, L, H)
122
- S_wisdom = self.S_wisdom(H_all) # (N, L, H) -> (N, |W|)
123
- loss = F.cross_entropy(S_wisdom, y) # (N, |W|), (N,) -> (N,)
124
- loss = loss.sum() # (N,) -> (1,)
125
- # so that the metrics accumulate over the course of this epoch
126
- # why dict? - just a boilerplate
127
- return {
128
- # you cannot change the keyword for the loss
129
- "loss": loss,
130
- }
131
-
132
- def on_train_batch_end(self, outputs: dict, *args, **kwargs) -> None:
133
- # watch the loss for this batch
134
- self.log("Train/Loss", outputs['loss'])
135
-
136
- def training_epoch_end(self, outputs: List[dict]) -> None:
137
- # to see an average performance over the batches in this specific epoch
138
- avg_loss = torch.stack([output['loss'].detach() for output in outputs]).mean()
139
- self.log("Train/Average Loss", avg_loss)
140
 
141
- def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> dict:
142
- return self.training_step(batch, batch_idx)
143
 
144
- def on_validation_batch_end(self, outputs: dict, *args, **kwargs) -> None:
145
- self.log("Validation/Loss", outputs['loss'])
146
 
147
- def validation_epoch_end(self, outputs: List[dict]) -> None:
148
- # to see an average performance over the batches in this specific epoch
149
- avg_loss = torch.stack([output['loss'].detach() for output in outputs]).mean()
150
- self.log("Validation/Average Loss", avg_loss)
151
 
152
  def configure_optimizers(self) -> torch.optim.Optimizer:
153
  """
@@ -162,7 +65,7 @@ class RD(pl.LightningModule):
162
  return cls.__name__.lower()
163
 
164
 
165
- class Alpha(RD):
166
  """
167
  @eubinecto
168
  The first prototype.
 
8
  from transformers import BertForMaskedLM
9
 
10
 
11
+ class Idiomifier(pl.LightningModule):
12
  """
13
  @eubinecto
14
  The superclass of all the reverse-dictionaries. This class houses any methods that are required by
15
  whatever reverse-dictionaries we define.
16
  """
 
 
17
  # passing them to avoid warnings --- #
18
  def train_dataloader(self):
19
  pass
 
33
  :param idiom2subwords: (|W|, K)
34
  :return: (N, K, |V|); (num samples, k, the size of the vocabulary of subwords)
35
  """
36
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def forward(self, X: torch.Tensor) -> torch.Tensor:
39
  """
40
+ given a batch, forward returns a batch of hidden vectors
41
+ :param X: (N, 3, L). input_ids, token_type_ids, and what was the last one...?
42
+ :return: (N, L, H)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ def step(self):
47
+ pass
48
 
49
+ def predict(self):
50
+ pass
51
 
52
+ def training_step(self):
53
+ pass
 
 
54
 
55
  def configure_optimizers(self) -> torch.optim.Optimizer:
56
  """
 
65
  return cls.__name__.lower()
66
 
67
 
68
+ class Alpha(Idiomifier):
69
  """
70
  @eubinecto
71
  The first prototype.
idiomify/paths.py CHANGED
@@ -5,14 +5,14 @@ ARTIFACTS_DIR = ROOT_DIR / "artifacts"
5
  CONFIG_YAML = ROOT_DIR / "config.yaml"
6
 
7
 
8
- def idiom2def_dir(ver: str) -> Path:
9
- return ARTIFACTS_DIR / f"idiom2def_{ver}"
10
-
11
-
12
  def idioms_dir(ver: str) -> Path:
13
  return ARTIFACTS_DIR / f"idioms_{ver}"
14
 
15
 
 
 
 
 
16
  def alpha_dir(ver: str) -> Path:
17
  return ARTIFACTS_DIR / f"alpha_{ver}"
18
 
 
5
  CONFIG_YAML = ROOT_DIR / "config.yaml"
6
 
7
 
 
 
 
 
8
  def idioms_dir(ver: str) -> Path:
9
  return ARTIFACTS_DIR / f"idioms_{ver}"
10
 
11
 
12
+ def literal2idiom(ver: str) -> Path:
13
+ return ARTIFACTS_DIR / f"literal2idiom_{ver}"
14
+
15
+
16
  def alpha_dir(ver: str) -> Path:
17
  return ARTIFACTS_DIR / f"alpha_{ver}"
18
 
idiomify/urls.py CHANGED
@@ -7,5 +7,10 @@ EPIE_MUTABLE_IDIOMS_TAGS_URL = "https://raw.githubusercontent.com/prateeksaxena2
7
  EPIE_MUTABLE_IDIOMS_URL = "https://raw.githubusercontent.com/prateeksaxena2809/EPIE_Corpus/master/Formal_Idioms_Corpus/Formal_Idioms_Candidates.txt" # noqa
8
  EPIE_MUTABLE_IDIOMS_CONTEXTS_URL = "https://github.com/prateeksaxena2809/EPIE_Corpus/blob/master/Formal_Idioms_Corpus/Formal_Idioms_Words.txt" # noqa
9
 
 
 
 
 
 
10
 
11
 
 
7
  EPIE_MUTABLE_IDIOMS_URL = "https://raw.githubusercontent.com/prateeksaxena2809/EPIE_Corpus/master/Formal_Idioms_Corpus/Formal_Idioms_Candidates.txt" # noqa
8
  EPIE_MUTABLE_IDIOMS_CONTEXTS_URL = "https://github.com/prateeksaxena2809/EPIE_Corpus/blob/master/Formal_Idioms_Corpus/Formal_Idioms_Words.txt" # noqa
9
 
10
+ # PIE dataset (Zhou, 2021)
11
+ # https://aclanthology.org/2021.mwe-1.5/
12
+ # right, let's just work on it.
13
+ PIE_URL = "https://raw.githubusercontent.com/zhjjn/MWE_PIE/main/data_cleaned.csv"
14
+
15
 
16
 
main_infer.py CHANGED
@@ -1,36 +1,37 @@
1
- import argparse
2
- from idiomify import tensors as T
3
- from idiomify.fetchers import fetch_config, fetch_rd, fetch_idioms
4
- from transformers import BertTokenizer
5
- from termcolor import colored
6
-
7
- def main():
8
- parser = argparse.ArgumentParser()
9
- parser.add_argument("--model", type=str,
10
- default="alpha")
11
- parser.add_argument("--ver", type=str,
12
- default="eng2eng")
13
- parser.add_argument("--sent", type=str,
14
- default="to avoid getting to the point")
15
- args = parser.parse_args()
16
- config = fetch_config()[args.model][args.ver]
17
- config.update(vars(args))
18
- idioms = fetch_idioms(config['idioms_ver'])
19
- rd = fetch_rd(config['model'], config['ver'])
20
- rd.eval()
21
- tokenizer = BertTokenizer.from_pretrained(config['bert'])
22
- X = T.inputs([("", config['sent'])], tokenizer, config['k'])
23
- probs = rd.P_wisdom(X).squeeze().tolist()
24
- wisdom2prob = [
25
- (wisdom, prob)
26
- for wisdom, prob in zip(idioms, probs)
27
- ]
28
- # sort and append
29
- res = list(sorted(wisdom2prob, key=lambda x: x[1], reverse=True))
30
- print(f"query: {colored(text=config['sent'], color='blue')}")
31
- for idx, (idiom, prob) in enumerate(res):
32
- print(idx, idiom, prob)
33
-
34
-
35
- if __name__ == '__main__':
36
- main()
 
 
1
+ # we disable them for now.
2
+ # import argparse
3
+ # from idiomify.fetchers import fetch_config, fetch_rd, fetch_idioms
4
+ # from transformers import BertTokenizer
5
+ # from termcolor import colored
6
+ #
7
+ #
8
+ # def main():
9
+ # parser = argparse.ArgumentParser()
10
+ # parser.add_argument("--model", type=str,
11
+ # default="alpha")
12
+ # parser.add_argument("--ver", type=str,
13
+ # default="eng2eng")
14
+ # parser.add_argument("--sent", type=str,
15
+ # default="to avoid getting to the point")
16
+ # args = parser.parse_args()
17
+ # config = fetch_config()[args.model][args.ver]
18
+ # config.update(vars(args))
19
+ # idioms = fetch_idioms(config['idioms_ver'])
20
+ # rd = fetch_rd(config['model'], config['ver'])
21
+ # rd.eval()
22
+ # tokenizer = BertTokenizer.from_pretrained(config['bert'])
23
+ # X = T.inputs([("", config['sent'])], tokenizer, config['k'])
24
+ # probs = rd.P_wisdom(X).squeeze().tolist()
25
+ # wisdom2prob = [
26
+ # (wisdom, prob)
27
+ # for wisdom, prob in zip(idioms, probs)
28
+ # ]
29
+ # # sort and append
30
+ # res = list(sorted(wisdom2prob, key=lambda x: x[1], reverse=True))
31
+ # print(f"query: {colored(text=config['sent'], color='blue')}")
32
+ # for idx, (idiom, prob) in enumerate(res):
33
+ # print(idx, idiom, prob)
34
+ #
35
+ #
36
+ # if __name__ == '__main__':
37
+ # main()
main_upload_idioms.py CHANGED
@@ -1,12 +1,40 @@
1
  """
2
- Here,
3
- ver a: Compatible with the first version
4
- ver b:
5
  """
 
 
 
 
 
6
 
7
 
8
  def main():
9
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  if __name__ == '__main__':
 
1
  """
2
+ Here, what should you do here?
3
+ just upload all idioms here - name it as epie.
 
4
  """
5
+ import os
6
+ from idiomify.paths import ROOT_DIR
7
+ from idiomify.fetchers import fetch_pie
8
+ import argparse
9
+ import wandb
10
 
11
 
12
  def main():
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--ver", type=str, default="pie_v0",
15
+ choices=["pie_v0", "pie_v1"])
16
+ config = vars(parser.parse_args())
17
+
18
+ # get the idioms here
19
+ if config['ver'] == "pie_v0":
20
+ # only the first 106, and this is for piloting
21
+ idioms = set([row[0] for row in fetch_pie()[:106]])
22
+ elif config['ver'] == "pie_v1":
23
+ # just include all
24
+ idioms = set([row[0] for row in fetch_pie()])
25
+ else:
26
+ raise NotImplementedError
27
+ idioms = list(idioms)
28
+
29
+ with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
30
+ artifact = wandb.Artifact(name="idioms", type="dataset")
31
+ txt_path = ROOT_DIR / "all.txt"
32
+ with open(txt_path, 'w') as fh:
33
+ for idiom in idioms:
34
+ fh.write(idiom + "\n")
35
+ artifact.add_file(txt_path)
36
+ run.log_artifact(artifact, aliases=["latest", config['ver']])
37
+ os.remove(txt_path)
38
 
39
 
40
  if __name__ == '__main__':
main_upload_literal2idiom.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Here, what should you do here?
3
+ just upload all idioms here - name it as epie.
4
+ """
5
+ import csv
6
+ import os
7
+ from idiomify.paths import ROOT_DIR
8
+ from idiomify.fetchers import fetch_pie
9
+ import argparse
10
+ import wandb
11
+
12
+
13
+ def main():
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument("--ver", type=str, default="pie_v0",
16
+ choices=["pie_v0", "pie_v1"])
17
+ config = vars(parser.parse_args())
18
+
19
+ # get the idioms here
20
+ if config['ver'] == "pie_v0":
21
+ # only the first 106, and we use this just for piloting
22
+ literal2idiom = [
23
+ (row[3], row[2]) for row in fetch_pie()[:106]
24
+ ]
25
+ elif config['ver'] == "pie_v1":
26
+ # just include all
27
+ literal2idiom = [
28
+ (row[3], row[2]) for row in fetch_pie()
29
+ ]
30
+ else:
31
+ raise NotImplementedError
32
+
33
+ with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
34
+ artifact = wandb.Artifact(name="literal2idiom", type="dataset")
35
+ tsv_path = ROOT_DIR / "all.tsv"
36
+ with open(tsv_path, 'w') as fh:
37
+ writer = csv.writer(fh, delimiter="\t")
38
+ for row in literal2idiom:
39
+ writer.writerow(row)
40
+ artifact.add_file(tsv_path)
41
+ run.log_artifact(artifact, aliases=["latest", config['ver']])
42
+ os.remove(tsv_path)
43
+
44
+
45
+ if __name__ == '__main__':
46
+ main()
main_upload_tokenizer.py DELETED
@@ -1,13 +0,0 @@
1
- """
2
- Build & upload a tokenizer to wandb.
3
- You need this if you were to add more tokens there.
4
- """
5
-
6
-
7
- def main():
8
- pass
9
- # TODO: fetch the dataset from wandb first!
10
-
11
-
12
- if __name__ == '__main__':
13
- main()