GANime / ganime /utils /statistics.py
Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
No virus
471 Bytes
import numpy as np
from tqdm.auto import tqdm
import tensorflow as tf
import tensorflow_datasets as tfds
def dataset_statistics(ds):
if isinstance(ds, tf.data.Dataset):
ds_numpy = tfds.as_numpy(ds)
elif isinstance(ds, tf.keras.utils.Sequence):
ds_numpy = ds
data = []
for da in tqdm(ds_numpy):
X, y = da
data.append(X)
all_data = np.concatenate(data)
return np.mean(all_data), np.var(all_data), np.std(all_data)