|
import torch |
|
from torch.utils.data import Dataset |
|
|
|
import numpy as np |
|
|
|
|
|
class load_ECG_Dataset(Dataset): |
|
|
|
def __init__(self, dataset): |
|
|
|
|
|
x_data = dataset["samples"] |
|
|
|
if isinstance(x_data, np.ndarray): |
|
x_data = torch.from_numpy(x_data) |
|
|
|
|
|
y_data = dataset.get("labels") |
|
if y_data is not None and isinstance(y_data, np.ndarray): |
|
y_data = torch.from_numpy(y_data) |
|
|
|
self.x_data = x_data.float() |
|
self.y_data = y_data.long() if y_data is not None else None |
|
|
|
self.len = x_data.shape[0] |
|
|
|
def get_labels(self): |
|
return self.y_data |
|
|
|
def __getitem__(self, idx): |
|
sample = { |
|
'samples': self.x_data[idx].squeeze(-1), |
|
'labels': int(self.y_data[idx]) |
|
} |
|
return sample |
|
|
|
def __len__(self): |
|
return self.len |
|
|
|
|