[#2] Support for d-1-2 done. Support for m-1-2 partially done (need to implement the testing logic with some valid metrics)
Browse files- README.md +12 -1
- config.yaml +16 -4
- explore/{explore_fetch_seq2seq.py → explore_fetch_idiomifier.py} +2 -2
- explore/{explore_fetch_seq2seq_predict.py → explore_fetch_idiomifier_predict.py} +2 -2
- explore/explore_fetch_idioms.py +1 -1
- explore/explore_fetch_literal2idiomatic.py +3 -2
- explore/explore_fetch_pie.py +3 -5
- explore/explore_fetch_pie_df_select.py +12 -0
- explore/explore_idiomifydatamodule.py +10 -2
- idiomify/builders.py +3 -3
- idiomify/data.py +24 -13
- idiomify/fetchers.py +18 -20
- idiomify/models.py +3 -5
- idiomify/paths.py +3 -3
- idiomify/preprocess.py +31 -0
- main_infer.py +6 -6
- main_train.py +4 -5
- main_upload_idioms.py +14 -18
- main_upload_literal2idiomatic.py +27 -26
- requirements.txt +2 -1
README.md
CHANGED
@@ -10,4 +10,15 @@ A human-inspired Idiomifier based on BERT
|
|
10 |
- wandb
|
11 |
- pytorch-lightning
|
12 |
- transformers
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
- wandb
|
11 |
- pytorch-lightning
|
12 |
- transformers
|
13 |
+
|
14 |
+
## Versions
|
15 |
+
### models
|
16 |
+
format: `m-a-b`
|
17 |
+
- a: used to indicate a change in the architecture, or a revision of the final product
|
18 |
+
- b: used to indicate a different version of the same architecture (with a different set of hyperparameters)
|
19 |
+
|
20 |
+
|
21 |
+
### datasets
|
22 |
+
format: `d-a-b`
|
23 |
+
- a: used to indicate a change in the dataset we are using
|
24 |
+
- b: used to indicate a different version of the dataset
|
config.yaml
CHANGED
@@ -1,8 +1,20 @@
|
|
1 |
-
|
2 |
-
|
|
|
3 |
bart: facebook/bart-base
|
4 |
lr: 0.0001
|
5 |
-
literal2idiomatic_ver:
|
6 |
max_epochs: 100
|
7 |
batch_size: 100
|
8 |
-
shuffle: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train:
|
2 |
+
ver: m-1-2
|
3 |
+
desc: just overfitting the model, but on the entire PIE dataset.
|
4 |
bart: facebook/bart-base
|
5 |
lr: 0.0001
|
6 |
+
literal2idiomatic_ver: d-1-2
|
7 |
max_epochs: 100
|
8 |
batch_size: 100
|
9 |
+
shuffle: true
|
10 |
+
|
11 |
+
# for building & uploading datasets or others
|
12 |
+
upload:
|
13 |
+
idioms:
|
14 |
+
ver: d-1-2
|
15 |
+
description: the set of idioms in the traning set of literal2idiomatic:d-1-2
|
16 |
+
literal2idiomatic:
|
17 |
+
ver: d-1-2
|
18 |
+
description: PIE data split into train & test set (80 / 20 split)
|
19 |
+
train_ratio: 0.8
|
20 |
+
seed: 104
|
explore/{explore_fetch_seq2seq.py → explore_fetch_idiomifier.py}
RENAMED
@@ -1,8 +1,8 @@
|
|
1 |
-
from idiomify.fetchers import
|
2 |
|
3 |
|
4 |
def main():
|
5 |
-
model =
|
6 |
print(model.bart.config)
|
7 |
|
8 |
|
|
|
1 |
+
from idiomify.fetchers import fetch_idiomifier
|
2 |
|
3 |
|
4 |
def main():
|
5 |
+
model = fetch_idiomifier("m-1-2")
|
6 |
print(model.bart.config)
|
7 |
|
8 |
|
explore/{explore_fetch_seq2seq_predict.py → explore_fetch_idiomifier_predict.py}
RENAMED
@@ -1,10 +1,10 @@
|
|
1 |
from transformers import BartTokenizer
|
2 |
from builders import SourcesBuilder
|
3 |
-
from fetchers import
|
4 |
|
5 |
|
6 |
def main():
|
7 |
-
model =
|
8 |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
9 |
lit2idi = [
|
10 |
("my man", ""),
|
|
|
1 |
from transformers import BartTokenizer
|
2 |
from builders import SourcesBuilder
|
3 |
+
from fetchers import fetch_idiomifier
|
4 |
|
5 |
|
6 |
def main():
|
7 |
+
model = fetch_idiomifier("m-1-2")
|
8 |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
9 |
lit2idi = [
|
10 |
("my man", ""),
|
explore/explore_fetch_idioms.py
CHANGED
@@ -2,7 +2,7 @@ from idiomify.fetchers import fetch_idioms
|
|
2 |
|
3 |
|
4 |
def main():
|
5 |
-
print(fetch_idioms("
|
6 |
|
7 |
|
8 |
if __name__ == '__main__':
|
|
|
2 |
|
3 |
|
4 |
def main():
|
5 |
+
print(fetch_idioms("d-1-2"))
|
6 |
|
7 |
|
8 |
if __name__ == '__main__':
|
explore/explore_fetch_literal2idiomatic.py
CHANGED
@@ -2,8 +2,9 @@ from idiomify.fetchers import fetch_literal2idiomatic
|
|
2 |
|
3 |
|
4 |
def main():
|
5 |
-
|
6 |
-
|
|
|
7 |
|
8 |
|
9 |
if __name__ == '__main__':
|
|
|
2 |
|
3 |
|
4 |
def main():
|
5 |
+
train_df, test_df = fetch_literal2idiomatic("d-1-2")
|
6 |
+
print(train_df.size) # 12408 rows
|
7 |
+
print(test_df.size) # 3102 rows
|
8 |
|
9 |
|
10 |
if __name__ == '__main__':
|
explore/explore_fetch_pie.py
CHANGED
@@ -3,11 +3,9 @@ from idiomify.fetchers import fetch_pie
|
|
3 |
|
4 |
|
5 |
def main():
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
if idx == 105:
|
10 |
-
break
|
11 |
|
12 |
|
13 |
if __name__ == '__main__':
|
|
|
3 |
|
4 |
|
5 |
def main():
|
6 |
+
pie_df = fetch_pie()
|
7 |
+
for idx, row in pie_df.iterrows():
|
8 |
+
print(row)
|
|
|
|
|
9 |
|
10 |
|
11 |
if __name__ == '__main__':
|
explore/explore_fetch_pie_df_select.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fetchers import fetch_pie
|
2 |
+
|
3 |
+
|
4 |
+
def main():
|
5 |
+
pie_df = fetch_pie()
|
6 |
+
print(pie_df.columns)
|
7 |
+
pie_df = pie_df[["Literal_Sent", "Idiomatic_Sent"]]
|
8 |
+
print(pie_df.head(5))
|
9 |
+
|
10 |
+
|
11 |
+
if __name__ == '__main__':
|
12 |
+
main()
|
explore/explore_idiomifydatamodule.py
CHANGED
@@ -3,7 +3,7 @@ from idiomify.data import IdiomifyDataModule
|
|
3 |
|
4 |
|
5 |
CONFIG = {
|
6 |
-
"literal2idiomatic_ver": "
|
7 |
"batch_size": 20,
|
8 |
"num_workers": 4,
|
9 |
"shuffle": True
|
@@ -11,7 +11,7 @@ CONFIG = {
|
|
11 |
|
12 |
|
13 |
def main():
|
14 |
-
tokenizer = BartTokenizer.from_pretrained("facebook/bart-
|
15 |
datamodule = IdiomifyDataModule(CONFIG, tokenizer)
|
16 |
datamodule.prepare_data()
|
17 |
datamodule.setup()
|
@@ -20,6 +20,14 @@ def main():
|
|
20 |
print(srcs.shape)
|
21 |
print(tgts_r.shape)
|
22 |
print(tgts.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
if __name__ == '__main__':
|
|
|
3 |
|
4 |
|
5 |
CONFIG = {
|
6 |
+
"literal2idiomatic_ver": "d-1-2",
|
7 |
"batch_size": 20,
|
8 |
"num_workers": 4,
|
9 |
"shuffle": True
|
|
|
11 |
|
12 |
|
13 |
def main():
|
14 |
+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
15 |
datamodule = IdiomifyDataModule(CONFIG, tokenizer)
|
16 |
datamodule.prepare_data()
|
17 |
datamodule.setup()
|
|
|
20 |
print(srcs.shape)
|
21 |
print(tgts_r.shape)
|
22 |
print(tgts.shape)
|
23 |
+
break
|
24 |
+
|
25 |
+
for batch in datamodule.test_dataloader():
|
26 |
+
srcs, tgts_r, tgts = batch
|
27 |
+
print(srcs.shape)
|
28 |
+
print(tgts_r.shape)
|
29 |
+
print(tgts.shape)
|
30 |
+
break
|
31 |
|
32 |
|
33 |
if __name__ == '__main__':
|
idiomify/builders.py
CHANGED
@@ -55,9 +55,9 @@ class SourcesBuilder(TensorBuilder):
|
|
55 |
padding=True,
|
56 |
truncation=True,
|
57 |
add_special_tokens=True)
|
58 |
-
|
59 |
-
|
60 |
-
return
|
61 |
|
62 |
|
63 |
class TargetsRightShiftedBuilder(TensorBuilder):
|
|
|
55 |
padding=True,
|
56 |
truncation=True,
|
57 |
add_special_tokens=True)
|
58 |
+
srcs = torch.stack([encodings['input_ids'],
|
59 |
+
encodings['attention_mask']], dim=1) # (N, 2, L)
|
60 |
+
return srcs # (N, 2, L)
|
61 |
|
62 |
|
63 |
class TargetsRightShiftedBuilder(TensorBuilder):
|
idiomify/data.py
CHANGED
@@ -1,9 +1,9 @@
|
|
|
|
1 |
import torch
|
2 |
-
from typing import Tuple, Optional
|
3 |
from torch.utils.data import Dataset, DataLoader
|
4 |
from pytorch_lightning import LightningDataModule
|
5 |
from wandb.sdk.wandb_run import Run
|
6 |
-
|
7 |
from idiomify.fetchers import fetch_literal2idiomatic
|
8 |
from idiomify.builders import SourcesBuilder, TargetsBuilder, TargetsRightShiftedBuilder
|
9 |
from transformers import BartTokenizer
|
@@ -38,9 +38,6 @@ class IdiomifyDataset(Dataset):
|
|
38 |
class IdiomifyDataModule(LightningDataModule):
|
39 |
|
40 |
# boilerplate - just ignore these
|
41 |
-
def test_dataloader(self):
|
42 |
-
pass
|
43 |
-
|
44 |
def val_dataloader(self):
|
45 |
pass
|
46 |
|
@@ -56,23 +53,37 @@ class IdiomifyDataModule(LightningDataModule):
|
|
56 |
self.tokenizer = tokenizer
|
57 |
self.run = run
|
58 |
# --- to be downloaded & built --- #
|
59 |
-
self.
|
60 |
-
self.
|
|
|
|
|
61 |
|
62 |
def prepare_data(self):
|
63 |
"""
|
64 |
prepare: download all data needed for this from wandb to local.
|
65 |
"""
|
66 |
-
self.
|
67 |
|
68 |
def setup(self, stage: Optional[str] = None):
|
69 |
# --- set up the builders --- #
|
70 |
# build the datasets
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
def train_dataloader(self) -> DataLoader:
|
77 |
-
return DataLoader(self.
|
78 |
shuffle=self.config['shuffle'], num_workers=self.config['num_workers'])
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
import torch
|
3 |
+
from typing import Tuple, Optional
|
4 |
from torch.utils.data import Dataset, DataLoader
|
5 |
from pytorch_lightning import LightningDataModule
|
6 |
from wandb.sdk.wandb_run import Run
|
|
|
7 |
from idiomify.fetchers import fetch_literal2idiomatic
|
8 |
from idiomify.builders import SourcesBuilder, TargetsBuilder, TargetsRightShiftedBuilder
|
9 |
from transformers import BartTokenizer
|
|
|
38 |
class IdiomifyDataModule(LightningDataModule):
|
39 |
|
40 |
# boilerplate - just ignore these
|
|
|
|
|
|
|
41 |
def val_dataloader(self):
|
42 |
pass
|
43 |
|
|
|
53 |
self.tokenizer = tokenizer
|
54 |
self.run = run
|
55 |
# --- to be downloaded & built --- #
|
56 |
+
self.train_df: Optional[pd.DataFrame] = None
|
57 |
+
self.test_df: Optional[pd.DataFrame] = None
|
58 |
+
self.train_dataset: Optional[IdiomifyDataset] = None
|
59 |
+
self.test_dataset: Optional[IdiomifyDataset] = None
|
60 |
|
61 |
def prepare_data(self):
|
62 |
"""
|
63 |
prepare: download all data needed for this from wandb to local.
|
64 |
"""
|
65 |
+
self.train_df, self.test_df = fetch_literal2idiomatic(self.config['literal2idiomatic_ver'], self.run)
|
66 |
|
67 |
def setup(self, stage: Optional[str] = None):
|
68 |
# --- set up the builders --- #
|
69 |
# build the datasets
|
70 |
+
self.train_dataset = self.build_dataset(self.train_df)
|
71 |
+
self.test_dataset = self.build_dataset(self.test_df)
|
72 |
+
|
73 |
+
def build_dataset(self, df: pd.DataFrame) -> IdiomifyDataset:
|
74 |
+
literal2idiomatic = [
|
75 |
+
(row['Literal_Sent'], row['Idiomatic_Sent'])
|
76 |
+
for _, row in df.iterrows()
|
77 |
+
]
|
78 |
+
srcs = SourcesBuilder(self.tokenizer)(literal2idiomatic)
|
79 |
+
tgts_r = TargetsRightShiftedBuilder(self.tokenizer)(literal2idiomatic)
|
80 |
+
tgts = TargetsBuilder(self.tokenizer)(literal2idiomatic)
|
81 |
+
return IdiomifyDataset(srcs, tgts_r, tgts)
|
82 |
|
83 |
def train_dataloader(self) -> DataLoader:
|
84 |
+
return DataLoader(self.train_dataset, batch_size=self.config['batch_size'],
|
85 |
shuffle=self.config['shuffle'], num_workers=self.config['num_workers'])
|
86 |
+
|
87 |
+
def test_dataloader(self):
|
88 |
+
return DataLoader(self.test_dataset, batch_size=self.config['batch_size'],
|
89 |
+
shuffle=False, num_workers=self.config['num_workers'])
|
idiomify/fetchers.py
CHANGED
@@ -1,25 +1,18 @@
|
|
1 |
-
import csv
|
2 |
-
from os import path
|
3 |
import yaml
|
4 |
import wandb
|
5 |
-
import
|
|
|
6 |
from typing import Tuple, List
|
7 |
from wandb.sdk.wandb_run import Run
|
8 |
from idiomify.paths import CONFIG_YAML, idioms_dir, literal2idiomatic, seq2seq_dir
|
9 |
from idiomify.urls import PIE_URL
|
10 |
from transformers import AutoModelForSeq2SeqLM, AutoConfig
|
11 |
-
from idiomify.models import
|
12 |
|
13 |
|
14 |
-
def fetch_pie() ->
|
15 |
-
|
16 |
-
|
17 |
-
reader = csv.reader(lines)
|
18 |
-
next(reader) # skip the header
|
19 |
-
return [
|
20 |
-
row
|
21 |
-
for row in reader
|
22 |
-
]
|
23 |
|
24 |
|
25 |
# --- from wandb --- #
|
@@ -39,7 +32,7 @@ def fetch_idioms(ver: str, run: Run = None) -> List[str]:
|
|
39 |
return [line.strip() for line in fh]
|
40 |
|
41 |
|
42 |
-
def fetch_literal2idiomatic(ver: str, run: Run = None) ->
|
43 |
# if run object is given, we track the lineage of the data.
|
44 |
# if not, we get the dataset via wandb Api.
|
45 |
if run:
|
@@ -47,13 +40,18 @@ def fetch_literal2idiomatic(ver: str, run: Run = None) -> List[Tuple[str, str]]:
|
|
47 |
else:
|
48 |
artifact = wandb.Api().artifact(f"eubinecto/idiomify/literal2idiomatic:{ver}", type="dataset")
|
49 |
artifact_dir = artifact.download(root=literal2idiomatic(ver))
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
54 |
|
55 |
|
56 |
-
def
|
|
|
|
|
|
|
|
|
57 |
if run:
|
58 |
artifact = run.use_artifact(f"seq2seq:{ver}", type="model")
|
59 |
else:
|
@@ -62,7 +60,7 @@ def fetch_seq2seq(ver: str, run: Run = None) -> Seq2Seq:
|
|
62 |
artifact_dir = artifact.download(root=seq2seq_dir(ver))
|
63 |
ckpt_path = path.join(artifact_dir, "model.ckpt")
|
64 |
bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
|
65 |
-
alpha =
|
66 |
return alpha
|
67 |
|
68 |
|
|
|
|
|
|
|
1 |
import yaml
|
2 |
import wandb
|
3 |
+
from os import path
|
4 |
+
import pandas as pd
|
5 |
from typing import Tuple, List
|
6 |
from wandb.sdk.wandb_run import Run
|
7 |
from idiomify.paths import CONFIG_YAML, idioms_dir, literal2idiomatic, seq2seq_dir
|
8 |
from idiomify.urls import PIE_URL
|
9 |
from transformers import AutoModelForSeq2SeqLM, AutoConfig
|
10 |
+
from idiomify.models import Idiomifier
|
11 |
|
12 |
|
13 |
+
def fetch_pie() -> pd.DataFrame:
|
14 |
+
# fetch & parse it directly from the web
|
15 |
+
return pd.read_csv(PIE_URL)
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
# --- from wandb --- #
|
|
|
32 |
return [line.strip() for line in fh]
|
33 |
|
34 |
|
35 |
+
def fetch_literal2idiomatic(ver: str, run: Run = None) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
36 |
# if run object is given, we track the lineage of the data.
|
37 |
# if not, we get the dataset via wandb Api.
|
38 |
if run:
|
|
|
40 |
else:
|
41 |
artifact = wandb.Api().artifact(f"eubinecto/idiomify/literal2idiomatic:{ver}", type="dataset")
|
42 |
artifact_dir = artifact.download(root=literal2idiomatic(ver))
|
43 |
+
train_path = path.join(artifact_dir, "train.tsv")
|
44 |
+
test_path = path.join(artifact_dir, "test.tsv")
|
45 |
+
train_df = pd.read_csv(train_path, sep="\t")
|
46 |
+
test_df = pd.read_csv(test_path, sep="\t")
|
47 |
+
return train_df, test_df
|
48 |
|
49 |
|
50 |
+
def fetch_idiomifier(ver: str, run: Run = None) -> Idiomifier:
|
51 |
+
"""
|
52 |
+
you may want to change the name to Idiomifier.
|
53 |
+
The current Idiomifier then turns into a pipeline.
|
54 |
+
"""
|
55 |
if run:
|
56 |
artifact = run.use_artifact(f"seq2seq:{ver}", type="model")
|
57 |
else:
|
|
|
60 |
artifact_dir = artifact.download(root=seq2seq_dir(ver))
|
61 |
ckpt_path = path.join(artifact_dir, "model.ckpt")
|
62 |
bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
|
63 |
+
alpha = Idiomifier.load_from_checkpoint(ckpt_path, bart=bart)
|
64 |
return alpha
|
65 |
|
66 |
|
idiomify/models.py
CHANGED
@@ -9,8 +9,7 @@ from transformers import BartForConditionalGeneration, BartTokenizer
|
|
9 |
from idiomify.builders import SourcesBuilder
|
10 |
|
11 |
|
12 |
-
#
|
13 |
-
class Seq2Seq(pl.LightningModule): # noqa
|
14 |
"""
|
15 |
the baseline is in here.
|
16 |
"""
|
@@ -58,12 +57,11 @@ class Seq2Seq(pl.LightningModule): # noqa
|
|
58 |
|
59 |
|
60 |
# for inference
|
61 |
-
class
|
62 |
|
63 |
-
def __init__(self, model:
|
64 |
self.model = model
|
65 |
self.builder = SourcesBuilder(tokenizer)
|
66 |
-
self.model.eval()
|
67 |
|
68 |
def __call__(self, src: str, max_length=100) -> str:
|
69 |
srcs = self.builder(literal2idiomatic=[(src, "")])
|
|
|
9 |
from idiomify.builders import SourcesBuilder
|
10 |
|
11 |
|
12 |
+
class Idiomifier(pl.LightningModule): # noqa
|
|
|
13 |
"""
|
14 |
the baseline is in here.
|
15 |
"""
|
|
|
57 |
|
58 |
|
59 |
# for inference
|
60 |
+
class Pipeline:
|
61 |
|
62 |
+
def __init__(self, model: Idiomifier, tokenizer: BartTokenizer):
|
63 |
self.model = model
|
64 |
self.builder = SourcesBuilder(tokenizer)
|
|
|
65 |
|
66 |
def __call__(self, src: str, max_length=100) -> str:
|
67 |
srcs = self.builder(literal2idiomatic=[(src, "")])
|
idiomify/paths.py
CHANGED
@@ -6,12 +6,12 @@ CONFIG_YAML = ROOT_DIR / "config.yaml"
|
|
6 |
|
7 |
|
8 |
def idioms_dir(ver: str) -> Path:
|
9 |
-
return ARTIFACTS_DIR / f"
|
10 |
|
11 |
|
12 |
def literal2idiomatic(ver: str) -> Path:
|
13 |
-
return ARTIFACTS_DIR / f"
|
14 |
|
15 |
|
16 |
def seq2seq_dir(ver: str) -> Path:
|
17 |
-
return ARTIFACTS_DIR / f"
|
|
|
6 |
|
7 |
|
8 |
def idioms_dir(ver: str) -> Path:
|
9 |
+
return ARTIFACTS_DIR / f"idioms_{ver}"
|
10 |
|
11 |
|
12 |
def literal2idiomatic(ver: str) -> Path:
|
13 |
+
return ARTIFACTS_DIR / f"literal2idiomatic_{ver}"
|
14 |
|
15 |
|
16 |
def seq2seq_dir(ver: str) -> Path:
|
17 |
+
return ARTIFACTS_DIR / f"seq2seq_{ver}"
|
idiomify/preprocess.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import pandas as pd
|
3 |
+
from sklearn.model_selection import train_test_split
|
4 |
+
|
5 |
+
|
6 |
+
def upsample(df: pd.DataFrame, seed: int) -> pd.DataFrame:
|
7 |
+
# TODO: implement upsampling later
|
8 |
+
return df
|
9 |
+
|
10 |
+
|
11 |
+
def cleanse(df: pd.DataFrame) -> pd.DataFrame:
|
12 |
+
"""
|
13 |
+
:param df:
|
14 |
+
:return:
|
15 |
+
"""
|
16 |
+
# TODO: implement cleansing
|
17 |
+
return df
|
18 |
+
|
19 |
+
|
20 |
+
def stratified_split(df: pd.DataFrame, ratio: float, seed: int) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
21 |
+
"""
|
22 |
+
stratified-split the given df into two df's.
|
23 |
+
"""
|
24 |
+
total = len(df)
|
25 |
+
ratio_size = int(total * ratio)
|
26 |
+
other_size = total - ratio_size
|
27 |
+
ratio_df, other_df = train_test_split(df, train_size=ratio_size,
|
28 |
+
stratify=df['Idiom'],
|
29 |
+
test_size=other_size, random_state=seed,
|
30 |
+
shuffle=True)
|
31 |
+
return ratio_df, other_df
|
main_infer.py
CHANGED
@@ -1,21 +1,21 @@
|
|
1 |
import argparse
|
2 |
-
from idiomify.models import Idiomifier
|
3 |
-
from idiomify.fetchers import fetch_config,
|
4 |
from transformers import BartTokenizer
|
5 |
|
6 |
|
7 |
def main():
|
8 |
parser = argparse.ArgumentParser()
|
9 |
-
parser.add_argument("--ver", type=str, default="tag011")
|
10 |
parser.add_argument("--src", type=str,
|
11 |
default="If there's any good to loosing my job,"
|
12 |
" it's that I'll now be able to go to school full-time and finish my degree earlier.")
|
13 |
args = parser.parse_args()
|
14 |
-
config = fetch_config()[
|
15 |
config.update(vars(args))
|
16 |
-
model =
|
|
|
17 |
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
18 |
-
idiomifier =
|
19 |
src = config['src']
|
20 |
tgt = idiomifier(src=config['src'])
|
21 |
print(src, "\n->", tgt)
|
|
|
1 |
import argparse
|
2 |
+
from idiomify.models import Idiomifier, Pipeline
|
3 |
+
from idiomify.fetchers import fetch_config, fetch_idiomifier
|
4 |
from transformers import BartTokenizer
|
5 |
|
6 |
|
7 |
def main():
|
8 |
parser = argparse.ArgumentParser()
|
|
|
9 |
parser.add_argument("--src", type=str,
|
10 |
default="If there's any good to loosing my job,"
|
11 |
" it's that I'll now be able to go to school full-time and finish my degree earlier.")
|
12 |
args = parser.parse_args()
|
13 |
+
config = fetch_config()['infer']
|
14 |
config.update(vars(args))
|
15 |
+
model = fetch_idiomifier(config['ver'])
|
16 |
+
model.eval() # this is crucial
|
17 |
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
18 |
+
idiomifier = Pipeline(model, tokenizer)
|
19 |
src = config['src']
|
20 |
tgt = idiomifier(src=config['src'])
|
21 |
print(src, "\n->", tgt)
|
main_train.py
CHANGED
@@ -8,19 +8,18 @@ from pytorch_lightning.loggers import WandbLogger
|
|
8 |
from transformers import BartTokenizer, BartForConditionalGeneration
|
9 |
from idiomify.data import IdiomifyDataModule
|
10 |
from idiomify.fetchers import fetch_config
|
11 |
-
from idiomify.models import
|
12 |
from idiomify.paths import ROOT_DIR
|
13 |
|
14 |
|
15 |
def main():
|
16 |
parser = argparse.ArgumentParser()
|
17 |
-
parser.add_argument("--ver", type=str, default="tag011")
|
18 |
parser.add_argument("--num_workers", type=int, default=os.cpu_count())
|
19 |
parser.add_argument("--log_every_n_steps", type=int, default=1)
|
20 |
parser.add_argument("--fast_dev_run", action="store_true", default=False)
|
21 |
parser.add_argument("--upload", dest='upload', action='store_true', default=False)
|
22 |
args = parser.parse_args()
|
23 |
-
config = fetch_config()[
|
24 |
config.update(vars(args))
|
25 |
if not config['upload']:
|
26 |
print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
|
@@ -28,7 +27,7 @@ def main():
|
|
28 |
# prepare the model
|
29 |
bart = BartForConditionalGeneration.from_pretrained(config['bart'])
|
30 |
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
31 |
-
model =
|
32 |
# prepare the datamodule
|
33 |
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
34 |
datamodule = IdiomifyDataModule(config, tokenizer, run)
|
@@ -46,7 +45,7 @@ def main():
|
|
46 |
if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1:
|
47 |
ckpt_path = ROOT_DIR / "model.ckpt"
|
48 |
trainer.save_checkpoint(str(ckpt_path))
|
49 |
-
artifact = wandb.Artifact(name="
|
50 |
artifact.add_file(str(ckpt_path))
|
51 |
run.log_artifact(artifact, aliases=["latest", config['ver']])
|
52 |
os.remove(str(ckpt_path)) # make sure you remove it after you are done with uploading it
|
|
|
8 |
from transformers import BartTokenizer, BartForConditionalGeneration
|
9 |
from idiomify.data import IdiomifyDataModule
|
10 |
from idiomify.fetchers import fetch_config
|
11 |
+
from idiomify.models import Idiomifier
|
12 |
from idiomify.paths import ROOT_DIR
|
13 |
|
14 |
|
15 |
def main():
|
16 |
parser = argparse.ArgumentParser()
|
|
|
17 |
parser.add_argument("--num_workers", type=int, default=os.cpu_count())
|
18 |
parser.add_argument("--log_every_n_steps", type=int, default=1)
|
19 |
parser.add_argument("--fast_dev_run", action="store_true", default=False)
|
20 |
parser.add_argument("--upload", dest='upload', action='store_true', default=False)
|
21 |
args = parser.parse_args()
|
22 |
+
config = fetch_config()['train']
|
23 |
config.update(vars(args))
|
24 |
if not config['upload']:
|
25 |
print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
|
|
|
27 |
# prepare the model
|
28 |
bart = BartForConditionalGeneration.from_pretrained(config['bart'])
|
29 |
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
30 |
+
model = Idiomifier(bart, config['lr'], tokenizer.bos_token_id, tokenizer.pad_token_id)
|
31 |
# prepare the datamodule
|
32 |
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
33 |
datamodule = IdiomifyDataModule(config, tokenizer, run)
|
|
|
45 |
if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1:
|
46 |
ckpt_path = ROOT_DIR / "model.ckpt"
|
47 |
trainer.save_checkpoint(str(ckpt_path))
|
48 |
+
artifact = wandb.Artifact(name="idiomifier", type="model", metadata=config)
|
49 |
artifact.add_file(str(ckpt_path))
|
50 |
run.log_artifact(artifact, aliases=["latest", config['ver']])
|
51 |
os.remove(str(ckpt_path)) # make sure you remove it after you are done with uploading it
|
main_upload_idioms.py
CHANGED
@@ -1,35 +1,31 @@
|
|
1 |
"""
|
2 |
-
|
3 |
-
|
4 |
"""
|
5 |
import os
|
6 |
-
from idiomify.paths import ROOT_DIR
|
7 |
-
from idiomify.fetchers import fetch_pie
|
8 |
-
import argparse
|
9 |
import wandb
|
|
|
|
|
10 |
|
11 |
|
12 |
def main():
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
# get the idioms here
|
18 |
-
if config['ver'] == "tag01":
|
19 |
-
# only the first 106, and this is for piloting
|
20 |
-
idioms = set([row[0] for row in fetch_pie()[:106]])
|
21 |
-
else:
|
22 |
-
raise NotImplementedError
|
23 |
-
idioms = list(idioms)
|
24 |
|
25 |
-
with wandb.init(entity="eubinecto", project="idiomify"
|
26 |
-
|
27 |
txt_path = ROOT_DIR / "all.txt"
|
28 |
with open(txt_path, 'w') as fh:
|
29 |
for idiom in idioms:
|
30 |
fh.write(idiom + "\n")
|
|
|
|
|
31 |
artifact.add_file(txt_path)
|
|
|
32 |
run.log_artifact(artifact, aliases=["latest", config['ver']])
|
|
|
33 |
os.remove(txt_path)
|
34 |
|
35 |
|
|
|
1 |
"""
|
2 |
+
will do this when I need to.
|
3 |
+
Is it absolutely necessary to keep track of idioms separately?
|
4 |
"""
|
5 |
import os
|
|
|
|
|
|
|
6 |
import wandb
|
7 |
+
from idiomify.fetchers import fetch_literal2idiomatic, fetch_config
|
8 |
+
from idiomify.paths import ROOT_DIR
|
9 |
|
10 |
|
11 |
def main():
|
12 |
+
config = fetch_config()['upload']['idioms']
|
13 |
+
train_df, _ = fetch_literal2idiomatic(config['ver'])
|
14 |
+
idioms = train_df['Idiom'].tolist()
|
15 |
+
idioms = list(set(idioms))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
with wandb.init(entity="eubinecto", project="idiomify") as run:
|
18 |
+
# the paths to write datasets in
|
19 |
txt_path = ROOT_DIR / "all.txt"
|
20 |
with open(txt_path, 'w') as fh:
|
21 |
for idiom in idioms:
|
22 |
fh.write(idiom + "\n")
|
23 |
+
artifact = wandb.Artifact(name="idioms", type="dataset", description=config['description'],
|
24 |
+
metadata=config)
|
25 |
artifact.add_file(txt_path)
|
26 |
+
# then, we just log them here.
|
27 |
run.log_artifact(artifact, aliases=["latest", config['ver']])
|
28 |
+
# don't forget to remove them
|
29 |
os.remove(txt_path)
|
30 |
|
31 |
|
main_upload_literal2idiomatic.py
CHANGED
@@ -1,39 +1,40 @@
|
|
1 |
"""
|
2 |
-
|
3 |
-
just upload all idioms here - name it as epie.
|
4 |
"""
|
5 |
-
import csv
|
6 |
import os
|
7 |
from idiomify.paths import ROOT_DIR
|
8 |
-
from idiomify.fetchers import fetch_pie
|
9 |
-
import
|
10 |
import wandb
|
11 |
|
12 |
|
13 |
def main():
|
14 |
-
parser = argparse.ArgumentParser()
|
15 |
-
parser.add_argument("--ver", type=str, default="tag01")
|
16 |
-
config = vars(parser.parse_args())
|
17 |
|
18 |
-
#
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
35 |
run.log_artifact(artifact, aliases=["latest", config['ver']])
|
36 |
-
|
|
|
|
|
37 |
|
38 |
|
39 |
if __name__ == '__main__':
|
|
|
1 |
"""
|
2 |
+
literal2idiomatic ver: d-1-2
|
|
|
3 |
"""
|
|
|
4 |
import os
|
5 |
from idiomify.paths import ROOT_DIR
|
6 |
+
from idiomify.fetchers import fetch_pie, fetch_config
|
7 |
+
from idiomify.preprocess import upsample, cleanse, stratified_split
|
8 |
import wandb
|
9 |
|
10 |
|
11 |
def main():
|
|
|
|
|
|
|
12 |
|
13 |
+
# here, we use all of them, while splitting them into train & test
|
14 |
+
pie_df = fetch_pie()
|
15 |
+
config = fetch_config()['upload']['literal2idiomatic']
|
16 |
+
train_df, test_df = pie_df.pipe(cleanse)\
|
17 |
+
.pipe(upsample, seed=config['seed'])\
|
18 |
+
.pipe(stratified_split, ratio=config['train_ratio'], seed=config['seed'])
|
19 |
+
# why don't you just "select" the columns? yeah, stop using csv library. just select them.
|
20 |
+
train_df = train_df[["Idiom", "Literal_Sent", "Idiomatic_Sent"]]
|
21 |
+
test_df = test_df[["Idiom", "Literal_Sent", "Idiomatic_Sent"]]
|
22 |
+
dfs = (train_df, test_df)
|
23 |
+
with wandb.init(entity="eubinecto", project="idiomify") as run:
|
24 |
+
# the paths to write datasets in
|
25 |
+
train_path = ROOT_DIR / "train.tsv"
|
26 |
+
test_path = ROOT_DIR / "test.tsv"
|
27 |
+
paths = (train_path, test_path)
|
28 |
+
artifact = wandb.Artifact(name="literal2idiomatic", type="dataset", description=config['description'],
|
29 |
+
metadata=config)
|
30 |
+
for tsv_path, df in zip(paths, dfs):
|
31 |
+
df.to_csv(tsv_path, sep="\t")
|
32 |
+
artifact.add_file(tsv_path)
|
33 |
+
# then, we just log them here.
|
34 |
run.log_artifact(artifact, aliases=["latest", config['ver']])
|
35 |
+
# don't forget to remove them
|
36 |
+
for tsv_path in paths:
|
37 |
+
os.remove(tsv_path)
|
38 |
|
39 |
|
40 |
if __name__ == '__main__':
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
pytorch-lightning==1.5.10
|
2 |
transformers==4.16.2
|
3 |
-
wandb==0.12.10
|
|
|
|
1 |
pytorch-lightning==1.5.10
|
2 |
transformers==4.16.2
|
3 |
+
wandb==0.12.10
|
4 |
+
scikit-learn==1.0.2
|