| import chex |
| import jax |
| from jax import random, jit, vmap |
| import jax.numpy as jnp |
| from functools import partial |
| from consts import EPS |
| from globals import UserInfo, Char, State |
|
|
|
|
| def norm_playtime(arr: chex.Array, cid: int) -> chex.Array: |
| max_playtime = jnp.max(arr) + EPS |
| norm = arr[cid] / max_playtime |
| return norm |
|
|
|
|
| @jit |
| def construct_feats(user: UserInfo, char: Char, char_id: int) -> chex.Array: |
| feats = [ |
| user.skill_level, |
| jnp.log1p(user.games_played), |
| char.difficulty, |
| char.execution_level, |
| char.neutral_required, |
| char.tier, |
| ] |
| feats.append(char.archetype_vec) |
| skill_match = 1.0 - jnp.abs(user.skill_level - (1.0 - char.difficulty)) |
|
|
| feats.append(jnp.array([skill_match])) |
|
|
| archetype_sim = jnp.dot(user.pref_archetype, char.archetype_vec) |
| feats.append(jnp.array([archetype_sim])) |
|
|
| tried_before = user.chars_attempted_mask[char_id] |
| novelty_bonus = 1.0 - tried_before |
| feats.append(jnp.array([novelty_bonus])) |
|
|
| past_perf = jnp.where(tried_before > 0.5, user.wr[char_id], 0.5) |
|
|
| feats.append(jnp.array([past_perf])) |
|
|
| norm = norm_playtime(user.playtime, char_id) |
| feats.append(jnp.array([norm])) |
|
|
| return jnp.concatenate([jnp.atleast_1d(feat) for feat in feats]) |
|
|
|
|
| @partial(jit, static_argnums=(2,)) |
| def build_feats(user: UserInfo, chars: Char, n_chars: int): |
| def build_single(cid: int): |
| char = jax.tree.map(lambda x: x[cid], chars) |
| return construct_feats(user, char, cid) |
|
|
| return vmap(build_single)(jnp.arange(n_chars)) |
|
|
|
|
| @jit |
| def sample_params(key: chex.PRNGKey, mu: chex.Array, Sigma: chex.Array) -> chex.Array: |
| d = mu.shape[0] |
| Lambda = Sigma + EPS * jnp.eye(d) |
| theta = random.multivariate_normal(key, mu, Lambda) |
| return theta |
|
|
|
|
| @jit |
| def compute_expected_rewards(thetas: chex.Array, feats: chex.Array) -> chex.Array: |
| return vmap(jnp.dot)(thetas, feats) |
|
|
|
|
| @jit |
| def thompson_sample( |
| key: chex.PRNGKey, state: State, feats: chex.Array |
| ) -> tuple[chex.Array, chex.Array]: |
| num_chars = feats.shape[0] |
| keys = random.split(key, num_chars) |
|
|
| thetas = vmap(sample_params)(keys, state.mu, state.Sigma) |
| rewards = compute_expected_rewards(thetas, feats) |
| return rewards, thetas |
|
|
|
|
| @jit |
| def update_posterior( |
| state: State, |
| char_id: int, |
| feats: chex.Array, |
| reward: float, |
| noise_var: float = 1.0, |
| use_adaptive_noise: bool = True, |
| ) -> State: |
| x = feats |
| d = x.shape[0] |
| mu_old = state.mu[char_id] |
| sigma_old = state.Sigma[char_id] |
|
|
| |
| |
|
|
| Sigma_old_inv = jnp.linalg.inv(sigma_old + EPS * jnp.eye(d)) |
| Sigma_new_inv = Sigma_old_inv + (1.0 / noise_var) * jnp.outer(x, x) |
| Sigma_new = jnp.linalg.inv(Sigma_new_inv) |
|
|
| mu_new = Sigma_new @ (Sigma_old_inv @ mu_old + (reward / noise_var) * x) |
|
|
| new_mu = state.mu.at[char_id].set(mu_new) |
| new_Sigma = state.Sigma.at[char_id].set(Sigma_new) |
|
|
| |
| new_beta = None |
|
|
| if use_adaptive_noise: |
| new_beta = state.beta.at[char_id].add(1) |
| return State( |
| mu=new_mu, |
| Sigma=new_Sigma, |
| alpha=state.alpha, |
| beta=new_beta if new_beta is not None else state.beta, |
| ) |
|
|
|
|
| @partial(jit, static_argnums=(2, 3)) |
| def select_top_k_diverse( |
| scores: chex.Array, archetypes: chex.Array, k: int, diversity_thresh: float |
| ) -> chex.Array: |
| n_chars = scores.shape[0] |
| sorted_idx = jnp.argsort(-scores) |
|
|
| def selection_step(carry, cand_idx): |
| select, cnt = carry |
| cand_idx = sorted_idx[cand_idx] |
|
|
| done = cnt > k |
|
|
| cand_arch = archetypes[cand_idx] |
|
|
| def check_item_diversity(sel_idx): |
| |
| is_valid = sel_idx >= 0 |
| sel_arch = archetypes[sel_idx] |
| |
|
|
| sim = jnp.dot(cand_arch, sel_arch) / ( |
| jnp.linalg.norm(cand_arch) * jnp.linalg.norm(sel_arch) + 1e-8 |
| ) |
| return jnp.where(is_valid, sim < diversity_thresh, True) |
|
|
| all_diverse = jnp.all(vmap(check_item_diversity)(select)) |
|
|
| add_op = jnp.logical_and(jnp.logical_not(done), all_diverse) |
|
|
| new_sel = jnp.where(add_op, select.at[cnt].set(cand_idx), select) |
| new_cnt = jnp.where(add_op, cnt + 1, cnt) |
| return (new_sel, new_cnt), None |
|
|
| init = jnp.full(k, -1, dtype=jnp.int32) |
| init = init.at[0].set(sorted_idx[0]) |
|
|
| (final_sel, null), null = jax.lax.scan( |
| selection_step, (init, 1), jnp.arange(1, n_chars) |
| ) |
| return final_sel |
|
|
|
|
| @jit |
| def compute_reward( |
| won: bool, completed: bool, rating: float, playtime_mins: float, weights:chex.Array = jnp.array([0.3, 0.15, 0.25, 0.3]) |
| ) -> float: |
| win_reward = jnp.where(won, weights[0], 0.0) |
| completion_reward = jnp.where(completed, weights[1], 0.0) |
| rating_reward = weights[2] * jnp.clip(rating / 5.0, 0.0, 1.0) |
|
|
| engagement_reward = weights[3] * jnp.clip(jnp.log1p(playtime_mins) / 5.0, 0.0, 1.0) |
| return win_reward + completion_reward + rating_reward + engagement_reward |
|
|
|
|
| @partial(jit, static_argnums=(4, 5)) |
| def recommend_characters( |
| key: chex.PRNGKey, |
| state: State, |
| user: UserInfo, |
| characters: Char, |
| n_chars: int, |
| top_k: int = 3, |
| diversity_threshold: float = 0.75, |
| ) -> tuple[chex.Array, chex.Array]: |
| features = build_feats(user, characters, n_chars) |
| sampled_rewards, sampled_thetas = thompson_sample(key, state, features) |
|
|
| selected = select_top_k_diverse( |
| sampled_rewards, characters.archetype_vec, top_k, diversity_threshold |
| ) |
|
|
| return selected, sampled_rewards |
|
|
|
|
| def init_thompson(n_chars: int, feature_dim: int, prior_var: float = 1.0) -> State: |
| return State( |
| mu=jnp.zeros((n_chars, feature_dim)), |
| Sigma=jnp.tile(prior_var * jnp.eye(feature_dim), (n_chars, 1, 1)), |
| alpha=jnp.ones(n_chars), |
| beta=jnp.ones(n_chars), |
| ) |
|
|
|
|
| @jit |
| def batch_update_posterior( |
| state: State, |
| char_ids: chex.Array, |
| features: chex.Array, |
| rewards: chex.Array, |
| noise_var: float = 1.0, |
| ) -> State: |
| def single_update(s, data): |
| char_id, feat, reward = data |
| return update_posterior(s, char_id, feat, reward, noise_var), None |
|
|
| final_state, _ = jax.lax.scan(single_update, state, (char_ids, features, rewards)) |
| return final_state |
|
|
|
|
|
|