|
""" |
|
bpe is short for Byte Pair Encoder. It translates arbitrary utf-8 strings into |
|
sequences of integers, where each integer represents small chunks of commonly |
|
occuring characters. This implementation is based on openai's gpt2 encoder.py: |
|
https://github.com/openai/gpt-2/blob/master/src/encoder.py |
|
but was mildly modified because the original implementation is a bit confusing. |
|
I also tried to add as many comments as possible, my own understanding of what's |
|
going on. |
|
""" |
|
|
|
import os |
|
import json |
|
import regex as re |
|
import requests |
|
|
|
import torch |
|
|
|
|
|
|
|
def bytes_to_unicode(): |
|
""" |
|
Every possible byte (really an integer 0..255) gets mapped by OpenAI to a unicode |
|
character that represents it visually. Some bytes have their appearance preserved |
|
because they don't cause any trouble. These are defined in list bs. For example: |
|
chr(33) returns "!", so in the returned dictionary we simply have d[33] -> "!". |
|
However, chr(0), for example, is '\x00', which looks ugly. So OpenAI maps these |
|
bytes, into new characters in a range where chr() returns a single nice character. |
|
So in the final dictionary we have d[0] -> 'Ā' instead, which is just chr(0 + 2**8). |
|
In particular, the space character is 32, which we can see by ord(' '). Instead, |
|
this function will shift space (32) by 256 to 288, so d[32] -> 'Ġ'. |
|
So this is just a simple one-to-one mapping of bytes 0..255 into unicode characters |
|
that "look nice", either in their original form, or a funny shifted character |
|
like 'Ā', or 'Ġ', etc. |
|
""" |
|
|
|
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) |
|
cs = bs[:] |
|
|
|
|
|
n = 0 |
|
for b in range(2**8): |
|
if b not in bs: |
|
|
|
bs.append(b) |
|
cs.append(2**8+n) |
|
n += 1 |
|
cs = [chr(n) for n in cs] |
|
d = dict(zip(bs, cs)) |
|
return d |
|
|
|
def get_pairs(word): |
|
""" |
|
Return all bigrams as a set of tuples, of consecutive elements in the iterable word. |
|
""" |
|
pairs = set() |
|
prev_char = word[0] |
|
for char in word[1:]: |
|
pairs.add((prev_char, char)) |
|
prev_char = char |
|
return pairs |
|
|
|
class Encoder: |
|
|
|
def __init__(self, encoder, bpe_merges): |
|
|
|
self.byte_encoder = bytes_to_unicode() |
|
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} |
|
|
|
self.encoder = encoder |
|
self.decoder = {v:k for k,v in self.encoder.items()} |
|
|
|
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) |
|
|
|
|
|
""" |
|
ok so what is this regex looking for, exactly? |
|
python re reference: https://docs.python.org/3/library/re.html |
|
- the vertical bars | is OR, so re.findall will chunkate text as the pieces match, from left to right |
|
- '\'s' would split up things like Andrej's -> (Andrej, 's) |
|
- ' ?\p{L}': optional space followed by 1+ unicode code points in the category "letter" |
|
- ' ?\p{N}': optional space followed by 1+ unicode code points in the category "number" |
|
- ' ?[^\s\p{L}\p{N}]+': optional space, then 1+ things that are NOT a whitespace, letter or number |
|
- '\s+(?!\S)': 1+ whitespace characters (e.g. space or tab or etc) UNLESS they are followed by non-whitespace |
|
so this will consume whitespace characters in a sequence but exclude the last whitespace in |
|
that sequence. that last whitespace has the opportunity to then match the optional ' ?' in |
|
earlier patterns. |
|
- '\s+': 1+ whitespace characters, intended probably to catch a full trailing sequence of whitespaces at end of string |
|
So TLDR: |
|
- we are special casing a few common apostrophe constructs ('s, 't, 're, ...) and making those into separate tokens |
|
- we then separate out strings into consecutive chunks of 1) letters, 2) numbers, 3) non-letter-numbers, 4) whitespaces |
|
""" |
|
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") |
|
self.cache = {} |
|
|
|
def bpe(self, token): |
|
""" |
|
this function uses self.bpe_ranks to iteratively merge all the possible bpe tokens |
|
up the tree. token is a string of one individual 'word' (after regex tokenization) |
|
and after byte encoding, e.g. 'Ġthere'. |
|
""" |
|
|
|
|
|
|
|
if token in self.cache: |
|
return self.cache[token] |
|
|
|
word = tuple(token) |
|
pairs = get_pairs(word) |
|
|
|
if not pairs: |
|
return token |
|
|
|
while True: |
|
|
|
|
|
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) |
|
if bigram not in self.bpe_ranks: |
|
break |
|
first, second = bigram |
|
|
|
|
|
|
|
new_word = [] |
|
i = 0 |
|
while i < len(word): |
|
|
|
|
|
try: |
|
j = word.index(first, i) |
|
new_word.extend(word[i:j]) |
|
i = j |
|
except: |
|
new_word.extend(word[i:]) |
|
break |
|
|
|
|
|
if word[i] == first and i < len(word)-1 and word[i+1] == second: |
|
new_word.append(first+second) |
|
i += 2 |
|
else: |
|
new_word.append(word[i]) |
|
i += 1 |
|
|
|
|
|
new_word = tuple(new_word) |
|
word = new_word |
|
if len(word) == 1: |
|
break |
|
else: |
|
pairs = get_pairs(word) |
|
|
|
|
|
|
|
|
|
word = ' '.join(word) |
|
|
|
|
|
self.cache[token] = word |
|
return word |
|
|
|
def encode(self, text): |
|
""" string goes in, list of integers comes out """ |
|
bpe_idx = [] |
|
|
|
tokens = re.findall(self.pat, text) |
|
|
|
for token in tokens: |
|
|
|
token_bytes = token.encode('utf-8') |
|
|
|
token_translated = ''.join(self.byte_encoder[b] for b in token_bytes) |
|
|
|
token_merged = self.bpe(token_translated).split(' ') |
|
|
|
token_ix = [self.encoder[bpe_token] for bpe_token in token_merged] |
|
|
|
bpe_idx.extend(token_ix) |
|
return bpe_idx |
|
|
|
def encode_and_show_work(self, text): |
|
""" debugging function, same as encode but returns all intermediate work """ |
|
bpe_idx = [] |
|
parts = [] |
|
tokens = re.findall(self.pat, text) |
|
for token in tokens: |
|
token_bytes = token.encode('utf-8') |
|
token_translated = ''.join(self.byte_encoder[b] for b in token_bytes) |
|
token_merged = self.bpe(token_translated).split(' ') |
|
token_ix = [self.encoder[bpe_token] for bpe_token in token_merged] |
|
bpe_idx.extend(token_ix) |
|
parts.append({ |
|
'token': token, |
|
'token_bytes': token_bytes, |
|
'token_translated': token_translated, |
|
'token_merged': token_merged, |
|
'token_ix': token_ix, |
|
}) |
|
out = { |
|
'bpe_idx': bpe_idx, |
|
'tokens': tokens, |
|
'parts': parts, |
|
} |
|
return out |
|
|
|
def decode(self, bpe_idx): |
|
""" list of integers comes in, string comes out """ |
|
|
|
tokens_merged = [self.decoder[token] for token in bpe_idx] |
|
|
|
tokens_flat = ''.join(tokens_merged) |
|
tokens_bytes = bytearray([self.byte_decoder[c] for c in tokens_flat]) |
|
|
|
text = tokens_bytes.decode('utf-8', errors='replace') |
|
return text |
|
|
|
def get_file(local_file, remote_file): |
|
""" downloads remote_file to local_file if necessary """ |
|
if not os.path.isfile(local_file): |
|
print(f"downloading {remote_file} to {local_file}") |
|
response = requests.get(remote_file) |
|
open(local_file, "wb").write(response.content) |
|
|
|
def get_encoder(): |
|
""" |
|
Returns an instance of the GPT BPE Encoder/Decoder |
|
and handles caching of "database" files. |
|
""" |
|
home_dir = os.path.expanduser('~') |
|
cache_dir = os.path.join(home_dir, '.cache', 'mingpt') |
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
|
|
encoder_local_file = os.path.join(cache_dir, 'encoder.json') |
|
encoder_remote_file = 'https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json' |
|
get_file(encoder_local_file, encoder_remote_file) |
|
with open(encoder_local_file, 'r') as f: |
|
encoder = json.load(f) |
|
assert len(encoder) == 50257 |
|
|
|
|
|
|
|
vocab_local_file = os.path.join(cache_dir, 'vocab.bpe') |
|
vocab_remote_file = 'https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe' |
|
get_file(vocab_local_file, vocab_remote_file) |
|
with open(vocab_local_file, 'r', encoding="utf-8") as f: |
|
bpe_data = f.read() |
|
|
|
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] |
|
assert len(bpe_merges) == 50000 |
|
|
|
|
|
enc = Encoder(encoder, bpe_merges) |
|
return enc |
|
|
|
|
|
|
|
class BPETokenizer: |
|
""" PyTorch-aware class that wraps the Encoder above """ |
|
|
|
def __init__(self): |
|
self.encoder = get_encoder() |
|
|
|
def __call__(self, text, return_tensors='pt'): |
|
|
|
assert return_tensors == 'pt' |
|
|
|
assert isinstance(text, str) |
|
|
|
idx = [self.encoder.encode(text)] |
|
|
|
out = torch.tensor(idx, dtype=torch.long) |
|
return out |
|
|
|
def decode(self, idx): |
|
|
|
assert idx.ndim == 1 |
|
|
|
text = self.encoder.decode(idx.tolist()) |
|
return text |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
text = "Hello!! I'm Andrej Karpathy. It's 2022. w00t :D 🤗" |
|
e = get_encoder() |
|
r = e.encode_and_show_work(text) |
|
|
|
print("Original text is:") |
|
print(text) |
|
print("First the text gets pre-tokenized, broken up into chunks, the outcome is:") |
|
print(r['tokens']) |
|
|
|
print("Then we iterate over each chunk and process them in turn...") |
|
for part in r['parts']: |
|
print(part) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("and the final outcome is concatenating and flattening all the token_ix:") |
|
print(r['bpe_idx']) |
|
|
|
|
|
print("ready to feed into a Transformer!") |
|
|