MoE_ECGFormer / data /dataset.py
splendor1811's picture
Upload 14 files
bcf646b verified
import torch
from torch.utils.data import Dataset
import numpy as np
class load_ECG_Dataset(Dataset):
# Initialize dataset
def __init__(self, dataset):
# Load sample
x_data = dataset["samples"]
# Convert to pytorch tensor
if isinstance(x_data, np.ndarray):
x_data = torch.from_numpy(x_data)
# Load labels
y_data = dataset.get("labels")
if y_data is not None and isinstance(y_data, np.ndarray):
y_data = torch.from_numpy(y_data)
self.x_data = x_data.float()
self.y_data = y_data.long() if y_data is not None else None
self.len = x_data.shape[0]
def get_labels(self):
return self.y_data
def __getitem__(self, idx):
sample = {
'samples': self.x_data[idx].squeeze(-1),
'labels': int(self.y_data[idx])
}
return sample
def __len__(self):
return self.len