YoungMeezz commited on
Commit
1a44432
ยท
verified ยท
1 Parent(s): 8be059e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +987 -0
app.py ADDED
@@ -0,0 +1,987 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # PersonalizeSAM -- Personalize Segment Anything Model with One Shot
3
+ # Licensed under The MIT License [see LICENSE for details]
4
+ # --------------------------------------------------------
5
+ from PIL import Image
6
+ import torch
7
+ import torch.nn as nn
8
+ import gradio as gr
9
+ import numpy as np
10
+ from torch.nn import functional as F
11
+
12
+ from show import *
13
+ from per_segment_anything import sam_model_registry, SamPredictor
14
+
15
+
16
+ import torch
17
+ import numpy as np
18
+ import matplotlib.pyplot as plt
19
+ from sklearn.metrics import precision_score, recall_score
20
+ import torch.nn.functional as F
21
+
22
+ import cv2
23
+ import numpy as np
24
+ from PIL import Image, ImageDraw
25
+
26
+
27
+ from PIL import ImageDraw, ImageFont
28
+
29
+ class ImageMask(gr.components.Image):
30
+ """
31
+ Sets: source="canvas", tool="sketch"
32
+ """
33
+
34
+ is_template = True
35
+
36
+ def __init__(self, **kwargs):
37
+ super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
38
+
39
+ def preprocess(self, x):
40
+ return super().preprocess(x)
41
+
42
+
43
+ class Mask_Weights(nn.Module):
44
+ def __init__(self):
45
+ super().__init__()
46
+ self.weights = nn.Parameter(torch.ones(2, 1, requires_grad=True) / 3)
47
+
48
+
49
+ def point_selection(mask_sim, topk=1):
50
+ # Top-1 point selection
51
+ w, h = mask_sim.shape
52
+ topk_xy = mask_sim.flatten(0).topk(topk)[1]
53
+ topk_x = (topk_xy // h).unsqueeze(0)
54
+ topk_y = (topk_xy - topk_x * h)
55
+ topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0)
56
+ topk_label = np.array([1] * topk)
57
+ topk_xy = topk_xy.cpu().numpy()
58
+
59
+ # Top-last point selection
60
+ last_xy = mask_sim.flatten(0).topk(topk, largest=False)[1]
61
+ last_x = (last_xy // h).unsqueeze(0)
62
+ last_y = (last_xy - last_x * h)
63
+ last_xy = torch.cat((last_y, last_x), dim=0).permute(1, 0)
64
+ last_label = np.array([0] * topk)
65
+ last_xy = last_xy.cpu().numpy()
66
+
67
+ return topk_xy, topk_label, last_xy, last_label
68
+
69
+
70
+ def calculate_dice_loss(inputs, targets, num_masks = 1):
71
+ """
72
+ Compute the DICE loss, similar to generalized IOU for masks
73
+ Args:
74
+ inputs: A float tensor of arbitrary shape.
75
+ The predictions for each example.
76
+ targets: A float tensor with the same shape as inputs. Stores the binary
77
+ classification label for each element in inputs
78
+ (0 for the negative class and 1 for the positive class).
79
+ """
80
+ inputs = inputs.sigmoid()
81
+ inputs = inputs.flatten(1)
82
+ numerator = 2 * (inputs * targets).sum(-1)
83
+ denominator = inputs.sum(-1) + targets.sum(-1)
84
+ loss = 1 - (numerator + 1) / (denominator + 1)
85
+ return loss.sum() / num_masks
86
+
87
+
88
+ def calculate_sigmoid_focal_loss(inputs, targets, num_masks = 1, alpha: float = 0.25, gamma: float = 2):
89
+ """
90
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
91
+ Args:
92
+ inputs: A float tensor of arbitrary shape.
93
+ The predictions for each example.
94
+ targets: A float tensor with the same shape as inputs. Stores the binary
95
+ classification label for each element in inputs
96
+ (0 for the negative class and 1 for the positive class).
97
+ alpha: (optional) Weighting factor in range (0,1) to balance
98
+ positive vs negative examples. Default = -1 (no weighting).
99
+ gamma: Exponent of the modulating factor (1 - p_t) to
100
+ balance easy vs hard examples.
101
+ Returns:
102
+ Loss tensor
103
+ """
104
+ prob = inputs.sigmoid()
105
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
106
+ p_t = prob * targets + (1 - prob) * (1 - targets)
107
+ loss = ce_loss * ((1 - p_t) ** gamma)
108
+
109
+ if alpha >= 0:
110
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
111
+ loss = alpha_t * loss
112
+
113
+ return loss.mean(1).sum() / num_masks
114
+
115
+
116
+ def inference(ic_image, ic_mask, image1, image2):
117
+ # in context image and mask
118
+ ic_image = np.array(ic_image.convert("RGB"))
119
+ ic_mask = np.array(ic_mask.convert("RGB"))
120
+
121
+ sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
122
+ sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to('cpu')
123
+ # sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
124
+ predictor = SamPredictor(sam)
125
+
126
+ # Image features encoding
127
+ ref_mask = predictor.set_image(ic_image, ic_mask)
128
+ ref_feat = predictor.features.squeeze().permute(1, 2, 0)
129
+
130
+ ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0: 2], mode="bilinear")
131
+ ref_mask = ref_mask.squeeze()[0]
132
+
133
+ # Target feature extraction
134
+ print("======> Obtain Location Prior" )
135
+ target_feat = ref_feat[ref_mask > 0]
136
+ target_embedding = target_feat.mean(0).unsqueeze(0)
137
+ target_feat = target_embedding / target_embedding.norm(dim=-1, keepdim=True)
138
+ target_embedding = target_embedding.unsqueeze(0)
139
+
140
+ output_image = []
141
+
142
+ for test_image in [image1, image2]:
143
+ print("======> Testing Image" )
144
+ test_image = np.array(test_image.convert("RGB"))
145
+
146
+ # Image feature encoding
147
+ predictor.set_image(test_image)
148
+ test_feat = predictor.features.squeeze()
149
+
150
+ # Cosine similarity
151
+ C, h, w = test_feat.shape
152
+ test_feat = test_feat / test_feat.norm(dim=0, keepdim=True)
153
+ test_feat = test_feat.reshape(C, h * w)
154
+ sim = target_feat @ test_feat
155
+
156
+ sim = sim.reshape(1, 1, h, w)
157
+ sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
158
+ sim = predictor.model.postprocess_masks(
159
+ sim,
160
+ input_size=predictor.input_size,
161
+ original_size=predictor.original_size).squeeze()
162
+
163
+ # Positive-negative location prior
164
+ topk_xy_i, topk_label_i, last_xy_i, last_label_i = point_selection(sim, topk=1)
165
+ topk_xy = np.concatenate([topk_xy_i, last_xy_i], axis=0)
166
+ topk_label = np.concatenate([topk_label_i, last_label_i], axis=0)
167
+
168
+ # Obtain the target guidance for cross-attention layers
169
+ sim = (sim - sim.mean()) / torch.std(sim)
170
+ sim = F.interpolate(sim.unsqueeze(0).unsqueeze(0), size=(64, 64), mode="bilinear")
171
+ attn_sim = sim.sigmoid_().unsqueeze(0).flatten(3)
172
+
173
+ # First-step prediction
174
+ masks, scores, logits, _ = predictor.predict(
175
+ point_coords=topk_xy,
176
+ point_labels=topk_label,
177
+ multimask_output=False,
178
+ attn_sim=attn_sim, # Target-guided Attention
179
+ target_embedding=target_embedding # Target-semantic Prompting
180
+ )
181
+ best_idx = 0
182
+
183
+ # Cascaded Post-refinement-1
184
+ masks, scores, logits, _ = predictor.predict(
185
+ point_coords=topk_xy,
186
+ point_labels=topk_label,
187
+ mask_input=logits[best_idx: best_idx + 1, :, :],
188
+ multimask_output=True)
189
+ best_idx = np.argmax(scores)
190
+
191
+ # Cascaded Post-refinement-2
192
+ y, x = np.nonzero(masks[best_idx])
193
+ x_min = x.min()
194
+ x_max = x.max()
195
+ y_min = y.min()
196
+ y_max = y.max()
197
+ input_box = np.array([x_min, y_min, x_max, y_max])
198
+ masks, scores, logits, _ = predictor.predict(
199
+ point_coords=topk_xy,
200
+ point_labels=topk_label,
201
+ box=input_box[None, :],
202
+ mask_input=logits[best_idx: best_idx + 1, :, :],
203
+ multimask_output=True)
204
+ best_idx = np.argmax(scores)
205
+
206
+ final_mask = masks[best_idx]
207
+
208
+
209
+
210
+
211
+
212
+ mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8)
213
+ mask_colors[final_mask, :] = np.array([[128, 0, 0]])
214
+ output_image.append(Image.fromarray((mask_colors * 0.6 + test_image * 0.4).astype('uint8'), 'RGB'))
215
+
216
+ return output_image[0].resize((224, 224)), output_image[1].resize((224, 224))
217
+
218
+
219
+ def inference_scribble(image, image1, image2):
220
+ # in context image and mask
221
+ ic_image = image["image"]
222
+ ic_mask = image["mask"]
223
+ ic_image = np.array(ic_image.convert("RGB"))
224
+ ic_mask = np.array(ic_mask.convert("RGB"))
225
+
226
+ sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
227
+ sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to('cpu')
228
+ # sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
229
+ predictor = SamPredictor(sam)
230
+
231
+ # Image features encoding
232
+ ref_mask = predictor.set_image(ic_image, ic_mask)
233
+ ref_feat = predictor.features.squeeze().permute(1, 2, 0)
234
+
235
+ ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0: 2], mode="bilinear")
236
+ ref_mask = ref_mask.squeeze()[0]
237
+
238
+ # Target feature extraction
239
+ print("======> Obtain Location Prior" )
240
+ target_feat = ref_feat[ref_mask > 0]
241
+ target_embedding = target_feat.mean(0).unsqueeze(0)
242
+ target_feat = target_embedding / target_embedding.norm(dim=-1, keepdim=True)
243
+ target_embedding = target_embedding.unsqueeze(0)
244
+
245
+ output_image = []
246
+
247
+ for test_image in [image1, image2]:
248
+ print("======> Testing Image" )
249
+ test_image = np.array(test_image.convert("RGB"))
250
+
251
+ # Image feature encoding
252
+ predictor.set_image(test_image)
253
+ test_feat = predictor.features.squeeze()
254
+
255
+ # Cosine similarity
256
+ C, h, w = test_feat.shape
257
+ test_feat = test_feat / test_feat.norm(dim=0, keepdim=True)
258
+ test_feat = test_feat.reshape(C, h * w)
259
+ sim = target_feat @ test_feat
260
+
261
+ sim = sim.reshape(1, 1, h, w)
262
+ sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
263
+ sim = predictor.model.postprocess_masks(
264
+ sim,
265
+ input_size=predictor.input_size,
266
+ original_size=predictor.original_size).squeeze()
267
+
268
+ # Positive-negative location prior
269
+ topk_xy_i, topk_label_i, last_xy_i, last_label_i = point_selection(sim, topk=1)
270
+ topk_xy = np.concatenate([topk_xy_i, last_xy_i], axis=0)
271
+ topk_label = np.concatenate([topk_label_i, last_label_i], axis=0)
272
+
273
+ # Obtain the target guidance for cross-attention layers
274
+ sim = (sim - sim.mean()) / torch.std(sim)
275
+ sim = F.interpolate(sim.unsqueeze(0).unsqueeze(0), size=(64, 64), mode="bilinear")
276
+ attn_sim = sim.sigmoid_().unsqueeze(0).flatten(3)
277
+
278
+ # First-step prediction
279
+ masks, scores, logits, _ = predictor.predict(
280
+ point_coords=topk_xy,
281
+ point_labels=topk_label,
282
+ multimask_output=False,
283
+ attn_sim=attn_sim, # Target-guided Attention
284
+ target_embedding=target_embedding # Target-semantic Prompting
285
+ )
286
+ best_idx = 0
287
+
288
+ # Cascaded Post-refinement-1
289
+ masks, scores, logits, _ = predictor.predict(
290
+ point_coords=topk_xy,
291
+ point_labels=topk_label,
292
+ mask_input=logits[best_idx: best_idx + 1, :, :],
293
+ multimask_output=True)
294
+ best_idx = np.argmax(scores)
295
+
296
+ # Cascaded Post-refinement-2
297
+ y, x = np.nonzero(masks[best_idx])
298
+ x_min = x.min()
299
+ x_max = x.max()
300
+ y_min = y.min()
301
+ y_max = y.max()
302
+ input_box = np.array([x_min, y_min, x_max, y_max])
303
+ masks, scores, logits, _ = predictor.predict(
304
+ point_coords=topk_xy,
305
+ point_labels=topk_label,
306
+ box=input_box[None, :],
307
+ mask_input=logits[best_idx: best_idx + 1, :, :],
308
+ multimask_output=True)
309
+ best_idx = np.argmax(scores)
310
+
311
+ final_mask = masks[best_idx]
312
+ mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8)
313
+ mask_colors[final_mask, :] = np.array([[128, 0, 0]])
314
+ output_image.append(Image.fromarray((mask_colors * 0.6 + test_image * 0.4).astype('uint8'), 'RGB'))
315
+
316
+ return output_image[0].resize((224, 224)), output_image[1].resize((224, 224))
317
+
318
+
319
+ def inference_finetune_train(ic_image, ic_mask, image1, image2):
320
+ # in context image and mask
321
+ ic_image = np.array(ic_image.convert("RGB"))
322
+ ic_mask = np.array(ic_mask.convert("RGB"))
323
+
324
+ gt_mask = torch.tensor(ic_mask)[:, :, 0] > 0
325
+ gt_mask = gt_mask.float().unsqueeze(0).flatten(1).to('cpu')
326
+ # gt_mask = gt_mask.float().unsqueeze(0).flatten(1)
327
+
328
+ sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
329
+ sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to('cpu')
330
+ # sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
331
+ for name, param in sam.named_parameters():
332
+ param.requires_grad = False
333
+ predictor = SamPredictor(sam)
334
+
335
+ #์ž๊ธฐ ์œ„์น˜ ์šฐ์„ ๊ฐ’ ํš๋“
336
+ print("======> Obtain Self Location Prior" )
337
+ # Image features encoding
338
+ ref_mask = predictor.set_image(ic_image, ic_mask)
339
+ ref_feat = predictor.features.squeeze().permute(1, 2, 0)
340
+
341
+ ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0: 2], mode="bilinear")
342
+ ref_mask = ref_mask.squeeze()[0]
343
+
344
+ # Target feature extraction
345
+ target_feat = ref_feat[ref_mask > 0]
346
+ target_feat_mean = target_feat.mean(0)
347
+ target_feat_max = torch.max(target_feat, dim=0)[0]
348
+ target_feat = (target_feat_max / 2 + target_feat_mean / 2).unsqueeze(0)
349
+
350
+ # Cosine similarity
351
+ h, w, C = ref_feat.shape
352
+ target_feat = target_feat / target_feat.norm(dim=-1, keepdim=True)
353
+ ref_feat = ref_feat / ref_feat.norm(dim=-1, keepdim=True)
354
+ ref_feat = ref_feat.permute(2, 0, 1).reshape(C, h * w)
355
+ sim = target_feat @ ref_feat
356
+
357
+ # target_feat ์ €์žฅ
358
+ torch.save(target_feat, 'target_feat.pth')
359
+ print("target_feat๊ฐ€ 'target_feat.pth' ํŒŒ์ผ๋กœ ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
360
+
361
+ sim = sim.reshape(1, 1, h, w)
362
+ sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
363
+ sim = predictor.model.postprocess_masks(
364
+ sim,
365
+ input_size=predictor.input_size,
366
+ original_size=predictor.original_size).squeeze()
367
+
368
+ # Positive location prior
369
+ topk_xy, topk_label, _, _ = point_selection(sim, topk=1)
370
+
371
+ print('======> Start Training')
372
+ # Learnable mask weights
373
+ mask_weights = Mask_Weights().to('cpu')
374
+ # mask_weights = Mask_Weights()
375
+ mask_weights.train()
376
+ train_epoch = 1000
377
+ optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=1e-4, eps=1e-4, betas=(0.9, 0.999), weight_decay=0.01, amsgrad=False)
378
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, train_epoch)
379
+
380
+ for train_idx in range(train_epoch):
381
+ # Run the decoder
382
+ masks, scores, logits, logits_high = predictor.predict(
383
+ point_coords=topk_xy,
384
+ point_labels=topk_label,
385
+ multimask_output=True)
386
+ logits_high = logits_high.flatten(1)
387
+
388
+ # Weighted sum three-scale masks
389
+ weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0)
390
+ logits_high = logits_high * weights
391
+ logits_high = logits_high.sum(0).unsqueeze(0)
392
+
393
+ dice_loss = calculate_dice_loss(logits_high, gt_mask)
394
+ focal_loss = calculate_sigmoid_focal_loss(logits_high, gt_mask)
395
+ loss = dice_loss + focal_loss
396
+
397
+ optimizer.zero_grad()
398
+ loss.backward()
399
+ optimizer.step()
400
+ scheduler.step()
401
+
402
+ if train_idx % 10 == 0:
403
+ print('Train Epoch: {:} / {:}'.format(train_idx, train_epoch))
404
+ current_lr = scheduler.get_last_lr()[0]
405
+ print('LR: {:.6f}, Dice_Loss: {:.4f}, Focal_Loss: {:.4f}'.format(current_lr, dice_loss.item(), focal_loss.item()))
406
+
407
+
408
+ mask_weights.eval()
409
+ weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0)
410
+ weights_np = weights.detach().cpu().numpy()
411
+ print('======> Mask weights:\n', weights_np)
412
+
413
+ # # 1. ๊ฐ€์ค‘์น˜ ์ €์žฅ
414
+ torch.save(mask_weights.state_dict(), 'mask_weights.pth')
415
+ print("๊ฐ€์ค‘์น˜๊ฐ€ 'mask_weights.pth' ํŒŒ์ผ๋กœ ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
416
+
417
+ #########################Training ๋ ########################################
418
+ # 2. ํ…Œ์ŠคํŠธ ์ „์šฉ ์ฝ”๋“œ
419
+ # ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ๋ฐ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ
420
+ mask_weights = Mask_Weights().to('cpu')
421
+ mask_weights.load_state_dict(torch.load('Personalize-SAM\mask_weights.pth'))
422
+ mask_weights.eval() # ํ‰๊ฐ€ ๋ชจ๋“œ๋กœ ์„ค์ • (์ถ”๊ฐ€ ํ•™์Šต ๋ฐฉ์ง€)
423
+
424
+ weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0)
425
+ weights_np = weights.detach().cpu().numpy()
426
+ print('======> Mask weights:\n', weights_np)
427
+
428
+ print('======> Start Testing')
429
+ output_image = []
430
+
431
+ for test_image in [image1, image2]:
432
+ test_image = np.array(test_image.convert("RGB"))
433
+
434
+ # Image feature encoding
435
+ predictor.set_image(test_image)
436
+ test_feat = predictor.features.squeeze()
437
+ # Image feature encoding
438
+ predictor.set_image(test_image)
439
+ test_feat = predictor.features.squeeze()
440
+
441
+ # Cosine similarity
442
+ C, h, w = test_feat.shape
443
+ test_feat = test_feat / test_feat.norm(dim=0, keepdim=True)
444
+ test_feat = test_feat.reshape(C, h * w)
445
+ sim = target_feat @ test_feat
446
+
447
+ sim = sim.reshape(1, 1, h, w)
448
+ sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
449
+ sim = predictor.model.postprocess_masks(
450
+ sim,
451
+ input_size=predictor.input_size,
452
+ original_size=predictor.original_size).squeeze()
453
+
454
+ # Positive location prior ์–‘์„ฑ ์œ„์น˜ ์šฐ์„ ๊ฐ’
455
+ topk_xy, topk_label, _, _ = point_selection(sim, topk=1)
456
+ print("์ขŒํ‘œ๊ฐ’",topk_xy)
457
+
458
+ # First-step prediction
459
+ masks, scores, logits, logits_high = predictor.predict(
460
+ point_coords=topk_xy,
461
+ point_labels=topk_label,
462
+ multimask_output=True)
463
+
464
+ # ์˜ˆ์ธก ์ ์ˆ˜ ์ถœ๋ ฅ
465
+ # print("์˜ˆ์ธก ์ ์ˆ˜ (scores):")
466
+ # for idx, score in enumerate(scores):
467
+ # print(f"Mask {idx + 1}: {score.item():.4f}")
468
+
469
+
470
+ # Weighted sum three-scale masks ์„ธ ๊ฐ€์ง€ ์Šค์ผ€์ผ์˜ ๋งˆ์Šคํฌ๋ฅผ ๊ฐ€์ค‘์น˜ ํ•ฉ์‚ฐํ•˜๋Š” ๊ณผ์ •
471
+ logits_high = logits_high * weights.unsqueeze(-1)
472
+ logit_high = logits_high.sum(0)
473
+ mask = (logit_high > 0).detach().cpu().numpy()
474
+
475
+ logits = logits * weights_np[..., None]
476
+ logit = logits.sum(0)
477
+
478
+ # Cascaded Post-refinement-1 ๋ชจ๋ธ์˜ ์„ธ๋ถ„ํ™”๋œ ํ›„์ฒ˜๋ฆฌ ๋‹จ๊ณ„ ์ค‘ ์ฒซ ๋ฒˆ์งธ ๋‹จ๊ณ„
479
+ y, x = np.nonzero(mask)
480
+ x_min = x.min()
481
+ x_max = x.max()
482
+ y_min = y.min()
483
+ y_max = y.max()
484
+ input_box = np.array([x_min, y_min, x_max, y_max])
485
+ masks, scores, logits, _ = predictor.predict(
486
+ point_coords=topk_xy,
487
+ point_labels=topk_label,
488
+ box=input_box[None, :],
489
+ mask_input=logit[None, :, :],
490
+ multimask_output=True)
491
+ best_idx = np.argmax(scores)
492
+
493
+ # Cascaded Post-refinement-2 ๋ชจ๋ธ์˜ ์„ธ๋ถ„ํ™”๋œ ํ›„์ฒ˜๋ฆฌ ๋‹จ๊ณ„ ์ค‘ ๋‘ ๋ฒˆ์งธ ๋‹จ๊ณ„
494
+ y, x = np.nonzero(masks[best_idx])
495
+ x_min = x.min()
496
+ x_max = x.max()
497
+ y_min = y.min()
498
+ y_max = y.max()
499
+ input_box = np.array([x_min, y_min, x_max, y_max])
500
+ masks, scores, logits, _ = predictor.predict(
501
+ point_coords=topk_xy,
502
+ point_labels=topk_label,
503
+ box=input_box[None, :],
504
+ mask_input=logits[best_idx: best_idx + 1, :, :],
505
+ multimask_output=True)
506
+ best_idx = np.argmax(scores)
507
+
508
+ final_mask = masks[best_idx]
509
+
510
+ # ์˜ˆ์ธก ์ ์ˆ˜ ์ถœ๋ ฅ
511
+ print("์˜ˆ์ธก ์ ์ˆ˜ (scores):")
512
+ for idx, score in enumerate(scores):
513
+ print(f"Mask {idx + 1}: {score.item():.4f}")
514
+ # Final mask์˜ ์ขŒํ‘œ ์ถ”์ถœ
515
+ # y_coords, x_coords = np.nonzero(final_mask)
516
+ # # ์ขŒํ‘œ๋ฅผ (y, x) ํ˜•์‹์œผ๋กœ ๋ฌถ์–ด์„œ ์ถœ๋ ฅ
517
+ # coordinates = list(zip(y_coords, x_coords))
518
+ # # ์ขŒํ‘œ ์ถœ๋ ฅ
519
+ # print("Segmentation๋œ ์ขŒํ‘œ๋“ค:")
520
+ # for coord in coordinates:
521
+ # print(coord)
522
+
523
+ # Image ์ƒ์„ฑ ๋ฐ ์ ์ˆ˜ ํ‘œ์‹œ
524
+ output_img = Image.fromarray((test_image).astype('uint8'), 'RGB')
525
+ draw = ImageDraw.Draw(output_img)
526
+
527
+ # ์‹ ๋ขฐ๋„ ์ ์ˆ˜๋ฅผ ๋งˆ์Šคํฌ ์˜์—ญ ์œ„์— ํ‘œ์‹œ
528
+ for idx, (mask, score) in enumerate(zip(masks, scores)):
529
+ y, x = np.nonzero(mask)
530
+ if len(x) > 0 and len(y) > 0: # ๋งˆ์Šคํฌ๊ฐ€ ๋น„์–ด์žˆ์ง€ ์•Š์„ ๋•Œ๋งŒ ํ…์ŠคํŠธ ํ‘œ์‹œ
531
+ x_center = int(x.mean())
532
+ y_center = int(y.mean())
533
+ draw.text((x_center, y_center), f"{score.item():.2f}", fill=(255, 255, 0))
534
+ # ์ตœ์ข… ๋งˆ์Šคํฌ ๋ฐ ์ ์ˆ˜๊ฐ€ ํฌํ•จ๋œ ์ด๋ฏธ์ง€๋ฅผ ๋ฆฌ์ŠคํŠธ์— ์ถ”๊ฐ€
535
+ mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8)
536
+ mask_colors[final_mask, :] = np.array([[128, 0, 0]])
537
+ overlay_image = Image.fromarray((mask_colors * 0.6 + test_image * 0.4).astype('uint8'), 'RGB')
538
+ draw_overlay = ImageDraw.Draw(overlay_image)
539
+
540
+ for idx, score in enumerate(scores):
541
+ draw_overlay.text((10, 10 + 20 * idx), f"Mask {idx + 1}: {score.item():.2f}", fill=(255, 255, 0))
542
+
543
+ output_image.append(overlay_image)
544
+
545
+ # output_image.append(Image.fromarray((mask_colors * 0.6 + test_image * 0.4).astype('uint8'), 'RGB'))
546
+
547
+ return output_image[0].resize((224, 224)), output_image[1].resize((224, 224))
548
+
549
+
550
+ # ์ปจํˆฌ์–ด์™€ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค๋ฅผ ๊ทธ๋ฆฌ๋Š” ํ•จ์ˆ˜
551
+ def draw_contours_and_bboxes(image, mask):
552
+ contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
553
+
554
+ # ๊ฐ์ฒด ์ˆ˜ ๊ณ„์‚ฐ
555
+ object_count = len(contours)
556
+
557
+ # ์ด๋ฏธ์ง€์— ์ปจํˆฌ์–ด์™€ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค๋ฅผ ๊ทธ๋ฆฌ๊ธฐ
558
+ for contour in contours:
559
+ # ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค
560
+ x, y, w, h = cv2.boundingRect(contour)
561
+ cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 0), 2) # ์ดˆ๋ก์ƒ‰ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค
562
+
563
+ # ์ปจํˆฌ์–ด ๊ทธ๋ฆฌ๊ธฐ
564
+ cv2.drawContours(image, [contour], -1, (0, 0, 255), 2) # ๋นจ๊ฐ„์ƒ‰ ์ปจํˆฌ์–ด
565
+
566
+ return image, object_count
567
+
568
+ def inference_finetune_test(image1, image2, image3, image4):
569
+ # in context image and mask
570
+ # ic_image = np.array(ic_image.convert("RGB"))
571
+ # ic_mask = np.array(ic_mask.convert("RGB"))
572
+
573
+ # gt_mask = torch.tensor(ic_mask)[:, :, 0] > 0
574
+ # gt_mask = gt_mask.float().unsqueeze(0).flatten(1).to('cpu')
575
+ # # gt_mask = gt_mask.float().unsqueeze(0).flatten(1)
576
+
577
+ sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
578
+ sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to('cpu')
579
+ # # sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
580
+ # for name, param in sam.named_parameters():
581
+ # param.requires_grad = False
582
+ predictor = SamPredictor(sam)
583
+
584
+ # #์ž๊ธฐ ์œ„์น˜ ์šฐ์„ ๊ฐ’ ํš๋“
585
+ print("======> Obtain Self Location Prior" )
586
+ # Image features encoding
587
+ # ref_mask = predictor.set_image(ic_image, ic_mask)
588
+ # ref_feat = predictor.features.squeeze().permute(1, 2, 0)
589
+
590
+ # ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0: 2], mode="bilinear")
591
+ # ref_mask = ref_mask.squeeze()[0]
592
+
593
+ # # Target feature extraction
594
+ # target_feat = ref_feat[ref_mask > 0]
595
+ # target_feat_mean = target_feat.mean(0)
596
+ # target_feat_max = torch.max(target_feat, dim=0)[0]
597
+ # target_feat = (target_feat_max / 2 + target_feat_mean / 2).unsqueeze(0)
598
+
599
+ # # Cosine similarity
600
+ # h, w, C = ref_feat.shape
601
+ # target_feat = target_feat / target_feat.norm(dim=-1, keepdim=True)
602
+ # ref_feat = ref_feat / ref_feat.norm(dim=-1, keepdim=True)
603
+ # ref_feat = ref_feat.permute(2, 0, 1).reshape(C, h * w)
604
+ # sim = target_feat @ ref_feat
605
+
606
+ # sim = sim.reshape(1, 1, h, w)
607
+ # sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
608
+ # sim = predictor.model.postprocess_masks(
609
+ # sim,
610
+ # input_size=predictor.input_size,
611
+ # original_size=predictor.original_size).squeeze()
612
+
613
+ # # Positive location prior
614
+ # topk_xy, topk_label, _, _ = point_selection(sim, topk=1)
615
+
616
+ # print('======> Start Training')
617
+ # # Learnable mask weights
618
+ # mask_weights = Mask_Weights().to('cpu')
619
+ # # mask_weights = Mask_Weights()
620
+ # mask_weights.train()
621
+ # train_epoch = 1000
622
+ # optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=1e-4, eps=1e-4, betas=(0.9, 0.999), weight_decay=0.01, amsgrad=False)
623
+ # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, train_epoch)
624
+
625
+ # for train_idx in range(train_epoch):
626
+ # # Run the decoder
627
+ # masks, scores, logits, logits_high = predictor.predict(
628
+ # point_coords=topk_xy,
629
+ # point_labels=topk_label,
630
+ # multimask_output=True)
631
+ # logits_high = logits_high.flatten(1)
632
+
633
+ # # Weighted sum three-scale masks
634
+ # weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0)
635
+ # logits_high = logits_high * weights
636
+ # logits_high = logits_high.sum(0).unsqueeze(0)
637
+
638
+ # dice_loss = calculate_dice_loss(logits_high, gt_mask)
639
+ # focal_loss = calculate_sigmoid_focal_loss(logits_high, gt_mask)
640
+ # loss = dice_loss + focal_loss
641
+
642
+ # optimizer.zero_grad()
643
+ # loss.backward()
644
+ # optimizer.step()
645
+ # scheduler.step()
646
+
647
+ # if train_idx % 10 == 0:
648
+ # print('Train Epoch: {:} / {:}'.format(train_idx, train_epoch))
649
+ # current_lr = scheduler.get_last_lr()[0]
650
+ # print('LR: {:.6f}, Dice_Loss: {:.4f}, Focal_Loss: {:.4f}'.format(current_lr, dice_loss.item(), focal_loss.item()))
651
+
652
+
653
+ # mask_weights.eval()
654
+ # weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0)
655
+ # weights_np = weights.detach().cpu().numpy()
656
+ # print('======> Mask weights:\n', weights_np)
657
+
658
+ # # 1. ๊ฐ€์ค‘์น˜ ์ €์žฅ
659
+ # torch.save(mask_weights.state_dict(), 'mask_weights.pth')
660
+ # print("๊ฐ€์ค‘์น˜๊ฐ€ 'mask_weights.pth' ํŒŒ์ผ๋กœ ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
661
+
662
+ #########################Training ๋ ########################################
663
+ # 2. ํ…Œ์ŠคํŠธ ์ „์šฉ ์ฝ”๋“œ
664
+ # ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ๋ฐ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ
665
+ mask_weights = Mask_Weights().to('cpu')
666
+ mask_weights.load_state_dict(torch.load('Personalize-SAM\mask_weights.pth'))
667
+ mask_weights.eval() # ํ‰๊ฐ€ ๋ชจ๋“œ๋กœ ์„ค์ • (์ถ”๊ฐ€ ํ•™์Šต ๋ฐฉ์ง€)
668
+
669
+ weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0)
670
+ weights_np = weights.detach().cpu().numpy()
671
+ print('======> Mask weights:\n', weights_np)
672
+
673
+ print('======> Start Testing')
674
+ output_image = []
675
+
676
+ # SAM Segmentation ๊ฒฐ๊ณผ๋ฅผ ์ €์žฅํ•  dictionary
677
+ segmentation_results = []
678
+
679
+ for test_image in [image1, image2, image3, image4]:
680
+ test_image = np.array(test_image.convert("RGB"))
681
+
682
+ # Image feature encoding
683
+ predictor.set_image(test_image)
684
+ test_feat = predictor.features.squeeze()
685
+ # Image feature encoding
686
+ predictor.set_image(test_image)
687
+ test_feat = predictor.features.squeeze()
688
+
689
+ # Cosine similarity
690
+ C, h, w = test_feat.shape
691
+ test_feat = test_feat / test_feat.norm(dim=0, keepdim=True)
692
+ test_feat = test_feat.reshape(C, h * w)
693
+ # target_feat ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
694
+ target_feat = torch.load('Personalize-SAM\\target_feat.pth')
695
+ sim = target_feat @ test_feat
696
+
697
+ sim = sim.reshape(1, 1, h, w)
698
+ sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
699
+ sim = predictor.model.postprocess_masks(
700
+ sim,
701
+ input_size=predictor.input_size,
702
+ original_size=predictor.original_size).squeeze()
703
+
704
+ # Positive location prior ์–‘์„ฑ ์œ„์น˜ ์šฐ์„ ๊ฐ’
705
+ topk_xy, topk_label, _, _ = point_selection(sim, topk=1)
706
+ print("์ขŒํ‘œ๊ฐ’",topk_xy)
707
+
708
+ # First-step prediction
709
+ masks, scores, logits, logits_high = predictor.predict(
710
+ point_coords=topk_xy,
711
+ point_labels=topk_label,
712
+ multimask_output=True)
713
+
714
+ # ์˜ˆ์ธก ์ ์ˆ˜ ์ถœ๋ ฅ
715
+ # print("์˜ˆ์ธก ์ ์ˆ˜ (scores):")
716
+ # for idx, score in enumerate(scores):
717
+ # print(f"Mask {idx + 1}: {score.item():.4f}")
718
+
719
+
720
+ # Weighted sum three-scale masks ์„ธ ๊ฐ€์ง€ ์Šค์ผ€์ผ์˜ ๋งˆ์Šคํฌ๋ฅผ ๊ฐ€์ค‘์น˜ ํ•ฉ์‚ฐํ•˜๋Š” ๊ณผ์ •
721
+ logits_high = logits_high * weights.unsqueeze(-1)
722
+ logit_high = logits_high.sum(0)
723
+ mask = (logit_high > 0).detach().cpu().numpy()
724
+
725
+ logits = logits * weights_np[..., None]
726
+ logit = logits.sum(0)
727
+
728
+ # Cascaded Post-refinement-1 ๋ชจ๋ธ์˜ ์„ธ๋ถ„ํ™”๋œ ํ›„์ฒ˜๋ฆฌ ๋‹จ๊ณ„ ์ค‘ ์ฒซ ๋ฒˆ์งธ ๋‹จ๊ณ„
729
+ y, x = np.nonzero(mask)
730
+ x_min = x.min()
731
+ x_max = x.max()
732
+ y_min = y.min()
733
+ y_max = y.max()
734
+ input_box = np.array([x_min, y_min, x_max, y_max])
735
+ masks, scores, logits, _ = predictor.predict(
736
+ point_coords=topk_xy,
737
+ point_labels=topk_label,
738
+ box=input_box[None, :],
739
+ mask_input=logit[None, :, :],
740
+ multimask_output=True)
741
+ best_idx = np.argmax(scores)
742
+
743
+ # Cascaded Post-refinement-2 ๋ชจ๋ธ์˜ ์„ธ๋ถ„ํ™”๋œ ํ›„์ฒ˜๋ฆฌ ๋‹จ๊ณ„ ์ค‘ ๋‘ ๋ฒˆ์งธ ๋‹จ๊ณ„
744
+ y, x = np.nonzero(masks[best_idx])
745
+ x_min = x.min()
746
+ x_max = x.max()
747
+ y_min = y.min()
748
+ y_max = y.max()
749
+ input_box = np.array([x_min, y_min, x_max, y_max])
750
+ masks, scores, logits, _ = predictor.predict(
751
+ point_coords=topk_xy,
752
+ point_labels=topk_label,
753
+ box=input_box[None, :],
754
+ mask_input=logits[best_idx: best_idx + 1, :, :],
755
+ multimask_output=True)
756
+ best_idx = np.argmax(scores)
757
+
758
+ final_mask = masks[best_idx]
759
+
760
+ # ๊ฒฐ๊ณผ๋ฅผ JSON ํ˜•์‹์œผ๋กœ ์ €์žฅํ•  dictionary
761
+ result = {
762
+ "image": f"image_{test_image}", # ์ด๋ฏธ์ง€๋ฅผ ๊ตฌ๋ถ„ํ•  ์ˆ˜ ์žˆ๋Š” ๏ฟฝ๏ฟฝ์œ ํ•œ ์ด๋ฆ„์„ ์‚ฌ์šฉ
763
+ "masks": [],
764
+ "scores": [],
765
+ "coordinates": []
766
+ }
767
+
768
+ for idx, (mask, score) in enumerate(zip(masks, scores)):
769
+ mask_coords = np.array(np.nonzero(mask)).T.tolist() # ๋งˆ์Šคํฌ ์ขŒํ‘œ๋ฅผ (y, x) ํ˜•์‹์œผ๋กœ ์ถ”์ถœ
770
+ result["masks"].append(mask_coords)
771
+ result["scores"].append(score.item())
772
+
773
+ # ๊ฐ ๋งˆ์Šคํฌ์— ๋Œ€ํ•ด ์ขŒํ‘œ ์ •๋ณด ์ถ”๊ฐ€
774
+ result["coordinates"].append(mask_coords)
775
+
776
+ # ๊ฐ ๋งˆ์Šคํฌ์˜ ์ค‘์‹ฌ ์ขŒํ‘œ ๊ณ„์‚ฐ
777
+ if mask_coords: # ์ขŒํ‘œ๊ฐ€ ์กด์žฌํ•˜๋Š” ๊ฒฝ์šฐ
778
+ y_coords, x_coords = zip(*mask_coords)
779
+ center_y = int(np.mean(y_coords))
780
+ center_x = int(np.mean(x_coords))
781
+
782
+ # ์ด๋ฏธ์ง€์— ์ค‘์‹ฌ ์ขŒํ‘œ ํ‘œ์‹œ
783
+ output_img = Image.fromarray((test_image).astype('uint8'), 'RGB')
784
+ draw = ImageDraw.Draw(output_img)
785
+ draw.text((center_x, center_y), f"({center_x}, {center_y})", fill=(255, 0, 0))
786
+
787
+ # ํ‘œ์‹œ๋œ ์ด๋ฏธ์ง€๋ฅผ ์ถœ๋ ฅ
788
+ output_image.append(output_img)
789
+
790
+ segmentation_results.append(result)
791
+
792
+ # JSON ํŒŒ์ผ๋กœ ์ €์žฅ
793
+ with open("segmentation_results.json", "w") as f:
794
+ json.dump(segmentation_results, f, indent=4)
795
+
796
+ print("Segmentation results saved as 'segmentation_results.json'")
797
+
798
+ # ์˜ˆ์ธก ์ ์ˆ˜ ์ถœ๋ ฅ
799
+ print("์˜ˆ์ธก ์ ์ˆ˜ (scores):")
800
+ for idx, score in enumerate(scores):
801
+ print(f"Mask {idx + 1}: {score.item():.4f}")
802
+ # Final mask์˜ ์ขŒํ‘œ ์ถ”์ถœ
803
+ # y_coords, x_coords = np.nonzero(final_mask)
804
+ # # ์ขŒํ‘œ๋ฅผ (y, x) ํ˜•์‹์œผ๋กœ ๋ฌถ์–ด์„œ ์ถœ๋ ฅ
805
+ # coordinates = list(zip(y_coords, x_coords))
806
+ # # ์ขŒํ‘œ ์ถœ๋ ฅ
807
+ # print("Segmentation๋œ ์ขŒํ‘œ๋“ค:")
808
+ # for coord in coordinates:
809
+ # print(coord)
810
+
811
+ # Image ์ƒ์„ฑ ๋ฐ ์ ์ˆ˜ ํ‘œ์‹œ
812
+ output_img = Image.fromarray((test_image).astype('uint8'), 'RGB')
813
+ draw = ImageDraw.Draw(output_img)
814
+
815
+
816
+ # segmentation๋œ ๊ฐ์ฒด์˜ ๊ฐœ์ˆ˜ ๊ณ„์‚ฐ
817
+ segmented_count = sum((mask.sum() > 0) for mask in masks) # ํ”ฝ์…€ ํ•ฉ์ด 0๋ณด๋‹ค ํฐ ๊ฒฝ์šฐ ์œ ํšจํ•œ segmentation์œผ๋กœ ๊ฐ„์ฃผ
818
+ # draw.text((170, 10), f"Cnt: {segmented_count}", fill=(255, 0, 0)) # segmentation ๊ฐœ์ˆ˜ ํ‘œ๊ธฐ
819
+
820
+
821
+ # ์‹ ๋ขฐ๋„ ์ ์ˆ˜๋ฅผ ๋งˆ์Šคํฌ ์˜์—ญ ์œ„์— ํ‘œ์‹œ
822
+ for idx, (mask, score) in enumerate(zip(masks, scores)):
823
+ y, x = np.nonzero(mask)
824
+ if len(x) > 0 and len(y) > 0: # ๋งˆ์Šคํฌ๊ฐ€ ๋น„์–ด์žˆ์ง€ ์•Š์„ ๋•Œ๋งŒ ํ…์ŠคํŠธ ํ‘œ์‹œ
825
+ x_center = int(x.mean())
826
+ y_center = int(y.mean())
827
+ # draw.text((x_center, y_center), f"{score.item():.2f}", fill=(255, 255, 0))
828
+
829
+
830
+ # ์ตœ์ข… ๋งˆ์Šคํฌ ๋ฐ ์ ์ˆ˜๊ฐ€ ํฌํ•จ๋œ ์ด๋ฏธ์ง€๋ฅผ ๋ฆฌ์ŠคํŠธ์— ์ถ”๊ฐ€
831
+ mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8)
832
+ mask_colors[final_mask, :] = np.array([[128, 0, 0]])
833
+
834
+
835
+ # red ๋งˆ์Šคํฌ ์˜์—ญ ์™ธ์˜ ๋ถ€๋ถ„์— ๋Œ€ํ•ด์„œ contour ๋ฐ bounding box ์ ์šฉ
836
+ test_image_np = np.array(test_image)
837
+
838
+ # 'final_mask' ์™ธ๋ถ€๋ฅผ ๋งˆ์Šคํฌ ์˜์—ญ์œผ๋กœ ์ง€์ •
839
+ final_mask_obj = final_mask.astype(np.uint8)
840
+
841
+ # inverse_mask์— ๋Œ€ํ•ด์„œ ์ปจํˆฌ์–ด ๋ฐ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค๋ฅผ ๊ทธ๋ฆผ
842
+ overlay_image, object_count = draw_contours_and_bboxes(test_image_np.copy(), final_mask_obj)
843
+
844
+ # ๊ฐ์ฒด ๊ฐœ์ˆ˜ ์ถœ๋ ฅ
845
+ print(f"Detected {object_count} objects in the background.")
846
+
847
+ # ์ตœ์ข… ์ด๋ฏธ์ง€ ๋ฐ ์ ์ˆ˜ ํ‘œ์‹œ
848
+ overlay_image = Image.fromarray(overlay_image)
849
+
850
+ # segmentation๋œ ๊ฐ์ฒด ๊ฐœ์ˆ˜๋ฅผ ๋‹ค์‹œ ํ•œ๋ฒˆ ํ‘œ๊ธฐ (์ด๋ฏธ์ง€ ์šฐ์ƒ๋‹จ ๋“ฑ ๋‹ค๋ฅธ ์œ„์น˜์—)
851
+ draw_overlay = ImageDraw.Draw(overlay_image)
852
+ draw_overlay.text((170, 10), f"Cnt: {segmented_count}", fill=(255, 255, 0))
853
+
854
+ for idx, score in enumerate(scores):
855
+ draw_overlay.text((10, 10 + 20 * idx), f"Mask {idx + 1}: {score.item():.2f}", fill=(255, 255, 0))
856
+
857
+ output_image.append(overlay_image)
858
+
859
+
860
+ # overlay_image = Image.fromarray((mask_colors * 0.6 + test_image * 0.4).astype('uint8'), 'RGB')
861
+ # draw_overlay = ImageDraw.Draw(overlay_image)
862
+
863
+ # # segmentation๋œ ๊ฐ์ฒด ๊ฐœ์ˆ˜๋ฅผ ๋‹ค์‹œ ํ•œ๋ฒˆ ํ‘œ๊ธฐ (์ด๋ฏธ์ง€ ์šฐ์ƒ๋‹จ ๋“ฑ ๋‹ค๋ฅธ ์œ„์น˜์—)
864
+ # draw_overlay.text((170, 10), f"Cnt: {segmented_count}", fill=(255, 255, 0))
865
+
866
+
867
+
868
+ # for idx, score in enumerate(scores):
869
+ # draw_overlay.text((10, 10 + 20 * idx), f"Mask {idx + 1}: {score.item():.2f}", fill=(255, 255, 0))
870
+
871
+ # output_image.append(overlay_image)
872
+
873
+ # output_image.append(Image.fromarray((mask_colors * 0.6 + test_image * 0.4).astype('uint8'), 'RGB'))
874
+
875
+ return output_image[0].resize((224, 224)), output_image[1].resize((224, 224)), output_image[2].resize((224, 224)), output_image[3].resize((224, 224))
876
+
877
+
878
+
879
+ description = """
880
+ <div style="text-align: center; font-weight: bold;">
881
+ <span style="font-size: 18px" id="paper-info">
882
+ [<a href="https://github.com/ZrrSkywalker/Personalize-SAM" target="_blank"><font color='black'>Github</font></a>]
883
+ [<a href="https://arxiv.org/pdf/2305.03048.pdf" target="_blank"><font color='black'>Paper</font></a>]
884
+ </span>
885
+ </div>
886
+ """
887
+
888
+ main = gr.Interface(
889
+ fn=inference,
890
+ inputs=[
891
+ gr.Image(type="pil", label="in context image",),
892
+ gr.Image(type="pil", label="in context mask"),
893
+ gr.Image(type="pil", label="test image1"),
894
+ gr.Image(type="pil", label="test image2"),
895
+ ],
896
+ outputs=[
897
+ gr.Image(type="pil", label="output image1"),
898
+ gr.Image(type="pil", label="output image2"),
899
+ ],
900
+ allow_flagging="never",
901
+ title="Personalize Segment Anything Model with 1 Shot",
902
+ description=description,
903
+ examples=[
904
+ ["./examples/cat_00.jpg", "./examples/cat_00.png", "./examples/cat_01.jpg", "./examples/cat_02.jpg"],
905
+ ["./examples/colorful_sneaker_00.jpg", "./examples/colorful_sneaker_00.png", "./examples/colorful_sneaker_01.jpg", "./examples/colorful_sneaker_02.jpg"],
906
+ ["./examples/duck_toy_00.jpg", "./examples/duck_toy_00.png", "./examples/duck_toy_01.jpg", "./examples/duck_toy_02.jpg"],
907
+ ]
908
+ )
909
+
910
+ main_scribble = gr.Interface(
911
+ fn=inference_scribble,
912
+ inputs=[
913
+ gr.ImageMask(label="[Stroke] Draw on Image", type="pil"),
914
+ gr.Image(type="pil", label="test image1"),
915
+ gr.Image(type="pil", label="test image2"),
916
+ ],
917
+ outputs=[
918
+ gr.Image(type="pil", label="output image1"),
919
+ gr.Image(type="pil", label="output image2"),
920
+ ],
921
+ allow_flagging="never",
922
+ title="Personalize Segment Anything Model with 1 Shot",
923
+ description=description,
924
+ examples=[
925
+ ["./examples/cat_00.jpg", "./examples/cat_01.jpg", "./examples/cat_02.jpg"],
926
+ ["./examples/colorful_sneaker_00.jpg", "./examples/colorful_sneaker_01.jpg", "./examples/colorful_sneaker_02.jpg"],
927
+ ["./examples/duck_toy_00.jpg", "./examples/duck_toy_01.jpg", "./examples/duck_toy_02.jpg"],
928
+ ]
929
+ )
930
+
931
+ main_finetune_train = gr.Interface(
932
+ fn=inference_finetune_train,
933
+ inputs=[
934
+ gr.Image(type="pil", label="in context image"),
935
+ gr.Image(type="pil", label="in context mask"),
936
+ gr.Image(type="pil", label="test image1"),
937
+ gr.Image(type="pil", label="test image2"),
938
+ ],
939
+ outputs=[
940
+ gr.components.Image(type="pil", label="output image1"),
941
+ gr.components.Image(type="pil", label="output image2"),
942
+ ],
943
+ allow_flagging="never",
944
+ title="Personalize Segment Anything Model with 1 Shot Train",
945
+ description=description,
946
+ examples=[
947
+ ["./examples/cat_00.jpg", "./examples/cat_00.png", "./examples/cat_01.jpg", "./examples/cat_02.jpg"],
948
+ ["./examples/colorful_sneaker_00.jpg", "./examples/colorful_sneaker_00.png", "./examples/colorful_sneaker_01.jpg", "./examples/colorful_sneaker_02.jpg"],
949
+ ["./examples/duck_toy_00.jpg", "./examples/duck_toy_00.png", "./examples/duck_toy_01.jpg", "./examples/duck_toy_02.jpg"],
950
+ ]
951
+ )
952
+
953
+
954
+
955
+ main_finetune_test = gr.Interface(
956
+ fn=inference_finetune_test,
957
+ inputs=[
958
+ gr.Image(type="pil", label="test image1"),
959
+ gr.Image(type="pil", label="test image2"),
960
+ gr.Image(type="pil", label="test image3"),
961
+ gr.Image(type="pil", label="test image4"),
962
+ ],
963
+ outputs=[
964
+ gr.components.Image(type="pil", label="output image1"),
965
+ gr.components.Image(type="pil", label="output image2"),
966
+ gr.components.Image(type="pil", label="output image3"),
967
+ gr.components.Image(type="pil", label="output image4"),
968
+ ],
969
+ allow_flagging="never",
970
+ title="Personalize Segment Anything Model with 1 Shot Test",
971
+ description=description,
972
+ examples=[
973
+ ["./examples/cat_00.jpg", "./examples/cat_00.png", "./examples/cat_01.jpg", "./examples/cat_02.jpg"],
974
+ ["./examples/colorful_sneaker_00.jpg", "./examples/colorful_sneaker_00.png", "./examples/colorful_sneaker_01.jpg", "./examples/colorful_sneaker_02.jpg"],
975
+ ["./examples/duck_toy_00.jpg", "./examples/duck_toy_00.png", "./examples/duck_toy_01.jpg", "./examples/duck_toy_02.jpg"],
976
+ ]
977
+ )
978
+
979
+
980
+ demo = gr.Blocks()
981
+ with demo:
982
+ gr.TabbedInterface(
983
+ [main_finetune_train, main_finetune_test],
984
+ ["Personalize-SAM-F_train", "Personalize-SAM-F_test"],
985
+ )
986
+
987
+ demo.launch(share=True)