JMalott commited on
Commit
da99fb7
1 Parent(s): 3514e2a

Upload text_tokenizer.py

Browse files
Files changed (1) hide show
  1. min_dalle/text_tokenizer.py +41 -0
min_dalle/text_tokenizer.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import inf
2
+ from typing import List, Tuple
3
+ from emoji import demojize
4
+
5
+ class TextTokenizer:
6
+ def __init__(self, vocab: dict, merges: List[str]):
7
+ self.token_from_subword = vocab
8
+ pairs = [tuple(pair.split()) for pair in merges]
9
+ self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
10
+
11
+ def tokenize(self, text: str, is_verbose: bool = False) -> List[int]:
12
+ sep_token = self.token_from_subword['</s>']
13
+ cls_token = self.token_from_subword['<s>']
14
+ unk_token = self.token_from_subword['<unk>']
15
+ text = demojize(text, delimiters=['', ''])
16
+ text = text.lower().encode("ascii", errors="ignore").decode()
17
+ tokens = [
18
+ self.token_from_subword.get(subword, unk_token)
19
+ for word in text.split(" ") if len(word) > 0
20
+ for subword in self.get_byte_pair_encoding(word, is_verbose)
21
+ ]
22
+ return [cls_token] + tokens + [sep_token]
23
+
24
+ def get_byte_pair_encoding(self, word: str, is_verbose: bool) -> List[str]:
25
+ def get_pair_rank(pair: Tuple[str, str]) -> int:
26
+ return self.rank_from_pair.get(pair, inf)
27
+
28
+ subwords = [chr(ord(" ") + 256)] + list(word)
29
+ while len(subwords) > 1:
30
+ pairs = list(zip(subwords[:-1], subwords[1:]))
31
+ pair_to_merge = min(pairs, key=get_pair_rank)
32
+ if pair_to_merge not in self.rank_from_pair: break
33
+ i = pairs.index(pair_to_merge)
34
+ subwords = (
35
+ (subwords[:i] if i > 0 else []) +
36
+ [subwords[i] + subwords[i + 1]] +
37
+ (subwords[i + 2:] if i + 2 < len(subwords) else [])
38
+ )
39
+
40
+ if is_verbose: print(subwords)
41
+ return subwords