SofaStyler / segmentation.py
Sophie98
more testing
defa533
raw history blame
No virus
3.47 kB
# Import libraries
import cv2
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import segmentation_models as sm
sm.set_framework('tf.keras')
def get_mask(image:Image) -> Image:
"""
This function generates a mask of the image that highlights all the sofas in the image.
This uses a pre-trained Unet model with a resnet50 backbone.
Remark: The model was trained on 640by640 images and it is therefore best that the image has the same size.
Parameters:
image = original image
Return:
mask = corresponding maks of the image
"""
model_path = "model_checkpoint.h5"
CLASSES = ['sofa']
BACKBONE = 'resnet50'
# define network parameters
n_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1) # case for binary and multiclass segmentation
activation = 'sigmoid' if n_classes == 1 else 'softmax'
preprocess_input = sm.get_preprocessing(BACKBONE)
LR=0.0001
#create model architecture
model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)
# define optomizer
optim = keras.optimizers.Adam(LR)
# Segmentation models losses can be combined together by '+' and scaled by integer or float factor
dice_loss = sm.losses.DiceLoss()
focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1 * focal_loss)
# actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
# total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss
metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
# compile keras model with defined optimozer, loss and metrics
model.compile(optim, total_loss, metrics)
# #load model
model.load_weights(model_path)
#model = keras.models.load_model('model_final.h5', compile=False)
print('loaded model')
return image
test_img = np.array(image)
test_img = cv2.resize(test_img, (640, 640))
test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
test_img = np.expand_dims(test_img, axis=0)
prediction = model.predict(preprocess_input(np.array(test_img))).round()
mask = Image.fromarray(prediction[...,0].squeeze()*255).convert("L")
return mask
def replace_sofa(image:Image, mask:Image, styled_sofa:Image) -> Image:
"""
This function replaces the original sofa in the image by the new styled sofa according
to the mask.
Remark: All images should have the same size.
Input:
image = Original image
mask = Generated masks highlighting the sofas in the image
styled_sofa = Styled image
Return:
new_image = Image containing the styled sofa
"""
image,mask,styled_sofa = np.array(image),np.array(mask),np.array(styled_sofa)
_, mask = cv2.threshold(mask, 10, 255, cv2.THRESH_BINARY)
mask_inv = cv2.bitwise_not(mask)
image_bg = cv2.bitwise_and(image,image,mask = mask_inv)
sofa_fg = cv2.bitwise_and(styled_sofa,styled_sofa,mask = mask)
new_image = cv2.add(image_bg,sofa_fg)
return Image.fromarray(new_image)
# image = cv2.imread('input/sofa.jpg')
# mask = cv2.imread('masks/sofa.jpg')
# styled_sofa = cv2.imread('output/sofa_stylized_style.jpg')
# #get_mask(image)
# plt.imshow(replace_sofa(image,mask,styled_sofa))
# plt.show()