Agusbs98's picture
Upload 37 files
bb18256
raw
history blame contribute delete
No virus
1.16 kB
import os, sys
from libs import *
class ECGDataset(torch.utils.data.Dataset):
def __init__(self,
df_path, data_path,
config,
augment = False,
):
self.df_path, self.data_path, = df_path, data_path,
self.df = pandas.read_csv(self.df_path)
self.config = config
self.augment = augment
def __len__(self,
):
return len(self.df)
def __getitem__(self,
index,
):
row = self.df.iloc[index]
# save np.load
np_load_old = np.load
# modify the default parameters of np.load
np.load = lambda *a,**k: np_load_old(*a, allow_pickle=True, **k)
# call load_data with allow_pickle implicitly set to true
ecg = np.load("{}/{}.npy".format(self.data_path, row["id"]))[self.config["ecg_leads"], :]
# restore np.load for future normal usage
np.load = np_load_old
ecg = pad_sequences(ecg, self.config["ecg_length"], "float64",
"post", "post",
)
if self.augment:
ecg = self.drop_lead(ecg)
ecg = torch.tensor(ecg).float()
return ecg