eubinecto
commited on
Commit
•
207cddf
1
Parent(s):
3646bbf
saving this branch
Browse files- explore/explore_fetch_epie.py +27 -0
- explore/explore_fetch_epie_counts.py +20 -0
- explore/explore_idiom2subwords.py +0 -0
- idiomify/builders.py +84 -0
- idiomify/datamodules.py +45 -5
- idiomify/fetchers.py +37 -3
- idiomify/models.py +0 -98
- idiomify/tensors.py +0 -56
- idiomify/urls.py +11 -0
- main_train.py +2 -2
- main_upload_idiom2context.py +11 -0
- main_upload_idioms.py +13 -0
- main_upload_tokenizer.py +13 -0
- requirements.txt +3 -66
explore/explore_fetch_epie.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from idiomify.fetchers import fetch_epie
|
3 |
+
|
4 |
+
|
5 |
+
def main():
|
6 |
+
epie = fetch_epie()
|
7 |
+
idioms = set([
|
8 |
+
idiom
|
9 |
+
for idiom, _, _ in epie
|
10 |
+
])
|
11 |
+
|
12 |
+
# so, what do you want? you want to build an idiom-masked language modeling?
|
13 |
+
for idiom, context, tag in epie:
|
14 |
+
print(context)
|
15 |
+
|
16 |
+
for idx, idiom in enumerate(idioms):
|
17 |
+
print(idx, idiom)
|
18 |
+
|
19 |
+
# isn't it better to just leave the idiom there, and have it guess what meaning it has?
|
20 |
+
# in that case, It may be better to use a generative model?
|
21 |
+
# but what would happen if you let it... just guess it?
|
22 |
+
# the problem with non-masking is that ... you give the model the answer.
|
23 |
+
# what you should rather do is... do something like... find similar words.
|
24 |
+
|
25 |
+
|
26 |
+
if __name__ == '__main__':
|
27 |
+
main()
|
explore/explore_fetch_epie_counts.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from idiomify.fetchers import fetch_epie
|
3 |
+
|
4 |
+
|
5 |
+
def main():
|
6 |
+
idioms = set([
|
7 |
+
idiom
|
8 |
+
for idiom, _, _ in fetch_epie()
|
9 |
+
])
|
10 |
+
contexts = [
|
11 |
+
context
|
12 |
+
for _, _, context in fetch_epie()
|
13 |
+
]
|
14 |
+
print("Total number of idioms:", len(idioms))
|
15 |
+
# This should learn... this - what I need for now is building a datamodule out of this
|
16 |
+
print("Total number of contexts:", len(contexts))
|
17 |
+
|
18 |
+
|
19 |
+
if __name__ == '__main__':
|
20 |
+
main()
|
explore/explore_idiom2subwords.py
ADDED
File without changes
|
idiomify/builders.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
all the functions for building tensors are defined here.
|
3 |
+
builders must accept device as one of the parameters.
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
from typing import List, Tuple
|
7 |
+
from transformers import BertTokenizer
|
8 |
+
|
9 |
+
|
10 |
+
class TensorBuilder:
|
11 |
+
|
12 |
+
def __init__(self, tokenizer: BertTokenizer):
|
13 |
+
self.tokenizer = tokenizer
|
14 |
+
|
15 |
+
def __call__(self, *args, **kwargs) -> torch.Tensor:
|
16 |
+
raise NotImplementedError
|
17 |
+
|
18 |
+
|
19 |
+
class Idiom2SubwordsBuilder(TensorBuilder):
|
20 |
+
|
21 |
+
def __call__(self, idioms: List[str], k: int) -> torch.Tensor:
|
22 |
+
mask_id = self.tokenizer.mask_token_id
|
23 |
+
pad_id = self.tokenizer.pad_token_id
|
24 |
+
# temporarily disable single-token status of the idioms
|
25 |
+
idioms = [idiom.split(" ") for idiom in idioms]
|
26 |
+
encodings = self.tokenizer(text=idioms,
|
27 |
+
add_special_tokens=False,
|
28 |
+
# should set this to True, as we already have the idioms split.
|
29 |
+
is_split_into_words=True,
|
30 |
+
padding='max_length',
|
31 |
+
max_length=k, # set to k
|
32 |
+
return_tensors="pt")
|
33 |
+
input_ids = encodings['input_ids']
|
34 |
+
input_ids[input_ids == pad_id] = mask_id # replace them with masks
|
35 |
+
return input_ids
|
36 |
+
|
37 |
+
|
38 |
+
class Idiom2DefBuilder(TensorBuilder):
|
39 |
+
|
40 |
+
def __call__(self, idiom2def: List[Tuple[str, str]], k: int) -> torch.Tensor:
|
41 |
+
defs = [definition for _, definition in idiom2def]
|
42 |
+
lefts = [" ".join(["[MASK]"] * k)] * len(defs)
|
43 |
+
encodings = self.tokenizer(text=lefts,
|
44 |
+
text_pair=defs,
|
45 |
+
return_tensors="pt",
|
46 |
+
add_special_tokens=True,
|
47 |
+
truncation=True,
|
48 |
+
padding=True,
|
49 |
+
verbose=True)
|
50 |
+
input_ids: torch.Tensor = encodings['input_ids']
|
51 |
+
cls_id: int = self.tokenizer.cls_token_id
|
52 |
+
sep_id: int = self.tokenizer.sep_token_id
|
53 |
+
mask_id: int = self.tokenizer.mask_token_id
|
54 |
+
wisdom_mask = torch.where(input_ids == mask_id, 1, 0)
|
55 |
+
desc_mask = torch.where(((input_ids != cls_id) & (input_ids != sep_id) & (input_ids != mask_id)), 1, 0)
|
56 |
+
return torch.stack([input_ids,
|
57 |
+
encodings['token_type_ids'],
|
58 |
+
encodings['attention_mask'],
|
59 |
+
wisdom_mask,
|
60 |
+
desc_mask], dim=1)
|
61 |
+
|
62 |
+
|
63 |
+
class Idiom2ContextBuilder(TensorBuilder):
|
64 |
+
|
65 |
+
def __call__(self, idiom2context: List[Tuple[str, str]]):
|
66 |
+
contexts = [context for _, context in idiom2context]
|
67 |
+
encodings = self.tokenizer(text=contexts,
|
68 |
+
return_tensors="pt",
|
69 |
+
add_special_tokens=True,
|
70 |
+
truncation=True,
|
71 |
+
padding=True,
|
72 |
+
verbose=True)
|
73 |
+
return torch.stack([encodings['input_ids'],
|
74 |
+
encodings['token_type_ids'],
|
75 |
+
encodings['attention_mask']], dim=1)
|
76 |
+
|
77 |
+
|
78 |
+
class TargetsBuilder(TensorBuilder):
|
79 |
+
|
80 |
+
def __call__(self, idiom2sent: List[Tuple[str, str]], idioms: List[str]) -> torch.Tensor:
|
81 |
+
return torch.LongTensor([
|
82 |
+
idioms.index(idiom)
|
83 |
+
for idiom, _ in idiom2sent
|
84 |
+
])
|
idiomify/datamodules.py
CHANGED
@@ -2,8 +2,8 @@ import torch
|
|
2 |
from typing import Tuple, Optional, List
|
3 |
from torch.utils.data import Dataset, DataLoader
|
4 |
from pytorch_lightning import LightningDataModule
|
5 |
-
from idiomify.fetchers import fetch_idiom2def
|
6 |
-
from idiomify import
|
7 |
from transformers import BertTokenizer
|
8 |
|
9 |
|
@@ -30,7 +30,7 @@ class IdiomifyDataset(Dataset):
|
|
30 |
return self.X[idx], self.y[idx]
|
31 |
|
32 |
|
33 |
-
class
|
34 |
|
35 |
# boilerplate - just ignore these
|
36 |
def test_dataloader(self):
|
@@ -66,10 +66,50 @@ class IdiomifyDataModule(LightningDataModule):
|
|
66 |
"""
|
67 |
# --- set up the builders --- #
|
68 |
# build the datasets
|
69 |
-
X =
|
70 |
-
y =
|
71 |
self.dataset = IdiomifyDataset(X, y)
|
72 |
|
73 |
def train_dataloader(self) -> DataLoader:
|
74 |
return DataLoader(self.dataset, batch_size=self.config['batch_size'],
|
75 |
shuffle=self.config['shuffle'], num_workers=self.config['num_workers'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from typing import Tuple, Optional, List
|
3 |
from torch.utils.data import Dataset, DataLoader
|
4 |
from pytorch_lightning import LightningDataModule
|
5 |
+
from idiomify.fetchers import fetch_idiom2def, fetch_epie
|
6 |
+
from idiomify.builders import Idiom2DefBuilder, Idiom2ContextBuilder, TargetsBuilder
|
7 |
from transformers import BertTokenizer
|
8 |
|
9 |
|
|
|
30 |
return self.X[idx], self.y[idx]
|
31 |
|
32 |
|
33 |
+
class Idiom2DefDataModule(LightningDataModule):
|
34 |
|
35 |
# boilerplate - just ignore these
|
36 |
def test_dataloader(self):
|
|
|
66 |
"""
|
67 |
# --- set up the builders --- #
|
68 |
# build the datasets
|
69 |
+
X = Idiom2DefBuilder(self.tokenizer)(self.idiom2def, self.config['k'])
|
70 |
+
y = TargetsBuilder(self.tokenizer)(self.idiom2def, self.idioms)
|
71 |
self.dataset = IdiomifyDataset(X, y)
|
72 |
|
73 |
def train_dataloader(self) -> DataLoader:
|
74 |
return DataLoader(self.dataset, batch_size=self.config['batch_size'],
|
75 |
shuffle=self.config['shuffle'], num_workers=self.config['num_workers'])
|
76 |
+
|
77 |
+
|
78 |
+
class Idiom2ContextsDataModule(LightningDataModule):
|
79 |
+
|
80 |
+
# boilerplate - just ignore these
|
81 |
+
def test_dataloader(self):
|
82 |
+
pass
|
83 |
+
|
84 |
+
def val_dataloader(self):
|
85 |
+
pass
|
86 |
+
|
87 |
+
def predict_dataloader(self):
|
88 |
+
pass
|
89 |
+
|
90 |
+
def __init__(self, config: dict, tokenizer: BertTokenizer, idioms: List[str]):
|
91 |
+
super().__init__()
|
92 |
+
self.config = config
|
93 |
+
self.tokenizer = tokenizer
|
94 |
+
self.idioms = idioms
|
95 |
+
self.idiom2context: Optional[List[Tuple[str, str]]] = None
|
96 |
+
self.dataset: Optional[IdiomifyDataset] = None
|
97 |
+
|
98 |
+
def prepare_data(self):
|
99 |
+
"""
|
100 |
+
prepare: download all data needed for this from wandb to local.
|
101 |
+
"""
|
102 |
+
self.idiom2context = [
|
103 |
+
(idiom, context)
|
104 |
+
for idiom, _, context in fetch_epie()
|
105 |
+
]
|
106 |
+
|
107 |
+
def setup(self, stage: Optional[str] = None):
|
108 |
+
# build the datasets
|
109 |
+
X = Idiom2ContextBuilder(self.tokenizer)(self.idiom2context)
|
110 |
+
y = TargetsBuilder(self.tokenizer)(self.idiom2context, self.idioms)
|
111 |
+
self.dataset = IdiomifyDataset(X, y)
|
112 |
+
|
113 |
+
def train_dataloader(self):
|
114 |
+
return DataLoader(self.dataset, batch_size=self.config['batch_size'],
|
115 |
+
shuffle=self.config['shuffle'], num_workers=self.config['num_workers'])
|
idiomify/fetchers.py
CHANGED
@@ -1,13 +1,47 @@
|
|
1 |
import csv
|
2 |
import yaml
|
3 |
import wandb
|
|
|
4 |
from typing import Tuple, List
|
5 |
-
|
|
|
|
|
|
|
6 |
from idiomify.paths import idiom2def_dir, CONFIG_YAML, idioms_dir, alpha_dir
|
7 |
-
from idiomify import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from transformers import AutoModelForMaskedLM, AutoConfig, BertTokenizer
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
# dataset
|
12 |
def fetch_idiom2def(ver: str) -> List[Tuple[str, str]]:
|
13 |
artifact = wandb.Api().artifact(f"eubinecto/idiomify-demo/idiom2def:{ver}", type="dataset")
|
@@ -45,7 +79,7 @@ def fetch_rd(model: str, ver: str) -> RD:
|
|
45 |
ckpt_path = artifact_path / "rd.ckpt"
|
46 |
idioms = fetch_idioms(config['idioms_ver'])
|
47 |
tokenizer = BertTokenizer.from_pretrained(config['bert'])
|
48 |
-
idiom2subwords =
|
49 |
if model == Alpha.name():
|
50 |
rd = Alpha.load_from_checkpoint(str(ckpt_path), mlm=mlm, idiom2subwords=idiom2subwords)
|
51 |
elif model == Gamma.name():
|
|
|
1 |
import csv
|
2 |
import yaml
|
3 |
import wandb
|
4 |
+
import requests
|
5 |
from typing import Tuple, List
|
6 |
+
|
7 |
+
from wandb.sdk.wandb_run import Run
|
8 |
+
|
9 |
+
from idiomify.models import Alpha, RD
|
10 |
from idiomify.paths import idiom2def_dir, CONFIG_YAML, idioms_dir, alpha_dir
|
11 |
+
from idiomify.urls import (
|
12 |
+
EPIE_IMMUTABLE_IDIOMS_URL,
|
13 |
+
EPIE_IMMUTABLE_IDIOMS_CONTEXTS_URL,
|
14 |
+
EPIE_IMMUTABLE_IDIOMS_TAGS_URL,
|
15 |
+
EPIE_MUTABLE_IDIOMS_URL,
|
16 |
+
EPIE_MUTABLE_IDIOMS_CONTEXTS_URL,
|
17 |
+
EPIE_MUTABLE_IDIOMS_TAGS_URL
|
18 |
+
)
|
19 |
+
from idiomify.builders import Idiom2SubwordsBuilder
|
20 |
from transformers import AutoModelForMaskedLM, AutoConfig, BertTokenizer
|
21 |
|
22 |
|
23 |
+
# sources for dataset
|
24 |
+
def fetch_epie() -> List[Tuple[str, str, str]]:
|
25 |
+
idioms = requests.get(EPIE_IMMUTABLE_IDIOMS_URL).text \
|
26 |
+
+ requests.get(EPIE_MUTABLE_IDIOMS_URL).text
|
27 |
+
contexts = requests.get(EPIE_IMMUTABLE_IDIOMS_CONTEXTS_URL).text \
|
28 |
+
+ requests.get(EPIE_MUTABLE_IDIOMS_CONTEXTS_URL).text
|
29 |
+
tags = requests.get(EPIE_IMMUTABLE_IDIOMS_TAGS_URL).text \
|
30 |
+
+ requests.get(EPIE_MUTABLE_IDIOMS_TAGS_URL).text
|
31 |
+
return list(zip(idioms.strip().split("\n"),
|
32 |
+
contexts.strip().split("\n"),
|
33 |
+
tags.strip().split("\n")))
|
34 |
+
|
35 |
+
|
36 |
+
# you should somehow get this from... wandb.
|
37 |
+
def fetch_idiom2context(ver: str, run: Run = None) -> List[Tuple[str, str]]:
|
38 |
+
"""
|
39 |
+
include run if you want to track the lineage
|
40 |
+
"""
|
41 |
+
if run:
|
42 |
+
pass
|
43 |
+
|
44 |
+
|
45 |
# dataset
|
46 |
def fetch_idiom2def(ver: str) -> List[Tuple[str, str]]:
|
47 |
artifact = wandb.Api().artifact(f"eubinecto/idiomify-demo/idiom2def:{ver}", type="dataset")
|
|
|
79 |
ckpt_path = artifact_path / "rd.ckpt"
|
80 |
idioms = fetch_idioms(config['idioms_ver'])
|
81 |
tokenizer = BertTokenizer.from_pretrained(config['bert'])
|
82 |
+
idiom2subwords = Idiom2SubwordsBuilder(tokenizer)(idioms, config['k'])
|
83 |
if model == Alpha.name():
|
84 |
rd = Alpha.load_from_checkpoint(str(ckpt_path), mlm=mlm, idiom2subwords=idiom2subwords)
|
85 |
elif model == Gamma.name():
|
idiomify/models.py
CHANGED
@@ -174,101 +174,3 @@ class Alpha(RD):
|
|
174 |
H_k = self.H_k(H_all) # (N, L, H) -> (N, K, H)
|
175 |
S_wisdom = self.S_wisdom_literal(H_k) # (N, K, H) -> (N, |W|)
|
176 |
return S_wisdom
|
177 |
-
|
178 |
-
|
179 |
-
class BiLSTMPooler(torch.nn.Module):
|
180 |
-
def __init__(self, hidden_size: int):
|
181 |
-
super().__init__()
|
182 |
-
self.lstm = torch.nn.LSTM(input_size=hidden_size, hidden_size=hidden_size // 2, batch_first=True,
|
183 |
-
num_layers=1, bidirectional=True)
|
184 |
-
|
185 |
-
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
186 |
-
hiddens, _ = self.lstm(X)
|
187 |
-
return hiddens[:, -1]
|
188 |
-
|
189 |
-
|
190 |
-
class Gamma(RD):
|
191 |
-
"""
|
192 |
-
@eubinecto
|
193 |
-
S_wisdom = S_wisdom_literal + S_wisdom_figurative
|
194 |
-
but the way we get S_wisdom_figurative is much simplified, compared with RDBeta.
|
195 |
-
"""
|
196 |
-
|
197 |
-
def __init__(self, mlm: BertForMaskedLM, idiom2subwords: torch.Tensor, k: int, lr: float):
|
198 |
-
super().__init__(mlm, idiom2subwords, k, lr)
|
199 |
-
# a pooler is a multilayer perceptron that pools wisdom_embeddings from idiom2subwords_embeddings
|
200 |
-
self.pooler = BiLSTMPooler(self.mlm.config.hidden_size)
|
201 |
-
# --- to be used to compute attentions --- #
|
202 |
-
self.attention_mask: Optional[torch.Tensor] = None
|
203 |
-
|
204 |
-
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
205 |
-
"""
|
206 |
-
:param X: (N, 4, L);
|
207 |
-
(num samples, 0=input_ids/1=token_type_ids/2=attention_mask/3=wisdom_mask, the maximum length)
|
208 |
-
:return: (N, L, H); (num samples, k, the size of the vocabulary of subwords)
|
209 |
-
"""
|
210 |
-
input_ids = X[:, 0] # (N, 4, L) -> (N, L)
|
211 |
-
token_type_ids = X[:, 1] # (N, 4, L) -> (N, L)
|
212 |
-
self.attention_mask = X[:, 2] # (N, 4, L) -> (N, L)
|
213 |
-
self.wisdom_mask = X[:, 3] # (N, 4, L) -> (N, L)
|
214 |
-
self.desc_mask = X[:, 4] # (N, 4, L) -> (N, L)
|
215 |
-
H_all = self.mlm.bert.forward(input_ids, self.attention_mask, token_type_ids)[0] # (N, 3, L) -> (N, L, H)
|
216 |
-
return H_all
|
217 |
-
|
218 |
-
def H_desc_attention_mask(self, attention_mask: torch.Tensor) -> torch.Tensor:
|
219 |
-
"""
|
220 |
-
this is needed mask the padding tokens
|
221 |
-
:param attention_mask: (N, L)
|
222 |
-
"""
|
223 |
-
N, L = attention_mask.size()
|
224 |
-
H_desc_attention_mask = torch.masked_select(attention_mask, self.desc_mask.bool())
|
225 |
-
H_desc_attention_mask = H_desc_attention_mask.reshape(N, L - (self.hparams['k'] + 3))
|
226 |
-
return H_desc_attention_mask
|
227 |
-
|
228 |
-
def S_wisdom(self, H_all: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
229 |
-
S_wisdom_literal = self.S_wisdom_literal(self.H_k(H_all))
|
230 |
-
S_wisdom_figurative = self.S_wisdom_figurative(H_all)
|
231 |
-
S_wisdom = S_wisdom_literal + S_wisdom_figurative
|
232 |
-
return S_wisdom, S_wisdom_literal, S_wisdom_figurative
|
233 |
-
|
234 |
-
def S_wisdom_figurative(self, H_all: torch.Tensor) -> torch.Tensor:
|
235 |
-
# --- draw the embeddings for wisdoms from the embeddings of idiom2subwords -- #
|
236 |
-
# this is to use as less of newly initialised weights as possible
|
237 |
-
idiom2subwords_embeddings = self.mlm.bert \
|
238 |
-
.embeddings.word_embeddings(self.idiom2subwords) # (W, K) -> (W, K, H)
|
239 |
-
wisdom_embeddings = self.pooler(idiom2subwords_embeddings).squeeze() # (W, H, K) -> (W, H, 1) -> (W, H)
|
240 |
-
# --- draw H_wisdom from H_desc with attention --- #
|
241 |
-
H_cls = H_all[:, 0] # (N, L, H) -> (N, H)
|
242 |
-
H_desc = self.H_desc(H_all) # (N, L, H) -> (N, D, H)
|
243 |
-
H_desc_attention_mask = self.H_desc_attention_mask(self.attention_mask) # (N, L) -> (N, D)
|
244 |
-
scores = torch.einsum("...h,...dh->...d", H_cls, H_desc) # (N, D)
|
245 |
-
# ignore the padding tokens
|
246 |
-
scores = torch.masked_fill(scores, H_desc_attention_mask != 1, float("-inf")) # (N, D)
|
247 |
-
attentions = torch.softmax(scores, dim=1) # over D
|
248 |
-
H_wisdom = torch.einsum("...d,...dh->...h", attentions, H_desc) # -> (N, H)
|
249 |
-
# --- now compare H_wisdom with all the wisdoms --- #
|
250 |
-
S_wisdom_figurative = torch.einsum("...h,wh->...w", H_wisdom, wisdom_embeddings) # (N, H) * (W, H) -> (N, W)
|
251 |
-
return S_wisdom_figurative
|
252 |
-
|
253 |
-
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> dict:
|
254 |
-
X, y = batch
|
255 |
-
H_all = self.forward(X) # (N, 3, L) -> (N, L, H)
|
256 |
-
S_wisdom, S_wisdom_literal, S_wisdom_figurative = self.S_wisdom(H_all) # (N, L, H) -> (N, |W|)
|
257 |
-
loss_all = F.cross_entropy(S_wisdom, y).sum() # (N, |W|), (N,) -> (N,) -> (1,)
|
258 |
-
loss_literal = F.cross_entropy(S_wisdom_literal, y).sum() # (N, |W|), (N,) -> (N,) -> (1,)
|
259 |
-
loss_figurative = F.cross_entropy(S_wisdom_figurative, y).sum() # (N, |W|), (N,) -> (N,) -> (1,)
|
260 |
-
loss = loss_all + loss_literal + loss_figurative # unweighted multi-task learning
|
261 |
-
return {
|
262 |
-
# you cannot change the keyword for the loss
|
263 |
-
"loss": loss,
|
264 |
-
}
|
265 |
-
|
266 |
-
def P_wisdom(self, X: torch.Tensor) -> torch.Tensor:
|
267 |
-
"""
|
268 |
-
:param X: (N, 3, L)
|
269 |
-
:return P_wisdom: (N, |W|), normalized over dim 1.
|
270 |
-
"""
|
271 |
-
H_all = self.forward(X) # (N, 3, L) -> (N, L, H)
|
272 |
-
S_wisdom, _, _ = self.S_wisdom(H_all) # (N, L, H) -> (N, W)
|
273 |
-
P_wisdom = F.softmax(S_wisdom, dim=1) # (N, W) -> (N, W)
|
274 |
-
return P_wisdom
|
|
|
174 |
H_k = self.H_k(H_all) # (N, L, H) -> (N, K, H)
|
175 |
S_wisdom = self.S_wisdom_literal(H_k) # (N, K, H) -> (N, |W|)
|
176 |
return S_wisdom
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
idiomify/tensors.py
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
all the functions for building tensors are defined here.
|
3 |
-
builders must accept device as one of the parameters.
|
4 |
-
"""
|
5 |
-
import torch
|
6 |
-
from typing import List, Tuple
|
7 |
-
from transformers import BertTokenizer
|
8 |
-
|
9 |
-
|
10 |
-
def idiom2subwords(idioms: List[str], tokenizer: BertTokenizer, k: int) -> torch.Tensor:
|
11 |
-
mask_id = tokenizer.mask_token_id
|
12 |
-
pad_id = tokenizer.pad_token_id
|
13 |
-
# temporarily disable single-token status of the idioms
|
14 |
-
idioms = [idiom.split(" ") for idiom in idioms]
|
15 |
-
encodings = tokenizer(text=idioms,
|
16 |
-
add_special_tokens=False,
|
17 |
-
# should set this to True, as we already have the idioms split.
|
18 |
-
is_split_into_words=True,
|
19 |
-
padding='max_length',
|
20 |
-
max_length=k, # set to k
|
21 |
-
return_tensors="pt")
|
22 |
-
input_ids = encodings['input_ids']
|
23 |
-
input_ids[input_ids == pad_id] = mask_id # replace them with masks
|
24 |
-
return input_ids
|
25 |
-
|
26 |
-
|
27 |
-
def inputs(idiom2def: List[Tuple[str, str]], tokenizer: BertTokenizer, k: int) -> torch.Tensor:
|
28 |
-
defs = [definition for _, definition in idiom2def]
|
29 |
-
lefts = [" ".join(["[MASK]"] * k)] * len(defs)
|
30 |
-
encodings = tokenizer(text=lefts,
|
31 |
-
text_pair=defs,
|
32 |
-
return_tensors="pt",
|
33 |
-
add_special_tokens=True,
|
34 |
-
truncation=True,
|
35 |
-
padding=True,
|
36 |
-
verbose=True)
|
37 |
-
input_ids: torch.Tensor = encodings['input_ids']
|
38 |
-
cls_id: int = tokenizer.cls_token_id
|
39 |
-
sep_id: int = tokenizer.sep_token_id
|
40 |
-
mask_id: int = tokenizer.mask_token_id
|
41 |
-
|
42 |
-
wisdom_mask = torch.where(input_ids == mask_id, 1, 0)
|
43 |
-
desc_mask = torch.where(((input_ids != cls_id) & (input_ids != sep_id) & (input_ids != mask_id)), 1, 0)
|
44 |
-
return torch.stack([input_ids,
|
45 |
-
encodings['token_type_ids'],
|
46 |
-
encodings['attention_mask'],
|
47 |
-
wisdom_mask,
|
48 |
-
desc_mask], dim=1)
|
49 |
-
|
50 |
-
|
51 |
-
def targets(idiom2def: List[Tuple[str, str]], idioms: List[str]) -> torch.Tensor:
|
52 |
-
return torch.LongTensor([
|
53 |
-
idioms.index(idiom)
|
54 |
-
for idiom, _ in idiom2def
|
55 |
-
])
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
idiomify/urls.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# EPIE dataset
|
3 |
+
EPIE_IMMUTABLE_IDIOMS_TAGS_URL = "https://raw.githubusercontent.com/prateeksaxena2809/EPIE_Corpus/master/Static_Idioms_Corpus/Static_Idioms_Tags.txt" # noqa
|
4 |
+
EPIE_IMMUTABLE_IDIOMS_URL = "https://raw.githubusercontent.com/prateeksaxena2809/EPIE_Corpus/master/Static_Idioms_Corpus/Static_Idioms_Candidates.txt" # noqa
|
5 |
+
EPIE_IMMUTABLE_IDIOMS_CONTEXTS_URL = "https://raw.githubusercontent.com/prateeksaxena2809/EPIE_Corpus/master/Static_Idioms_Corpus/Static_Idioms_Words.txt" # noqa
|
6 |
+
EPIE_MUTABLE_IDIOMS_TAGS_URL = "https://raw.githubusercontent.com/prateeksaxena2809/EPIE_Corpus/master/Formal_Idioms_Corpus/Formal_Idioms_Tags.txt" # noqa
|
7 |
+
EPIE_MUTABLE_IDIOMS_URL = "https://raw.githubusercontent.com/prateeksaxena2809/EPIE_Corpus/master/Formal_Idioms_Corpus/Formal_Idioms_Candidates.txt" # noqa
|
8 |
+
EPIE_MUTABLE_IDIOMS_CONTEXTS_URL = "https://github.com/prateeksaxena2809/EPIE_Corpus/blob/master/Formal_Idioms_Corpus/Formal_Idioms_Words.txt" # noqa
|
9 |
+
|
10 |
+
|
11 |
+
|
main_train.py
CHANGED
@@ -6,7 +6,7 @@ import pytorch_lightning as pl
|
|
6 |
from pytorch_lightning.loggers import WandbLogger
|
7 |
from termcolor import colored
|
8 |
from transformers import BertForMaskedLM, BertTokenizer
|
9 |
-
from idiomify.datamodules import
|
10 |
from idiomify.fetchers import fetch_config, fetch_idioms
|
11 |
from idiomify.models import Alpha, Gamma
|
12 |
from idiomify.paths import ROOT_DIR
|
@@ -40,7 +40,7 @@ def main():
|
|
40 |
else:
|
41 |
raise ValueError
|
42 |
# prepare datamodule
|
43 |
-
datamodule =
|
44 |
|
45 |
with wandb.init(entity="eubinecto", project="idiomify-demo", config=config) as run:
|
46 |
logger = WandbLogger(log_model=False)
|
|
|
6 |
from pytorch_lightning.loggers import WandbLogger
|
7 |
from termcolor import colored
|
8 |
from transformers import BertForMaskedLM, BertTokenizer
|
9 |
+
from idiomify.datamodules import Idiom2DefDataModule
|
10 |
from idiomify.fetchers import fetch_config, fetch_idioms
|
11 |
from idiomify.models import Alpha, Gamma
|
12 |
from idiomify.paths import ROOT_DIR
|
|
|
40 |
else:
|
41 |
raise ValueError
|
42 |
# prepare datamodule
|
43 |
+
datamodule = Idiom2DefDataModule(config, tokenizer, idioms)
|
44 |
|
45 |
with wandb.init(entity="eubinecto", project="idiomify-demo", config=config) as run:
|
46 |
logger = WandbLogger(log_model=False)
|
main_upload_idiom2context.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Build and upload an idiom2context dataset to wandb.
|
3 |
+
"""
|
4 |
+
|
5 |
+
|
6 |
+
def main():
|
7 |
+
pass
|
8 |
+
|
9 |
+
|
10 |
+
if __name__ == '__main__':
|
11 |
+
main()
|
main_upload_idioms.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Here,
|
3 |
+
ver a: Compatible with the first version
|
4 |
+
ver b:
|
5 |
+
"""
|
6 |
+
|
7 |
+
|
8 |
+
def main():
|
9 |
+
pass
|
10 |
+
|
11 |
+
|
12 |
+
if __name__ == '__main__':
|
13 |
+
main()
|
main_upload_tokenizer.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Build & upload a tokenizer to wandb.
|
3 |
+
You need this if you were to add more tokens there.
|
4 |
+
"""
|
5 |
+
|
6 |
+
|
7 |
+
def main():
|
8 |
+
pass
|
9 |
+
# TODO: fetch the dataset from wandb first!
|
10 |
+
|
11 |
+
|
12 |
+
if __name__ == '__main__':
|
13 |
+
main()
|
requirements.txt
CHANGED
@@ -1,66 +1,3 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
async-timeout==4.0.2
|
5 |
-
attrs==21.4.0
|
6 |
-
cachetools==4.2.4
|
7 |
-
certifi==2021.10.8
|
8 |
-
charset-normalizer==2.0.10
|
9 |
-
click==8.0.3
|
10 |
-
configparser==5.2.0
|
11 |
-
docker-pycreds==0.4.0
|
12 |
-
filelock==3.4.2
|
13 |
-
frozenlist==1.3.0
|
14 |
-
fsspec==2022.1.0
|
15 |
-
future==0.18.2
|
16 |
-
gitdb==4.0.9
|
17 |
-
GitPython==3.1.26
|
18 |
-
google-auth==2.3.3
|
19 |
-
google-auth-oauthlib==0.4.6
|
20 |
-
grpcio==1.43.0
|
21 |
-
huggingface-hub==0.4.0
|
22 |
-
idna==3.3
|
23 |
-
importlib-metadata==4.10.1
|
24 |
-
joblib==1.1.0
|
25 |
-
Markdown==3.3.6
|
26 |
-
multidict==5.2.0
|
27 |
-
numpy==1.22.1
|
28 |
-
oauthlib==3.1.1
|
29 |
-
packaging==21.3
|
30 |
-
pathtools==0.1.2
|
31 |
-
promise==2.3
|
32 |
-
protobuf==3.19.3
|
33 |
-
psutil==5.9.0
|
34 |
-
pyasn1==0.4.8
|
35 |
-
pyasn1-modules==0.2.8
|
36 |
-
pyDeprecate==0.3.1
|
37 |
-
pyparsing==3.0.6
|
38 |
-
python-dateutil==2.8.2
|
39 |
-
pytorch-lightning==1.5.8
|
40 |
-
PyYAML==6.0
|
41 |
-
regex==2022.1.18
|
42 |
-
requests==2.27.1
|
43 |
-
requests-oauthlib==1.3.0
|
44 |
-
rsa==4.8
|
45 |
-
sacremoses==0.0.47
|
46 |
-
sentry-sdk==1.5.2
|
47 |
-
shortuuid==1.0.8
|
48 |
-
six==1.16.0
|
49 |
-
smmap==5.0.0
|
50 |
-
subprocess32==3.5.4
|
51 |
-
tensorboard==2.7.0
|
52 |
-
tensorboard-data-server==0.6.1
|
53 |
-
tensorboard-plugin-wit==1.8.1
|
54 |
-
termcolor==1.1.0
|
55 |
-
tokenizers==0.10.3
|
56 |
-
torch==1.10.1
|
57 |
-
torchmetrics==0.7.0
|
58 |
-
tqdm==4.62.3
|
59 |
-
transformers==4.15.0
|
60 |
-
typing_extensions==4.0.1
|
61 |
-
urllib3==1.26.8
|
62 |
-
wandb==0.12.9
|
63 |
-
Werkzeug==2.0.2
|
64 |
-
yarl==1.7.2
|
65 |
-
yaspin==2.1.0
|
66 |
-
zipp==3.7.0
|
|
|
1 |
+
pytorch-lightning==1.5.10
|
2 |
+
transformers==4.16.2
|
3 |
+
wandb==0.12.10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|