djl234 commited on
Commit
9d42a5c
·
1 Parent(s): 35b0a7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -4
app.py CHANGED
@@ -8,18 +8,59 @@ os.system("pip3 install torch")
8
  os.system("pip3 install collections")
9
  os.system("pip3 install torchvision")
10
  os.system("pip3 install einops")
 
11
  #os.system("pip3 install argparse")
12
-
13
  from PIL import Image
14
  import torch
15
  from torchvision import transforms
16
  from model_video import build_model
17
  import numpy as np
18
  import collections
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  #import argparse
20
  device='cuda:0'
21
  net = build_model(device).to(device)
22
- #net=torch.nn.DataParallel(net)
23
  model_path = 'image_best.pth'
24
  print(model_path)
25
  weight=torch.load(model_path,map_location=torch.device(device))
@@ -44,8 +85,10 @@ def test(gpu_id, net, img_list, group_size, img_size):
44
  for i in range(5):
45
  group_img[i]=img_transform(Image.fromarray(img_list[i]))
46
  _,pred_mask=net(group_img)
47
- print(pred_mask.shape)
48
- result = [Image.fromarray((pred_mask[i].detach().squeeze().unsqueeze(2).repeat(1,1,3) * 255).numpy().astype(np.uint8)) for i in range(5)]
 
 
49
  #w, h = 224,224#Image.open(image_list[i][j]).size
50
  #result = result.resize((w, h), Image.BILINEAR)
51
  #result.convert('L').save('0.png')
 
8
  os.system("pip3 install collections")
9
  os.system("pip3 install torchvision")
10
  os.system("pip3 install einops")
11
+ os.system("pip3 install pydensecrf")
12
  #os.system("pip3 install argparse")
13
+ import pydensecrf.densecrf as dcrf
14
  from PIL import Image
15
  import torch
16
  from torchvision import transforms
17
  from model_video import build_model
18
  import numpy as np
19
  import collections
20
+
21
+ def crf_refine(img, annos):
22
+ def _sigmoid(x):
23
+ return 1 / (1 + np.exp(-x))
24
+
25
+ assert img.dtype == np.uint8
26
+ assert annos.dtype == np.uint8
27
+ assert img.shape[:2] == annos.shape
28
+
29
+ # img and annos should be np array with data type uint8
30
+
31
+ EPSILON = 1e-8
32
+
33
+ M = 2 # salient or not
34
+ tau = 1.05
35
+ # Setup the CRF model
36
+ d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], M)
37
+
38
+ anno_norm = annos / 255.
39
+
40
+ n_energy = -np.log((1.0 - anno_norm + EPSILON)) / (tau * _sigmoid(1 - anno_norm))
41
+ p_energy = -np.log(anno_norm + EPSILON) / (tau * _sigmoid(anno_norm))
42
+
43
+ U = np.zeros((M, img.shape[0] * img.shape[1]), dtype='float32')
44
+ U[0, :] = n_energy.flatten()
45
+ U[1, :] = p_energy.flatten()
46
+
47
+ d.setUnaryEnergy(U)
48
+
49
+ d.addPairwiseGaussian(sxy=3, compat=3)
50
+ d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5)
51
+
52
+ # Do the inference
53
+ infer = np.array(d.inference(1)).astype('float32')
54
+ res = infer[1, :]
55
+
56
+ res = res * 255
57
+ res = res.reshape(img.shape[:2])
58
+ return res.astype('uint8')
59
+
60
  #import argparse
61
  device='cuda:0'
62
  net = build_model(device).to(device)
63
+ #net=torch.nn.DataParallel(net)
64
  model_path = 'image_best.pth'
65
  print(model_path)
66
  weight=torch.load(model_path,map_location=torch.device(device))
 
85
  for i in range(5):
86
  group_img[i]=img_transform(Image.fromarray(img_list[i]))
87
  _,pred_mask=net(group_img)
88
+ pred_mask=(pred_mask.detach().squeeze()*255).numpy().astype(np.uint8)
89
+ pred_mask=[crf_refine(img_list[i],pred_mask[i]) for i in range(5)]
90
+ print(pred_mask[0].shape)
91
+ result = [Image.fromarray((torch.from_numpy(pred_mask[i]).unsqueeze(2).repeat(1,1,3)).numpy()) for i in range(5)]
92
  #w, h = 224,224#Image.open(image_list[i][j]).size
93
  #result = result.resize((w, h), Image.BILINEAR)
94
  #result.convert('L').save('0.png')