hbpkillerX commited on
Commit
23811f4
1 Parent(s): e4ea306

Delete train.py

Browse files
Files changed (1) hide show
  1. train.py +0 -114
train.py DELETED
@@ -1,114 +0,0 @@
1
- import os
2
- import random
3
- import numpy as np
4
- from glob import glob
5
- from PIL import Image, ImageOps
6
- import matplotlib.pyplot as plt
7
- import tensorflow as tf
8
- from tensorflow import keras
9
- from tensorflow.keras import layers
10
- from model import get_model
11
-
12
- # functions to create the dataset
13
- random.seed(1)
14
- IMAGE_SIZE = 128
15
- BATCH_SIZE = 4
16
- MAX_TRAIN_IMAGES = 300
17
-
18
- def autocontrast(tensor, cutoff=0):
19
- tensor = tf.cast(tensor, dtype=tf.float32)
20
- min_val = tf.reduce_min(tensor)
21
- max_val = tf.reduce_max(tensor)
22
- range_val = max_val - min_val
23
- adjusted_tensor = tf.clip_by_value(tf.cast(tf.round((tensor - min_val - cutoff) * (255 / (range_val - 2 * cutoff))), tf.uint8), 0, 255)
24
- return adjusted_tensor
25
-
26
- def read_image(image_path):
27
- image = tf.io.read_file(image_path)
28
- image = tf.image.decode_png(image, channels=3)
29
- image = autocontrast(image)
30
- image.set_shape([None, None, 3])
31
- image = tf.cast(image, dtype=tf.float32) / 255
32
- return image
33
-
34
-
35
- def random_crop(low_image, enhanced_image):
36
- low_image_shape = tf.shape(low_image)[:2]
37
- low_w = tf.random.uniform(
38
- shape=(), maxval=low_image_shape[1] - IMAGE_SIZE + 1, dtype=tf.int32
39
- )
40
- low_h = tf.random.uniform(
41
- shape=(), maxval=low_image_shape[0] - IMAGE_SIZE + 1, dtype=tf.int32
42
- )
43
- enhanced_w = low_w
44
- enhanced_h = low_h
45
- low_image_cropped = low_image[
46
- low_h : low_h + IMAGE_SIZE, low_w : low_w + IMAGE_SIZE
47
- ]
48
- enhanced_image_cropped = enhanced_image[
49
- enhanced_h : enhanced_h + IMAGE_SIZE, enhanced_w : enhanced_w + IMAGE_SIZE
50
- ]
51
- return low_image_cropped, enhanced_image_cropped
52
-
53
-
54
- def load_data(low_light_image_path, enhanced_image_path):
55
- low_light_image = read_image(low_light_image_path)
56
- enhanced_image = read_image(enhanced_image_path)
57
- low_light_image, enhanced_image = random_crop(low_light_image, enhanced_image)
58
- return low_light_image, enhanced_image
59
-
60
-
61
- def get_dataset(low_light_images, enhanced_images):
62
- dataset = tf.data.Dataset.from_tensor_slices((low_light_images, enhanced_images))
63
- dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
64
- dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
65
- return dataset
66
-
67
- # Loss functions
68
-
69
- class CustomLoss:
70
- def __init__(self, perceptual_loss_model):
71
- self.perceptual_loss_model = perceptual_loss_model
72
- def perceptual_loss(self, y_true, y_pred):
73
- y_true_features = self.perceptual_loss_model(y_true)
74
- y_pred_features = self.perceptual_loss_model(y_pred)
75
- loss = tf.reduce_mean(tf.square(y_true_features[0] - y_pred_features[0])) + tf.reduce_mean(tf.square(y_true_features[1] - y_pred_features[1]))
76
- return loss
77
- def charbonnier_loss(self, y_true, y_pred):
78
- return tf.reduce_mean(tf.sqrt(tf.square(y_true - y_pred) + tf.square(1e-3)))
79
- def __call__(self, y_true, y_pred):
80
- return 0.5*self.perceptual_loss(y_true, y_pred) + 0.4*self.charbonnier_loss(y_true, y_pred)
81
-
82
- def peak_signal_noise_ratio(y_true, y_pred):
83
- return tf.image.psnr(y_pred, y_true, max_val=255.0)
84
-
85
- def main():
86
- train_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[:MAX_TRAIN_IMAGES]
87
- train_enhanced_images = sorted(glob("./lol_dataset/our485/high/*"))[:MAX_TRAIN_IMAGES]
88
-
89
- val_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[MAX_TRAIN_IMAGES:]
90
- val_enhanced_images = sorted(glob("./lol_dataset/our485/high/*"))[MAX_TRAIN_IMAGES:]
91
-
92
- train_dataset = get_dataset(train_low_light_images, train_enhanced_images)
93
- val_dataset = get_dataset(val_low_light_images, val_enhanced_images)
94
-
95
- #Model for calculating perceptual loss
96
- vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
97
- for layer in vgg.layers:
98
- layer.trainable = False #Freeze all the layers, since this model is for evaluation only
99
- outputs = [vgg.get_layer('block3_conv3').output, vgg.get_layer('block4_conv3').output]
100
- perceptual_loss_model = tf.keras.models.Model(inputs=vgg.input, outputs=outputs)
101
-
102
- optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
103
- loss = CustomLoss(perceptual_loss_model)
104
- model = get_model()
105
-
106
- model.compile(
107
- optimizer=optimizer, loss=loss, metrics=[peak_signal_noise_ratio]
108
- )
109
-
110
- history = model.fit(train_dataset, validation_data=val_dataset, epochs=50)
111
- model.save_weights("model.h5")
112
-
113
- if __name__ == "__main__":
114
- main()