С Чичерин commited on
Commit
57510bb
1 Parent(s): 8029b4a

fixing issues with cpu

Browse files
Files changed (1) hide show
  1. test.py +12 -12
test.py CHANGED
@@ -526,10 +526,10 @@ def get_crossentropy_loss(gt,pre):
526
  return entropy_loss
527
 
528
  def get_alpha_loss(predict, alpha, trimap):
529
- weighted = torch.zeros(trimap.shape).cuda()
530
  weighted[trimap == 128] = 1.
531
  alpha_f = alpha / 255.
532
- alpha_f = alpha_f.cuda()
533
  diff = predict - alpha_f
534
  diff = diff * weighted
535
  alpha_loss = torch.sqrt(diff ** 2 + 1e-12)
@@ -537,9 +537,9 @@ def get_alpha_loss(predict, alpha, trimap):
537
  return alpha_loss_weighted
538
 
539
  def get_alpha_loss_whole_img(predict, alpha):
540
- weighted = torch.ones(alpha.shape).cuda()
541
  alpha_f = alpha / 255.
542
- alpha_f = alpha_f.cuda()
543
  diff = predict - alpha_f
544
  alpha_loss = torch.sqrt(diff ** 2 + 1e-12)
545
  alpha_loss = alpha_loss.sum()/(weighted.sum())
@@ -555,7 +555,7 @@ def build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=False):
555
  kernel = np.sum(gaussian(grid), axis=2)
556
  kernel /= np.sum(kernel)
557
  kernel = np.tile(kernel, (n_channels, 1, 1))
558
- kernel = torch.FloatTensor(kernel[:, None, :, :]).cuda()
559
  return Variable(kernel, requires_grad=False)
560
 
561
  def conv_gauss(img, kernel):
@@ -576,10 +576,10 @@ def laplacian_pyramid(img, kernel, max_levels=5):
576
  return pyr
577
 
578
  def get_laplacian_loss(predict, alpha, trimap):
579
- weighted = torch.zeros(trimap.shape).cuda()
580
  weighted[trimap == 128] = 1.
581
  alpha_f = alpha / 255.
582
- alpha_f = alpha_f.cuda()
583
  alpha_f = alpha_f.clone()*weighted
584
  predict = predict.clone()*weighted
585
  gauss_kernel = build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=True)
@@ -590,7 +590,7 @@ def get_laplacian_loss(predict, alpha, trimap):
590
 
591
  def get_laplacian_loss_whole_img(predict, alpha):
592
  alpha_f = alpha / 255.
593
- alpha_f = alpha_f.cuda()
594
  gauss_kernel = build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=True)
595
  pyr_alpha = laplacian_pyramid(alpha_f, gauss_kernel, 5)
596
  pyr_predict = laplacian_pyramid(predict, gauss_kernel, 5)
@@ -598,7 +598,7 @@ def get_laplacian_loss_whole_img(predict, alpha):
598
  return laplacian_loss
599
 
600
  def get_composition_loss_whole_img(img, alpha, fg, bg, predict):
601
- weighted = torch.ones(alpha.shape).cuda()
602
  predict_3 = torch.cat((predict, predict, predict), 1)
603
  comp = predict_3 * fg + (1. - predict_3) * bg
604
  comp_loss = torch.sqrt((comp - img) ** 2 + 1e-12)
@@ -781,7 +781,7 @@ def inference_img(model, img):
781
  img=cv2.copyMakeBorder(img, 8-h%8, 0, 8-w%8, 0, cv2.BORDER_REFLECT)
782
  # print(img.shape)
783
 
784
- tensor_img = torch.from_numpy(img).permute(2, 0, 1).cuda()
785
  input_t = tensor_img
786
  input_t = input_t/255.0
787
  normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
@@ -839,7 +839,7 @@ def test_am2k(model):
839
  alpha = alpha[:,:,0] if alpha.ndim>2 else alpha
840
 
841
  with torch.no_grad():
842
- torch.cuda.empty_cache()
843
  predict = inference_img( model, img)
844
 
845
 
@@ -926,7 +926,7 @@ def test_p3m10k(model,dataset_choice, max_image=-1):
926
  trimap = trimap[:,:,0] if trimap.ndim>2 else trimap
927
  alpha = alpha[:,:,0] if alpha.ndim>2 else alpha
928
  with torch.no_grad():
929
- torch.cuda.empty_cache()
930
  start = time.time()
931
 
932
 
 
526
  return entropy_loss
527
 
528
  def get_alpha_loss(predict, alpha, trimap):
529
+ weighted = torch.zeros(trimap.shape).to(device)
530
  weighted[trimap == 128] = 1.
531
  alpha_f = alpha / 255.
532
+ alpha_f = alpha_f.to(device)
533
  diff = predict - alpha_f
534
  diff = diff * weighted
535
  alpha_loss = torch.sqrt(diff ** 2 + 1e-12)
 
537
  return alpha_loss_weighted
538
 
539
  def get_alpha_loss_whole_img(predict, alpha):
540
+ weighted = torch.ones(alpha.shape).to(device)
541
  alpha_f = alpha / 255.
542
+ alpha_f = alpha_f.to(device)
543
  diff = predict - alpha_f
544
  alpha_loss = torch.sqrt(diff ** 2 + 1e-12)
545
  alpha_loss = alpha_loss.sum()/(weighted.sum())
 
555
  kernel = np.sum(gaussian(grid), axis=2)
556
  kernel /= np.sum(kernel)
557
  kernel = np.tile(kernel, (n_channels, 1, 1))
558
+ kernel = torch.FloatTensor(kernel[:, None, :, :]).to(device)
559
  return Variable(kernel, requires_grad=False)
560
 
561
  def conv_gauss(img, kernel):
 
576
  return pyr
577
 
578
  def get_laplacian_loss(predict, alpha, trimap):
579
+ weighted = torch.zeros(trimap.shape).to(device)
580
  weighted[trimap == 128] = 1.
581
  alpha_f = alpha / 255.
582
+ alpha_f = alpha_f.to(device)
583
  alpha_f = alpha_f.clone()*weighted
584
  predict = predict.clone()*weighted
585
  gauss_kernel = build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=True)
 
590
 
591
  def get_laplacian_loss_whole_img(predict, alpha):
592
  alpha_f = alpha / 255.
593
+ alpha_f = alpha_f.to(device)
594
  gauss_kernel = build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=True)
595
  pyr_alpha = laplacian_pyramid(alpha_f, gauss_kernel, 5)
596
  pyr_predict = laplacian_pyramid(predict, gauss_kernel, 5)
 
598
  return laplacian_loss
599
 
600
  def get_composition_loss_whole_img(img, alpha, fg, bg, predict):
601
+ weighted = torch.ones(alpha.shape).to(device)
602
  predict_3 = torch.cat((predict, predict, predict), 1)
603
  comp = predict_3 * fg + (1. - predict_3) * bg
604
  comp_loss = torch.sqrt((comp - img) ** 2 + 1e-12)
 
781
  img=cv2.copyMakeBorder(img, 8-h%8, 0, 8-w%8, 0, cv2.BORDER_REFLECT)
782
  # print(img.shape)
783
 
784
+ tensor_img = torch.from_numpy(img).permute(2, 0, 1).to(device)
785
  input_t = tensor_img
786
  input_t = input_t/255.0
787
  normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
 
839
  alpha = alpha[:,:,0] if alpha.ndim>2 else alpha
840
 
841
  with torch.no_grad():
842
+ # torch.cuda.empty_cache()
843
  predict = inference_img( model, img)
844
 
845
 
 
926
  trimap = trimap[:,:,0] if trimap.ndim>2 else trimap
927
  alpha = alpha[:,:,0] if alpha.ndim>2 else alpha
928
  with torch.no_grad():
929
+ # torch.cuda.empty_cache()
930
  start = time.time()
931
 
932