[#1] IdiomifyDataModule implemented (srcs, tgts_r, tgts)
Browse files- explore/{explore_fetch_literal2idiom.py → explore_fetch_literal2idiomatic.py} +2 -2
- explore/explore_idiomifydatamodule.py +26 -0
- explore/explore_src_builder.py +2 -2
- explore/explore_tgt_builder.py +2 -2
- idiomify/builders.py +20 -15
- idiomify/datamodules.py +21 -59
- idiomify/fetchers.py +4 -27
- idiomify/paths.py +2 -2
- main_train.py +2 -2
- main_upload_literal2idiom.py → main_upload_literal2idiomatic.py +1 -1
explore/{explore_fetch_literal2idiom.py → explore_fetch_literal2idiomatic.py}
RENAMED
@@ -1,8 +1,8 @@
|
|
1 |
-
from idiomify.fetchers import
|
2 |
|
3 |
|
4 |
def main():
|
5 |
-
for src, tgt in
|
6 |
print(src, "->", tgt)
|
7 |
|
8 |
|
|
|
1 |
+
from idiomify.fetchers import fetch_literal2idiomatic
|
2 |
|
3 |
|
4 |
def main():
|
5 |
+
for src, tgt in fetch_literal2idiomatic("pie_v0"):
|
6 |
print(src, "->", tgt)
|
7 |
|
8 |
|
explore/explore_idiomifydatamodule.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BartTokenizer
|
2 |
+
from idiomify.datamodules import IdiomifyDataModule
|
3 |
+
|
4 |
+
|
5 |
+
CONFIG = {
|
6 |
+
"literal2idiomatic_ver": "pie_v0",
|
7 |
+
"batch_size": 20,
|
8 |
+
"num_workers": 4,
|
9 |
+
"shuffle": True
|
10 |
+
}
|
11 |
+
|
12 |
+
|
13 |
+
def main():
|
14 |
+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
15 |
+
datamodule = IdiomifyDataModule(CONFIG, tokenizer)
|
16 |
+
datamodule.prepare_data()
|
17 |
+
datamodule.setup()
|
18 |
+
for batch in datamodule.train_dataloader():
|
19 |
+
srcs, tgts_r, tgts = batch
|
20 |
+
print(srcs.shape)
|
21 |
+
print(tgts_r.shape)
|
22 |
+
print(tgts.shape)
|
23 |
+
|
24 |
+
|
25 |
+
if __name__ == '__main__':
|
26 |
+
main()
|
explore/explore_src_builder.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from transformers import BartTokenizer
|
2 |
-
from idiomify.builders import
|
3 |
|
4 |
BATCH = [
|
5 |
("I could die at any moment", "I could kick the bucket at any moment"),
|
@@ -9,7 +9,7 @@ BATCH = [
|
|
9 |
|
10 |
def main():
|
11 |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
12 |
-
builder =
|
13 |
src = builder(BATCH)
|
14 |
print(src)
|
15 |
|
|
|
1 |
from transformers import BartTokenizer
|
2 |
+
from idiomify.builders import SourcesBuilder
|
3 |
|
4 |
BATCH = [
|
5 |
("I could die at any moment", "I could kick the bucket at any moment"),
|
|
|
9 |
|
10 |
def main():
|
11 |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
12 |
+
builder = SourcesBuilder(tokenizer)
|
13 |
src = builder(BATCH)
|
14 |
print(src)
|
15 |
|
explore/explore_tgt_builder.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from transformers import BartTokenizer
|
2 |
-
from idiomify.builders import
|
3 |
|
4 |
BATCH = [
|
5 |
("I could die at any moment", "I could kick the bucket at any moment"),
|
@@ -9,7 +9,7 @@ BATCH = [
|
|
9 |
|
10 |
def main():
|
11 |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
12 |
-
builder =
|
13 |
tgt_r, tgt = builder(BATCH)
|
14 |
print(tgt_r)
|
15 |
print(tgt)
|
|
|
1 |
from transformers import BartTokenizer
|
2 |
+
from idiomify.builders import TargetsBuilder
|
3 |
|
4 |
BATCH = [
|
5 |
("I could die at any moment", "I could kick the bucket at any moment"),
|
|
|
9 |
|
10 |
def main():
|
11 |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
12 |
+
builder = TargetsBuilder(tokenizer)
|
13 |
tgt_r, tgt = builder(BATCH)
|
14 |
print(tgt_r)
|
15 |
print(tgt)
|
idiomify/builders.py
CHANGED
@@ -4,12 +4,12 @@ builders must accept device as one of the parameters.
|
|
4 |
"""
|
5 |
import torch
|
6 |
from typing import List, Tuple
|
7 |
-
from transformers import BertTokenizer
|
8 |
|
9 |
|
10 |
class TensorBuilder:
|
11 |
|
12 |
-
def __init__(self, tokenizer:
|
13 |
self.tokenizer = tokenizer
|
14 |
|
15 |
def __call__(self, *args, **kwargs) -> torch.Tensor:
|
@@ -45,7 +45,7 @@ class Idiom2SubwordsBuilder(TensorBuilder):
|
|
45 |
return input_ids
|
46 |
|
47 |
|
48 |
-
class
|
49 |
"""
|
50 |
to be used for both training and inference
|
51 |
"""
|
@@ -60,24 +60,29 @@ class SRCBuilder(TensorBuilder):
|
|
60 |
return src # (N, 2, L)
|
61 |
|
62 |
|
63 |
-
class
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
def __call__(self, literal2idiomatic: List[Tuple[str, str]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
68 |
-
encodings_r = self.tokenizer([
|
69 |
self.tokenizer.bos_token + idiomatic # starts with bos, but does not end with eos (right-shifted)
|
70 |
for _, idiomatic in literal2idiomatic
|
71 |
], return_tensors="pt", add_special_tokens=False, padding=True, truncation=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
encodings = self.tokenizer([
|
73 |
idiomatic + self.tokenizer.eos_token # no bos, but ends with eos
|
74 |
for _, idiomatic in literal2idiomatic
|
75 |
], return_tensors="pt", add_special_tokens=False, padding=True, truncation=True)
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
encodings['attention_mask']], dim=1) # (N, 2, L)
|
80 |
-
return tgt_r, tgt
|
81 |
-
|
82 |
|
83 |
|
|
|
4 |
"""
|
5 |
import torch
|
6 |
from typing import List, Tuple
|
7 |
+
from transformers import BertTokenizer, BartTokenizer
|
8 |
|
9 |
|
10 |
class TensorBuilder:
|
11 |
|
12 |
+
def __init__(self, tokenizer: BartTokenizer):
|
13 |
self.tokenizer = tokenizer
|
14 |
|
15 |
def __call__(self, *args, **kwargs) -> torch.Tensor:
|
|
|
45 |
return input_ids
|
46 |
|
47 |
|
48 |
+
class SourcesBuilder(TensorBuilder):
|
49 |
"""
|
50 |
to be used for both training and inference
|
51 |
"""
|
|
|
60 |
return src # (N, 2, L)
|
61 |
|
62 |
|
63 |
+
class TargetsRightShiftedBuilder(TensorBuilder):
|
64 |
+
|
65 |
+
def __call__(self, literal2idiomatic: List[Tuple[str, str]]) -> torch.Tensor:
|
66 |
+
encodings = self.tokenizer([
|
|
|
|
|
67 |
self.tokenizer.bos_token + idiomatic # starts with bos, but does not end with eos (right-shifted)
|
68 |
for _, idiomatic in literal2idiomatic
|
69 |
], return_tensors="pt", add_special_tokens=False, padding=True, truncation=True)
|
70 |
+
tgts_r = torch.stack([encodings['input_ids'],
|
71 |
+
encodings['attention_mask']], dim=1) # (N, 2, L)
|
72 |
+
return tgts_r
|
73 |
+
|
74 |
+
|
75 |
+
class TargetsBuilder(TensorBuilder):
|
76 |
+
"""
|
77 |
+
This is to be used only for training. As for inference, we don't need this.
|
78 |
+
"""
|
79 |
+
def __call__(self, literal2idiomatic: List[Tuple[str, str]]) -> torch.Tensor:
|
80 |
encodings = self.tokenizer([
|
81 |
idiomatic + self.tokenizer.eos_token # no bos, but ends with eos
|
82 |
for _, idiomatic in literal2idiomatic
|
83 |
], return_tensors="pt", add_special_tokens=False, padding=True, truncation=True)
|
84 |
+
tgts = torch.stack([encodings['input_ids'],
|
85 |
+
encodings['attention_mask']], dim=1) # (N, 2, L)
|
86 |
+
return tgts
|
|
|
|
|
|
|
87 |
|
88 |
|
idiomify/datamodules.py
CHANGED
@@ -2,35 +2,38 @@ import torch
|
|
2 |
from typing import Tuple, Optional, List
|
3 |
from torch.utils.data import Dataset, DataLoader
|
4 |
from pytorch_lightning import LightningDataModule
|
5 |
-
from idiomify.fetchers import
|
6 |
-
from idiomify.builders import
|
7 |
-
from transformers import
|
8 |
|
9 |
|
10 |
class IdiomifyDataset(Dataset):
|
11 |
def __init__(self,
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
self.
|
|
|
|
|
16 |
|
17 |
def __len__(self) -> int:
|
18 |
"""
|
19 |
Returning the size of the dataset
|
20 |
:return:
|
21 |
"""
|
22 |
-
|
|
|
23 |
|
24 |
-
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.LongTensor]:
|
25 |
"""
|
26 |
Returns features & the label
|
27 |
:param idx:
|
28 |
:return:
|
29 |
"""
|
30 |
-
return self.
|
31 |
|
32 |
|
33 |
-
class
|
34 |
|
35 |
# boilerplate - just ignore these
|
36 |
def test_dataloader(self):
|
@@ -44,21 +47,19 @@ class Idiom2DefDataModule(LightningDataModule):
|
|
44 |
|
45 |
def __init__(self,
|
46 |
config: dict,
|
47 |
-
tokenizer:
|
48 |
-
idioms: List[str]):
|
49 |
super().__init__()
|
50 |
self.config = config
|
51 |
self.tokenizer = tokenizer
|
52 |
-
self.idioms = idioms
|
53 |
# --- to be downloaded & built --- #
|
54 |
-
self.
|
55 |
self.dataset: Optional[IdiomifyDataset] = None
|
56 |
|
57 |
def prepare_data(self):
|
58 |
"""
|
59 |
prepare: download all data needed for this from wandb to local.
|
60 |
"""
|
61 |
-
self.
|
62 |
|
63 |
def setup(self, stage: Optional[str] = None):
|
64 |
"""
|
@@ -66,50 +67,11 @@ class Idiom2DefDataModule(LightningDataModule):
|
|
66 |
"""
|
67 |
# --- set up the builders --- #
|
68 |
# build the datasets
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
72 |
|
73 |
def train_dataloader(self) -> DataLoader:
|
74 |
return DataLoader(self.dataset, batch_size=self.config['batch_size'],
|
75 |
shuffle=self.config['shuffle'], num_workers=self.config['num_workers'])
|
76 |
-
|
77 |
-
|
78 |
-
class Idiom2ContextsDataModule(LightningDataModule):
|
79 |
-
|
80 |
-
# boilerplate - just ignore these
|
81 |
-
def test_dataloader(self):
|
82 |
-
pass
|
83 |
-
|
84 |
-
def val_dataloader(self):
|
85 |
-
pass
|
86 |
-
|
87 |
-
def predict_dataloader(self):
|
88 |
-
pass
|
89 |
-
|
90 |
-
def __init__(self, config: dict, tokenizer: BertTokenizer, idioms: List[str]):
|
91 |
-
super().__init__()
|
92 |
-
self.config = config
|
93 |
-
self.tokenizer = tokenizer
|
94 |
-
self.idioms = idioms
|
95 |
-
self.idiom2context: Optional[List[Tuple[str, str]]] = None
|
96 |
-
self.dataset: Optional[IdiomifyDataset] = None
|
97 |
-
|
98 |
-
def prepare_data(self):
|
99 |
-
"""
|
100 |
-
prepare: download all data needed for this from wandb to local.
|
101 |
-
"""
|
102 |
-
self.idiom2context = [
|
103 |
-
(idiom, context)
|
104 |
-
for idiom, _, context in fetch_epie()
|
105 |
-
]
|
106 |
-
|
107 |
-
def setup(self, stage: Optional[str] = None):
|
108 |
-
# build the datasets
|
109 |
-
X = Idiom2ContextBuilder(self.tokenizer)(self.idiom2context)
|
110 |
-
y = LabelsBuilder(self.tokenizer)(self.idiom2context, self.idioms)
|
111 |
-
self.dataset = IdiomifyDataset(X, y)
|
112 |
-
|
113 |
-
def train_dataloader(self):
|
114 |
-
return DataLoader(self.dataset, batch_size=self.config['batch_size'],
|
115 |
-
shuffle=self.config['shuffle'], num_workers=self.config['num_workers'])
|
|
|
2 |
from typing import Tuple, Optional, List
|
3 |
from torch.utils.data import Dataset, DataLoader
|
4 |
from pytorch_lightning import LightningDataModule
|
5 |
+
from idiomify.fetchers import fetch_literal2idiomatic
|
6 |
+
from idiomify.builders import SourcesBuilder, TargetsBuilder, TargetsRightShiftedBuilder
|
7 |
+
from transformers import BartTokenizer
|
8 |
|
9 |
|
10 |
class IdiomifyDataset(Dataset):
|
11 |
def __init__(self,
|
12 |
+
srcs: torch.Tensor,
|
13 |
+
tgts_r: torch.Tensor,
|
14 |
+
tgts: torch.Tensor):
|
15 |
+
self.srcs = srcs
|
16 |
+
self.tgts_r = tgts_r
|
17 |
+
self.tgts = tgts
|
18 |
|
19 |
def __len__(self) -> int:
|
20 |
"""
|
21 |
Returning the size of the dataset
|
22 |
:return:
|
23 |
"""
|
24 |
+
assert self.srcs.shape[0] == self.tgts_r.shape[0] == self.tgts.shape[0]
|
25 |
+
return self.srcs.shape[0]
|
26 |
|
27 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
|
28 |
"""
|
29 |
Returns features & the label
|
30 |
:param idx:
|
31 |
:return:
|
32 |
"""
|
33 |
+
return self.srcs[idx], self.tgts_r[idx], self.tgts[idx]
|
34 |
|
35 |
|
36 |
+
class IdiomifyDataModule(LightningDataModule):
|
37 |
|
38 |
# boilerplate - just ignore these
|
39 |
def test_dataloader(self):
|
|
|
47 |
|
48 |
def __init__(self,
|
49 |
config: dict,
|
50 |
+
tokenizer: BartTokenizer):
|
|
|
51 |
super().__init__()
|
52 |
self.config = config
|
53 |
self.tokenizer = tokenizer
|
|
|
54 |
# --- to be downloaded & built --- #
|
55 |
+
self.literal2idiomatic: Optional[List[Tuple[str, str]]] = None
|
56 |
self.dataset: Optional[IdiomifyDataset] = None
|
57 |
|
58 |
def prepare_data(self):
|
59 |
"""
|
60 |
prepare: download all data needed for this from wandb to local.
|
61 |
"""
|
62 |
+
self.literal2idiomatic = fetch_literal2idiomatic(self.config['literal2idiomatic_ver'])
|
63 |
|
64 |
def setup(self, stage: Optional[str] = None):
|
65 |
"""
|
|
|
67 |
"""
|
68 |
# --- set up the builders --- #
|
69 |
# build the datasets
|
70 |
+
srcs = SourcesBuilder(self.tokenizer)(self.literal2idiomatic)
|
71 |
+
tgts_r = TargetsRightShiftedBuilder(self.tokenizer)(self.literal2idiomatic)
|
72 |
+
tgts = TargetsBuilder(self.tokenizer)(self.literal2idiomatic)
|
73 |
+
self.dataset = IdiomifyDataset(srcs, tgts_r, tgts)
|
74 |
|
75 |
def train_dataloader(self) -> DataLoader:
|
76 |
return DataLoader(self.dataset, batch_size=self.config['batch_size'],
|
77 |
shuffle=self.config['shuffle'], num_workers=self.config['num_workers'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
idiomify/fetchers.py
CHANGED
@@ -5,10 +5,7 @@ import wandb
|
|
5 |
import requests
|
6 |
from typing import Tuple, List
|
7 |
from wandb.sdk.wandb_run import Run
|
8 |
-
from
|
9 |
-
from idiomify.builders import Idiom2SubwordsBuilder
|
10 |
-
from idiomify.models import Alpha, RD
|
11 |
-
from idiomify.paths import CONFIG_YAML, idioms_dir, alpha_dir, literal2idiom
|
12 |
from idiomify.urls import (
|
13 |
EPIE_IMMUTABLE_IDIOMS_URL,
|
14 |
EPIE_IMMUTABLE_IDIOMS_CONTEXTS_URL,
|
@@ -74,40 +71,20 @@ def fetch_idioms(ver: str, run: Run = None) -> List[str]:
|
|
74 |
return [line.strip() for line in fh]
|
75 |
|
76 |
|
77 |
-
def
|
78 |
# if run object is given, we track the lineage of the data.
|
79 |
# if not, we get the dataset via wandb Api.
|
80 |
if run:
|
81 |
artifact = run.use_artifact("literal2idiom", type="dataset", aliases=ver)
|
82 |
else:
|
83 |
-
artifact = wandb.Api().artifact(f"eubinecto/idiomify/
|
84 |
-
artifact_dir = artifact.download(root=
|
85 |
tsv_path = path.join(artifact_dir, "all.tsv")
|
86 |
with open(tsv_path, 'r') as fh:
|
87 |
reader = csv.reader(fh, delimiter="\t")
|
88 |
return [(row[0], row[1]) for row in reader]
|
89 |
|
90 |
|
91 |
-
def fetch_rd(model: str, ver: str) -> RD:
|
92 |
-
artifact = wandb.Api().artifact(f"eubinecto/idiomify-demo/{model}:{ver}", type="model")
|
93 |
-
config = artifact.metadata
|
94 |
-
artifact_path = alpha_dir(ver)
|
95 |
-
artifact.download(root=str(artifact_path))
|
96 |
-
mlm = AutoModelForMaskedLM.from_config(AutoConfig.from_pretrained(config['bert']))
|
97 |
-
ckpt_path = artifact_path / "rd.ckpt"
|
98 |
-
idioms = fetch_idioms(config['idioms_ver'])
|
99 |
-
tokenizer = BertTokenizer.from_pretrained(config['bert'])
|
100 |
-
idiom2subwords = Idiom2SubwordsBuilder(tokenizer)(idioms, config['k'])
|
101 |
-
# if model == Alpha.name():
|
102 |
-
# rd = Alpha.load_from_checkpoint(str(ckpt_path), mlm=mlm, idiom2subwords=idiom2subwords)
|
103 |
-
# elif model == Gamma.name():
|
104 |
-
# rd = Gamma.load_from_checkpoint(str(ckpt_path), mlm=mlm, idiom2subwords=idiom2subwords)
|
105 |
-
# else:
|
106 |
-
# raise ValueError
|
107 |
-
rd = ...
|
108 |
-
return rd
|
109 |
-
|
110 |
-
|
111 |
def fetch_config() -> dict:
|
112 |
with open(str(CONFIG_YAML), 'r', encoding="utf-8") as fh:
|
113 |
return yaml.safe_load(fh)
|
|
|
5 |
import requests
|
6 |
from typing import Tuple, List
|
7 |
from wandb.sdk.wandb_run import Run
|
8 |
+
from idiomify.paths import CONFIG_YAML, idioms_dir, literal2idiomatic
|
|
|
|
|
|
|
9 |
from idiomify.urls import (
|
10 |
EPIE_IMMUTABLE_IDIOMS_URL,
|
11 |
EPIE_IMMUTABLE_IDIOMS_CONTEXTS_URL,
|
|
|
71 |
return [line.strip() for line in fh]
|
72 |
|
73 |
|
74 |
+
def fetch_literal2idiomatic(ver: str, run: Run = None) -> List[Tuple[str, str]]:
|
75 |
# if run object is given, we track the lineage of the data.
|
76 |
# if not, we get the dataset via wandb Api.
|
77 |
if run:
|
78 |
artifact = run.use_artifact("literal2idiom", type="dataset", aliases=ver)
|
79 |
else:
|
80 |
+
artifact = wandb.Api().artifact(f"eubinecto/idiomify/literal2idiomatic:{ver}", type="dataset")
|
81 |
+
artifact_dir = artifact.download(root=literal2idiomatic(ver))
|
82 |
tsv_path = path.join(artifact_dir, "all.tsv")
|
83 |
with open(tsv_path, 'r') as fh:
|
84 |
reader = csv.reader(fh, delimiter="\t")
|
85 |
return [(row[0], row[1]) for row in reader]
|
86 |
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
def fetch_config() -> dict:
|
89 |
with open(str(CONFIG_YAML), 'r', encoding="utf-8") as fh:
|
90 |
return yaml.safe_load(fh)
|
idiomify/paths.py
CHANGED
@@ -9,8 +9,8 @@ def idioms_dir(ver: str) -> Path:
|
|
9 |
return ARTIFACTS_DIR / f"idioms_{ver}"
|
10 |
|
11 |
|
12 |
-
def
|
13 |
-
return ARTIFACTS_DIR / f"
|
14 |
|
15 |
|
16 |
def alpha_dir(ver: str) -> Path:
|
|
|
9 |
return ARTIFACTS_DIR / f"idioms_{ver}"
|
10 |
|
11 |
|
12 |
+
def literal2idiomatic(ver: str) -> Path:
|
13 |
+
return ARTIFACTS_DIR / f"literal2idiomatic_{ver}"
|
14 |
|
15 |
|
16 |
def alpha_dir(ver: str) -> Path:
|
main_train.py
CHANGED
@@ -6,7 +6,7 @@ import pytorch_lightning as pl
|
|
6 |
from pytorch_lightning.loggers import WandbLogger
|
7 |
from termcolor import colored
|
8 |
from transformers import BertForMaskedLM, BertTokenizer
|
9 |
-
from idiomify.datamodules import
|
10 |
from idiomify.fetchers import fetch_config, fetch_idioms
|
11 |
from idiomify.models import Alpha, Gamma
|
12 |
from idiomify.paths import ROOT_DIR
|
@@ -40,7 +40,7 @@ def main():
|
|
40 |
else:
|
41 |
raise ValueError
|
42 |
# prepare datamodule
|
43 |
-
datamodule =
|
44 |
|
45 |
with wandb.init(entity="eubinecto", project="idiomify-demo", config=config) as run:
|
46 |
logger = WandbLogger(log_model=False)
|
|
|
6 |
from pytorch_lightning.loggers import WandbLogger
|
7 |
from termcolor import colored
|
8 |
from transformers import BertForMaskedLM, BertTokenizer
|
9 |
+
from idiomify.datamodules import IdiomifyDataModule
|
10 |
from idiomify.fetchers import fetch_config, fetch_idioms
|
11 |
from idiomify.models import Alpha, Gamma
|
12 |
from idiomify.paths import ROOT_DIR
|
|
|
40 |
else:
|
41 |
raise ValueError
|
42 |
# prepare datamodule
|
43 |
+
datamodule = IdiomifyDataModule(config, tokenizer, idioms)
|
44 |
|
45 |
with wandb.init(entity="eubinecto", project="idiomify-demo", config=config) as run:
|
46 |
logger = WandbLogger(log_model=False)
|
main_upload_literal2idiom.py → main_upload_literal2idiomatic.py
RENAMED
@@ -31,7 +31,7 @@ def main():
|
|
31 |
raise NotImplementedError
|
32 |
|
33 |
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
34 |
-
artifact = wandb.Artifact(name="
|
35 |
tsv_path = ROOT_DIR / "all.tsv"
|
36 |
with open(tsv_path, 'w') as fh:
|
37 |
writer = csv.writer(fh, delimiter="\t")
|
|
|
31 |
raise NotImplementedError
|
32 |
|
33 |
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
34 |
+
artifact = wandb.Artifact(name="literal2idiomatic", type="dataset")
|
35 |
tsv_path = ROOT_DIR / "all.tsv"
|
36 |
with open(tsv_path, 'w') as fh:
|
37 |
writer = csv.writer(fh, delimiter="\t")
|