geekyrakshit commited on
Commit
246dd98
·
1 Parent(s): 60407c7

added model loading function

Browse files
Files changed (1) hide show
  1. enhance_me/mirnet/mirnet.py +11 -1
enhance_me/mirnet/mirnet.py CHANGED
@@ -5,7 +5,7 @@ from typing import List
5
  from datetime import datetime
6
 
7
  from tensorflow import keras
8
- from tensorflow.keras import optimizers
9
 
10
  from wandb.keras import WandbCallback
11
 
@@ -76,6 +76,16 @@ class MIRNet:
76
  metrics=[peak_signal_noise_ratio],
77
  )
78
 
 
 
 
 
 
 
 
 
 
 
79
  def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
80
  self.model.save_weights(
81
  filepath, overwrite=overwrite, save_format=save_format, options=options
 
5
  from datetime import datetime
6
 
7
  from tensorflow import keras
8
+ from tensorflow.keras import optimizers, models
9
 
10
  from wandb.keras import WandbCallback
11
 
 
76
  metrics=[peak_signal_noise_ratio],
77
  )
78
 
79
+ def load_model(
80
+ self, filepath, custom_objects=None, compile=True, options=None
81
+ ) -> None:
82
+ self.model = models.load_model(
83
+ filepath=filepath,
84
+ custom_objects=custom_objects,
85
+ compile=compile,
86
+ options=options,
87
+ )
88
+
89
  def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
90
  self.model.save_weights(
91
  filepath, overwrite=overwrite, save_format=save_format, options=options