Source code for transformers.models.transfo_xl.tokenization_transfo_xl

# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
 Tokenization classes for Transformer XL model. Adapted from https://github.com/kimiyoung/transformer-xl.
"""


import glob
import os
import pickle
import re
from collections import Counter, OrderedDict
from typing import List, Optional, Tuple

import numpy as np

import sacremoses as sm

from ...file_utils import cached_path, is_torch_available, torch_only_method
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging


if is_torch_available():
    import torch


logger = logging.get_logger(__name__)

VOCAB_FILES_NAMES = {
    "pretrained_vocab_file": "vocab.pkl",
    "pretrained_vocab_file_torch": "vocab.bin",
    "vocab_file": "vocab.txt",
}

PRETRAINED_VOCAB_FILES_MAP = {
    "pretrained_vocab_file": {
        "transfo-xl-wt103": "https://huggingface.co/transfo-xl-wt103/resolve/main/vocab.pkl",
    }
}

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "transfo-xl-wt103": None,
}

PRETRAINED_CORPUS_ARCHIVE_MAP = {
    "transfo-xl-wt103": "https://huggingface.co/transfo-xl-wt103/resolve/main/corpus.bin",
}
CORPUS_NAME = "corpus.bin"

MATCH_NUMBERS = r"(?<=\d)[,.](?=\d)", r" @\g<0>@ "
DETOKENIZE_NUMBERS = [(r" @\,@ ", r","), (r" @\.@ ", r".")]


def tokenize_numbers(text_array: List[str]) -> List[str]:
    """
    Splits large comma-separated numbers and floating point values. This is done by replacing commas with ' @,@ ' and
    dots with ' @.@ '.

    Args:
        text_array: An already tokenized text as list.

    Returns:
        A list of strings with tokenized numbers.

    Example::
        >>> tokenize_numbers(["$", "5,000", "1.73", "m"])
        ["$", "5", "@,@", "000", "1", "@.@", "73", "m"]
    """
    tokenized = []
    for i in range(len(text_array)):
        reg, sub = MATCH_NUMBERS
        replaced = re.sub(reg, sub, text_array[i]).split()
        tokenized.extend(replaced)

    return tokenized


def detokenize_numbers(text: str) -> str:
    """
    Inverts the operation of `tokenize_numbers`. This is replacing ' @,@ ' and ' @.@' by ',' and '.'.

    Args:
        text: A string where the number should be detokenized.

    Returns:
        A detokenized string.

    Example::
        >>> detokenize_numbers("$ 5 @,@ 000 1 @.@ 73 m")
        "$ 5,000 1.73 m"
    """
    for reg, sub in DETOKENIZE_NUMBERS:
        text = re.sub(reg, sub, text)
    return text


[docs]class TransfoXLTokenizer(PreTrainedTokenizer): """ Construct a Transformer-XL tokenizer adapted from Vocab class in `the original code <https://github.com/kimiyoung/transformer-xl>`__. The Transformer-XL tokenizer is a word-level tokenizer (no sub-word tokenization). This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. Users should refer to this superclass for more information regarding those methods. Args: special (:obj:`List[str]`, `optional`): A list of special tokens (to be treated by the original implementation of this tokenizer). min_freq (:obj:`int`, `optional`, defaults to 0): The minimum number of times a token has to be present in order to be kept in the vocabulary (otherwise it will be mapped to :obj:`unk_token`). max_size (:obj:`int`, `optional`): The maximum size of the vocabulary. If left unset, it will default to the size of the vocabulary found after excluding the tokens according to the :obj:`min_freq` rule. lower_case (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to lowercase the input when tokenizing. delimiter (:obj:`str`, `optional`): The delimiter used between tokens. vocab_file (:obj:`str`, `optional`): File containing the vocabulary (from the original implementation). pretrained_vocab_file (:obj:`str`, `optional`): File containing the vocabulary as saved with the :obj:`save_pretrained()` method. never_split (:obj:`List[str]`, `optional`): List of tokens that should never be split. If no list is specified, will simply use the existing special tokens. unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`): The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead. eos_token (:obj:`str`, `optional`, defaults to :obj:`"<eos>"`): The end of sequence token. additional_special_tokens (:obj:`List[str]`, `optional`, defaults to :obj:`["<formula>"]`): A list of additional special tokens (for the HuggingFace functionality). language (:obj:`str`, `optional`, defaults to :obj:`"en"`): The language of this tokenizer (used for mose preprocessing). """ vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES model_input_names = ["input_ids"] def __init__( self, special=None, min_freq=0, max_size=None, lower_case=False, delimiter=None, vocab_file=None, pretrained_vocab_file: str = None, never_split=None, unk_token="<unk>", eos_token="<eos>", additional_special_tokens=["<formula>"], language="en", **kwargs ): super().__init__( special=special, min_freq=min_freq, max_size=max_size, lower_case=lower_case, delimiter=delimiter, vocab_file=vocab_file, pretrained_vocab_file=pretrained_vocab_file, never_split=never_split, unk_token=unk_token, eos_token=eos_token, additional_special_tokens=additional_special_tokens, language=language, **kwargs, ) if never_split is None: never_split = self.all_special_tokens if special is None: special = [] self.counter = Counter() self.special = special self.min_freq = min_freq self.max_size = max_size self.lower_case = lower_case self.delimiter = delimiter self.vocab_file = vocab_file self.never_split = never_split self.punctuation_symbols = '!"#$%&()*+,-./\\:;<=>?@[\\]^_`{|}~' self.punction_without_space_before_pattern = re.compile(rf"[^\s][{self.punctuation_symbols}]") self.punctuation_with_space_around_pattern = self._compile_space_around_punctuation_pattern() self.language = language self.moses_punct_normalizer = sm.MosesPunctNormalizer(language) self.moses_tokenizer = sm.MosesTokenizer(language) self.moses_detokenizer = sm.MosesDetokenizer(language) # This try... catch... is not beautiful but honestly this tokenizer was not made to be used # in a library like ours, at all. try: vocab_dict = None if pretrained_vocab_file is not None: # Priority on pickle files (support PyTorch and TF) with open(pretrained_vocab_file, "rb") as f: vocab_dict = pickle.load(f) # Loading a torch-saved transfo-xl vocab dict with pickle results in an integer # Entering this if statement means that we tried to load a torch-saved file with pickle, and we failed. # We therefore load it with torch, if it's available. if type(vocab_dict) == int: if not is_torch_available(): raise ImportError( "Not trying to load dict with PyTorch as you need to install pytorch to load " "from a PyTorch pretrained vocabulary, " "or activate it with environment variables USE_TORCH=1 and USE_TF=0." ) vocab_dict = torch.load(pretrained_vocab_file) if vocab_dict is not None: for key, value in vocab_dict.items(): if key not in self.__dict__: self.__dict__[key] = value elif vocab_file is not None: self.build_vocab() except Exception as e: raise ValueError( f"Unable to parse file {pretrained_vocab_file}. Unknown format. " "If you tried to load a model saved through TransfoXLTokenizerFast," "please note they are not compatible." ) from e if vocab_file is not None: self.build_vocab() @property def do_lower_case(self): return self.lower_case def _compile_space_around_punctuation_pattern(self): look_ahead_for_special_token = f"(?=[{self.punctuation_symbols}])" look_ahead_to_match_all_except_space = r"(?=[^\s])" return re.compile(r"" + look_ahead_for_special_token + look_ahead_to_match_all_except_space) def count_file(self, path, verbose=False, add_eos=False): if verbose: logger.info(f"counting file {path} ...") assert os.path.exists(path), f"Input file {path} not found" sents = [] with open(path, "r", encoding="utf-8") as f: for idx, line in enumerate(f): if verbose and idx > 0 and idx % 500000 == 0: logger.info(f" line {idx}") symbols = self.tokenize(line, add_eos=add_eos) self.counter.update(symbols) sents.append(symbols) return sents def count_sents(self, sents, verbose=False): """ sents : a list of sentences, each a list of tokenized symbols """ if verbose: logger.info(f"counting {len(sents)} sents ...") for idx, symbols in enumerate(sents): if verbose and idx > 0 and idx % 500000 == 0: logger.info(f" line {idx}") self.counter.update(symbols) def _build_from_file(self, vocab_file): self.idx2sym = [] self.sym2idx = OrderedDict() with open(vocab_file, "r", encoding="utf-8") as f: for line in f: symb = line.strip().split()[0] self.add_symbol(symb) if "<UNK>" in self.sym2idx: self.unk_idx = self.sym2idx["<UNK>"] elif "<unk>" in self.sym2idx: self.unk_idx = self.sym2idx["<unk>"] else: raise ValueError("No <unknown> token in vocabulary")
[docs] def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if os.path.isdir(save_directory): vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["pretrained_vocab_file"], ) else: vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory with open(vocab_file, "wb") as f: pickle.dump(self.__dict__, f) return (vocab_file,)
def build_vocab(self): if self.vocab_file: logger.info(f"building vocab from {self.vocab_file}") self._build_from_file(self.vocab_file) logger.info(f"final vocab size {len(self)}") else: logger.info(f"building vocab with min_freq={self.min_freq}, max_size={self.max_size}") self.idx2sym = [] self.sym2idx = OrderedDict() for sym in self.special: self.add_special(sym) for sym, cnt in self.counter.most_common(self.max_size): if cnt < self.min_freq: break self.add_symbol(sym) logger.info(f"final vocab size {len(self)} from {len(self.counter)} unique tokens") @torch_only_method def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False): if verbose: logger.info(f"encoding file {path} ...") assert os.path.exists(path), f"Output file {path} not found" encoded = [] with open(path, "r", encoding="utf-8") as f: for idx, line in enumerate(f): if verbose and idx > 0 and idx % 500000 == 0: logger.info(f" line {idx}") symbols = self.tokenize(line, add_eos=add_eos, add_double_eos=add_double_eos) encoded.append(self.convert_to_tensor(symbols)) if ordered: encoded = torch.cat(encoded) return encoded @torch_only_method def encode_sents(self, sents, ordered=False, verbose=False): if verbose: logger.info(f"encoding {len(sents)} sents ...") encoded = [] for idx, symbols in enumerate(sents): if verbose and idx > 0 and idx % 500000 == 0: logger.info(f" line {idx}") encoded.append(self.convert_to_tensor(symbols)) if ordered: encoded = torch.cat(encoded) return encoded def add_special(self, sym): if sym not in self.sym2idx: self.idx2sym.append(sym) self.sym2idx[sym] = len(self.idx2sym) - 1 setattr(self, f"{sym.strip('<>')}_idx", self.sym2idx[sym]) def add_symbol(self, sym): if sym not in self.sym2idx: self.idx2sym.append(sym) self.sym2idx[sym] = len(self.idx2sym) - 1 def move_added_token(self, token: str, target_idx: int): """ Moves an added token to a specific position in the vocab. This method should be used when resizing an embedding layer other than the last one in the `AdaptiveEmbedding` in order to move the token in the tokenizer from the default position (at the very end) to the desired one. Args: token: The token to move to a specific position in the vocab. target_idx: The position where the token should be moved to. """ assert token in self.added_tokens_encoder, "Token which should be moved has to be an added token" assert token not in self.idx2sym, "Token which should be moved is already in vocab" # Insert sym into vocab self.idx2sym.insert(target_idx, token) self.sym2idx[token] = target_idx # Shift following indices in sym2idx for idx in range(target_idx + 1, len(self.idx2sym)): current_sym = self.idx2sym[idx] self.sym2idx[current_sym] = idx # Delete token from added_tokens old_index = self.added_tokens_encoder[token] del self.added_tokens_decoder[old_index] del self.added_tokens_encoder[token] def moses_punct_norm(self, text): return self.moses_punct_normalizer.normalize(text) def moses_tokenize(self, text): return self.moses_tokenizer.tokenize( text, aggressive_dash_splits=True, return_str=False, escape=False, protected_patterns=self.never_split ) def moses_pipeline(self, text: str) -> List[str]: """ Does basic tokenization using :class:`sacremoses.MosesPunctNormalizer` and :class:`sacremoses.MosesTokenizer` with `aggressive_dash_splits=True` (see :func:`sacremoses.tokenize.MosesTokenizer.tokenize`). Additionally, large comma-separated numbers and floating point values are split. E.g. "23,000 people are 1.80m tall" -> "23 @,@ 000 people are 1 @.@ 80m tall" Args: text: Text to be tokenize Returns: A list of tokenized string Example:: >>> tokenizer = TransfoXLTokenizer.from_pretrained("transfo-xl-wt103") >>> tokenizer.moses_pipeline("23,000 people are 1.80 m tall") ['23', '@,@', '000', 'people', 'are', '1', '@.@', '80', 'm', 'tall'] """ text = self.moses_punct_norm(text) text = self.moses_tokenize(text) text = tokenize_numbers(text) return text def _convert_id_to_token(self, idx): """Converts an id in a token (BPE) using the vocab.""" assert 0 <= idx < len(self), f"Index {idx} out of vocabulary range" return self.idx2sym[idx] def _convert_token_to_id(self, sym): """Converts a token (str) in an id using the vocab.""" if sym in self.sym2idx: return self.sym2idx[sym] else: # logger.info(f'encounter unk {sym}') # assert '<eos>' not in sym if hasattr(self, "unk_idx"): return self.sym2idx.get(sym, self.unk_idx) # Backward compatibility with pre-trained models elif "<unk>" in self.sym2idx: return self.sym2idx["<unk>"] elif "<UNK>" in self.sym2idx: return self.sym2idx["<UNK>"] else: raise ValueError("Token not in vocabulary and no <unk> token in vocabulary for replacement") def convert_tokens_to_string(self, tokens): """ Converts a sequence of tokens (string) in a single string. Additionally, the split numbers are converted back into it's original form. """ out_string = self.moses_detokenizer.detokenize(tokens) return detokenize_numbers(out_string).strip() @torch_only_method def convert_to_tensor(self, symbols): return torch.LongTensor(self.convert_tokens_to_ids(symbols)) @property def vocab_size(self): return len(self.idx2sym) def get_vocab(self): return dict(self.sym2idx, **self.added_tokens_encoder) def _tokenize(self, line, add_eos=False, add_double_eos=False): line = line.strip() # convert to lower case if self.lower_case: line = line.lower() # empty delimiter '' will evaluate False if self.delimiter == "": symbols = line else: symbols = self.moses_pipeline(line) if add_double_eos: # lm1b return ["<S>"] + symbols + ["<S>"] elif add_eos: return symbols + ["<eos>"] else: return symbols
class LMOrderedIterator(object): def __init__(self, data, bsz, bptt, device="cpu", ext_len=None): """ data -- LongTensor -- the LongTensor is strictly ordered """ self.bsz = bsz self.bptt = bptt self.ext_len = ext_len if ext_len is not None else 0 self.device = device # Work out how cleanly we can divide the dataset into bsz parts. self.n_step = data.size(0) // bsz # Trim off any extra elements that wouldn't cleanly fit (remainders). data = data.narrow(0, 0, self.n_step * bsz) # Evenly divide the data across the bsz batches. self.data = data.view(bsz, -1).t().contiguous().to(device) # Number of mini-batches self.n_batch = (self.n_step + self.bptt - 1) // self.bptt def get_batch(self, i, bptt=None): if bptt is None: bptt = self.bptt seq_len = min(bptt, self.data.size(0) - 1 - i) end_idx = i + seq_len beg_idx = max(0, i - self.ext_len) data = self.data[beg_idx:end_idx] target = self.data[i + 1 : i + 1 + seq_len] data_out = data.transpose(0, 1).contiguous().to(self.device) target_out = target.transpose(0, 1).contiguous().to(self.device) return data_out, target_out, seq_len def get_fixlen_iter(self, start=0): for i in range(start, self.data.size(0) - 1, self.bptt): yield self.get_batch(i) def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): max_len = self.bptt + max_deviation * std i = start while True: bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.0 bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std)))) data, target, seq_len = self.get_batch(i, bptt) i += seq_len yield data, target, seq_len if i >= self.data.size(0) - 2: break def __iter__(self): return self.get_fixlen_iter() class LMShuffledIterator(object): def __init__(self, data, bsz, bptt, device="cpu", ext_len=None, shuffle=False): """ data -- list[LongTensor] -- there is no order among the LongTensors """ self.data = data self.bsz = bsz self.bptt = bptt self.ext_len = ext_len if ext_len is not None else 0 self.device = device self.shuffle = shuffle def get_sent_stream(self): # index iterator epoch_indices = np.random.permutation(len(self.data)) if self.shuffle else np.array(range(len(self.data))) # sentence iterator for idx in epoch_indices: yield self.data[idx] @torch_only_method def stream_iterator(self, sent_stream): # streams for each data in the batch streams = [None] * self.bsz data = torch.LongTensor(self.bptt, self.bsz) target = torch.LongTensor(self.bptt, self.bsz) n_retain = 0 while True: # data : [n_retain+bptt x bsz] # target : [bptt x bsz] data[n_retain:].fill_(-1) target.fill_(-1) valid_batch = True for i in range(self.bsz): n_filled = 0 try: while n_filled < self.bptt: if streams[i] is None or len(streams[i]) <= 1: streams[i] = next(sent_stream) # number of new tokens to fill in n_new = min(len(streams[i]) - 1, self.bptt - n_filled) # first n_retain tokens are retained from last batch data[n_retain + n_filled : n_retain + n_filled + n_new, i] = streams[i][:n_new] target[n_filled : n_filled + n_new, i] = streams[i][1 : n_new + 1] streams[i] = streams[i][n_new:] n_filled += n_new except StopIteration: valid_batch = False break if not valid_batch: return data_out = data.transpose(0, 1).contiguous().to(self.device) target_out = target.transpose(0, 1).contiguous().to(self.device) yield data_out, target_out, self.bptt n_retain = min(data.size(0), self.ext_len) if n_retain > 0: data[:n_retain] = data[-n_retain:] data.resize_(n_retain + self.bptt, data.size(1)) def __iter__(self): # sent_stream is an iterator sent_stream = self.get_sent_stream() for batch in self.stream_iterator(sent_stream): yield batch class LMMultiFileIterator(LMShuffledIterator): def __init__(self, paths, vocab, bsz, bptt, device="cpu", ext_len=None, shuffle=False): self.paths = paths self.vocab = vocab self.bsz = bsz self.bptt = bptt self.ext_len = ext_len if ext_len is not None else 0 self.device = device self.shuffle = shuffle def get_sent_stream(self, path): sents = self.vocab.encode_file(path, add_double_eos=True) if self.shuffle: np.random.shuffle(sents) sent_stream = iter(sents) return sent_stream def __iter__(self): if self.shuffle: np.random.shuffle(self.paths) for path in self.paths: # sent_stream is an iterator sent_stream = self.get_sent_stream(path) for batch in self.stream_iterator(sent_stream): yield batch class TransfoXLCorpus(object): @classmethod @torch_only_method def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): """ Instantiate a pre-processed corpus. """ vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP: corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path] else: corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME) # redirect to the cache, if necessary try: resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir) except EnvironmentError: logger.error( f"Corpus '{pretrained_model_name_or_path}' was not found in corpus list " f"({', '.join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys())}. " f"We assumed '{pretrained_model_name_or_path}' was a path or url but couldn't find files {corpus_file} " "at this path or url." ) return None if resolved_corpus_file == corpus_file: logger.info(f"loading corpus file {corpus_file}") else: logger.info(f"loading corpus file {corpus_file} from cache at {resolved_corpus_file}") # Instantiate tokenizer. corpus = cls(*inputs, **kwargs) corpus_dict = torch.load(resolved_corpus_file) for key, value in corpus_dict.items(): corpus.__dict__[key] = value corpus.vocab = vocab if corpus.train is not None: corpus.train = torch.tensor(corpus.train, dtype=torch.long) if corpus.valid is not None: corpus.valid = torch.tensor(corpus.valid, dtype=torch.long) if corpus.test is not None: corpus.test = torch.tensor(corpus.test, dtype=torch.long) return corpus def __init__(self, *args, **kwargs): self.vocab = TransfoXLTokenizer(*args, **kwargs) self.dataset = None self.train = None self.valid = None self.test = None def build_corpus(self, path, dataset): self.dataset = dataset if self.dataset in ["ptb", "wt2", "enwik8", "text8"]: self.vocab.count_file(os.path.join(path, "train.txt")) self.vocab.count_file(os.path.join(path, "valid.txt")) self.vocab.count_file(os.path.join(path, "test.txt")) elif self.dataset == "wt103": self.vocab.count_file(os.path.join(path, "train.txt")) elif self.dataset == "lm1b": train_path_pattern = os.path.join( path, "1-billion-word-language-modeling-benchmark-r13output", "training-monolingual.tokenized.shuffled", "news.en-*", ) train_paths = glob.glob(train_path_pattern) # the vocab will load from file when build_vocab() is called self.vocab.build_vocab() if self.dataset in ["ptb", "wt2", "wt103"]: self.train = self.vocab.encode_file(os.path.join(path, "train.txt"), ordered=True) self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=True) self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=True) elif self.dataset in ["enwik8", "text8"]: self.train = self.vocab.encode_file(os.path.join(path, "train.txt"), ordered=True, add_eos=False) self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=True, add_eos=False) self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=True, add_eos=False) elif self.dataset == "lm1b": self.train = train_paths self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=False, add_double_eos=True) self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=False, add_double_eos=True) def get_iterator(self, split, *args, **kwargs): if split == "train": if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]: data_iter = LMOrderedIterator(self.train, *args, **kwargs) elif self.dataset == "lm1b": kwargs["shuffle"] = True data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs) elif split in ["valid", "test"]: data = self.valid if split == "valid" else self.test if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]: data_iter = LMOrderedIterator(data, *args, **kwargs) elif self.dataset == "lm1b": data_iter = LMShuffledIterator(data, *args, **kwargs) else: data_iter = None raise ValueError(f"Split not recognized: {split}") return data_iter @torch_only_method def get_lm_corpus(datadir, dataset): fn = os.path.join(datadir, "cache.pt") fn_pickle = os.path.join(datadir, "cache.pkl") if os.path.exists(fn): logger.info("Loading cached dataset...") corpus = torch.load(fn_pickle) elif os.path.exists(fn): logger.info("Loading cached dataset from pickle...") with open(fn, "rb") as fp: corpus = pickle.load(fp) else: logger.info(f"Producing dataset {dataset}...") kwargs = {} if dataset in ["wt103", "wt2"]: kwargs["special"] = ["<eos>"] kwargs["lower_case"] = False elif dataset == "ptb": kwargs["special"] = ["<eos>"] kwargs["lower_case"] = True elif dataset == "lm1b": kwargs["special"] = [] kwargs["lower_case"] = False kwargs["vocab_file"] = os.path.join(datadir, "1b_word_vocab.txt") elif dataset in ["enwik8", "text8"]: pass corpus = TransfoXLCorpus(datadir, dataset, **kwargs) torch.save(corpus, fn) return corpus