MoE_ECGFormer / data /dataloader.py
splendor1811's picture
Upload 14 files
bcf646b verified
import torch
from torch.utils.data import DataLoader
from data.dataset import load_ECG_Dataset
import os
import numpy as np
import math
def get_class_weight(labels_dict):
total = sum(labels_dict.values())
max_num = max(labels_dict.values())
mu = 1.0 / (total / max_num)
class_weight = dict()
for key, value in labels_dict.items():
score = math.log(mu * total / float(value))
class_weight[key] = score if score > 1.0 else 1.0
return class_weight
class ECGDataloader:
testdata_path: str
traindata_path: str
valdata_path: str
def __init__(self, data_path, data_type, hparams):
self.traindata_path = os.path.join(data_path, data_type, f'train.pt')
self.testdata_path = os.path.join(data_path, data_type, f'test.pt')
self.validdata_path = os.path.join(data_path, data_type, f'val.pt')
self.batch_size = hparams['batch_size']
def train_dataloader(self):
train_dataset = torch.load(self.traindata_path)
train_dataset = load_ECG_Dataset(train_dataset)
cw = train_dataset.y_data.numpy().tolist()
cw_dict = {}
for i in range(len(np.unique(train_dataset.y_data.numpy()))):
cw_dict[i] = cw.count(i)
train_loader = DataLoader(dataset=train_dataset, batch_size=self.batch_size, shuffle=True,
drop_last=True, num_workers=4)
return train_loader, get_class_weight(cw_dict)
def test_dataloader(self):
test_dataset = torch.load(self.testdata_path)
test_dataset = load_ECG_Dataset(test_dataset)
test_loader = DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=False,
drop_last=False, num_workers=4)
return test_loader
def valid_dataloader(self):
valid_dataset = torch.load(self.validdata_path)
valid_dataset = load_ECG_Dataset(valid_dataset)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=self.batch_size, shuffle=True,
drop_last=False, num_workers=4)
return valid_loader