idiomify / main_infer.py
eubinecto's picture
[#1] checkpoint before amending builders.py
e9d1a5a
raw
history blame
1.46 kB
# we disable them for now.
# import argparse
# from idiomify.fetchers import fetch_config, fetch_rd, fetch_idioms
# from transformers import BertTokenizer
# from termcolor import colored
#
#
# def main():
# parser = argparse.ArgumentParser()
# parser.add_argument("--model", type=str,
# default="alpha")
# parser.add_argument("--ver", type=str,
# default="eng2eng")
# parser.add_argument("--sent", type=str,
# default="to avoid getting to the point")
# args = parser.parse_args()
# config = fetch_config()[args.model][args.ver]
# config.update(vars(args))
# idioms = fetch_idioms(config['idioms_ver'])
# rd = fetch_rd(config['model'], config['ver'])
# rd.eval()
# tokenizer = BertTokenizer.from_pretrained(config['bert'])
# X = T.inputs([("", config['sent'])], tokenizer, config['k'])
# probs = rd.P_wisdom(X).squeeze().tolist()
# wisdom2prob = [
# (wisdom, prob)
# for wisdom, prob in zip(idioms, probs)
# ]
# # sort and append
# res = list(sorted(wisdom2prob, key=lambda x: x[1], reverse=True))
# print(f"query: {colored(text=config['sent'], color='blue')}")
# for idx, (idiom, prob) in enumerate(res):
# print(idx, idiom, prob)
#
#
# if __name__ == '__main__':
# main()