Spaces:
Runtime error
Runtime error
geekyrakshit
commited on
Commit
·
246dd98
1
Parent(s):
60407c7
added model loading function
Browse files- enhance_me/mirnet/mirnet.py +11 -1
enhance_me/mirnet/mirnet.py
CHANGED
@@ -5,7 +5,7 @@ from typing import List
|
|
5 |
from datetime import datetime
|
6 |
|
7 |
from tensorflow import keras
|
8 |
-
from tensorflow.keras import optimizers
|
9 |
|
10 |
from wandb.keras import WandbCallback
|
11 |
|
@@ -76,6 +76,16 @@ class MIRNet:
|
|
76 |
metrics=[peak_signal_noise_ratio],
|
77 |
)
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
|
80 |
self.model.save_weights(
|
81 |
filepath, overwrite=overwrite, save_format=save_format, options=options
|
|
|
5 |
from datetime import datetime
|
6 |
|
7 |
from tensorflow import keras
|
8 |
+
from tensorflow.keras import optimizers, models
|
9 |
|
10 |
from wandb.keras import WandbCallback
|
11 |
|
|
|
76 |
metrics=[peak_signal_noise_ratio],
|
77 |
)
|
78 |
|
79 |
+
def load_model(
|
80 |
+
self, filepath, custom_objects=None, compile=True, options=None
|
81 |
+
) -> None:
|
82 |
+
self.model = models.load_model(
|
83 |
+
filepath=filepath,
|
84 |
+
custom_objects=custom_objects,
|
85 |
+
compile=compile,
|
86 |
+
options=options,
|
87 |
+
)
|
88 |
+
|
89 |
def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
|
90 |
self.model.save_weights(
|
91 |
filepath, overwrite=overwrite, save_format=save_format, options=options
|