Datasets documentation

Main classes

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Main classes


class datasets.DatasetInfo

< >

( description: str = <factory> citation: str = <factory> homepage: str = <factory> license: str = <factory> features: typing.Optional[datasets.features.features.Features] = None post_processed: typing.Optional[] = None supervised_keys: typing.Optional[] = None builder_name: typing.Optional[str] = None dataset_name: typing.Optional[str] = None config_name: typing.Optional[str] = None version: typing.Union[str, datasets.utils.version.Version, NoneType] = None splits: typing.Optional[dict] = None download_checksums: typing.Optional[dict] = None download_size: typing.Optional[int] = None post_processing_size: typing.Optional[int] = None dataset_size: typing.Optional[int] = None size_in_bytes: typing.Optional[int] = None )


  • description (str) — A description of the dataset.
  • citation (str) — A BibTeX citation of the dataset.
  • homepage (str) — A URL to the official homepage for the dataset.
  • license (str) — The dataset’s license. It can be the name of the license or a paragraph containing the terms of the license.
  • features (Features, optional) — The features used to specify the dataset’s column types.
  • post_processed (PostProcessedInfo, optional) — Information regarding the resources of a possible post-processing of a dataset. For example, it can contain the information of an index.
  • supervised_keys (SupervisedKeysData, optional) — Specifies the input feature and the label for supervised learning if applicable for the dataset (legacy from TFDS).
  • builder_name (str, optional) — The name of the GeneratorBasedBuilder subclass used to create the dataset. Usually matched to the corresponding script name. It is also the snake_case version of the dataset builder class name.
  • config_name (str, optional) — The name of the configuration derived from BuilderConfig.
  • version (str or Version, optional) — The version of the dataset.
  • splits (dict, optional) — The mapping between split name and metadata.
  • download_checksums (dict, optional) — The mapping between the URL to download the dataset’s checksums and corresponding metadata.
  • download_size (int, optional) — The size of the files to download to generate the dataset, in bytes.
  • post_processing_size (int, optional) — Size of the dataset in bytes after post-processing, if any.
  • dataset_size (int, optional) — The combined size in bytes of the Arrow tables for all splits.
  • size_in_bytes (int, optional) — The combined size in bytes of all files associated with the dataset (downloaded files + Arrow files).
  • **config_kwargs (additional keyword arguments) — Keyword arguments to be passed to the BuilderConfig and used in the DatasetBuilder.

Information about a dataset.

DatasetInfo documents datasets, including its name, version, and features. See the constructor arguments and properties for a full list.

Not all fields are known on construction and may be updated later.


< >

( dataset_info_dir: str storage_options: typing.Optional[dict] = None )


  • dataset_info_dir (str) — The directory containing the metadata file. This should be the root directory of a specific dataset version.
  • storage_options (dict, optional) — Key/value pairs to be passed on to the file-system backend, if any.

    Added in 2.9.0

Create DatasetInfo from the JSON file in dataset_info_dir.

This function updates all the dynamically generated fields (num_examples, hash, time of creation,…) of the DatasetInfo.

This will overwrite all previous metadata.


>>> from datasets import DatasetInfo
>>> ds_info = DatasetInfo.from_directory("/path/to/directory/")


< >

( dataset_info_dir pretty_print = False storage_options: typing.Optional[dict] = None )


  • dataset_info_dir (str) — Destination directory.
  • pretty_print (bool, defaults to False) — If True, the JSON will be pretty-printed with the indent level of 4.
  • storage_options (dict, optional) — Key/value pairs to be passed on to the file-system backend, if any.

    Added in 2.9.0

Write DatasetInfo and license (if present) as JSON files to dataset_info_dir.


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")


The base class Dataset implements a Dataset backed by an Apache Arrow table.

class datasets.Dataset

< >

( arrow_table: Table info: typing.Optional[] = None split: typing.Optional[datasets.splits.NamedSplit] = None indices_table: typing.Optional[datasets.table.Table] = None fingerprint: typing.Optional[str] = None )

A Dataset backed by an Arrow table.


< >

( name: str column: typing.Union[list, <built-in function array>] new_fingerprint: str feature: typing.Union[dict, list, tuple, datasets.features.features.Value, datasets.features.features.ClassLabel, datasets.features.translation.Translation, datasets.features.translation.TranslationVariableLanguages, datasets.features.features.LargeList, datasets.features.features.Sequence, datasets.features.features.Array2D, datasets.features.features.Array3D, datasets.features.features.Array4D, datasets.features.features.Array5D,, datasets.features.image.Image,, NoneType] = None )


  • name (str) — Column name.
  • column (list or np.array) — Column data to be added.
  • feature (FeatureType or None, defaults to None) — Column datatype.

Add column to Dataset.

Added in 1.7


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> more_text = ds["text"]
>>> ds.add_column(name="text_2", column=more_text)
    features: ['text', 'label', 'text_2'],
    num_rows: 1066


< >

( item: dict new_fingerprint: str )


  • item (dict) — Item data to be added.

Add item to Dataset.

Added in 1.7


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> new_review = {'label': 0, 'text': 'this movie is the absolute worst thing I have ever seen'}
>>> ds = ds.add_item(new_review)
>>> ds[-1]
{'label': 0, 'text': 'this movie is the absolute worst thing I have ever seen'}


< >

( filename: str info: typing.Optional[] = None split: typing.Optional[datasets.splits.NamedSplit] = None indices_filename: typing.Optional[str] = None in_memory: bool = False )


  • filename (str) — File name of the dataset.
  • info (DatasetInfo, optional) — Dataset information, like description, citation, etc.
  • split (NamedSplit, optional) — Name of the dataset split.
  • indices_filename (str, optional) — File names of the indices.
  • in_memory (bool, defaults to False) — Whether to copy the data in-memory.

Instantiate a Dataset backed by an Arrow table at filename.


< >

( buffer: Buffer info: typing.Optional[] = None split: typing.Optional[datasets.splits.NamedSplit] = None indices_buffer: typing.Optional[pyarrow.lib.Buffer] = None )


  • buffer (pyarrow.Buffer) — Arrow buffer.
  • info (DatasetInfo, optional) — Dataset information, like description, citation, etc.
  • split (NamedSplit, optional) — Name of the dataset split.
  • indices_buffer (pyarrow.Buffer, optional) — Indices Arrow buffer.

Instantiate a Dataset backed by an Arrow buffer.


< >

( df: DataFrame features: typing.Optional[datasets.features.features.Features] = None info: typing.Optional[] = None split: typing.Optional[datasets.splits.NamedSplit] = None preserve_index: typing.Optional[bool] = None )


  • df (pandas.DataFrame) — Dataframe that contains the dataset.
  • features (Features, optional) — Dataset features.
  • info (DatasetInfo, optional) — Dataset information, like description, citation, etc.
  • split (NamedSplit, optional) — Name of the dataset split.
  • preserve_index (bool, optional) — Whether to store the index as an additional column in the resulting Dataset. The default of None will store the index as a column, except for RangeIndex which is stored as metadata only. Use preserve_index=True to force it to be stored as a column.

Convert pandas.DataFrame to a pyarrow.Table to create a Dataset.

The column types in the resulting Arrow Table are inferred from the dtypes of the pandas.Series in the DataFrame. In the case of non-object Series, the NumPy dtype is translated to its Arrow equivalent. In the case of object, we need to guess the datatype by looking at the Python objects in this Series.

Be aware that Series of the object dtype don’t carry enough information to always lead to a meaningful Arrow type. In the case that we cannot infer a type, e.g. because the DataFrame is of length 0 or the Series only contains None/nan objects, the type is set to null. This behavior can be avoided by constructing explicit features and passing it to this function.

Important: a dataset created with from_pandas() lives in memory and therefore doesn’t have an associated cache directory. This may change in the feature, but in the meantime if you want to reduce memory usage you should write it back on disk and reload using using e.g. save_to_disk / load_from_disk.


>>> ds = Dataset.from_pandas(df)


< >

( mapping: dict features: typing.Optional[datasets.features.features.Features] = None info: typing.Optional[] = None split: typing.Optional[datasets.splits.NamedSplit] = None )


  • mapping (Mapping) — Mapping of strings to Arrays or Python lists.
  • features (Features, optional) — Dataset features.
  • info (DatasetInfo, optional) — Dataset information, like description, citation, etc.
  • split (NamedSplit, optional) — Name of the dataset split.

Convert dict to a pyarrow.Table to create a Dataset.

Important: a dataset created with from_dict() lives in memory and therefore doesn’t have an associated cache directory. This may change in the feature, but in the meantime if you want to reduce memory usage you should write it back on disk and reload using using e.g. save_to_disk / load_from_disk.


< >

( generator: typing.Callable features: typing.Optional[datasets.features.features.Features] = None cache_dir: str = None keep_in_memory: bool = False gen_kwargs: typing.Optional[dict] = None num_proc: typing.Optional[int] = None split: NamedSplit = NamedSplit('train') **kwargs )


  • generator ( —Callable): A generator function that yields examples.
  • features (Features, optional) — Dataset features.
  • cache_dir (str, optional, defaults to "~/.cache/huggingface/datasets") — Directory to cache data.
  • keep_in_memory (bool, defaults to False) — Whether to copy the data in-memory.
  • gen_kwargs(dict, optional) — Keyword arguments to be passed to the generator callable. You can define a sharded dataset by passing the list of shards in gen_kwargs and setting num_proc greater than 1.
  • num_proc (int, optional, defaults to None) — Number of processes when downloading and generating the dataset locally. This is helpful if the dataset is made of multiple files. Multiprocessing is disabled by default. If num_proc is greater than one, then all list values in gen_kwargs must be the same length. These values will be split between calls to the generator. The number of shards will be the minimum of the shortest list in gen_kwargs and num_proc.

    Added in 2.7.0

  • split (NamedSplit, defaults to Split.TRAIN) — Split name to be assigned to the dataset.

    Added in 2.21.0

  • **kwargs (additional keyword arguments) — Keyword arguments to be passed to :GeneratorConfig.

Create a Dataset from a generator.


>>> def gen():
...     yield {"text": "Good", "label": 0}
...     yield {"text": "Bad", "label": 1}
>>> ds = Dataset.from_generator(gen)
>>> def gen(shards):
...     for shard in shards:
...         with open(shard) as f:
...             for line in f:
...                 yield {"line": line}
>>> shards = [f"data{i}.txt" for i in range(32)]
>>> ds = Dataset.from_generator(gen, gen_kwargs={"shards": shards})


< >

( )

The Apache Arrow table backing the dataset.


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
text: string
label: int64
text: [["compassionately explores the seemingly irreconcilable situation between conservative christian parents and their estranged gay and lesbian children .","the soundtrack alone is worth the price of admission .","rodriguez does a splendid job of racial profiling hollywood style--casting excellent latin actors of all ages--a trend long overdue .","beneath the film's obvious determination to shock at any cost lies considerable skill and determination , backed by sheer nerve .","bielinsky is a filmmaker of impressive talent .","so beautifully acted and directed , it's clear that washington most certainly has a new career ahead of him if he so chooses .","a visual spectacle full of stunning images and effects .","a gentle and engrossing character study .","it's enough to watch huppert scheming , with her small , intelligent eyes as steady as any noir villain , and to enjoy the perfectly pitched web of tension that chabrol spins .","an engrossing portrait of uncompromising artists trying to create something original against the backdrop of a corporate music industry that only seems to care about the bottom line .",...,"ultimately , jane learns her place as a girl , softens up and loses some of the intensity that made her an interesting character to begin with .","ah-nuld's action hero days might be over .","it's clear why deuces wild , which was shot two years ago , has been gathering dust on mgm's shelf .","feels like nothing quite so much as a middle-aged moviemaker's attempt to surround himself with beautiful , half-naked women .","when the precise nature of matthew's predicament finally comes into sharp focus , the revelation fails to justify the build-up .","this picture is murder by numbers , and as easy to be bored by as your abc's , despite a few whopping shootouts .","hilarious musical comedy though stymied by accents thick as mud .","if you are into splatter movies , then you will probably have a reasonably good time with the salton sea .","a dull , simple-minded and stereotypical tale of drugs , death and mind-numbing indifference on the inner-city streets .","the feature-length stretch . . . strains the show's concept ."]]
label: [[1,1,1,1,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0]]


< >

( )

The cache files containing the Apache Arrow table backing the dataset.


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> ds.cache_files
[{'filename': '/root/.cache/huggingface/datasets/rotten_tomatoes_movie_review/default/1.0.0/40d411e45a6ce3484deed7cc15b82a53dad9a72aafd9f86f8f227134bec5ca46/rotten_tomatoes_movie_review-validation.arrow'}]


< >

( )

Number of columns in the dataset.


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> ds.num_columns


< >

( )

Number of rows in the dataset (same as Dataset.len()).


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> ds.num_rows


< >

( )

Names of the columns in the dataset.


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> ds.column_names
['text', 'label']


< >

( )

Shape of the dataset (number of columns, number of rows).


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> ds.shape
(1066, 2)


< >

( column: str ) list


  • column (str) — Column name (list all the column names with column_names).



List of unique elements in the given column.

Return a list of the unique elements in a column.

This is implemented in the low-level backend and as such, very fast.


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> ds.unique('label')
[1, 0]


< >

( new_fingerprint: typing.Optional[str] = None max_depth = 16 ) Dataset


  • new_fingerprint (str, optional) — The new fingerprint of the dataset after transform. If None, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments.



A copy of the dataset with flattened columns.

Flatten the table. Each column with a struct type is flattened into one column per struct field. Other columns are left unchanged.


>>> from datasets import load_dataset
>>> ds = load_dataset("squad", split="train")
>>> ds.features
{'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None),
 'context': Value(dtype='string', id=None),
 'id': Value(dtype='string', id=None),
 'question': Value(dtype='string', id=None),
 'title': Value(dtype='string', id=None)}
>>> ds.flatten()
    features: ['id', 'title', 'context', 'question', 'answers.text', 'answers.answer_start'],
    num_rows: 87599


< >

( features: Features batch_size: typing.Optional[int] = 1000 keep_in_memory: bool = False load_from_cache_file: typing.Optional[bool] = None cache_file_name: typing.Optional[str] = None writer_batch_size: typing.Optional[int] = 1000 num_proc: typing.Optional[int] = None ) Dataset


  • features (Features) — New features to cast the dataset to. The name of the fields in the features must match the current column names. The type of the data must also be convertible from one type to the other. For non-trivial conversion, e.g. str <-> ClassLabel you should use map() to update the Dataset.
  • batch_size (int, defaults to 1000) — Number of examples per batch provided to cast. If batch_size <= 0 or batch_size == None then provide the full dataset as a single batch to cast.
  • keep_in_memory (bool, defaults to False) — Whether to copy the data in-memory.
  • load_from_cache_file (bool, defaults to True if caching is enabled) — If a cache file storing the current computation from function can be identified, use it instead of recomputing.
  • cache_file_name (str, optional, defaults to None) — Provide the name of a path for the cache file. It is used to store the results of the computation instead of the automatically generated cache file name.
  • writer_batch_size (int, defaults to 1000) — Number of rows per write operation for the cache file writer. This value is a good trade-off between memory usage during the processing, and processing speed. Higher value makes the processing do fewer lookups, lower value consume less temporary memory while running map().
  • num_proc (int, optional, defaults to None) — Number of processes for multiprocessing. By default it doesn’t use multiprocessing.



A copy of the dataset with casted features.

Cast the dataset to a new set of features.


>>> from datasets import load_dataset, ClassLabel, Value
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> ds.features
{'label': ClassLabel(names=['neg', 'pos'], id=None),
 'text': Value(dtype='string', id=None)}
>>> new_features = ds.features.copy()
>>> new_features['label'] = ClassLabel(names=['bad', 'good'])
>>> new_features['text'] = Value('large_string')
>>> ds = ds.cast(new_features)
>>> ds.features
{'label': ClassLabel(names=['bad', 'good'], id=None),
 'text': Value(dtype='large_string', id=None)}


< >

( column: str feature: typing.Union[dict, list, tuple, datasets.features.features.Value, datasets.features.features.ClassLabel, datasets.features.translation.Translation, datasets.features.translation.TranslationVariableLanguages, datasets.features.features.LargeList, datasets.features.features.Sequence, datasets.features.features.Array2D, datasets.features.features.Array3D, datasets.features.features.Array4D, datasets.features.features.Array5D,, datasets.features.image.Image,] new_fingerprint: typing.Optional[str] = None )


  • column (str) — Column name.
  • feature (FeatureType) — Target feature.
  • new_fingerprint (str, optional) — The new fingerprint of the dataset after transform. If None, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments.

Cast column to feature for decoding.


>>> from datasets import load_dataset, ClassLabel
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> ds.features
{'label': ClassLabel(names=['neg', 'pos'], id=None),
 'text': Value(dtype='string', id=None)}
>>> ds = ds.cast_column('label', ClassLabel(names=['bad', 'good']))
>>> ds.features
{'label': ClassLabel(names=['bad', 'good'], id=None),
 'text': Value(dtype='string', id=None)}


< >

( column_names: typing.Union[str, typing.List[str]] new_fingerprint: typing.Optional[str] = None ) Dataset


  • column_names (Union[str, List[str]]) — Name of the column(s) to remove.
  • new_fingerprint (str, optional) — The new fingerprint of the dataset after transform. If None, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments.



A copy of the dataset object without the columns to remove.

Remove one or several column(s) in the dataset and the features associated to them.

You can also remove a column using map() with remove_columns but the present method doesn’t copy the data of the remaining columns and is thus faster.


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> ds = ds.remove_columns('label')
    features: ['text'],
    num_rows: 1066
>>> ds = ds.remove_columns(column_names=ds.column_names) # Removing all the columns returns an empty dataset with the `num_rows` property set to 0
    features: [],
    num_rows: 0


< >

( original_column_name: str new_column_name: str new_fingerprint: typing.Optional[str] = None ) Dataset


  • original_column_name (str) — Name of the column to rename.
  • new_column_name (str) — New name for the column.
  • new_fingerprint (str, optional) — The new fingerprint of the dataset after transform. If None, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments.



A copy of the dataset with a renamed column.

Rename a column in the dataset, and move the features associated to the original column under the new column name.


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> ds = ds.rename_column('label', 'label_new')
    features: ['text', 'label_new'],
    num_rows: 1066


< >

( column_mapping: typing.Dict[str, str] new_fingerprint: typing.Optional[str] = None ) Dataset


  • column_mapping (Dict[str, str]) — A mapping of columns to rename to their new names
  • new_fingerprint (str, optional) — The new fingerprint of the dataset after transform. If None, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments.



A copy of the dataset with renamed columns

Rename several columns in the dataset, and move the features associated to the original columns under the new column names.


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> ds = ds.rename_columns({'text': 'text_new', 'label': 'label_new'})
    features: ['text_new', 'label_new'],
    num_rows: 1066


< >

( column_names: typing.Union[str, typing.List[str]] new_fingerprint: typing.Optional[str] = None ) Dataset


  • column_names (Union[str, List[str]]) — Name of the column(s) to keep.
  • new_fingerprint (str, optional) — The new fingerprint of the dataset after transform. If None, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments.



A copy of the dataset object which only consists of selected columns.

Select one or several column(s) in the dataset and the features associated to them.


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> ds.select_columns(['text'])
    features: ['text'],
    num_rows: 1066


< >

( column: str include_nulls: bool = False )


  • column (str) — The name of the column to cast (list all the column names with column_names)
  • include_nulls (bool, defaults to False) — Whether to include null values in the class labels. If True, the null values will be encoded as the "None" class label.

    Added in 1.14.2

Casts the given column as ClassLabel and updates the table.


>>> from datasets import load_dataset
>>> ds = load_dataset("boolq", split="validation")
>>> ds.features
{'answer': Value(dtype='bool', id=None),
 'passage': Value(dtype='string', id=None),
 'question': Value(dtype='string', id=None)}
>>> ds = ds.class_encode_column('answer')
>>> ds.features
{'answer': ClassLabel(num_classes=2, names=['False', 'True'], id=None),
 'passage': Value(dtype='string', id=None),
 'question': Value(dtype='string', id=None)}


< >

( )

Number of rows in the dataset.


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> ds.__len__
<bound method Dataset.__len__ of Dataset({
    features: ['text', 'label'],
    num_rows: 1066


< >

( )

Iterate through the examples.

If a formatting is set with Dataset.set_format() rows will be returned with the selected format.


< >

( batch_size: int drop_last_batch: bool = False )


  • batch_size (int) — size of each batch to yield.
  • drop_last_batch (bool, default False) — Whether a last batch smaller than the batch_size should be dropped

Iterate through the batches of size batch_size.

If a formatting is set with [~datasets.Dataset.set_format] rows will be returned with the selected format.


< >

( type: typing.Optional[str] = None columns: typing.Optional[typing.List] = None output_all_columns: bool = False **format_kwargs )


  • type (str, optional) — Output type selected in [None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']. None means `getitem“ returns python objects (default).
  • columns (List[str], optional) — Columns to format in the output. None means __getitem__ returns all columns (default).
  • output_all_columns (bool, defaults to False) — Keep un-formatted columns as well in the output (as python objects).
  • **format_kwargs (additional keyword arguments) — Keywords arguments passed to the convert function like np.array, torch.tensor or tensorflow.ragged.constant.

To be used in a with statement. Set __getitem__ return format (type and columns).


< >

( type: typing.Optional[str] = None columns: typing.Optional[typing.List] = None output_all_columns: bool = False **format_kwargs )


  • type (str, optional) — Either output type selected in [None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']. None means __getitem__ returns python objects (default).
  • columns (List[str], optional) — Columns to format in the output. None means __getitem__ returns all columns (default).
  • output_all_columns (bool, defaults to False) — Keep un-formatted columns as well in the output (as python objects).
  • **format_kwargs (additional keyword arguments) — Keywords arguments passed to the convert function like np.array, torch.tensor or tensorflow.ragged.constant.

Set __getitem__ return format (type and columns). The data formatting is applied on-the-fly. The format type (for example “numpy”) is used to format batches when using __getitem__. It’s also possible to use custom transforms for formatting using set_transform().

It is possible to call map() after calling set_format. Since map may add new columns, then the list of formatted columns

gets updated. In this case, if you apply map on a dataset to add a new column, then this column will be formatted as:

new formatted columns = (all columns - previously unformatted columns)


>>> from datasets import load_dataset
>>> from transformers import AutoTokenizer
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
>>> ds = x: tokenizer(x['text'], truncation=True, padding=True), batched=True)
>>> ds.set_format(type='numpy', columns=['text', 'label'])
>>> ds.format
{'type': 'numpy',
'format_kwargs': {},
'columns': ['text', 'label'],
'output_all_columns': False}


< >

( transform: typing.Optional[typing.Callable] columns: typing.Optional[typing.List] = None output_all_columns: bool = False )


  • transform (Callable, optional) — User-defined formatting transform, replaces the format defined by set_format(). A formatting function is a callable that takes a batch (as a dict) as input and returns a batch. This function is applied right before returning the objects in __getitem__.
  • columns (List[str], optional) — Columns to format in the output. If specified, then the input batch of the transform only contains those columns.
  • output_all_columns (bool, defaults to False) — Keep un-formatted columns as well in the output (as python objects). If set to True, then the other un-formatted columns are kept with the output of the transform.

Set __getitem__ return format using this transform. The transform is applied on-the-fly on batches when __getitem__ is called. As set_format(), this can be reset using reset_format().


>>> from datasets import load_dataset
>>> from transformers import AutoTokenizer
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
>>> def encode(batch):
...     return tokenizer(batch['text'], padding=True, truncation=True, return_tensors='pt')
>>> ds.set_transform(encode)
>>> ds[0]
{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
 1, 1]),
 'input_ids': tensor([  101, 29353,  2135, 15102,  1996,  9428, 20868,  2890,  8663,  6895,
         20470,  2571,  3663,  2090,  4603,  3017,  3008,  1998,  2037, 24211,
         5637,  1998, 11690,  2336,  1012,   102]),
 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0])}


< >

( )

Reset __getitem__ return format to python objects and all columns.

Same as self.set_format()


>>> from datasets import load_dataset
>>> from transformers import AutoTokenizer
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
>>> ds = x: tokenizer(x['text'], truncation=True, padding=True), batched=True)
>>> ds.set_format(type='numpy', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])
>>> ds.format
{'columns': ['input_ids', 'token_type_ids', 'attention_mask', 'label'],
 'format_kwargs': {},
 'output_all_columns': False,
 'type': 'numpy'}
>>> ds.reset_format()
>>> ds.format
{'columns': ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
 'format_kwargs': {},
 'output_all_columns': False,
 'type': None}


< >

( type: typing.Optional[str] = None columns: typing.Optional[typing.List] = None output_all_columns: bool = False **format_kwargs )


  • type (str, optional) — Either output type selected in [None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']. None means __getitem__ returns python objects (default).
  • columns (List[str], optional) — Columns to format in the output. None means __getitem__ returns all columns (default).
  • output_all_columns (bool, defaults to False) — Keep un-formatted columns as well in the output (as python objects).
  • **format_kwargs (additional keyword arguments) — Keywords arguments passed to the convert function like np.array, torch.tensor or tensorflow.ragged.constant.

Set __getitem__ return format (type and columns). The data formatting is applied on-the-fly. The format type (for example “numpy”) is used to format batches when using __getitem__.

It’s also possible to use custom transforms for formatting using with_transform().

Contrary to set_format(), with_format returns a new Dataset object.


>>> from datasets import load_dataset
>>> from transformers import AutoTokenizer
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
>>> ds = x: tokenizer(x['text'], truncation=True, padding=True), batched=True)
>>> ds.format
{'columns': ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
 'format_kwargs': {},
 'output_all_columns': False,
 'type': None}
>>> ds = ds.with_format("torch")
>>> ds.format
{'columns': ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
 'format_kwargs': {},
 'output_all_columns': False,
 'type': 'torch'}
>>> ds[0]
{'text': 'compassionately explores the seemingly irreconcilable situation between conservative christian parents and their estranged gay and lesbian children .',
 'label': tensor(1),
 'input_ids': tensor([  101, 18027, 16310, 16001,  1103,  9321,   178, 11604,  7235,  6617,
        1742,  2165,  2820,  1206,  6588, 22572, 12937,  1811,  2153,  1105,
        1147, 12890, 19587,  6463,  1105, 15026,  1482,   119,   102,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0]),
 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}


< >

( transform: typing.Optional[typing.Callable] columns: typing.Optional[typing.List] = None output_all_columns: bool = False )


  • transform (Callable, optional) — User-defined formatting transform, replaces the format defined by set_format(). A formatting function is a callable that takes a batch (as a dict) as input and returns a batch. This function is applied right before returning the objects in __getitem__.
  • columns (List[str], optional) — Columns to format in the output. If specified, then the input batch of the transform only contains those columns.
  • output_all_columns (bool, defaults to False) — Keep un-formatted columns as well in the output (as python objects). If set to True, then the other un-formatted columns are kept with the output of the transform.

Set __getitem__ return format using this transform. The transform is applied on-the-fly on batches when __getitem__ is called.

As set_format(), this can be reset using reset_format().

Contrary to set_transform(), with_transform returns a new Dataset object.


>>> from datasets import load_dataset
>>> from transformers import AutoTokenizer
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
>>> def encode(example):
...     return tokenizer(example["text"], padding=True, truncation=True, return_tensors='pt')
>>> ds = ds.with_transform(encode)
>>> ds[0]
{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
 1, 1, 1, 1, 1]),
 'input_ids': tensor([  101, 18027, 16310, 16001,  1103,  9321,   178, 11604,  7235,  6617,
         1742,  2165,  2820,  1206,  6588, 22572, 12937,  1811,  2153,  1105,
         1147, 12890, 19587,  6463,  1105, 15026,  1482,   119,   102]),
 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0])}


< >

( key )

Can be used to index columns (by string names) or rows (by integer index or iterable of indices or bools).


< >

( ) int



Number of removed files.

Clean up all cache files in the dataset cache directory, excepted the currently used cache file if there is one.

Be careful when running this command that no other process is currently using other cache files.


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> ds.cleanup_cache_files()


< >

( function: typing.Optional[typing.Callable] = None with_indices: bool = False with_rank: bool = False input_columns: typing.Union[str, typing.List[str], NoneType] = None batched: bool = False batch_size: typing.Optional[int] = 1000 drop_last_batch: bool = False remove_columns: typing.Union[str, typing.List[str], NoneType] = None keep_in_memory: bool = False load_from_cache_file: typing.Optional[bool] = None cache_file_name: typing.Optional[str] = None writer_batch_size: typing.Optional[int] = 1000 features: typing.Optional[datasets.features.features.Features] = None disable_nullable: bool = False fn_kwargs: typing.Optional[dict] = None num_proc: typing.Optional[int] = None suffix_template: str = '_{rank:05d}_of_{num_proc:05d}' new_fingerprint: typing.Optional[str] = None desc: typing.Optional[str] = None )


  • function (Callable) — Function with one of the following signatures:

    • function(example: Dict[str, Any]) -> Dict[str, Any] if batched=False and with_indices=False and with_rank=False
    • function(example: Dict[str, Any], *extra_args) -> Dict[str, Any] if batched=False and with_indices=True and/or with_rank=True (one extra arg for each)
    • function(batch: Dict[str, List]) -> Dict[str, List] if batched=True and with_indices=False and with_rank=False
    • function(batch: Dict[str, List], *extra_args) -> Dict[str, List] if batched=True and with_indices=True and/or with_rank=True (one extra arg for each)

    For advanced usage, the function can also return a pyarrow.Table. Moreover if your function returns nothing (None), then map will run your function and return the dataset unchanged. If no function is provided, default to identity function: lambda x: x.

  • with_indices (bool, defaults to False) — Provide example indices to function. Note that in this case the signature of function should be def function(example, idx[, rank]): ....
  • with_rank (bool, defaults to False) — Provide process rank to function. Note that in this case the signature of function should be def function(example[, idx], rank): ....
  • input_columns (Optional[Union[str, List[str]]], defaults to None) — The columns to be passed into function as positional arguments. If None, a dict mapping to all formatted columns is passed as one argument.
  • batched (bool, defaults to False) — Provide batch of examples to function.
  • batch_size (int, optional, defaults to 1000) — Number of examples per batch provided to function if batched=True. If batch_size <= 0 or batch_size == None, provide the full dataset as a single batch to function.
  • drop_last_batch (bool, defaults to False) — Whether a last batch smaller than the batch_size should be dropped instead of being processed by the function.
  • remove_columns (Optional[Union[str, List[str]]], defaults to None) — Remove a selection of columns while doing the mapping. Columns will be removed before updating the examples with the output of function, i.e. if function is adding columns with names in remove_columns, these columns will be kept.
  • keep_in_memory (bool, defaults to False) — Keep the dataset in memory instead of writing it to a cache file.
  • load_from_cache_file (Optional[bool], defaults to True if caching is enabled) — If a cache file storing the current computation from function can be identified, use it instead of recomputing.
  • cache_file_name (str, optional, defaults to None) — Provide the name of a path for the cache file. It is used to store the results of the computation instead of the automatically generated cache file name.
  • writer_batch_size (int, defaults to 1000) — Number of rows per write operation for the cache file writer. This value is a good trade-off between memory usage during the processing, and processing speed. Higher value makes the processing do fewer lookups, lower value consume less temporary memory while running map.
  • features (Optional[datasets.Features], defaults to None) — Use a specific Features to store the cache file instead of the automatically generated one.
  • disable_nullable (bool, defaults to False) — Disallow null values in the table.
  • fn_kwargs (Dict, optional, defaults to None) — Keyword arguments to be passed to function.
  • num_proc (int, optional, defaults to None) — Max number of processes when generating cache. Already cached shards are loaded sequentially.
  • suffix_template (str) — If cache_file_name is specified, then this suffix will be added at the end of the base name of each. Defaults to "_{rank:05d}_of_{num_proc:05d}". For example, if cache_file_name is “processed.arrow”, then for rank=1 and num_proc=4, the resulting file would be "processed_00001_of_00004.arrow" for the default suffix.
  • new_fingerprint (str, optional, defaults to None) — The new fingerprint of the dataset after transform. If None, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments.
  • desc (str, optional, defaults to None) — Meaningful description to be displayed alongside with the progress bar while mapping examples.

Apply a function to all the examples in the table (individually or in batches) and update the table. If your function returns a column that already exists, then it overwrites it.

You can specify whether the function should be batched or not with the batched parameter:

  • If batched is False, then the function takes 1 example in and should return 1 example. An example is a dictionary, e.g. {"text": "Hello there !"}.
  • If batched is True and batch_size is 1, then the function takes a batch of 1 example as input and can return a batch with 1 or more examples. A batch is a dictionary, e.g. a batch of 1 example is {"text": ["Hello there !"]}.
  • If batched is True and batch_size is n > 1, then the function takes a batch of n examples as input and can return a batch with n examples, or with an arbitrary number of examples. Note that the last batch may have less than n examples. A batch is a dictionary, e.g. a batch of n examples is {"text": ["Hello there !"] * n}.


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> def add_prefix(example):
...     example["text"] = "Review: " + example["text"]
...     return example
>>> ds =
>>> ds[0:3]["text"]
['Review: compassionately explores the seemingly irreconcilable situation between conservative christian parents and their estranged gay and lesbian children .',
 'Review: the soundtrack alone is worth the price of admission .',
 'Review: rodriguez does a splendid job of racial profiling hollywood style--casting excellent latin actors of all ages--a trend long overdue .']

# process a batch of examples
>>> ds = example: tokenizer(example["text"]), batched=True)
# set number of processors
>>> ds =, num_proc=4)


< >

( function: typing.Optional[typing.Callable] = None with_indices: bool = False with_rank: bool = False input_columns: typing.Union[str, typing.List[str], NoneType] = None batched: bool = False batch_size: typing.Optional[int] = 1000 keep_in_memory: bool = False load_from_cache_file: typing.Optional[bool] = None cache_file_name: typing.Optional[str] = None writer_batch_size: typing.Optional[int] = 1000 fn_kwargs: typing.Optional[dict] = None num_proc: typing.Optional[int] = None suffix_template: str = '_{rank:05d}_of_{num_proc:05d}' new_fingerprint: typing.Optional[str] = None desc: typing.Optional[str] = None )


  • function (Callable) — Callable with one of the following signatures:

    • function(example: Dict[str, Any]) -> bool if batched=False and with_indices=False and with_rank=False
    • function(example: Dict[str, Any], *extra_args) -> bool if batched=False and with_indices=True and/or with_rank=True (one extra arg for each)
    • function(batch: Dict[str, List]) -> List[bool] if batched=True and with_indices=False and with_rank=False
    • function(batch: Dict[str, List], *extra_args) -> List[bool] if batched=True and with_indices=True and/or with_rank=True (one extra arg for each)

    If no function is provided, defaults to an always True function: lambda x: True.

  • with_indices (bool, defaults to False) — Provide example indices to function. Note that in this case the signature of function should be def function(example, idx[, rank]): ....
  • with_rank (bool, defaults to False) — Provide process rank to function. Note that in this case the signature of function should be def function(example[, idx], rank): ....
  • input_columns (str or List[str], optional) — The columns to be passed into function as positional arguments. If None, a dict mapping to all formatted columns is passed as one argument.
  • batched (bool, defaults to False) — Provide batch of examples to function.
  • batch_size (int, optional, defaults to 1000) — Number of examples per batch provided to function if batched = True. If batched = False, one example per batch is passed to function. If batch_size <= 0 or batch_size == None, provide the full dataset as a single batch to function.
  • keep_in_memory (bool, defaults to False) — Keep the dataset in memory instead of writing it to a cache file.
  • load_from_cache_file (Optional[bool], defaults to True if caching is enabled) — If a cache file storing the current computation from function can be identified, use it instead of recomputing.
  • cache_file_name (str, optional) — Provide the name of a path for the cache file. It is used to store the results of the computation instead of the automatically generated cache file name.
  • writer_batch_size (int, defaults to 1000) — Number of rows per write operation for the cache file writer. This value is a good trade-off between memory usage during the processing, and processing speed. Higher value makes the processing do fewer lookups, lower value consume less temporary memory while running map.
  • fn_kwargs (dict, optional) — Keyword arguments to be passed to function.
  • num_proc (int, optional) — Number of processes for multiprocessing. By default it doesn’t use multiprocessing.
  • suffix_template (str) — If cache_file_name is specified, then this suffix will be added at the end of the base name of each. For example, if cache_file_name is "processed.arrow", then for rank = 1 and num_proc = 4, the resulting file would be "processed_00001_of_00004.arrow" for the default suffix (default _{rank:05d}_of_{num_proc:05d}).
  • new_fingerprint (str, optional) — The new fingerprint of the dataset after transform. If None, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments.
  • desc (str, optional, defaults to None) — Meaningful description to be displayed alongside with the progress bar while filtering examples.

Apply a filter function to all the elements in the table in batches and update the table so that the dataset only includes examples according to the filter function.


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> ds.filter(lambda x: x["label"] == 1)
    features: ['text', 'label'],
    num_rows: 533


< >

( indices: typing.Iterable keep_in_memory: bool = False indices_cache_file_name: typing.Optional[str] = None writer_batch_size: typing.Optional[int] = 1000 new_fingerprint: typing.Optional[str] = None )


  • indices (range, list, iterable, ndarray or Series) — Range, list or 1D-array of integer indices for indexing. If the indices correspond to a contiguous range, the Arrow table is simply sliced. However passing a list of indices that are not contiguous creates indices mapping, which is much less efficient, but still faster than recreating an Arrow table made of the requested rows.
  • keep_in_memory (bool, defaults to False) — Keep the indices mapping in memory instead of writing it to a cache file.
  • indices_cache_file_name (str, optional, defaults to None) — Provide the name of a path for the cache file. It is used to store the indices mapping instead of the automatically generated cache file name.
  • writer_batch_size (int, defaults to 1000) — Number of rows per write operation for the cache file writer. This value is a good trade-off between memory usage during the processing, and processing speed. Higher value makes the processing do fewer lookups, lower value consume less temporary memory while running map.
  • new_fingerprint (str, optional, defaults to None) — The new fingerprint of the dataset after transform. If None, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments.

Create a new dataset with rows selected following the list/array of indices.


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
    features: ['text', 'label'],
    num_rows: 4


< >

( column_names: typing.Union[str, typing.Sequence[str]] reverse: typing.Union[bool, typing.Sequence[bool]] = False null_placement: str = 'at_end' keep_in_memory: bool = False load_from_cache_file: typing.Optional[bool] = None indices_cache_file_name: typing.Optional[str] = None writer_batch_size: typing.Optional[int] = 1000 new_fingerprint: typing.Optional[str] = None )


  • column_names (Union[str, Sequence[str]]) — Column name(s) to sort by.
  • reverse (Union[bool, Sequence[bool]], defaults to False) — If True, sort by descending order rather than ascending. If a single bool is provided, the value is applied to the sorting of all column names. Otherwise a list of bools with the same length and order as column_names must be provided.
  • null_placement (str, defaults to at_end) — Put None values at the beginning if at_start or first or at the end if at_end or last

    Added in 1.14.2

  • keep_in_memory (bool, defaults to False) — Keep the sorted indices in memory instead of writing it to a cache file.
  • load_from_cache_file (Optional[bool], defaults to True if caching is enabled) — If a cache file storing the sorted indices can be identified, use it instead of recomputing.
  • indices_cache_file_name (str, optional, defaults to None) — Provide the name of a path for the cache file. It is used to store the sorted indices instead of the automatically generated cache file name.
  • writer_batch_size (int, defaults to 1000) — Number of rows per write operation for the cache file writer. Higher value gives smaller cache files, lower value consume less temporary memory.
  • new_fingerprint (str, optional, defaults to None) — The new fingerprint of the dataset after transform. If None, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments

Create a new dataset sorted according to a single or multiple columns.


>>> from datasets import load_dataset
>>> ds = load_dataset('rotten_tomatoes', split='validation')
>>> ds['label'][:10]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
>>> sorted_ds = ds.sort('label')
>>> sorted_ds['label'][:10]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
>>> another_sorted_ds = ds.sort(['label', 'text'], reverse=[True, False])
>>> another_sorted_ds['label'][:10]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


< >

( seed: typing.Optional[int] = None generator: typing.Optional[numpy.random._generator.Generator] = None keep_in_memory: bool = False load_from_cache_file: typing.Optional[bool] = None indices_cache_file_name: typing.Optional[str] = None writer_batch_size: typing.Optional[int] = 1000 new_fingerprint: typing.Optional[str] = None )


  • seed (int, optional) — A seed to initialize the default BitGenerator if generator=None. If None, then fresh, unpredictable entropy will be pulled from the OS. If an int or array_like[ints] is passed, then it will be passed to SeedSequence to derive the initial BitGenerator state.
  • generator (numpy.random.Generator, optional) — Numpy random Generator to use to compute the permutation of the dataset rows. If generator=None (default), uses np.random.default_rng (the default BitGenerator (PCG64) of NumPy).
  • keep_in_memory (bool, default False) — Keep the shuffled indices in memory instead of writing it to a cache file.
  • load_from_cache_file (Optional[bool], defaults to True if caching is enabled) — If a cache file storing the shuffled indices can be identified, use it instead of recomputing.
  • indices_cache_file_name (str, optional) — Provide the name of a path for the cache file. It is used to store the shuffled indices instead of the automatically generated cache file name.
  • writer_batch_size (int, defaults to 1000) — Number of rows per write operation for the cache file writer. This value is a good trade-off between memory usage during the processing, and processing speed. Higher value makes the processing do fewer lookups, lower value consume less temporary memory while running map.
  • new_fingerprint (str, optional, defaults to None) — The new fingerprint of the dataset after transform. If None, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments.

Create a new Dataset where the rows are shuffled.

Currently shuffling uses numpy random generators. You can either supply a NumPy BitGenerator to use, or a seed to initiate NumPy’s default random generator (PCG64).

Shuffling takes the list of indices [0:len(my_dataset)] and shuffles it to create an indices mapping. However as soon as your Dataset has an indices mapping, the speed can become 10x slower. This is because there is an extra step to get the row index to read using the indices mapping, and most importantly, you aren’t reading contiguous chunks of data anymore. To restore the speed, you’d need to rewrite the entire dataset on your disk again using Dataset.flatten_indices(), which removes the indices mapping.

This may take a lot of time depending of the size of your dataset though:

my_dataset[0]  # fast
my_dataset = my_dataset.shuffle(seed=42)
my_dataset[0]  # up to 10x slower
my_dataset = my_dataset.flatten_indices()  # rewrite the shuffled dataset on disk as contiguous chunks of data
my_dataset[0]  # fast again

In this case, we recommend switching to an IterableDataset and leveraging its fast approximate shuffling method IterableDataset.shuffle().

It only shuffles the shards order and adds a shuffle buffer to your dataset, which keeps the speed of your dataset optimal:

my_iterable_dataset = my_dataset.to_iterable_dataset(num_shards=128)
for example in enumerate(my_iterable_dataset):  # fast

shuffled_iterable_dataset = my_iterable_dataset.shuffle(seed=42, buffer_size=100)

for example in enumerate(shuffled_iterable_dataset):  # as fast as before


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> ds['label'][:10]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

# set a seed
>>> shuffled_ds = ds.shuffle(seed=42)
>>> shuffled_ds['label'][:10]
[1, 0, 1, 1, 0, 0, 0, 0, 0, 0]


< >

( n: int )


  • n (int) — Number of elements to skip.

Create a new Dataset that skips the first n elements.


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="train")
>>> list(ds.take(3))
[{'label': 1,
 'text': 'the rock is destined to be the 21st century's new " conan " and that he's going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'},
 {'label': 1,
 'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth .'},
 {'label': 1, 'text': 'effective but too-tepid biopic'}]
>>> ds = ds.skip(1)
>>> list(ds.take(3))
[{'label': 1,
 'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth .'},
 {'label': 1, 'text': 'effective but too-tepid biopic'},
 {'label': 1,
 'text': 'if you sometimes like to go to the movies to have fun , wasabi is a good place to start .'}]


< >

( n: int )


  • n (int) — Number of elements to take.

Create a new Dataset with only the first n elements.


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="train")
>>> small_ds = ds.take(2)
>>> list(small_ds)
[{'label': 1,
 'text': 'the rock is destined to be the 21st century's new " conan " and that he's going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'},
 {'label': 1,
 'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth .'}]


< >

( test_size: typing.Union[float, int, NoneType] = None train_size: typing.Union[float, int, NoneType] = None shuffle: bool = True stratify_by_column: typing.Optional[str] = None seed: typing.Optional[int] = None generator: typing.Optional[numpy.random._generator.Generator] = None keep_in_memory: bool = False load_from_cache_file: typing.Optional[bool] = None train_indices_cache_file_name: typing.Optional[str] = None test_indices_cache_file_name: typing.Optional[str] = None writer_batch_size: typing.Optional[int] = 1000 train_new_fingerprint: typing.Optional[str] = None test_new_fingerprint: typing.Optional[str] = None )


  • test_size (numpy.random.Generator, optional) — Size of the test split If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test samples. If None, the value is set to the complement of the train size. If train_size is also None, it will be set to 0.25.
  • train_size (numpy.random.Generator, optional) — Size of the train split If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the train split. If int, represents the absolute number of train samples. If None, the value is automatically set to the complement of the test size.
  • shuffle (bool, optional, defaults to True) — Whether or not to shuffle the data before splitting.
  • stratify_by_column (str, optional, defaults to None) — The column name of labels to be used to perform stratified split of data.
  • seed (int, optional) — A seed to initialize the default BitGenerator if generator=None. If None, then fresh, unpredictable entropy will be pulled from the OS. If an int or array_like[ints] is passed, then it will be passed to SeedSequence to derive the initial BitGenerator state.
  • generator (numpy.random.Generator, optional) — Numpy random Generator to use to compute the permutation of the dataset rows. If generator=None (default), uses np.random.default_rng (the default BitGenerator (PCG64) of NumPy).
  • keep_in_memory (bool, defaults to False) — Keep the splits indices in memory instead of writing it to a cache file.
  • load_from_cache_file (Optional[bool], defaults to True if caching is enabled) — If a cache file storing the splits indices can be identified, use it instead of recomputing.
  • train_cache_file_name (str, optional) — Provide the name of a path for the cache file. It is used to store the train split indices instead of the automatically generated cache file name.
  • test_cache_file_name (str, optional) — Provide the name of a path for the cache file. It is used to store the test split indices instead of the automatically generated cache file name.
  • writer_batch_size (int, defaults to 1000) — Number of rows per write operation for the cache file writer. This value is a good trade-off between memory usage during the processing, and processing speed. Higher value makes the processing do fewer lookups, lower value consume less temporary memory while running map.
  • train_new_fingerprint (str, optional, defaults to None) — The new fingerprint of the train set after transform. If None, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments
  • test_new_fingerprint (str, optional, defaults to None) — The new fingerprint of the test set after transform. If None, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments

Return a dictionary (datasets.DatasetDict) with two random train and test subsets (train and test Dataset splits). Splits are created from the dataset according to test_size, train_size and shuffle.

This method is similar to scikit-learn train_test_split.


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> ds = ds.train_test_split(test_size=0.2, shuffle=True)
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 852
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 214

# set a seed
>>> ds = ds.train_test_split(test_size=0.2, seed=42)

# stratified split
>>> ds = load_dataset("imdb",split="train")
    features: ['text', 'label'],
    num_rows: 25000
>>> ds = ds.train_test_split(test_size=0.2, stratify_by_column="label")
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 20000
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 5000


< >

( num_shards: int index: int contiguous: bool = True keep_in_memory: bool = False indices_cache_file_name: typing.Optional[str] = None writer_batch_size: typing.Optional[int] = 1000 )


  • num_shards (int) — How many shards to split the dataset into.
  • index (int) — Which shard to select and return.
  • contiguous — (bool, defaults to True): Whether to select contiguous blocks of indices for shards.
  • keep_in_memory (bool, defaults to False) — Keep the dataset in memory instead of writing it to a cache file.
  • indices_cache_file_name (str, optional) — Provide the name of a path for the cache file. It is used to store the indices of each shard instead of the automatically generated cache file name.
  • writer_batch_size (int, defaults to 1000) — This only concerns the indices mapping. Number of indices per write operation for the cache file writer. This value is a good trade-off between memory usage during the processing, and processing speed. Higher value makes the processing do fewer lookups, lower value consume less temporary memory while running map.

Return the index-nth shard from dataset split into num_shards pieces.

This shards deterministically. dataset.shard(n, i) splits the dataset into contiguous chunks, so it can be easily concatenated back together after processing. If len(dataset) % n == l, then the first l dataset each have length (len(dataset) // n) + 1, and the remaining dataset have length (len(dataset) // n). datasets.concatenate_datasets([dset.shard(n, i) for i in range(n)]) returns a dataset with the same order as the original.

Note: n should be less or equal to the number of elements in the dataset len(dataset).

On the other hand, dataset.shard(n, i, contiguous=False) contains all elements of the dataset whose index mod n = i.

Be sure to shard before using any randomizing operator (such as shuffle). It is best if the shard operator is used early in the dataset pipeline.


>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> ds
    features: ['text', 'label'],
    num_rows: 1066
>>> ds.shard(num_shards=2, index=0)
    features: ['text', 'label'],
    num_rows: 533


< >

( batch_size: typing.Optional[int] = None columns: typing.Union[str, typing.List[str], NoneType] = None shuffle: bool = False collate_fn: typing.Optional[typing.Callable] = None drop_remainder: bool = False collate_fn_args: typing.Optional[typing.Dict[str, typing.Any]] = None label_cols: typing.Union[str, typing.List[str], NoneType] = None prefetch: bool = True num_workers: int = 0 num_test_batches: int = 20 )


  • batch_size (int, optional) — Size of batches to load from the dataset. Defaults to None, which implies that the dataset won’t be batched, but the returned dataset can be batched later with tf_dataset.batch(batch_size).
  • columns (List[str] or str, optional) — Dataset column(s) to load in the Column names that are created by the collate_fn and that do not exist in the original dataset can be used.
  • shuffle(bool, defaults to False) — Shuffle the dataset order when loading. Recommended True for training, False for validation/evaluation.
  • drop_remainder(bool, defaults to False) — Drop the last incomplete batch when loading. Ensures that all batches yielded by the dataset will have the same length on the batch dimension.
  • collate_fn(Callable, optional) — A function or callable object (such as a DataCollator) that will collate lists of samples into a batch.
  • collate_fn_args (Dict, optional) — An optional dict of keyword arguments to be passed to the collate_fn.
  • label_cols (List[str] or str, defaults to None) — Dataset column(s) to load as labels. Note that many models compute loss internally rather than letting Keras do it, in which case passing the labels here is optional, as long as they’re in the input columns.
  • prefetch (bool, defaults to True) — Whether to run the dataloader in a separate thread and maintain a small buffer of batches for training. Improves performance by allowing data to be loaded in the background while the model is training.
  • num_workers (int, defaults to 0) — Number of workers to use for loading the dataset.
  • num_test_batches (int, defaults to 20) — Number of batches to use to infer the output signature of the dataset. The higher this number, the more accurate the signature will be, but the longer it will take to create the dataset.

Create a from the underlying Dataset. This will load and collate batches from the Dataset, and is suitable for passing to methods like or model.predict(). The dataset will yield dicts for both inputs and labels unless the dict would contain only a single key, in which case a raw tf.Tensor is yielded instead.


>>> ds_train = ds["train"].to_tf_dataset(
...    columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'],
...    shuffle=True,
...    batch_size=16,
...    collate_fn=data_collator,
... )


< >

( repo_id: str config_name: str = 'default' set_default: typing.Optional[bool] = None split: typing.Optional[str] = None data_dir: typing.Optional[str] = None commit_message: typing.Optional[str] = None commit_description: typing.Optional[str] = None private: typing.Optional[bool] = None token: typing.Optional[str] = None revision: typing.Optional[str] = None create_pr: typing.Optional[bool] = False max_shard_size: typing.Union[str, int, NoneType] = None num_shards: typing.Optional[int] = None embed_external_files: bool = True )


  • repo_id (str) — The ID of the repository to push to in the following format: <user>/<dataset_name> or <org>/<dataset_name>. Also accepts <dataset_name>, which will default to the namespace of the logged-in user.
  • config_name (str, defaults to “default”) — The configuration name (or subset) of a dataset. Defaults to “default”.
  • set_default (bool, optional) — Whether to set this configuration as the default one. Otherwise, the default configuration is the one named “default”.
  • split (str, optional) — The name of the split that will be given to that dataset. Defaults to self.split.
  • data_dir (str, optional) — Directory name that will contain the uploaded data files. Defaults to the config_name if different from “default”, else “data”.

    Added in 2.17.0

  • commit_message (str, optional) — Message to commit while pushing. Will default to "Upload dataset".
  • commit_description (str, optional) — Description of the commit that will be created. Additionally, description of the PR if a PR is created (create_pr is True).

    Added in 2.16.0

  • private (bool, optional) — Whether to make the repo private. If None (default), the repo will be public unless the organization’s default is private. This value is ignored if the repo already exists.
  • token (str, optional) — An optional authentication token for the Hugging Face Hub. If no token is passed, will default to the token saved locally when logging in with huggingface-cli login. Will raise an error if no token is passed and the user is not logged-in.
  • revision (str, optional) — Branch to push the uploaded files to. Defaults to the "main" branch.

    Added in 2.15.0

  • create_pr (bool, optional, defaults to False) — Whether to create a PR with the uploaded files or directly commit.

    Added in 2.15.0

  • max_shard_size (int or str, optional, defaults to "500MB") — The maximum size of the dataset shards to be uploaded to the hub. If expressed as a string, needs to be digits followed by a unit (like "5MB").
  • num_shards (int, optional) — Number of shards to write. By default, the number of shards depends on max_shard_size.

    Added in 2.8.0

  • embed_external_files (bool, defaults to True) — Whether to embed file bytes in the shards. In particular, this will do the following before the push for the fields of type:

    • Audio and Image: remove local path information and embed file content in the Parquet files.

Pushes the dataset to the hub as a Parquet dataset. The dataset is pushed using HTTP requests and does not need to have neither git or git-lfs installed.

The resulting Parquet files are self-contained by default. If your dataset contains Image, Audio or Video data, the Parquet files will store the bytes of your images or audio files. You can disable this by setting embed_external_files to False.


>>> dataset.push_to_hub("<organization>/<dataset_id>")
>>> dataset_dict.push_to_hub("<organization>/<dataset_id>", private=True)
>>> dataset.push_to_hub("<organization>/<dataset_id>", max_shard_size="1GB")
>>> dataset.push_to_hub("<organization>/<dataset_id>", num_shards=1024)

If your dataset has multiple splits (e.g. train/validation/test):

>>> train_dataset.push_to_hub("<organization>/<dataset_id>", split="train")
>>> val_dataset.push_to_hub("<organization>/<dataset_id>", split="validation")
>>> # later
>>> dataset = load_dataset("<organization>/<dataset_id>")
>>> train_dataset = dataset["train"]
>>> val_dataset = dataset["validation"]

If you want to add a new configuration (or subset) to a dataset (e.g. if the dataset has multiple tasks/versions/languages):

>>> english_dataset.push_to_hub("<organization>/<dataset_id>", "en")
>>> french_dataset.push_to_hub("<organization>/<dataset_id>", "fr")
>>> # later
>>> english_dataset = load_dataset("<organization>/<dataset_id>", "en")
>>> french_dataset = load_dataset("<organization>/<dataset_id>", "fr")


< >

( dataset_path: typing.Union[str, bytes, os.PathLike] max_shard_size: typing.Union[str, int, NoneType] = None num_shards: typing.Optional[int] = None num_proc: typing.Optional[int] = None storage_options: typing.Optional[dict] = None )


  • dataset_path (path-like) — Path (e.g. dataset/train) or remote URI (e.g. s3://my-bucket/dataset/train) of the dataset directory where the dataset will be saved to.
  • max_shard_size (int or str, optional, defaults to "500MB") — The maximum size of the dataset shards to be uploaded to the hub. If expressed as a string, needs to be digits followed by a unit (like "50MB").
  • num_shards (int, optional) — Number of shards to write. By default the number of shards depends on max_shard_size and num_proc.

    Added in 2.8.0

  • num_proc (int, optional) — Number of processes when downloading and generating the dataset locally. Multiprocessing is disabled by default.

    Added in 2.8.0

  • storage_options (dict, optional) — Key/value pairs to be passed on to the file-system backend, if any.

    Added in 2.8.0

Saves a dataset to a dataset directory, or in a filesystem using any implementation of fsspec.spec.AbstractFileSystem.

For Image, Audio and Video data:

All the Image(), Audio() and Video() data are stored in the arrow files. If you want to store paths or urls, please use the Value(“string”) type.


>>> ds.save_to_disk("path/to/dataset/directory")
>>> ds.save_to_disk("path/to/dataset/directory", max_shard_size="1GB")
>>> ds.save_to_disk("path/to/dataset/directory", num_shards=1024)


< >

( dataset_path: typing.Union[str, bytes, os.PathLike] keep_in_memory: typing.Optional[bool] = None storage_options: typing.Optional[dict] = None ) Dataset or DatasetDict


  • dataset_path (path-like) — Path (e.g. "dataset/train") or remote URI (e.g. "s3//my-bucket/dataset/train") of the dataset directory where the dataset will be loaded from.
  • keep_in_memory (bool, defaults to None) — Whether to copy the dataset in-memory. If None, the dataset will not be copied in-memory unless explicitly enabled by setting datasets.config.IN_MEMORY_MAX_SIZE to nonzero. See more details in the improve performance section.
  • storage_options (dict, optional) — Key/value pairs to be passed on to the file-system backend, if any.

    Added in 2.8.0


Dataset or DatasetDict

  • If dataset_path is a path of a dataset directory, the dataset requested.
  • If dataset_path is a path of a dataset dict directory, a datasets.DatasetDict with each split.

Loads a dataset that was previously saved using save_to_disk from a dataset directory, or from a filesystem using any implementation of fsspec.spec.AbstractFileSystem.


>>> ds = load_from_disk("path/to/dataset/directory")


< >

( keep_in_memory: bool = False cache_file_name: typing.Optional[str] = None writer_batch_size: typing.Optional[int] = 1000 features: typing.Optional[datasets.features.features.Features] = None disable_nullable: bool = False num_proc: typing.Optional[int] = None new_fingerprint: typing.Optional[str] = None )


  • keep_in_memory (bool, defaults to False) — Keep the dataset in memory instead of writing it to a cache file.
  • cache_file_name (str, optional, default None) — Provide the name of a path for the cache file. It is used to store the results of the computation instead of the automatically generated cache file name.
  • writer_batch_size (int, defaults to 1000) — Number of rows per write operation for the cache file writer. This value is a good trade-off between memory usage during the processing, and processing speed. Higher value makes the processing do fewer lookups, lower value consume less temporary memory while running map.
  • features (Optional[datasets.Features], defaults to None) — Use a specific Features to store the cache file instead of the automatically generated one.
  • disable_nullable (bool, defaults to False) — Allow null values in the table.
  • num_proc (int, optional, default None) — Max number of processes when generating cache. Already cached shards are loaded sequentially
  • new_fingerprint (str, optional, defaults to None) — The new fingerprint of the dataset after transform. If None, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments

Create and cache a new Dataset by flattening the indices mapping.


< >

( path_or_buf: typing.Union[str, bytes, os.PathLike, typing.BinaryIO] batch_size: typing.Optional[int] = None num_proc: typing.Optional[int] = None storage_options: typing.Optional[dict] = None **to_csv_kwargs ) int


  • path_or_buf (PathLike or FileOrBuffer) — Either a path to a file (e.g. file.csv), a remote URI (e.g. hf://datasets/username/my_dataset_name/data.csv), or a BinaryIO, where the dataset will be saved to in the specified format.
  • batch_size (int, optional) — Size of the batch to load in memory and write at once. Defaults to datasets.config.DEFAULT_MAX_BATCH_SIZE.
  • num_proc (int, optional) — Number of processes for multiprocessing. By default it doesn’t use multiprocessing. batch_size in this case defaults to datasets.config.DEFAULT_MAX_BATCH_SIZE but feel free to make it 5x or 10x of the default value if you have sufficient compute power.
  • storage_options (dict, optional) — Key/value pairs to be passed on to the file-system backend, if any.

    Added in 2.19.0

  • **to_csv_kwargs (additional keyword arguments) — Parameters to pass to pandas’s pandas.DataFrame.to_csv.

    Changed in 2.10.0

    Now, index defaults to False if not specified.

    If you would like to write the index, pass index=True and also set a name for the index column by passing index_label.



The number of characters or bytes written.

Exports the dataset to csv


>>> ds.to_csv("path/to/dataset/directory")


< >

( batch_size: typing.Optional[int] = None batched: bool = False )


  • batched (bool) — Set to True to return a generator that yields the dataset as batches of batch_size rows. Defaults to False (returns the whole datasets once).
  • batch_size (int, optional) — The size (number of rows) of the batches if batched is True. Defaults to datasets.config.DEFAULT_MAX_BATCH_SIZE.

Returns the dataset as a pandas.DataFrame. Can also return a generator for large datasets.


>>> ds.to_pandas()


< >

( batch_size: typing.Optional[int] = None )


  • batch_size (int, optional) — The size (number of rows) of the batches if batched is True. Defaults to datasets.config.DEFAULT_MAX_BATCH_SIZE.

Returns the dataset as a Python dict. Can also return a generator for large datasets.


>>> ds.to_dict()


< >

( path_or_buf: typing.Union[str, bytes, os.PathLike, typing.BinaryIO] batch_size: typing.Optional[int] = None num_proc: typing.Optional[int] = None storage_options: typing.Optional[dict] = None **to_json_kwargs ) int


  • path_or_buf (PathLike or FileOrBuffer) — Either a path to a file (e.g. file.json), a remote URI (e.g. hf://datasets/username/my_dataset_name/data.json), or a BinaryIO, where the dataset will be saved to in the specified format.
  • batch_size (int, optional) — Size of the batch to load in memory and write at once. Defaults to datasets.config.DEFAULT_MAX_BATCH_SIZE.
  • num_proc (int, optional) — Number of processes for multiprocessing. By default, it doesn’t use multiprocessing. batch_size in this case defaults to datasets.config.DEFAULT_MAX_BATCH_SIZE but feel free to make it 5x or 10x of the default value if you have sufficient compute power.
  • storage_options (dict, optional) — Key/value pairs to be passed on to the file-system backend, if any.

    Added in 2.19.0

  • **to_json_kwargs (additional keyword arguments) — Parameters to pass to pandas’s pandas.DataFrame.to_json. Default arguments are lines=True and `orient=“records”.

    Changed in 2.11.0

    The parameter index defaults to False if orient is "split" or "table".

    If you would like to write the index, pass index=True.



The number of characters or bytes written.

Export the dataset to JSON Lines or JSON.

The default output format is JSON Lines. To export to JSON, pass lines=False argument and the desired orient.


>>> ds.to_json("path/to/dataset/directory/filename.jsonl")


< >

( path_or_buf: typing.Union[str, bytes, os.PathLike, typing.BinaryIO] batch_size: typing.Optional[int] = None storage_options: typing.Optional[dict] = None **parquet_writer_kwargs ) int


  • path_or_buf (PathLike or FileOrBuffer) — Either a path to a file (e.g. file.parquet), a remote URI (e.g. hf://datasets/username/my_dataset_name/data.parquet), or a BinaryIO, where the dataset will be saved to in the specified format.
  • batch_size (int, optional) — Size of the batch to load in memory and write at once. Defaults to datasets.config.DEFAULT_MAX_BATCH_SIZE.
  • storage_options (dict, optional) — Key/value pairs to be passed on to the file-system backend, if any.

    Added in 2.19.0

  • **parquet_writer_kwargs (additional keyword arguments) — Parameters to pass to PyArrow’s pyarrow.parquet.ParquetWriter.



The number of characters or bytes written.

Exports the dataset to parquet


>>> ds.to_parquet("path/to/dataset/directory")


< >

( name: str con: typing.Union[str, ForwardRef('sqlalchemy.engine.Connection'), ForwardRef('sqlalchemy.engine.Engine'), ForwardRef('sqlite3.Connection')] batch_size: typing.Optional[int] = None **sql_writer_kwargs ) int


  • name (str) — Name of SQL table.
  • con (str or sqlite3.Connection or sqlalchemy.engine.Connection or sqlalchemy.engine.Connection) — A URI string or a SQLite3/SQLAlchemy connection object used to write to a database.
  • batch_size (int, optional) — Size of the batch to load in memory and write at once. Defaults to datasets.config.DEFAULT_MAX_BATCH_SIZE.
  • **sql_writer_kwargs (additional keyword arguments) — Parameters to pass to pandas’s pandas.DataFrame.to_sql.

    Changed in 2.11.0

    Now, index defaults to False if not specified.

    If you would like to write the index, pass index=True and also set a name for the index column by passing index_label.



The number of records written.

Exports the dataset to a SQL database.


>>> # con provided as a connection URI string
>>> ds.to_sql("data", "sqlite:///my_own_db.sql")
>>> # con provided as a sqlite3 connection object
>>> import sqlite3
>>> con = sqlite3.connect("my_own_db.sql")
>>> with con:
...     ds.to_sql("data", con)


< >

( num_shards: typing.Optional[int] = 1 )


  • num_shards (int, default to 1) — Number of shards to define when instantiating the iterable dataset. This is especially useful for big datasets to be able to shuffle properly, and also to enable fast parallel loading using a PyTorch DataLoader or in distributed setups for example. Shards are defined using datasets.Dataset.shard(): it simply slices the data without writing anything on disk.

Get an datasets.IterableDataset from a map-style datasets.Dataset. This is equivalent to loading a dataset in streaming mode with datasets.load_dataset(), but much faster since the data is streamed from local files.

Contrary to map-style datasets, iterable datasets are lazy and can only be iterated over (e.g. using a for loop). Since they are read sequentially in training loops, iterable datasets are much faster than map-style datasets. All the transformations applied to iterable datasets like filtering or processing are done on-the-fly when you start iterating over the dataset.

Still, it is possible to shuffle an iterable dataset using datasets.IterableDataset.shuffle(). This is a fast approximate shuffling that works best if you have multiple shards and if you specify a buffer size that is big enough.

To get the best speed performance, make sure your dataset doesn’t have an indices mapping. If this is the case, the data are not read contiguously, which can be slow sometimes. You can use ds = ds.flatten_indices() to write your dataset in contiguous chunks of data and have optimal speed before switching to an iterable dataset.


Basic usage:

>>> ids = ds.to_iterable_dataset()
>>> for example in ids:
...     pass

With lazy filtering and processing:

>>> ids = ds.to_iterable_dataset()
>>> ids = ids.filter(filter_fn).map(process_fn)  # will filter and process on-the-fly when you start iterating over the iterable dataset
>>> for example in ids:
...     pass

With sharding to enable efficient shuffling:

>>> ids = ds.to_iterable_dataset(num_shards=64)  # the dataset is split into 64 shards to be iterated over
>>> ids = ids.shuffle(buffer_size=10_000)  # will shuffle the shards order and use a shuffle buffer for fast approximate shuffling when you start iterating
>>> for example in ids:
...     pass

With a PyTorch DataLoader:

>>> import torch
>>> ids = ds.to_iterable_dataset(num_shards=64)
>>> ids = ids.filter(filter_fn).map(process_fn)
>>> dataloader =, num_workers=4)  # will assign 64 / 4 = 16 shards to each worker to load, filter and process when you start iterating
>>> for example in ids:
...     pass

With a PyTorch DataLoader and shuffling:

>>> import torch
>>> ids = ds.to_iterable_dataset(num_shards=64)
>>> ids = ids.shuffle(buffer_size=10_000)  # will shuffle the shards order and use a shuffle buffer when you start iterating
>>> dataloader =, num_workers=4)  # will assign 64 / 4 = 16 shards from the shuffled list of shards to each worker when you start iterating
>>> for example in ids:
...     pass

In a distributed setup like PyTorch DDP with a PyTorch DataLoader and shuffling

>>> from datasets.distributed import split_dataset_by_node
>>> ids = ds.to_iterable_dataset(num_shards=512)
>>> ids = ids.shuffle(buffer_size=10_000, seed=42)  # will shuffle the shards order and use a shuffle buffer when you start iterating
>>> ids = split_dataset_by_node(ds, world_size=8, rank=0)  # will keep only 512 / 8 = 64 shards from the shuffled lists of shards when you start iterating
>>> dataloader =, num_workers=4)  # will assign 64 / 4 = 16 shards from this node's list of shards to each worker when you start iterating
>>> for example in ids:
...     pass

With shuffling and multiple epochs:

>>> ids = ds.to_iterable_dataset(num_shards=64)
>>> ids = ids.shuffle(buffer_size=10_000, seed=42)  # will shuffle the shards order and use a shuffle buffer when you start iterating
>>> for epoch in range(n_epochs):
...     ids.set_epoch(epoch)  # will use effective_seed = seed + epoch to shuffle the shards and for the shuffle buffer when you start iterating
...     for example in ids:
...         pass
Feel free to also use `IterableDataset.set_epoch()` when using a PyTorch DataLoader or in distributed setups.


< >

( column: str index_name: typing.Optional[str] = None device: typing.Optional[int] = None string_factory: typing.Optional[str] = None metric_type: typing.Optional[int] = None custom_index: typing.Optional[ForwardRef('faiss.Index')] = None batch_size: int = 1000 train_size: typing.Optional[int] = None faiss_verbose: bool = False dtype = <class 'numpy.float32'> )


  • column (str) — The column of the vectors to add to the index.
  • index_name (str, optional) — The index_name/identifier of the index. This is the index_name that is used to call get_nearest_examples() or search(). By default it corresponds to column.
  • device (Union[int, List[int]], optional) — If positive integer, this is the index of the GPU to use. If negative integer, use all GPUs. If a list of positive integers is passed in, run only on those GPUs. By default it uses the CPU.
  • string_factory (str, optional) — This is passed to the index factory of Faiss to create the index. Default index class is IndexFlat.
  • metric_type (int, optional) — Type of metric. Ex: faiss.METRIC_INNER_PRODUCT or faiss.METRIC_L2.
  • custom_index (faiss.Index, optional) — Custom Faiss index that you already have instantiated and configured for your needs.
  • batch_size (int) — Size of the batch to use while adding vectors to the FaissIndex. Default value is 1000.

    Added in 2.4.0

  • train_size (int, optional) — If the index needs a training step, specifies how many vectors will be used to train the index.
  • faiss_verbose (bool, defaults to False) — Enable the verbosity of the Faiss index.
  • dtype (data-type) — The dtype of the numpy arrays that are indexed. Default is np.float32.

Add a dense index using Faiss for fast retrieval. By default the index is done over the vectors of the specified column. You can specify device if you want to run it on GPU (device must be the GPU index). You can find more information about Faiss here:


>>> ds = datasets.load_dataset('crime_and_punish', split='train')
>>> ds_with_embeddings = example: {'embeddings': embed(example['line']}))
>>> ds_with_embeddings.add_faiss_index(column='embeddings')
>>> # query
>>> scores, retrieved_examples = ds_with_embeddings.get_nearest_examples('embeddings', embed('my new query'), k=10)
>>> # save index
>>> ds_with_embeddings.save_faiss_index('embeddings', 'my_index.faiss')

>>> ds = datasets.load_dataset('crime_and_punish', split='train')
>>> # load index
>>> ds.load_faiss_index('embeddings', 'my_index.faiss')
>>> # query
>>> scores, retrieved_examples = ds.get_nearest_examples('embeddings', embed('my new query'), k=10)


< >

( external_arrays: <built-in function array> index_name: str device: typing.Optional[int] = None string_factory: typing.Optional[str] = None metric_type: typing.Optional[int] = None custom_index: typing.Optional[ForwardRef('faiss.Index')] = None batch_size: int = 1000 train_size: typing.Optional[int] = None faiss_verbose: bool = False dtype = <class 'numpy.float32'> )


  • external_arrays (np.array) — If you want to use arrays from outside the lib for the index, you can set external_arrays. It will use external_arrays to create the Faiss index instead of the arrays in the given column.
  • index_name (str) — The index_name/identifier of the index. This is the index_name that is used to call get_nearest_examples() or search().
  • device (Optional Union[int, List[int]], optional) — If positive integer, this is the index of the GPU to use. If negative integer, use all GPUs. If a list of positive integers is passed in, run only on those GPUs. By default it uses the CPU.
  • string_factory (str, optional) — This is passed to the index factory of Faiss to create the index. Default index class is IndexFlat.
  • metric_type (int, optional) — Type of metric. Ex: faiss.faiss.METRIC_INNER_PRODUCT or faiss.METRIC_L2.
  • custom_index (faiss.Index, optional) — Custom Faiss index that you already have instantiated and configured for your needs.
  • batch_size (int, optional) — Size of the batch to use while adding vectors to the FaissIndex. Default value is 1000.

    Added in 2.4.0

  • train_size (int, optional) — If the index needs a training step, specifies how many vectors will be used to train the index.
  • faiss_verbose (bool, defaults to False) — Enable the verbosity of the Faiss index.
  • dtype (numpy.dtype) — The dtype of the numpy arrays that are indexed. Default is np.float32.

Add a dense index using Faiss for fast retrieval. The index is created using the vectors of external_arrays. You can specify device if you want to run it on GPU (device must be the GPU index). You can find more information about Faiss here:


< >

( index_name: str file: typing.Union[str, pathlib.PurePath] storage_options: typing.Optional[typing.Dict] = None )


  • index_name (str) — The index_name/identifier of the index. This is the index_name that is used to call .get_nearest or .search.
  • file (str) — The path to the serialized faiss index on disk or remote URI (e.g. "s3://my-bucket/index.faiss").
  • storage_options (dict, optional) — Key/value pairs to be passed on to the file-system backend, if any.

    Added in 2.11.0

Save a FaissIndex on disk.


< >

( index_name: str file: typing.Union[str, pathlib.PurePath] device: typing.Union[int, typing.List[int], NoneType] = None storage_options: typing.Optional[typing.Dict] = None )


  • index_name (str) — The index_name/identifier of the index. This is the index_name that is used to call .get_nearest or .search.
  • file (str) — The path to the serialized faiss index on disk or remote URI (e.g. "s3://my-bucket/index.faiss").
  • device (Optional Union[int, List[int]]) — If positive integer, this is the index of the GPU to use. If negative integer, use all GPUs. If a list of positive integers is passed in, run only on those GPUs. By default it uses the CPU.
  • storage_options (dict, optional) — Key/value pairs to be passed on to the file-system backend, if any.

    Added in 2.11.0

Load a FaissIndex from disk.

If you want to do additional configurations, you can have access to the faiss index object by doing .get_index(index_name).faiss_index to make it fit your needs.


< >

( column: str index_name: typing.Optional[str] = None host: typing.Optional[str] = None port: typing.Optional[int] = None es_client: typing.Optional[ForwardRef('elasticsearch.Elasticsearch')] = None es_index_name: typing.Optional[str] = None es_index_config: typing.Optional[dict] = None )


  • column (str) — The column of the documents to add to the index.
  • index_name (str, optional) — The index_name/identifier of the index. This is the index name that is used to call get_nearest_examples() or search(). By default it corresponds to column.
  • host (str, optional, defaults to localhost) — Host of where ElasticSearch is running.
  • port (str, optional, defaults to 9200) — Port of where ElasticSearch is running.
  • es_client (elasticsearch.Elasticsearch, optional) — The elasticsearch client used to create the index if host and port are None.
  • es_index_name (str, optional) — The elasticsearch index name used to create the index.
  • es_index_config (dict, optional) — The configuration of the elasticsearch index. Default config is:

Add a text index using ElasticSearch for fast retrieval. This is done in-place.


>>> es_client = elasticsearch.Elasticsearch()
>>> ds = datasets.load_dataset('crime_and_punish', split='train')
>>> ds.add_elasticsearch_index(column='line', es_client=es_client, es_index_name="my_es_index")
>>> scores, retrieved_examples = ds.get_nearest_examples('line', 'my new query', k=10)


< >

( index_name: str es_index_name: str host: typing.Optional[str] = None port: typing.Optional[int] = None es_client: typing.Optional[ForwardRef('Elasticsearch')] = None es_index_config: typing.Optional[dict] = None )


  • index_name (str) — The index_name/identifier of the index. This is the index name that is used to call get_nearest or search.
  • es_index_name (str) — The name of elasticsearch index to load.
  • host (str, optional, defaults to localhost) — Host of where ElasticSearch is running.
  • port (str, optional, defaults to 9200) — Port of where ElasticSearch is running.
  • es_client (elasticsearch.Elasticsearch, optional) — The elasticsearch client used to create the index if host and port are None.
  • es_index_config (dict, optional) — The configuration of the elasticsearch index. Default config is:

Load an existing text index using ElasticSearch for fast retrieval.


< >

( )

List the colindex_nameumns/identifiers of all the attached indexes.


< >

( index_name: str )


  • index_name (str) — Index name.

List the index_name/identifiers of all the attached indexes.


< >

( index_name: str )


  • index_name (str) — The index_name/identifier of the index.

Drop the index with the specified column.


< >

( index_name: str query: typing.Union[str, <built-in function array>] k: int = 10 **kwargs ) (scores, indices)


  • index_name (str) — The name/identifier of the index.
  • query (Union[str, np.ndarray]) — The query as a string if index_name is a text index or as a numpy array if index_name is a vector index.
  • k (int) — The number of examples to retrieve.


(scores, indices)

A tuple of (scores, indices) where:

  • scores (List[List[float]): the retrieval scores from either FAISS (IndexFlatL2 by default) or ElasticSearch of the retrieved examples
  • indices (List[List[int]]): the indices of the retrieved examples

Find the nearest examples indices in the dataset to the query.


< >

( index_name: str queries: typing.Union[typing.List[str], <built-in function array>] k: int = 10 **kwargs ) (total_scores, total_indices)


  • index_name (str) — The index_name/identifier of the index.
  • queries (Union[List[str], np.ndarray]) — The queries as a list of strings if index_name is a text index or as a numpy array if index_name is a vector index.
  • k (int) — The number of examples to retrieve per query.


(total_scores, total_indices)

A tuple of (total_scores, total_indices) where:

  • total_scores (List[List[float]): the retrieval scores from either FAISS (IndexFlatL2 by default) or ElasticSearch of the retrieved examples per query
  • total_indices (List[List[int]]): the indices of the retrieved examples per query

Find the nearest examples indices in the dataset to the query.


< >

( index_name: str query: typing.Union[str, <built-in function array>] k: int = 10 **kwargs ) (scores, examples)


  • index_name (str) — The index_name/identifier of the index.
  • query (Union[str, np.ndarray]) — The query as a string if index_name is a text index or as a numpy array if index_name is a vector index.