|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import sys |
|
from argparse import ArgumentParser |
|
|
|
import sentencepiece as spm |
|
|
|
try: |
|
import sentencepiece_model_pb2 as spt |
|
except (ImportError, ModuleNotFoundError): |
|
raise Exception("Ensure that sentencepiece_model_pb2.py has been generated from the protoc compiler") |
|
|
|
|
|
"""Utility to add special tokens to existing sentencepiece models. |
|
|
|
Generate sentencepiece_model_pb2.py in the directory of this script before running |
|
To generate run `protoc --python_out=<path_to_NeMo>/scripts/tokenizers/ sentencepiece_model.proto` |
|
inside the src folder in sentencepiece repo |
|
Refer: https://github.com/google/sentencepiece/issues/121 |
|
|
|
Usage: |
|
python edit_spt_model.py \ |
|
--input_file <input_model_dir> \ |
|
--output_file <output_model_dir> \ |
|
--tokens <space separated special tokens> |
|
|
|
Example: |
|
python edit_spt_model.py \ |
|
--input_file test.model \ |
|
--output_file test.model \ |
|
--tokens [CLS] [SEP] |
|
""" |
|
|
|
|
|
def edit_spt_model(): |
|
parser = ArgumentParser() |
|
parser.add_argument( |
|
"--input_file", type=str, required=True, help="Path to sentencepiece model file", |
|
) |
|
parser.add_argument( |
|
"--output_file", type=str, required=True, help="Path to sentencepiece model file", |
|
) |
|
parser.add_argument( |
|
"--tokens", type=str, nargs='+', required=True, help="Special tokens to add to tokenizer", |
|
) |
|
parser.add_argument( |
|
"--is_userdefined", action="store_true", help="When set, the new tokens are set as user_defined tokens", |
|
) |
|
args = parser.parse_args() |
|
|
|
token_type = 3 |
|
if args.is_userdefined: |
|
token_type = 4 |
|
|
|
model = spt.ModelProto() |
|
model.ParseFromString(open(args.input_file, 'rb').read()) |
|
|
|
for token in args.tokens: |
|
piece = model.SentencePiece(piece=token, score=0.0, type=token_type) |
|
if piece in model.pieces: |
|
logging.error(f"Special Token '{token}' already exists in the input model!") |
|
sys.exit(1) |
|
model.pieces.append(piece) |
|
|
|
sp = spm.SentencePieceProcessor() |
|
try: |
|
sp.LoadFromSerializedProto(model.SerializeToString()) |
|
for token in args.tokens: |
|
id = sp.piece_to_id(token) |
|
logging.info(f"Created token '{token}' at ID {id}") |
|
logging.info(f"New tokenizer vocab size: {sp.get_piece_size()}") |
|
except: |
|
logging.error("Could not appropriately configure new tokenizer. Verify if the special tokens already exist.") |
|
sys.exit(1) |
|
|
|
with open(args.output_file, 'wb') as outf: |
|
outf.write(model.SerializeToString()) |
|
|
|
logging.info(f"Created new tokenizer at: {args.output_file}") |
|
|
|
|
|
if __name__ == "__main__": |
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") |
|
|
|
edit_spt_model() |
|
|