File size: 932 Bytes
08409ff
 
 
6fd648a
08409ff
e3c7b5a
6fd648a
 
 
 
 
08409ff
64a6414
 
6fd648a
642d911
6fd648a
e3c7b5a
 
6fd648a
642d911
08409ff
 
64a6414
6fd648a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
"""
This is for just a simple sanity check on the inference.
"""
import argparse
from idiomify.pipeline import Pipeline
from idiomify.fetchers import fetch_config, fetch_idiomifier
from transformers import BartTokenizer


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--sent", type=str,
                        default="If there's any good to loosing my job,"
                                " it's that I'll now be able to go to school full-time and finish my degree earlier.")
    args = parser.parse_args()
    config = fetch_config()['idiomifier']
    config.update(vars(args))
    model = fetch_idiomifier(config['ver'])
    model.eval()  # this is crucial
    tokenizer = BartTokenizer.from_pretrained(config['bart'])
    pipeline = Pipeline(model, tokenizer)
    src = config['sent']
    tgt = pipeline(sents=[config['sent']])
    print(src, "\n->", tgt)


if __name__ == '__main__':
    main()