StarryNight_StyleTransfer / helper_functions.py
breynolds1247's picture
Update helper_functions.py
4f449fa
raw
history blame
1.53 kB
import tensorflow as tf
from tensorflow import keras
def img_scaler(image, max_dim = 512):
#Casts tensor to a new data type
original_shape = tf.cast(tf.shape(image)[:-1], tf.float32)
#Creates scale constant for the image based on imput max_dim
scale_ratio = max_dim / max(original_shape)
#Casts tensor to a new data type
new_shape = tf.cast(original_shape * scale_ratio, tf.int32)
#Resizes image
return tf.image.resize(image, new_shape)
def load_img(image_path, content=True, max_dim = 512):
if content:
#content images come straight from the web app, so no opening or decoding
img = image_path
#Convert image to dtype
img = tf.image.convert_image_dtype(img, tf.float32)
#Scale the image using the created scaler function
img = img_scaler(img, max_dim)
#Adds a fourth dimension to the Tensor because the model requires a 4-dimensional Tensor
return img[tf.newaxis, :]
else:
#Read contents of the input filename
img = tf.io.read_file(image_path)
#Detect whether an image is a BMP, GIF, JPEG, or PNG,
#performs the appropriate operation
#convert the input bytes string into a Tensor of type dtype
img = tf.image.decode_image(img, channels=3)
#Convert image to dtype
img = tf.image.convert_image_dtype(img, tf.float32)
#Scale the image using the created scaler function
img = img_scaler(img, max_dim)
#Adds a fourth dimension to the Tensor because the model requires a 4-dimensional Tensor
return img[tf.newaxis, :]