import torch def spsa_func(input, params, process, i, sample_rate=24000): return process(input.cpu(), params.cpu(), i, sample_rate).type_as(input) class SPSAFunction(torch.autograd.Function): @staticmethod def forward( ctx, input, params, process, epsilon, thread_context, sample_rate=24000, ): """Apply processor to a batch of tensors using given parameters. Args: input (Tensor): Audio with shape: batch x 2 x samples params (Tensor): Processor parameters with shape: batch x params process (function): Function that will apply processing. epsilon (float): Perturbation strength for SPSA computation. Returns: output (Tensor): Processed audio with same shape as input. """ ctx.save_for_backward(input, params) ctx.epsilon = epsilon ctx.process = process ctx.thread_context = thread_context if thread_context.parallel: for i in range(input.shape[0]): msg = ( "forward", ( i, input[i].view(-1).detach().cpu().numpy(), params[i].view(-1).detach().cpu().numpy(), sample_rate, ), ) thread_context.procs[i][1].send(msg) z = torch.empty_like(input) for i in range(input.shape[0]): z[i] = torch.from_numpy(thread_context.procs[i][1].recv()) else: z = torch.empty_like(input) for i in range(input.shape[0]): value = ( i, input[i].view(-1).detach().cpu().numpy(), params[i].view(-1).detach().cpu().numpy(), sample_rate, ) z[i] = torch.from_numpy( thread_context.static_forward(thread_context.dsp, value) ) return z @staticmethod def backward(ctx, grad_output): """Estimate gradients using SPSA.""" input, params = ctx.saved_tensors epsilon = ctx.epsilon needs_input_grad = ctx.needs_input_grad[0] needs_param_grad = ctx.needs_input_grad[1] thread_context = ctx.thread_context grads_input = None grads_params = None # Receive grads if needs_input_grad: grads_input = torch.empty_like(input) if needs_param_grad: grads_params = torch.empty_like(params) if thread_context.parallel: for i in range(input.shape[0]): msg = ( "backward", ( i, input[i].view(-1).detach().cpu().numpy(), params[i].view(-1).detach().cpu().numpy(), needs_input_grad, needs_param_grad, grad_output[i].view(-1).detach().cpu().numpy(), epsilon, ), ) thread_context.procs[i][1].send(msg) # Wait for output for i in range(input.shape[0]): temp1, temp2 = thread_context.procs[i][1].recv() if temp1 is not None: grads_input[i] = torch.from_numpy(temp1) if temp2 is not None: grads_params[i] = torch.from_numpy(temp2) return grads_input, grads_params, None, None, None, None else: for i in range(input.shape[0]): value = ( i, input[i].view(-1).detach().cpu().numpy(), params[i].view(-1).detach().cpu().numpy(), needs_input_grad, needs_param_grad, grad_output[i].view(-1).detach().cpu().numpy(), epsilon, ) temp1, temp2 = thread_context.static_backward(thread_context.dsp, value) if temp1 is not None: grads_input[i] = torch.from_numpy(temp1) if temp2 is not None: grads_params[i] = torch.from_numpy(temp2) return grads_input, grads_params, None, None, None, None