Spaces:
Paused
Paused
# Copyright 2020 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import json | |
import os | |
import shutil | |
import tempfile | |
from unittest import TestCase | |
from transformers import BartTokenizer, BartTokenizerFast, DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast | |
from transformers.models.bart.configuration_bart import BartConfig | |
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES | |
from transformers.models.dpr.configuration_dpr import DPRConfig | |
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES | |
from transformers.testing_utils import require_faiss, require_tokenizers, require_torch, slow | |
from transformers.utils import is_datasets_available, is_faiss_available, is_torch_available | |
if is_torch_available() and is_datasets_available() and is_faiss_available(): | |
from transformers.models.rag.configuration_rag import RagConfig | |
from transformers.models.rag.tokenization_rag import RagTokenizer | |
class RagTokenizerTest(TestCase): | |
def setUp(self): | |
self.tmpdirname = tempfile.mkdtemp() | |
self.retrieval_vector_size = 8 | |
# DPR tok | |
vocab_tokens = [ | |
"[UNK]", | |
"[CLS]", | |
"[SEP]", | |
"[PAD]", | |
"[MASK]", | |
"want", | |
"##want", | |
"##ed", | |
"wa", | |
"un", | |
"runn", | |
"##ing", | |
",", | |
"low", | |
"lowest", | |
] | |
dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer") | |
os.makedirs(dpr_tokenizer_path, exist_ok=True) | |
self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"]) | |
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: | |
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) | |
# BART tok | |
vocab = [ | |
"l", | |
"o", | |
"w", | |
"e", | |
"r", | |
"s", | |
"t", | |
"i", | |
"d", | |
"n", | |
"\u0120", | |
"\u0120l", | |
"\u0120n", | |
"\u0120lo", | |
"\u0120low", | |
"er", | |
"\u0120lowest", | |
"\u0120newer", | |
"\u0120wider", | |
"<unk>", | |
] | |
vocab_tokens = dict(zip(vocab, range(len(vocab)))) | |
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] | |
self.special_tokens_map = {"unk_token": "<unk>"} | |
bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer") | |
os.makedirs(bart_tokenizer_path, exist_ok=True) | |
self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"]) | |
self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"]) | |
with open(self.vocab_file, "w", encoding="utf-8") as fp: | |
fp.write(json.dumps(vocab_tokens) + "\n") | |
with open(self.merges_file, "w", encoding="utf-8") as fp: | |
fp.write("\n".join(merges)) | |
def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer: | |
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer")) | |
def get_bart_tokenizer(self) -> BartTokenizer: | |
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer")) | |
def tearDown(self): | |
shutil.rmtree(self.tmpdirname) | |
def test_save_load_pretrained_with_saved_config(self): | |
save_dir = os.path.join(self.tmpdirname, "rag_tokenizer") | |
rag_config = RagConfig(question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict()) | |
rag_tokenizer = RagTokenizer(question_encoder=self.get_dpr_tokenizer(), generator=self.get_bart_tokenizer()) | |
rag_config.save_pretrained(save_dir) | |
rag_tokenizer.save_pretrained(save_dir) | |
new_rag_tokenizer = RagTokenizer.from_pretrained(save_dir, config=rag_config) | |
self.assertIsInstance(new_rag_tokenizer.question_encoder, DPRQuestionEncoderTokenizerFast) | |
self.assertEqual(new_rag_tokenizer.question_encoder.get_vocab(), rag_tokenizer.question_encoder.get_vocab()) | |
self.assertIsInstance(new_rag_tokenizer.generator, BartTokenizerFast) | |
self.assertEqual(new_rag_tokenizer.generator.get_vocab(), rag_tokenizer.generator.get_vocab()) | |
def test_pretrained_token_nq_tokenizer(self): | |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") | |
input_strings = [ | |
"who got the first nobel prize in physics", | |
"when is the next deadpool movie being released", | |
"which mode is used for short wave broadcast service", | |
"who is the owner of reading football club", | |
"when is the next scandal episode coming out", | |
"when is the last time the philadelphia won the superbowl", | |
"what is the most current adobe flash player version", | |
"how many episodes are there in dragon ball z", | |
"what is the first step in the evolution of the eye", | |
"where is gall bladder situated in human body", | |
"what is the main mineral in lithium batteries", | |
"who is the president of usa right now", | |
"where do the greasers live in the outsiders", | |
"panda is a national animal of which country", | |
"what is the name of manchester united stadium", | |
] | |
input_dict = tokenizer(input_strings) | |
self.assertIsNotNone(input_dict) | |
def test_pretrained_sequence_nq_tokenizer(self): | |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") | |
input_strings = [ | |
"who got the first nobel prize in physics", | |
"when is the next deadpool movie being released", | |
"which mode is used for short wave broadcast service", | |
"who is the owner of reading football club", | |
"when is the next scandal episode coming out", | |
"when is the last time the philadelphia won the superbowl", | |
"what is the most current adobe flash player version", | |
"how many episodes are there in dragon ball z", | |
"what is the first step in the evolution of the eye", | |
"where is gall bladder situated in human body", | |
"what is the main mineral in lithium batteries", | |
"who is the president of usa right now", | |
"where do the greasers live in the outsiders", | |
"panda is a national animal of which country", | |
"what is the name of manchester united stadium", | |
] | |
input_dict = tokenizer(input_strings) | |
self.assertIsNotNone(input_dict) | |