Spaces:
Build error
Build error
Sophie98
commited on
Commit
•
2aa2edf
1
Parent(s):
d1be09c
changed model loading
Browse files- segmentation.py +20 -19
segmentation.py
CHANGED
@@ -7,33 +7,34 @@ import matplotlib.pyplot as plt
|
|
7 |
from PIL import Image
|
8 |
import segmentation_models as sm
|
9 |
|
10 |
-
model_path = "model_checkpoint.h5"
|
11 |
-
CLASSES = ['sofa']
|
12 |
BACKBONE = 'resnet50'
|
13 |
|
14 |
# define network parameters
|
15 |
-
|
16 |
-
|
17 |
preprocess_input = sm.get_preprocessing(BACKBONE)
|
18 |
sm.set_framework('tf.keras')
|
19 |
-
LR=0.0001
|
20 |
|
21 |
#create model architecture
|
22 |
-
model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)
|
23 |
-
# define optomizer
|
24 |
-
optim = keras.optimizers.Adam(LR)
|
25 |
-
# Segmentation models losses can be combined together by '+' and scaled by integer or float factor
|
26 |
-
dice_loss = sm.losses.DiceLoss()
|
27 |
-
focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
|
28 |
-
total_loss = dice_loss + (1 * focal_loss)
|
29 |
-
# actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
|
30 |
-
# total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss
|
31 |
-
metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
|
32 |
-
# compile keras model with defined optimozer, loss and metrics
|
33 |
-
model.compile(optim, total_loss, metrics)
|
34 |
|
35 |
-
#load model
|
36 |
-
model.load_weights(model_path)
|
|
|
37 |
|
38 |
def get_mask(image:Image) -> Image:
|
39 |
"""
|
|
|
7 |
from PIL import Image
|
8 |
import segmentation_models as sm
|
9 |
|
10 |
+
# model_path = "model_checkpoint.h5"
|
11 |
+
# CLASSES = ['sofa']
|
12 |
BACKBONE = 'resnet50'
|
13 |
|
14 |
# define network parameters
|
15 |
+
# in_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1) # case for binary and multiclass segmentation
|
16 |
+
# actvation = 'sigmoid' if n_classes == 1 else 'softmax'
|
17 |
preprocess_input = sm.get_preprocessing(BACKBONE)
|
18 |
sm.set_framework('tf.keras')
|
19 |
+
# LR=0.0001
|
20 |
|
21 |
#create model architecture
|
22 |
+
# model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)
|
23 |
+
# # define optomizer
|
24 |
+
# optim = keras.optimizers.Adam(LR)
|
25 |
+
# # Segmentation models losses can be combined together by '+' and scaled by integer or float factor
|
26 |
+
# dice_loss = sm.losses.DiceLoss()
|
27 |
+
# focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
|
28 |
+
# total_loss = dice_loss + (1 * focal_loss)
|
29 |
+
# # actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
|
30 |
+
# # total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss
|
31 |
+
# metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
|
32 |
+
# # compile keras model with defined optimozer, loss and metrics
|
33 |
+
# model.compile(optim, total_loss, metrics)
|
34 |
|
35 |
+
# #load model
|
36 |
+
# model.load_weights(model_path)
|
37 |
+
model = keras.models.load_model('model_final.h5', compile=False)
|
38 |
|
39 |
def get_mask(image:Image) -> Image:
|
40 |
"""
|