Datasets documentation


You are viewing v2.0.0 version. A newer version v2.20.0 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started


Dataset streaming lets you get started with a dataset without waiting for the entire dataset to download. The data is downloaded progressively as you iterate over the dataset. This is especially helpful when:

  • You don’t want to wait for an extremely large dataset to download.
  • The dataset size exceeds the amount of disk space on your computer.

For example, the English split of the OSCAR dataset is 1.2 terabytes, but you can use it instantly with streaming. Stream a dataset by setting streaming=True in datasets.load_dataset() as shown below:

>>> from datasets import load_dataset
>>> dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
>>> print(next(iter(dataset)))
{'text': 'Mtendere Village was inspired by the vision of Chief Napoleon Dzombe, which he shared with John Blanchard during his first visit to Malawi. Chief Napoleon conveyed the desperate need for a program to intervene and care for the orphans and vulnerable children (OVC) in Malawi, and John committed to help...

Loading a dataset in streaming mode creates a new dataset type instance (instead of the classic datasets.Dataset object), known as an datasets.IterableDataset. This special type of dataset has its own set of processing methods shown below.

An datasets.IterableDataset is useful for iterative jobs like training a model. You shouldn’t use a datasets.IterableDataset for jobs that require random access to examples because you have to iterate all over it using a for loop. Getting the last example in an iterable dataset would require you to iterate over all the previous examples.


Like a regular datasets.Dataset object, you can also shuffle a datasets.IterableDataset with datasets.IterableDataset.shuffle().

The buffer_size argument controls the size of the buffer to randomly sample examples from. Let’s say your dataset has one million examples, and you set the buffer_size to ten thousand. datasets.IterableDataset.shuffle() will randomly select examples from the first ten thousand examples in the buffer. Selected examples in the buffer are replaced with new examples. By default, the buffer size is 1,000.

>>> from datasets import load_dataset
>>> dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
>>> shuffled_dataset = dataset.shuffle(seed=42, buffer_size=10_000)

datasets.IterableDataset.shuffle() will also shuffle the order of the shards if the dataset is sharded into multiple sets.


Sometimes you may want to reshuffle the dataset after each epoch. This will require you to set a different seed for each epoch. Use datasets.IterableDataset.set_epoch()in between epochs to tell the dataset what epoch you’re on.

Your seed effectively becomes: initial seed + current epoch.

>>> for epoch in range(epochs):
...     shuffled_dataset.set_epoch(epoch)
...     for example in shuffled_dataset:
...         ...

Split dataset

You can split your dataset one of two ways:

>>> dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
>>> dataset_head = dataset.take(2)
>>> list(dataset_head)
[{'id': 0, 'text': 'Mtendere Village was...'}, {'id': 1, 'text': 'Lily James cannot fight the music...'}]
>>> train_dataset = shuffled_dataset.skip(1000)

take and skip prevent future calls to shuffle because they lock in the order of the shards. You should shuffle your dataset before splitting it.


datasets.interleave_datasets() can combine an datasets.IterableDataset with other datasets. The combined dataset returns alternating examples from each of the original datasets.

>>> from datasets import interleave_datasets
>>> en_dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
>>> fr_dataset = load_dataset('oscar', "unshuffled_deduplicated_fr", split='train', streaming=True)

>>> multilingual_dataset = interleave_datasets([en_dataset, fr_dataset])
>>> list(multilingual_dataset.take(2))
[{'text': 'Mtendere Village was inspired by the vision...'}, {'text': "Média de débat d'idées, de culture et de littérature..."}]

Define sampling probabilities from each of the original datasets for more control over how each of them are sampled and combined. Set the probabilities argument with your desired sampling probabilities:

>>> multilingual_dataset_with_oversampling = interleave_datasets([en_dataset, fr_dataset], probabilities=[0.8, 0.2], seed=42)
>>> list(multilingual_dataset_with_oversampling.take(2))
[{'text': 'Mtendere Village was inspired by the vision...'}, {'text': 'Lily James cannot fight the music...'}]

Around 80% of the final dataset is made of the en_dataset, and 20% of the fr_dataset.

Rename, remove, and cast

The following methods allow you to modify the columns of a dataset. These methods are useful for renaming or removing columns and changing columns to a new set of features.


Use datasets.IterableDataset.rename_column()when you need to rename a column in your dataset. Features associated with the original column are actually moved under the new column name, instead of just replacing the original column in-place.

Provide datasets.IterableDataset.rename_column()with the name of the original column, and the new column name:

>>> from datasets import load_dataset
>>> dataset = load_dataset('mc4', 'en', streaming=True, split='train')
>>> dataset = dataset.rename_column("text", "content")


When you need to remove one or more columns, give datasets.IterableDataset.remove_columns() the name of the column to remove. Remove more than one column by providing a list of column names:

>>> from datasets import load_dataset
>>> dataset = load_dataset('mc4', 'en', streaming=True, split='train')
>>> dataset = dataset.remove_columns('timestamp')


datasets.IterableDataset.cast()changes the feature type of one or more columns. This method takes your new datasets.Features as its argument. The following sample code shows how to change the feature types of datasets.ClassLabel and datasets.Value:

>>> from datasets import load_dataset
>>> dataset = load_dataset('glue', 'mrpc', split='train')features
{'sentence1': Value(dtype='string', id=None),
'sentence2': Value(dtype='string', id=None),
'label': ClassLabel(num_classes=2, names=['not_equivalent', 'equivalent'], names_file=None, id=None),
'idx': Value(dtype='int32', id=None)}

>>> from datasets import ClassLabel, Value
>>> new_features = dataset.features.copy()
>>> new_features["label"] = ClassLabel(names=['negative', 'positive'])
>>> new_features["idx"] = Value('int64')
>>> dataset = dataset.cast(new_features)
>>> dataset.features
{'sentence1': Value(dtype='string', id=None),
'sentence2': Value(dtype='string', id=None),
'label': ClassLabel(num_classes=2, names=['negative', 'positive'], names_file=None, id=None),
'idx': Value(dtype='int64', id=None)}

Casting only works if the original feature type and new feature type are compatible. For example, you can cast a column with the feature type Value('int32') to Value('bool') if the original column only contains ones and zeros.

Use datasets.Dataset.cast_column() to change the feature type of just one column. Pass the column name and its new feature type as arguments:

>>> dataset.features
{'audio': Audio(sampling_rate=44100, mono=True, id=None)}

>>> dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
>>> dataset.features
{'audio': Audio(sampling_rate=16000, mono=True, id=None)}


Similar to the function for a regular datasets.Dataset, 🤗 Datasets features for processing datasets.IterableDataset\s. applies processing on-the-fly when examples are streamed.

It allows you to apply a processing function to each example in a dataset, independently or in batches. This function can even create new rows and columns.

The following example demonstrates how to tokenize a datasets.IterableDataset. The function needs to accept and output a dict:

>>> def add_prefix(example):
...     example['text'] = 'My text: ' + example['text']
...     return example

Next, apply this function to the dataset with

>>> from datasets import load_dataset
>>> dataset = load_dataset('oscar', 'unshuffled_deduplicated_en', streaming=True, split='train')
>>> updated_dataset =
>>> list(updated_dataset.take(3))
[{'id': 0, 'text': 'My text: Mtendere Village was inspired by...'},
 {'id': 1, 'text': 'My text: Lily James cannot fight the music...'},
 {'id': 2, 'text': 'My text: "I\'d love to help kickstart...'}]

Let’s take a look at another example, except this time, you will remove a column with When you remove a column, it is only removed after the example has been provided to the mapped function. This allows the mapped function to use the content of the columns before they are removed.

Specify the column to remove with the remove_columns argument in

>>> updated_dataset =, remove_columns=["id"])
>>> list(updated_dataset.take(3))
[{'text': 'My text: Mtendere Village was inspired by...'},
 {'text': 'My text: Lily James cannot fight the music...'},
 {'text': 'My text: "I\'d love to help kickstart...'}]

Batch processing also supports working with batches of examples. Operate on batches by setting batched=True. The default batch size is 1000, but you can adjust it with the batch_size argument. This opens the door to many interesting applications such as tokenization, splitting long sentences into shorter chunks, and data augmentation.


>>> from datasets import load_dataset
>>> from transformers import AutoTokenizer
>>> dataset = load_dataset("mc4", "en", streaming=True, split="train")
>>> tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
>>> def encode(examples):
...     return tokenizer(examples['text'], truncation=True, padding='max_length')
>>> dataset =, batched=True, remove_columns=["text", "timestamp", "url"])
>>> next(iter(dataset))
{'input_ids': 101, 8466, 1018, 1010, 4029, 2475, 2062, 18558, 3100, 2061, ...,1106, 3739, 102],
'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1]}

See other examples of batch processing in the batched map processing documentation. They work the same for iterable datasets.


You can filter rows in the dataset based on a predicate function using datasets.Dataset.filter(). It returns rows that match a specified condition:

>>> from datasets import load_dataset
>>> dataset = load_dataset('oscar', 'unshuffled_deduplicated_en', streaming=True, split='train')
>>> start_with_ar = dataset.filter(lambda example: example['text'].startswith('Ar'))
>>> next(iter(start_with_ar))
{'id': 4, 'text': 'Are you looking for Number the Stars (Essential Modern Classics)?...'}

datasets.Dataset.filter() can also filter by indices if you set with_indices=True:

>>> even_dataset = dataset.filter(lambda example, idx: idx % 2 == 0, with_indices=True)
>>> list(even_dataset.take(3))
[{'id': 0, 'text': 'Mtendere Village was inspired by the vision of Chief Napoleon Dzombe, ...'},
 {'id': 2, 'text': '"I\'d love to help kickstart continued development! And 0 EUR/month...'},
 {'id': 4, 'text': 'Are you looking for Number the Stars (Essential Modern Classics)? Normally, ...'}]

Stream in a training loop

datasets.IterableDataset can be integrated into a training loop. First, shuffle the dataset:

Hide Pytorch content
>>> seed, buffer_size = 42, 10_000
>>> dataset = dataset.shuffle(seed, buffer_size=buffer_size)

Lastly, create a simple training loop and start training:

>>> import torch
>>> from import DataLoader
>>> from transformers import AutoModelForMaskedLM, DataCollatorForLanguageModeling
>>> from tqdm import tqdm
>>> dataset = dataset.with_format("torch")
>>> dataloader = DataLoader(dataset, collate_fn=DataCollatorForLanguageModeling(tokenizer))
>>> device = 'cuda' if torch.cuda.is_available() else 'cpu' 
>>> model = AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased")
>>> model.train().to(device)
>>> optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-5)
>>> for epoch in range(3):
...     dataset.set_epoch(epoch)
...     for i, batch in enumerate(tqdm(dataloader, total=5)):
...         if i == 5:
...             break
...         batch = {k: for k, v in batch.items()}
...         outputs = model(**batch)
...         loss = outputs[0]
...         loss.backward()
...         optimizer.step()
...         optimizer.zero_grad()
...         if i % 10 == 0:
...             print(f"loss: {loss}")