|
|
import os |
|
|
import glob |
|
|
import numpy as np |
|
|
from typing import Any, Callable, Dict, Optional, Set, Tuple |
|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
import random |
|
|
|
|
|
class fMRIDataset(Dataset): |
|
|
def __init__(self, |
|
|
data_root, datasets, split_suffixes, crop_length=40, downstream=False): |
|
|
|
|
|
self.file_paths = [] |
|
|
self.crop_length = crop_length |
|
|
self.downstream = downstream |
|
|
for dataset_name in datasets: |
|
|
for suffix in split_suffixes: |
|
|
folder_name = f"{dataset_name}_{suffix}" |
|
|
folder_path = os.path.join(data_root, folder_name) |
|
|
if not os.path.exists(folder_path): |
|
|
print(f"Warning: Folder not found: {folder_path}") |
|
|
continue |
|
|
|
|
|
for root, dirs, files in os.walk(folder_path): |
|
|
npz_files = glob.glob(os.path.join(root, "*.npz")) |
|
|
if len(npz_files) > 1: |
|
|
|
|
|
|
|
|
npz_files = sorted(npz_files)[:1] |
|
|
self.file_paths.extend(npz_files) |
|
|
|
|
|
print(f"Dataset loaded. Total files found: {len(self.file_paths)}") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.file_paths) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
file_path = self.file_paths[idx] |
|
|
try: |
|
|
with np.load(file_path) as data_file: |
|
|
key = list(data_file.keys())[0] |
|
|
fmri_data = data_file[key] |
|
|
fmri_data = fmri_data.astype(np.float32) |
|
|
except Exception as e: |
|
|
print(f"Error loading file {file_path}: {e}") |
|
|
return None |
|
|
|
|
|
total_time_frames = fmri_data.shape[-1] |
|
|
if total_time_frames > self.crop_length: |
|
|
start_idx = np.random.randint(0, total_time_frames - self.crop_length + 1) |
|
|
end_idx = start_idx + self.crop_length |
|
|
cropped_data = fmri_data[..., start_idx:end_idx] |
|
|
else: |
|
|
cropped_data = fmri_data[..., :self.crop_length] |
|
|
|
|
|
data_tensor = torch.from_numpy(cropped_data) |
|
|
|
|
|
data_tensor = data_tensor.permute(3, 0, 1, 2) |
|
|
|
|
|
return data_tensor |
|
|
|