Spaces:
Runtime error
Runtime error
from typing import Optional | |
import tensorflow as tf | |
import matplotlib.pyplot as plt | |
def visualize_data(dataset: tf.data.Dataset, data_augmentation: Optional[tf.keras.Sequential]=None) -> None: | |
plt.figure(figsize=(10, 10)) | |
for images, labels in dataset.take(1): | |
for i in range(9): | |
_ = plt.subplot(3, 3, i + 1) | |
if data_augmentation != None: | |
augmented_image = data_augmentation(images) | |
plt.imshow(augmented_image[0].numpy().astype('uint8')) | |
else: | |
plt.imshow(images[i].numpy().astype('uint8')) | |
plt.title(int(labels[i])) | |
plt.axis('off') | |
plt.show() |