| import tensorflow as tf | |
| import numpy as np | |
| def load_data(): | |
| """ | |
| Loads CIFAR-10 dataset and normalizes it. | |
| Returns: | |
| (x_train, y_train), (x_test, y_test) | |
| """ | |
| (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() | |
| # Normalize pixel values to be between 0 and 1 | |
| x_train = x_train.astype('float32') / 255.0 | |
| x_test = x_test.astype('float32') / 255.0 | |
| return (x_train, y_train), (x_test, y_test) | |
| def get_class_names(): | |
| return ['airplane', 'automobile', 'bird', 'cat', 'deer', | |
| 'dog', 'frog', 'horse', 'ship', 'truck'] | |