import os from matplotlib import gridspec import matplotlib.pylab as plt import numpy as np import tensorflow as tf def crop_center(image): """Returns a cropped square image.""" shape = image.shape new_shape = min(shape[1], shape[2]) offset_y = max(shape[1] - shape[2], 0) // 2 offset_x = max(shape[2] - shape[1], 0) // 2 image = tf.image.crop_to_bounding_box( image, offset_y, offset_x, new_shape, new_shape) return image def load_image(image_url, image_size=(256, 256), preserve_aspect_ratio=True): """Loads and preprocesses images.""" # Cache image file locally. image_path = tf.keras.utils.get_file(os.path.basename(image_url)[-128:], image_url) # Load and convert to float32 numpy array, add batch dimension, and normalize to range [0, 1]. img = plt.imread(image_path).astype(np.float32)[np.newaxis, ...] if img.max() > 1.0: img = img / 255. if len(img.shape) == 3: img = tf.stack([img, img, img], axis=-1) img = crop_center(img) img = tf.image.resize(img, image_size, preserve_aspect_ratio=True) return img