|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""CIFAR-10 Data Set""" |
|
|
|
|
|
import pickle |
|
|
|
import numpy as np |
|
|
|
import datasets |
|
from datasets.tasks import ImageClassification |
|
|
|
|
|
_CITATION = """\ |
|
@TECHREPORT{Krizhevsky09learningmultiple, |
|
author = {Alex Krizhevsky}, |
|
title = {Learning multiple layers of features from tiny images}, |
|
institution = {}, |
|
year = {2009} |
|
} |
|
""" |
|
|
|
_DESCRIPTION = """\ |
|
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images |
|
per class. There are 50000 training images and 10000 test images. |
|
""" |
|
|
|
_DATA_URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" |
|
|
|
_NAMES = [ |
|
"airplane", |
|
"automobile", |
|
"bird", |
|
"cat", |
|
"deer", |
|
"dog", |
|
"frog", |
|
"horse", |
|
"ship", |
|
"truck", |
|
] |
|
|
|
|
|
class Cifar10(datasets.GeneratorBasedBuilder): |
|
"""CIFAR-10 Data Set""" |
|
|
|
BUILDER_CONFIGS = [ |
|
datasets.BuilderConfig( |
|
name="plain_text", |
|
version=datasets.Version("1.0.0", ""), |
|
description="Plain text import of CIFAR-10 Data Set", |
|
) |
|
] |
|
|
|
def _info(self): |
|
return datasets.DatasetInfo( |
|
description=_DESCRIPTION, |
|
features=datasets.Features( |
|
{ |
|
"img": datasets.Image(), |
|
"label": datasets.features.ClassLabel(names=_NAMES), |
|
} |
|
), |
|
supervised_keys=("img", "label"), |
|
homepage="https://www.cs.toronto.edu/~kriz/cifar.html", |
|
citation=_CITATION, |
|
task_templates=ImageClassification(image_column="img", label_column="label"), |
|
) |
|
|
|
def _split_generators(self, dl_manager): |
|
archive = dl_manager.download(_DATA_URL) |
|
|
|
return [ |
|
datasets.SplitGenerator( |
|
name=datasets.Split.TRAIN, gen_kwargs={"files": dl_manager.iter_archive(archive), "split": "train"} |
|
), |
|
datasets.SplitGenerator( |
|
name=datasets.Split.TEST, gen_kwargs={"files": dl_manager.iter_archive(archive), "split": "test"} |
|
), |
|
] |
|
|
|
def _generate_examples(self, files, split): |
|
"""This function returns the examples in the raw (text) form.""" |
|
|
|
if split == "train": |
|
batches = ["data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4", "data_batch_5"] |
|
|
|
if split == "test": |
|
batches = ["test_batch"] |
|
batches = [f"cifar-10-batches-py/{filename}" for filename in batches] |
|
|
|
for path, fo in files: |
|
|
|
if path in batches: |
|
dict = pickle.load(fo, encoding="bytes") |
|
|
|
labels = dict[b"labels"] |
|
images = dict[b"data"] |
|
|
|
for idx, _ in enumerate(images): |
|
|
|
img_reshaped = np.transpose(np.reshape(images[idx], (3, 32, 32)), (1, 2, 0)) |
|
|
|
yield f"{path}_{idx}", { |
|
"img": img_reshaped, |
|
"label": labels[idx], |
|
} |
|
|