| """Unlearning methods wrapper.""" |
| import numpy as np |
| from typing import Dict, Optional, Tuple |
| from src.graph_utils import build_adjacency, get_deletion_neighborhood, get_user_item_sets_in_radius |
|
|
|
|
| def one_step_downdate_poisson_gamma(edges, edge_to_remove, full_params, |
| N, M, K, a0, b0, c0, d0): |
| """One-step local downdate: subtract deleted contribution from seed blocks only.""" |
| import time |
| t0 = time.time() |
| i_del, j_del, x_del = edge_to_remove |
| |
| params = {k: v.copy() for k, v in full_params.items()} |
| a, b, c, d = params['a'], params['b'], params['c'], params['d'] |
| |
| if x_del > 0: |
| |
| from scipy.special import digamma |
| from src.utils import stable_softmax |
| |
| log_r = digamma(a[i_del]) - np.log(b[i_del]) + digamma(c[j_del]) - np.log(d[j_del]) |
| r = stable_softmax(log_r) |
| |
| |
| a[i_del] = np.maximum(a0, a[i_del] - x_del * r) |
| b[i_del] = np.maximum(b0 * 0.5, b[i_del] - c[j_del] / d[j_del]) |
| |
| |
| c[j_del] = np.maximum(c0, c[j_del] - x_del * r) |
| d[j_del] = np.maximum(d0 * 0.5, d[j_del] - a[i_del] / b[i_del]) |
| |
| runtime = time.time() - t0 |
| |
| from src.utils import FitResult |
| return FitResult( |
| params=params, |
| objective_trace=[], |
| n_iterations=1, |
| converged=True, |
| runtime_sec=runtime, |
| model_family='poisson_gamma', |
| inference_type='vi', |
| likelihood='poisson', |
| prior='gamma', |
| diagnostics={'method': 'one_step_downdate'} |
| ) |
|
|
|
|
| def one_step_downdate_gaussian(edges, edge_to_remove, full_params, |
| N, M, K, sigma_x, sigma_U=None, sigma_V=None, |
| model_family='gaussian_gaussian'): |
| """One-step local downdate for Gaussian models.""" |
| import time |
| t0 = time.time() |
| i_del, j_del, x_del = edge_to_remove |
| |
| params = {k: v.copy() for k, v in full_params.items()} |
| |
| if model_family == 'gaussian_gaussian': |
| m_U, s_U = params['m_U'], params['s_U'] |
| m_V, s_V = params['m_V'], params['s_V'] |
| prec_x = 1.0 / (sigma_x ** 2) |
| |
| |
| for k in range(K): |
| old_prec = 1.0 / s_U[i_del, k] |
| new_prec = old_prec - prec_x * (m_V[j_del, k]**2 + s_V[j_del, k]) |
| new_prec = max(new_prec, 1.0 / (sigma_U**2 if sigma_U else 1.0)) |
| s_U[i_del, k] = 1.0 / new_prec |
| |
| for k in range(K): |
| old_prec = 1.0 / s_V[j_del, k] |
| new_prec = old_prec - prec_x * (m_U[i_del, k]**2 + s_U[i_del, k]) |
| new_prec = max(new_prec, 1.0 / (sigma_V**2 if sigma_V else 1.0)) |
| s_V[j_del, k] = 1.0 / new_prec |
| |
| elif model_family == 'gaussian_gamma_map': |
| |
| pass |
| |
| runtime = time.time() - t0 |
| |
| from src.utils import FitResult |
| return FitResult( |
| params=params, objective_trace=[], n_iterations=1, |
| converged=True, runtime_sec=runtime, |
| model_family=model_family, inference_type='vi' if 'gaussian_gaussian' == model_family else 'map', |
| likelihood='gaussian', prior='gaussian' if model_family == 'gaussian_gaussian' else 'gamma', |
| diagnostics={'method': 'one_step_downdate'} |
| ) |
|
|