[#1] refactoring: Alpha -> Seq2Seq. We rely on git tags for versioning models.
Browse files- config.yaml +8 -8
- explore/explore_bert_base_multilingual_tokenizer.py +0 -44
- explore/explore_bert_base_tokenizer.py +0 -45
- explore/explore_fetch_epie_counts.py +0 -19
- explore/{explore_fetch_alpha.py → explore_fetch_seq2seq.py} +2 -2
- explore/{explore_fetch_alpha_predict.py → explore_fetch_seq2seq_predict.py} +2 -2
- explore/explore_nlpaug.py +21 -0
- idiomify/builders.py +5 -5
- idiomify/fetchers.py +8 -41
- idiomify/idiomifier.py +0 -22
- idiomify/models.py +25 -3
- idiomify/paths.py +4 -4
- idiomify/urls.py +0 -3
- main_infer.py +8 -11
- main_train.py +4 -9
- main_upload_idioms.py +2 -6
- main_upload_literal2idiomatic.py +2 -8
config.yaml
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
1 |
+
tag011:
|
2 |
+
desc: just overfitting
|
3 |
+
bart: facebook/bart-base
|
4 |
+
lr: 0.0001
|
5 |
+
literal2idiomatic_ver: tag01
|
6 |
+
max_epochs: 100
|
7 |
+
batch_size: 100
|
8 |
+
shuffle: true
|
explore/explore_bert_base_multilingual_tokenizer.py
DELETED
@@ -1,44 +0,0 @@
|
|
1 |
-
from idiomify.fetchers import fetch_idiom2def
|
2 |
-
from transformers import AutoTokenizer, BertTokenizer, BertTokenizerFast
|
3 |
-
|
4 |
-
|
5 |
-
def main():
|
6 |
-
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-uncased")
|
7 |
-
idiom2def = fetch_idiom2def("d") # eng2kor
|
8 |
-
|
9 |
-
for idiom, definition in idiom2def:
|
10 |
-
print(tokenizer.decode(tokenizer(idiom)['input_ids']),
|
11 |
-
tokenizer.decode(tokenizer(definition)['input_ids']))
|
12 |
-
|
13 |
-
# right, the tokenizer knows Korean, which is great.
|
14 |
-
"""
|
15 |
-
/opt/homebrew/Caskroom/miniforge/base/envs/idiomify-demo/bin/python /Users/eubinecto/Desktop/Projects/Toy/idiomify-demo/explore/explore_mbert_tokenizer.py
|
16 |
-
[CLS] beat around the bush [SEP] [CLS] 불쾌하거나 민감한 주제에 대해 직접적으로 이야기하는 것을 피하기 위해 모호하거나 완곡하게 말한다. [SEP]
|
17 |
-
[CLS] beat around the bush [SEP] [CLS] 단어나 태도가 우회적이다 [SEP]
|
18 |
-
[CLS] beat around the bush [SEP] [CLS] 우물쭈물하다 [SEP]
|
19 |
-
[CLS] beat around the bush [SEP] [CLS] 우회적으로 접근하다 [SEP]
|
20 |
-
[CLS] backhanded compliment [SEP] [CLS] 칭찬으로 가장한 모욕적이거나 부정적인 논평 [SEP]
|
21 |
-
[CLS] backhanded compliment [SEP] [CLS] 의도하지 않거나 애매한 칭찬 [SEP]
|
22 |
-
[CLS] backhanded compliment [SEP] [CLS] 누군가를 칭찬하는 것 같지만 비판으로도 이해될 수 있는 말 [SEP]
|
23 |
-
[CLS] backhanded compliment [SEP] [CLS] 남을 기쁘게 하는 말 같지만 모욕이 될 수도 있는 말 [SEP]
|
24 |
-
[CLS] backhanded compliment [SEP] [CLS] 감탄하는 듯 하면서도 모욕으로 이해될 수 있는 말 [SEP]
|
25 |
-
[CLS] steer clear of [SEP] [CLS] 누군가나 뭔가를 피하다 [SEP]
|
26 |
-
[CLS] steer clear of [SEP] [CLS] 떨어져 지내다 [SEP]
|
27 |
-
[CLS] steer clear of [SEP] [CLS] 피하거나 멀리하도록 주의하다 [SEP]
|
28 |
-
[CLS] steer clear of [SEP] [CLS] 불쾌하거나 위험하거나 문제를 일으킬 것 같은 사람이나 물건을 피하다 [SEP]
|
29 |
-
[CLS] steer clear of [SEP] [CLS] 일부러 피하다 [SEP]
|
30 |
-
[CLS] dish it out [SEP] [CLS] 가혹한 생각, 비판, 또는 모욕의 목소리를 내는 것. [SEP]
|
31 |
-
[CLS] dish it out [SEP] [CLS] 누군가 또는 무언가에 대해 험담하는 것 [SEP]
|
32 |
-
[CLS] dish it out [SEP] [CLS] 어떤 것을 주거나 정보나 당신의 의견과 같은 것을 말하는 것 [SEP]
|
33 |
-
[CLS] dish it out [SEP] [CLS] 다른 사람을 쉽게 비판하지만 다른 사람이 자신을 비판할때는 좋아하지 않음 [SEP]
|
34 |
-
[CLS] dish it out [SEP] [CLS] 다른 사람을 비판하다 [SEP]
|
35 |
-
[CLS] make headway [SEP] [CLS] 성취하고자 하는 어떤 것에 진척이 생기다 [SEP]
|
36 |
-
[CLS] make headway [SEP] [CLS] 특히 이것이 느리거나 어려울 때, 진전을 이루다. [SEP]
|
37 |
-
[CLS] make headway [SEP] [CLS] 전진하다 [SEP]
|
38 |
-
[CLS] make headway [SEP] [CLS] 앞으로 나아가거나 진전을 이루다 [SEP]
|
39 |
-
[CLS] make headway [SEP] [CLS] 성공하기 시작하다 [SEP]
|
40 |
-
"""
|
41 |
-
|
42 |
-
|
43 |
-
if __name__ == '__main__':
|
44 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
explore/explore_bert_base_tokenizer.py
DELETED
@@ -1,45 +0,0 @@
|
|
1 |
-
from idiomify.fetchers import fetch_idiom2def
|
2 |
-
from transformers import AutoTokenizer, BertTokenizer, BertTokenizerFast
|
3 |
-
|
4 |
-
|
5 |
-
def main():
|
6 |
-
|
7 |
-
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
8 |
-
idiom2def = fetch_idiom2def("c") # eng2eng
|
9 |
-
for idiom, definition in idiom2def:
|
10 |
-
print(tokenizer.decode(tokenizer(idiom)['input_ids']),
|
11 |
-
tokenizer.decode(tokenizer(definition)['input_ids']))
|
12 |
-
|
13 |
-
"""
|
14 |
-
/opt/homebrew/Caskroom/miniforge/base/envs/idiomify-demo/bin/python /Users/eubinecto/Desktop/Projects/Toy/idiomify-demo/explore/explore_bert_base_tokenizer.py
|
15 |
-
Downloading: 100%|██████████| 226k/226k [00:00<00:00, 298kB/s]
|
16 |
-
Downloading: 100%|██████████| 28.0/28.0 [00:00<00:00, 8.27kB/s]
|
17 |
-
Downloading: 100%|██████████| 455k/455k [00:01<00:00, 449kB/s]
|
18 |
-
[CLS] beat around the bush [SEP] [CLS] to speak vaguely or euphemistically so as to avoid talkingdirectly about an unpleasant or sensitive topic [SEP]
|
19 |
-
[CLS] beat around the bush [SEP] [CLS] indirection in word or deed [SEP]
|
20 |
-
[CLS] beat around the bush [SEP] [CLS] to shilly - shally [SEP]
|
21 |
-
[CLS] beat around the bush [SEP] [CLS] to approach something in a roundabout way [SEP]
|
22 |
-
[CLS] backhanded compliment [SEP] [CLS] an insulting or negative comment disguised as praise. [SEP]
|
23 |
-
[CLS] backhanded compliment [SEP] [CLS] an unintended or ambiguous compliment. [SEP]
|
24 |
-
[CLS] backhanded compliment [SEP] [CLS] a remark which seems to be praising someone or something but which could also be understood as criticism [SEP]
|
25 |
-
[CLS] backhanded compliment [SEP] [CLS] a remark that seems to say something pleasant about a person but could also be an insult [SEP]
|
26 |
-
[CLS] backhanded compliment [SEP] [CLS] a remark that seems to express admiration but could also be understood as an insult [SEP]
|
27 |
-
[CLS] steer clear of [SEP] [CLS] to avoid someone or something. [SEP]
|
28 |
-
[CLS] steer clear of [SEP] [CLS] stay away from [SEP]
|
29 |
-
[CLS] steer clear of [SEP] [CLS] take care to avoid or keep away from [SEP]
|
30 |
-
[CLS] steer clear of [SEP] [CLS] to avoid someone or something that seems unpleasant, dangerous, or likely to cause problems [SEP]
|
31 |
-
[CLS] steer clear of [SEP] [CLS] deliberately avoid someone [SEP]
|
32 |
-
[CLS] dish it out [SEP] [CLS] to voice harsh thoughts, criticisms, or insults. [SEP]
|
33 |
-
[CLS] dish it out [SEP] [CLS] to gossip about someone or something [SEP]
|
34 |
-
[CLS] dish it out [SEP] [CLS] to give something, or to tell something such as information or your opinions [SEP]
|
35 |
-
[CLS] dish it out [SEP] [CLS] someone easily criticizes other people but does not like it when other people criticize him or her [SEP]
|
36 |
-
[CLS] dish it out [SEP] [CLS] to criticize other people [SEP]
|
37 |
-
[CLS] make headway [SEP] [CLS] make progress with something that you are trying to achieve. [SEP]
|
38 |
-
[CLS] make headway [SEP] [CLS] make progress, especially when this is slow or difficult [SEP]
|
39 |
-
[CLS] make headway [SEP] [CLS] to advance. [SEP]
|
40 |
-
[CLS] make headway [SEP] [CLS] to move forward or make progress [SEP]
|
41 |
-
[CLS] make headway [SEP] [CLS] to begin to succeed [SEP]
|
42 |
-
"""
|
43 |
-
|
44 |
-
if __name__ == '__main__':
|
45 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
explore/explore_fetch_epie_counts.py
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
from idiomify.fetchers import fetch_epie
|
2 |
-
|
3 |
-
|
4 |
-
def main():
|
5 |
-
idioms = set([
|
6 |
-
idiom
|
7 |
-
for idiom, _, _ in fetch_epie()
|
8 |
-
])
|
9 |
-
contexts = [
|
10 |
-
context
|
11 |
-
for _, _, context in fetch_epie()
|
12 |
-
]
|
13 |
-
print("Total number of idioms:", len(idioms))
|
14 |
-
# This should learn... this - what I need for now is building a datamodule out of this
|
15 |
-
print("Total number of contexts:", len(contexts))
|
16 |
-
|
17 |
-
|
18 |
-
if __name__ == '__main__':
|
19 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
explore/{explore_fetch_alpha.py → explore_fetch_seq2seq.py}
RENAMED
@@ -1,8 +1,8 @@
|
|
1 |
-
from idiomify.fetchers import
|
2 |
|
3 |
|
4 |
def main():
|
5 |
-
model =
|
6 |
print(model.bart.config)
|
7 |
|
8 |
|
|
|
1 |
+
from idiomify.fetchers import fetch_seq2seq
|
2 |
|
3 |
|
4 |
def main():
|
5 |
+
model = fetch_seq2seq("overfit")
|
6 |
print(model.bart.config)
|
7 |
|
8 |
|
explore/{explore_fetch_alpha_predict.py → explore_fetch_seq2seq_predict.py}
RENAMED
@@ -1,10 +1,10 @@
|
|
1 |
from transformers import BartTokenizer
|
2 |
from builders import SourcesBuilder
|
3 |
-
from fetchers import
|
4 |
|
5 |
|
6 |
def main():
|
7 |
-
model =
|
8 |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
9 |
lit2idi = [
|
10 |
("my man", ""),
|
|
|
1 |
from transformers import BartTokenizer
|
2 |
from builders import SourcesBuilder
|
3 |
+
from fetchers import fetch_seq2seq
|
4 |
|
5 |
|
6 |
def main():
|
7 |
+
model = fetch_seq2seq("overfit")
|
8 |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
9 |
lit2idi = [
|
10 |
("my man", ""),
|
explore/explore_nlpaug.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import nlpaug.augmenter.word as naw
|
3 |
+
import nlpaug.augmenter.sentence as nas
|
4 |
+
|
5 |
+
import nltk
|
6 |
+
|
7 |
+
|
8 |
+
sent = "I am really happy with the new job and I mean that with sincere feeling"
|
9 |
+
|
10 |
+
|
11 |
+
def main():
|
12 |
+
nltk.download("omw-1.4")
|
13 |
+
# this seems legit! I could definitely use this to increase the accuracy of the model
|
14 |
+
# for a few idioms (possibly ten, ten very different but frequent idioms)
|
15 |
+
aug = naw.ContextualWordEmbsAug()
|
16 |
+
augmented = aug.augment(sent, n=10)
|
17 |
+
print(augmented)
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == '__main__':
|
21 |
+
main()
|
idiomify/builders.py
CHANGED
@@ -4,7 +4,7 @@ builders must accept device as one of the parameters.
|
|
4 |
"""
|
5 |
import torch
|
6 |
from typing import List, Tuple
|
7 |
-
from transformers import
|
8 |
|
9 |
|
10 |
class TensorBuilder:
|
@@ -61,7 +61,9 @@ class SourcesBuilder(TensorBuilder):
|
|
61 |
|
62 |
|
63 |
class TargetsRightShiftedBuilder(TensorBuilder):
|
64 |
-
|
|
|
|
|
65 |
def __call__(self, literal2idiomatic: List[Tuple[str, str]]) -> torch.Tensor:
|
66 |
encodings = self.tokenizer([
|
67 |
self.tokenizer.bos_token + idiomatic # starts with bos, but does not end with eos (right-shifted)
|
@@ -73,9 +75,7 @@ class TargetsRightShiftedBuilder(TensorBuilder):
|
|
73 |
|
74 |
|
75 |
class TargetsBuilder(TensorBuilder):
|
76 |
-
|
77 |
-
This is to be used only for training. As for inference, we don't need this.
|
78 |
-
"""
|
79 |
def __call__(self, literal2idiomatic: List[Tuple[str, str]]) -> torch.Tensor:
|
80 |
encodings = self.tokenizer([
|
81 |
idiomatic + self.tokenizer.eos_token # no bos, but ends with eos
|
|
|
4 |
"""
|
5 |
import torch
|
6 |
from typing import List, Tuple
|
7 |
+
from transformers import BartTokenizer
|
8 |
|
9 |
|
10 |
class TensorBuilder:
|
|
|
61 |
|
62 |
|
63 |
class TargetsRightShiftedBuilder(TensorBuilder):
|
64 |
+
"""
|
65 |
+
This is to be used only for training. As for inference, we don't need this.
|
66 |
+
"""
|
67 |
def __call__(self, literal2idiomatic: List[Tuple[str, str]]) -> torch.Tensor:
|
68 |
encodings = self.tokenizer([
|
69 |
self.tokenizer.bos_token + idiomatic # starts with bos, but does not end with eos (right-shifted)
|
|
|
75 |
|
76 |
|
77 |
class TargetsBuilder(TensorBuilder):
|
78 |
+
|
|
|
|
|
79 |
def __call__(self, literal2idiomatic: List[Tuple[str, str]]) -> torch.Tensor:
|
80 |
encodings = self.tokenizer([
|
81 |
idiomatic + self.tokenizer.eos_token # no bos, but ends with eos
|
idiomify/fetchers.py
CHANGED
@@ -5,43 +5,10 @@ import wandb
|
|
5 |
import requests
|
6 |
from typing import Tuple, List
|
7 |
from wandb.sdk.wandb_run import Run
|
8 |
-
from idiomify.paths import CONFIG_YAML, idioms_dir, literal2idiomatic,
|
9 |
-
from idiomify.urls import
|
10 |
-
EPIE_IMMUTABLE_IDIOMS_URL,
|
11 |
-
EPIE_IMMUTABLE_IDIOMS_CONTEXTS_URL,
|
12 |
-
EPIE_IMMUTABLE_IDIOMS_TAGS_URL,
|
13 |
-
EPIE_MUTABLE_IDIOMS_URL,
|
14 |
-
EPIE_MUTABLE_IDIOMS_CONTEXTS_URL,
|
15 |
-
EPIE_MUTABLE_IDIOMS_TAGS_URL,
|
16 |
-
PIE_URL
|
17 |
-
)
|
18 |
from transformers import AutoModelForSeq2SeqLM, AutoConfig
|
19 |
-
from models import
|
20 |
-
|
21 |
-
|
22 |
-
def fetch_epie(ver: str) -> List[Tuple[str, str, str]]:
|
23 |
-
"""
|
24 |
-
It fetches the EPIE idioms, contexts, and tags from the web
|
25 |
-
:param ver: str
|
26 |
-
:type ver: str
|
27 |
-
:return: A list of tuples. Each tuple contains three strings: an idiom, a context, and a tag.
|
28 |
-
"""
|
29 |
-
if ver == "immutable":
|
30 |
-
idioms_url = EPIE_IMMUTABLE_IDIOMS_URL
|
31 |
-
contexts_url = EPIE_IMMUTABLE_IDIOMS_CONTEXTS_URL
|
32 |
-
tags_url = EPIE_IMMUTABLE_IDIOMS_TAGS_URL
|
33 |
-
elif ver == "mutable":
|
34 |
-
idioms_url = EPIE_MUTABLE_IDIOMS_URL
|
35 |
-
contexts_url = EPIE_MUTABLE_IDIOMS_CONTEXTS_URL
|
36 |
-
tags_url = EPIE_MUTABLE_IDIOMS_TAGS_URL
|
37 |
-
else:
|
38 |
-
raise ValueError
|
39 |
-
idioms = requests.get(idioms_url).text
|
40 |
-
contexts = requests.get(contexts_url).text
|
41 |
-
tags = requests.get(tags_url).text
|
42 |
-
return list(zip(idioms.strip().split("\n"),
|
43 |
-
contexts.strip().split("\n"),
|
44 |
-
tags.strip().split("\n")))
|
45 |
|
46 |
|
47 |
def fetch_pie() -> list:
|
@@ -86,16 +53,16 @@ def fetch_literal2idiomatic(ver: str, run: Run = None) -> List[Tuple[str, str]]:
|
|
86 |
return [(row[0], row[1]) for row in reader]
|
87 |
|
88 |
|
89 |
-
def
|
90 |
if run:
|
91 |
-
artifact = run.use_artifact(f"
|
92 |
else:
|
93 |
-
artifact = wandb.Api().artifact(f"eubinecto/idiomify/
|
94 |
config = artifact.metadata
|
95 |
-
artifact_dir = artifact.download(root=
|
96 |
ckpt_path = path.join(artifact_dir, "model.ckpt")
|
97 |
bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
|
98 |
-
alpha =
|
99 |
return alpha
|
100 |
|
101 |
|
|
|
5 |
import requests
|
6 |
from typing import Tuple, List
|
7 |
from wandb.sdk.wandb_run import Run
|
8 |
+
from idiomify.paths import CONFIG_YAML, idioms_dir, literal2idiomatic, seq2seq_dir
|
9 |
+
from idiomify.urls import PIE_URL
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
from transformers import AutoModelForSeq2SeqLM, AutoConfig
|
11 |
+
from idiomify.models import Seq2Seq
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
def fetch_pie() -> list:
|
|
|
53 |
return [(row[0], row[1]) for row in reader]
|
54 |
|
55 |
|
56 |
+
def fetch_seq2seq(ver: str, run: Run = None) -> Seq2Seq:
|
57 |
if run:
|
58 |
+
artifact = run.use_artifact(f"seq2seq:{ver}", type="model")
|
59 |
else:
|
60 |
+
artifact = wandb.Api().artifact(f"eubinecto/idiomify/seq2seq:{ver}", type="model")
|
61 |
config = artifact.metadata
|
62 |
+
artifact_dir = artifact.download(root=seq2seq_dir(ver))
|
63 |
ckpt_path = path.join(artifact_dir, "model.ckpt")
|
64 |
bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
|
65 |
+
alpha = Seq2Seq.load_from_checkpoint(ckpt_path, bart=bart)
|
66 |
return alpha
|
67 |
|
68 |
|
idiomify/idiomifier.py
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
from transformers import BartTokenizer
|
2 |
-
from builders import SourcesBuilder
|
3 |
-
from models import Alpha
|
4 |
-
|
5 |
-
|
6 |
-
class Idiomifier:
|
7 |
-
|
8 |
-
def __init__(self, model: Alpha, tokenizer: BartTokenizer):
|
9 |
-
self.model = model
|
10 |
-
self.builder = SourcesBuilder(tokenizer)
|
11 |
-
self.model.eval()
|
12 |
-
|
13 |
-
def __call__(self, src: str, max_length=100) -> str:
|
14 |
-
srcs = self.builder(literal2idiomatic=[(src, "")])
|
15 |
-
pred_ids = self.model.bart.generate(
|
16 |
-
inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
|
17 |
-
attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
|
18 |
-
decoder_start_token_id=self.model.hparams['bos_token_id'],
|
19 |
-
max_length=max_length,
|
20 |
-
).squeeze() # -> (N, L_t) -> (L_t)
|
21 |
-
tgt = self.builder.tokenizer.decode(pred_ids, skip_special_tokens=True)
|
22 |
-
return tgt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
idiomify/models.py
CHANGED
@@ -5,12 +5,14 @@ from typing import Tuple
|
|
5 |
import torch
|
6 |
from torch.nn import functional as F
|
7 |
import pytorch_lightning as pl
|
8 |
-
from transformers import BartForConditionalGeneration
|
|
|
9 |
|
10 |
|
11 |
-
|
|
|
12 |
"""
|
13 |
-
the baseline.
|
14 |
"""
|
15 |
def __init__(self, bart: BartForConditionalGeneration, lr: float, bos_token_id: int, pad_token_id: int): # noqa
|
16 |
super().__init__()
|
@@ -54,3 +56,23 @@ class Alpha(pl.LightningModule): # noqa
|
|
54 |
"""
|
55 |
# The authors used Adam, so we might as well use it as well.
|
56 |
return torch.optim.AdamW(self.parameters(), lr=self.hparams['lr'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import torch
|
6 |
from torch.nn import functional as F
|
7 |
import pytorch_lightning as pl
|
8 |
+
from transformers import BartForConditionalGeneration, BartTokenizer
|
9 |
+
from idiomify.builders import SourcesBuilder
|
10 |
|
11 |
|
12 |
+
# for training
|
13 |
+
class Seq2Seq(pl.LightningModule): # noqa
|
14 |
"""
|
15 |
+
the baseline is in here.
|
16 |
"""
|
17 |
def __init__(self, bart: BartForConditionalGeneration, lr: float, bos_token_id: int, pad_token_id: int): # noqa
|
18 |
super().__init__()
|
|
|
56 |
"""
|
57 |
# The authors used Adam, so we might as well use it as well.
|
58 |
return torch.optim.AdamW(self.parameters(), lr=self.hparams['lr'])
|
59 |
+
|
60 |
+
|
61 |
+
# for inference
|
62 |
+
class Idiomifier:
|
63 |
+
|
64 |
+
def __init__(self, model: Seq2Seq, tokenizer: BartTokenizer):
|
65 |
+
self.model = model
|
66 |
+
self.builder = SourcesBuilder(tokenizer)
|
67 |
+
self.model.eval()
|
68 |
+
|
69 |
+
def __call__(self, src: str, max_length=100) -> str:
|
70 |
+
srcs = self.builder(literal2idiomatic=[(src, "")])
|
71 |
+
pred_ids = self.model.bart.generate(
|
72 |
+
inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
|
73 |
+
attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
|
74 |
+
decoder_start_token_id=self.model.hparams['bos_token_id'],
|
75 |
+
max_length=max_length,
|
76 |
+
).squeeze() # -> (N, L_t) -> (L_t)
|
77 |
+
tgt = self.builder.tokenizer.decode(pred_ids, skip_special_tokens=True)
|
78 |
+
return tgt
|
idiomify/paths.py
CHANGED
@@ -6,12 +6,12 @@ CONFIG_YAML = ROOT_DIR / "config.yaml"
|
|
6 |
|
7 |
|
8 |
def idioms_dir(ver: str) -> Path:
|
9 |
-
return ARTIFACTS_DIR / f"
|
10 |
|
11 |
|
12 |
def literal2idiomatic(ver: str) -> Path:
|
13 |
-
return ARTIFACTS_DIR / f"
|
14 |
|
15 |
|
16 |
-
def
|
17 |
-
return ARTIFACTS_DIR / f"
|
|
|
6 |
|
7 |
|
8 |
def idioms_dir(ver: str) -> Path:
|
9 |
+
return ARTIFACTS_DIR / f"idioms-{ver}"
|
10 |
|
11 |
|
12 |
def literal2idiomatic(ver: str) -> Path:
|
13 |
+
return ARTIFACTS_DIR / f"literal2idiomatic-{ver}"
|
14 |
|
15 |
|
16 |
+
def seq2seq_dir(ver: str) -> Path:
|
17 |
+
return ARTIFACTS_DIR / f"seq2seq-{ver}"
|
idiomify/urls.py
CHANGED
@@ -11,6 +11,3 @@ EPIE_MUTABLE_IDIOMS_CONTEXTS_URL = "https://github.com/prateeksaxena2809/EPIE_Co
|
|
11 |
# https://aclanthology.org/2021.mwe-1.5/
|
12 |
# right, let's just work on it.
|
13 |
PIE_URL = "https://raw.githubusercontent.com/zhjjn/MWE_PIE/main/data_cleaned.csv"
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
11 |
# https://aclanthology.org/2021.mwe-1.5/
|
12 |
# right, let's just work on it.
|
13 |
PIE_URL = "https://raw.githubusercontent.com/zhjjn/MWE_PIE/main/data_cleaned.csv"
|
|
|
|
|
|
main_infer.py
CHANGED
@@ -1,27 +1,24 @@
|
|
1 |
import argparse
|
2 |
-
from
|
3 |
-
from
|
4 |
-
from idiomify.fetchers import fetch_config, fetch_alpha
|
5 |
from transformers import BartTokenizer
|
6 |
|
7 |
|
8 |
def main():
|
9 |
parser = argparse.ArgumentParser()
|
10 |
-
parser.add_argument("--
|
11 |
-
default="alpha")
|
12 |
-
parser.add_argument("--ver", type=str,
|
13 |
-
default="overfit")
|
14 |
parser.add_argument("--src", type=str,
|
15 |
-
default="If there's any
|
|
|
16 |
args = parser.parse_args()
|
17 |
-
config = fetch_config()[args.
|
18 |
config.update(vars(args))
|
19 |
-
model =
|
20 |
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
21 |
idiomifier = Idiomifier(model, tokenizer)
|
22 |
src = config['src']
|
23 |
tgt = idiomifier(src=config['src'])
|
24 |
-
print(src, "\n->",
|
25 |
|
26 |
|
27 |
if __name__ == '__main__':
|
|
|
1 |
import argparse
|
2 |
+
from idiomify.models import Idiomifier
|
3 |
+
from idiomify.fetchers import fetch_config, fetch_seq2seq
|
|
|
4 |
from transformers import BartTokenizer
|
5 |
|
6 |
|
7 |
def main():
|
8 |
parser = argparse.ArgumentParser()
|
9 |
+
parser.add_argument("--ver", type=str, default="tag011")
|
|
|
|
|
|
|
10 |
parser.add_argument("--src", type=str,
|
11 |
+
default="If there's any good to loosing my job,"
|
12 |
+
" it's that I'll now be able to go to school full-time and finish my degree earlier.")
|
13 |
args = parser.parse_args()
|
14 |
+
config = fetch_config()[args.ver]
|
15 |
config.update(vars(args))
|
16 |
+
model = fetch_seq2seq(config['ver'])
|
17 |
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
18 |
idiomifier = Idiomifier(model, tokenizer)
|
19 |
src = config['src']
|
20 |
tgt = idiomifier(src=config['src'])
|
21 |
+
print(src, "\n->", tgt)
|
22 |
|
23 |
|
24 |
if __name__ == '__main__':
|
main_train.py
CHANGED
@@ -8,20 +8,19 @@ from pytorch_lightning.loggers import WandbLogger
|
|
8 |
from transformers import BartTokenizer, BartForConditionalGeneration
|
9 |
from idiomify.data import IdiomifyDataModule
|
10 |
from idiomify.fetchers import fetch_config
|
11 |
-
from idiomify.models import
|
12 |
from idiomify.paths import ROOT_DIR
|
13 |
|
14 |
|
15 |
def main():
|
16 |
parser = argparse.ArgumentParser()
|
17 |
-
parser.add_argument("--
|
18 |
-
parser.add_argument("--ver", type=str, default="overfit ")
|
19 |
parser.add_argument("--num_workers", type=int, default=os.cpu_count())
|
20 |
parser.add_argument("--log_every_n_steps", type=int, default=1)
|
21 |
parser.add_argument("--fast_dev_run", action="store_true", default=False)
|
22 |
parser.add_argument("--upload", dest='upload', action='store_true', default=False)
|
23 |
args = parser.parse_args()
|
24 |
-
config = fetch_config()[args.
|
25 |
config.update(vars(args))
|
26 |
if not config['upload']:
|
27 |
print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
|
@@ -29,12 +28,8 @@ def main():
|
|
29 |
# prepare the model
|
30 |
bart = BartForConditionalGeneration.from_pretrained(config['bart'])
|
31 |
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
32 |
-
|
33 |
-
model = Alpha(bart, config['lr'], tokenizer.bos_token_id, tokenizer.pad_token_id)
|
34 |
-
else:
|
35 |
-
raise NotImplementedError
|
36 |
# prepare the datamodule
|
37 |
-
|
38 |
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
39 |
datamodule = IdiomifyDataModule(config, tokenizer, run)
|
40 |
logger = WandbLogger(log_model=False)
|
|
|
8 |
from transformers import BartTokenizer, BartForConditionalGeneration
|
9 |
from idiomify.data import IdiomifyDataModule
|
10 |
from idiomify.fetchers import fetch_config
|
11 |
+
from idiomify.models import Seq2Seq
|
12 |
from idiomify.paths import ROOT_DIR
|
13 |
|
14 |
|
15 |
def main():
|
16 |
parser = argparse.ArgumentParser()
|
17 |
+
parser.add_argument("--ver", type=str, default="tag011")
|
|
|
18 |
parser.add_argument("--num_workers", type=int, default=os.cpu_count())
|
19 |
parser.add_argument("--log_every_n_steps", type=int, default=1)
|
20 |
parser.add_argument("--fast_dev_run", action="store_true", default=False)
|
21 |
parser.add_argument("--upload", dest='upload', action='store_true', default=False)
|
22 |
args = parser.parse_args()
|
23 |
+
config = fetch_config()[args.ver]
|
24 |
config.update(vars(args))
|
25 |
if not config['upload']:
|
26 |
print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
|
|
|
28 |
# prepare the model
|
29 |
bart = BartForConditionalGeneration.from_pretrained(config['bart'])
|
30 |
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
31 |
+
model = Seq2Seq(bart, config['lr'], tokenizer.bos_token_id, tokenizer.pad_token_id)
|
|
|
|
|
|
|
32 |
# prepare the datamodule
|
|
|
33 |
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
34 |
datamodule = IdiomifyDataModule(config, tokenizer, run)
|
35 |
logger = WandbLogger(log_model=False)
|
main_upload_idioms.py
CHANGED
@@ -11,17 +11,13 @@ import wandb
|
|
11 |
|
12 |
def main():
|
13 |
parser = argparse.ArgumentParser()
|
14 |
-
parser.add_argument("--ver", type=str, default="
|
15 |
-
choices=["pie_v0", "pie_v1"])
|
16 |
config = vars(parser.parse_args())
|
17 |
|
18 |
# get the idioms here
|
19 |
-
if config['ver'] == "
|
20 |
# only the first 106, and this is for piloting
|
21 |
idioms = set([row[0] for row in fetch_pie()[:106]])
|
22 |
-
elif config['ver'] == "pie_v1":
|
23 |
-
# just include all
|
24 |
-
idioms = set([row[0] for row in fetch_pie()])
|
25 |
else:
|
26 |
raise NotImplementedError
|
27 |
idioms = list(idioms)
|
|
|
11 |
|
12 |
def main():
|
13 |
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument("--ver", type=str, default="tag01")
|
|
|
15 |
config = vars(parser.parse_args())
|
16 |
|
17 |
# get the idioms here
|
18 |
+
if config['ver'] == "tag01":
|
19 |
# only the first 106, and this is for piloting
|
20 |
idioms = set([row[0] for row in fetch_pie()[:106]])
|
|
|
|
|
|
|
21 |
else:
|
22 |
raise NotImplementedError
|
23 |
idioms = list(idioms)
|
main_upload_literal2idiomatic.py
CHANGED
@@ -12,21 +12,15 @@ import wandb
|
|
12 |
|
13 |
def main():
|
14 |
parser = argparse.ArgumentParser()
|
15 |
-
parser.add_argument("--ver", type=str, default="
|
16 |
-
choices=["pie_v0", "pie_v1"])
|
17 |
config = vars(parser.parse_args())
|
18 |
|
19 |
# get the idioms here
|
20 |
-
if config['ver'] == "
|
21 |
# only the first 106, and we use this just for piloting
|
22 |
literal2idiom = [
|
23 |
(row[3], row[2]) for row in fetch_pie()[:106]
|
24 |
]
|
25 |
-
elif config['ver'] == "pie_v1":
|
26 |
-
# just include all
|
27 |
-
literal2idiom = [
|
28 |
-
(row[3], row[2]) for row in fetch_pie()
|
29 |
-
]
|
30 |
else:
|
31 |
raise NotImplementedError
|
32 |
|
|
|
12 |
|
13 |
def main():
|
14 |
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument("--ver", type=str, default="tag01")
|
|
|
16 |
config = vars(parser.parse_args())
|
17 |
|
18 |
# get the idioms here
|
19 |
+
if config['ver'] == "tag01":
|
20 |
# only the first 106, and we use this just for piloting
|
21 |
literal2idiom = [
|
22 |
(row[3], row[2]) for row in fetch_pie()[:106]
|
23 |
]
|
|
|
|
|
|
|
|
|
|
|
24 |
else:
|
25 |
raise NotImplementedError
|
26 |
|