File size: 699 Bytes
e4497d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from typing import Optional

import tensorflow as tf
import matplotlib.pyplot as plt


def visualize_data(dataset: tf.data.Dataset, data_augmentation: Optional[tf.keras.Sequential]=None) -> None:
    plt.figure(figsize=(10, 10))
    for images, labels in dataset.take(1):
        for i in range(9):
            _ = plt.subplot(3, 3, i + 1)
            
            if data_augmentation != None:
                augmented_image = data_augmentation(images)
                plt.imshow(augmented_image[0].numpy().astype('uint8'))
            else:
                plt.imshow(images[i].numpy().astype('uint8'))
            
            plt.title(int(labels[i]))
            plt.axis('off')
    plt.show()