Datasets documentation
Create a NIfTI dataset
Create a NIfTI dataset
This page shows how to create and share a dataset of medical images in NIfTI format (.nii / .nii.gz) using the datasets library.
You can share a dataset with your team or with anyone in the community by creating a dataset repository on the Hugging Face Hub:
from datasets import load_dataset
dataset = load_dataset("<username>/my_nifti_dataset")There are two common ways to create a NIfTI dataset:
- Create a dataset from local NIfTI files in Python and upload it with
Dataset.push_to_hub. - Use a folder-based convention (one file per example) and a small helper to convert it into a
Dataset.
You can control access to your dataset by requiring users to share their contact information first. Check out the Gated datasets guide for more information.
Local files
If you already have a list of file paths to NIfTI files, the easiest workflow is to create a Dataset from that list and cast the column to the Nifti feature.
from datasets import Dataset
from datasets import Nifti
# simple example: create a dataset from file paths
files = ["/path/to/scan_001.nii.gz", "/path/to/scan_002.nii.gz"]
ds = Dataset.from_dict({"nifti": files}).cast_column("nifti", Nifti())
# access a decoded nibabel image (if decode=True)
# ds[0]["nifti"] will be a nibabel.Nifti1Image object when decode=True
# or a dict {'bytes': None, 'path': '...'} when decode=FalseThe Nifti feature supports a decode parameter. When decode=True (the default), it loads the NIfTI file into a nibabel.nifti1.Nifti1Image object. You can access the image data as a numpy array with img.get_fdata(). When decode=False, it returns a dict with the file path and bytes.
from datasets import Dataset, Nifti
ds = Dataset.from_dict({"nifti": ["/path/to/scan.nii.gz"]}).cast_column("nifti", Nifti(decode=True))
img = ds[0]["nifti"] # instance of: nibabel.nifti1.Nifti1Image
arr = img.get_fdata()After preparing the dataset you can push it to the Hub:
ds.push_to_hub("<username>/my_nifti_dataset")This will create a dataset repository containing your NIfTI dataset with a data/ folder of parquet shards.
Folder conventions and metadata
If you organize your dataset in folders you can create splits automatically (train/test/validation) by following a structure like:
dataset/train/scan_0001.nii
dataset/train/scan_0002.nii
dataset/validation/scan_1001.nii
dataset/test/scan_2001.niiIf you have labels or other metadata, provide a metadata.csv, metadata.jsonl, or metadata.parquet in the folder so files can be linked to metadata rows. The metadata must contain a file_name (or *_file_name) field with the relative path to the NIfTI file next to the metadata file.
Example metadata.csv:
file_name,patient_id,age,diagnosis
scan_0001.nii.gz,P001,45,healthy
scan_0002.nii.gz,P002,59,disease_xThe Nifti feature works with zipped datasets too — each zip can contain NIfTI files and a metadata file. This is useful when uploading large datasets as archives.
This means your dataset structure could look like this (mixed compressed and uncompressed files):
dataset/train/scan_0001.nii.gz
dataset/train/scan_0002.nii
dataset/validation/scan_1001.nii.gz
dataset/test/scan_2001.niiConverting to PyTorch tensors
Use the set_transform() function to apply the transformation on-the-fly to batches of the dataset:
import torch
import nibabel
import numpy as np
def transform_to_pytorch(example):
example["nifti_torch"] = [torch.tensor(ex.get_fdata()) for ex in example["nifti"]]
return example
ds.set_transform(transform_to_pytorch)
Accessing elements now (e.g. ds[0]) will yield torch tensors in the "nifti_torch" key.
Usage of NifTI1Image
NifTI is a format to store the result of 3 (or even 4) dimensional brain scans. This includes 3 spatial dimensions (x,y,z) and optionally a time dimension (t). Furthermore, the given positions here are only relative to the scanner, therefore the dimensions (4, 5, 6) are used to lift this to real world coordinates.
You can visualize nifti files for instance leveraging matplotlib as follows:
import matplotlib.pyplot as plt
from datasets import load_dataset
def show_slices(slices):
""" Function to display row of image slices """
fig, axes = plt.subplots(1, len(slices))
for i, slice in enumerate(slices):
axes[i].imshow(slice.T, cmap="gray", origin="lower")
nifti_ds = load_dataset("<username>/my_nifti_dataset")
for epi_img in nifti_ds:
nifti_img = epi_img["nifti"].get_fdata()
show_slices([nifti_img[:, :, 16], nifti_img[26, :, :], nifti_img[:, 30, :]])
plt.show()