justin-zk commited on
Commit
d1e13bc
1 Parent(s): 9e64572

using cude app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -106,8 +106,8 @@ def inference(ic_image, ic_mask, image1, image2):
106
  ic_mask = np.array(ic_mask.convert("RGB"))
107
 
108
  sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
109
- # sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
110
- sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
111
  predictor = SamPredictor(sam)
112
 
113
  # Image features encoding
@@ -206,8 +206,8 @@ def inference_scribble(image, image1, image2):
206
  ic_mask = np.array(ic_mask.convert("RGB"))
207
 
208
  sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
209
- # sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
210
- sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
211
  predictor = SamPredictor(sam)
212
 
213
  # Image features encoding
@@ -304,12 +304,12 @@ def inference_finetune(ic_image, ic_mask, image1, image2):
304
  ic_mask = np.array(ic_mask.convert("RGB"))
305
 
306
  gt_mask = torch.tensor(ic_mask)[:, :, 0] > 0
307
- # gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda()
308
- gt_mask = gt_mask.float().unsqueeze(0).flatten(1)
309
 
310
  sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
311
- # sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
312
- sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
313
  for name, param in sam.named_parameters():
314
  param.requires_grad = False
315
  predictor = SamPredictor(sam)
@@ -347,8 +347,8 @@ def inference_finetune(ic_image, ic_mask, image1, image2):
347
 
348
  print('======> Start Training')
349
  # Learnable mask weights
350
- # mask_weights = Mask_Weights().cuda()
351
- mask_weights = Mask_Weights()
352
  mask_weights.train()
353
  train_epoch = 1000
354
  optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=1e-3, eps=1e-4)
 
106
  ic_mask = np.array(ic_mask.convert("RGB"))
107
 
108
  sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
109
+ sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
110
+ # sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
111
  predictor = SamPredictor(sam)
112
 
113
  # Image features encoding
 
206
  ic_mask = np.array(ic_mask.convert("RGB"))
207
 
208
  sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
209
+ sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
210
+ # sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
211
  predictor = SamPredictor(sam)
212
 
213
  # Image features encoding
 
304
  ic_mask = np.array(ic_mask.convert("RGB"))
305
 
306
  gt_mask = torch.tensor(ic_mask)[:, :, 0] > 0
307
+ gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda()
308
+ # gt_mask = gt_mask.float().unsqueeze(0).flatten(1)
309
 
310
  sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
311
+ sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
312
+ # sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
313
  for name, param in sam.named_parameters():
314
  param.requires_grad = False
315
  predictor = SamPredictor(sam)
 
347
 
348
  print('======> Start Training')
349
  # Learnable mask weights
350
+ mask_weights = Mask_Weights().cuda()
351
+ # mask_weights = Mask_Weights()
352
  mask_weights.train()
353
  train_epoch = 1000
354
  optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=1e-3, eps=1e-4)