Spaces:
Build error
Build error
Merge branch 'main' of https://huggingface.co/spaces/SophieDC/SofaStyler into main
Browse files- box_ops.py +88 -0
- segmentation.py +72 -0
box_ops.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Utilities for bounding box manipulation and GIoU.
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
from torchvision.ops.boxes import box_area
|
7 |
+
|
8 |
+
|
9 |
+
def box_cxcywh_to_xyxy(x):
|
10 |
+
x_c, y_c, w, h = x.unbind(-1)
|
11 |
+
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
12 |
+
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
13 |
+
return torch.stack(b, dim=-1)
|
14 |
+
|
15 |
+
|
16 |
+
def box_xyxy_to_cxcywh(x):
|
17 |
+
x0, y0, x1, y1 = x.unbind(-1)
|
18 |
+
b = [(x0 + x1) / 2, (y0 + y1) / 2,
|
19 |
+
(x1 - x0), (y1 - y0)]
|
20 |
+
return torch.stack(b, dim=-1)
|
21 |
+
|
22 |
+
|
23 |
+
# modified from torchvision to also return the union
|
24 |
+
def box_iou(boxes1, boxes2):
|
25 |
+
area1 = box_area(boxes1)
|
26 |
+
area2 = box_area(boxes2)
|
27 |
+
|
28 |
+
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
29 |
+
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
30 |
+
|
31 |
+
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
32 |
+
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
33 |
+
|
34 |
+
union = area1[:, None] + area2 - inter
|
35 |
+
|
36 |
+
iou = inter / union
|
37 |
+
return iou, union
|
38 |
+
|
39 |
+
|
40 |
+
def generalized_box_iou(boxes1, boxes2):
|
41 |
+
"""
|
42 |
+
Generalized IoU from https://giou.stanford.edu/
|
43 |
+
|
44 |
+
The boxes should be in [x0, y0, x1, y1] format
|
45 |
+
|
46 |
+
Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
47 |
+
and M = len(boxes2)
|
48 |
+
"""
|
49 |
+
# degenerate boxes gives inf / nan results
|
50 |
+
# so do an early check
|
51 |
+
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
52 |
+
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
53 |
+
iou, union = box_iou(boxes1, boxes2)
|
54 |
+
|
55 |
+
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
56 |
+
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
57 |
+
|
58 |
+
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
59 |
+
area = wh[:, :, 0] * wh[:, :, 1]
|
60 |
+
|
61 |
+
return iou - (area - union) / area
|
62 |
+
|
63 |
+
|
64 |
+
def masks_to_boxes(masks):
|
65 |
+
"""Compute the bounding boxes around the provided masks
|
66 |
+
|
67 |
+
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
|
68 |
+
|
69 |
+
Returns a [N, 4] tensors, with the boxes in xyxy format
|
70 |
+
"""
|
71 |
+
if masks.numel() == 0:
|
72 |
+
return torch.zeros((0, 4), device=masks.device)
|
73 |
+
|
74 |
+
h, w = masks.shape[-2:]
|
75 |
+
|
76 |
+
y = torch.arange(0, h, dtype=torch.float)
|
77 |
+
x = torch.arange(0, w, dtype=torch.float)
|
78 |
+
y, x = torch.meshgrid(y, x)
|
79 |
+
|
80 |
+
x_mask = (masks * x.unsqueeze(0))
|
81 |
+
x_max = x_mask.flatten(1).max(-1)[0]
|
82 |
+
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
83 |
+
|
84 |
+
y_mask = (masks * y.unsqueeze(0))
|
85 |
+
y_max = y_mask.flatten(1).max(-1)[0]
|
86 |
+
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
87 |
+
|
88 |
+
return torch.stack([x_min, y_min, x_max, y_max], 1)
|
segmentation.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Import libraries
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
from tensorflow import keras
|
5 |
+
import numpy as np
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from PIL import Image
|
8 |
+
import segmentation_models as sm
|
9 |
+
|
10 |
+
def get_mask(image):
|
11 |
+
model_path = "Segmentation/model_checkpoint.h5"
|
12 |
+
CLASSES = ['sofa']
|
13 |
+
BACKBONE = 'resnet50'
|
14 |
+
|
15 |
+
# define network parameters
|
16 |
+
n_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1) # case for binary and multiclass segmentation
|
17 |
+
activation = 'sigmoid' if n_classes == 1 else 'softmax'
|
18 |
+
preprocess_input = sm.get_preprocessing(BACKBONE)
|
19 |
+
sm.set_framework('tf.keras')
|
20 |
+
LR=0.0001
|
21 |
+
|
22 |
+
#create model architecture
|
23 |
+
model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)
|
24 |
+
# define optomizer
|
25 |
+
optim = keras.optimizers.Adam(LR)
|
26 |
+
# Segmentation models losses can be combined together by '+' and scaled by integer or float factor
|
27 |
+
dice_loss = sm.losses.DiceLoss()
|
28 |
+
focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
|
29 |
+
total_loss = dice_loss + (1 * focal_loss)
|
30 |
+
# actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
|
31 |
+
# total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss
|
32 |
+
metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
|
33 |
+
# compile keras model with defined optimozer, loss and metrics
|
34 |
+
model.compile(optim, total_loss, metrics)
|
35 |
+
|
36 |
+
#load model
|
37 |
+
model.load_weights(model_path)
|
38 |
+
|
39 |
+
|
40 |
+
test_img = np.array(image)#cv2.imread(path, cv2.IMREAD_COLOR)
|
41 |
+
test_img = cv2.resize(test_img, (640, 640))
|
42 |
+
test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
|
43 |
+
test_img = np.expand_dims(test_img, axis=0)
|
44 |
+
|
45 |
+
prediction = model.predict(test_img).round()
|
46 |
+
mask = Image.fromarray(prediction[...,0].squeeze()*255).convert("L")
|
47 |
+
mask.save("masks/sofa.jpg")
|
48 |
+
return np.array(mask)
|
49 |
+
|
50 |
+
def replace_sofa(image,mask,styled_sofa):
|
51 |
+
# print(mask.shape)
|
52 |
+
# mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
|
53 |
+
# print(mask.shape)
|
54 |
+
image = np.array(image)
|
55 |
+
#image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
56 |
+
styled_sofa = cv2.cvtColor(styled_sofa, cv2.COLOR_BGR2RGB)
|
57 |
+
|
58 |
+
_, mask = cv2.threshold(mask, 10, 255, cv2.THRESH_BINARY)
|
59 |
+
mask_inv = cv2.bitwise_not(mask)
|
60 |
+
image_bg = cv2.bitwise_and(image,image,mask = mask_inv)
|
61 |
+
sofa_fg = cv2.bitwise_and(styled_sofa,styled_sofa,mask = mask)
|
62 |
+
new_image = cv2.add(image_bg,sofa_fg)
|
63 |
+
return new_image
|
64 |
+
|
65 |
+
# image = cv2.imread('input/sofa.jpg')
|
66 |
+
# mask = cv2.imread('masks/sofa.jpg')
|
67 |
+
# styled_sofa = cv2.imread('output/sofa_stylized_style.jpg')
|
68 |
+
|
69 |
+
# #get_mask(image)
|
70 |
+
|
71 |
+
# plt.imshow(replace_sofa(image,mask,styled_sofa))
|
72 |
+
# plt.show()
|