MeetMeAt92 commited on
Commit
4cf621a
1 Parent(s): 59d446c

Delete model.h5

Browse files
Files changed (1) hide show
  1. model.h5 +0 -331
model.h5 DELETED
@@ -1,331 +0,0 @@
1
- import os
2
- import cv2
3
- import random
4
- import numpy as np
5
- from glob import glob
6
- from PIL import Image, ImageOps
7
- import matplotlib.pyplot as plt
8
-
9
- import tensorflow as tf
10
- from tensorflow import keras
11
- from tensorflow.keras import layers
12
-
13
- from google.colab import drive
14
- drive.mount('/content/gdrive')
15
-
16
-
17
- random.seed(10)
18
-
19
- IMAGE_SIZE = 128
20
- BATCH_SIZE = 4
21
- MAX_TRAIN_IMAGES = 300
22
-
23
-
24
- def read_image(image_path):
25
- image = tf.io.read_file(image_path)
26
- image = tf.image.decode_png(image, channels=3)
27
- image.set_shape([None, None, 3])
28
- image = tf.cast(image, dtype=tf.float32) / 255.0
29
-
30
- return image
31
-
32
-
33
- def random_crop(low_image, enhanced_image):
34
- low_image_shape = tf.shape(low_image)[:2]
35
- low_w = tf.random.uniform(
36
- shape=(), maxval=low_image_shape[1] - IMAGE_SIZE + 1, dtype=tf.int32
37
- )
38
- low_h = tf.random.uniform(
39
- shape=(), maxval=low_image_shape[0] - IMAGE_SIZE + 1, dtype=tf.int32
40
- )
41
- enhanced_w = low_w
42
- enhanced_h = low_h
43
- low_image_cropped = low_image[
44
- low_h : low_h + IMAGE_SIZE, low_w : low_w + IMAGE_SIZE
45
- ]
46
- enhanced_image_cropped = enhanced_image[
47
- enhanced_h : enhanced_h + IMAGE_SIZE, enhanced_w : enhanced_w + IMAGE_SIZE
48
- ]
49
- return low_image_cropped, enhanced_image_cropped
50
-
51
-
52
- def load_data(low_light_image_path, enhanced_image_path):
53
- low_light_image = read_image(low_light_image_path)
54
- enhanced_image = read_image(enhanced_image_path)
55
- low_light_image, enhanced_image = random_crop(low_light_image, enhanced_image)
56
- return low_light_image, enhanced_image
57
-
58
-
59
- def get_dataset(low_light_images, enhanced_images):
60
- dataset = tf.data.Dataset.from_tensor_slices((low_light_images, enhanced_images))
61
-
62
- dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
63
-
64
- dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
65
- return dataset
66
-
67
-
68
- train_low_light_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/our485/low/*"))[:MAX_TRAIN_IMAGES]
69
- train_enhanced_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/our485/high/*"))[:MAX_TRAIN_IMAGES]
70
-
71
- val_low_light_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/our485/low/*"))[MAX_TRAIN_IMAGES:]
72
- val_enhanced_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/our485/high/*"))[MAX_TRAIN_IMAGES:]
73
-
74
- test_low_light_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/eval15/low/*"))
75
- test_enhanced_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/eval15/high/*"))
76
-
77
-
78
- train_dataset = get_dataset(train_low_light_images, train_enhanced_images)
79
- val_dataset = get_dataset(val_low_light_images, val_enhanced_images)
80
-
81
-
82
- print("Train Dataset:", train_dataset)
83
- print("Val Dataset:", val_dataset)
84
-
85
-
86
- def selective_kernel_feature_fusion(
87
- multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3
88
- ):
89
- channels = list(multi_scale_feature_1.shape)[-1]
90
- combined_feature = layers.Add()(
91
- [multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3]
92
- )
93
- gap = layers.GlobalAveragePooling2D()(combined_feature)
94
- channel_wise_statistics = tf.reshape(gap, shape=(-1, 1, 1, channels))
95
- compact_feature_representation = layers.Conv2D(
96
- filters=channels // 8, kernel_size=(1, 1), activation="relu"
97
- )(channel_wise_statistics)
98
- feature_descriptor_1 = layers.Conv2D(
99
- channels, kernel_size=(1, 1), activation="softmax"
100
- )(compact_feature_representation)
101
- feature_descriptor_2 = layers.Conv2D(
102
- channels, kernel_size=(1, 1), activation="softmax"
103
- )(compact_feature_representation)
104
- feature_descriptor_3 = layers.Conv2D(
105
- channels, kernel_size=(1, 1), activation="softmax"
106
- )(compact_feature_representation)
107
- feature_1 = multi_scale_feature_1 * feature_descriptor_1
108
- feature_2 = multi_scale_feature_2 * feature_descriptor_2
109
- feature_3 = multi_scale_feature_3 * feature_descriptor_3
110
- aggregated_feature = layers.Add()([feature_1, feature_2, feature_3])
111
- return aggregated_feature
112
-
113
-
114
-
115
-
116
- def spatial_attention_block(input_tensor):
117
- average_pooling = tf.reduce_max(input_tensor, axis=-1)
118
- average_pooling = tf.expand_dims(average_pooling, axis=-1)
119
- max_pooling = tf.reduce_mean(input_tensor, axis=-1)
120
- max_pooling = tf.expand_dims(max_pooling, axis=-1)
121
- concatenated = layers.Concatenate(axis=-1)([average_pooling, max_pooling])
122
- feature_map = layers.Conv2D(1, kernel_size=(1, 1))(concatenated)
123
- feature_map = tf.nn.sigmoid(feature_map)
124
- return input_tensor * feature_map
125
-
126
-
127
- def channel_attention_block(input_tensor):
128
- channels = list(input_tensor.shape)[-1]
129
- average_pooling = layers.GlobalAveragePooling2D()(input_tensor)
130
- feature_descriptor = tf.reshape(average_pooling, shape=(-1, 1, 1, channels))
131
- feature_activations = layers.Conv2D(
132
- filters=channels // 8, kernel_size=(1, 1), activation="relu"
133
- )(feature_descriptor)
134
- feature_activations = layers.Conv2D(
135
- filters=channels, kernel_size=(1, 1), activation="sigmoid"
136
- )(feature_activations)
137
- return input_tensor * feature_activations
138
-
139
-
140
- def dual_attention_unit_block(input_tensor):
141
- channels = list(input_tensor.shape)[-1]
142
- feature_map = layers.Conv2D(
143
- channels, kernel_size=(3, 3), padding="same", activation="relu"
144
- )(input_tensor)
145
- feature_map = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(
146
- feature_map
147
- )
148
- channel_attention = channel_attention_block(feature_map)
149
- spatial_attention = spatial_attention_block(feature_map)
150
- concatenation = layers.Concatenate(axis=-1)([channel_attention, spatial_attention])
151
- concatenation = layers.Conv2D(channels, kernel_size=(1, 1))(concatenation)
152
- return layers.Add()([input_tensor, concatenation])
153
-
154
-
155
- # Recursive Residual Modules
156
-
157
-
158
- def down_sampling_module(input_tensor):
159
- channels = list(input_tensor.shape)[-1]
160
- main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")(
161
- input_tensor
162
- )
163
- main_branch = layers.Conv2D(
164
- channels, kernel_size=(3, 3), padding="same", activation="relu"
165
- )(main_branch)
166
- main_branch = layers.MaxPooling2D()(main_branch)
167
- main_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(main_branch)
168
- skip_branch = layers.MaxPooling2D()(input_tensor)
169
- skip_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(skip_branch)
170
- return layers.Add()([skip_branch, main_branch])
171
-
172
-
173
- def up_sampling_module(input_tensor):
174
- channels = list(input_tensor.shape)[-1]
175
- main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")(
176
- input_tensor
177
- )
178
- main_branch = layers.Conv2D(
179
- channels, kernel_size=(3, 3), padding="same", activation="relu"
180
- )(main_branch)
181
- main_branch = layers.UpSampling2D()(main_branch)
182
- main_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(main_branch)
183
- skip_branch = layers.UpSampling2D()(input_tensor)
184
- skip_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(skip_branch)
185
- return layers.Add()([skip_branch, main_branch])
186
-
187
-
188
- # MRB Block
189
- def multi_scale_residual_block(input_tensor, channels):
190
- # features
191
- level1 = input_tensor
192
- level2 = down_sampling_module(input_tensor)
193
- level3 = down_sampling_module(level2)
194
- # DAU
195
- level1_dau = dual_attention_unit_block(level1)
196
- level2_dau = dual_attention_unit_block(level2)
197
- level3_dau = dual_attention_unit_block(level3)
198
- # SKFF
199
- level1_skff = selective_kernel_feature_fusion(
200
- level1_dau,
201
- up_sampling_module(level2_dau),
202
- up_sampling_module(up_sampling_module(level3_dau)),
203
- )
204
- level2_skff = selective_kernel_feature_fusion(
205
- down_sampling_module(level1_dau), level2_dau, up_sampling_module(level3_dau)
206
- )
207
- level3_skff = selective_kernel_feature_fusion(
208
- down_sampling_module(down_sampling_module(level1_dau)),
209
- down_sampling_module(level2_dau),
210
- level3_dau,
211
- )
212
- # DAU 2
213
- level1_dau_2 = dual_attention_unit_block(level1_skff)
214
- level2_dau_2 = up_sampling_module((dual_attention_unit_block(level2_skff)))
215
- level3_dau_2 = up_sampling_module(
216
- up_sampling_module(dual_attention_unit_block(level3_skff))
217
- )
218
- # SKFF 2
219
- skff_ = selective_kernel_feature_fusion(level1_dau_2, level2_dau_2, level3_dau_2)
220
- conv = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(skff_)
221
- return layers.Add()([input_tensor, conv])
222
-
223
-
224
-
225
-
226
- def recursive_residual_group(input_tensor, num_mrb, channels):
227
- conv1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
228
- for _ in range(num_mrb):
229
- conv1 = multi_scale_residual_block(conv1, channels)
230
- conv2 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(conv1)
231
- return layers.Add()([conv2, input_tensor])
232
-
233
-
234
- def mirnet_model(num_rrg, num_mrb, channels):
235
- input_tensor = keras.Input(shape=[None, None, 3])
236
- x1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
237
- for _ in range(num_rrg):
238
- x1 = recursive_residual_group(x1, num_mrb, channels)
239
- conv = layers.Conv2D(3, kernel_size=(3, 3), padding="same")(x1)
240
- output_tensor = layers.Add()([input_tensor, conv])
241
- return keras.Model(input_tensor, output_tensor)
242
-
243
-
244
- model = mirnet_model(num_rrg=3, num_mrb=2, channels=64)
245
-
246
-
247
- def charbonnier_loss(y_true, y_pred):
248
- return tf.reduce_mean(tf.sqrt(tf.square(y_true - y_pred) + tf.square(1e-3)))
249
-
250
-
251
- def peak_signal_noise_ratio(y_true, y_pred):
252
- return tf.image.psnr(y_pred, y_true, max_val=255.0)
253
-
254
-
255
- optimizer = keras.optimizers.Adam(learning_rate=1e-4)
256
- model.compile(
257
- optimizer=optimizer, loss=charbonnier_loss, metrics=[peak_signal_noise_ratio]
258
- )
259
-
260
- history = model.fit(
261
- train_dataset,
262
- validation_data=val_dataset,
263
- #epochs traning cycles set krna k lia
264
- epochs=1,
265
- callbacks=[
266
- keras.callbacks.ReduceLROnPlateau(
267
- monitor="val_peak_signal_noise_ratio",
268
- factor=0.5,
269
- patience=5,
270
- verbose=1,
271
- min_delta=1e-7,
272
- mode="max",
273
- )
274
- ],
275
- )
276
-
277
- plt.plot(history.history["loss"], label="train_loss")
278
- plt.plot(history.history["val_loss"], label="val_loss")
279
- plt.xlabel("Epochs")
280
- plt.ylabel("Loss")
281
- plt.title("Train and Validation Losses Over Epochs", fontsize=14)
282
- plt.legend()
283
- plt.grid()
284
- plt.show()
285
-
286
-
287
- plt.plot(history.history["peak_signal_noise_ratio"], label="train_psnr")
288
- plt.plot(history.history["val_peak_signal_noise_ratio"], label="val_psnr")
289
- plt.xlabel("Epochs")
290
- plt.ylabel("PSNR")
291
- plt.title("Train and Validation PSNR Over Epochs", fontsize=14)
292
- plt.legend()
293
- plt.grid()
294
- plt.show()
295
-
296
-
297
-
298
-
299
- def plot_results(images, titles, figure_size=(12, 12)):
300
- fig = plt.figure(figsize=figure_size)
301
- for i in range(len(images)):
302
- fig.add_subplot(1, len(images), i + 1).set_title(titles[i])
303
- _ = plt.imshow(images[i])
304
- plt.axis("off")
305
- plt.show()
306
-
307
-
308
- def infer(original_image):
309
- image = keras.preprocessing.image.img_to_array(original_image)
310
- image = image.astype("float16") / 255.0
311
- image = np.expand_dims(image, axis=0)
312
- output = model.predict(image)
313
- output_image = output[0] * 255.0
314
- output_image = output_image.clip(0, 255)
315
- output_image = output_image.reshape(
316
- (np.shape(output_image)[0], np.shape(output_image)[1], 3)
317
- )
318
- output_image = Image.fromarray(np.uint8(output_image))
319
- original_image = Image.fromarray(np.uint8(original_image))
320
- return output_image
321
-
322
-
323
-
324
- for low_light_image in random.sample(test_low_light_images, 2):
325
- original_image = Image.open(low_light_image)
326
- enhanced_image = infer(original_image)
327
- plot_results(
328
- [original_image, ImageOps.autocontrast(original_image), enhanced_image],
329
- ["Original", "PIL Autocontrast", "MIRNet Enhanced"],
330
- (20, 12),
331
- )