|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple |
|
|
|
import einops |
|
import torch |
|
from jaxtyping import Float |
|
from typeguard import typechecked |
|
|
|
|
|
@torch.no_grad() |
|
@typechecked |
|
def get_contributions( |
|
parts: torch.Tensor, |
|
whole: torch.Tensor, |
|
distance_norm: int = 1, |
|
) -> torch.Tensor: |
|
""" |
|
Compute contributions of the `parts` vectors into the `whole` vector. |
|
|
|
Shapes of the tensors are as follows: |
|
parts: p_1 ... p_k, v_1 ... v_n, d |
|
whole: v_1 ... v_n, d |
|
result: p_1 ... p_k, v_1 ... v_n |
|
|
|
Here |
|
* `p_1 ... p_k`: dimensions for enumerating the parts |
|
* `v_1 ... v_n`: dimensions listing the independent cases (batching), |
|
* `d` is the dimension to compute the distances on. |
|
|
|
The resulting contributions will be normalized so that |
|
for each v_: sum(over p_ of result(p_, v_)) = 1. |
|
""" |
|
EPS = 1e-5 |
|
|
|
k = len(parts.shape) - len(whole.shape) |
|
assert k >= 0 |
|
assert parts.shape[k:] == whole.shape |
|
bc_whole = whole.expand(parts.shape) |
|
|
|
distance = torch.nn.functional.pairwise_distance(parts, bc_whole, p=distance_norm) |
|
|
|
whole_norm = torch.norm(whole, p=distance_norm, dim=-1) |
|
distance = (whole_norm - distance).clip(min=EPS) |
|
|
|
sum = distance.sum(dim=tuple(range(k)), keepdim=True) |
|
|
|
return distance / sum |
|
|
|
|
|
@torch.no_grad() |
|
@typechecked |
|
def get_contributions_with_one_off_part( |
|
parts: torch.Tensor, |
|
one_off: torch.Tensor, |
|
whole: torch.Tensor, |
|
distance_norm: int = 1, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Same as computing the contributions, but there is one additional part. That's useful |
|
because we always have the residual stream as one of the parts. |
|
|
|
See `get_contributions` documentation about `parts` and `whole` dimensions. The |
|
`one_off` should have the same dimensions as `whole`. |
|
|
|
Returns a pair consisting of |
|
1. contributions tensor for the `parts` |
|
2. contributions tensor for the `one_off` vector |
|
""" |
|
assert one_off.shape == whole.shape |
|
|
|
k = len(parts.shape) - len(whole.shape) |
|
assert k >= 0 |
|
|
|
|
|
flat = parts.flatten(start_dim=0, end_dim=k - 1) |
|
flat = torch.cat([flat, one_off.unsqueeze(0)]) |
|
contributions = get_contributions(flat, whole, distance_norm) |
|
parts_contributions, one_off_contributions = torch.split( |
|
contributions, flat.shape[0] - 1 |
|
) |
|
return ( |
|
parts_contributions.unflatten(0, parts.shape[0:k]), |
|
one_off_contributions[0], |
|
) |
|
|
|
|
|
@torch.no_grad() |
|
@typechecked |
|
def get_attention_contributions( |
|
resid_pre: Float[torch.Tensor, "batch pos d_model"], |
|
resid_mid: Float[torch.Tensor, "batch pos d_model"], |
|
decomposed_attn: Float[torch.Tensor, "batch pos key_pos head d_model"], |
|
distance_norm: int = 1, |
|
) -> Tuple[ |
|
Float[torch.Tensor, "batch pos key_pos head"], |
|
Float[torch.Tensor, "batch pos"], |
|
]: |
|
""" |
|
Returns a pair of |
|
- a tensor of contributions of each token via each head |
|
- the contribution of the residual stream. |
|
""" |
|
|
|
|
|
|
|
|
|
parts = einops.rearrange( |
|
decomposed_attn, |
|
"batch pos key_pos head d_model -> key_pos head batch pos d_model", |
|
) |
|
attn_contribution, residual_contribution = get_contributions_with_one_off_part( |
|
parts, resid_pre, resid_mid, distance_norm |
|
) |
|
return ( |
|
einops.rearrange( |
|
attn_contribution, "key_pos head batch pos -> batch pos key_pos head" |
|
), |
|
residual_contribution, |
|
) |
|
|
|
|
|
@torch.no_grad() |
|
@typechecked |
|
def get_mlp_contributions( |
|
resid_mid: Float[torch.Tensor, "batch pos d_model"], |
|
resid_post: Float[torch.Tensor, "batch pos d_model"], |
|
mlp_out: Float[torch.Tensor, "batch pos d_model"], |
|
distance_norm: int = 1, |
|
) -> Tuple[Float[torch.Tensor, "batch pos"], Float[torch.Tensor, "batch pos"]]: |
|
""" |
|
Returns a pair of (mlp, residual) contributions for each sentence and token. |
|
""" |
|
|
|
contributions = get_contributions( |
|
torch.stack((mlp_out, resid_mid)), resid_post, distance_norm |
|
) |
|
return contributions[0], contributions[1] |
|
|
|
|
|
@torch.no_grad() |
|
@typechecked |
|
def get_decomposed_mlp_contributions( |
|
resid_mid: Float[torch.Tensor, "d_model"], |
|
resid_post: Float[torch.Tensor, "d_model"], |
|
decomposed_mlp_out: Float[torch.Tensor, "hidden d_model"], |
|
distance_norm: int = 1, |
|
) -> Tuple[Float[torch.Tensor, "hidden"], float]: |
|
""" |
|
Similar to `get_mlp_contributions`, but it takes the MLP output for each neuron of |
|
the hidden layer and thus computes a contribution per neuron. |
|
|
|
Doesn't contain batch and token dimensions for sake of saving memory. But we may |
|
consider adding them. |
|
""" |
|
|
|
neuron_contributions, residual_contribution = get_contributions_with_one_off_part( |
|
decomposed_mlp_out, resid_mid, resid_post, distance_norm |
|
) |
|
return neuron_contributions, residual_contribution.item() |
|
|
|
|
|
@torch.no_grad() |
|
def apply_threshold_and_renormalize( |
|
threshold: float, |
|
c_blocks: torch.Tensor, |
|
c_residual: torch.Tensor, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Thresholding mechanism used in the original graphs paper. After the threshold is |
|
applied, the remaining contributions are renormalized on order to sum up to 1 for |
|
each representation. |
|
|
|
threshold: The threshold. |
|
c_residual: Contribution of the residual stream for each representation. This tensor |
|
should contain 1 element per representation, i.e., its dimensions are all batch |
|
dimensions. |
|
c_blocks: Contributions of the blocks. Could be 1 block per representation, like |
|
ffn, or heads*tokens blocks in case of attention. The shape of `c_residual` |
|
must be a prefix if the shape of this tensor. The remaining dimensions are for |
|
listing the blocks. |
|
""" |
|
|
|
block_dims = len(c_blocks.shape) |
|
resid_dims = len(c_residual.shape) |
|
bound_dims = block_dims - resid_dims |
|
assert bound_dims >= 0 |
|
assert c_blocks.shape[0:resid_dims] == c_residual.shape |
|
|
|
c_blocks = c_blocks * (c_blocks > threshold) |
|
c_residual = c_residual * (c_residual > threshold) |
|
|
|
denom = c_residual + c_blocks.sum(dim=tuple(range(resid_dims, block_dims))) |
|
return ( |
|
c_blocks / denom.reshape(denom.shape + (1,) * bound_dims), |
|
c_residual / denom, |
|
) |
|
|