import numpy as np import torch from torch.utils.data import Dataset, IterableDataset from ..utils.generic import ModelOutput class PipelineDataset(Dataset): def __init__(self, dataset, process, params): self.dataset = dataset self.process = process self.params = params def __len__(self): return len(self.dataset) def __getitem__(self, i): item = self.dataset[i] processed = self.process(item, **self.params) return processed class PipelineIterator(IterableDataset): def __init__(self, loader, infer, params, loader_batch_size=None): """ Roughly equivalent to ``` for item in loader: yield infer(item, **params) ``` Arguments: loader (`torch.utils.data.DataLoader` or any iterator): The iterator that will be used to apply `infer` on. infer (any function): The function to apply of each element of `loader`. params (`dict`): The parameters passed to `infer` along with every item loader_batch_size (`int`, *optional*): If specified, the items of `loader` are supposed to come as batch, and are loader_batched here making it roughly behave as ``` for items in loader: for i in loader_batch_size: item = items[i] yield infer(item, **params) ```""" self.loader = loader self.infer = infer self.params = params if loader_batch_size == 1: # Let's spare some time by deactivating altogether loader_batch_size = None self.loader_batch_size = loader_batch_size # Internal bookkeeping self._loader_batch_index = None self._loader_batch_data = None def __len__(self): return len(self.loader) def __iter__(self): self.iterator = iter(self.loader) return self def loader_batch_item(self): """ Return item located at `loader_batch_index` within the current `loader_batch_data`. """ if isinstance(self._loader_batch_data, torch.Tensor): # Batch data is simple tensor, just fetch the slice result = self._loader_batch_data[self._loader_batch_index] else: # Batch data is assumed to be BaseModelOutput (or dict) loader_batched = {} for k, element in self._loader_batch_data.items(): if isinstance(element, ModelOutput): # Convert ModelOutput to tuple first element = element.to_tuple() if isinstance(element[0], torch.Tensor): loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element) elif isinstance(element[0], np.ndarray): loader_batched[k] = tuple(np.expand_dims(el[self._loader_batch_index], 0) for el in element) continue if k in {"hidden_states", "past_key_values", "attentions"} and isinstance(element, tuple): # Those are stored as lists of tensors so need specific unbatching. if isinstance(element[0], torch.Tensor): loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element) elif isinstance(element[0], np.ndarray): loader_batched[k] = tuple(np.expand_dims(el[self._loader_batch_index], 0) for el in element) continue if element is None: # This can happen for optional data that get passed around loader_batched[k] = None elif isinstance(element[self._loader_batch_index], torch.Tensor): # Take correct batch data, but make it looked like batch_size=1 # For compatibility with other methods within transformers loader_batched[k] = element[self._loader_batch_index].unsqueeze(0) elif isinstance(element[self._loader_batch_index], np.ndarray): # Take correct batch data, but make it looked like batch_size=1 # For compatibility with other methods within transformers loader_batched[k] = np.expand_dims(element[self._loader_batch_index], 0) else: # This is typically a list, so no need to `unsqueeze`. loader_batched[k] = element[self._loader_batch_index] # Recreate the element by reusing the original class to make it look # batch_size=1 result = self._loader_batch_data.__class__(loader_batched) self._loader_batch_index += 1 return result def __next__(self): if self._loader_batch_index is not None and self._loader_batch_index < self.loader_batch_size: # We are currently unrolling a batch so we just need to return # the current item within a batch return self.loader_batch_item() # We're out of items within a batch item = next(self.iterator) processed = self.infer(item, **self.params) # We now have a batch of "inferred things". if self.loader_batch_size is not None: # Try to infer the size of the batch if isinstance(processed, torch.Tensor): first_tensor = processed else: key = list(processed.keys())[0] first_tensor = processed[key] if isinstance(first_tensor, list): observed_batch_size = len(first_tensor) else: observed_batch_size = first_tensor.shape[0] if 0 < observed_batch_size < self.loader_batch_size: # could be last batch so we can't unroll as many # elements. self.loader_batch_size = observed_batch_size # Setting internal index to unwrap the batch self._loader_batch_data = processed self._loader_batch_index = 0 return self.loader_batch_item() else: # We're not unrolling batches return processed class PipelineChunkIterator(PipelineIterator): def __init__(self, loader, infer, params, loader_batch_size=None): """ Roughly equivalent to ``` for iterator in loader: for item in iterator: yield infer(item, **params) ``` Arguments: loader (`torch.utils.data.DataLoader` or any iterator): The iterator that will be used to apply `infer` on. infer (any function): The function to apply of each element of `loader`. params (`dict`): The parameters passed to `infer` along with every item """ super().__init__(loader, infer, params) def __iter__(self): self.iterator = iter(self.loader) self.subiterator = None return self def __next__(self): if self.subiterator is None: "Subiterator None means we haven't started a `preprocess` iterator. so start it" self.subiterator = self.infer(next(self.iterator), **self.params) try: # Try to return next item processed = next(self.subiterator) except StopIteration: # When a preprocess iterator ends, we can start lookig at the next item # ChunkIterator will keep feeding until ALL elements of iterator # all have created their subiterator and have been iterating against. # # Another way to look at it, is we're basically flattening lists of lists # into a single list, but with generators self.subiterator = self.infer(next(self.iterator), **self.params) processed = next(self.subiterator) return processed class PipelinePackIterator(PipelineIterator): """ Roughly equivalent to ``` packed = [] for item in loader: packed.append(item) if item["is_last"]: yield packed packed = [] ``` but it also handles cases where `item` are batched (meaning it's a dict of Tensor with first dimension > 1. In that case it does ``` packed = [] for batch in loader: # item is batched for item in batch: packed.append(item) if item["is_last"]: yield packed packed = [] ``` Arguments: loader (`torch.utils.data.DataLoader` or any iterator): The iterator that will be used to apply `infer` on. infer (any function): The function to apply of each element of `loader`. params (`dict`): The parameters passed to `infer` along with every item loader_batch_size (`int`, *optional*): If specified, the items of `loader` are supposed to come as batch, and are loader_batched here making it roughly behave as ``` for items in loader: for i in loader_batch_size: item = items[i] yield infer(item, **params) ```""" def __iter__(self): self.iterator = iter(self.loader) return self def __next__(self): # Extremely similar to PipelineIterator in its unpacking mechanism # BUT, we have an extra required item which is the presence of `is_last` # That is because everything is flattened by `PipelineChunkIterator` we # need to keep track of how to regroup here in the original `process` # boundaries so that `process` and `postprocess` see the same data. # This iterator accumulates items (possibly while unbatching) until it # its a `is_last` and then just passes it on to the caller. is_last = False accumulator = [] if self._loader_batch_index is not None and self._loader_batch_index < self.loader_batch_size: while self._loader_batch_index < self.loader_batch_size: item = self.loader_batch_item() is_last = item.pop("is_last") accumulator.append(item) if is_last: return accumulator while not is_last: processed = self.infer(next(self.iterator), **self.params) if self.loader_batch_size is not None: if isinstance(processed, torch.Tensor): first_tensor = processed else: key = list(processed.keys())[0] first_tensor = processed[key] if isinstance(first_tensor, list): observed_batch_size = len(first_tensor) else: observed_batch_size = first_tensor.shape[0] if 0 < observed_batch_size < self.loader_batch_size: # could be last batch so we can't unroll as many # elements. self.loader_batch_size = observed_batch_size self._loader_batch_data = processed self._loader_batch_index = 0 while self._loader_batch_index < self.loader_batch_size: item = self.loader_batch_item() is_last = item.pop("is_last") accumulator.append(item) if is_last: return accumulator else: item = processed is_last = item.pop("is_last") accumulator.append(item) return accumulator class KeyDataset(Dataset): def __init__(self, dataset: Dataset, key: str): self.dataset = dataset self.key = key def __len__(self): return len(self.dataset) def __getitem__(self, i): return self.dataset[i][self.key] class KeyPairDataset(Dataset): def __init__(self, dataset: Dataset, key1: str, key2: str): self.dataset = dataset self.key1 = key1 self.key2 = key2 def __len__(self): return len(self.dataset) def __getitem__(self, i): return {"text": self.dataset[i][self.key1], "text_pair": self.dataset[i][self.key2]}