Process image data
🤗 Datasets support loading and processing images with the Image feature. This guide will show you how to:
- Load an image dataset.
- Load a generic image dataset with
ImageFolder
. - Use map() to quickly apply transforms to an entire dataset.
- Add data augmentations to your images with Dataset.set_transform().
Image datasets
The images in an image dataset are typically either a:
- PIL
image
. - Path to an image file you can load.
For example, load the Food-101 dataset and take a look:
>>> from datasets import load_dataset, Image
>>> dataset = load_dataset("food101", split="train[:100]")
>>> dataset[0]
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=384x512 at 0x7FC45AB5C590>,
'label': 6}
The Image feature automatically decodes the data from the image
column to return an image object. Now try and call the image
column to see what the image is:
>>> from datasets import load_dataset, Image
>>> dataset = load_dataset("food101", split="train[100:200]")
>>> dataset[0]["image"]
To load an image from its path, use the cast_column() method. The Image feature will decode the data at the path to return an image object:
>>> from datasets import load_dataset, Image
>>> dataset = Dataset.from_dict({"image": ["path/to/image_1", "path/to/image_2", ..., "path/to/image_n"]}).cast_column("image", Image())
>>> dataset[0]["image"]
You can also access the path and bytes of an image file by setting decode=False
when you load a dataset. In this case, you will need to cast the image
column:
>>> dataset = load_dataset("food101", split="train[:100]").cast_column('image', Image(decode=False))
ImageFolder
You can also load your image dataset with a ImageFolder
dataset builder without writing a custom dataloader. Your image dataset structure should look like this:
folder/dog/golden_retriever.png
folder/dog/german_shepherd.png
folder/dog/chihuahua.png
folder/cat/maine_coon.png
folder/cat/bengal.png
folder/cat/birman.png
Then load your dataset by specifying imagefolder
and the directory of your dataset in data_dir
:
>>> from datasets import load_dataset
>>> dataset = load_dataset("imagefolder", data_dir="/path/to/folder")
Load remote datasets from their URLs with the data_files
parameter:
>>> dataset = load_dataset("imagefolder", data_files="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip", split="train")
ImageFolder
will create a label
column, and the label name is based on the directory name.
Map
map() can apply transforms over an entire dataset and it also generates a cache file.
Create a simple Resize
function:
>>> def transforms(examples):
... examples["pixel_values"] = [image.convert("RGB").resize((100,100)) for image in examples["image"]]
... return examples
Now map() the function over the entire dataset and set batched=True
. The transform returns pixel_values
as a cacheable PIL.Image
object:
>>> dataset = dataset.map(transforms, remove_columns=["image"], batched=True)
>>> dataset[0]
{'label': 6,
'pixel_values': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=100x100 at 0x7F058237BB10>}
This saves time because you don’t have to execute the same transform twice. It is best to use map() for operations you only run once per training - like resizing an image - instead of using it for operations executed for each epoch, like data augmentations.
map() takes up some memory, but you can reduce its memory requirements with the following parameters:
batch_size
determines the number of examples that are processed in one call to the transform function.writer_batch_size
determines the number of processed examples that are kept in memory before they are stored away.
Both parameter values default to 1000, which can be expensive if you are storing images. Lower the value to use less memory when calling map().
Data augmentation
Adding data augmentations to a dataset is common to prevent overfitting and achieve better performance. You can use any library or package you want to apply the augmentations. This guide will use the transforms from torchvision.
Feel free to use other data augmentation libraries like Albumentations. 🤗 Datasets can apply any custom function and transforms to an entire dataset!
Add the ColorJitter
transform to change the color properties of the image randomly:
>>> from torchvision.transforms import Compose, ColorJitter, ToTensor
>>> jitter = Compose(
... [
... ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.7),
... ToTensor(),
... ]
... )
Create a function to apply the ColorJitter
transform to an image:
>>> def transforms(examples):
... examples["pixel_values"] = [jitter(image.convert("RGB")) for image in examples["image"]]
... return examples
Then you can use the set_transform() function to apply the transform on-the-fly to consume less disk space. Use this function if you only need to access the examples once:
>>> dataset.set_transform(transforms)
Now visualize the results of the ColorJitter
transform:
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> img = dataset[0]["pixel_values"]
>>> plt.imshow(img.permute(1, 2, 0))