mahnerak's picture
Initial Commit πŸš€
ce00289
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple
import einops
import torch
from jaxtyping import Float
from typeguard import typechecked
@torch.no_grad()
@typechecked
def get_contributions(
parts: torch.Tensor,
whole: torch.Tensor,
distance_norm: int = 1,
) -> torch.Tensor:
"""
Compute contributions of the `parts` vectors into the `whole` vector.
Shapes of the tensors are as follows:
parts: p_1 ... p_k, v_1 ... v_n, d
whole: v_1 ... v_n, d
result: p_1 ... p_k, v_1 ... v_n
Here
* `p_1 ... p_k`: dimensions for enumerating the parts
* `v_1 ... v_n`: dimensions listing the independent cases (batching),
* `d` is the dimension to compute the distances on.
The resulting contributions will be normalized so that
for each v_: sum(over p_ of result(p_, v_)) = 1.
"""
EPS = 1e-5
k = len(parts.shape) - len(whole.shape)
assert k >= 0
assert parts.shape[k:] == whole.shape
bc_whole = whole.expand(parts.shape) # new dims p_1 ... p_k are added to the front
distance = torch.nn.functional.pairwise_distance(parts, bc_whole, p=distance_norm)
whole_norm = torch.norm(whole, p=distance_norm, dim=-1)
distance = (whole_norm - distance).clip(min=EPS)
sum = distance.sum(dim=tuple(range(k)), keepdim=True)
return distance / sum
@torch.no_grad()
@typechecked
def get_contributions_with_one_off_part(
parts: torch.Tensor,
one_off: torch.Tensor,
whole: torch.Tensor,
distance_norm: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Same as computing the contributions, but there is one additional part. That's useful
because we always have the residual stream as one of the parts.
See `get_contributions` documentation about `parts` and `whole` dimensions. The
`one_off` should have the same dimensions as `whole`.
Returns a pair consisting of
1. contributions tensor for the `parts`
2. contributions tensor for the `one_off` vector
"""
assert one_off.shape == whole.shape
k = len(parts.shape) - len(whole.shape)
assert k >= 0
# Flatten the p_ dimensions, get contributions for the list, unflatten.
flat = parts.flatten(start_dim=0, end_dim=k - 1)
flat = torch.cat([flat, one_off.unsqueeze(0)])
contributions = get_contributions(flat, whole, distance_norm)
parts_contributions, one_off_contributions = torch.split(
contributions, flat.shape[0] - 1
)
return (
parts_contributions.unflatten(0, parts.shape[0:k]),
one_off_contributions[0],
)
@torch.no_grad()
@typechecked
def get_attention_contributions(
resid_pre: Float[torch.Tensor, "batch pos d_model"],
resid_mid: Float[torch.Tensor, "batch pos d_model"],
decomposed_attn: Float[torch.Tensor, "batch pos key_pos head d_model"],
distance_norm: int = 1,
) -> Tuple[
Float[torch.Tensor, "batch pos key_pos head"],
Float[torch.Tensor, "batch pos"],
]:
"""
Returns a pair of
- a tensor of contributions of each token via each head
- the contribution of the residual stream.
"""
# part dimensions | batch dimensions | vector dimension
# ----------------+------------------+-----------------
# key_pos, head | batch, pos | d_model
parts = einops.rearrange(
decomposed_attn,
"batch pos key_pos head d_model -> key_pos head batch pos d_model",
)
attn_contribution, residual_contribution = get_contributions_with_one_off_part(
parts, resid_pre, resid_mid, distance_norm
)
return (
einops.rearrange(
attn_contribution, "key_pos head batch pos -> batch pos key_pos head"
),
residual_contribution,
)
@torch.no_grad()
@typechecked
def get_mlp_contributions(
resid_mid: Float[torch.Tensor, "batch pos d_model"],
resid_post: Float[torch.Tensor, "batch pos d_model"],
mlp_out: Float[torch.Tensor, "batch pos d_model"],
distance_norm: int = 1,
) -> Tuple[Float[torch.Tensor, "batch pos"], Float[torch.Tensor, "batch pos"]]:
"""
Returns a pair of (mlp, residual) contributions for each sentence and token.
"""
contributions = get_contributions(
torch.stack((mlp_out, resid_mid)), resid_post, distance_norm
)
return contributions[0], contributions[1]
@torch.no_grad()
@typechecked
def get_decomposed_mlp_contributions(
resid_mid: Float[torch.Tensor, "d_model"],
resid_post: Float[torch.Tensor, "d_model"],
decomposed_mlp_out: Float[torch.Tensor, "hidden d_model"],
distance_norm: int = 1,
) -> Tuple[Float[torch.Tensor, "hidden"], float]:
"""
Similar to `get_mlp_contributions`, but it takes the MLP output for each neuron of
the hidden layer and thus computes a contribution per neuron.
Doesn't contain batch and token dimensions for sake of saving memory. But we may
consider adding them.
"""
neuron_contributions, residual_contribution = get_contributions_with_one_off_part(
decomposed_mlp_out, resid_mid, resid_post, distance_norm
)
return neuron_contributions, residual_contribution.item()
@torch.no_grad()
def apply_threshold_and_renormalize(
threshold: float,
c_blocks: torch.Tensor,
c_residual: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Thresholding mechanism used in the original graphs paper. After the threshold is
applied, the remaining contributions are renormalized on order to sum up to 1 for
each representation.
threshold: The threshold.
c_residual: Contribution of the residual stream for each representation. This tensor
should contain 1 element per representation, i.e., its dimensions are all batch
dimensions.
c_blocks: Contributions of the blocks. Could be 1 block per representation, like
ffn, or heads*tokens blocks in case of attention. The shape of `c_residual`
must be a prefix if the shape of this tensor. The remaining dimensions are for
listing the blocks.
"""
block_dims = len(c_blocks.shape)
resid_dims = len(c_residual.shape)
bound_dims = block_dims - resid_dims
assert bound_dims >= 0
assert c_blocks.shape[0:resid_dims] == c_residual.shape
c_blocks = c_blocks * (c_blocks > threshold)
c_residual = c_residual * (c_residual > threshold)
denom = c_residual + c_blocks.sum(dim=tuple(range(resid_dims, block_dims)))
return (
c_blocks / denom.reshape(denom.shape + (1,) * bound_dims),
c_residual / denom,
)