MinhNH commited on
Commit
48c5871
·
1 Parent(s): 3e77b0a

Initial commit

Browse files
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import cv2
4
+ import torch
5
+ import numpy as np
6
+ import argparse
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from baseline.DRL.actor import *
11
+ from baseline.Renderer.stroke_gen import *
12
+ from baseline.Renderer.model import *
13
+
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ width = 128
16
+
17
+
18
+ actor_path = 'ckpts/actor.pkl'
19
+ renderer_path = 'ckpts/renderer.pkl'
20
+ #
21
+ divide = 4
22
+ canvas_cnt = divide * divide
23
+
24
+ Decoder = FCN()
25
+ Decoder.load_state_dict(torch.load(renderer_path))
26
+
27
+ def decode(x, canvas): # b * (10 + 3)
28
+ x = x.view(-1, 10 + 3)
29
+ stroke = 1 - Decoder(x[:, :10])
30
+ stroke = stroke.view(-1, width, width, 1)
31
+ color_stroke = stroke * x[:, -3:].view(-1, 1, 1, 3)
32
+ stroke = stroke.permute(0, 3, 1, 2)
33
+ color_stroke = color_stroke.permute(0, 3, 1, 2)
34
+ stroke = stroke.view(-1, 5, 1, width, width)
35
+ color_stroke = color_stroke.view(-1, 5, 3, width, width)
36
+ res = []
37
+ for i in range(5):
38
+ canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i]
39
+ res.append(canvas)
40
+ return canvas, res
41
+
42
+ def small2large(x):
43
+ # (d * d, width, width) -> (d * width, d * width)
44
+ x = x.reshape(divide, divide, width, width, -1)
45
+ x = np.transpose(x, (0, 2, 1, 3, 4))
46
+ x = x.reshape(divide * width, divide * width, -1)
47
+ return x
48
+
49
+ def large2small(x):
50
+ # (d * width, d * width) -> (d * d, width, width)
51
+ x = x.reshape(divide, width, divide, width, 3)
52
+ x = np.transpose(x, (0, 2, 1, 3, 4))
53
+ x = x.reshape(canvas_cnt, width, width, 3)
54
+ return x
55
+
56
+ def smooth(img):
57
+ def smooth_pix(img, tx, ty):
58
+ if tx == divide * width - 1 or ty == divide * width - 1 or tx == 0 or ty == 0:
59
+ return img
60
+ img[tx, ty] = (img[tx, ty] + img[tx + 1, ty] + img[tx, ty + 1] + img[tx - 1, ty] + img[tx, ty - 1] + img[tx + 1, ty - 1] + img[tx - 1, ty + 1] + img[tx - 1, ty - 1] + img[tx + 1, ty + 1]) / 9
61
+ return img
62
+
63
+ for p in range(divide):
64
+ for q in range(divide):
65
+ x = p * width
66
+ y = q * width
67
+ for k in range(width):
68
+ img = smooth_pix(img, x + k, y + width - 1)
69
+ if q != divide - 1:
70
+ img = smooth_pix(img, x + k, y + width)
71
+ for k in range(width):
72
+ img = smooth_pix(img, x + width - 1, y + k)
73
+ if p != divide - 1:
74
+ img = smooth_pix(img, x + width, y + k)
75
+ return img
76
+
77
+ def save_img(res, imgid, origin_shape, output_name, divide=False):
78
+ output = res.detach().cpu().numpy() # d * d, 3, width, width
79
+ output = np.transpose(output, (0, 2, 3, 1))
80
+ if divide:
81
+ output = small2large(output)
82
+ output = smooth(output)
83
+ else:
84
+ output = output[0]
85
+ output = (output * 255).astype('uint8')
86
+ output = cv2.resize(output, origin_shape)
87
+ cv2.imwrite(output_name +"/" + str(imgid) + '.jpg', output)
88
+
89
+ actor = ResNet(9, 18, 65) # action_bundle = 5, 65 = 5 * 13
90
+ actor.load_state_dict(torch.load(actor_path))
91
+ actor = actor.to(device).eval()
92
+ Decoder = Decoder.to(device).eval()
93
+
94
+
95
+
96
+ def paint_img(img):
97
+ max_step = 40
98
+ # imgid = 0
99
+ # output_name = os.path.join('output', str(len(os.listdir('output'))) if os.path.exists('output') else '0')
100
+ # os.makedirs(output_name, exist_ok= True)
101
+ # img = cv2.imread(args.img, cv2.IMREAD_COLOR)
102
+ origin_shape = (img.shape[1], img.shape[0])
103
+ patch_img = cv2.resize(img, (width * divide, width * divide))
104
+ patch_img = large2small(patch_img)
105
+ patch_img = np.transpose(patch_img, (0, 3, 1, 2))
106
+ patch_img = torch.tensor(patch_img).to(device).float() / 255.
107
+
108
+ img = cv2.resize(img, (width, width))
109
+ img = img.reshape(1, width, width, 3)
110
+ img = np.transpose(img, (0, 3, 1, 2))
111
+ img = torch.tensor(img).to(device).float() / 255.
112
+
113
+ T = torch.ones([1, 1, width, width], dtype=torch.float32).to(device)
114
+ coord = torch.zeros([1, 2, width, width])
115
+ for i in range(width):
116
+ for j in range(width):
117
+ coord[0, 0, i, j] = i / (width - 1.)
118
+ coord[0, 1, i, j] = j / (width - 1.)
119
+ coord = coord.to(device) # Coordconv
120
+ canvas = torch.zeros([1, 3, width, width]).to(device)
121
+
122
+ with torch.no_grad():
123
+ if divide != 1:
124
+ max_step = max_step // 2
125
+ for i in range(max_step):
126
+ stepnum = T * i / max_step
127
+ actions = actor(torch.cat([canvas, img, stepnum, coord], 1))
128
+ canvas, res = decode(actions, canvas)
129
+ for j in range(5):
130
+ # save_img(res[j], imgid)
131
+ # imgid += 1
132
+ output = res[j].detach().cpu().numpy() # d * d, 3, width, width
133
+ output = np.transpose(output, (0, 2, 3, 1))
134
+ output = output[0]
135
+ output = (output * 255).astype('uint8')
136
+ output = cv2.resize(output, origin_shape)
137
+ yield output
138
+ if divide != 1:
139
+ canvas = canvas[0].detach().cpu().numpy()
140
+ canvas = np.transpose(canvas, (1, 2, 0))
141
+ canvas = cv2.resize(canvas, (width * divide, width * divide))
142
+ canvas = large2small(canvas)
143
+ canvas = np.transpose(canvas, (0, 3, 1, 2))
144
+ canvas = torch.tensor(canvas).to(device).float()
145
+ coord = coord.expand(canvas_cnt, 2, width, width)
146
+ T = T.expand(canvas_cnt, 1, width, width)
147
+ for i in range(max_step):
148
+ stepnum = T * i / max_step
149
+ actions = actor(torch.cat([canvas, patch_img, stepnum, coord], 1))
150
+ canvas, res = decode(actions, canvas)
151
+ # print('divided canvas step {}, L2Loss = {}'.format(i, ((canvas - patch_img) ** 2).mean()))
152
+ for j in range(5):
153
+ # save_img(res[j], imgid, True)
154
+ # imgid += 1
155
+ output = res[j].detach().cpu().numpy() # d * d, 3, width, width
156
+ output = np.transpose(output, (0, 2, 3, 1))
157
+ output = small2large(output)
158
+ output = smooth(output)
159
+ output = (output * 255).astype('uint8')
160
+ output = cv2.resize(output, origin_shape)
161
+ yield output
162
+
163
+ return output
164
+
165
+ examples = [
166
+ ["image\chaoyue.png"],
167
+ ["image\degang.png"],
168
+ ["image\JayChou.png"],
169
+ ["image\Leslie.png"],
170
+ ["image\mayun.png"],
171
+
172
+ ]
173
+
174
+ demo = gr.Interface(fn=paint_img, inputs=gr.Image(), outputs="image", examples = examples)
175
+ demo.queue()
176
+ demo.launch(server_name="0.0.0.0")
baseline/DRL/__pycache__/actor.cpython-310.pyc ADDED
Binary file (4.08 kB). View file
 
baseline/DRL/actor.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.nn.utils.weight_norm as weightNorm
7
+
8
+ from torch.autograd import Variable
9
+ import sys
10
+
11
+ def conv3x3(in_planes, out_planes, stride=1):
12
+ return (nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False))
13
+
14
+ def cfg(depth):
15
+ depth_lst = [18, 34, 50, 101, 152]
16
+ assert (depth in depth_lst), "Error : Resnet depth should be either 18, 34, 50, 101, 152"
17
+ cf_dict = {
18
+ '18': (BasicBlock, [2,2,2,2]),
19
+ '34': (BasicBlock, [3,4,6,3]),
20
+ '50': (Bottleneck, [3,4,6,3]),
21
+ '101':(Bottleneck, [3,4,23,3]),
22
+ '152':(Bottleneck, [3,8,36,3]),
23
+ }
24
+
25
+ return cf_dict[str(depth)]
26
+
27
+ class BasicBlock(nn.Module):
28
+ expansion = 1
29
+
30
+ def __init__(self, in_planes, planes, stride=1):
31
+ super(BasicBlock, self).__init__()
32
+ self.conv1 = conv3x3(in_planes, planes, stride)
33
+ self.bn1 = nn.BatchNorm2d(planes)
34
+ self.conv2 = conv3x3(planes, planes)
35
+ self.bn2 = nn.BatchNorm2d(planes)
36
+
37
+ self.shortcut = nn.Sequential()
38
+ if stride != 1 or in_planes != self.expansion * planes:
39
+ self.shortcut = nn.Sequential(
40
+ (nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)),
41
+ nn.BatchNorm2d(self.expansion*planes)
42
+ )
43
+
44
+ def forward(self, x):
45
+ out = F.relu(self.bn1(self.conv1(x)))
46
+ out = self.bn2(self.conv2(out))
47
+ out += self.shortcut(x)
48
+ out = F.relu(out)
49
+
50
+ return out
51
+
52
+ class Bottleneck(nn.Module):
53
+ expansion = 4
54
+
55
+ def __init__(self, in_planes, planes, stride=1):
56
+ super(Bottleneck, self).__init__()
57
+ self.conv1 = (nn.Conv2d(in_planes, planes, kernel_size=1, bias=False))
58
+ self.conv2 = (nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False))
59
+ self.conv3 = (nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False))
60
+ self.bn1 = nn.BatchNorm2d(planes)
61
+ self.bn2 = nn.BatchNorm2d(planes)
62
+ self.bn3 = nn.BatchNorm2d(self.expansion*planes)
63
+
64
+ self.shortcut = nn.Sequential()
65
+ if stride != 1 or in_planes != self.expansion*planes:
66
+ self.shortcut = nn.Sequential(
67
+ (nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)),
68
+ )
69
+
70
+ def forward(self, x):
71
+ out = F.relu(self.bn1(self.conv1(x)))
72
+ out = F.relu(self.bn2(self.conv2(out)))
73
+ out = self.bn3(self.conv3(out))
74
+ out += self.shortcut(x)
75
+ out = F.relu(out)
76
+
77
+ return out
78
+
79
+ class ResNet(nn.Module):
80
+ def __init__(self, num_inputs, depth, num_outputs):
81
+ super(ResNet, self).__init__()
82
+ self.in_planes = 64
83
+
84
+ block, num_blocks = cfg(depth)
85
+
86
+ self.conv1 = conv3x3(num_inputs, 64, 2)
87
+ self.bn1 = nn.BatchNorm2d(64)
88
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=2)
89
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
90
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
91
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
92
+ self.fc = nn.Linear(512 * block.expansion, num_outputs)
93
+
94
+ def _make_layer(self, block, planes, num_blocks, stride):
95
+ strides = [stride] + [1]*(num_blocks-1)
96
+ layers = []
97
+
98
+ for stride in strides:
99
+ layers.append(block(self.in_planes, planes, stride))
100
+ self.in_planes = planes * block.expansion
101
+
102
+ return nn.Sequential(*layers)
103
+
104
+ def forward(self, x):
105
+ x = F.relu(self.bn1(self.conv1(x)))
106
+ x = self.layer1(x)
107
+ x = self.layer2(x)
108
+ x = self.layer3(x)
109
+ x = self.layer4(x)
110
+ x = F.avg_pool2d(x, 4)
111
+ x = x.view(x.size(0), -1)
112
+ x = self.fc(x)
113
+ x = torch.sigmoid(x)
114
+ return x
baseline/DRL/critic.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.nn.utils.weight_norm as weightNorm
5
+
6
+ from torch.autograd import Variable
7
+ import sys
8
+
9
+ def conv3x3(in_planes, out_planes, stride=1):
10
+ return weightNorm(nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True))
11
+
12
+ class TReLU(nn.Module):
13
+ def __init__(self):
14
+ super(TReLU, self).__init__()
15
+ self.alpha = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
16
+ self.alpha.data.fill_(0)
17
+
18
+ def forward(self, x):
19
+ x = F.relu(x - self.alpha) + self.alpha
20
+ return x
21
+
22
+ def cfg(depth):
23
+ depth_lst = [18, 34, 50, 101, 152]
24
+ assert (depth in depth_lst), "Error : Resnet depth should be either 18, 34, 50, 101, 152"
25
+ cf_dict = {
26
+ '18': (BasicBlock, [2,2,2,2]),
27
+ '34': (BasicBlock, [3,4,6,3]),
28
+ '50': (Bottleneck, [3,4,6,3]),
29
+ '101':(Bottleneck, [3,4,23,3]),
30
+ '152':(Bottleneck, [3,8,36,3]),
31
+ }
32
+
33
+ return cf_dict[str(depth)]
34
+
35
+ class BasicBlock(nn.Module):
36
+ expansion = 1
37
+
38
+ def __init__(self, in_planes, planes, stride=1):
39
+ super(BasicBlock, self).__init__()
40
+ self.conv1 = conv3x3(in_planes, planes, stride)
41
+ self.conv2 = conv3x3(planes, planes)
42
+
43
+ self.shortcut = nn.Sequential()
44
+ if stride != 1 or in_planes != self.expansion * planes:
45
+ self.shortcut = nn.Sequential(
46
+ weightNorm(nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True)),
47
+ )
48
+ self.relu_1 = TReLU()
49
+ self.relu_2 = TReLU()
50
+
51
+ def forward(self, x):
52
+ out = self.relu_1(self.conv1(x))
53
+ out = self.conv2(out)
54
+ out += self.shortcut(x)
55
+ out = self.relu_2(out)
56
+
57
+ return out
58
+
59
+ class Bottleneck(nn.Module):
60
+ expansion = 4
61
+
62
+ def __init__(self, in_planes, planes, stride=1):
63
+ super(Bottleneck, self).__init__()
64
+ self.conv1 = weightNorm(nn.Conv2d(in_planes, planes, kernel_size=1, bias=True))
65
+ self.conv2 = weightNorm(nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True))
66
+ self.conv3 = weightNorm(nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=True))
67
+ self.relu_1 = TReLU()
68
+ self.relu_2 = TReLU()
69
+ self.relu_3 = TReLU()
70
+
71
+ self.shortcut = nn.Sequential()
72
+ if stride != 1 or in_planes != self.expansion*planes:
73
+ self.shortcut = nn.Sequential(
74
+ weightNorm(nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True)),
75
+ )
76
+
77
+ def forward(self, x):
78
+ out = self.relu_1(self.conv1(x))
79
+ out = self.relu_2(self.conv2(out))
80
+ out = self.conv3(out)
81
+ out += self.shortcut(x)
82
+ out = self.relu_3(out)
83
+
84
+ return out
85
+
86
+ class ResNet_wobn(nn.Module):
87
+ def __init__(self, num_inputs, depth, num_outputs):
88
+ super(ResNet_wobn, self).__init__()
89
+ self.in_planes = 64
90
+
91
+ block, num_blocks = cfg(depth)
92
+
93
+ self.conv1 = conv3x3(num_inputs, 64, 2)
94
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=2)
95
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
96
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
97
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
98
+ self.fc = nn.Linear(512 * block.expansion, num_outputs)
99
+ self.relu_1 = TReLU()
100
+
101
+ def _make_layer(self, block, planes, num_blocks, stride):
102
+ strides = [stride] + [1]*(num_blocks-1)
103
+ layers = []
104
+
105
+ for stride in strides:
106
+ layers.append(block(self.in_planes, planes, stride))
107
+ self.in_planes = planes * block.expansion
108
+
109
+ return nn.Sequential(*layers)
110
+
111
+ def forward(self, x):
112
+ x = self.relu_1(self.conv1(x))
113
+ x = self.layer1(x)
114
+ x = self.layer2(x)
115
+ x = self.layer3(x)
116
+ x = self.layer4(x)
117
+ x = F.avg_pool2d(x, 4)
118
+ x = x.view(x.size(0), -1)
119
+ x = self.fc(x)
120
+ return x
baseline/DRL/ddpg.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.optim import Adam, SGD
6
+ from Renderer.model import *
7
+ from DRL.rpm import rpm
8
+ from DRL.actor import *
9
+ from DRL.critic import *
10
+ from DRL.wgan import *
11
+ from utils.util import *
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ coord = torch.zeros([1, 2, 128, 128])
15
+ for i in range(128):
16
+ for j in range(128):
17
+ coord[0, 0, i, j] = i / 127.
18
+ coord[0, 1, i, j] = j / 127.
19
+ coord = coord.to(device)
20
+
21
+ criterion = nn.MSELoss()
22
+
23
+ Decoder = FCN()
24
+ Decoder.load_state_dict(torch.load('../renderer.pkl'))
25
+
26
+ def decode(x, canvas): # b * (10 + 3)
27
+ x = x.view(-1, 10 + 3)
28
+ stroke = 1 - Decoder(x[:, :10])
29
+ stroke = stroke.view(-1, 128, 128, 1)
30
+ color_stroke = stroke * x[:, -3:].view(-1, 1, 1, 3)
31
+ stroke = stroke.permute(0, 3, 1, 2)
32
+ color_stroke = color_stroke.permute(0, 3, 1, 2)
33
+ stroke = stroke.view(-1, 5, 1, 128, 128)
34
+ color_stroke = color_stroke.view(-1, 5, 3, 128, 128)
35
+ for i in range(5):
36
+ canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i]
37
+ return canvas
38
+
39
+ def cal_trans(s, t):
40
+ return (s.transpose(0, 3) * t).transpose(0, 3)
41
+
42
+ class DDPG(object):
43
+ def __init__(self, batch_size=64, env_batch=1, max_step=40, \
44
+ tau=0.001, discount=0.9, rmsize=800, \
45
+ writer=None, resume=None, output_path=None):
46
+
47
+ self.max_step = max_step
48
+ self.env_batch = env_batch
49
+ self.batch_size = batch_size
50
+
51
+ self.actor = ResNet(9, 18, 65) # target, canvas, stepnum, coordconv 3 + 3 + 1 + 2
52
+ self.actor_target = ResNet(9, 18, 65)
53
+ self.critic = ResNet_wobn(3 + 9, 18, 1) # add the last canvas for better prediction
54
+ self.critic_target = ResNet_wobn(3 + 9, 18, 1)
55
+
56
+ self.actor_optim = Adam(self.actor.parameters(), lr=1e-2)
57
+ self.critic_optim = Adam(self.critic.parameters(), lr=1e-2)
58
+
59
+ if (resume != None):
60
+ self.load_weights(resume)
61
+
62
+ hard_update(self.actor_target, self.actor)
63
+ hard_update(self.critic_target, self.critic)
64
+
65
+ # Create replay buffer
66
+ self.memory = rpm(rmsize * max_step)
67
+
68
+ # Hyper-parameters
69
+ self.tau = tau
70
+ self.discount = discount
71
+
72
+ # Tensorboard
73
+ self.writer = writer
74
+ self.log = 0
75
+
76
+ self.state = [None] * self.env_batch # Most recent state
77
+ self.action = [None] * self.env_batch # Most recent action
78
+ self.choose_device()
79
+
80
+ def play(self, state, target=False):
81
+ state = torch.cat((state[:, :6].float() / 255, state[:, 6:7].float() / self.max_step, coord.expand(state.shape[0], 2, 128, 128)), 1)
82
+ if target:
83
+ return self.actor_target(state)
84
+ else:
85
+ return self.actor(state)
86
+
87
+ def update_gan(self, state):
88
+ canvas = state[:, :3]
89
+ gt = state[:, 3 : 6]
90
+ fake, real, penal = update(canvas.float() / 255, gt.float() / 255)
91
+ if self.log % 20 == 0:
92
+ self.writer.add_scalar('train/gan_fake', fake, self.log)
93
+ self.writer.add_scalar('train/gan_real', real, self.log)
94
+ self.writer.add_scalar('train/gan_penal', penal, self.log)
95
+
96
+ def evaluate(self, state, action, target=False):
97
+ T = state[:, 6 : 7]
98
+ gt = state[:, 3 : 6].float() / 255
99
+ canvas0 = state[:, :3].float() / 255
100
+ canvas1 = decode(action, canvas0)
101
+ gan_reward = cal_reward(canvas1, gt) - cal_reward(canvas0, gt)
102
+ # L2_reward = ((canvas0 - gt) ** 2).mean(1).mean(1).mean(1) - ((canvas1 - gt) ** 2).mean(1).mean(1).mean(1)
103
+ coord_ = coord.expand(state.shape[0], 2, 128, 128)
104
+ merged_state = torch.cat([canvas0, canvas1, gt, (T + 1).float() / self.max_step, coord_], 1)
105
+ # canvas0 is not necessarily added
106
+ if target:
107
+ Q = self.critic_target(merged_state)
108
+ return (Q + gan_reward), gan_reward
109
+ else:
110
+ Q = self.critic(merged_state)
111
+ if self.log % 20 == 0:
112
+ self.writer.add_scalar('train/expect_reward', Q.mean(), self.log)
113
+ self.writer.add_scalar('train/gan_reward', gan_reward.mean(), self.log)
114
+ return (Q + gan_reward), gan_reward
115
+
116
+ def update_policy(self, lr):
117
+ self.log += 1
118
+
119
+ for param_group in self.critic_optim.param_groups:
120
+ param_group['lr'] = lr[0]
121
+ for param_group in self.actor_optim.param_groups:
122
+ param_group['lr'] = lr[1]
123
+
124
+ # Sample batch
125
+ state, action, reward, \
126
+ next_state, terminal = self.memory.sample_batch(self.batch_size, device)
127
+
128
+ self.update_gan(next_state)
129
+
130
+ with torch.no_grad():
131
+ next_action = self.play(next_state, True)
132
+ target_q, _ = self.evaluate(next_state, next_action, True)
133
+ target_q = self.discount * ((1 - terminal.float()).view(-1, 1)) * target_q
134
+
135
+ cur_q, step_reward = self.evaluate(state, action)
136
+ target_q += step_reward.detach()
137
+
138
+ value_loss = criterion(cur_q, target_q)
139
+ self.critic.zero_grad()
140
+ value_loss.backward(retain_graph=True)
141
+ self.critic_optim.step()
142
+
143
+ action = self.play(state)
144
+ pre_q, _ = self.evaluate(state.detach(), action)
145
+ policy_loss = -pre_q.mean()
146
+ self.actor.zero_grad()
147
+ policy_loss.backward(retain_graph=True)
148
+ self.actor_optim.step()
149
+
150
+ # Target update
151
+ soft_update(self.actor_target, self.actor, self.tau)
152
+ soft_update(self.critic_target, self.critic, self.tau)
153
+
154
+ return -policy_loss, value_loss
155
+
156
+ def observe(self, reward, state, done, step):
157
+ s0 = torch.tensor(self.state, device='cpu')
158
+ a = to_tensor(self.action, "cpu")
159
+ r = to_tensor(reward, "cpu")
160
+ s1 = torch.tensor(state, device='cpu')
161
+ d = to_tensor(done.astype('float32'), "cpu")
162
+ for i in range(self.env_batch):
163
+ self.memory.append([s0[i], a[i], r[i], s1[i], d[i]])
164
+ self.state = state
165
+
166
+ def noise_action(self, noise_factor, state, action):
167
+ noise = np.zeros(action.shape)
168
+ for i in range(self.env_batch):
169
+ action[i] = action[i] + np.random.normal(0, self.noise_level[i], action.shape[1:]).astype('float32')
170
+ return np.clip(action.astype('float32'), 0, 1)
171
+
172
+ def select_action(self, state, return_fix=False, noise_factor=0):
173
+ self.eval()
174
+ with torch.no_grad():
175
+ action = self.play(state)
176
+ action = to_numpy(action)
177
+ if noise_factor > 0:
178
+ action = self.noise_action(noise_factor, state, action)
179
+ self.train()
180
+ self.action = action
181
+ if return_fix:
182
+ return action
183
+ return self.action
184
+
185
+ def reset(self, obs, factor):
186
+ self.state = obs
187
+ self.noise_level = np.random.uniform(0, factor, self.env_batch)
188
+
189
+ def load_weights(self, path):
190
+ if path is None: return
191
+ self.actor.load_state_dict(torch.load('{}/actor.pkl'.format(path)))
192
+ self.critic.load_state_dict(torch.load('{}/critic.pkl'.format(path)))
193
+ load_gan(path)
194
+
195
+ def save_model(self, path):
196
+ self.actor.cpu()
197
+ self.critic.cpu()
198
+ torch.save(self.actor.state_dict(),'{}/actor.pkl'.format(path))
199
+ torch.save(self.critic.state_dict(),'{}/critic.pkl'.format(path))
200
+ save_gan(path)
201
+ self.choose_device()
202
+
203
+ def eval(self):
204
+ self.actor.eval()
205
+ self.actor_target.eval()
206
+ self.critic.eval()
207
+ self.critic_target.eval()
208
+
209
+ def train(self):
210
+ self.actor.train()
211
+ self.actor_target.train()
212
+ self.critic.train()
213
+ self.critic_target.train()
214
+
215
+ def choose_device(self):
216
+ Decoder.to(device)
217
+ self.actor.to(device)
218
+ self.actor_target.to(device)
219
+ self.critic.to(device)
220
+ self.critic_target.to(device)
baseline/DRL/evaluator.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from utils.util import *
3
+
4
+ class Evaluator(object):
5
+
6
+ def __init__(self, args, writer):
7
+ self.validate_episodes = args.validate_episodes
8
+ self.max_step = args.max_step
9
+ self.env_batch = args.env_batch
10
+ self.writer = writer
11
+ self.log = 0
12
+
13
+ def __call__(self, env, policy, debug=False):
14
+ observation = None
15
+ for episode in range(self.validate_episodes):
16
+ # reset at the start of episode
17
+ observation = env.reset(test=True, episode=episode)
18
+ episode_steps = 0
19
+ episode_reward = 0.
20
+ assert observation is not None
21
+ # start episode
22
+ episode_reward = np.zeros(self.env_batch)
23
+ while (episode_steps < self.max_step or not self.max_step):
24
+ action = policy(observation)
25
+ observation, reward, done, (step_num) = env.step(action)
26
+ episode_reward += reward
27
+ episode_steps += 1
28
+ env.save_image(self.log, episode_steps)
29
+ dist = env.get_dist()
30
+ self.log += 1
31
+ return episode_reward, dist
baseline/DRL/multi.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ from env import Paint
5
+ from utils.util import *
6
+ from DRL.ddpg import decode
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ class fastenv():
10
+ def __init__(self,
11
+ max_episode_length=10, env_batch=64, \
12
+ writer=None):
13
+ self.max_episode_length = max_episode_length
14
+ self.env_batch = env_batch
15
+ self.env = Paint(self.env_batch, self.max_episode_length)
16
+ self.env.load_data()
17
+ self.observation_space = self.env.observation_space
18
+ self.action_space = self.env.action_space
19
+ self.writer = writer
20
+ self.test = False
21
+ self.log = 0
22
+
23
+ def save_image(self, log, step):
24
+ for i in range(self.env_batch):
25
+ if self.env.imgid[i] <= 10:
26
+ canvas = cv2.cvtColor((to_numpy(self.env.canvas[i].permute(1, 2, 0))), cv2.COLOR_BGR2RGB)
27
+ self.writer.add_image('{}/canvas_{}.png'.format(str(self.env.imgid[i]), str(step)), canvas, log)
28
+ if step == self.max_episode_length:
29
+ for i in range(self.env_batch):
30
+ if self.env.imgid[i] < 50:
31
+ gt = cv2.cvtColor((to_numpy(self.env.gt[i].permute(1, 2, 0))), cv2.COLOR_BGR2RGB)
32
+ canvas = cv2.cvtColor((to_numpy(self.env.canvas[i].permute(1, 2, 0))), cv2.COLOR_BGR2RGB)
33
+ self.writer.add_image(str(self.env.imgid[i]) + '/_target.png', gt, log)
34
+ self.writer.add_image(str(self.env.imgid[i]) + '/_canvas.png', canvas, log)
35
+
36
+ def step(self, action):
37
+ with torch.no_grad():
38
+ ob, r, d, _ = self.env.step(torch.tensor(action).to(device))
39
+ if d[0]:
40
+ if not self.test:
41
+ self.dist = self.get_dist()
42
+ for i in range(self.env_batch):
43
+ self.writer.add_scalar('train/dist', self.dist[i], self.log)
44
+ self.log += 1
45
+ return ob, r, d, _
46
+
47
+ def get_dist(self):
48
+ return to_numpy((((self.env.gt.float() - self.env.canvas.float()) / 255) ** 2).mean(1).mean(1).mean(1))
49
+
50
+ def reset(self, test=False, episode=0):
51
+ self.test = test
52
+ ob = self.env.reset(self.test, episode * self.env_batch)
53
+ return ob
baseline/DRL/rpm.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from collections import deque
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import pickle as pickle
6
+
7
+ class rpm(object):
8
+ # replay memory
9
+ def __init__(self, buffer_size):
10
+ self.buffer_size = buffer_size
11
+ self.buffer = []
12
+ self.index = 0
13
+
14
+ def append(self, obj):
15
+ if self.size() > self.buffer_size:
16
+ print('buffer size larger than set value, trimming...')
17
+ self.buffer = self.buffer[(self.size() - self.buffer_size):]
18
+ elif self.size() == self.buffer_size:
19
+ self.buffer[self.index] = obj
20
+ self.index += 1
21
+ self.index %= self.buffer_size
22
+ else:
23
+ self.buffer.append(obj)
24
+
25
+ def size(self):
26
+ return len(self.buffer)
27
+
28
+ def sample_batch(self, batch_size, device, only_state=False):
29
+ if self.size() < batch_size:
30
+ batch = random.sample(self.buffer, self.size())
31
+ else:
32
+ batch = random.sample(self.buffer, batch_size)
33
+
34
+ if only_state:
35
+ res = torch.stack(tuple(item[3] for item in batch), dim=0)
36
+ return res.to(device)
37
+ else:
38
+ item_count = 5
39
+ res = []
40
+ for i in range(5):
41
+ k = torch.stack(tuple(item[i] for item in batch), dim=0)
42
+ res.append(k.to(device))
43
+ return res[0], res[1], res[2], res[3], res[4]
baseline/DRL/wgan.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from torch.optim import Adam, SGD
5
+ from torch import autograd
6
+ from torch.autograd import Variable
7
+ import torch.nn.functional as F
8
+ from torch.autograd import grad as torch_grad
9
+ import torch.nn.utils.weight_norm as weightNorm
10
+ from utils.util import *
11
+
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ dim = 128
14
+ LAMBDA = 10 # Gradient penalty lambda hyperparameter
15
+
16
+ class TReLU(nn.Module):
17
+ def __init__(self):
18
+ super(TReLU, self).__init__()
19
+ self.alpha = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
20
+ self.alpha.data.fill_(0)
21
+
22
+ def forward(self, x):
23
+ x = F.relu(x - self.alpha) + self.alpha
24
+ return x
25
+
26
+ class Discriminator(nn.Module):
27
+ def __init__(self):
28
+ super(Discriminator, self).__init__()
29
+
30
+ self.conv0 = weightNorm(nn.Conv2d(6, 16, 5, 2, 2))
31
+ self.conv1 = weightNorm(nn.Conv2d(16, 32, 5, 2, 2))
32
+ self.conv2 = weightNorm(nn.Conv2d(32, 64, 5, 2, 2))
33
+ self.conv3 = weightNorm(nn.Conv2d(64, 128, 5, 2, 2))
34
+ self.conv4 = weightNorm(nn.Conv2d(128, 1, 5, 2, 2))
35
+ self.relu0 = TReLU()
36
+ self.relu1 = TReLU()
37
+ self.relu2 = TReLU()
38
+ self.relu3 = TReLU()
39
+
40
+ def forward(self, x):
41
+ x = self.conv0(x)
42
+ x = self.relu0(x)
43
+ x = self.conv1(x)
44
+ x = self.relu1(x)
45
+ x = self.conv2(x)
46
+ x = self.relu2(x)
47
+ x = self.conv3(x)
48
+ x = self.relu3(x)
49
+ x = self.conv4(x)
50
+ x = F.avg_pool2d(x, 4)
51
+ x = x.view(-1, 1)
52
+ return x
53
+
54
+ netD = Discriminator()
55
+ target_netD = Discriminator()
56
+ netD = netD.to(device)
57
+ target_netD = target_netD.to(device)
58
+ hard_update(target_netD, netD)
59
+
60
+ optimizerD = Adam(netD.parameters(), lr=3e-4, betas=(0.5, 0.999))
61
+ def cal_gradient_penalty(netD, real_data, fake_data, batch_size):
62
+ alpha = torch.rand(batch_size, 1)
63
+ alpha = alpha.expand(batch_size, int(real_data.nelement()/batch_size)).contiguous()
64
+ alpha = alpha.view(batch_size, 6, dim, dim)
65
+ alpha = alpha.to(device)
66
+ fake_data = fake_data.view(batch_size, 6, dim, dim)
67
+ interpolates = Variable(alpha * real_data.data + ((1 - alpha) * fake_data.data), requires_grad=True)
68
+ disc_interpolates = netD(interpolates)
69
+ gradients = autograd.grad(disc_interpolates, interpolates,
70
+ grad_outputs=torch.ones(disc_interpolates.size()).to(device),
71
+ create_graph=True, retain_graph=True)[0]
72
+ gradients = gradients.view(gradients.size(0), -1)
73
+ gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
74
+ return gradient_penalty
75
+
76
+ def cal_reward(fake_data, real_data):
77
+ return target_netD(torch.cat([real_data, fake_data], 1))
78
+
79
+ def save_gan(path):
80
+ netD.cpu()
81
+ torch.save(netD.state_dict(),'{}/wgan.pkl'.format(path))
82
+ netD.to(device)
83
+
84
+ def load_gan(path):
85
+ netD.load_state_dict(torch.load('{}/wgan.pkl'.format(path)))
86
+
87
+ def update(fake_data, real_data):
88
+ fake_data = fake_data.detach()
89
+ real_data = real_data.detach()
90
+ fake = torch.cat([real_data, fake_data], 1)
91
+ real = torch.cat([real_data, real_data], 1)
92
+ D_real = netD(real)
93
+ D_fake = netD(fake)
94
+ gradient_penalty = cal_gradient_penalty(netD, real, fake, real.shape[0])
95
+ optimizerD.zero_grad()
96
+ D_cost = D_fake.mean() - D_real.mean() + gradient_penalty
97
+ D_cost.backward()
98
+ optimizerD.step()
99
+ soft_update(target_netD, netD, 0.001)
100
+ return D_fake.mean(), D_real.mean(), gradient_penalty
baseline/Renderer/__init__.py ADDED
File without changes
baseline/Renderer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (181 Bytes). View file
 
baseline/Renderer/__pycache__/model.cpython-310.pyc ADDED
Binary file (1.52 kB). View file
 
baseline/Renderer/__pycache__/stroke_gen.cpython-310.pyc ADDED
Binary file (1.12 kB). View file
 
baseline/Renderer/model.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.nn.utils.weight_norm as weightNorm
5
+
6
+ class FCN(nn.Module):
7
+ def __init__(self):
8
+ super(FCN, self).__init__()
9
+ self.fc1 = (nn.Linear(10, 512))
10
+ self.fc2 = (nn.Linear(512, 1024))
11
+ self.fc3 = (nn.Linear(1024, 2048))
12
+ self.fc4 = (nn.Linear(2048, 4096))
13
+ self.conv1 = (nn.Conv2d(16, 32, 3, 1, 1))
14
+ self.conv2 = (nn.Conv2d(32, 32, 3, 1, 1))
15
+ self.conv3 = (nn.Conv2d(8, 16, 3, 1, 1))
16
+ self.conv4 = (nn.Conv2d(16, 16, 3, 1, 1))
17
+ self.conv5 = (nn.Conv2d(4, 8, 3, 1, 1))
18
+ self.conv6 = (nn.Conv2d(8, 4, 3, 1, 1))
19
+ self.pixel_shuffle = nn.PixelShuffle(2)
20
+
21
+ def forward(self, x):
22
+ x = F.relu(self.fc1(x))
23
+ x = F.relu(self.fc2(x))
24
+ x = F.relu(self.fc3(x))
25
+ x = F.relu(self.fc4(x))
26
+ x = x.view(-1, 16, 16, 16)
27
+ x = F.relu(self.conv1(x))
28
+ x = self.pixel_shuffle(self.conv2(x))
29
+ x = F.relu(self.conv3(x))
30
+ x = self.pixel_shuffle(self.conv4(x))
31
+ x = F.relu(self.conv5(x))
32
+ x = self.pixel_shuffle(self.conv6(x))
33
+ x = torch.sigmoid(x)
34
+ return 1 - x.view(-1, 128, 128)
baseline/Renderer/stroke_gen.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ def normal(x, width):
5
+ return (int)(x * (width - 1) + 0.5)
6
+
7
+ def draw(f, width=128):
8
+ x0, y0, x1, y1, x2, y2, z0, z2, w0, w2 = f
9
+ x1 = x0 + (x2 - x0) * x1
10
+ y1 = y0 + (y2 - y0) * y1
11
+ x0 = normal(x0, width * 2)
12
+ x1 = normal(x1, width * 2)
13
+ x2 = normal(x2, width * 2)
14
+ y0 = normal(y0, width * 2)
15
+ y1 = normal(y1, width * 2)
16
+ y2 = normal(y2, width * 2)
17
+ z0 = (int)(1 + z0 * width // 2)
18
+ z2 = (int)(1 + z2 * width // 2)
19
+ canvas = np.zeros([width * 2, width * 2]).astype('float32')
20
+ tmp = 1. / 100
21
+ for i in range(100):
22
+ t = i * tmp
23
+ x = (int)((1-t) * (1-t) * x0 + 2 * t * (1-t) * x1 + t * t * x2)
24
+ y = (int)((1-t) * (1-t) * y0 + 2 * t * (1-t) * y1 + t * t * y2)
25
+ z = (int)((1-t) * z0 + t * z2)
26
+ w = (1-t) * w0 + t * w2
27
+ cv2.circle(canvas, (y, x), z, w, -1)
28
+ return 1 - cv2.resize(canvas, dsize=(width, width))
baseline/env.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import json
3
+ import torch
4
+ import numpy as np
5
+ import argparse
6
+ import torchvision.transforms as transforms
7
+ import cv2
8
+ from DRL.ddpg import decode
9
+ from utils.util import *
10
+ from PIL import Image
11
+ from torchvision import transforms, utils
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ aug = transforms.Compose(
15
+ [transforms.ToPILImage(),
16
+ transforms.RandomHorizontalFlip(),
17
+ ])
18
+
19
+ width = 128
20
+ convas_area = width * width
21
+
22
+ img_train = []
23
+ img_test = []
24
+ train_num = 0
25
+ test_num = 0
26
+
27
+ class Paint:
28
+ def __init__(self, batch_size, max_step):
29
+ self.batch_size = batch_size
30
+ self.max_step = max_step
31
+ self.action_space = (13)
32
+ self.observation_space = (self.batch_size, width, width, 7)
33
+ self.test = False
34
+
35
+ def load_data(self):
36
+ # CelebA
37
+ global train_num, test_num
38
+ for i in range(200000):
39
+ img_id = '%06d' % (i + 1)
40
+ try:
41
+ img = cv2.imread('./data/img_align_celeba/' + img_id + '.jpg', cv2.IMREAD_UNCHANGED)
42
+ img = cv2.resize(img, (width, width))
43
+ if i > 2000:
44
+ train_num += 1
45
+ img_train.append(img)
46
+ else:
47
+ test_num += 1
48
+ img_test.append(img)
49
+ finally:
50
+ if (i + 1) % 10000 == 0:
51
+ print('loaded {} images'.format(i + 1))
52
+ print('finish loading data, {} training images, {} testing images'.format(str(train_num), str(test_num)))
53
+
54
+ def pre_data(self, id, test):
55
+ if test:
56
+ img = img_test[id]
57
+ else:
58
+ img = img_train[id]
59
+ if not test:
60
+ img = aug(img)
61
+ img = np.asarray(img)
62
+ return np.transpose(img, (2, 0, 1))
63
+
64
+ def reset(self, test=False, begin_num=False):
65
+ self.test = test
66
+ self.imgid = [0] * self.batch_size
67
+ self.gt = torch.zeros([self.batch_size, 3, width, width], dtype=torch.uint8).to(device)
68
+ for i in range(self.batch_size):
69
+ if test:
70
+ id = (i + begin_num) % test_num
71
+ else:
72
+ id = np.random.randint(train_num)
73
+ self.imgid[i] = id
74
+ self.gt[i] = torch.tensor(self.pre_data(id, test))
75
+ self.tot_reward = ((self.gt.float() / 255) ** 2).mean(1).mean(1).mean(1)
76
+ self.stepnum = 0
77
+ self.canvas = torch.zeros([self.batch_size, 3, width, width], dtype=torch.uint8).to(device)
78
+ self.lastdis = self.ini_dis = self.cal_dis()
79
+ return self.observation()
80
+
81
+ def observation(self):
82
+ # canvas B * 3 * width * width
83
+ # gt B * 3 * width * width
84
+ # T B * 1 * width * width
85
+ ob = []
86
+ T = torch.ones([self.batch_size, 1, width, width], dtype=torch.uint8) * self.stepnum
87
+ return torch.cat((self.canvas, self.gt, T.to(device)), 1) # canvas, img, T
88
+
89
+ def cal_trans(self, s, t):
90
+ return (s.transpose(0, 3) * t).transpose(0, 3)
91
+
92
+ def step(self, action):
93
+ self.canvas = (decode(action, self.canvas.float() / 255) * 255).byte()
94
+ self.stepnum += 1
95
+ ob = self.observation()
96
+ done = (self.stepnum == self.max_step)
97
+ reward = self.cal_reward() # np.array([0.] * self.batch_size)
98
+ return ob.detach(), reward, np.array([done] * self.batch_size), None
99
+
100
+ def cal_dis(self):
101
+ return (((self.canvas.float() - self.gt.float()) / 255) ** 2).mean(1).mean(1).mean(1)
102
+
103
+ def cal_reward(self):
104
+ dis = self.cal_dis()
105
+ reward = (self.lastdis - dis) / (self.ini_dis + 1e-8)
106
+ self.lastdis = dis
107
+ return to_numpy(reward)
baseline/utils/tensorboard.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import scipy.misc
3
+ from io import BytesIO
4
+ import tensorboardX as tb
5
+ from tensorboardX.summary import Summary
6
+
7
+ class TensorBoard(object):
8
+ def __init__(self, model_dir):
9
+ self.summary_writer = tb.FileWriter(model_dir)
10
+
11
+ def add_image(self, tag, img, step):
12
+ summary = Summary()
13
+ bio = BytesIO()
14
+
15
+ if type(img) == str:
16
+ img = PIL.Image.open(img)
17
+ elif type(img) == PIL.Image.Image:
18
+ pass
19
+ else:
20
+ img = scipy.misc.toimage(img)
21
+
22
+ img.save(bio, format="png")
23
+ image_summary = Summary.Image(encoded_image_string=bio.getvalue())
24
+ summary.value.add(tag=tag, image=image_summary)
25
+ self.summary_writer.add_summary(summary, global_step=step)
26
+
27
+ def add_scalar(self, tag, value, step):
28
+ summary = Summary(value=[Summary.Value(tag=tag, simple_value=value)])
29
+ self.summary_writer.add_summary(summary, global_step=step)
baseline/utils/util.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.autograd import Variable
4
+
5
+ USE_CUDA = torch.cuda.is_available()
6
+
7
+ def prRed(prt): print("\033[91m {}\033[00m" .format(prt))
8
+ def prGreen(prt): print("\033[92m {}\033[00m" .format(prt))
9
+ def prYellow(prt): print("\033[93m {}\033[00m" .format(prt))
10
+ def prLightPurple(prt): print("\033[94m {}\033[00m" .format(prt))
11
+ def prPurple(prt): print("\033[95m {}\033[00m" .format(prt))
12
+ def prCyan(prt): print("\033[96m {}\033[00m" .format(prt))
13
+ def prLightGray(prt): print("\033[97m {}\033[00m" .format(prt))
14
+ def prBlack(prt): print("\033[98m {}\033[00m" .format(prt))
15
+
16
+ def to_numpy(var):
17
+ return var.cpu().data.numpy() if USE_CUDA else var.data.numpy()
18
+
19
+ def to_tensor(ndarray, device):
20
+ return torch.tensor(ndarray, dtype=torch.float, device=device)
21
+
22
+ def soft_update(target, source, tau):
23
+ for target_param, param in zip(target.parameters(), source.parameters()):
24
+ target_param.data.copy_(
25
+ target_param.data * (1.0 - tau) + param.data * tau
26
+ )
27
+
28
+ def hard_update(target, source):
29
+ for m1, m2 in zip(target.modules(), source.modules()):
30
+ m1._buffers = m2._buffers.copy()
31
+ for target_param, param in zip(target.parameters(), source.parameters()):
32
+ target_param.data.copy_(param.data)
33
+
34
+ def get_output_folder(parent_dir, env_name):
35
+ """Return save folder.
36
+
37
+ Assumes folders in the parent_dir have suffix -run{run
38
+ number}. Finds the highest run number and sets the output folder
39
+ to that number + 1. This is just convenient so that if you run the
40
+ same script multiple times tensorboard can plot all of the results
41
+ on the same plots with different names.
42
+
43
+ Parameters
44
+ ----------
45
+ parent_dir: str
46
+ Path of the directory containing all experiment runs.
47
+
48
+ Returns
49
+ -------
50
+ parent_dir/run_dir
51
+ Path to this run's save directory.
52
+ """
53
+ os.makedirs(parent_dir, exist_ok=True)
54
+ experiment_id = 0
55
+ for folder_name in os.listdir(parent_dir):
56
+ if not os.path.isdir(os.path.join(parent_dir, folder_name)):
57
+ continue
58
+ try:
59
+ folder_name = int(folder_name.split('-run')[-1])
60
+ if folder_name > experiment_id:
61
+ experiment_id = folder_name
62
+ except:
63
+ pass
64
+ experiment_id += 1
65
+
66
+ parent_dir = os.path.join(parent_dir, env_name)
67
+ parent_dir = parent_dir + '-run{}'.format(experiment_id)
68
+ os.makedirs(parent_dir, exist_ok=True)
69
+ return parent_dir
ckpts/actor.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75e908a42d4c90ad092892a9183bd467453696d8e2202b9640f9a5b4a488eab9
3
+ size 44898539
ckpts/renderer.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34639d4ab4f30807fb056bf4b71de0c3f434d6cbe8600cf97b6af3a855a5ca8e
3
+ size 44165821
image/JayChou.png ADDED
image/Leslie.png ADDED
image/Trump.png ADDED
image/chaoyue.png ADDED
image/degang.png ADDED
image/lisa.png ADDED
image/mayun.png ADDED
image/test.png ADDED