|
""" |
|
https://github.com/ymcui/Chinese-LLaMA-Alpaca/blob/main/scripts/merge_tokenizer/merge_tokenizers.py |
|
|
|
|
|
python merge_tokenizers.py \ |
|
--llama_tokenizer_dir llama_tokenizer_dir \ |
|
--chinese_sp_model_file chinese_sp_model_file |
|
""" |
|
|
|
import os |
|
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"]="python" |
|
from transformers import LlamaTokenizer |
|
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model |
|
import sentencepiece as spm |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--llama_tokenizer_dir', default="../../llama/tokenizer", type=str) |
|
parser.add_argument('--chinese_sp_model_file', default='chinese_sp.model', type=str) |
|
args = parser.parse_args() |
|
|
|
llama_tokenizer_dir = args.llama_tokenizer_dir |
|
chinese_sp_model_file = args.chinese_sp_model_file |
|
|
|
|
|
llama_tokenizer = LlamaTokenizer.from_pretrained(llama_tokenizer_dir) |
|
chinese_sp_model = spm.SentencePieceProcessor() |
|
chinese_sp_model.Load(chinese_sp_model_file) |
|
|
|
llama_spm = sp_pb2_model.ModelProto() |
|
llama_spm.ParseFromString(llama_tokenizer.sp_model.serialized_model_proto()) |
|
chinese_spm = sp_pb2_model.ModelProto() |
|
chinese_spm.ParseFromString(chinese_sp_model.serialized_model_proto()) |
|
|
|
|
|
print(len(llama_tokenizer),len(chinese_sp_model)) |
|
print(llama_tokenizer.all_special_tokens) |
|
print(llama_tokenizer.all_special_ids) |
|
print(llama_tokenizer.special_tokens_map) |
|
|
|
|
|
llama_spm_tokens_set=set(p.piece for p in llama_spm.pieces) |
|
print(len(llama_spm_tokens_set)) |
|
print(f"Before:{len(llama_spm_tokens_set)}") |
|
for p in chinese_spm.pieces: |
|
piece = p.piece |
|
if piece not in llama_spm_tokens_set: |
|
new_p = sp_pb2_model.ModelProto().SentencePiece() |
|
new_p.piece = piece |
|
new_p.score = 0 |
|
llama_spm.pieces.append(new_p) |
|
print(f"New model pieces: {len(llama_spm.pieces)}") |
|
|
|
|
|
output_sp_dir = 'merged_tokenizer_sp' |
|
output_hf_dir = 'merged_tokenizer_hf' |
|
os.makedirs(output_sp_dir,exist_ok=True) |
|
with open(output_sp_dir+'/chinese_llama.model', 'wb') as f: |
|
f.write(llama_spm.SerializeToString()) |
|
tokenizer = LlamaTokenizer(vocab_file=output_sp_dir+'/chinese_llama.model') |
|
|
|
tokenizer.save_pretrained(output_hf_dir) |
|
print(f"Chinese-LLaMA tokenizer has been saved to {output_hf_dir}") |
|
|
|
|
|
|
|
llama_tokenizer = LlamaTokenizer.from_pretrained(llama_tokenizer_dir) |
|
chinese_llama_tokenizer = LlamaTokenizer.from_pretrained(output_hf_dir) |
|
print(tokenizer.all_special_tokens) |
|
print(tokenizer.all_special_ids) |
|
print(tokenizer.special_tokens_map) |
|
text='''白日依山尽,黄河入海流。欲穷千里目,更上一层楼。 |
|
The primary use of LLaMA is research on large language models, including''' |
|
print("Test text:\n",text) |
|
print(f"Tokenized by LLaMA tokenizer:{llama_tokenizer.tokenize(text)}") |
|
print(f"Tokenized by Chinese-LLaMA tokenizer:{chinese_llama_tokenizer.tokenize(text)}") |