|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
import unittest |
|
|
from functools import lru_cache |
|
|
|
|
|
from datasets import load_dataset |
|
|
|
|
|
from transformers import BloomTokenizerFast |
|
|
from transformers.testing_utils import require_jinja, require_tokenizers |
|
|
|
|
|
from ...test_tokenization_common import TokenizerTesterMixin, use_cache_if_possible |
|
|
|
|
|
|
|
|
@require_tokenizers |
|
|
class BloomTokenizationTest(TokenizerTesterMixin, unittest.TestCase): |
|
|
from_pretrained_id = "bigscience/tokenizer" |
|
|
slow_tokenizer_class = None |
|
|
rust_tokenizer_class = BloomTokenizerFast |
|
|
tokenizer_class = BloomTokenizerFast |
|
|
test_rust_tokenizer = True |
|
|
test_slow_tokenizer = False |
|
|
from_pretrained_vocab_key = "tokenizer_file" |
|
|
special_tokens_map = {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>"} |
|
|
|
|
|
@classmethod |
|
|
def setUpClass(cls): |
|
|
super().setUpClass() |
|
|
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/tokenizer") |
|
|
tokenizer.save_pretrained(cls.tmpdirname) |
|
|
|
|
|
@classmethod |
|
|
@use_cache_if_possible |
|
|
@lru_cache(maxsize=64) |
|
|
def get_rust_tokenizer(cls, pretrained_name=None, **kwargs): |
|
|
_kwargs = copy.deepcopy(cls.special_tokens_map) |
|
|
_kwargs.update(kwargs) |
|
|
kwargs = _kwargs |
|
|
pretrained_name = pretrained_name or cls.tmpdirname |
|
|
return BloomTokenizerFast.from_pretrained(pretrained_name, **kwargs) |
|
|
|
|
|
@unittest.skip(reason="This needs a slow tokenizer. Bloom does not have one!") |
|
|
def test_encode_decode_with_spaces(self): |
|
|
return |
|
|
|
|
|
def test_encodings_from_sample_data(self): |
|
|
""" |
|
|
Assert that the created tokens are the same than the hard-coded ones |
|
|
""" |
|
|
tokenizer = self.get_rust_tokenizer() |
|
|
|
|
|
INPUT_SENTENCES = ["The quick brown fox</s>", "jumps over the lazy dog</s>"] |
|
|
TARGET_TOKENS = [[2175, 23714, 73173, 144252, 2], [77, 132619, 3478, 368, 109586, 35433, 2]] |
|
|
|
|
|
computed_tokens = tokenizer.batch_encode_plus(INPUT_SENTENCES)["input_ids"] |
|
|
self.assertListEqual(TARGET_TOKENS, computed_tokens) |
|
|
|
|
|
decoded_tokens = tokenizer.batch_decode(computed_tokens) |
|
|
self.assertListEqual(decoded_tokens, INPUT_SENTENCES) |
|
|
|
|
|
def test_padding(self, max_length=6): |
|
|
for tokenizer, pretrained_name, kwargs in self.tokenizers_list: |
|
|
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): |
|
|
tokenizer_r = self.get_rust_tokenizer(pretrained_name, **kwargs) |
|
|
|
|
|
|
|
|
s = "This is a simple input" |
|
|
s2 = ["This is a simple input 1", "This is a simple input 2"] |
|
|
p = ("This is a simple input", "This is a pair") |
|
|
p2 = [ |
|
|
("This is a simple input 1", "This is a simple input 2"), |
|
|
("This is a simple pair 1", "This is a simple pair 2"), |
|
|
] |
|
|
|
|
|
|
|
|
try: |
|
|
tokenizer_r.encode(s, max_length=max_length) |
|
|
tokenizer_r.encode_plus(s, max_length=max_length) |
|
|
|
|
|
tokenizer_r.batch_encode_plus(s2, max_length=max_length) |
|
|
tokenizer_r.encode(p, max_length=max_length) |
|
|
tokenizer_r.batch_encode_plus(p2, max_length=max_length) |
|
|
except ValueError: |
|
|
self.fail("Bloom Tokenizer should be able to deal with padding") |
|
|
|
|
|
tokenizer_r.pad_token = None |
|
|
self.assertRaises(ValueError, tokenizer_r.encode, s, max_length=max_length, padding="max_length") |
|
|
|
|
|
|
|
|
self.assertRaises(ValueError, tokenizer_r.encode_plus, s, max_length=max_length, padding="max_length") |
|
|
|
|
|
|
|
|
self.assertRaises( |
|
|
ValueError, |
|
|
tokenizer_r.batch_encode_plus, |
|
|
s2, |
|
|
max_length=max_length, |
|
|
padding="max_length", |
|
|
) |
|
|
|
|
|
|
|
|
self.assertRaises(ValueError, tokenizer_r.encode, p, max_length=max_length, padding="max_length") |
|
|
|
|
|
|
|
|
self.assertRaises(ValueError, tokenizer_r.encode_plus, p, max_length=max_length, padding="max_length") |
|
|
|
|
|
|
|
|
self.assertRaises( |
|
|
ValueError, |
|
|
tokenizer_r.batch_encode_plus, |
|
|
p2, |
|
|
max_length=max_length, |
|
|
padding="max_length", |
|
|
) |
|
|
|
|
|
def test_encodings_from_xnli_dataset(self): |
|
|
""" |
|
|
Tests the tokenizer downloaded from here: |
|
|
- https://huggingface.co/bigscience/tokenizer/ |
|
|
""" |
|
|
tokenizer = self.get_rust_tokenizer() |
|
|
ds = load_dataset("facebook/xnli", "all_languages", split="test", streaming=True) |
|
|
|
|
|
sample_data = next(iter(ds))["premise"] |
|
|
input_text = list(sample_data.values()) |
|
|
|
|
|
output_tokens = list(map(tokenizer.encode, input_text)) |
|
|
predicted_text = [tokenizer.decode(x, clean_up_tokenization_spaces=False) for x in output_tokens] |
|
|
self.assertListEqual(predicted_text, input_text) |
|
|
|
|
|
@require_jinja |
|
|
def test_tokenization_for_chat(self): |
|
|
tokenizer = self.get_rust_tokenizer() |
|
|
tokenizer.chat_template = "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}" |
|
|
test_chats = [ |
|
|
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}], |
|
|
[ |
|
|
{"role": "system", "content": "You are a helpful chatbot."}, |
|
|
{"role": "user", "content": "Hello!"}, |
|
|
{"role": "assistant", "content": "Nice to meet you."}, |
|
|
], |
|
|
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}], |
|
|
] |
|
|
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats] |
|
|
expected_tokens = [ |
|
|
[5448, 1306, 267, 66799, 44799, 37143, 17, 2, 59414, 4, 2], |
|
|
[5448, 1306, 267, 66799, 44799, 37143, 17, 2, 59414, 4, 2, 229126, 427, 11890, 1152, 17, 2], |
|
|
[229126, 427, 11890, 1152, 17, 2, 59414, 4, 2], |
|
|
] |
|
|
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens): |
|
|
self.assertListEqual(tokenized_chat, expected_tokens) |
|
|
|
|
|
def test_add_prefix_space_fast(self): |
|
|
tokenizer_w_prefix = self.get_rust_tokenizer(add_prefix_space=True) |
|
|
tokenizer_wo_prefix = self.get_rust_tokenizer(add_prefix_space=False) |
|
|
tokens_w_prefix = tokenizer_w_prefix.tokenize("Hey") |
|
|
tokens_wo_prefix = tokenizer_wo_prefix.tokenize("Hey") |
|
|
self.assertNotEqual(tokens_w_prefix, tokens_wo_prefix) |
|
|
|