HindiTokenizer / src /Basictokenizer.py
Manu101's picture
Upload 12 files
693faa9 verified
raw
history blame
7.37 kB
"""
Minimal (byte-level) Byte Pair Encoding tokenizer.
Algorithmically follows along the GPT tokenizer:
https://github.com/openai/gpt-2/blob/master/src/encoder.py
But:
- Does not handle the regular expression splitting pattern.
- Does not handle any special tokens.
"""
import copy
from .base import Tokenizer, get_stats, merge
# class BasicTokenizer(Tokenizer):
#
# def __init__(self):
# super().__init__()
#
# def train(self, text, vocab_size, verbose=False):
# assert vocab_size >= 256
# num_merges = vocab_size - 256
#
# # input text preprocessing
# text_bytes = text.encode("utf-8") # raw bytes
# ids = list(text_bytes) # list of integers in range 0..255
#
# # iteratively merge the most common pairs to create new tokens
# merges = {} # (int, int) -> int
# vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
# for i in range(num_merges):
# # count up the number of times every consecutive pair appears
# stats = get_stats(ids)
# # find the pair with the highest count
# pair = max(stats, key=stats.get)
# # mint a new token: assign it the next available id
# idx = 256 + i
# # replace all occurrences of pair in ids with idx
# ids = merge(ids, pair, idx)
# # save the merge
# merges[pair] = idx
# vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# # prints
# if verbose:
# print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
#
# # save class variables
# self.merges = merges # used in encode()
# self.vocab = vocab # used in decode()
#
# def decode(self, ids):
# # given ids (list of integers), return Python string
# text_bytes = b"".join(self.vocab[idx] for idx in ids)
# text = text_bytes.decode("utf-8", errors="replace")
# return text
#
# def encode(self, text):
# # given a string text, return the token ids
# text_bytes = text.encode("utf-8") # raw bytes
# ids = list(text_bytes) # list of integers in range 0..255
# while len(ids) >= 2:
# # find the pair with the lowest merge index
# stats = get_stats(ids)
# pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
# # subtle: if there are no more merges available, the key will
# # result in an inf for every single pair, and the min will be
# # just the first pair in the list, arbitrarily
# # we can detect this terminating case by a membership check
# if pair not in self.merges:
# break # nothing else can be merged anymore
# # otherwise let's merge the best pair (lowest merge index)
# idx = self.merges[pair]
# ids = merge(ids, pair, idx)
# return ids
class BasicTokenizer(Tokenizer):
def __init__(self):
super().__init__()
self.merge_counter = 0
def train(self, text, vocab_size, verbose=False):
# left assert in place just to introduce consistency and a hard check of the increase in vocab size and number of merges
assert vocab_size >= 256
num_merges = vocab_size - 256
current_batch_merge_counter = 0 # in case not all exact `num_merges` happen
# input text preprocessing
text_bytes = text.encode("utf-8") # encode to get all waw bytes
ids = list(text_bytes) # represent the bytes in ints
# use same merge dict if exists
self.merges = {} if self.merges is None else self.merges # to hold all merges (int, int) -> int
# Use same vocab for this Tokenizer object if it exists
# Tokenizer vocab: int -> bytes
self.vocab = {idx: bytes([idx]) for idx in range(256)} if self.vocab is None else self.vocab
# iteratively merge the MOST COMMON pair from the text
for i in range(num_merges):
# get count of pairs
stats = get_stats(ids)
# find the pair with the highest count
# pair = max(stats, key=stats.get)
# tmp_stats = copy.deepcopy(stats)
# get most occurring pair from ids
pair = max(stats, key=stats.get)
while pair in self.merges:
# pair was previously merged ... use this first to update IDS
# No need to add to merges and vocab, use previously stored token
already_merged_idx = self.merges[pair]
# just replace already merged pairs in ids and get new ids and no need to again add to merges and vocab
ids = merge(ids, pair, already_merged_idx)
stats = get_stats(ids)
if stats and len(ids) >= 2:
pair = max(stats, key=stats.get)
else:
# no new merges found in this incoming data batch
print(f"\n\nstopping merges as no new byte pair found in the current batch")
break
# this most occurring pair not merged yet in any data batch
# generate a new token considering how many have been generated so far for the same tokenizer
idx = len(self.vocab) + 1
# update current new generated tokens to add to self.merge_counter later
current_batch_merge_counter += 1
# replace all occurrences of `pair` above in `ids` with NEW `idx` token, add this one to merges & vocab
# Note: this pair has never been seen for merging
ids = merge(ids, pair, idx)
self.merges[pair] = idx
self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]]
if verbose:
print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({self.vocab[idx]}) had {stats[pair]} count")
self.merge_counter += current_batch_merge_counter
def decode(self, ids):
# given ids (list of integers), return Python string
text_bytes = b"".join(self.vocab[idx] for idx in ids)
text = text_bytes.decode("utf-8", errors="replace")
return text
def encode(self, text):
# input a string text, returns the token ids
text_bytes = text.encode("utf-8")
ids = list(text_bytes)
while len(ids) >= 2:
# here find the pair with the lowest merge index
stats = get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
# if no merges i.e. the pair is not in merges dict,
# the key will result in an `inf` for every single pair,
# and the min will be just the first pair in the list,
# we can detect this terminating case by a membership check
if pair not in self.merges:
break # nothing else can be merged anymore
# otherwise merge the best pair NOTE: (lowest merge index)
idx = self.merges[pair]
ids = merge(ids, pair, idx)
return ids