geekyrakshit commited on
Commit
0668e89
1 Parent(s): 295bcab

updated zero-dce model

Browse files
enhance_me/zero_dce/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ from .zero_dce import ZeroDCE
enhance_me/zero_dce/zero_dce.py CHANGED
@@ -24,7 +24,7 @@ class ZeroDCE(Model):
24
  super(ZeroDCE, self).__init__(**kwargs)
25
  self.experiment_name = experiment_name
26
  if wandb_api_key is not None:
27
- init_wandb("mirnet", experiment_name, wandb_api_key)
28
  self.using_wandb = True
29
  else:
30
  self.using_wandb = False
 
24
  super(ZeroDCE, self).__init__(**kwargs)
25
  self.experiment_name = experiment_name
26
  if wandb_api_key is not None:
27
+ init_wandb("zero-dce", experiment_name, wandb_api_key)
28
  self.using_wandb = True
29
  else:
30
  self.using_wandb = False
notebooks/enhance_me_train.ipynb CHANGED
@@ -41,7 +41,8 @@
41
  "\n",
42
  "from PIL import Image\n",
43
  "from enhance_me import commons\n",
44
- "from enhance_me.mirnet import MIRNet"
 
45
  ]
46
  },
47
  {
@@ -183,7 +184,62 @@
183
  "id": "dO-IbNQHkB3R"
184
  },
185
  "outputs": [],
186
- "source": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  }
188
  ],
189
  "metadata": {
 
41
  "\n",
42
  "from PIL import Image\n",
43
  "from enhance_me import commons\n",
44
+ "from enhance_me.mirnet import MIRNet\n",
45
+ "from enhance_me.zero_dce import ZeroDCE"
46
  ]
47
  },
48
  {
 
184
  "id": "dO-IbNQHkB3R"
185
  },
186
  "outputs": [],
187
+ "source": [
188
+ "# @title Zero-DCE Train Configs\n",
189
+ "\n",
190
+ "experiment_name = \"lol_dataset_128\" # @param {type:\"string\"}\n",
191
+ "image_size = 128 # @param {type:\"integer\"}\n",
192
+ "dataset_label = \"lol\" # @param [\"lol\"]\n",
193
+ "apply_random_horizontal_flip = True # @param {type:\"boolean\"}\n",
194
+ "apply_random_vertical_flip = True # @param {type:\"boolean\"}\n",
195
+ "apply_random_rotation = True # @param {type:\"boolean\"}\n",
196
+ "use_mixed_precision = False # @param {type:\"boolean\"}\n",
197
+ "wandb_api_key = \"\" # @param {type:\"string\"}\n",
198
+ "val_split = 0.1 # @param {type:\"slider\", min:0.1, max:1.0, step:0.1}\n",
199
+ "batch_size = 32 # @param {type:\"integer\"}\n",
200
+ "learning_rate = 1e-4 # @param {type:\"number\"}\n",
201
+ "epsilon = 1e-3 # @param {type:\"number\"}\n",
202
+ "epochs = 100 # @param {type:\"slider\", min:10, max:100, step:5}"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "metadata": {},
209
+ "outputs": [],
210
+ "source": [
211
+ "zero_dce = ZeroDCE(\n",
212
+ " experiment_name=experiment_name,\n",
213
+ " wandb_api_key=None if wandb_api_key == \"\" else wandb_api_key\n",
214
+ ")"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "code",
219
+ "execution_count": null,
220
+ "metadata": {},
221
+ "outputs": [],
222
+ "source": [
223
+ "zero_dce.build_datasets(\n",
224
+ " image_size=image_size,\n",
225
+ " dataset_label=dataset_label,\n",
226
+ " apply_random_horizontal_flip=apply_random_horizontal_flip,\n",
227
+ " apply_random_vertical_flip=apply_random_vertical_flip,\n",
228
+ " apply_random_rotation=apply_random_rotation,\n",
229
+ " val_split=val_split,\n",
230
+ " batch_size=batch_size\n",
231
+ ")"
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "code",
236
+ "execution_count": null,
237
+ "metadata": {},
238
+ "outputs": [],
239
+ "source": [
240
+ "zero_dce.compile(learning_rate=learning_rate)\n",
241
+ "zero_dce.train(epochs=epochs)"
242
+ ]
243
  }
244
  ],
245
  "metadata": {