File size: 3,170 Bytes
9ff1108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import paddle
import paddle.nn as nn
import math

class Painter(nn.Layer):
    """
    network architecture written in paddle.
    """
    def __init__(self, param_per_stroke, total_strokes, hidden_dim, n_heads=8, n_enc_layers=3, n_dec_layers=3):
        super().__init__()
        self.enc_img = nn.Sequential(
            nn.Pad2D([1, 1, 1, 1], 'reflect'),
            nn.Conv2D(3, 32, 3, 1),
            nn.BatchNorm2D(32),
            nn.ReLU(),  # maybe replace with the inplace version
            nn.Pad2D([1, 1, 1, 1], 'reflect'),
            nn.Conv2D(32, 64, 3, 2),
            nn.BatchNorm2D(64),
            nn.ReLU(),
            nn.Pad2D([1, 1, 1, 1], 'reflect'),
            nn.Conv2D(64, 128, 3, 2),
            nn.BatchNorm2D(128),
            nn.ReLU())
        self.enc_canvas = nn.Sequential(
            nn.Pad2D([1, 1, 1, 1], 'reflect'),
            nn.Conv2D(3, 32, 3, 1),
            nn.BatchNorm2D(32),
            nn.ReLU(),
            nn.Pad2D([1, 1, 1, 1], 'reflect'),
            nn.Conv2D(32, 64, 3, 2),
            nn.BatchNorm2D(64),
            nn.ReLU(),
            nn.Pad2D([1, 1, 1, 1], 'reflect'),
            nn.Conv2D(64, 128, 3, 2),
            nn.BatchNorm2D(128),
            nn.ReLU())
        self.conv = nn.Conv2D(128 * 2, hidden_dim, 1)
        self.transformer = nn.Transformer(hidden_dim, n_heads, n_enc_layers, n_dec_layers)
        self.linear_param = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, param_per_stroke))
        self.linear_decider = nn.Linear(hidden_dim, 1)
        self.query_pos = paddle.static.create_parameter([total_strokes, hidden_dim], dtype='float32', 
                                                        default_initializer=nn.initializer.Uniform(0, 1))
        self.row_embed = paddle.static.create_parameter([8, hidden_dim // 2], dtype='float32', 
                                                        default_initializer=nn.initializer.Uniform(0, 1))
        self.col_embed = paddle.static.create_parameter([8, hidden_dim // 2], dtype='float32', 
                                                        default_initializer=nn.initializer.Uniform(0, 1))

    def forward(self, img, canvas):
        """
        prediction
        """
        b, _, H, W = img.shape
        img_feat = self.enc_img(img)
        canvas_feat = self.enc_canvas(canvas)
        h, w = img_feat.shape[-2:]
        feat = paddle.concat([img_feat, canvas_feat], axis=1)
        feat_conv = self.conv(feat)

        pos_embed = paddle.concat([
            self.col_embed[:w].unsqueeze(0).tile([h, 1, 1]),
            self.row_embed[:h].unsqueeze(1).tile([1, w, 1]),
        ], axis=-1).flatten(0, 1).unsqueeze(1)

        hidden_state = self.transformer((pos_embed + feat_conv.flatten(2).transpose([2, 0, 1])).transpose([1, 0, 2]),
                                        self.query_pos.unsqueeze(1).tile([1, b, 1]).transpose([1, 0, 2]))

        param = self.linear_param(hidden_state)
        decision = self.linear_decider(hidden_state)
        return param, decision