cat-vs-dog / src /draw.py
eddydecena's picture
Add: Source
e4497d1
from typing import Optional
import tensorflow as tf
import matplotlib.pyplot as plt
def visualize_data(dataset: tf.data.Dataset, data_augmentation: Optional[tf.keras.Sequential]=None) -> None:
plt.figure(figsize=(10, 10))
for images, labels in dataset.take(1):
for i in range(9):
_ = plt.subplot(3, 3, i + 1)
if data_augmentation != None:
augmented_image = data_augmentation(images)
plt.imshow(augmented_image[0].numpy().astype('uint8'))
else:
plt.imshow(images[i].numpy().astype('uint8'))
plt.title(int(labels[i]))
plt.axis('off')
plt.show()