File size: 2,381 Bytes
09d339d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import tensorflow as tf
import matplotlib.pyplot as plt


class Augment(tf.keras.layers.Layer):
  def __init__(self, seed=42):
    super().__init__()
    # both use the same seed, so they'll make the same random changes.
    self.augment_inputs = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
    self.augment_labels = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)

  def call(self, inputs, labels):
    inputs = self.augment_inputs(inputs)
    labels = self.augment_labels(labels)
    return inputs, labels
  
def load_image(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128), method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  input_image, input_mask = normalize_image(input_image, input_mask)
  return input_image, input_mask

def normalize_image(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask -= 1
  return input_image, input_mask

def create_mask(pred_mask):
  pred_mask = tf.math.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]

def U_net_model(output_channels:int, down_stack, up_stack):
  inputs = tf.keras.layers.Input(shape=[128, 128, 3])
  skips = down_stack(inputs)
  outputs = skips[-1]
  skips = reversed(skips[:-1])
  for up, skip in zip(up_stack, skips):
    outputs = up(outputs)
    concatenate = tf.keras.layers.Concatenate()
    outputs = concatenate([outputs, skip])
  last = tf.keras.layers.Conv2DTranspose(filters=output_channels, kernel_size=3, strides=2, padding='same')
  outputs = last(outputs)
  return tf.keras.Model(inputs=inputs, outputs=outputs)

def display(display_list):
  plt.figure(figsize=(15, 15))
  titles = ['Input Image', 'Predicted Mask']
  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(titles[i])
    plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()

def show_predictions(image_url, model):
  image = tf.keras.utils.get_file(origin=image_url)
  image = tf.keras.utils.load_img(image)
  image = tf.keras.utils.img_to_array(image)
  image = tf.image.resize(image, (128,128))
  image = tf.cast(image, tf.float32) / 255.0
  image = tf.expand_dims(image, axis=0)
  pred_mask = model.predict(image)
  display([image[0], create_mask(pred_mask)])