Sophie98 commited on
Commit
a37eb28
1 Parent(s): f628b78

changed model loading

Browse files
Files changed (2) hide show
  1. app.py +4 -1
  2. segmentation.py +28 -28
app.py CHANGED
@@ -5,6 +5,10 @@ from segmentation import get_mask,replace_sofa
5
  from styleTransfer import create_styledSofa
6
  from PIL import Image
7
 
 
 
 
 
8
  def resize_sofa(img):
9
  """
10
  This function adds padding to make the orignal image square and 640by640.
@@ -130,4 +134,3 @@ if __name__ == "__main__":
130
  demo.launch()
131
 
132
 
133
- #https://github.com/dhawan98/Post-Processing-of-Image-Segmentation-using-CRF
5
  from styleTransfer import create_styledSofa
6
  from PIL import Image
7
 
8
+ #https://colab.research.google.com/drive/11CtQpSeRBGAuw4TtE_rL470tRo-1X-p2#scrollTo=edGukUHXyymr
9
+ #https://colab.research.google.com/drive/1xq33YKf0LVKCkbbUZIoNPzgpR_4Kd0qL#scrollTo=sPuM8Xypjs-c
10
+ #https://github.com/dhawan98/Post-Processing-of-Image-Segmentation-using-CRF
11
+
12
  def resize_sofa(img):
13
  """
14
  This function adds padding to make the orignal image square and 640by640.
134
  demo.launch()
135
 
136
 
 
segmentation.py CHANGED
@@ -7,6 +7,34 @@ import matplotlib.pyplot as plt
7
  from PIL import Image
8
  import segmentation_models as sm
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def get_mask(image:Image) -> Image:
11
  """
12
  This function generates a mask of the image that highlights all the sofas in the image.
@@ -17,34 +45,6 @@ def get_mask(image:Image) -> Image:
17
  Return:
18
  mask = corresponding maks of the image
19
  """
20
- model_path = "model_checkpoint.h5"
21
- CLASSES = ['sofa']
22
- BACKBONE = 'resnet50'
23
-
24
- # define network parameters
25
- n_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1) # case for binary and multiclass segmentation
26
- activation = 'sigmoid' if n_classes == 1 else 'softmax'
27
- preprocess_input = sm.get_preprocessing(BACKBONE)
28
- sm.set_framework('tf.keras')
29
- LR=0.0001
30
-
31
- #create model architecture
32
- model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)
33
- # define optomizer
34
- optim = keras.optimizers.Adam(LR)
35
- # Segmentation models losses can be combined together by '+' and scaled by integer or float factor
36
- dice_loss = sm.losses.DiceLoss()
37
- focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
38
- total_loss = dice_loss + (1 * focal_loss)
39
- # actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
40
- # total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss
41
- metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
42
- # compile keras model with defined optimozer, loss and metrics
43
- model.compile(optim, total_loss, metrics)
44
-
45
- #load model
46
- model.load_weights(model_path)
47
-
48
  test_img = np.array(image)#cv2.imread(path, cv2.IMREAD_COLOR)
49
  test_img = cv2.resize(test_img, (640, 640))
50
  test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
7
  from PIL import Image
8
  import segmentation_models as sm
9
 
10
+ model_path = "model_checkpoint.h5"
11
+ CLASSES = ['sofa']
12
+ BACKBONE = 'resnet50'
13
+
14
+ # define network parameters
15
+ n_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1) # case for binary and multiclass segmentation
16
+ activation = 'sigmoid' if n_classes == 1 else 'softmax'
17
+ preprocess_input = sm.get_preprocessing(BACKBONE)
18
+ sm.set_framework('tf.keras')
19
+ LR=0.0001
20
+
21
+ #create model architecture
22
+ model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)
23
+ # define optomizer
24
+ optim = keras.optimizers.Adam(LR)
25
+ # Segmentation models losses can be combined together by '+' and scaled by integer or float factor
26
+ dice_loss = sm.losses.DiceLoss()
27
+ focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
28
+ total_loss = dice_loss + (1 * focal_loss)
29
+ # actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
30
+ # total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss
31
+ metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
32
+ # compile keras model with defined optimozer, loss and metrics
33
+ model.compile(optim, total_loss, metrics)
34
+
35
+ #load model
36
+ model.load_weights(model_path)
37
+
38
  def get_mask(image:Image) -> Image:
39
  """
40
  This function generates a mask of the image that highlights all the sofas in the image.
45
  Return:
46
  mask = corresponding maks of the image
47
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  test_img = np.array(image)#cv2.imread(path, cv2.IMREAD_COLOR)
49
  test_img = cv2.resize(test_img, (640, 640))
50
  test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)