File size: 3,468 Bytes
72a1628
 
 
 
 
 
 
 
f2e8e46
72a1628
a37eb28
e4fb230
79c6687
 
 
 
 
 
 
 
 
3b83a8e
5d930e1
 
3b83a8e
 
 
5d930e1
758d201
3b83a8e
5d930e1
3b83a8e
 
5d930e1
 
 
 
 
 
 
 
 
 
 
 
3b83a8e
 
5d930e1
 
7ae512f
defa533
5d930e1
72a1628
 
 
 
e4fb230
72a1628
e4fb230
72a1628
e4fb230
79c6687
 
 
 
 
 
 
 
 
 
 
e4fb230
72a1628
 
 
 
 
 
e4fb230
72a1628
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Import libraries

import cv2
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import segmentation_models as sm
sm.set_framework('tf.keras')


def get_mask(image:Image) -> Image:
    """
    This function generates a mask of the image that highlights all the sofas in the image.
    This uses a pre-trained Unet model with a resnet50 backbone.
    Remark: The model was trained on 640by640 images and it is therefore best that the image has the same size.
    Parameters:
            image = original image
    Return:
            mask  = corresponding maks of the image
    """

    model_path = "model_checkpoint.h5"
    CLASSES = ['sofa']
    BACKBONE = 'resnet50'

    # define network parameters
    n_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1)  # case for binary and multiclass segmentation
    activation = 'sigmoid' if n_classes == 1 else 'softmax'
    preprocess_input = sm.get_preprocessing(BACKBONE)
    LR=0.0001

    #create model architecture
    model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)
    # define optomizer
    optim = keras.optimizers.Adam(LR)
    # Segmentation models losses can be combined together by '+' and scaled by integer or float factor
    dice_loss = sm.losses.DiceLoss()
    focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
    total_loss = dice_loss + (1 * focal_loss)
    # actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
    # total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss 
    metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
    # compile keras model with defined optimozer, loss and metrics
    model.compile(optim, total_loss, metrics)

    # #load model
    model.load_weights(model_path) 
    #model = keras.models.load_model('model_final.h5', compile=False)
    print('loaded model')
    return image
    test_img = np.array(image) 
    test_img = cv2.resize(test_img, (640, 640))
    test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
    test_img = np.expand_dims(test_img, axis=0)

    prediction = model.predict(preprocess_input(np.array(test_img))).round()
    mask = Image.fromarray(prediction[...,0].squeeze()*255).convert("L")
    return mask 

def replace_sofa(image:Image, mask:Image, styled_sofa:Image) -> Image:
    """
    This function replaces the original sofa in the image by the new styled sofa according
    to the mask.
    Remark: All images should have the same size.
    Input:
        image       = Original image
        mask        = Generated masks highlighting the sofas in the image
        styled_sofa = Styled image
    Return:
        new_image   = Image containing the styled sofa
    """
    image,mask,styled_sofa = np.array(image),np.array(mask),np.array(styled_sofa) 

    _, mask = cv2.threshold(mask, 10, 255, cv2.THRESH_BINARY)
    mask_inv = cv2.bitwise_not(mask)
    image_bg = cv2.bitwise_and(image,image,mask = mask_inv)
    sofa_fg = cv2.bitwise_and(styled_sofa,styled_sofa,mask = mask)
    new_image = cv2.add(image_bg,sofa_fg)
    return Image.fromarray(new_image)

# image = cv2.imread('input/sofa.jpg')
# mask = cv2.imread('masks/sofa.jpg')
# styled_sofa = cv2.imread('output/sofa_stylized_style.jpg')

# #get_mask(image)

# plt.imshow(replace_sofa(image,mask,styled_sofa))
# plt.show()