|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import, division, print_function, unicode_literals |
|
|
|
import argparse |
|
|
|
import sentencepiece as spm |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--model", required=True, help="sentencepiece model to use for decoding" |
|
) |
|
parser.add_argument("--input", required=True, help="input file to decode") |
|
parser.add_argument("--input_format", choices=["piece", "id"], default="piece") |
|
args = parser.parse_args() |
|
|
|
sp = spm.SentencePieceProcessor() |
|
sp.Load(args.model) |
|
|
|
if args.input_format == "piece": |
|
|
|
def decode(l): |
|
return "".join(sp.DecodePieces(l)) |
|
|
|
elif args.input_format == "id": |
|
|
|
def decode(l): |
|
return "".join(sp.DecodeIds(l)) |
|
|
|
else: |
|
raise NotImplementedError |
|
|
|
def tok2int(tok): |
|
|
|
return int(tok) if tok != "<<unk>>" else 0 |
|
|
|
with open(args.input, "r", encoding="utf-8") as h: |
|
for line in h: |
|
if args.input_format == "id": |
|
print(decode(list(map(tok2int, line.rstrip().split())))) |
|
elif args.input_format == "piece": |
|
print(decode(line.rstrip().split())) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|