[#1] main_infer.py implemented
Browse files- idiomify/fetchers.py +1 -2
- idiomify/idiomifier.py +22 -0
- idiomify/models.py +0 -8
- main_infer.py +28 -37
idiomify/fetchers.py
CHANGED
@@ -95,8 +95,7 @@ def fetch_alpha(ver: str, run: Run = None) -> Alpha:
|
|
95 |
artifact_dir = artifact.download(root=alpha_dir(ver))
|
96 |
ckpt_path = path.join(artifact_dir, "model.ckpt")
|
97 |
bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
|
98 |
-
|
99 |
-
alpha = Alpha.load_from_checkpoint(ckpt_path, bart=bart)
|
100 |
return alpha
|
101 |
|
102 |
|
|
|
95 |
artifact_dir = artifact.download(root=alpha_dir(ver))
|
96 |
ckpt_path = path.join(artifact_dir, "model.ckpt")
|
97 |
bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
|
98 |
+
alpha = Alpha.load_from_checkpoint(ckpt_path, bart=bart)
|
|
|
99 |
return alpha
|
100 |
|
101 |
|
idiomify/idiomifier.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BartTokenizer
|
2 |
+
from builders import SourcesBuilder
|
3 |
+
from models import Alpha
|
4 |
+
|
5 |
+
|
6 |
+
class Idiomifier:
|
7 |
+
|
8 |
+
def __init__(self, model: Alpha, tokenizer: BartTokenizer):
|
9 |
+
self.model = model
|
10 |
+
self.builder = SourcesBuilder(tokenizer)
|
11 |
+
self.model.eval()
|
12 |
+
|
13 |
+
def __call__(self, src: str, max_length=100) -> str:
|
14 |
+
srcs = self.builder(literal2idiomatic=[(src, "")])
|
15 |
+
pred_ids = self.model.bart.generate(
|
16 |
+
inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
|
17 |
+
attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
|
18 |
+
decoder_start_token_id=self.model.hparams['bos_token_id'],
|
19 |
+
max_length=max_length,
|
20 |
+
).squeeze() # -> (N, L_t) -> (L_t)
|
21 |
+
tgt = self.builder.tokenizer.decode(pred_ids, skip_special_tokens=True)
|
22 |
+
return tgt
|
idiomify/models.py
CHANGED
@@ -47,14 +47,6 @@ class Alpha(pl.LightningModule): # noqa
|
|
47 |
def on_train_batch_end(self, outputs: dict, *args, **kwargs):
|
48 |
self.log("Train/Loss", outputs['loss'])
|
49 |
|
50 |
-
def predict(self, srcs: torch.Tensor) -> torch.Tensor:
|
51 |
-
pred_ids = self.bart.generate(
|
52 |
-
inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
|
53 |
-
attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
|
54 |
-
decoder_start_token_id=self.hparams['bos_token_id'],
|
55 |
-
)
|
56 |
-
return pred_ids # (N, L)
|
57 |
-
|
58 |
def configure_optimizers(self) -> torch.optim.Optimizer:
|
59 |
"""
|
60 |
Instantiates and returns the optimizer to be used for this model
|
|
|
47 |
def on_train_batch_end(self, outputs: dict, *args, **kwargs):
|
48 |
self.log("Train/Loss", outputs['loss'])
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
def configure_optimizers(self) -> torch.optim.Optimizer:
|
51 |
"""
|
52 |
Instantiates and returns the optimizer to be used for this model
|
main_infer.py
CHANGED
@@ -1,37 +1,28 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
# # sort and append
|
30 |
-
# res = list(sorted(wisdom2prob, key=lambda x: x[1], reverse=True))
|
31 |
-
# print(f"query: {colored(text=config['sent'], color='blue')}")
|
32 |
-
# for idx, (idiom, prob) in enumerate(res):
|
33 |
-
# print(idx, idiom, prob)
|
34 |
-
#
|
35 |
-
#
|
36 |
-
# if __name__ == '__main__':
|
37 |
-
# main()
|
|
|
1 |
+
import argparse
|
2 |
+
from termcolor import colored
|
3 |
+
from idiomifier import Idiomifier
|
4 |
+
from idiomify.fetchers import fetch_config, fetch_alpha
|
5 |
+
from transformers import BartTokenizer
|
6 |
+
|
7 |
+
|
8 |
+
def main():
|
9 |
+
parser = argparse.ArgumentParser()
|
10 |
+
parser.add_argument("--model", type=str,
|
11 |
+
default="alpha")
|
12 |
+
parser.add_argument("--ver", type=str,
|
13 |
+
default="overfit")
|
14 |
+
parser.add_argument("--src", type=str,
|
15 |
+
default="If there's any benefits to losing my job, it's that I'll now be able to go to school full-time and finish my degree earlier.")
|
16 |
+
args = parser.parse_args()
|
17 |
+
config = fetch_config()[args.model][args.ver]
|
18 |
+
config.update(vars(args))
|
19 |
+
model = fetch_alpha(config['ver'])
|
20 |
+
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
21 |
+
idiomifier = Idiomifier(model, tokenizer)
|
22 |
+
src = config['src']
|
23 |
+
tgt = idiomifier(src=config['src'])
|
24 |
+
print(src, "\n->", colored(tgt, "blue"))
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == '__main__':
|
28 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|