File size: 4,522 Bytes
08838e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import json
import argparse
from tqdm import tqdm
import os

from datasets import load_dataset
from tokenizers import SentencePieceBPETokenizer
from transformers import LlamaTokenizerFast, TrainingArguments, AutoTokenizer

def main(args):

	# Load the dataset from the huggingface Hub and prepare it for training
	if args.dataset_name is not None:
		data_files = os.listdir(args.dataset_name)
		data_files = [args.dataset_name+f for f in data_files]
		print(len(data_files))
		dataset = load_dataset("json",
				data_files=data_files,
				split=args.dataset_split,
				token=args.hub_token if args.hub_token else None
		)
		print(dataset)

	else:
		raise ValueError("No dataset name provided or dataset is already tokenized") 

	# Remove non text columns
	dataset = dataset.remove_columns([col for col in dataset.column_names if col != "text"])

	# select `num_samples` from the dataset
	dataset = dataset.shuffle(seed=args.seed).select(range(args.num_samples))

	# Create a SentencePieceBPETokenizer
	tokenizer = SentencePieceBPETokenizer()

	# Train the SentencePieceBPETokenizer on the dataset
	tokenizer.train_from_iterator(
		iterator=dataset['text'],
		vocab_size=args.vocab_size,
		show_progress=True,
		special_tokens=["<unk>", "<s>", "</s>",  "<pad>"],
	)

	# Save the tokenizer
	tokenizer.save("new-sentencepiece-tokenizer.json", pretty=True)

	# Load reference tokenizer
	if args.reference_tokenizer is not None and args.hub_token is not None:
		reference_tokenizer = AutoTokenizer.from_pretrained(args.reference_tokenizer, token=args.hub_token if args.hub_token else None)
		reference_tokenizer.save_pretrained("reference-tokenizer")
	else:
		raise ValueError("No tokenizer name provided or no hub token provided. Try using `--reference_tokenizer 'meta-llama/Llama-2-7b-hf'")

	# Read and dump the json file for the new tokenizer and the reference tokenizer
	with open("new-sentencepiece-tokenizer.json") as f:
		new_llama_tokenizer_json = json.load(f)

	with open("reference-tokenizer/tokenizer.json") as f:
		reference_tokenizer_json = json.load(f)

	# Add the reference tokenizer's config to the new tokenizer's config
	new_llama_tokenizer_json["normalizer"] = reference_tokenizer_json["normalizer"]
	new_llama_tokenizer_json["pre_tokenizer"] = reference_tokenizer_json["pre_tokenizer"]
	new_llama_tokenizer_json["post_processor"] = reference_tokenizer_json["post_processor"]
	new_llama_tokenizer_json["decoder"] = reference_tokenizer_json["decoder"]
	new_llama_tokenizer_json["model"]['fuse_unk'] = reference_tokenizer_json["model"]['fuse_unk']
	new_llama_tokenizer_json["model"]['byte_fallback'] = reference_tokenizer_json["model"]['byte_fallback']

	# Dump the new tokenizer's config
	with open("new-sentencepiece-tokenizer.json", "w") as f:
		json.dump(new_llama_tokenizer_json, f, indent=2, ensure_ascii=False)

	# Load the new tokenizer as a LlamaTokenizerFast
	new_llama_tokenizer = LlamaTokenizerFast(
		tokenizer_file="new-sentencepiece-tokenizer.json",
		name_or_path=args.reference_tokenizer + "-tokenizer",
		unk_token="<unk>",
		unk_token_id=0,
		bos_token="<s>",
		bos_token_id=1,
		eos_token="</s>",
		eos_token_id=2,
		pad_token="<pad>",
		pad_token_id=3,
		padding_side="right",
	)

	# Save the new tokenizer
	new_llama_tokenizer.save_pretrained("new-llama-tokenizer")

if __name__ == "__main__":
	parser = argparse.ArgumentParser(description="Train a new Llama tokenizer")
	parser.add_argument(
		"--dataset_name",
		type=str,
		default=None,
		help="The name of the dataset to be tokenized",
	)
	parser.add_argument(
		"--dataset_split",
		type=str,
		default=None,
		help="The split of the dataset to be tokenized",
	)
	parser.add_argument(
		"--hub_token",
		type=str,
		default=None,
		help="The token to access the dataset on the hub",
	)
	parser.add_argument(
		"--reference_tokenizer",
		type=str,
		default=None,
		help="The name of the reference tokenizer to use",
	)
	parser.add_argument(
		"--seed",
		type=int,
		default=123,
		help="set random seed",
	)
	parser.add_argument(
		"--num_samples",
		type=int,
		default=None,
		help="Number of samples to use from the dataset",
	)
	parser.add_argument(
		"--vocab_size",
		type=int,
		default=None,
		help="Vocabulary size to use for the tokenizer",
	)
	args = parser.parse_args()
	main(args)

# How to run:
# python tokenizer_train.py --dataset_name /mimir/dataset/delivery/mimir_base/data/ --dataset_split train --reference_tokenizer meta-llama/Llama-2-7b-hf --vocab_size 32768 --hub_token hf_IIbKlx.... --num_samples 6000000