eubinecto commited on
Commit
6fd648a
1 Parent(s): 322e083

[#1] main_infer.py implemented

Browse files
idiomify/fetchers.py CHANGED
@@ -95,8 +95,7 @@ def fetch_alpha(ver: str, run: Run = None) -> Alpha:
95
  artifact_dir = artifact.download(root=alpha_dir(ver))
96
  ckpt_path = path.join(artifact_dir, "model.ckpt")
97
  bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
98
- with open(ckpt_path, 'r') as fh:
99
- alpha = Alpha.load_from_checkpoint(ckpt_path, bart=bart)
100
  return alpha
101
 
102
 
 
95
  artifact_dir = artifact.download(root=alpha_dir(ver))
96
  ckpt_path = path.join(artifact_dir, "model.ckpt")
97
  bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
98
+ alpha = Alpha.load_from_checkpoint(ckpt_path, bart=bart)
 
99
  return alpha
100
 
101
 
idiomify/idiomifier.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BartTokenizer
2
+ from builders import SourcesBuilder
3
+ from models import Alpha
4
+
5
+
6
+ class Idiomifier:
7
+
8
+ def __init__(self, model: Alpha, tokenizer: BartTokenizer):
9
+ self.model = model
10
+ self.builder = SourcesBuilder(tokenizer)
11
+ self.model.eval()
12
+
13
+ def __call__(self, src: str, max_length=100) -> str:
14
+ srcs = self.builder(literal2idiomatic=[(src, "")])
15
+ pred_ids = self.model.bart.generate(
16
+ inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
17
+ attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
18
+ decoder_start_token_id=self.model.hparams['bos_token_id'],
19
+ max_length=max_length,
20
+ ).squeeze() # -> (N, L_t) -> (L_t)
21
+ tgt = self.builder.tokenizer.decode(pred_ids, skip_special_tokens=True)
22
+ return tgt
idiomify/models.py CHANGED
@@ -47,14 +47,6 @@ class Alpha(pl.LightningModule): # noqa
47
  def on_train_batch_end(self, outputs: dict, *args, **kwargs):
48
  self.log("Train/Loss", outputs['loss'])
49
 
50
- def predict(self, srcs: torch.Tensor) -> torch.Tensor:
51
- pred_ids = self.bart.generate(
52
- inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
53
- attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
54
- decoder_start_token_id=self.hparams['bos_token_id'],
55
- )
56
- return pred_ids # (N, L)
57
-
58
  def configure_optimizers(self) -> torch.optim.Optimizer:
59
  """
60
  Instantiates and returns the optimizer to be used for this model
 
47
  def on_train_batch_end(self, outputs: dict, *args, **kwargs):
48
  self.log("Train/Loss", outputs['loss'])
49
 
 
 
 
 
 
 
 
 
50
  def configure_optimizers(self) -> torch.optim.Optimizer:
51
  """
52
  Instantiates and returns the optimizer to be used for this model
main_infer.py CHANGED
@@ -1,37 +1,28 @@
1
- # we disable them for now.
2
- # import argparse
3
- # from idiomify.fetchers import fetch_config, fetch_rd, fetch_idioms
4
- # from transformers import BertTokenizer
5
- # from termcolor import colored
6
- #
7
- #
8
- # def main():
9
- # parser = argparse.ArgumentParser()
10
- # parser.add_argument("--model", type=str,
11
- # default="alpha")
12
- # parser.add_argument("--ver", type=str,
13
- # default="eng2eng")
14
- # parser.add_argument("--sent", type=str,
15
- # default="to avoid getting to the point")
16
- # args = parser.parse_args()
17
- # config = fetch_config()[args.model][args.ver]
18
- # config.update(vars(args))
19
- # idioms = fetch_idioms(config['idioms_ver'])
20
- # rd = fetch_rd(config['model'], config['ver'])
21
- # rd.eval()
22
- # tokenizer = BertTokenizer.from_pretrained(config['bert'])
23
- # X = T.inputs([("", config['sent'])], tokenizer, config['k'])
24
- # probs = rd.P_wisdom(X).squeeze().tolist()
25
- # wisdom2prob = [
26
- # (wisdom, prob)
27
- # for wisdom, prob in zip(idioms, probs)
28
- # ]
29
- # # sort and append
30
- # res = list(sorted(wisdom2prob, key=lambda x: x[1], reverse=True))
31
- # print(f"query: {colored(text=config['sent'], color='blue')}")
32
- # for idx, (idiom, prob) in enumerate(res):
33
- # print(idx, idiom, prob)
34
- #
35
- #
36
- # if __name__ == '__main__':
37
- # main()
 
1
+ import argparse
2
+ from termcolor import colored
3
+ from idiomifier import Idiomifier
4
+ from idiomify.fetchers import fetch_config, fetch_alpha
5
+ from transformers import BartTokenizer
6
+
7
+
8
+ def main():
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument("--model", type=str,
11
+ default="alpha")
12
+ parser.add_argument("--ver", type=str,
13
+ default="overfit")
14
+ parser.add_argument("--src", type=str,
15
+ default="If there's any benefits to losing my job, it's that I'll now be able to go to school full-time and finish my degree earlier.")
16
+ args = parser.parse_args()
17
+ config = fetch_config()[args.model][args.ver]
18
+ config.update(vars(args))
19
+ model = fetch_alpha(config['ver'])
20
+ tokenizer = BartTokenizer.from_pretrained(config['bart'])
21
+ idiomifier = Idiomifier(model, tokenizer)
22
+ src = config['src']
23
+ tgt = idiomifier(src=config['src'])
24
+ print(src, "\n->", colored(tgt, "blue"))
25
+
26
+
27
+ if __name__ == '__main__':
28
+ main()