Sophie98 commited on
Commit
2aa2edf
1 Parent(s): d1be09c

changed model loading

Browse files
Files changed (1) hide show
  1. segmentation.py +20 -19
segmentation.py CHANGED
@@ -7,33 +7,34 @@ import matplotlib.pyplot as plt
7
  from PIL import Image
8
  import segmentation_models as sm
9
 
10
- model_path = "model_checkpoint.h5"
11
- CLASSES = ['sofa']
12
  BACKBONE = 'resnet50'
13
 
14
  # define network parameters
15
- n_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1) # case for binary and multiclass segmentation
16
- activation = 'sigmoid' if n_classes == 1 else 'softmax'
17
  preprocess_input = sm.get_preprocessing(BACKBONE)
18
  sm.set_framework('tf.keras')
19
- LR=0.0001
20
 
21
  #create model architecture
22
- model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)
23
- # define optomizer
24
- optim = keras.optimizers.Adam(LR)
25
- # Segmentation models losses can be combined together by '+' and scaled by integer or float factor
26
- dice_loss = sm.losses.DiceLoss()
27
- focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
28
- total_loss = dice_loss + (1 * focal_loss)
29
- # actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
30
- # total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss
31
- metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
32
- # compile keras model with defined optimozer, loss and metrics
33
- model.compile(optim, total_loss, metrics)
34
 
35
- #load model
36
- model.load_weights(model_path)
 
37
 
38
  def get_mask(image:Image) -> Image:
39
  """
 
7
  from PIL import Image
8
  import segmentation_models as sm
9
 
10
+ # model_path = "model_checkpoint.h5"
11
+ # CLASSES = ['sofa']
12
  BACKBONE = 'resnet50'
13
 
14
  # define network parameters
15
+ # in_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1) # case for binary and multiclass segmentation
16
+ # actvation = 'sigmoid' if n_classes == 1 else 'softmax'
17
  preprocess_input = sm.get_preprocessing(BACKBONE)
18
  sm.set_framework('tf.keras')
19
+ # LR=0.0001
20
 
21
  #create model architecture
22
+ # model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)
23
+ # # define optomizer
24
+ # optim = keras.optimizers.Adam(LR)
25
+ # # Segmentation models losses can be combined together by '+' and scaled by integer or float factor
26
+ # dice_loss = sm.losses.DiceLoss()
27
+ # focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
28
+ # total_loss = dice_loss + (1 * focal_loss)
29
+ # # actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
30
+ # # total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss
31
+ # metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
32
+ # # compile keras model with defined optimozer, loss and metrics
33
+ # model.compile(optim, total_loss, metrics)
34
 
35
+ # #load model
36
+ # model.load_weights(model_path)
37
+ model = keras.models.load_model('model_final.h5', compile=False)
38
 
39
  def get_mask(image:Image) -> Image:
40
  """