Merge pull request #3 from eubinecto/issue_2
Browse files- README.md +16 -7
- config.yaml +19 -6
- explore/explore_bart_logits_shape.py +1 -1
- explore/{explore_fetch_seq2seq.py → explore_fetch_idiomifier.py} +2 -2
- explore/{explore_fetch_seq2seq_predict.py → explore_fetch_idiomifier_predict.py} +2 -2
- explore/explore_fetch_idioms.py +1 -1
- explore/explore_fetch_literal2idiomatic.py +3 -2
- explore/explore_fetch_pie.py +3 -5
- explore/explore_fetch_pie_df_select.py +12 -0
- explore/explore_idiomifydatamodule.py +11 -3
- explore/explore_torchmetrics_bleu.py +28 -0
- idiomify/builders.py +3 -5
- idiomify/{data.py → datamodules.py} +24 -13
- idiomify/fetchers.py +24 -25
- idiomify/metrics.py +0 -4
- idiomify/models.py +21 -25
- idiomify/paths.py +4 -4
- idiomify/pipeline.py +22 -0
- idiomify/preprocess.py +31 -0
- main_deploy.py +43 -0
- main_eval.py +34 -0
- main_infer.py +12 -9
- main_train.py +5 -7
- main_upload_idioms.py +14 -18
- main_upload_literal2idiomatic.py +27 -26
- requirements.txt +5 -1
README.md
CHANGED
@@ -1,13 +1,22 @@
|
|
1 |
# Idiomify
|
|
|
2 |
|
3 |
-
A human-inspired Idiomifier based on BERT
|
4 |
|
5 |
-
|
6 |
|
7 |
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Idiomify
|
2 |
+
[data:image/s3,"s3://crabby-images/09b99/09b99edd7fe9b0c15af68405ca2b0e373cc0a8cb" alt="Open in Streamlit"](https://share.streamlit.io/eubinecto/idiomify/issue_2/main_deploy.py)
|
3 |
|
|
|
4 |
|
5 |
+
Grammarly for idioms. A human-inspired Idiomifier based on BART.
|
6 |
|
7 |
|
8 |
|
9 |
+
<img width="764" alt="image" src="https://user-images.githubusercontent.com/56193069/156941205-830b53aa-a3e6-4263-be03-e568124a256e.png">
|
10 |
+
|
11 |
+
|
12 |
+
## Versions
|
13 |
+
### models
|
14 |
+
format: `m-a-b`
|
15 |
+
- a: used to indicate a change in the architecture, or a revision of the final product
|
16 |
+
- b: used to indicate a different version of the same architecture (with a different set of hyperparameters)
|
17 |
+
|
18 |
+
|
19 |
+
### datasets
|
20 |
+
format: `d-a-b`
|
21 |
+
- a: used to indicate a change in the dataset we are using
|
22 |
+
- b: used to indicate a different version of the dataset
|
config.yaml
CHANGED
@@ -1,8 +1,21 @@
|
|
1 |
-
|
2 |
-
|
|
|
3 |
bart: facebook/bart-base
|
4 |
lr: 0.0001
|
5 |
-
literal2idiomatic_ver:
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
idiomifier:
|
2 |
+
ver: m-1-2
|
3 |
+
desc: just overfitting the model, but on the entire PIE dataset.
|
4 |
bart: facebook/bart-base
|
5 |
lr: 0.0001
|
6 |
+
literal2idiomatic_ver: d-1-2
|
7 |
+
idioms_ver: d-1-2
|
8 |
+
max_epochs: 2
|
9 |
+
batch_size: 40
|
10 |
+
shuffle: true
|
11 |
+
seed: 104
|
12 |
+
|
13 |
+
# for building & uploading datasets or tokenizer
|
14 |
+
idioms:
|
15 |
+
ver: d-1-2
|
16 |
+
description: the set of idioms in the traning set of literal2idiomatic_d-1-2.
|
17 |
+
literal2idiomatic:
|
18 |
+
ver: d-1-2
|
19 |
+
description: PIE data split into train & test set (80 / 20 split). There is no validation set because I don't intend to do any hyperparameter tuning on this thing.
|
20 |
+
train_ratio: 0.8
|
21 |
+
seed: 104
|
explore/explore_bart_logits_shape.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from transformers import BartTokenizer, BartForConditionalGeneration
|
2 |
|
3 |
-
from
|
4 |
|
5 |
|
6 |
CONFIG = {
|
|
|
1 |
from transformers import BartTokenizer, BartForConditionalGeneration
|
2 |
|
3 |
+
from datamodules import IdiomifyDataModule
|
4 |
|
5 |
|
6 |
CONFIG = {
|
explore/{explore_fetch_seq2seq.py → explore_fetch_idiomifier.py}
RENAMED
@@ -1,8 +1,8 @@
|
|
1 |
-
from idiomify.fetchers import
|
2 |
|
3 |
|
4 |
def main():
|
5 |
-
model =
|
6 |
print(model.bart.config)
|
7 |
|
8 |
|
|
|
1 |
+
from idiomify.fetchers import fetch_idiomifier
|
2 |
|
3 |
|
4 |
def main():
|
5 |
+
model = fetch_idiomifier("m-1-2")
|
6 |
print(model.bart.config)
|
7 |
|
8 |
|
explore/{explore_fetch_seq2seq_predict.py → explore_fetch_idiomifier_predict.py}
RENAMED
@@ -1,10 +1,10 @@
|
|
1 |
from transformers import BartTokenizer
|
2 |
from builders import SourcesBuilder
|
3 |
-
from fetchers import
|
4 |
|
5 |
|
6 |
def main():
|
7 |
-
model =
|
8 |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
9 |
lit2idi = [
|
10 |
("my man", ""),
|
|
|
1 |
from transformers import BartTokenizer
|
2 |
from builders import SourcesBuilder
|
3 |
+
from fetchers import fetch_idiomifier
|
4 |
|
5 |
|
6 |
def main():
|
7 |
+
model = fetch_idiomifier("m-1-2")
|
8 |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
9 |
lit2idi = [
|
10 |
("my man", ""),
|
explore/explore_fetch_idioms.py
CHANGED
@@ -2,7 +2,7 @@ from idiomify.fetchers import fetch_idioms
|
|
2 |
|
3 |
|
4 |
def main():
|
5 |
-
print(fetch_idioms("
|
6 |
|
7 |
|
8 |
if __name__ == '__main__':
|
|
|
2 |
|
3 |
|
4 |
def main():
|
5 |
+
print(fetch_idioms("d-1-2"))
|
6 |
|
7 |
|
8 |
if __name__ == '__main__':
|
explore/explore_fetch_literal2idiomatic.py
CHANGED
@@ -2,8 +2,9 @@ from idiomify.fetchers import fetch_literal2idiomatic
|
|
2 |
|
3 |
|
4 |
def main():
|
5 |
-
|
6 |
-
|
|
|
7 |
|
8 |
|
9 |
if __name__ == '__main__':
|
|
|
2 |
|
3 |
|
4 |
def main():
|
5 |
+
train_df, test_df = fetch_literal2idiomatic("d-1-2")
|
6 |
+
print(train_df.size) # 12408 rows
|
7 |
+
print(test_df.size) # 3102 rows
|
8 |
|
9 |
|
10 |
if __name__ == '__main__':
|
explore/explore_fetch_pie.py
CHANGED
@@ -3,11 +3,9 @@ from idiomify.fetchers import fetch_pie
|
|
3 |
|
4 |
|
5 |
def main():
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
if idx == 105:
|
10 |
-
break
|
11 |
|
12 |
|
13 |
if __name__ == '__main__':
|
|
|
3 |
|
4 |
|
5 |
def main():
|
6 |
+
pie_df = fetch_pie()
|
7 |
+
for idx, row in pie_df.iterrows():
|
8 |
+
print(row)
|
|
|
|
|
9 |
|
10 |
|
11 |
if __name__ == '__main__':
|
explore/explore_fetch_pie_df_select.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fetchers import fetch_pie
|
2 |
+
|
3 |
+
|
4 |
+
def main():
|
5 |
+
pie_df = fetch_pie()
|
6 |
+
print(pie_df.columns)
|
7 |
+
pie_df = pie_df[["Literal_Sent", "Idiomatic_Sent"]]
|
8 |
+
print(pie_df.head(5))
|
9 |
+
|
10 |
+
|
11 |
+
if __name__ == '__main__':
|
12 |
+
main()
|
explore/explore_idiomifydatamodule.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
from transformers import BartTokenizer
|
2 |
-
from idiomify.
|
3 |
|
4 |
|
5 |
CONFIG = {
|
6 |
-
"literal2idiomatic_ver": "
|
7 |
"batch_size": 20,
|
8 |
"num_workers": 4,
|
9 |
"shuffle": True
|
@@ -11,7 +11,7 @@ CONFIG = {
|
|
11 |
|
12 |
|
13 |
def main():
|
14 |
-
tokenizer = BartTokenizer.from_pretrained("facebook/bart-
|
15 |
datamodule = IdiomifyDataModule(CONFIG, tokenizer)
|
16 |
datamodule.prepare_data()
|
17 |
datamodule.setup()
|
@@ -20,6 +20,14 @@ def main():
|
|
20 |
print(srcs.shape)
|
21 |
print(tgts_r.shape)
|
22 |
print(tgts.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
if __name__ == '__main__':
|
|
|
1 |
from transformers import BartTokenizer
|
2 |
+
from idiomify.datamodules import IdiomifyDataModule
|
3 |
|
4 |
|
5 |
CONFIG = {
|
6 |
+
"literal2idiomatic_ver": "d-1-2",
|
7 |
"batch_size": 20,
|
8 |
"num_workers": 4,
|
9 |
"shuffle": True
|
|
|
11 |
|
12 |
|
13 |
def main():
|
14 |
+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
15 |
datamodule = IdiomifyDataModule(CONFIG, tokenizer)
|
16 |
datamodule.prepare_data()
|
17 |
datamodule.setup()
|
|
|
20 |
print(srcs.shape)
|
21 |
print(tgts_r.shape)
|
22 |
print(tgts.shape)
|
23 |
+
break
|
24 |
+
|
25 |
+
for batch in datamodule.test_dataloader():
|
26 |
+
srcs, tgts_r, tgts = batch
|
27 |
+
print(srcs.shape)
|
28 |
+
print(tgts_r.shape)
|
29 |
+
print(tgts.shape)
|
30 |
+
break
|
31 |
|
32 |
|
33 |
if __name__ == '__main__':
|
explore/explore_torchmetrics_bleu.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from torchmetrics import BLEUScore
|
3 |
+
from transformers import BartTokenizer
|
4 |
+
|
5 |
+
|
6 |
+
pairs = [
|
7 |
+
("I knew you could do it", "I knew you could do it"),
|
8 |
+
("I knew you could do it", "you knew you could do it")
|
9 |
+
]
|
10 |
+
|
11 |
+
|
12 |
+
def main():
|
13 |
+
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
|
14 |
+
metric = BLEUScore()
|
15 |
+
preds = tokenizer([pred for pred, _ in pairs])['input_ids']
|
16 |
+
targets = tokenizer([target for _, target in pairs])['input_ids']
|
17 |
+
print(preds)
|
18 |
+
print(targets)
|
19 |
+
print(metric(preds, targets))
|
20 |
+
# arghhh, so bleu score does not support tensors...
|
21 |
+
"""
|
22 |
+
AttributeError: 'int' object has no attribute 'split'
|
23 |
+
"""
|
24 |
+
# let's just go for the accuracies then.
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == '__main__':
|
28 |
+
main()
|
idiomify/builders.py
CHANGED
@@ -55,9 +55,9 @@ class SourcesBuilder(TensorBuilder):
|
|
55 |
padding=True,
|
56 |
truncation=True,
|
57 |
add_special_tokens=True)
|
58 |
-
|
59 |
-
|
60 |
-
return
|
61 |
|
62 |
|
63 |
class TargetsRightShiftedBuilder(TensorBuilder):
|
@@ -83,5 +83,3 @@ class TargetsBuilder(TensorBuilder):
|
|
83 |
], return_tensors="pt", add_special_tokens=False, padding=True, truncation=True)
|
84 |
tgts = encodings['input_ids']
|
85 |
return tgts # (N, L)
|
86 |
-
|
87 |
-
|
|
|
55 |
padding=True,
|
56 |
truncation=True,
|
57 |
add_special_tokens=True)
|
58 |
+
srcs = torch.stack([encodings['input_ids'],
|
59 |
+
encodings['attention_mask']], dim=1) # (N, 2, L)
|
60 |
+
return srcs # (N, 2, L)
|
61 |
|
62 |
|
63 |
class TargetsRightShiftedBuilder(TensorBuilder):
|
|
|
83 |
], return_tensors="pt", add_special_tokens=False, padding=True, truncation=True)
|
84 |
tgts = encodings['input_ids']
|
85 |
return tgts # (N, L)
|
|
|
|
idiomify/{data.py → datamodules.py}
RENAMED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
|
3 |
from torch.utils.data import Dataset, DataLoader
|
|
|
4 |
from pytorch_lightning import LightningDataModule
|
5 |
from wandb.sdk.wandb_run import Run
|
6 |
-
|
7 |
from idiomify.fetchers import fetch_literal2idiomatic
|
8 |
from idiomify.builders import SourcesBuilder, TargetsBuilder, TargetsRightShiftedBuilder
|
9 |
from transformers import BartTokenizer
|
@@ -38,9 +38,6 @@ class IdiomifyDataset(Dataset):
|
|
38 |
class IdiomifyDataModule(LightningDataModule):
|
39 |
|
40 |
# boilerplate - just ignore these
|
41 |
-
def test_dataloader(self):
|
42 |
-
pass
|
43 |
-
|
44 |
def val_dataloader(self):
|
45 |
pass
|
46 |
|
@@ -56,23 +53,37 @@ class IdiomifyDataModule(LightningDataModule):
|
|
56 |
self.tokenizer = tokenizer
|
57 |
self.run = run
|
58 |
# --- to be downloaded & built --- #
|
59 |
-
self.
|
60 |
-
self.
|
|
|
|
|
61 |
|
62 |
def prepare_data(self):
|
63 |
"""
|
64 |
prepare: download all data needed for this from wandb to local.
|
65 |
"""
|
66 |
-
self.
|
67 |
|
68 |
def setup(self, stage: Optional[str] = None):
|
69 |
# --- set up the builders --- #
|
70 |
# build the datasets
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
def train_dataloader(self) -> DataLoader:
|
77 |
-
return DataLoader(self.
|
78 |
shuffle=self.config['shuffle'], num_workers=self.config['num_workers'])
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
+
import pandas as pd
|
3 |
from torch.utils.data import Dataset, DataLoader
|
4 |
+
from typing import Tuple, Optional
|
5 |
from pytorch_lightning import LightningDataModule
|
6 |
from wandb.sdk.wandb_run import Run
|
|
|
7 |
from idiomify.fetchers import fetch_literal2idiomatic
|
8 |
from idiomify.builders import SourcesBuilder, TargetsBuilder, TargetsRightShiftedBuilder
|
9 |
from transformers import BartTokenizer
|
|
|
38 |
class IdiomifyDataModule(LightningDataModule):
|
39 |
|
40 |
# boilerplate - just ignore these
|
|
|
|
|
|
|
41 |
def val_dataloader(self):
|
42 |
pass
|
43 |
|
|
|
53 |
self.tokenizer = tokenizer
|
54 |
self.run = run
|
55 |
# --- to be downloaded & built --- #
|
56 |
+
self.train_df: Optional[pd.DataFrame] = None
|
57 |
+
self.test_df: Optional[pd.DataFrame] = None
|
58 |
+
self.train_dataset: Optional[IdiomifyDataset] = None
|
59 |
+
self.test_dataset: Optional[IdiomifyDataset] = None
|
60 |
|
61 |
def prepare_data(self):
|
62 |
"""
|
63 |
prepare: download all data needed for this from wandb to local.
|
64 |
"""
|
65 |
+
self.train_df, self.test_df = fetch_literal2idiomatic(self.config['literal2idiomatic_ver'], self.run)
|
66 |
|
67 |
def setup(self, stage: Optional[str] = None):
|
68 |
# --- set up the builders --- #
|
69 |
# build the datasets
|
70 |
+
self.train_dataset = self.build_dataset(self.train_df)
|
71 |
+
self.test_dataset = self.build_dataset(self.test_df)
|
72 |
+
|
73 |
+
def build_dataset(self, df: pd.DataFrame) -> IdiomifyDataset:
|
74 |
+
literal2idiomatic = [
|
75 |
+
(row['Literal_Sent'], row['Idiomatic_Sent'])
|
76 |
+
for _, row in df.iterrows()
|
77 |
+
]
|
78 |
+
srcs = SourcesBuilder(self.tokenizer)(literal2idiomatic)
|
79 |
+
tgts_r = TargetsRightShiftedBuilder(self.tokenizer)(literal2idiomatic)
|
80 |
+
tgts = TargetsBuilder(self.tokenizer)(literal2idiomatic)
|
81 |
+
return IdiomifyDataset(srcs, tgts_r, tgts)
|
82 |
|
83 |
def train_dataloader(self) -> DataLoader:
|
84 |
+
return DataLoader(self.train_dataset, batch_size=self.config['batch_size'],
|
85 |
shuffle=self.config['shuffle'], num_workers=self.config['num_workers'])
|
86 |
+
|
87 |
+
def test_dataloader(self) -> DataLoader:
|
88 |
+
return DataLoader(self.test_dataset, batch_size=self.config['batch_size'],
|
89 |
+
shuffle=False, num_workers=self.config['num_workers'])
|
idiomify/fetchers.py
CHANGED
@@ -1,25 +1,19 @@
|
|
1 |
-
import csv
|
2 |
-
from os import path
|
3 |
import yaml
|
4 |
import wandb
|
5 |
-
import
|
|
|
6 |
from typing import Tuple, List
|
7 |
from wandb.sdk.wandb_run import Run
|
8 |
-
from idiomify.paths import CONFIG_YAML, idioms_dir, literal2idiomatic,
|
9 |
from idiomify.urls import PIE_URL
|
10 |
from transformers import AutoModelForSeq2SeqLM, AutoConfig
|
11 |
-
from idiomify.models import
|
12 |
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
next(reader) # skip the header
|
19 |
-
return [
|
20 |
-
row
|
21 |
-
for row in reader
|
22 |
-
]
|
23 |
|
24 |
|
25 |
# --- from wandb --- #
|
@@ -39,7 +33,7 @@ def fetch_idioms(ver: str, run: Run = None) -> List[str]:
|
|
39 |
return [line.strip() for line in fh]
|
40 |
|
41 |
|
42 |
-
def fetch_literal2idiomatic(ver: str, run: Run = None) ->
|
43 |
# if run object is given, we track the lineage of the data.
|
44 |
# if not, we get the dataset via wandb Api.
|
45 |
if run:
|
@@ -47,23 +41,28 @@ def fetch_literal2idiomatic(ver: str, run: Run = None) -> List[Tuple[str, str]]:
|
|
47 |
else:
|
48 |
artifact = wandb.Api().artifact(f"eubinecto/idiomify/literal2idiomatic:{ver}", type="dataset")
|
49 |
artifact_dir = artifact.download(root=literal2idiomatic(ver))
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
54 |
|
55 |
|
56 |
-
def
|
|
|
|
|
|
|
|
|
57 |
if run:
|
58 |
-
artifact = run.use_artifact(f"
|
59 |
else:
|
60 |
-
artifact = wandb.Api().artifact(f"eubinecto/idiomify/
|
61 |
config = artifact.metadata
|
62 |
-
artifact_dir = artifact.download(root=
|
63 |
ckpt_path = path.join(artifact_dir, "model.ckpt")
|
64 |
bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
|
65 |
-
|
66 |
-
return
|
67 |
|
68 |
|
69 |
def fetch_config() -> dict:
|
|
|
|
|
|
|
1 |
import yaml
|
2 |
import wandb
|
3 |
+
from os import path
|
4 |
+
import pandas as pd
|
5 |
from typing import Tuple, List
|
6 |
from wandb.sdk.wandb_run import Run
|
7 |
+
from idiomify.paths import CONFIG_YAML, idioms_dir, literal2idiomatic, idiomifier_dir
|
8 |
from idiomify.urls import PIE_URL
|
9 |
from transformers import AutoModelForSeq2SeqLM, AutoConfig
|
10 |
+
from idiomify.models import Idiomifier
|
11 |
|
12 |
|
13 |
+
# --- from the web --- #
|
14 |
+
def fetch_pie() -> pd.DataFrame:
|
15 |
+
# fetch & parse it directly from the web
|
16 |
+
return pd.read_csv(PIE_URL)
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
|
19 |
# --- from wandb --- #
|
|
|
33 |
return [line.strip() for line in fh]
|
34 |
|
35 |
|
36 |
+
def fetch_literal2idiomatic(ver: str, run: Run = None) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
37 |
# if run object is given, we track the lineage of the data.
|
38 |
# if not, we get the dataset via wandb Api.
|
39 |
if run:
|
|
|
41 |
else:
|
42 |
artifact = wandb.Api().artifact(f"eubinecto/idiomify/literal2idiomatic:{ver}", type="dataset")
|
43 |
artifact_dir = artifact.download(root=literal2idiomatic(ver))
|
44 |
+
train_path = path.join(artifact_dir, "train.tsv")
|
45 |
+
test_path = path.join(artifact_dir, "test.tsv")
|
46 |
+
train_df = pd.read_csv(train_path, sep="\t")
|
47 |
+
test_df = pd.read_csv(test_path, sep="\t")
|
48 |
+
return train_df, test_df
|
49 |
|
50 |
|
51 |
+
def fetch_idiomifier(ver: str, run: Run = None) -> Idiomifier:
|
52 |
+
"""
|
53 |
+
you may want to change the name to Idiomifier.
|
54 |
+
The current Idiomifier then turns into a pipeline.
|
55 |
+
"""
|
56 |
if run:
|
57 |
+
artifact = run.use_artifact(f"idiomifier:{ver}", type="model")
|
58 |
else:
|
59 |
+
artifact = wandb.Api().artifact(f"eubinecto/idiomify/idiomifier:{ver}", type="model")
|
60 |
config = artifact.metadata
|
61 |
+
artifact_dir = artifact.download(root=idiomifier_dir(ver))
|
62 |
ckpt_path = path.join(artifact_dir, "model.ckpt")
|
63 |
bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
|
64 |
+
model = Idiomifier.load_from_checkpoint(ckpt_path, bart=bart)
|
65 |
+
return model
|
66 |
|
67 |
|
68 |
def fetch_config() -> dict:
|
idiomify/metrics.py
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
you may want to include bleu score.
|
3 |
-
and more metrics for paraphrasing.
|
4 |
-
"""
|
|
|
|
|
|
|
|
|
|
idiomify/models.py
CHANGED
@@ -7,17 +7,19 @@ from torch.nn import functional as F
|
|
7 |
import pytorch_lightning as pl
|
8 |
from transformers import BartForConditionalGeneration, BartTokenizer
|
9 |
from idiomify.builders import SourcesBuilder
|
|
|
10 |
|
11 |
-
|
12 |
-
# for training
|
13 |
-
class Seq2Seq(pl.LightningModule): # noqa
|
14 |
"""
|
15 |
the baseline is in here.
|
16 |
"""
|
17 |
def __init__(self, bart: BartForConditionalGeneration, lr: float, bos_token_id: int, pad_token_id: int): # noqa
|
18 |
super().__init__()
|
19 |
-
self.bart = bart
|
20 |
self.save_hyperparameters(ignore=["bart"])
|
|
|
|
|
|
|
|
|
21 |
|
22 |
def forward(self, srcs: torch.Tensor, tgts_r: torch.Tensor) -> torch.Tensor:
|
23 |
"""
|
@@ -38,10 +40,10 @@ class Seq2Seq(pl.LightningModule): # noqa
|
|
38 |
|
39 |
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> dict:
|
40 |
srcs, tgts_r, tgts = batch # (N, 2, L_s), (N, 2, L_t), (N, 2, L_t)
|
41 |
-
logits = self.forward(srcs, tgts_r) # -> (N, L, |V|)
|
42 |
-
logits = logits.transpose(1, 2) # (N, L, |V|) -> (N, |V|, L)
|
43 |
loss = F.cross_entropy(logits, tgts, ignore_index=self.hparams['pad_token_id'])\
|
44 |
.sum() # (N, L, |V|), (N, L) -> (N,) -> (1,)
|
|
|
45 |
return {
|
46 |
"loss": loss
|
47 |
}
|
@@ -49,6 +51,19 @@ class Seq2Seq(pl.LightningModule): # noqa
|
|
49 |
def on_train_batch_end(self, outputs: dict, *args, **kwargs):
|
50 |
self.log("Train/Loss", outputs['loss'])
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
def configure_optimizers(self) -> torch.optim.Optimizer:
|
53 |
"""
|
54 |
Instantiates and returns the optimizer to be used for this model
|
@@ -57,22 +72,3 @@ class Seq2Seq(pl.LightningModule): # noqa
|
|
57 |
# The authors used Adam, so we might as well use it as well.
|
58 |
return torch.optim.AdamW(self.parameters(), lr=self.hparams['lr'])
|
59 |
|
60 |
-
|
61 |
-
# for inference
|
62 |
-
class Idiomifier:
|
63 |
-
|
64 |
-
def __init__(self, model: Seq2Seq, tokenizer: BartTokenizer):
|
65 |
-
self.model = model
|
66 |
-
self.builder = SourcesBuilder(tokenizer)
|
67 |
-
self.model.eval()
|
68 |
-
|
69 |
-
def __call__(self, src: str, max_length=100) -> str:
|
70 |
-
srcs = self.builder(literal2idiomatic=[(src, "")])
|
71 |
-
pred_ids = self.model.bart.generate(
|
72 |
-
inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
|
73 |
-
attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
|
74 |
-
decoder_start_token_id=self.model.hparams['bos_token_id'],
|
75 |
-
max_length=max_length,
|
76 |
-
).squeeze() # -> (N, L_t) -> (L_t)
|
77 |
-
tgt = self.builder.tokenizer.decode(pred_ids, skip_special_tokens=True)
|
78 |
-
return tgt
|
|
|
7 |
import pytorch_lightning as pl
|
8 |
from transformers import BartForConditionalGeneration, BartTokenizer
|
9 |
from idiomify.builders import SourcesBuilder
|
10 |
+
from torchmetrics import Accuracy
|
11 |
|
12 |
+
class Idiomifier(pl.LightningModule): # noqa
|
|
|
|
|
13 |
"""
|
14 |
the baseline is in here.
|
15 |
"""
|
16 |
def __init__(self, bart: BartForConditionalGeneration, lr: float, bos_token_id: int, pad_token_id: int): # noqa
|
17 |
super().__init__()
|
|
|
18 |
self.save_hyperparameters(ignore=["bart"])
|
19 |
+
self.bart = bart
|
20 |
+
# metrics (using accuracies as of right now)
|
21 |
+
self.acc_train = Accuracy(ignore_index=pad_token_id)
|
22 |
+
self.acc_test = Accuracy(ignore_index=pad_token_id)
|
23 |
|
24 |
def forward(self, srcs: torch.Tensor, tgts_r: torch.Tensor) -> torch.Tensor:
|
25 |
"""
|
|
|
40 |
|
41 |
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> dict:
|
42 |
srcs, tgts_r, tgts = batch # (N, 2, L_s), (N, 2, L_t), (N, 2, L_t)
|
43 |
+
logits = self.forward(srcs, tgts_r).transpose(1, 2) # ... -> (N, L, |V|) -> (N, |V|, L)
|
|
|
44 |
loss = F.cross_entropy(logits, tgts, ignore_index=self.hparams['pad_token_id'])\
|
45 |
.sum() # (N, L, |V|), (N, L) -> (N,) -> (1,)
|
46 |
+
self.acc_train.update(logits.detach(), target=tgts.detach())
|
47 |
return {
|
48 |
"loss": loss
|
49 |
}
|
|
|
51 |
def on_train_batch_end(self, outputs: dict, *args, **kwargs):
|
52 |
self.log("Train/Loss", outputs['loss'])
|
53 |
|
54 |
+
def on_train_epoch_end(self, *args, **kwargs) -> None:
|
55 |
+
self.log("Train/Accuracy", self.acc_train.compute())
|
56 |
+
self.acc_train.reset()
|
57 |
+
|
58 |
+
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], *args, **kwargs):
|
59 |
+
srcs, tgts_r, tgts = batch # (N, 2, L_s), (N, 2, L_t), (N, 2, L_t)
|
60 |
+
logits = self.forward(srcs, tgts_r).transpose(1, 2) # ... -> (N, L, |V|) -> (N, |V|, L)
|
61 |
+
self.acc_test.update(logits.detach(), target=tgts.detach())
|
62 |
+
|
63 |
+
def on_test_epoch_end(self, *args, **kwargs) -> None:
|
64 |
+
self.log("Test/Accuracy", self.acc_test.compute())
|
65 |
+
self.acc_test.reset()
|
66 |
+
|
67 |
def configure_optimizers(self) -> torch.optim.Optimizer:
|
68 |
"""
|
69 |
Instantiates and returns the optimizer to be used for this model
|
|
|
72 |
# The authors used Adam, so we might as well use it as well.
|
73 |
return torch.optim.AdamW(self.parameters(), lr=self.hparams['lr'])
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
idiomify/paths.py
CHANGED
@@ -6,12 +6,12 @@ CONFIG_YAML = ROOT_DIR / "config.yaml"
|
|
6 |
|
7 |
|
8 |
def idioms_dir(ver: str) -> Path:
|
9 |
-
return ARTIFACTS_DIR / f"
|
10 |
|
11 |
|
12 |
def literal2idiomatic(ver: str) -> Path:
|
13 |
-
return ARTIFACTS_DIR / f"
|
14 |
|
15 |
|
16 |
-
def
|
17 |
-
return ARTIFACTS_DIR / f"
|
|
|
6 |
|
7 |
|
8 |
def idioms_dir(ver: str) -> Path:
|
9 |
+
return ARTIFACTS_DIR / f"idioms_{ver}"
|
10 |
|
11 |
|
12 |
def literal2idiomatic(ver: str) -> Path:
|
13 |
+
return ARTIFACTS_DIR / f"literal2idiomatic_{ver}"
|
14 |
|
15 |
|
16 |
+
def idiomifier_dir(ver: str) -> Path:
|
17 |
+
return ARTIFACTS_DIR / f"idiomifier_{ver}"
|
idiomify/pipeline.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from transformers import BartTokenizer
|
3 |
+
from idiomify.builders import SourcesBuilder
|
4 |
+
from idiomify.models import Idiomifier
|
5 |
+
|
6 |
+
|
7 |
+
class Pipeline:
|
8 |
+
|
9 |
+
def __init__(self, model: Idiomifier, tokenizer: BartTokenizer):
|
10 |
+
self.model = model
|
11 |
+
self.builder = SourcesBuilder(tokenizer)
|
12 |
+
|
13 |
+
def __call__(self, sents: List[str], max_length=100) -> List[str]:
|
14 |
+
srcs = self.builder(literal2idiomatic=[(sent, "") for sent in sents])
|
15 |
+
pred_ids = self.model.bart.generate(
|
16 |
+
inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
|
17 |
+
attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
|
18 |
+
decoder_start_token_id=self.model.hparams['bos_token_id'],
|
19 |
+
max_length=max_length,
|
20 |
+
) # -> (N, L_t)
|
21 |
+
tgts = self.builder.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
22 |
+
return tgts
|
idiomify/preprocess.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import pandas as pd
|
3 |
+
from sklearn.model_selection import train_test_split
|
4 |
+
|
5 |
+
|
6 |
+
def upsample(df: pd.DataFrame, seed: int) -> pd.DataFrame:
|
7 |
+
# TODO: implement upsampling later
|
8 |
+
return df
|
9 |
+
|
10 |
+
|
11 |
+
def cleanse(df: pd.DataFrame) -> pd.DataFrame:
|
12 |
+
"""
|
13 |
+
:param df:
|
14 |
+
:return:
|
15 |
+
"""
|
16 |
+
# TODO: implement cleansing
|
17 |
+
return df
|
18 |
+
|
19 |
+
|
20 |
+
def stratified_split(df: pd.DataFrame, ratio: float, seed: int) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
21 |
+
"""
|
22 |
+
stratified-split the given df into two df's.
|
23 |
+
"""
|
24 |
+
total = len(df)
|
25 |
+
ratio_size = int(total * ratio)
|
26 |
+
other_size = total - ratio_size
|
27 |
+
ratio_df, other_df = train_test_split(df, train_size=ratio_size,
|
28 |
+
stratify=df['Idiom'],
|
29 |
+
test_size=other_size, random_state=seed,
|
30 |
+
shuffle=True)
|
31 |
+
return ratio_df, other_df
|
main_deploy.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
we deploy the pipeline via streamlit.
|
3 |
+
"""
|
4 |
+
from typing import Tuple, List
|
5 |
+
import streamlit as st
|
6 |
+
from transformers import BartTokenizer
|
7 |
+
from idiomify.fetchers import fetch_config, fetch_idiomifier, fetch_idioms
|
8 |
+
from idiomify.pipeline import Pipeline
|
9 |
+
from idiomify.models import Idiomifier
|
10 |
+
|
11 |
+
|
12 |
+
@st.cache(allow_output_mutation=True)
|
13 |
+
def fetch_resources() -> Tuple[dict, Idiomifier, BartTokenizer, List[str]]:
|
14 |
+
config = fetch_config()['idiomifier']
|
15 |
+
model = fetch_idiomifier(config['ver'])
|
16 |
+
idioms = fetch_idioms(config['idioms_ver'])
|
17 |
+
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
18 |
+
return config, model, tokenizer, idioms
|
19 |
+
|
20 |
+
|
21 |
+
def main():
|
22 |
+
# fetch a pre-trained model
|
23 |
+
config, model, tokenizer, idioms = fetch_resources()
|
24 |
+
pipeline = Pipeline(model, tokenizer)
|
25 |
+
st.title("Idiomify Demo")
|
26 |
+
st.markdown(f"Author: `Eu-Bin KIM`")
|
27 |
+
st.markdown(f"Version: `{config['ver']}`")
|
28 |
+
text = st.text_area("Type sentences here",
|
29 |
+
value="Just remember there will always be a hope even when things look black")
|
30 |
+
with st.sidebar:
|
31 |
+
st.subheader("Supported idioms")
|
32 |
+
st.write(" / ".join(idioms))
|
33 |
+
|
34 |
+
if st.button(label="Idiomify"):
|
35 |
+
with st.spinner("Please wait..."):
|
36 |
+
sents = [sent for sent in text.split(".") if sent]
|
37 |
+
sents = pipeline(sents, max_length=200)
|
38 |
+
# highlight the rule & honorifics that were applied
|
39 |
+
st.write(". ".join(sents))
|
40 |
+
|
41 |
+
|
42 |
+
if __name__ == '__main__':
|
43 |
+
main()
|
main_eval.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import wandb
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
from pytorch_lightning.loggers import WandbLogger
|
7 |
+
from transformers import BartTokenizer
|
8 |
+
from idiomify.datamodules import IdiomifyDataModule
|
9 |
+
from idiomify.fetchers import fetch_config, fetch_idiomifier
|
10 |
+
from idiomify.paths import ROOT_DIR
|
11 |
+
|
12 |
+
|
13 |
+
def main():
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument("--num_workers", type=int, default=os.cpu_count())
|
16 |
+
parser.add_argument("--fast_dev_run", action="store_true", default=False)
|
17 |
+
args = parser.parse_args()
|
18 |
+
config = fetch_config()['idiomifier']
|
19 |
+
config.update(vars(args))
|
20 |
+
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
21 |
+
# prepare the datamodule
|
22 |
+
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
23 |
+
model = fetch_idiomifier(config['ver'], run) # fetch a pre-trained model
|
24 |
+
datamodule = IdiomifyDataModule(config, tokenizer, run)
|
25 |
+
logger = WandbLogger(log_model=False)
|
26 |
+
trainer = pl.Trainer(fast_dev_run=config['fast_dev_run'],
|
27 |
+
gpus=torch.cuda.device_count(),
|
28 |
+
default_root_dir=str(ROOT_DIR),
|
29 |
+
logger=logger)
|
30 |
+
trainer.test(model, datamodule)
|
31 |
+
|
32 |
+
|
33 |
+
if __name__ == '__main__':
|
34 |
+
main()
|
main_infer.py
CHANGED
@@ -1,23 +1,26 @@
|
|
|
|
|
|
|
|
1 |
import argparse
|
2 |
-
from idiomify.
|
3 |
-
from idiomify.fetchers import fetch_config,
|
4 |
from transformers import BartTokenizer
|
5 |
|
6 |
|
7 |
def main():
|
8 |
parser = argparse.ArgumentParser()
|
9 |
-
parser.add_argument("--
|
10 |
-
parser.add_argument("--src", type=str,
|
11 |
default="If there's any good to loosing my job,"
|
12 |
" it's that I'll now be able to go to school full-time and finish my degree earlier.")
|
13 |
args = parser.parse_args()
|
14 |
-
config = fetch_config()[
|
15 |
config.update(vars(args))
|
16 |
-
model =
|
|
|
17 |
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
18 |
-
|
19 |
-
src = config['
|
20 |
-
tgt =
|
21 |
print(src, "\n->", tgt)
|
22 |
|
23 |
|
|
|
1 |
+
"""
|
2 |
+
This is for just a simple sanity check on the inference.
|
3 |
+
"""
|
4 |
import argparse
|
5 |
+
from idiomify.pipeline import Pipeline
|
6 |
+
from idiomify.fetchers import fetch_config, fetch_idiomifier
|
7 |
from transformers import BartTokenizer
|
8 |
|
9 |
|
10 |
def main():
|
11 |
parser = argparse.ArgumentParser()
|
12 |
+
parser.add_argument("--sent", type=str,
|
|
|
13 |
default="If there's any good to loosing my job,"
|
14 |
" it's that I'll now be able to go to school full-time and finish my degree earlier.")
|
15 |
args = parser.parse_args()
|
16 |
+
config = fetch_config()['idiomifier']
|
17 |
config.update(vars(args))
|
18 |
+
model = fetch_idiomifier(config['ver'])
|
19 |
+
model.eval() # this is crucial
|
20 |
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
21 |
+
pipeline = Pipeline(model, tokenizer)
|
22 |
+
src = config['sent']
|
23 |
+
tgt = pipeline(sents=[config['sent']])
|
24 |
print(src, "\n->", tgt)
|
25 |
|
26 |
|
main_train.py
CHANGED
@@ -6,29 +6,27 @@ import pytorch_lightning as pl
|
|
6 |
from termcolor import colored
|
7 |
from pytorch_lightning.loggers import WandbLogger
|
8 |
from transformers import BartTokenizer, BartForConditionalGeneration
|
9 |
-
from idiomify.
|
10 |
from idiomify.fetchers import fetch_config
|
11 |
-
from idiomify.models import
|
12 |
from idiomify.paths import ROOT_DIR
|
13 |
|
14 |
|
15 |
def main():
|
16 |
parser = argparse.ArgumentParser()
|
17 |
-
parser.add_argument("--ver", type=str, default="tag011")
|
18 |
parser.add_argument("--num_workers", type=int, default=os.cpu_count())
|
19 |
parser.add_argument("--log_every_n_steps", type=int, default=1)
|
20 |
parser.add_argument("--fast_dev_run", action="store_true", default=False)
|
21 |
parser.add_argument("--upload", dest='upload', action='store_true', default=False)
|
22 |
args = parser.parse_args()
|
23 |
-
config = fetch_config()[
|
24 |
config.update(vars(args))
|
25 |
if not config['upload']:
|
26 |
print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
|
27 |
-
|
28 |
# prepare the model
|
29 |
bart = BartForConditionalGeneration.from_pretrained(config['bart'])
|
30 |
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
31 |
-
model =
|
32 |
# prepare the datamodule
|
33 |
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
34 |
datamodule = IdiomifyDataModule(config, tokenizer, run)
|
@@ -46,7 +44,7 @@ def main():
|
|
46 |
if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1:
|
47 |
ckpt_path = ROOT_DIR / "model.ckpt"
|
48 |
trainer.save_checkpoint(str(ckpt_path))
|
49 |
-
artifact = wandb.Artifact(name="
|
50 |
artifact.add_file(str(ckpt_path))
|
51 |
run.log_artifact(artifact, aliases=["latest", config['ver']])
|
52 |
os.remove(str(ckpt_path)) # make sure you remove it after you are done with uploading it
|
|
|
6 |
from termcolor import colored
|
7 |
from pytorch_lightning.loggers import WandbLogger
|
8 |
from transformers import BartTokenizer, BartForConditionalGeneration
|
9 |
+
from idiomify.datamodules import IdiomifyDataModule
|
10 |
from idiomify.fetchers import fetch_config
|
11 |
+
from idiomify.models import Idiomifier
|
12 |
from idiomify.paths import ROOT_DIR
|
13 |
|
14 |
|
15 |
def main():
|
16 |
parser = argparse.ArgumentParser()
|
|
|
17 |
parser.add_argument("--num_workers", type=int, default=os.cpu_count())
|
18 |
parser.add_argument("--log_every_n_steps", type=int, default=1)
|
19 |
parser.add_argument("--fast_dev_run", action="store_true", default=False)
|
20 |
parser.add_argument("--upload", dest='upload', action='store_true', default=False)
|
21 |
args = parser.parse_args()
|
22 |
+
config = fetch_config()['idiomifier']
|
23 |
config.update(vars(args))
|
24 |
if not config['upload']:
|
25 |
print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
|
|
|
26 |
# prepare the model
|
27 |
bart = BartForConditionalGeneration.from_pretrained(config['bart'])
|
28 |
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
29 |
+
model = Idiomifier(bart, config['lr'], tokenizer.bos_token_id, tokenizer.pad_token_id)
|
30 |
# prepare the datamodule
|
31 |
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
32 |
datamodule = IdiomifyDataModule(config, tokenizer, run)
|
|
|
44 |
if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1:
|
45 |
ckpt_path = ROOT_DIR / "model.ckpt"
|
46 |
trainer.save_checkpoint(str(ckpt_path))
|
47 |
+
artifact = wandb.Artifact(name="idiomifier", type="model", metadata=config)
|
48 |
artifact.add_file(str(ckpt_path))
|
49 |
run.log_artifact(artifact, aliases=["latest", config['ver']])
|
50 |
os.remove(str(ckpt_path)) # make sure you remove it after you are done with uploading it
|
main_upload_idioms.py
CHANGED
@@ -1,35 +1,31 @@
|
|
1 |
"""
|
2 |
-
|
3 |
-
|
4 |
"""
|
5 |
import os
|
6 |
-
from idiomify.paths import ROOT_DIR
|
7 |
-
from idiomify.fetchers import fetch_pie
|
8 |
-
import argparse
|
9 |
import wandb
|
|
|
|
|
10 |
|
11 |
|
12 |
def main():
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
# get the idioms here
|
18 |
-
if config['ver'] == "tag01":
|
19 |
-
# only the first 106, and this is for piloting
|
20 |
-
idioms = set([row[0] for row in fetch_pie()[:106]])
|
21 |
-
else:
|
22 |
-
raise NotImplementedError
|
23 |
-
idioms = list(idioms)
|
24 |
|
25 |
-
with wandb.init(entity="eubinecto", project="idiomify"
|
26 |
-
|
27 |
txt_path = ROOT_DIR / "all.txt"
|
28 |
with open(txt_path, 'w') as fh:
|
29 |
for idiom in idioms:
|
30 |
fh.write(idiom + "\n")
|
|
|
|
|
31 |
artifact.add_file(txt_path)
|
|
|
32 |
run.log_artifact(artifact, aliases=["latest", config['ver']])
|
|
|
33 |
os.remove(txt_path)
|
34 |
|
35 |
|
|
|
1 |
"""
|
2 |
+
will do this when I need to.
|
3 |
+
Is it absolutely necessary to keep track of idioms separately?
|
4 |
"""
|
5 |
import os
|
|
|
|
|
|
|
6 |
import wandb
|
7 |
+
from idiomify.fetchers import fetch_literal2idiomatic, fetch_config
|
8 |
+
from idiomify.paths import ROOT_DIR
|
9 |
|
10 |
|
11 |
def main():
|
12 |
+
config = fetch_config()['idioms']
|
13 |
+
train_df, _ = fetch_literal2idiomatic(config['ver'])
|
14 |
+
idioms = train_df['Idiom'].tolist()
|
15 |
+
idioms = list(set(idioms))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
with wandb.init(entity="eubinecto", project="idiomify") as run:
|
18 |
+
# the paths to write datasets in
|
19 |
txt_path = ROOT_DIR / "all.txt"
|
20 |
with open(txt_path, 'w') as fh:
|
21 |
for idiom in idioms:
|
22 |
fh.write(idiom + "\n")
|
23 |
+
artifact = wandb.Artifact(name="idioms", type="dataset", description=config['description'],
|
24 |
+
metadata=config)
|
25 |
artifact.add_file(txt_path)
|
26 |
+
# then, we just log them here.
|
27 |
run.log_artifact(artifact, aliases=["latest", config['ver']])
|
28 |
+
# don't forget to remove them
|
29 |
os.remove(txt_path)
|
30 |
|
31 |
|
main_upload_literal2idiomatic.py
CHANGED
@@ -1,39 +1,40 @@
|
|
1 |
"""
|
2 |
-
|
3 |
-
just upload all idioms here - name it as epie.
|
4 |
"""
|
5 |
-
import csv
|
6 |
import os
|
7 |
from idiomify.paths import ROOT_DIR
|
8 |
-
from idiomify.fetchers import fetch_pie
|
9 |
-
import
|
10 |
import wandb
|
11 |
|
12 |
|
13 |
def main():
|
14 |
-
parser = argparse.ArgumentParser()
|
15 |
-
parser.add_argument("--ver", type=str, default="tag01")
|
16 |
-
config = vars(parser.parse_args())
|
17 |
|
18 |
-
#
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
35 |
run.log_artifact(artifact, aliases=["latest", config['ver']])
|
36 |
-
|
|
|
|
|
37 |
|
38 |
|
39 |
if __name__ == '__main__':
|
|
|
1 |
"""
|
2 |
+
literal2idiomatic ver: d-1-2
|
|
|
3 |
"""
|
|
|
4 |
import os
|
5 |
from idiomify.paths import ROOT_DIR
|
6 |
+
from idiomify.fetchers import fetch_pie, fetch_config
|
7 |
+
from idiomify.preprocess import upsample, cleanse, stratified_split
|
8 |
import wandb
|
9 |
|
10 |
|
11 |
def main():
|
|
|
|
|
|
|
12 |
|
13 |
+
# here, we use all of them, while splitting them into train & test
|
14 |
+
pie_df = fetch_pie()
|
15 |
+
config = fetch_config()['literal2idiomatic']
|
16 |
+
train_df, test_df = pie_df.pipe(cleanse)\
|
17 |
+
.pipe(upsample, seed=config['seed'])\
|
18 |
+
.pipe(stratified_split, ratio=config['train_ratio'], seed=config['seed'])
|
19 |
+
# why don't you just "select" the columns? yeah, stop using csv library. just select them.
|
20 |
+
train_df = train_df[["Idiom", "Literal_Sent", "Idiomatic_Sent"]]
|
21 |
+
test_df = test_df[["Idiom", "Literal_Sent", "Idiomatic_Sent"]]
|
22 |
+
dfs = (train_df, test_df)
|
23 |
+
with wandb.init(entity="eubinecto", project="idiomify") as run:
|
24 |
+
# the paths to write datasets in
|
25 |
+
train_path = ROOT_DIR / "train.tsv"
|
26 |
+
test_path = ROOT_DIR / "test.tsv"
|
27 |
+
paths = (train_path, test_path)
|
28 |
+
artifact = wandb.Artifact(name="literal2idiomatic", type="dataset", description=config['description'],
|
29 |
+
metadata=config)
|
30 |
+
for tsv_path, df in zip(paths, dfs):
|
31 |
+
df.to_csv(tsv_path, sep="\t")
|
32 |
+
artifact.add_file(tsv_path)
|
33 |
+
# then, we just log them here.
|
34 |
run.log_artifact(artifact, aliases=["latest", config['ver']])
|
35 |
+
# don't forget to remove them
|
36 |
+
for tsv_path in paths:
|
37 |
+
os.remove(tsv_path)
|
38 |
|
39 |
|
40 |
if __name__ == '__main__':
|
requirements.txt
CHANGED
@@ -1,3 +1,7 @@
|
|
1 |
pytorch-lightning==1.5.10
|
2 |
transformers==4.16.2
|
3 |
-
wandb==0.12.10
|
|
|
|
|
|
|
|
|
|
1 |
pytorch-lightning==1.5.10
|
2 |
transformers==4.16.2
|
3 |
+
wandb==0.12.10
|
4 |
+
scikit-learn==1.0.2
|
5 |
+
pandas==1.3.5
|
6 |
+
streamlit==1.7.0
|
7 |
+
watchdog==2.1.6
|