File size: 1,036 Bytes
9691248 6f7f115 9691248 6f7f115 9691248 6f7f115 9691248 453bf0e 4e95a0e 9691248 6f7f115 9691248 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import itertools
from typing import List
import torch
from .utils import compute_time_delta
class PriorsDataset:
def __init__(self, dataset, history, time_delta_map):
self.dataset = dataset
self.history = history
self.study_id_to_index = dict(zip(dataset['study_id'], range(len(dataset))))
self.time_delta_map = time_delta_map
self.inf_time_delta_value = time_delta_map(float('inf'))
def __getitem__(self, idx):
batch = self.dataset[idx]
if self.history:
raise NotImplementedError("Priors were made not available in the public release.")
return batch
def __len__(self):
return len(self.dataset)
def __getattr__(self, name):
return getattr(self.dataset, name)
def __getitems__(self, keys: List):
batch = self.__getitem__(keys)
n_examples = len(batch[next(iter(batch))])
return [{col: array[i] for col, array in batch.items()} for i in range(n_examples)]
|