|
import itertools |
|
from typing import List |
|
|
|
import torch |
|
|
|
from .utils import compute_time_delta |
|
|
|
|
|
class PriorsDataset: |
|
def __init__(self, dataset, history, time_delta_map): |
|
self.dataset = dataset |
|
self.history = history |
|
self.study_id_to_index = dict(zip(dataset['study_id'], range(len(dataset)))) |
|
self.time_delta_map = time_delta_map |
|
self.inf_time_delta_value = time_delta_map(float('inf')) |
|
|
|
def __getitem__(self, idx): |
|
batch = self.dataset[idx] |
|
|
|
if self.history: |
|
|
|
raise NotImplementedError("Priors were made not available in the public release.") |
|
|
|
return batch |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getattr__(self, name): |
|
return getattr(self.dataset, name) |
|
|
|
def __getitems__(self, keys: List): |
|
batch = self.__getitem__(keys) |
|
n_examples = len(batch[next(iter(batch))]) |
|
return [{col: array[i] for col, array in batch.items()} for i in range(n_examples)] |
|
|