from collections import Counter import datasets import transformers from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS from transformers.utils import logging logging.set_verbosity_info() TOKENIZER_CLASSES = { name: (getattr(transformers, name), getattr(transformers, name + "Fast")) for name in SLOW_TO_FAST_CONVERTERS } dataset = datasets.load_dataset("xnli", split="test+validation") total = 0 perfect = 0 imperfect = 0 wrong = 0 def check_diff(spm_diff, tok_diff, slow, fast): if spm_diff == list(reversed(tok_diff)): # AAA -> AA+A vs A+AA case. return True elif len(spm_diff) == len(tok_diff) and fast.decode(spm_diff) == fast.decode(tok_diff): # Second order OK # Barrich -> Barr + ich vs Bar + rich return True spm_reencoded = slow.encode(slow.decode(spm_diff)) tok_reencoded = fast.encode(fast.decode(spm_diff)) if spm_reencoded != spm_diff and spm_reencoded == tok_reencoded: # Type 3 error. # Snehagatha -> # Sne, h, aga, th, a # Sne, ha, gat, ha # Encoding the wrong with sp does not even recover what spm gave us # It fits tokenizer however... return True return False def check_LTR_mark(line, idx, fast): enc = fast.encode_plus(line)[0] offsets = enc.offsets curr, prev = offsets[idx], offsets[idx - 1] if curr is not None and line[curr[0] : curr[1]] == "\u200f": return True if prev is not None and line[prev[0] : prev[1]] == "\u200f": return True def check_details(line, spm_ids, tok_ids, slow, fast): # Encoding can be the same with same result AAA -> A + AA vs AA + A # We can check that we use at least exactly the same number of tokens. for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)): if spm_id != tok_id: break first = i for i, (spm_id, tok_id) in enumerate(zip(reversed(spm_ids), reversed(tok_ids))): if spm_id != tok_id: break last = len(spm_ids) - i spm_diff = spm_ids[first:last] tok_diff = tok_ids[first:last] if check_diff(spm_diff, tok_diff, slow, fast): return True if check_LTR_mark(line, first, fast): return True if last - first > 5: # We might have twice a single problem, attempt to subdivide the disjointed tokens into smaller problems spms = Counter(spm_ids[first:last]) toks = Counter(tok_ids[first:last]) removable_tokens = {spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si} min_width = 3 for i in range(last - first - min_width): if all(spm_ids[first + i + j] in removable_tokens for j in range(min_width)): possible_matches = [ k for k in range(last - first - min_width) if tok_ids[first + k : first + k + min_width] == spm_ids[first + i : first + i + min_width] ] for j in possible_matches: if check_diff(spm_ids[first : first + i], tok_ids[first : first + j], sp, tok) and check_details( line, spm_ids[first + i : last], tok_ids[first + j : last], slow, fast, ): return True print(f"Spm: {[fast.decode([spm_ids[i]]) for i in range(first, last)]}") try: print(f"Tok: {[fast.decode([tok_ids[i]]) for i in range(first, last)]}") except Exception: pass ok_start = fast.decode(spm_ids[:first]) ok_end = fast.decode(spm_ids[last:]) wrong = fast.decode(spm_ids[first:last]) print() print(wrong) return False def test_string(slow, fast, text): global perfect global imperfect global wrong global total slow_ids = slow.encode(text) fast_ids = fast.encode(text) skip_assert = False total += 1 if slow_ids != fast_ids: if check_details(text, slow_ids, fast_ids, slow, fast): skip_assert = True imperfect += 1 else: wrong += 1 else: perfect += 1 if total % 10000 == 0: print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})") if skip_assert: return assert ( slow_ids == fast_ids ), f"line {text} : \n\n{slow_ids}\n{fast_ids}\n\n{slow.tokenize(text)}\n{fast.tokenize(text)}" def test_tokenizer(slow, fast): global batch_total for i in range(len(dataset)): # premise, all languages for text in dataset[i]["premise"].values(): test_string(slow, fast, text) # hypothesis, all languages for text in dataset[i]["hypothesis"]["translation"]: test_string(slow, fast, text) if __name__ == "__main__": for name, (slow_class, fast_class) in TOKENIZER_CLASSES.items(): checkpoint_names = list(slow_class.max_model_input_sizes.keys()) for checkpoint in checkpoint_names: imperfect = 0 perfect = 0 wrong = 0 total = 0 print(f"========================== Checking {name}: {checkpoint} ==========================") slow = slow_class.from_pretrained(checkpoint, force_download=True) fast = fast_class.from_pretrained(checkpoint, force_download=True) test_tokenizer(slow, fast) print(f"Accuracy {perfect * 100 / total:.2f}")