File size: 1,036 Bytes
9691248
 
6f7f115
 
 
9691248
6f7f115
 
9691248
 
 
 
 
 
 
6f7f115
9691248
 
 
 
453bf0e
4e95a0e
9691248
 
6f7f115
9691248
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import itertools
from typing import List

import torch

from .utils import compute_time_delta


class PriorsDataset:
    def __init__(self, dataset, history, time_delta_map):
        self.dataset = dataset
        self.history = history
        self.study_id_to_index = dict(zip(dataset['study_id'], range(len(dataset))))
        self.time_delta_map = time_delta_map
        self.inf_time_delta_value = time_delta_map(float('inf'))
        
    def __getitem__(self, idx):
        batch = self.dataset[idx]
        
        if self.history:
            
            raise NotImplementedError("Priors were made not available in the public release.")

        return batch

    def __len__(self):
        return len(self.dataset)

    def __getattr__(self, name):
        return getattr(self.dataset, name)
    
    def __getitems__(self, keys: List):
        batch = self.__getitem__(keys)
        n_examples = len(batch[next(iter(batch))])
        return [{col: array[i] for col, array in batch.items()} for i in range(n_examples)]