cxrmate-ed / dataset.py
anicolson's picture
Update dataset.py
4e95a0e verified
import itertools
from typing import List
import torch
from .utils import compute_time_delta
class PriorsDataset:
def __init__(self, dataset, history, time_delta_map):
self.dataset = dataset
self.history = history
self.study_id_to_index = dict(zip(dataset['study_id'], range(len(dataset))))
self.time_delta_map = time_delta_map
self.inf_time_delta_value = time_delta_map(float('inf'))
def __getitem__(self, idx):
batch = self.dataset[idx]
if self.history:
raise NotImplementedError("Priors were made not available in the public release.")
return batch
def __len__(self):
return len(self.dataset)
def __getattr__(self, name):
return getattr(self.dataset, name)
def __getitems__(self, keys: List):
batch = self.__getitem__(keys)
n_examples = len(batch[next(iter(batch))])
return [{col: array[i] for col, array in batch.items()} for i in range(n_examples)]