# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import shutil import struct from functools import lru_cache import numpy as np import torch from fairseq.dataclass.constants import DATASET_IMPL_CHOICES from fairseq.data.fasta_dataset import FastaDataset from fairseq.file_io import PathManager from fairseq.data.huffman import HuffmanMMapIndexedDataset, HuffmanMMapIndex from . import FairseqDataset from typing import Union def best_fitting_int_dtype( max_int_to_represent, ) -> Union[np.uint16, np.uint32, np.int64]: if max_int_to_represent is None: return np.uint32 # Safe guess elif max_int_to_represent < 65500: return np.uint16 elif max_int_to_represent < 4294967295: return np.uint32 else: return np.int64 # we avoid np.uint64 because it doesn't save space and its type promotion behaves unexpectedly # https://github.com/numpy/numpy/issues/5745 def get_available_dataset_impl(): return list(map(str, DATASET_IMPL_CHOICES)) def infer_dataset_impl(path): if IndexedRawTextDataset.exists(path): return "raw" elif IndexedDataset.exists(path): with open(index_file_path(path), "rb") as f: magic = f.read(8) if magic == IndexedDataset._HDR_MAGIC: return "cached" elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: return "mmap" elif magic == HuffmanMMapIndex._HDR_MAGIC[:8]: return "huffman" else: return None elif FastaDataset.exists(path): return "fasta" else: return None def make_builder(out_file, impl, vocab_size=None): if impl == "mmap": return MMapIndexedDatasetBuilder( out_file, dtype=best_fitting_int_dtype(vocab_size) ) elif impl == "fasta": raise NotImplementedError elif impl == "huffman": raise ValueError("Use HuffmanCodeBuilder directly as it has a different interface.") else: return IndexedDatasetBuilder(out_file) def make_dataset(path, impl, fix_lua_indexing=False, dictionary=None): if impl == "raw" and IndexedRawTextDataset.exists(path): assert dictionary is not None return IndexedRawTextDataset(path, dictionary) elif impl == "lazy" and IndexedDataset.exists(path): return IndexedDataset(path, fix_lua_indexing=fix_lua_indexing) elif impl == "cached" and IndexedDataset.exists(path): return IndexedCachedDataset(path, fix_lua_indexing=fix_lua_indexing) elif impl == "mmap" and MMapIndexedDataset.exists(path): return MMapIndexedDataset(path) elif impl == "fasta" and FastaDataset.exists(path): from fairseq.data.fasta_dataset import EncodedFastaDataset return EncodedFastaDataset(path, dictionary) elif impl == "huffman" and HuffmanMMapIndexedDataset.exists(path): return HuffmanMMapIndexedDataset(path) return None def dataset_exists(path, impl): if impl == "raw": return IndexedRawTextDataset.exists(path) elif impl == "mmap": return MMapIndexedDataset.exists(path) elif impl == "huffman": return HuffmanMMapIndexedDataset.exists(path) else: return IndexedDataset.exists(path) def read_longs(f, n): a = np.empty(n, dtype=np.int64) f.readinto(a) return a def write_longs(f, a): f.write(np.array(a, dtype=np.int64)) _code_to_dtype = { 1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float, 7: np.double, 8: np.uint16, 9: np.uint32, 10: np.uint64, } def _dtype_header_code(dtype) -> int: for k in _code_to_dtype.keys(): if _code_to_dtype[k] == dtype: return k raise ValueError(dtype) def index_file_path(prefix_path): return prefix_path + ".idx" def data_file_path(prefix_path): return prefix_path + ".bin" class IndexedDataset(FairseqDataset): """Loader for TorchNet IndexedDataset""" _HDR_MAGIC = b"TNTIDX\x00\x00" def __init__(self, path, fix_lua_indexing=False): super().__init__() self.path = path self.fix_lua_indexing = fix_lua_indexing self.data_file = None self.read_index(path) def read_index(self, path): with open(index_file_path(path), "rb") as f: magic = f.read(8) assert magic == self._HDR_MAGIC, ( "Index file doesn't match expected format. " "Make sure that --dataset-impl is configured properly." ) version = f.read(8) assert struct.unpack("= self._len: raise IndexError("index out of range") def __del__(self): if self.data_file: self.data_file.close() @lru_cache(maxsize=8) def __getitem__(self, i) -> torch.Tensor: if not self.data_file: self.read_data(self.path) self.check_index(i) tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) item = torch.from_numpy(a).long() if self.fix_lua_indexing: item -= 1 # subtract 1 for 0-based indexing return item def __len__(self): return self._len def num_tokens(self, index): return self.sizes[index] def size(self, index): return self.sizes[index] @staticmethod def exists(path): return PathManager.exists(index_file_path(path)) and PathManager.exists( data_file_path(path) ) @property def supports_prefetch(self): return False # avoid prefetching to save memory class IndexedCachedDataset(IndexedDataset): def __init__(self, path, fix_lua_indexing=False): super().__init__(path, fix_lua_indexing=fix_lua_indexing) self.cache = None self.cache_index = {} @property def supports_prefetch(self): return True def prefetch(self, indices): if all(i in self.cache_index for i in indices): return if not self.data_file: self.read_data(self.path) indices = sorted(set(indices)) total_size = 0 for i in indices: total_size += self.data_offsets[i + 1] - self.data_offsets[i] self.cache = np.empty(total_size, dtype=self.dtype) ptx = 0 self.cache_index.clear() for i in indices: self.cache_index[i] = ptx size = self.data_offsets[i + 1] - self.data_offsets[i] a = self.cache[ptx : ptx + size] self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) ptx += size if self.data_file: # close and delete data file after prefetch so we can pickle self.data_file.close() self.data_file = None @lru_cache(maxsize=8) def __getitem__(self, i): self.check_index(i) tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) ptx = self.cache_index[i] np.copyto(a, self.cache[ptx : ptx + a.size]) item = torch.from_numpy(a).long() if self.fix_lua_indexing: item -= 1 # subtract 1 for 0-based indexing return item class IndexedRawTextDataset(FairseqDataset): """Takes a text file as input and binarizes it in memory at instantiation. Original lines are also kept in memory""" def __init__(self, path, dictionary, append_eos=True, reverse_order=False): self.tokens_list = [] self.lines = [] self.sizes = [] self.append_eos = append_eos self.reverse_order = reverse_order self.read_data(path, dictionary) self.size = len(self.tokens_list) def read_data(self, path, dictionary): with open(path, "r", encoding="utf-8") as f: for line in f: self.lines.append(line.strip("\n")) tokens = dictionary.encode_line( line, add_if_not_exist=False, append_eos=self.append_eos, reverse_order=self.reverse_order, ).long() self.tokens_list.append(tokens) self.sizes.append(len(tokens)) self.sizes = np.array(self.sizes) def check_index(self, i): if i < 0 or i >= self.size: raise IndexError("index out of range") @lru_cache(maxsize=8) def __getitem__(self, i): self.check_index(i) return self.tokens_list[i] def get_original_text(self, i): self.check_index(i) return self.lines[i] def __del__(self): pass def __len__(self): return self.size def num_tokens(self, index): return self.sizes[index] def size(self, index): return self.sizes[index] @staticmethod def exists(path): return PathManager.exists(path) class IndexedDatasetBuilder: element_sizes = { np.uint8: 1, np.int8: 1, np.int16: 2, np.int32: 4, np.int64: 8, np.float: 4, np.double: 8, } def __init__(self, out_file, dtype=np.int32): self.out_file = open(out_file, "wb") self.dtype = dtype self.data_offsets = [0] self.dim_offsets = [0] self.sizes = [] self.element_size = self.element_sizes[self.dtype] def add_item(self, tensor): # +1 for Lua compatibility bytes = self.out_file.write(np.array(tensor.numpy() + 1, dtype=self.dtype)) self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) for s in tensor.size(): self.sizes.append(s) self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) def merge_file_(self, another_file): index = IndexedDataset(another_file) assert index.dtype == self.dtype begin = self.data_offsets[-1] for offset in index.data_offsets[1:]: self.data_offsets.append(begin + offset) self.sizes.extend(index.sizes) begin = self.dim_offsets[-1] for dim_offset in index.dim_offsets[1:]: self.dim_offsets.append(begin + dim_offset) with open(data_file_path(another_file), "rb") as f: while True: data = f.read(1024) if data: self.out_file.write(data) else: break def finalize(self, index_file): self.out_file.close() index = open(index_file, "wb") index.write(b"TNTIDX\x00\x00") index.write(struct.pack(" str: local_index_path = PathManager.get_local_path(index_file_path(path)) local_data_path = PathManager.get_local_path(data_file_path(path)) assert local_index_path.endswith(".idx") and local_data_path.endswith(".bin"), ( "PathManager.get_local_path does not return files with expected patterns: " f"{local_index_path} and {local_data_path}" ) local_path = local_data_path[:-4] # stripping surfix ".bin" assert local_path == local_index_path[:-4] # stripping surfix ".idx" return local_path class MMapIndexedDatasetBuilder: def __init__(self, out_file, dtype=np.int64): self._data_file = open(out_file, "wb") self._dtype = dtype self._sizes = [] def add_item(self, tensor): np_array = np.array(tensor.numpy(), dtype=self._dtype) self._data_file.write(np_array.tobytes(order="C")) self._sizes.append(np_array.size) def merge_file_(self, another_file): # Concatenate index index = MMapIndexedDataset.Index(index_file_path(another_file)) assert index.dtype == self._dtype for size in index.sizes: self._sizes.append(size) # Concatenate data with open(data_file_path(another_file), "rb") as f: shutil.copyfileobj(f, self._data_file) def finalize(self, index_file): self._data_file.close() with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index: index.write(self._sizes)