Spaces:
Runtime error
Runtime error
File size: 699 Bytes
e4497d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from typing import Optional
import tensorflow as tf
import matplotlib.pyplot as plt
def visualize_data(dataset: tf.data.Dataset, data_augmentation: Optional[tf.keras.Sequential]=None) -> None:
plt.figure(figsize=(10, 10))
for images, labels in dataset.take(1):
for i in range(9):
_ = plt.subplot(3, 3, i + 1)
if data_augmentation != None:
augmented_image = data_augmentation(images)
plt.imshow(augmented_image[0].numpy().astype('uint8'))
else:
plt.imshow(images[i].numpy().astype('uint8'))
plt.title(int(labels[i]))
plt.axis('off')
plt.show() |