File size: 5,521 Bytes
87c1bd1
b866603
6056c5f
 
6c1a1d5
6056c5f
 
d3fa5af
d7d8a78
af636b2
b2a4f5f
 
3b1f1d6
d8264ca
68ad3c1
62db64f
9d42a5c
6056c5f
 
dd76ffb
6056c5f
 
 
 
9d42a5c
 
8d863f5
9d42a5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2a4f5f
d3fa5af
d86b569
9d42a5c
472e711
6056c5f
d86b569
6056c5f
 
 
 
 
 
d86b569
a335d23
6056c5f
d86b569
a335d23
6056c5f
 
 
 
 
 
 
 
 
b02f90e
dd76ffb
af636b2
9bcbd3d
 
83ba315
602b01b
 
78142e8
9d42a5c
e8eee0c
697cdd7
6056c5f
 
 
 
 
aaf297a
963faaf
aaf297a
 
 
 
 
 
 
 
b406026
 
6056c5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a335d23
6056c5f
 
d9dc8f6
 
6056c5f
 
05af878
 
6056c5f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import tqdm
#import fastCNN
import numpy as np


import gradio as gr
import os
#os.system("sudo apt-get install nvIDia-cuda-toolkit")
os.system("pip3 install torch")
#os.system("/usr/local/bin/python -m pip install --upgrade pip")
os.system("pip3 install collections")
os.system("pip3 install torchvision")
os.system("pip3 install einops")
aaaa=0
os.system("pip3 install pydensecrf")
#os.system("pip install argparse")
import pydensecrf.densecrf as dcrf
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision import transforms
from model_video import build_model
import numpy as np
import collections

def crf_refine(img, annos):
    print(img.shape,annos.shape)
    def _sigmoid(x):
        return 1 / (1 + np.exp(-x))

    assert img.dtype == np.uint8
    assert annos.dtype == np.uint8
    assert img.shape[:2] == annos.shape

    # img and annos should be np array with data type uint8

    EPSILON = 1e-8

    M = 2  # salient or not
    tau = 1.05
    # Setup the CRF model
    d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], M)

    anno_norm = annos / 255.

    n_energy = -np.log((1.0 - anno_norm + EPSILON)) / (tau * _sigmoid(1 - anno_norm))
    p_energy = -np.log(anno_norm + EPSILON) / (tau * _sigmoid(anno_norm))

    U = np.zeros((M, img.shape[0] * img.shape[1]), dtype='float32')
    U[0, :] = n_energy.flatten()
    U[1, :] = p_energy.flatten()

    d.setUnaryEnergy(U)

    d.addPairwiseGaussian(sxy=3, compat=3)
    d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5)

    # Do the inference
    infer = np.array(d.inference(1)).astype('float32')
    res = infer[1, :]

    res = res * 255
    res = res.reshape(img.shape[:2])
    return res.astype('uint8')

#import argparse
device='cpu'
net = build_model(device).to(device)
#net=torch.nn.DataParallel(net)
model_path = 'image_best.pth'
print(model_path)
weight=torch.load(model_path,map_location=torch.device(device))
#print(type(weight))
new_dict=collections.OrderedDict()
for k in weight.keys():
  new_dict[k[len('module.'):]]=weight[k]
net.load_state_dict(new_dict)
net.eval()
net = net.to(device)
def test(gpu_id, net, img_list, group_size, img_size):
    print('test')
    #device=device
    hl,wl=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list]
    img_transform = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    img_transform_gray = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(),
                                             transforms.Normalize(mean=[0.449], std=[0.226])])
    with torch.no_grad():
        
        group_img=torch.rand(5,3,224,224)
        for i in range(5):
          group_img[i]=img_transform(Image.fromarray(img_list[i]))
        _,pred_mask=net(group_img*1)
        pred_mask=(pred_mask.detach().squeeze()*255)#.numpy().astype(np.uint8)
        #pred_mask=[F.interpolate(pred_mask[i].reshape(1,1,pred_mask[i].shape[-2],pred_mask[i].shape[-1]),size=(size,size),mode='bilinear').squeeze().numpy().astype(np.uint8) for i in range(5)]
        img_resize=[((group_img[i]-group_img[i].min())/(group_img[i].max()-group_img[i].min())*255).permute(1,2,0).contiguous().numpy().astype(np.uint8) 
                    for i in range(5)]
        pred_mask=[crf_refine(img_resize[i],pred_mask[i].numpy().astype(np.uint8)) for i in range(5)]
        #for i in range(5):
        #    print(img_list[i].shape,pred_mask[i].shape)
        #pred_mask=[crf_refine(img_list[i],pred_mask[i]) for i in range(5)]
        print(pred_mask[0].shape)
        white=(torch.ones(2,pred_mask[0].shape[1],3)*255).long()
        result = [torch.cat([torch.from_numpy(img_resize[i]),white,torch.from_numpy(pred_mask[i]).unsqueeze(2).repeat(1,1,3)],dim=0).numpy() for i in range(5)]
        #w, h = 224,224#Image.open(image_list[i][j]).size
        #result = result.resize((w, h), Image.BILINEAR)
        #result.convert('L').save('0.png')
        print('done')
        return result
        
img_lst=[(torch.rand(352,352,3)*255).numpy().astype(np.uint8) for i in range(5)]






#simly test
res=test('cpu',net,img_lst,5,224)
'''for i in range(5):
    assert res[i].shape[0]==352 and res[i].shape[1]==352 and res[i].shape[2]==3'''
def sepia(img1,img2,img3,img4,img5):
  print('sepia')
  '''ans=[]
  print(len(input_imgs))
  for input_img in input_imgs:
    sepia_filter = np.array(
        [[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]]
    )
    sepia_img = input_img.dot(sepia_filter.T)
    sepia_img /= sepia_img.max()
    ans.append(input_img)'''
  img_list=[img1,img2,img3,img4,img5]
  h_list,w_list=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list]
  #print(type(img1))
  #print(img1.shape)
  result_list=test(device,net,img_list,5,224)
  #result_list=[result_list[i].resize((w_list[i], h_list[i]), Image.BILINEAR) for i in range(5)]
  img1,img2,img3,img4,img5=result_list#test('cpu',net,img_list,5,224)
  white=(torch.ones(img1.shape[0],2,3)*255).numpy().astype(np.uint8)
  return np.concatenate([img1,white,img2,white,img3,white,img4,white,img5],axis=1)

#gr.Image(shape=(224, 2))
#demo = gr.Interface(sepia, inputs=["image","image","image","image","image"], outputs=["image","image","image","image","image"])#gr.Interface(sepia, gr.Image(shape=(200, 200)), "image")
demo = gr.Interface(sepia, inputs=["image","image","image","image","image"], outputs=["image"])
demo.launch(debug=True)