File size: 5,737 Bytes
66a6dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import torch
import numpy as np
import torch.multiprocessing as mp

from deepafx_st.processors.dsp.peq import ParametricEQ
from deepafx_st.processors.dsp.compressor import Compressor
from deepafx_st.processors.spsa.spsa_func import SPSAFunction
from deepafx_st.utils import rademacher


def dsp_func(x, p, dsp, sample_rate=24000):

    (peq, comp), meta = dsp

    p_peq = p[:meta]
    p_comp = p[meta:]

    y = peq(x, p_peq, sample_rate)
    y = comp(y, p_comp, sample_rate)

    return y


class SPSAChannel(torch.nn.Module):
    """

    Args:
        sample_rate (float): Sample rate of the plugin instance
        parallel (bool, optional): Use parallel workers for DSP.

    By default, this utilizes parallelized instances of the plugin channel,
    where the number of workers is equal to the batch size.
    """

    def __init__(
        self,
        sample_rate: int,
        parallel: bool = False,
        batch_size: int = 8,
    ):
        super().__init__()

        self.batch_size = batch_size
        self.parallel = parallel

        if self.parallel:
            self.apply_func = SPSAFunction.apply

            procs = {}
            for b in range(self.batch_size):

                peq = ParametricEQ(sample_rate)
                comp = Compressor(sample_rate)
                dsp = ((peq, comp), peq.num_control_params)

                parent_conn, child_conn = mp.Pipe()
                p = mp.Process(target=SPSAChannel.worker_pipe, args=(child_conn, dsp))
                p.start()
                procs[b] = [p, parent_conn, child_conn]
                #print(b, p)

                # Update stuff for external public members TODO: fix
                self.ports = [peq.ports, comp.ports]
                self.num_control_params = (
                    comp.num_control_params + peq.num_control_params
                )

            self.procs = procs
            #print(self.procs)

        else:
            self.peq = ParametricEQ(sample_rate)
            self.comp = Compressor(sample_rate)
            self.apply_func = SPSAFunction.apply
            self.ports = [self.peq.ports, self.comp.ports]
            self.num_control_params = (
                self.comp.num_control_params + self.peq.num_control_params
            )
            self.dsp = ((self.peq, self.comp), self.peq.num_control_params)

        # add one param for wet/dry mix
        # self.num_control_params += 1

    def __del__(self):
        if hasattr(self, "procs"):
            for proc_idx, proc in self.procs.items():
                #print(f"Closing {proc_idx}...")
                proc[0].terminate()

    def forward(self, x, p, epsilon=0.001, sample_rate=24000, **kwargs):
        """
        Args:
            x (Tensor): Input signal with shape: [batch x channels x samples]
            p (Tensor): Audio effect control parameters with shape: [batch x parameters]
            epsilon (float, optional): Twiddle parameter range for SPSA gradient estimation.

        Returns:
            y (Tensor): Processed audio signal.

        """
        if self.parallel:
            y = self.apply_func(x, p, None, epsilon, self, sample_rate)

        else:
            # this will process on CPU in NumPy
            y = self.apply_func(x, p, None, epsilon, self, sample_rate)

        return y.type_as(x)

    @staticmethod
    def static_backward(dsp, value):

        (
            batch_index,
            x,
            params,
            needs_input_grad,
            needs_param_grad,
            grad_output,
            epsilon,
        ) = value

        grads_input = None
        grads_params = None
        ps = params.shape[-1]
        factors = [1.0]

        # estimate gradient w.r.t input
        if needs_input_grad:
            delta_k = rademacher(x.shape).numpy()
            J_plus = dsp_func(x + epsilon * delta_k, params, dsp)
            J_minus = dsp_func(x - epsilon * delta_k, params, dsp)
            grads_input = (J_plus - J_minus) / (2.0 * epsilon)

        # estimate gradient w.r.t params
        grads_params_runs = []
        if needs_param_grad:
            for factor in factors:
                params_sublist = []
                delta_k = rademacher(params.shape).numpy()

                # compute output in two random directions of the parameter space
                params_plus = np.clip(params + (factor * epsilon * delta_k), 0, 1)
                J_plus = dsp_func(x, params_plus, dsp)

                params_minus = np.clip(params - (factor * epsilon * delta_k), 0, 1)
                J_minus = dsp_func(x, params_minus, dsp)
                grad_param = J_plus - J_minus

                # compute gradient for each parameter as a function of epsilon and random direction
                for sub_p_idx in range(ps):
                    grad_p = grad_param / (2 * epsilon * delta_k[sub_p_idx])
                    params_sublist.append(np.sum(grad_output * grad_p))

                grads_params = np.array(params_sublist)
                grads_params_runs.append(grads_params)

            # average gradients
            grads_params = np.mean(grads_params_runs, axis=0)

        return grads_input, grads_params

    @staticmethod
    def static_forward(dsp, value):
        batch_index, x, p, sample_rate = value
        y = dsp_func(x, p, dsp, sample_rate)
        return y

    @staticmethod
    def worker_pipe(child_conn, dsp):

        while True:
            msg, value = child_conn.recv()
            if msg == "forward":
                child_conn.send(SPSAChannel.static_forward(dsp, value))
            elif msg == "backward":
                child_conn.send(SPSAChannel.static_backward(dsp, value))
            elif msg == "shutdown":
                break