Spaces:
Build error
Build error
Sophie98
commited on
Commit
•
5d930e1
1
Parent(s):
7ae512f
test test test
Browse files- segmentation.py +20 -21
segmentation.py
CHANGED
@@ -20,37 +20,36 @@ def get_mask(image:Image) -> Image:
|
|
20 |
mask = corresponding maks of the image
|
21 |
"""
|
22 |
|
23 |
-
|
24 |
-
|
25 |
BACKBONE = 'resnet50'
|
26 |
|
27 |
# define network parameters
|
28 |
-
|
29 |
-
|
30 |
preprocess_input = sm.get_preprocessing(BACKBONE)
|
31 |
-
|
32 |
|
33 |
#create model architecture
|
34 |
-
|
35 |
-
#
|
36 |
-
|
37 |
-
#
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
#
|
42 |
-
#
|
43 |
-
|
44 |
-
#
|
45 |
-
|
46 |
|
47 |
# #load model
|
48 |
-
|
49 |
-
model = keras.models.load_model('model_final.h5', compile=False)
|
50 |
print('loaded model')
|
51 |
-
test_img = np.array(image)
|
52 |
test_img = cv2.resize(test_img, (640, 640))
|
53 |
-
return test_img
|
54 |
test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
|
55 |
test_img = np.expand_dims(test_img, axis=0)
|
56 |
|
|
|
20 |
mask = corresponding maks of the image
|
21 |
"""
|
22 |
|
23 |
+
model_path = "model_checkpoint.h5"
|
24 |
+
CLASSES = ['sofa']
|
25 |
BACKBONE = 'resnet50'
|
26 |
|
27 |
# define network parameters
|
28 |
+
n_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1) # case for binary and multiclass segmentation
|
29 |
+
actvation = 'sigmoid' if n_classes == 1 else 'softmax'
|
30 |
preprocess_input = sm.get_preprocessing(BACKBONE)
|
31 |
+
LR=0.0001
|
32 |
|
33 |
#create model architecture
|
34 |
+
model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)
|
35 |
+
# define optomizer
|
36 |
+
optim = keras.optimizers.Adam(LR)
|
37 |
+
# Segmentation models losses can be combined together by '+' and scaled by integer or float factor
|
38 |
+
dice_loss = sm.losses.DiceLoss()
|
39 |
+
focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
|
40 |
+
total_loss = dice_loss + (1 * focal_loss)
|
41 |
+
# actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
|
42 |
+
# total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss
|
43 |
+
metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
|
44 |
+
# compile keras model with defined optimozer, loss and metrics
|
45 |
+
model.compile(optim, total_loss, metrics)
|
46 |
|
47 |
# #load model
|
48 |
+
model.load_weights(model_path)
|
49 |
+
#model = keras.models.load_model('model_final.h5', compile=False)
|
50 |
print('loaded model')
|
51 |
+
test_img = np.array(image)
|
52 |
test_img = cv2.resize(test_img, (640, 640))
|
|
|
53 |
test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
|
54 |
test_img = np.expand_dims(test_img, axis=0)
|
55 |
|