import torch from torch.utils.data import DataLoader from data.dataset import load_ECG_Dataset import os import numpy as np import math def get_class_weight(labels_dict): total = sum(labels_dict.values()) max_num = max(labels_dict.values()) mu = 1.0 / (total / max_num) class_weight = dict() for key, value in labels_dict.items(): score = math.log(mu * total / float(value)) class_weight[key] = score if score > 1.0 else 1.0 return class_weight class ECGDataloader: testdata_path: str traindata_path: str valdata_path: str def __init__(self, data_path, data_type, hparams): self.traindata_path = os.path.join(data_path, data_type, f'train.pt') self.testdata_path = os.path.join(data_path, data_type, f'test.pt') self.validdata_path = os.path.join(data_path, data_type, f'val.pt') self.batch_size = hparams['batch_size'] def train_dataloader(self): train_dataset = torch.load(self.traindata_path) train_dataset = load_ECG_Dataset(train_dataset) cw = train_dataset.y_data.numpy().tolist() cw_dict = {} for i in range(len(np.unique(train_dataset.y_data.numpy()))): cw_dict[i] = cw.count(i) train_loader = DataLoader(dataset=train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=4) return train_loader, get_class_weight(cw_dict) def test_dataloader(self): test_dataset = torch.load(self.testdata_path) test_dataset = load_ECG_Dataset(test_dataset) test_loader = DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=False, drop_last=False, num_workers=4) return test_loader def valid_dataloader(self): valid_dataset = torch.load(self.validdata_path) valid_dataset = load_ECG_Dataset(valid_dataset) valid_loader = DataLoader(dataset=valid_dataset, batch_size=self.batch_size, shuffle=True, drop_last=False, num_workers=4) return valid_loader