Datasets documentation

Use with PyTorch

You are viewing v2.6.1 version. A newer version v3.1.0 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Use with PyTorch

This document is a quick introduction to using datasets with PyTorch, with a particular focus on how to get torch.Tensor objects out of our datasets, and how to use a PyTorch DataLoader and a Hugging Face Dataset with the best performance.

Dataset format

By default, datasets return regular python objects: integers, floats, strings, lists, etc.

To get PyTorch tensors instead, you can set the format of the dataset to pytorch using Dataset.with_format():

>>> from datasets import Dataset
>>> data = [[1, 2],[3, 4]]
>>> ds = Dataset.from_dict({"data": data})
>>> ds = ds.with_format("torch")
>>> ds[0]
{'data': tensor([1, 2])}
>>> ds[:2]
{'data': tensor([[1, 2],
         [3, 4]])}

A Dataset object is a wrapper of an Arrow table, which allows fast zero-copy reads from arrays in the dataset to PyTorch tensors.

To load the data as tensors on a GPU, specify the device argument:

>>> import torch
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> ds = ds.with_format("torch", device=device)
>>> ds[0]
{'data': tensor([1, 2], device='cuda:0')}

N-dimensional arrays

If your dataset consists of N-dimensional arrays, you will see that by default they are considered as nested lists. In particular, a PyTorch formatted dataset outputs nested lists instead of a single tensor:

>>> from datasets import Dataset
>>> data = [[[1, 2],[3, 4]],[[5, 6],[7, 8]]]
>>> ds = Dataset.from_dict({"data": data})
>>> ds = ds.with_format("torch")
>>> ds[0]
{'data': [tensor([1, 2]), tensor([3, 4])]}

To get a single tensor, you must explicitly use the Array feature type and specify the shape of your tensors:

>>> from datasets import Dataset, Features, Array2D
>>> data = [[[1, 2],[3, 4]],[[5, 6],[7, 8]]]
>>> features = Features({"data": Array2D(shape=(2, 2), dtype='int32')})
>>> ds = Dataset.from_dict({"data": data}, features=features)
>>> ds = ds.with_format("torch")
>>> ds[0]
{'data': tensor([[1, 2],
         [3, 4]])}
>>> ds[:2]
{'data': tensor([[[1, 2],
          [3, 4]],
 
         [[5, 6],
          [7, 8]]])}

Other feature types

ClassLabel data are properly converted to tensors:

>>> from datasets import Dataset, Features, ClassLabel
>>> labels = [0, 0, 1]
>>> features = Features({"label": ClassLabel(names=["negative", "positive"])})
>>> ds = Dataset.from_dict({"label": labels}, features=features) 
>>> ds = ds.with_format("torch")  
>>> ds[:3]
{'label': tensor([0, 0, 1])}

String and binary objects are unchanged, since PyTorch only supports numbers.

The Image and Audio feature types are also supported:

>>> from datasets import Dataset, Features, Audio, Image
>>> images = ["path/to/image.png"] * 10
>>> features = Features({"image": Image()})
>>> ds = Dataset.from_dict({"image": images}, features=features) 
>>> ds = ds.with_format("torch")
>>> ds[0]["image"].shape
torch.Size([512, 512, 4])
>>> ds[0]
{'image': tensor([[[255, 215, 106, 255],
         [255, 215, 106, 255],
         ...,
         [255, 255, 255, 255],
         [255, 255, 255, 255]]], dtype=torch.uint8)}
>>> ds[:2]["image"].shape
torch.Size([2, 512, 512, 4])
>>> ds[:2]
{'image': tensor([[[[255, 215, 106, 255],
          [255, 215, 106, 255],
          ...,
          [255, 255, 255, 255],
          [255, 255, 255, 255]]]], dtype=torch.uint8)}
>>> from datasets import Dataset, Features, Audio, Image
>>> audio = ["path/to/audio.wav"] * 10
>>> features = Features({"audio": Audio()})
>>> ds = Dataset.from_dict({"audio": audio}, features=features) 
>>> ds = ds.with_format("torch")  
>>> ds[0]["audio"]["array"]
tensor([ 6.1035e-05,  1.5259e-05,  1.6785e-04,  ..., -1.5259e-05,
        -1.5259e-05,  1.5259e-05])
>>> ds[0]["audio"]["sampling_rate"]
tensor(44100)

Data loading

Like torch.utils.data.Dataset objects, a Dataset can be passed directly to a PyTorch DataLoader:

>>> import numpy as np
>>> from datasets import Dataset 
>>> from torch.utils.data import DataLoader
>>> data = np.random.rand(16)
>>> label = np.random.randint(0, 2, size=16)
>>> ds = Dataset.from_dict({"data": data, "label": label}).with_format("torch")
>>> dataloader = DataLoader(ds, batch_size=4)
>>> for batch in dataloader:
...     print(batch)                                                                                            
{'data': tensor([0.0047, 0.4979, 0.6726, 0.8105]), 'label': tensor([0, 1, 0, 1])}
{'data': tensor([0.4832, 0.2723, 0.4259, 0.2224]), 'label': tensor([0, 0, 0, 0])}
{'data': tensor([0.5837, 0.3444, 0.4658, 0.6417]), 'label': tensor([0, 1, 0, 0])}
{'data': tensor([0.7022, 0.1225, 0.7228, 0.8259]), 'label': tensor([1, 1, 1, 1])}

Optimize data loading

There are several ways you can increase the speed your data is loaded which can save you time, especially if you are working with large datasets. PyTorch offers parallelized data loading, retrieving batches of indices instead of individually, and streaming to progressively download datasets.

Use multiple Workers

You can parallelize data loading with the num_workers argument of a PyTorch DataLoader and get a higher throughput.

Under the hood, the DataLoader starts num_workers processes. Each process reloads the dataset passed to the DataLoader and is used to query examples. Reloading the dataset inside a worker doesn’t fill up your RAM, since it simply memory-maps the dataset again from your disk.

>>> import numpy as np
>>> from datasets import Dataset, load_from_disk
>>> from torch.utils.data import DataLoader
>>> data = np.random.rand(10_000)
>>> Dataset.from_dict({"data": data}).save_to_disk("my_dataset")
>>> ds = load_from_disk("my_dataset").with_format("torch")
>>> dataloader = DataLoader(ds, batch_size=32, num_workers=4)

Use a BatchSampler

By default, the PyTorch DataLoader load batches of data from a dataset one by one like this:

batch = [dataset[idx] for idx in range(start, end)]

Unfortunately, this does numerous read operations on the dataset. It is more efficient to query batches of examples using a list:

batch = dataset[start:end]
# or
batch = dataset[list_of_indices]

For the PyTorch DataLoader to query batches using a list, you can use a BatchSampler:

>>> from torch.utils.data.sampler import BatchSampler, RandomSampler
>>> sampler = BatchSampler(RandomSampler(ds), batch_size=32, drop_last=False)
>>> dataloader = DataLoader(ds, sampler=sampler)

Moreover, this is particularly useful if you used set_transform to apply a transform on-the-fly when examples are accessed. You must use a BatchSampler if you want the transform to be given full batches instead of receiving batch_size times one single element.

Stream data

Loading a dataset in streaming mode is useful to progressively download the data you need while iterating over the dataset. Set the format of a streaming dataset to torch, and it inherits from torch.utils.data.IterableDataset so you can pass it to a DataLoader:

>>> import numpy as np
>>> from datasets import Dataset, load_dataset
>>> from torch.utils.data import DataLoader
>>> data = np.random.rand(10_000)
>>> Dataset.from_dict({"data": data}).push_to_hub("<username>/my_dataset")  # Upload to the Hugging Face Hub
>>> ds = load_dataset("<username>/my_dataset", streaming=True, split="train").with_format("torch")
>>> dataloader = DataLoader(ds, batch_size=32)

If the dataset is split in several shards (i.e. if the dataset consists of multiple data files), then you can stream in parallel using num_workers:

>>> ds = load_dataset("c4", "en", streaming=True, split="train").with_format("torch")
>>> ds.n_shards
1024
>>> dataloader = DataLoader(ds, batch_size=32, num_workers=4)

In this case each worker will be given a subset of the list of shards to stream from.