Spaces:
Build error
Build error
| import glob, json, os | |
| import torch | |
| import warnings | |
| import numpy as np | |
| from torch.utils.data import Dataset | |
| class BaseDataset2(Dataset): | |
| """Template class for all datasets in the project.""" | |
| def __init__(self, x, y): | |
| """Initialize dataset. | |
| Args: | |
| x(ndarray): Input features. | |
| y(ndarray): Targets. | |
| """ | |
| self.data = torch.from_numpy(x).float() | |
| self.targets = torch.from_numpy(y).float() | |
| self.latents = None | |
| self.labels = None | |
| self.is_radial = [] | |
| self.partition = True | |
| def __getitem__(self, index): | |
| return self.data[index], self.targets[index], index | |
| def __len__(self): | |
| return len(self.data) | |
| def numpy(self, idx=None): | |
| """Get dataset as ndarray. | |
| Specify indices to return a subset of the dataset, otherwise return whole dataset. | |
| Args: | |
| idx(int, optional): Specify index or indices to return. | |
| Returns: | |
| ndarray: Return flattened dataset as a ndarray. | |
| """ | |
| n = len(self) | |
| data = self.data.numpy().reshape((n, -1)) | |
| if idx is None: | |
| return data, self.targets.numpy() | |
| else: | |
| return data[idx], self.targets[idx].numpy() | |
| def get_latents(self): | |
| """Get latent variables. | |
| Returns: | |
| latents(ndarray): Latent variables for each sample. | |
| """ | |
| return self.latents | |
| def load_json(file_path): | |
| with open(file_path, 'r') as f: | |
| data = json.load(f) | |
| return data | |
| def read_json_files(file): | |
| data_x = [] | |
| data_y = [] | |
| samples = load_json(file) | |
| valid_samples = 0 | |
| for sample in samples: | |
| data = [] | |
| skip_sample = False | |
| for key in ['AX1', 'AX2', 'AX3', 'AX4', 'AY1', 'AY2', 'AY3', 'AY4', 'AZ1', 'AZ2', 'AZ3', 'AZ4', 'GX1', 'GX2', 'GX3', 'GX4', 'GY1', 'GY2', 'GY3', 'GY4', 'GZ1', 'GZ2', 'GZ3', 'GZ4', 'GZ1_precise_time_diff', 'GZ2_precise_time_diff', 'GZ3_precise_time_diff', 'GZ4_precise_time_diff', 'precise_time_diff']: | |
| if key in sample: | |
| if key.endswith('_precise_time_diff') or key == 'precise_time_diff': | |
| if sample[key] is None: | |
| skip_sample = True | |
| break | |
| data.append(round(sample[key])*20) | |
| else: | |
| data.extend(sample[key]) | |
| else: | |
| warnings.warn(f"KeyError: {key} not found in JSON file: {file}") | |
| if skip_sample: | |
| #warnings.warn(f"Skipped sample with null values in JSON file: {json_file}") | |
| continue | |
| if len(data) != 768*2 + 5: # 24 keys * 64 values each + 5 additional values | |
| warnings.warn(f"Incomplete sample in JSON file: {file}") | |
| continue | |
| valid_samples += 1 | |
| tensor = torch.tensor(data, dtype=torch.float32) | |
| data_x.append(tensor) | |
| data_y.append(1) | |
| if valid_samples == 0: | |
| warnings.warn(f"No valid samples found in JSON file: {file}") | |
| if not data_x: | |
| raise ValueError("No valid samples found in all the JSON files.") | |
| return torch.stack(data_x), torch.tensor(data_y, dtype=torch.long) |