eubinecto commited on
Commit
64a6414
1 Parent(s): 6fd648a

[#1] refactoring: Alpha -> Seq2Seq. We rely on git tags for versioning models.

Browse files
config.yaml CHANGED
@@ -1,8 +1,8 @@
1
- alpha:
2
- overfit:
3
- bart: facebook/bart-base
4
- lr: 0.0001
5
- literal2idiomatic_ver: pie_v0
6
- max_epochs: 100
7
- batch_size: 100
8
- shuffle: true
 
1
+ tag011:
2
+ desc: just overfitting
3
+ bart: facebook/bart-base
4
+ lr: 0.0001
5
+ literal2idiomatic_ver: tag01
6
+ max_epochs: 100
7
+ batch_size: 100
8
+ shuffle: true
explore/explore_bert_base_multilingual_tokenizer.py DELETED
@@ -1,44 +0,0 @@
1
- from idiomify.fetchers import fetch_idiom2def
2
- from transformers import AutoTokenizer, BertTokenizer, BertTokenizerFast
3
-
4
-
5
- def main():
6
- tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-uncased")
7
- idiom2def = fetch_idiom2def("d") # eng2kor
8
-
9
- for idiom, definition in idiom2def:
10
- print(tokenizer.decode(tokenizer(idiom)['input_ids']),
11
- tokenizer.decode(tokenizer(definition)['input_ids']))
12
-
13
- # right, the tokenizer knows Korean, which is great.
14
- """
15
- /opt/homebrew/Caskroom/miniforge/base/envs/idiomify-demo/bin/python /Users/eubinecto/Desktop/Projects/Toy/idiomify-demo/explore/explore_mbert_tokenizer.py
16
- [CLS] beat around the bush [SEP] [CLS] 불쾌하거나 민감한 주제에 대해 직접적으로 이야기하는 것을 피하기 위해 모호하거나 완곡하게 말한다. [SEP]
17
- [CLS] beat around the bush [SEP] [CLS] 단어나 태도가 우회적이다 [SEP]
18
- [CLS] beat around the bush [SEP] [CLS] 우물쭈물하다 [SEP]
19
- [CLS] beat around the bush [SEP] [CLS] 우회적으로 접근하다 [SEP]
20
- [CLS] backhanded compliment [SEP] [CLS] 칭찬으로 가장한 모욕적이거나 부정적인 논평 [SEP]
21
- [CLS] backhanded compliment [SEP] [CLS] 의도하지 않거나 애매한 칭찬 [SEP]
22
- [CLS] backhanded compliment [SEP] [CLS] 누군가를 칭찬하는 것 같지만 비판으로도 이해될 수 있는 말 [SEP]
23
- [CLS] backhanded compliment [SEP] [CLS] 남을 기쁘게 하는 말 같지만 모욕이 될 수도 있는 말 [SEP]
24
- [CLS] backhanded compliment [SEP] [CLS] 감탄하는 듯 하면서도 모욕으로 이해될 수 있는 말 [SEP]
25
- [CLS] steer clear of [SEP] [CLS] 누군가나 뭔가를 피하다 [SEP]
26
- [CLS] steer clear of [SEP] [CLS] 떨어져 지내다 [SEP]
27
- [CLS] steer clear of [SEP] [CLS] 피하거나 멀리하도록 주의하다 [SEP]
28
- [CLS] steer clear of [SEP] [CLS] 불쾌하거나 위험하거나 문제를 일으킬 것 같은 사람이나 물건을 피하다 [SEP]
29
- [CLS] steer clear of [SEP] [CLS] 일부러 피하다 [SEP]
30
- [CLS] dish it out [SEP] [CLS] 가혹한 생각, 비판, 또는 모욕의 목소리를 내는 것. [SEP]
31
- [CLS] dish it out [SEP] [CLS] 누군가 또는 무언가에 대해 험담하는 것 [SEP]
32
- [CLS] dish it out [SEP] [CLS] 어떤 것을 주거나 정보나 당신의 의견과 같은 것을 말하는 것 [SEP]
33
- [CLS] dish it out [SEP] [CLS] 다른 사람을 쉽게 비판하지만 다른 사람이 자신을 비판할때는 좋아하지 않음 [SEP]
34
- [CLS] dish it out [SEP] [CLS] 다른 사람을 비판하다 [SEP]
35
- [CLS] make headway [SEP] [CLS] 성취하고자 하는 어떤 것에 진척이 생기다 [SEP]
36
- [CLS] make headway [SEP] [CLS] 특히 이것이 느리거나 어려울 때, 진전을 이루다. [SEP]
37
- [CLS] make headway [SEP] [CLS] 전진하다 [SEP]
38
- [CLS] make headway [SEP] [CLS] 앞으로 나아가거나 진전을 이루다 [SEP]
39
- [CLS] make headway [SEP] [CLS] 성공하기 시작하다 [SEP]
40
- """
41
-
42
-
43
- if __name__ == '__main__':
44
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
explore/explore_bert_base_tokenizer.py DELETED
@@ -1,45 +0,0 @@
1
- from idiomify.fetchers import fetch_idiom2def
2
- from transformers import AutoTokenizer, BertTokenizer, BertTokenizerFast
3
-
4
-
5
- def main():
6
-
7
- tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
8
- idiom2def = fetch_idiom2def("c") # eng2eng
9
- for idiom, definition in idiom2def:
10
- print(tokenizer.decode(tokenizer(idiom)['input_ids']),
11
- tokenizer.decode(tokenizer(definition)['input_ids']))
12
-
13
- """
14
- /opt/homebrew/Caskroom/miniforge/base/envs/idiomify-demo/bin/python /Users/eubinecto/Desktop/Projects/Toy/idiomify-demo/explore/explore_bert_base_tokenizer.py
15
- Downloading: 100%|██████████| 226k/226k [00:00<00:00, 298kB/s]
16
- Downloading: 100%|██████████| 28.0/28.0 [00:00<00:00, 8.27kB/s]
17
- Downloading: 100%|██████████| 455k/455k [00:01<00:00, 449kB/s]
18
- [CLS] beat around the bush [SEP] [CLS] to speak vaguely or euphemistically so as to avoid talkingdirectly about an unpleasant or sensitive topic [SEP]
19
- [CLS] beat around the bush [SEP] [CLS] indirection in word or deed [SEP]
20
- [CLS] beat around the bush [SEP] [CLS] to shilly - shally [SEP]
21
- [CLS] beat around the bush [SEP] [CLS] to approach something in a roundabout way [SEP]
22
- [CLS] backhanded compliment [SEP] [CLS] an insulting or negative comment disguised as praise. [SEP]
23
- [CLS] backhanded compliment [SEP] [CLS] an unintended or ambiguous compliment. [SEP]
24
- [CLS] backhanded compliment [SEP] [CLS] a remark which seems to be praising someone or something but which could also be understood as criticism [SEP]
25
- [CLS] backhanded compliment [SEP] [CLS] a remark that seems to say something pleasant about a person but could also be an insult [SEP]
26
- [CLS] backhanded compliment [SEP] [CLS] a remark that seems to express admiration but could also be understood as an insult [SEP]
27
- [CLS] steer clear of [SEP] [CLS] to avoid someone or something. [SEP]
28
- [CLS] steer clear of [SEP] [CLS] stay away from [SEP]
29
- [CLS] steer clear of [SEP] [CLS] take care to avoid or keep away from [SEP]
30
- [CLS] steer clear of [SEP] [CLS] to avoid someone or something that seems unpleasant, dangerous, or likely to cause problems [SEP]
31
- [CLS] steer clear of [SEP] [CLS] deliberately avoid someone [SEP]
32
- [CLS] dish it out [SEP] [CLS] to voice harsh thoughts, criticisms, or insults. [SEP]
33
- [CLS] dish it out [SEP] [CLS] to gossip about someone or something [SEP]
34
- [CLS] dish it out [SEP] [CLS] to give something, or to tell something such as information or your opinions [SEP]
35
- [CLS] dish it out [SEP] [CLS] someone easily criticizes other people but does not like it when other people criticize him or her [SEP]
36
- [CLS] dish it out [SEP] [CLS] to criticize other people [SEP]
37
- [CLS] make headway [SEP] [CLS] make progress with something that you are trying to achieve. [SEP]
38
- [CLS] make headway [SEP] [CLS] make progress, especially when this is slow or difficult [SEP]
39
- [CLS] make headway [SEP] [CLS] to advance. [SEP]
40
- [CLS] make headway [SEP] [CLS] to move forward or make progress [SEP]
41
- [CLS] make headway [SEP] [CLS] to begin to succeed [SEP]
42
- """
43
-
44
- if __name__ == '__main__':
45
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
explore/explore_fetch_epie_counts.py DELETED
@@ -1,19 +0,0 @@
1
- from idiomify.fetchers import fetch_epie
2
-
3
-
4
- def main():
5
- idioms = set([
6
- idiom
7
- for idiom, _, _ in fetch_epie()
8
- ])
9
- contexts = [
10
- context
11
- for _, _, context in fetch_epie()
12
- ]
13
- print("Total number of idioms:", len(idioms))
14
- # This should learn... this - what I need for now is building a datamodule out of this
15
- print("Total number of contexts:", len(contexts))
16
-
17
-
18
- if __name__ == '__main__':
19
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
explore/{explore_fetch_alpha.py → explore_fetch_seq2seq.py} RENAMED
@@ -1,8 +1,8 @@
1
- from idiomify.fetchers import fetch_alpha
2
 
3
 
4
  def main():
5
- model = fetch_alpha("overfit")
6
  print(model.bart.config)
7
 
8
 
 
1
+ from idiomify.fetchers import fetch_seq2seq
2
 
3
 
4
  def main():
5
+ model = fetch_seq2seq("overfit")
6
  print(model.bart.config)
7
 
8
 
explore/{explore_fetch_alpha_predict.py → explore_fetch_seq2seq_predict.py} RENAMED
@@ -1,10 +1,10 @@
1
  from transformers import BartTokenizer
2
  from builders import SourcesBuilder
3
- from fetchers import fetch_alpha
4
 
5
 
6
  def main():
7
- model = fetch_alpha("overfit")
8
  tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
9
  lit2idi = [
10
  ("my man", ""),
 
1
  from transformers import BartTokenizer
2
  from builders import SourcesBuilder
3
+ from fetchers import fetch_seq2seq
4
 
5
 
6
  def main():
7
+ model = fetch_seq2seq("overfit")
8
  tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
9
  lit2idi = [
10
  ("my man", ""),
explore/explore_nlpaug.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import nlpaug.augmenter.word as naw
3
+ import nlpaug.augmenter.sentence as nas
4
+
5
+ import nltk
6
+
7
+
8
+ sent = "I am really happy with the new job and I mean that with sincere feeling"
9
+
10
+
11
+ def main():
12
+ nltk.download("omw-1.4")
13
+ # this seems legit! I could definitely use this to increase the accuracy of the model
14
+ # for a few idioms (possibly ten, ten very different but frequent idioms)
15
+ aug = naw.ContextualWordEmbsAug()
16
+ augmented = aug.augment(sent, n=10)
17
+ print(augmented)
18
+
19
+
20
+ if __name__ == '__main__':
21
+ main()
idiomify/builders.py CHANGED
@@ -4,7 +4,7 @@ builders must accept device as one of the parameters.
4
  """
5
  import torch
6
  from typing import List, Tuple
7
- from transformers import BertTokenizer, BartTokenizer
8
 
9
 
10
  class TensorBuilder:
@@ -61,7 +61,9 @@ class SourcesBuilder(TensorBuilder):
61
 
62
 
63
  class TargetsRightShiftedBuilder(TensorBuilder):
64
-
 
 
65
  def __call__(self, literal2idiomatic: List[Tuple[str, str]]) -> torch.Tensor:
66
  encodings = self.tokenizer([
67
  self.tokenizer.bos_token + idiomatic # starts with bos, but does not end with eos (right-shifted)
@@ -73,9 +75,7 @@ class TargetsRightShiftedBuilder(TensorBuilder):
73
 
74
 
75
  class TargetsBuilder(TensorBuilder):
76
- """
77
- This is to be used only for training. As for inference, we don't need this.
78
- """
79
  def __call__(self, literal2idiomatic: List[Tuple[str, str]]) -> torch.Tensor:
80
  encodings = self.tokenizer([
81
  idiomatic + self.tokenizer.eos_token # no bos, but ends with eos
 
4
  """
5
  import torch
6
  from typing import List, Tuple
7
+ from transformers import BartTokenizer
8
 
9
 
10
  class TensorBuilder:
 
61
 
62
 
63
  class TargetsRightShiftedBuilder(TensorBuilder):
64
+ """
65
+ This is to be used only for training. As for inference, we don't need this.
66
+ """
67
  def __call__(self, literal2idiomatic: List[Tuple[str, str]]) -> torch.Tensor:
68
  encodings = self.tokenizer([
69
  self.tokenizer.bos_token + idiomatic # starts with bos, but does not end with eos (right-shifted)
 
75
 
76
 
77
  class TargetsBuilder(TensorBuilder):
78
+
 
 
79
  def __call__(self, literal2idiomatic: List[Tuple[str, str]]) -> torch.Tensor:
80
  encodings = self.tokenizer([
81
  idiomatic + self.tokenizer.eos_token # no bos, but ends with eos
idiomify/fetchers.py CHANGED
@@ -5,43 +5,10 @@ import wandb
5
  import requests
6
  from typing import Tuple, List
7
  from wandb.sdk.wandb_run import Run
8
- from idiomify.paths import CONFIG_YAML, idioms_dir, literal2idiomatic, alpha_dir
9
- from idiomify.urls import (
10
- EPIE_IMMUTABLE_IDIOMS_URL,
11
- EPIE_IMMUTABLE_IDIOMS_CONTEXTS_URL,
12
- EPIE_IMMUTABLE_IDIOMS_TAGS_URL,
13
- EPIE_MUTABLE_IDIOMS_URL,
14
- EPIE_MUTABLE_IDIOMS_CONTEXTS_URL,
15
- EPIE_MUTABLE_IDIOMS_TAGS_URL,
16
- PIE_URL
17
- )
18
  from transformers import AutoModelForSeq2SeqLM, AutoConfig
19
- from models import Alpha
20
-
21
-
22
- def fetch_epie(ver: str) -> List[Tuple[str, str, str]]:
23
- """
24
- It fetches the EPIE idioms, contexts, and tags from the web
25
- :param ver: str
26
- :type ver: str
27
- :return: A list of tuples. Each tuple contains three strings: an idiom, a context, and a tag.
28
- """
29
- if ver == "immutable":
30
- idioms_url = EPIE_IMMUTABLE_IDIOMS_URL
31
- contexts_url = EPIE_IMMUTABLE_IDIOMS_CONTEXTS_URL
32
- tags_url = EPIE_IMMUTABLE_IDIOMS_TAGS_URL
33
- elif ver == "mutable":
34
- idioms_url = EPIE_MUTABLE_IDIOMS_URL
35
- contexts_url = EPIE_MUTABLE_IDIOMS_CONTEXTS_URL
36
- tags_url = EPIE_MUTABLE_IDIOMS_TAGS_URL
37
- else:
38
- raise ValueError
39
- idioms = requests.get(idioms_url).text
40
- contexts = requests.get(contexts_url).text
41
- tags = requests.get(tags_url).text
42
- return list(zip(idioms.strip().split("\n"),
43
- contexts.strip().split("\n"),
44
- tags.strip().split("\n")))
45
 
46
 
47
  def fetch_pie() -> list:
@@ -86,16 +53,16 @@ def fetch_literal2idiomatic(ver: str, run: Run = None) -> List[Tuple[str, str]]:
86
  return [(row[0], row[1]) for row in reader]
87
 
88
 
89
- def fetch_alpha(ver: str, run: Run = None) -> Alpha:
90
  if run:
91
- artifact = run.use_artifact(f"alpha:{ver}", type="model")
92
  else:
93
- artifact = wandb.Api().artifact(f"eubinecto/idiomify/alpha:{ver}", type="model")
94
  config = artifact.metadata
95
- artifact_dir = artifact.download(root=alpha_dir(ver))
96
  ckpt_path = path.join(artifact_dir, "model.ckpt")
97
  bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
98
- alpha = Alpha.load_from_checkpoint(ckpt_path, bart=bart)
99
  return alpha
100
 
101
 
 
5
  import requests
6
  from typing import Tuple, List
7
  from wandb.sdk.wandb_run import Run
8
+ from idiomify.paths import CONFIG_YAML, idioms_dir, literal2idiomatic, seq2seq_dir
9
+ from idiomify.urls import PIE_URL
 
 
 
 
 
 
 
 
10
  from transformers import AutoModelForSeq2SeqLM, AutoConfig
11
+ from idiomify.models import Seq2Seq
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  def fetch_pie() -> list:
 
53
  return [(row[0], row[1]) for row in reader]
54
 
55
 
56
+ def fetch_seq2seq(ver: str, run: Run = None) -> Seq2Seq:
57
  if run:
58
+ artifact = run.use_artifact(f"seq2seq:{ver}", type="model")
59
  else:
60
+ artifact = wandb.Api().artifact(f"eubinecto/idiomify/seq2seq:{ver}", type="model")
61
  config = artifact.metadata
62
+ artifact_dir = artifact.download(root=seq2seq_dir(ver))
63
  ckpt_path = path.join(artifact_dir, "model.ckpt")
64
  bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
65
+ alpha = Seq2Seq.load_from_checkpoint(ckpt_path, bart=bart)
66
  return alpha
67
 
68
 
idiomify/idiomifier.py DELETED
@@ -1,22 +0,0 @@
1
- from transformers import BartTokenizer
2
- from builders import SourcesBuilder
3
- from models import Alpha
4
-
5
-
6
- class Idiomifier:
7
-
8
- def __init__(self, model: Alpha, tokenizer: BartTokenizer):
9
- self.model = model
10
- self.builder = SourcesBuilder(tokenizer)
11
- self.model.eval()
12
-
13
- def __call__(self, src: str, max_length=100) -> str:
14
- srcs = self.builder(literal2idiomatic=[(src, "")])
15
- pred_ids = self.model.bart.generate(
16
- inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
17
- attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
18
- decoder_start_token_id=self.model.hparams['bos_token_id'],
19
- max_length=max_length,
20
- ).squeeze() # -> (N, L_t) -> (L_t)
21
- tgt = self.builder.tokenizer.decode(pred_ids, skip_special_tokens=True)
22
- return tgt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
idiomify/models.py CHANGED
@@ -5,12 +5,14 @@ from typing import Tuple
5
  import torch
6
  from torch.nn import functional as F
7
  import pytorch_lightning as pl
8
- from transformers import BartForConditionalGeneration
 
9
 
10
 
11
- class Alpha(pl.LightningModule): # noqa
 
12
  """
13
- the baseline.
14
  """
15
  def __init__(self, bart: BartForConditionalGeneration, lr: float, bos_token_id: int, pad_token_id: int): # noqa
16
  super().__init__()
@@ -54,3 +56,23 @@ class Alpha(pl.LightningModule): # noqa
54
  """
55
  # The authors used Adam, so we might as well use it as well.
56
  return torch.optim.AdamW(self.parameters(), lr=self.hparams['lr'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import torch
6
  from torch.nn import functional as F
7
  import pytorch_lightning as pl
8
+ from transformers import BartForConditionalGeneration, BartTokenizer
9
+ from idiomify.builders import SourcesBuilder
10
 
11
 
12
+ # for training
13
+ class Seq2Seq(pl.LightningModule): # noqa
14
  """
15
+ the baseline is in here.
16
  """
17
  def __init__(self, bart: BartForConditionalGeneration, lr: float, bos_token_id: int, pad_token_id: int): # noqa
18
  super().__init__()
 
56
  """
57
  # The authors used Adam, so we might as well use it as well.
58
  return torch.optim.AdamW(self.parameters(), lr=self.hparams['lr'])
59
+
60
+
61
+ # for inference
62
+ class Idiomifier:
63
+
64
+ def __init__(self, model: Seq2Seq, tokenizer: BartTokenizer):
65
+ self.model = model
66
+ self.builder = SourcesBuilder(tokenizer)
67
+ self.model.eval()
68
+
69
+ def __call__(self, src: str, max_length=100) -> str:
70
+ srcs = self.builder(literal2idiomatic=[(src, "")])
71
+ pred_ids = self.model.bart.generate(
72
+ inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
73
+ attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
74
+ decoder_start_token_id=self.model.hparams['bos_token_id'],
75
+ max_length=max_length,
76
+ ).squeeze() # -> (N, L_t) -> (L_t)
77
+ tgt = self.builder.tokenizer.decode(pred_ids, skip_special_tokens=True)
78
+ return tgt
idiomify/paths.py CHANGED
@@ -6,12 +6,12 @@ CONFIG_YAML = ROOT_DIR / "config.yaml"
6
 
7
 
8
  def idioms_dir(ver: str) -> Path:
9
- return ARTIFACTS_DIR / f"idioms_{ver}"
10
 
11
 
12
  def literal2idiomatic(ver: str) -> Path:
13
- return ARTIFACTS_DIR / f"literal2idiomatic_{ver}"
14
 
15
 
16
- def alpha_dir(ver: str) -> Path:
17
- return ARTIFACTS_DIR / f"alpha_{ver}"
 
6
 
7
 
8
  def idioms_dir(ver: str) -> Path:
9
+ return ARTIFACTS_DIR / f"idioms-{ver}"
10
 
11
 
12
  def literal2idiomatic(ver: str) -> Path:
13
+ return ARTIFACTS_DIR / f"literal2idiomatic-{ver}"
14
 
15
 
16
+ def seq2seq_dir(ver: str) -> Path:
17
+ return ARTIFACTS_DIR / f"seq2seq-{ver}"
idiomify/urls.py CHANGED
@@ -11,6 +11,3 @@ EPIE_MUTABLE_IDIOMS_CONTEXTS_URL = "https://github.com/prateeksaxena2809/EPIE_Co
11
  # https://aclanthology.org/2021.mwe-1.5/
12
  # right, let's just work on it.
13
  PIE_URL = "https://raw.githubusercontent.com/zhjjn/MWE_PIE/main/data_cleaned.csv"
14
-
15
-
16
-
 
11
  # https://aclanthology.org/2021.mwe-1.5/
12
  # right, let's just work on it.
13
  PIE_URL = "https://raw.githubusercontent.com/zhjjn/MWE_PIE/main/data_cleaned.csv"
 
 
 
main_infer.py CHANGED
@@ -1,27 +1,24 @@
1
  import argparse
2
- from termcolor import colored
3
- from idiomifier import Idiomifier
4
- from idiomify.fetchers import fetch_config, fetch_alpha
5
  from transformers import BartTokenizer
6
 
7
 
8
  def main():
9
  parser = argparse.ArgumentParser()
10
- parser.add_argument("--model", type=str,
11
- default="alpha")
12
- parser.add_argument("--ver", type=str,
13
- default="overfit")
14
  parser.add_argument("--src", type=str,
15
- default="If there's any benefits to losing my job, it's that I'll now be able to go to school full-time and finish my degree earlier.")
 
16
  args = parser.parse_args()
17
- config = fetch_config()[args.model][args.ver]
18
  config.update(vars(args))
19
- model = fetch_alpha(config['ver'])
20
  tokenizer = BartTokenizer.from_pretrained(config['bart'])
21
  idiomifier = Idiomifier(model, tokenizer)
22
  src = config['src']
23
  tgt = idiomifier(src=config['src'])
24
- print(src, "\n->", colored(tgt, "blue"))
25
 
26
 
27
  if __name__ == '__main__':
 
1
  import argparse
2
+ from idiomify.models import Idiomifier
3
+ from idiomify.fetchers import fetch_config, fetch_seq2seq
 
4
  from transformers import BartTokenizer
5
 
6
 
7
  def main():
8
  parser = argparse.ArgumentParser()
9
+ parser.add_argument("--ver", type=str, default="tag011")
 
 
 
10
  parser.add_argument("--src", type=str,
11
+ default="If there's any good to loosing my job,"
12
+ " it's that I'll now be able to go to school full-time and finish my degree earlier.")
13
  args = parser.parse_args()
14
+ config = fetch_config()[args.ver]
15
  config.update(vars(args))
16
+ model = fetch_seq2seq(config['ver'])
17
  tokenizer = BartTokenizer.from_pretrained(config['bart'])
18
  idiomifier = Idiomifier(model, tokenizer)
19
  src = config['src']
20
  tgt = idiomifier(src=config['src'])
21
+ print(src, "\n->", tgt)
22
 
23
 
24
  if __name__ == '__main__':
main_train.py CHANGED
@@ -8,20 +8,19 @@ from pytorch_lightning.loggers import WandbLogger
8
  from transformers import BartTokenizer, BartForConditionalGeneration
9
  from idiomify.data import IdiomifyDataModule
10
  from idiomify.fetchers import fetch_config
11
- from idiomify.models import Alpha
12
  from idiomify.paths import ROOT_DIR
13
 
14
 
15
  def main():
16
  parser = argparse.ArgumentParser()
17
- parser.add_argument("--model", type=str, default="alpha")
18
- parser.add_argument("--ver", type=str, default="overfit ")
19
  parser.add_argument("--num_workers", type=int, default=os.cpu_count())
20
  parser.add_argument("--log_every_n_steps", type=int, default=1)
21
  parser.add_argument("--fast_dev_run", action="store_true", default=False)
22
  parser.add_argument("--upload", dest='upload', action='store_true', default=False)
23
  args = parser.parse_args()
24
- config = fetch_config()[args.model][args.ver]
25
  config.update(vars(args))
26
  if not config['upload']:
27
  print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
@@ -29,12 +28,8 @@ def main():
29
  # prepare the model
30
  bart = BartForConditionalGeneration.from_pretrained(config['bart'])
31
  tokenizer = BartTokenizer.from_pretrained(config['bart'])
32
- if config['model'] == "alpha":
33
- model = Alpha(bart, config['lr'], tokenizer.bos_token_id, tokenizer.pad_token_id)
34
- else:
35
- raise NotImplementedError
36
  # prepare the datamodule
37
-
38
  with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
39
  datamodule = IdiomifyDataModule(config, tokenizer, run)
40
  logger = WandbLogger(log_model=False)
 
8
  from transformers import BartTokenizer, BartForConditionalGeneration
9
  from idiomify.data import IdiomifyDataModule
10
  from idiomify.fetchers import fetch_config
11
+ from idiomify.models import Seq2Seq
12
  from idiomify.paths import ROOT_DIR
13
 
14
 
15
  def main():
16
  parser = argparse.ArgumentParser()
17
+ parser.add_argument("--ver", type=str, default="tag011")
 
18
  parser.add_argument("--num_workers", type=int, default=os.cpu_count())
19
  parser.add_argument("--log_every_n_steps", type=int, default=1)
20
  parser.add_argument("--fast_dev_run", action="store_true", default=False)
21
  parser.add_argument("--upload", dest='upload', action='store_true', default=False)
22
  args = parser.parse_args()
23
+ config = fetch_config()[args.ver]
24
  config.update(vars(args))
25
  if not config['upload']:
26
  print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
 
28
  # prepare the model
29
  bart = BartForConditionalGeneration.from_pretrained(config['bart'])
30
  tokenizer = BartTokenizer.from_pretrained(config['bart'])
31
+ model = Seq2Seq(bart, config['lr'], tokenizer.bos_token_id, tokenizer.pad_token_id)
 
 
 
32
  # prepare the datamodule
 
33
  with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
34
  datamodule = IdiomifyDataModule(config, tokenizer, run)
35
  logger = WandbLogger(log_model=False)
main_upload_idioms.py CHANGED
@@ -11,17 +11,13 @@ import wandb
11
 
12
  def main():
13
  parser = argparse.ArgumentParser()
14
- parser.add_argument("--ver", type=str, default="pie_v0",
15
- choices=["pie_v0", "pie_v1"])
16
  config = vars(parser.parse_args())
17
 
18
  # get the idioms here
19
- if config['ver'] == "pie_v0":
20
  # only the first 106, and this is for piloting
21
  idioms = set([row[0] for row in fetch_pie()[:106]])
22
- elif config['ver'] == "pie_v1":
23
- # just include all
24
- idioms = set([row[0] for row in fetch_pie()])
25
  else:
26
  raise NotImplementedError
27
  idioms = list(idioms)
 
11
 
12
  def main():
13
  parser = argparse.ArgumentParser()
14
+ parser.add_argument("--ver", type=str, default="tag01")
 
15
  config = vars(parser.parse_args())
16
 
17
  # get the idioms here
18
+ if config['ver'] == "tag01":
19
  # only the first 106, and this is for piloting
20
  idioms = set([row[0] for row in fetch_pie()[:106]])
 
 
 
21
  else:
22
  raise NotImplementedError
23
  idioms = list(idioms)
main_upload_literal2idiomatic.py CHANGED
@@ -12,21 +12,15 @@ import wandb
12
 
13
  def main():
14
  parser = argparse.ArgumentParser()
15
- parser.add_argument("--ver", type=str, default="pie_v0",
16
- choices=["pie_v0", "pie_v1"])
17
  config = vars(parser.parse_args())
18
 
19
  # get the idioms here
20
- if config['ver'] == "pie_v0":
21
  # only the first 106, and we use this just for piloting
22
  literal2idiom = [
23
  (row[3], row[2]) for row in fetch_pie()[:106]
24
  ]
25
- elif config['ver'] == "pie_v1":
26
- # just include all
27
- literal2idiom = [
28
- (row[3], row[2]) for row in fetch_pie()
29
- ]
30
  else:
31
  raise NotImplementedError
32
 
 
12
 
13
  def main():
14
  parser = argparse.ArgumentParser()
15
+ parser.add_argument("--ver", type=str, default="tag01")
 
16
  config = vars(parser.parse_args())
17
 
18
  # get the idioms here
19
+ if config['ver'] == "tag01":
20
  # only the first 106, and we use this just for piloting
21
  literal2idiom = [
22
  (row[3], row[2]) for row in fetch_pie()[:106]
23
  ]
 
 
 
 
 
24
  else:
25
  raise NotImplementedError
26