Spaces:
Sleeping
Sleeping
File size: 5,015 Bytes
4f55ca2 a40e67a 4f55ca2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
from abc import ABC, abstractmethod
from functools import reduce
import numpy as np
from sklearn.metrics import pairwise_distances
from sklearn.metrics.pairwise import linear_kernel, rbf_kernel
from ibydmt.bet import get_bet
class Payoff(ABC):
def __init__(self, config):
self.bet = get_bet(config.bet)(config)
@abstractmethod
def compute(self, *args, **kwargs):
pass
class Kernel:
def __init__(self, kernel: str, scale_method: str, scale: float):
if kernel == "linear":
self.base_kernel = linear_kernel
elif kernel == "rbf":
self.base_kernel = rbf_kernel
self.scale_method = scale_method
self.scale = scale
self.gamma = None
self.recompute_gamma = True
self.prev = None
else:
raise NotImplementedError(f"{kernel} is not implemented")
def __call__(self, x, y):
if self.base_kernel == linear_kernel:
return self.base_kernel(x, y)
if self.base_kernel == rbf_kernel:
if self.scale_method == "constant":
self.gamma = self.scale
elif self.scale_method == "quantile":
if self.prev is None:
self.prev = y
if self.recompute_gamma:
dist = pairwise_distances(
self.prev.reshape(-1, self.prev.shape[-1])
)
scale = np.quantile(dist, self.scale)
gamma = 1 / (2 * scale**2) if scale > 0 else None
self.gamma = gamma
if len(self.prev) > 100:
self.recompute_gamma = False
self.prev = np.vstack([self.prev, x])
else:
raise NotImplementedError(
f"{self.scale} is not implemented for rbf_kernel"
)
return self.base_kernel(x, y, gamma=self.gamma)
class KernelPayoff(Payoff):
def __init__(self, config):
super().__init__(config)
self.kernel = config.kernel
self.scale_method = config.get("kernel_scale_method", "quantile")
self.scale = config.get("kernel_scale", 0.5)
@abstractmethod
def witness_function(self, d, prev_d):
pass
def compute(self, d, null_d, prev_d):
g = reduce(
lambda acc, u: acc
+ self.witness_function(u[0], prev_d)
- self.witness_function(u[1], prev_d),
zip(d, null_d),
0,
)
g = g.squeeze().item()
return self.bet.compute(g)
class HSIC(KernelPayoff):
def __init__(self, config):
super().__init__(config)
kernel = self.kernel
scale_method = self.scale_method
scale = self.scale
self.kernel_y = Kernel(kernel, scale_method, scale)
self.kernel_z = Kernel(kernel, scale_method, scale)
def witness_function(self, d, prev_d):
y, z = d
prev_y, prev_z = prev_d[:, 0], prev_d[:, 1]
y_mat = self.kernel_y(y.reshape(-1, 1), prev_y.reshape(-1, 1))
z_mat = self.kernel_z(z.reshape(-1, 1), prev_z.reshape(-1, 1))
mu_joint = np.mean(y_mat * z_mat)
mu_prod = np.mean(y_mat, axis=1) @ np.mean(z_mat, axis=1)
return mu_joint - mu_prod
class cMMD(KernelPayoff):
def __init__(self, config):
super().__init__(config)
kernel = self.kernel
scale_method = self.scale_method
scale = self.scale
self.kernel_y = Kernel(kernel, scale_method, scale)
self.kernel_zj = Kernel(kernel, scale_method, scale)
self.kernel_cond_z = Kernel(kernel, scale_method, scale)
def witness_function(self, u, prev_d):
y, zj, cond_z = u[0], u[1], u[2:]
prev_y, prev_zj, prev_null_zj, prev_cond_z = (
prev_d[:, 0],
prev_d[:, 1],
prev_d[:, 2],
prev_d[:, 3:],
)
y_mat = self.kernel_y(y.reshape(-1, 1), prev_y.reshape(-1, 1))
zj_mat = self.kernel_zj(zj.reshape(-1, 1), prev_zj.reshape(-1, 1))
cond_z_mat = self.kernel_cond_z(
cond_z.reshape(-1, prev_cond_z.shape[1]),
prev_cond_z.reshape(-1, prev_cond_z.shape[1]),
)
null_zj_mat = self.kernel_zj(zj.reshape(-1, 1), prev_null_zj.reshape(-1, 1))
mu = np.mean(y_mat * zj_mat * cond_z_mat)
mu_null = np.mean(y_mat * null_zj_mat * cond_z_mat)
return mu - mu_null
class xMMD(KernelPayoff):
def __init__(self, config):
super().__init__(config)
self.kernel = Kernel(self.kernel, self.scale_method, self.scale)
def witness_function(self, u, prev_d):
prev_y, prev_y_null = prev_d[:, 0], prev_d[:, 1]
mu_y = np.mean(self.kernel(u.reshape(-1, 1), prev_y.reshape(-1, 1)), axis=1)
mu_y_null = np.mean(
self.kernel(u.reshape(-1, 1), prev_y_null.reshape(-1, 1)), axis=1
)
return mu_y - mu_y_null
|