import torch from torch.utils.data import Dataset import numpy as np class load_ECG_Dataset(Dataset): # Initialize dataset def __init__(self, dataset): # Load sample x_data = dataset["samples"] # Convert to pytorch tensor if isinstance(x_data, np.ndarray): x_data = torch.from_numpy(x_data) # Load labels y_data = dataset.get("labels") if y_data is not None and isinstance(y_data, np.ndarray): y_data = torch.from_numpy(y_data) self.x_data = x_data.float() self.y_data = y_data.long() if y_data is not None else None self.len = x_data.shape[0] def get_labels(self): return self.y_data def __getitem__(self, idx): sample = { 'samples': self.x_data[idx].squeeze(-1), 'labels': int(self.y_data[idx]) } return sample def __len__(self): return self.len