Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Process image data

🤗 Datasets support loading and processing images with the Image feature. This guide will show you how to:

  • Load an image dataset.
  • Load a generic image dataset with ImageFolder.
  • Use map() to quickly apply transforms to an entire dataset.
  • Add data augmentations to your images with Dataset.set_transform().

Image datasets

The images in an image dataset are typically either a:

  • PIL image.
  • Path to an image file you can load.

For example, load the Food-101 dataset and take a look:

>>> from datasets import load_dataset, Image

>>> dataset = load_dataset("food101", split="train[:100]")
>>> dataset[0]
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=384x512 at 0x7FC45AB5C590>,
 'label': 6}

The Image feature automatically decodes the data from the image column to return an image object. Now try and call the image column to see what the image is:

>>> from datasets import load_dataset, Image

>>> dataset = load_dataset("food101", split="train[100:200]")
>>> dataset[0]["image"]


To load an image from its path, use the cast_column() method. The Image feature will decode the data at the path to return an image object:

>>> from datasets import load_dataset, Image

>>> dataset = Dataset.from_dict({"image": ["path/to/image_1", "path/to/image_2", ..., "path/to/image_n"]}).cast_column("image", Image())
>>> dataset[0]["image"]

You can also access the path and bytes of an image file by setting decode=False when you load a dataset. In this case, you will need to cast the image column:

>>> dataset = load_dataset("food101", split="train[:100]").cast_column('image', Image(decode=False))


You can also load your image dataset with a ImageFolder dataset builder without writing a custom dataloader. Your image dataset structure should look like this:



Then load your dataset by specifying imagefolder and the directory of your dataset in data_dir:

>>> from datasets import load_dataset
>>> dataset = load_dataset("imagefolder", data_dir="/path/to/folder")

Load remote datasets from their URLs with the data_files parameter:

>>> dataset = load_dataset("imagefolder", data_files="", split="train")

ImageFolder will create a label column, and the label name is based on the directory name.


map() can apply transforms over an entire dataset and it also generates a cache file.

Create a simple Resize function:

>>> def transforms(examples):
...     examples["pixel_values"] = [image.convert("RGB").resize((100,100)) for image in examples["image"]]
...     return examples

Now map() the function over the entire dataset and set batched=True. The transform returns pixel_values as a cacheable PIL.Image object:

>>> dataset =, remove_columns=["image"], batched=True)
>>> dataset[0]
{'label': 6,
 'pixel_values': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=100x100 at 0x7F058237BB10>}

This saves time because you don’t have to execute the same transform twice. It is best to use map() for operations you only run once per training - like resizing an image - instead of using it for operations executed for each epoch, like data augmentations.

map() takes up some memory, but you can reduce its memory requirements with the following parameters:

  • batch_size determines the number of examples that are processed in one call to the transform function.
  • writer_batch_size determines the number of processed examples that are kept in memory before they are stored away.

Both parameter values default to 1000, which can be expensive if you are storing images. Lower the value to use less memory when calling map().

Data augmentation

Adding data augmentations to a dataset is common to prevent overfitting and achieve better performance. You can use any library or package you want to apply the augmentations. This guide will use the transforms from torchvision.

Feel free to use other data augmentation libraries like Albumentations. 🤗 Datasets can apply any custom function and transforms to an entire dataset!

Add the ColorJitter transform to change the color properties of the image randomly:

>>> from torchvision.transforms import Compose, ColorJitter, ToTensor

>>> jitter = Compose(
...     [
...          ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.7),
...          ToTensor(),
...     ]
... )

Create a function to apply the ColorJitter transform to an image:

>>> def transforms(examples):
...     examples["pixel_values"] = [jitter(image.convert("RGB")) for image in examples["image"]]
...     return examples

Then you can use the set_transform() function to apply the transform on-the-fly to consume less disk space. Use this function if you only need to access the examples once:

>>> dataset.set_transform(transforms)

Now visualize the results of the ColorJitter transform:

>>> import numpy as np
>>> import matplotlib.pyplot as plt

>>> img = dataset[0]["pixel_values"]
>>> plt.imshow(img.permute(1, 2, 0))