demo / src /helpers /generator.py
Xmaster6y's picture
working interface
340463d unverified
raw
history blame
2.11 kB
"""Script to generate features for a given board state.
"""
from typing import Optional
from lczerolens import ModelWrapper
from lczerolens.xai import ActivationLens
from lczerolens.encodings import InputEncoding
import chess
import einops
import torch
from .sae import SparseAutoEncoder
class OutputGenerator:
def __init__(self, sae: SparseAutoEncoder, wrapper: ModelWrapper, module_exp: Optional[str] = None):
self.sae = sae
self.wrapper = wrapper
self.lens = ActivationLens(module_exp=module_exp)
@torch.no_grad
def generate(
self,
root_fen: Optional[str] = None,
traj_fen: Optional[str] = None,
root_board: Optional[chess.Board] = None,
traj_board: Optional[chess.Board] = None,
):
if root_board is not None and traj_board is not None:
input_encoding = InputEncoding.INPUT_CLASSICAL_112_PLANE
elif root_fen is not None and traj_fen is not None:
root_board = chess.Board(root_fen)
traj_board = chess.Board(traj_fen)
input_encoding = InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED
else:
raise ValueError
iter_boards = iter([([root_board, traj_board],)])
result_iter = self.lens.analyse_batched_boards(
iter_boards,
self.wrapper,
return_output=True,
wrapper_kwargs={
"input_encoding": input_encoding,
}
)
act_dict, (model_output,) = next(result_iter)
if len(act_dict) == 0:
raise ValueError("No module matced the given expression.")
elif len(act_dict) > 1:
raise ValueError("Multiple modules matched the given expression.")
acts = next(iter(act_dict.values()))
root_acts = einops.rearrange(acts[0], "c h w -> (h w) c")
traj_acts = einops.rearrange(acts[1], "c h w -> (h w) c")
pixel_acts = torch.cat([root_acts, traj_acts], dim=1)
sae_output = self.sae(pixel_acts, output_features=True)
return model_output, pixel_acts, sae_output