File size: 2,381 Bytes
09d339d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import tensorflow as tf
import matplotlib.pyplot as plt
class Augment(tf.keras.layers.Layer):
def __init__(self, seed=42):
super().__init__()
# both use the same seed, so they'll make the same random changes.
self.augment_inputs = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
self.augment_labels = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
def call(self, inputs, labels):
inputs = self.augment_inputs(inputs)
labels = self.augment_labels(labels)
return inputs, labels
def load_image(datapoint):
input_image = tf.image.resize(datapoint['image'], (128, 128))
input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128), method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)
input_image, input_mask = normalize_image(input_image, input_mask)
return input_image, input_mask
def normalize_image(input_image, input_mask):
input_image = tf.cast(input_image, tf.float32) / 255.0
input_mask -= 1
return input_image, input_mask
def create_mask(pred_mask):
pred_mask = tf.math.argmax(pred_mask, axis=-1)
pred_mask = pred_mask[..., tf.newaxis]
return pred_mask[0]
def U_net_model(output_channels:int, down_stack, up_stack):
inputs = tf.keras.layers.Input(shape=[128, 128, 3])
skips = down_stack(inputs)
outputs = skips[-1]
skips = reversed(skips[:-1])
for up, skip in zip(up_stack, skips):
outputs = up(outputs)
concatenate = tf.keras.layers.Concatenate()
outputs = concatenate([outputs, skip])
last = tf.keras.layers.Conv2DTranspose(filters=output_channels, kernel_size=3, strides=2, padding='same')
outputs = last(outputs)
return tf.keras.Model(inputs=inputs, outputs=outputs)
def display(display_list):
plt.figure(figsize=(15, 15))
titles = ['Input Image', 'Predicted Mask']
for i in range(len(display_list)):
plt.subplot(1, len(display_list), i+1)
plt.title(titles[i])
plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
plt.axis('off')
plt.show()
def show_predictions(image_url, model):
image = tf.keras.utils.get_file(origin=image_url)
image = tf.keras.utils.load_img(image)
image = tf.keras.utils.img_to_array(image)
image = tf.image.resize(image, (128,128))
image = tf.cast(image, tf.float32) / 255.0
image = tf.expand_dims(image, axis=0)
pred_mask = model.predict(image)
display([image[0], create_mask(pred_mask)]) |