|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
|
|
from . import BaseWrapperDataset |
|
|
|
|
|
class PrependDataset(BaseWrapperDataset): |
|
def __init__(self, dataset, prepend_getter, ensure_first_token_is=None): |
|
super().__init__(dataset) |
|
self.prepend_getter = prepend_getter |
|
self.ensure_first_token = ensure_first_token_is |
|
|
|
def __getitem__(self, idx): |
|
item = self.dataset[idx] |
|
is_tuple = isinstance(item, tuple) |
|
src = item[0] if is_tuple else item |
|
|
|
assert self.ensure_first_token is None or src[0] == self.ensure_first_token |
|
prepend_idx = self.prepend_getter(self.dataset, idx) |
|
assert isinstance(prepend_idx, int) |
|
src[0] = prepend_idx |
|
item = tuple((src,) + item[1:]) if is_tuple else src |
|
return item |
|
|