Sophie98 commited on
Commit
210920a
2 Parent(s): ab92204 033b042

Merge branch 'main' of https://huggingface.co/spaces/SophieDC/SofaStyler into main

Browse files
Files changed (2) hide show
  1. box_ops.py +88 -0
  2. segmentation.py +72 -0
box_ops.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Utilities for bounding box manipulation and GIoU.
4
+ """
5
+ import torch
6
+ from torchvision.ops.boxes import box_area
7
+
8
+
9
+ def box_cxcywh_to_xyxy(x):
10
+ x_c, y_c, w, h = x.unbind(-1)
11
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
12
+ (x_c + 0.5 * w), (y_c + 0.5 * h)]
13
+ return torch.stack(b, dim=-1)
14
+
15
+
16
+ def box_xyxy_to_cxcywh(x):
17
+ x0, y0, x1, y1 = x.unbind(-1)
18
+ b = [(x0 + x1) / 2, (y0 + y1) / 2,
19
+ (x1 - x0), (y1 - y0)]
20
+ return torch.stack(b, dim=-1)
21
+
22
+
23
+ # modified from torchvision to also return the union
24
+ def box_iou(boxes1, boxes2):
25
+ area1 = box_area(boxes1)
26
+ area2 = box_area(boxes2)
27
+
28
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
29
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
30
+
31
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
32
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
33
+
34
+ union = area1[:, None] + area2 - inter
35
+
36
+ iou = inter / union
37
+ return iou, union
38
+
39
+
40
+ def generalized_box_iou(boxes1, boxes2):
41
+ """
42
+ Generalized IoU from https://giou.stanford.edu/
43
+
44
+ The boxes should be in [x0, y0, x1, y1] format
45
+
46
+ Returns a [N, M] pairwise matrix, where N = len(boxes1)
47
+ and M = len(boxes2)
48
+ """
49
+ # degenerate boxes gives inf / nan results
50
+ # so do an early check
51
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
52
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
53
+ iou, union = box_iou(boxes1, boxes2)
54
+
55
+ lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
56
+ rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
57
+
58
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
59
+ area = wh[:, :, 0] * wh[:, :, 1]
60
+
61
+ return iou - (area - union) / area
62
+
63
+
64
+ def masks_to_boxes(masks):
65
+ """Compute the bounding boxes around the provided masks
66
+
67
+ The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
68
+
69
+ Returns a [N, 4] tensors, with the boxes in xyxy format
70
+ """
71
+ if masks.numel() == 0:
72
+ return torch.zeros((0, 4), device=masks.device)
73
+
74
+ h, w = masks.shape[-2:]
75
+
76
+ y = torch.arange(0, h, dtype=torch.float)
77
+ x = torch.arange(0, w, dtype=torch.float)
78
+ y, x = torch.meshgrid(y, x)
79
+
80
+ x_mask = (masks * x.unsqueeze(0))
81
+ x_max = x_mask.flatten(1).max(-1)[0]
82
+ x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
83
+
84
+ y_mask = (masks * y.unsqueeze(0))
85
+ y_max = y_mask.flatten(1).max(-1)[0]
86
+ y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
87
+
88
+ return torch.stack([x_min, y_min, x_max, y_max], 1)
segmentation.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import libraries
2
+
3
+ import cv2
4
+ from tensorflow import keras
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from PIL import Image
8
+ import segmentation_models as sm
9
+
10
+ def get_mask(image):
11
+ model_path = "Segmentation/model_checkpoint.h5"
12
+ CLASSES = ['sofa']
13
+ BACKBONE = 'resnet50'
14
+
15
+ # define network parameters
16
+ n_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1) # case for binary and multiclass segmentation
17
+ activation = 'sigmoid' if n_classes == 1 else 'softmax'
18
+ preprocess_input = sm.get_preprocessing(BACKBONE)
19
+ sm.set_framework('tf.keras')
20
+ LR=0.0001
21
+
22
+ #create model architecture
23
+ model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)
24
+ # define optomizer
25
+ optim = keras.optimizers.Adam(LR)
26
+ # Segmentation models losses can be combined together by '+' and scaled by integer or float factor
27
+ dice_loss = sm.losses.DiceLoss()
28
+ focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
29
+ total_loss = dice_loss + (1 * focal_loss)
30
+ # actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
31
+ # total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss
32
+ metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
33
+ # compile keras model with defined optimozer, loss and metrics
34
+ model.compile(optim, total_loss, metrics)
35
+
36
+ #load model
37
+ model.load_weights(model_path)
38
+
39
+
40
+ test_img = np.array(image)#cv2.imread(path, cv2.IMREAD_COLOR)
41
+ test_img = cv2.resize(test_img, (640, 640))
42
+ test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
43
+ test_img = np.expand_dims(test_img, axis=0)
44
+
45
+ prediction = model.predict(test_img).round()
46
+ mask = Image.fromarray(prediction[...,0].squeeze()*255).convert("L")
47
+ mask.save("masks/sofa.jpg")
48
+ return np.array(mask)
49
+
50
+ def replace_sofa(image,mask,styled_sofa):
51
+ # print(mask.shape)
52
+ # mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
53
+ # print(mask.shape)
54
+ image = np.array(image)
55
+ #image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
56
+ styled_sofa = cv2.cvtColor(styled_sofa, cv2.COLOR_BGR2RGB)
57
+
58
+ _, mask = cv2.threshold(mask, 10, 255, cv2.THRESH_BINARY)
59
+ mask_inv = cv2.bitwise_not(mask)
60
+ image_bg = cv2.bitwise_and(image,image,mask = mask_inv)
61
+ sofa_fg = cv2.bitwise_and(styled_sofa,styled_sofa,mask = mask)
62
+ new_image = cv2.add(image_bg,sofa_fg)
63
+ return new_image
64
+
65
+ # image = cv2.imread('input/sofa.jpg')
66
+ # mask = cv2.imread('masks/sofa.jpg')
67
+ # styled_sofa = cv2.imread('output/sofa_stylized_style.jpg')
68
+
69
+ # #get_mask(image)
70
+
71
+ # plt.imshow(replace_sofa(image,mask,styled_sofa))
72
+ # plt.show()