Sophie98 commited on
Commit
3b83a8e
1 Parent(s): bf82406
Files changed (2) hide show
  1. segmentation.py +30 -28
  2. styleTransfer.py +16 -10
segmentation.py CHANGED
@@ -7,34 +7,6 @@ import matplotlib.pyplot as plt
7
  from PIL import Image
8
  import segmentation_models as sm
9
 
10
- # model_path = "model_checkpoint.h5"
11
- # CLASSES = ['sofa']
12
- BACKBONE = 'resnet50'
13
-
14
- # define network parameters
15
- # in_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1) # case for binary and multiclass segmentation
16
- # actvation = 'sigmoid' if n_classes == 1 else 'softmax'
17
- preprocess_input = sm.get_preprocessing(BACKBONE)
18
- sm.set_framework('tf.keras')
19
- # LR=0.0001
20
-
21
- #create model architecture
22
- # model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)
23
- # # define optomizer
24
- # optim = keras.optimizers.Adam(LR)
25
- # # Segmentation models losses can be combined together by '+' and scaled by integer or float factor
26
- # dice_loss = sm.losses.DiceLoss()
27
- # focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
28
- # total_loss = dice_loss + (1 * focal_loss)
29
- # # actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
30
- # # total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss
31
- # metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
32
- # # compile keras model with defined optimozer, loss and metrics
33
- # model.compile(optim, total_loss, metrics)
34
-
35
- # #load model
36
- # model.load_weights(model_path)
37
- model = keras.models.load_model('model_final.h5', compile=False)
38
 
39
  def get_mask(image:Image) -> Image:
40
  """
@@ -46,6 +18,36 @@ def get_mask(image:Image) -> Image:
46
  Return:
47
  mask = corresponding maks of the image
48
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  test_img = np.array(image)#cv2.imread(path, cv2.IMREAD_COLOR)
50
  test_img = cv2.resize(test_img, (640, 640))
51
  test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
 
7
  from PIL import Image
8
  import segmentation_models as sm
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def get_mask(image:Image) -> Image:
12
  """
 
18
  Return:
19
  mask = corresponding maks of the image
20
  """
21
+
22
+ # model_path = "model_checkpoint.h5"
23
+ # CLASSES = ['sofa']
24
+ BACKBONE = 'resnet50'
25
+
26
+ # define network parameters
27
+ # in_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1) # case for binary and multiclass segmentation
28
+ # actvation = 'sigmoid' if n_classes == 1 else 'softmax'
29
+ preprocess_input = sm.get_preprocessing(BACKBONE)
30
+ sm.set_framework('tf.keras')
31
+ # LR=0.0001
32
+
33
+ #create model architecture
34
+ # model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)
35
+ # # define optomizer
36
+ # optim = keras.optimizers.Adam(LR)
37
+ # # Segmentation models losses can be combined together by '+' and scaled by integer or float factor
38
+ # dice_loss = sm.losses.DiceLoss()
39
+ # focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
40
+ # total_loss = dice_loss + (1 * focal_loss)
41
+ # # actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
42
+ # # total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss
43
+ # metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
44
+ # # compile keras model with defined optimozer, loss and metrics
45
+ # model.compile(optim, total_loss, metrics)
46
+
47
+ # #load model
48
+ # model.load_weights(model_path)
49
+ model = keras.models.load_model('model_final.h5', compile=False)
50
+
51
  test_img = np.array(image)#cv2.imread(path, cv2.IMREAD_COLOR)
52
  test_img = cv2.resize(test_img, (640, 640))
53
  test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
styleTransfer.py CHANGED
@@ -34,19 +34,20 @@ def content_transform():
34
  transform = transforms.Compose(transform_list)
35
  return transform
36
 
37
- # Advanced options
38
- content_size=640
39
- style_size=640
40
 
41
- vgg = StyTR.vgg
42
- vgg.load_state_dict(torch.load(vgg_path))
43
- vgg = nn.Sequential(*list(vgg.children())[:44])
44
 
45
- decoder = StyTR.decoder
46
- Trans = transformer.Transformer()
47
- embedding = StyTR.PatchEmbed()
48
 
49
- def StyleTransformer(content_img: Image, style_img: Image):
 
 
 
 
 
 
50
 
51
  decoder.eval()
52
  Trans.eval()
@@ -105,6 +106,11 @@ def styleProjection(content_image,style_image):
105
  }])
106
 
107
  return Image.fromarray(np.uint8(result[0]['data'])[:,:,::-1]).convert('RGB')
 
 
 
 
 
108
 
109
 
110
 
 
34
  transform = transforms.Compose(transform_list)
35
  return transform
36
 
 
 
 
37
 
38
+ def StyleTransformer(content_img: Image, style_img: Image):
 
 
39
 
40
+ # Advanced options
41
+ content_size=640
42
+ style_size=640
43
 
44
+ vgg = StyTR.vgg
45
+ vgg.load_state_dict(torch.load(vgg_path))
46
+ vgg = nn.Sequential(*list(vgg.children())[:44])
47
+
48
+ decoder = StyTR.decoder
49
+ Trans = transformer.Transformer()
50
+ embedding = StyTR.PatchEmbed()
51
 
52
  decoder.eval()
53
  Trans.eval()
 
106
  }])
107
 
108
  return Image.fromarray(np.uint8(result[0]['data'])[:,:,::-1]).convert('RGB')
109
+
110
+
111
+ def create_styledSofa(content_image,style_image):
112
+ output = StyleTransformer(content_image,style_image)
113
+ return output
114
 
115
 
116