File size: 2,666 Bytes
51da11a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch

class StyleTransferController(torch.nn.Module):
    def __init__(
        self,
        num_control_params,
        edim,
        hidden_dim=256,
        agg_method="mlp",
    ):
        """Plugin parameter controller module to map from input to target style.

        Args:
            num_control_params (int): Number of plugin parameters to predicted.
            edim (int): Size of the encoder representations.
            hidden_dim (int, optional): Hidden size of the 3-layer parameter predictor MLP. Default: 256
            agg_method (str, optional): Input/reference embed aggregation method ["conv" or "linear", "mlp"]. Default: "mlp"
        """
        super().__init__()
        self.num_control_params = num_control_params
        self.edim = edim
        self.hidden_dim = hidden_dim
        self.agg_method = agg_method

        if agg_method == "conv":
            self.agg = torch.nn.Conv1d(
                2,
                1,
                kernel_size=129,
                stride=1,
                padding="same",
                bias=False,
            )
            mlp_in_dim = edim
        elif agg_method == "linear":
            self.agg = torch.nn.Linear(edim * 2, edim)
        elif agg_method == "mlp":
            self.agg = None
            mlp_in_dim = edim * 2
        else:
            raise ValueError(f"Invalid agg_method = {self.agg_method}.")

        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(mlp_in_dim, hidden_dim),
            torch.nn.LeakyReLU(0.01),
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.LeakyReLU(0.01),
            torch.nn.Linear(hidden_dim, num_control_params),
            torch.nn.Sigmoid(),  # normalize between 0 and 1
        )

    def forward(self, e_x, e_y, z=None):
        """Forward pass to generate plugin parameters.

        Args:
            e_x (tensor): Input signal embedding of shape (batch, edim)
            e_y (tensor): Target signal embedding of shape (batch, edim)
        Returns:
            p (tensor): Estimated control parameters of shape (batch, num_control_params)
        """

        # use learnable projection
        if self.agg_method == "conv":
            e_xy = torch.stack((e_x, e_y), dim=1)  # concat on channel dim
            e_xy = self.agg(e_xy)
        elif self.agg_method == "linear":
            e_xy = torch.cat((e_x, e_y), dim=-1)  # concat on embed dim
            e_xy = self.agg(e_xy)
        else:
            e_xy = torch.cat((e_x, e_y), dim=-1)  # concat on embed dim

        # pass through MLP to project to control parametesr
        p = self.mlp(e_xy.squeeze(1))

        return p