IBYDMT / ibydmt /payoff.py
jacopoteneggi's picture
Update
4f55ca2 verified
raw
history blame
4.98 kB
from abc import ABC, abstractmethod
from functools import reduce
import numpy as np
from sklearn.metrics import pairwise_distances
from sklearn.metrics.pairwise import linear_kernel, rbf_kernel
from ibydmt.bet import get_bet
class Payoff(ABC):
def __init__(self, config):
self.bet = get_bet(config.bet)(config)
@abstractmethod
def compute(self, *args, **kwargs):
pass
class Kernel:
def __init__(self, kernel: str, scale_method: str, scale: float):
if kernel == "linear":
self.base_kernel = linear_kernel
elif kernel == "rbf":
self.base_kernel = rbf_kernel
self.scale_method = scale_method
self.scale = scale
self.gamma = None
self.recompute_gamma = True
self.prev = None
else:
raise NotImplementedError(f"{kernel} is not implemented")
def __call__(self, x, y):
if self.base_kernel == linear_kernel:
return self.base_kernel(x, y)
if self.base_kernel == rbf_kernel:
if self.scale_method == "constant":
self.gamma = self.scale
elif self.scale_method == "quantile":
if self.prev is None:
self.prev = y
if self.recompute_gamma:
dist = pairwise_distances(
self.prev.reshape(-1, self.prev.shape[-1])
)
scale = np.quantile(dist, self.scale)
gamma = 1 / (2 * scale**2) if scale > 0 else None
self.gamma = gamma
if len(self.prev) > 100:
self.recompute_gamma = False
self.prev = np.vstack([self.prev, x])
else:
raise NotImplementedError(
f"{self.scale} is not implemented for rbf_kernel"
)
return self.base_kernel(x, y, gamma=self.gamma)
class KernelPayoff(Payoff):
def __init__(self, config):
super().__init__(config)
self.kernel = config.kernel
self.scale_method = config.get("kernel_scale_method", "quantile")
self.scale = config.get("kernel_scale", 0.5)
@abstractmethod
def witness_function(self, d, prev_d):
pass
def compute(self, d, null_d, prev_d):
g = reduce(
lambda acc, u: acc
+ self.witness_function(u[0], prev_d)
- self.witness_function(u[1], prev_d),
zip(d, null_d),
0,
)
return self.bet.compute(g)
class HSIC(KernelPayoff):
def __init__(self, config):
super().__init__(config)
kernel = self.kernel
scale_method = self.scale_method
scale = self.scale
self.kernel_y = Kernel(kernel, scale_method, scale)
self.kernel_z = Kernel(kernel, scale_method, scale)
def witness_function(self, d, prev_d):
y, z = d
prev_y, prev_z = prev_d[:, 0], prev_d[:, 1]
y_mat = self.kernel_y(y.reshape(-1, 1), prev_y.reshape(-1, 1))
z_mat = self.kernel_z(z.reshape(-1, 1), prev_z.reshape(-1, 1))
mu_joint = np.mean(y_mat * z_mat)
mu_prod = np.mean(y_mat, axis=1) @ np.mean(z_mat, axis=1)
return mu_joint - mu_prod
class cMMD(KernelPayoff):
def __init__(self, config):
super().__init__(config)
kernel = self.kernel
scale_method = self.scale_method
scale = self.scale
self.kernel_y = Kernel(kernel, scale_method, scale)
self.kernel_zj = Kernel(kernel, scale_method, scale)
self.kernel_cond_z = Kernel(kernel, scale_method, scale)
def witness_function(self, u, prev_d):
y, zj, cond_z = u[0], u[1], u[2:]
prev_y, prev_zj, prev_null_zj, prev_cond_z = (
prev_d[:, 0],
prev_d[:, 1],
prev_d[:, 2],
prev_d[:, 3:],
)
y_mat = self.kernel_y(y.reshape(-1, 1), prev_y.reshape(-1, 1))
zj_mat = self.kernel_zj(zj.reshape(-1, 1), prev_zj.reshape(-1, 1))
cond_z_mat = self.kernel_cond_z(
cond_z.reshape(-1, prev_cond_z.shape[1]),
prev_cond_z.reshape(-1, prev_cond_z.shape[1]),
)
null_zj_mat = self.kernel_zj(zj.reshape(-1, 1), prev_null_zj.reshape(-1, 1))
mu = np.mean(y_mat * zj_mat * cond_z_mat)
mu_null = np.mean(y_mat * null_zj_mat * cond_z_mat)
return mu - mu_null
class xMMD(KernelPayoff):
def __init__(self, config):
super().__init__(config)
self.kernel = Kernel(self.kernel, self.scale_method, self.scale)
def witness_function(self, u, prev_d):
prev_y, prev_y_null = prev_d[:, 0], prev_d[:, 1]
mu_y = np.mean(self.kernel(u.reshape(-1, 1), prev_y.reshape(-1, 1)), axis=1)
mu_y_null = np.mean(
self.kernel(u.reshape(-1, 1), prev_y_null.reshape(-1, 1)), axis=1
)
return mu_y - mu_y_null