|
import torch |
|
from torch.utils.data import DataLoader |
|
from data.dataset import load_ECG_Dataset |
|
|
|
import os |
|
import numpy as np |
|
import math |
|
|
|
|
|
def get_class_weight(labels_dict): |
|
total = sum(labels_dict.values()) |
|
max_num = max(labels_dict.values()) |
|
mu = 1.0 / (total / max_num) |
|
class_weight = dict() |
|
for key, value in labels_dict.items(): |
|
score = math.log(mu * total / float(value)) |
|
class_weight[key] = score if score > 1.0 else 1.0 |
|
return class_weight |
|
|
|
|
|
class ECGDataloader: |
|
testdata_path: str |
|
traindata_path: str |
|
valdata_path: str |
|
|
|
def __init__(self, data_path, data_type, hparams): |
|
self.traindata_path = os.path.join(data_path, data_type, f'train.pt') |
|
self.testdata_path = os.path.join(data_path, data_type, f'test.pt') |
|
self.validdata_path = os.path.join(data_path, data_type, f'val.pt') |
|
self.batch_size = hparams['batch_size'] |
|
|
|
def train_dataloader(self): |
|
train_dataset = torch.load(self.traindata_path) |
|
train_dataset = load_ECG_Dataset(train_dataset) |
|
cw = train_dataset.y_data.numpy().tolist() |
|
cw_dict = {} |
|
for i in range(len(np.unique(train_dataset.y_data.numpy()))): |
|
cw_dict[i] = cw.count(i) |
|
|
|
train_loader = DataLoader(dataset=train_dataset, batch_size=self.batch_size, shuffle=True, |
|
drop_last=True, num_workers=4) |
|
return train_loader, get_class_weight(cw_dict) |
|
|
|
def test_dataloader(self): |
|
test_dataset = torch.load(self.testdata_path) |
|
test_dataset = load_ECG_Dataset(test_dataset) |
|
test_loader = DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=False, |
|
drop_last=False, num_workers=4) |
|
return test_loader |
|
|
|
def valid_dataloader(self): |
|
valid_dataset = torch.load(self.validdata_path) |
|
valid_dataset = load_ECG_Dataset(valid_dataset) |
|
valid_loader = DataLoader(dataset=valid_dataset, batch_size=self.batch_size, shuffle=True, |
|
drop_last=False, num_workers=4) |
|
return valid_loader |
|
|