|
"""This is an educational implementation of the byte pair encoding algorithm.""" |
|
import collections |
|
from typing import Optional |
|
|
|
import regex |
|
|
|
import tiktoken |
|
|
|
|
|
class OBITokenizer: |
|
def __init__(self, *, pat_str: str, mergeable_ranks: dict[bytes, int]) -> None: |
|
"""Creates an Encoding object.""" |
|
|
|
self.pat_str = pat_str |
|
|
|
self.mergeable_ranks = mergeable_ranks |
|
|
|
self._decoder = {token: token_bytes for token_bytes, token in mergeable_ranks.items()} |
|
self._pat = regex.compile(pat_str) |
|
|
|
def encode(self, text: str, visualise: Optional[str] = "colour") -> list[int]: |
|
"""Encodes a string into tokens. |
|
|
|
>>> enc.encode("hello world") |
|
[388, 372] |
|
""" |
|
|
|
words = self._pat.findall(text) |
|
tokens = [] |
|
for word in words: |
|
|
|
word_bytes = word.encode("utf-8") |
|
word_tokens = bpe_encode(self.mergeable_ranks, word_bytes, visualise=visualise) |
|
tokens.extend(word_tokens) |
|
return tokens |
|
|
|
def decode_bytes(self, tokens: list[int]) -> bytes: |
|
"""Decodes a list of tokens into bytes. |
|
|
|
>>> enc.decode_bytes([388, 372]) |
|
b'hello world' |
|
""" |
|
return b"".join(self._decoder[token] for token in tokens) |
|
|
|
def decode(self, tokens: list[int]) -> str: |
|
"""Decodes a list of tokens into a string. |
|
|
|
Decoded bytes are not guaranteed to be valid UTF-8. In that case, we replace |
|
the invalid bytes with the replacement character "�". |
|
|
|
>>> enc.decode([388, 372]) |
|
'hello world' |
|
""" |
|
return self.decode_bytes(tokens).decode("utf-8", errors="replace") |
|
|
|
def decode_tokens_bytes(self, tokens: list[int]) -> list[bytes]: |
|
"""Decodes a list of tokens into a list of bytes. |
|
|
|
Useful for visualising how a string is tokenised. |
|
|
|
>>> enc.decode_tokens_bytes([388, 372]) |
|
[b'hello', b' world'] |
|
""" |
|
return [self._decoder[token] for token in tokens] |
|
|
|
@staticmethod |
|
def train(training_data: str, vocab_size: int, pat_str: str): |
|
"""Train a BPE tokeniser on some data!""" |
|
mergeable_ranks = bpe_train(data=training_data, vocab_size=vocab_size, pat_str=pat_str) |
|
return OBITokenizer(pat_str=pat_str, mergeable_ranks=mergeable_ranks) |
|
|
|
@staticmethod |
|
def from_tiktoken(encoding): |
|
if isinstance(encoding, str): |
|
encoding = tiktoken.get_encoding(encoding) |
|
return OBITokenizer( |
|
pat_str=encoding._pat_str, mergeable_ranks=encoding._mergeable_ranks |
|
) |
|
|
|
|
|
def bpe_encode( |
|
mergeable_ranks: dict[bytes, int], input: bytes, visualise: Optional[str] = "colour" |
|
) -> list[int]: |
|
parts = [bytes([b]) for b in input] |
|
while True: |
|
|
|
if visualise: |
|
if visualise in ["colour", "color"]: |
|
visualise_tokens(parts) |
|
elif visualise == "simple": |
|
print(parts) |
|
|
|
|
|
min_idx = None |
|
min_rank = None |
|
for i, pair in enumerate(zip(parts[:-1], parts[1:])): |
|
rank = mergeable_ranks.get(pair[0] + pair[1]) |
|
|
|
if rank is not None and (min_rank is None or rank < min_rank): |
|
min_idx = i |
|
min_rank = rank |
|
|
|
|
|
if min_rank is None: |
|
break |
|
assert min_idx is not None |
|
|
|
|
|
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :] |
|
|
|
if visualise: |
|
print() |
|
|
|
tokens = [mergeable_ranks[part] for part in parts] |
|
return tokens |
|
|
|
|
|
def bpe_train( |
|
data: str, vocab_size: int, pat_str: str, visualise: Optional[str] = "colour" |
|
) -> dict[bytes, int]: |
|
|
|
if vocab_size < 2**8: |
|
raise ValueError("vocab_size must be at least 256, so we can encode all bytes") |
|
ranks = {} |
|
for i in range(2**8): |
|
ranks[bytes([i])] = i |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
words: list[list[bytes]] = [ |
|
[bytes([b]) for b in word.encode("utf-8")] for word in regex.findall(pat_str, data) |
|
] |
|
|
|
|
|
while len(ranks) < vocab_size: |
|
|
|
stats = collections.Counter() |
|
for piece in words: |
|
for pair in zip(piece[:-1], piece[1:]): |
|
stats[pair] += 1 |
|
|
|
most_common_pair = max(stats, key=lambda x: stats[x]) |
|
token_bytes = most_common_pair[0] + most_common_pair[1] |
|
token = len(ranks) |
|
|
|
ranks[token_bytes] = token |
|
|
|
|
|
|
|
new_words = [] |
|
for word in words: |
|
new_word = [] |
|
i = 0 |
|
while i < len(word) - 1: |
|
if (word[i], word[i + 1]) == most_common_pair: |
|
|
|
new_word.append(token_bytes) |
|
i += 2 |
|
else: |
|
new_word.append(word[i]) |
|
i += 1 |
|
if i == len(word) - 1: |
|
new_word.append(word[i]) |
|
new_words.append(new_word) |
|
words = new_words |
|
|
|
|
|
if visualise: |
|
print(f"The current most common pair is {most_common_pair[0]} + {most_common_pair[1]}") |
|
print(f"So we made {token_bytes} our {len(ranks)}th token") |
|
if visualise in ["colour", "color"]: |
|
print("Now the first fifty words in our training data look like:") |
|
visualise_tokens([token for word in words[:50] for token in word]) |
|
elif visualise == "simple": |
|
print("Now the first twenty words in our training data look like:") |
|
for word in words[:20]: |
|
print(word) |
|
print("\n") |
|
|
|
return ranks |
|
|
|
|
|
def visualise_tokens(token_values: list[bytes]) -> None: |
|
background = [f"\u001b[48;5;{i}m" for i in [167, 179, 185, 77, 80, 68, 134]] |
|
|
|
|
|
|
|
unicode_token_values = [x.decode("utf-8", errors="replace") for x in token_values] |
|
|
|
running_length = 0 |
|
last_color = None |
|
for token in unicode_token_values: |
|
color = background[running_length % len(background)] |
|
if color == last_color: |
|
color = background[(running_length + 1) % len(background)] |
|
assert color != last_color |
|
last_color = color |
|
running_length += len(token) |
|
print(color + token, end="") |
|
print("\u001b[0m") |
|
|
|
|
|
def train_simple_encoding(): |
|
gpt2_pattern = ( |
|
r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" |
|
) |
|
with open(__file__, "r") as f: |
|
data = f.read() |
|
|
|
enc = OBITokenizer.train(data, vocab_size=600, pat_str=gpt2_pattern) |
|
|
|
print("This is the sequence of merges performed in order to encode 'hello world':") |
|
tokens = enc.encode("hello world") |
|
assert enc.decode(tokens) == "hello world" |
|
assert enc.decode_bytes(tokens) == b"hello world" |
|
assert enc.decode_tokens_bytes(tokens) == [b"hello", b" world"] |
|
|
|
return enc |