File size: 6,719 Bytes
ce00289 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple
import einops
import torch
from jaxtyping import Float
from typeguard import typechecked
@torch.no_grad()
@typechecked
def get_contributions(
parts: torch.Tensor,
whole: torch.Tensor,
distance_norm: int = 1,
) -> torch.Tensor:
"""
Compute contributions of the `parts` vectors into the `whole` vector.
Shapes of the tensors are as follows:
parts: p_1 ... p_k, v_1 ... v_n, d
whole: v_1 ... v_n, d
result: p_1 ... p_k, v_1 ... v_n
Here
* `p_1 ... p_k`: dimensions for enumerating the parts
* `v_1 ... v_n`: dimensions listing the independent cases (batching),
* `d` is the dimension to compute the distances on.
The resulting contributions will be normalized so that
for each v_: sum(over p_ of result(p_, v_)) = 1.
"""
EPS = 1e-5
k = len(parts.shape) - len(whole.shape)
assert k >= 0
assert parts.shape[k:] == whole.shape
bc_whole = whole.expand(parts.shape) # new dims p_1 ... p_k are added to the front
distance = torch.nn.functional.pairwise_distance(parts, bc_whole, p=distance_norm)
whole_norm = torch.norm(whole, p=distance_norm, dim=-1)
distance = (whole_norm - distance).clip(min=EPS)
sum = distance.sum(dim=tuple(range(k)), keepdim=True)
return distance / sum
@torch.no_grad()
@typechecked
def get_contributions_with_one_off_part(
parts: torch.Tensor,
one_off: torch.Tensor,
whole: torch.Tensor,
distance_norm: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Same as computing the contributions, but there is one additional part. That's useful
because we always have the residual stream as one of the parts.
See `get_contributions` documentation about `parts` and `whole` dimensions. The
`one_off` should have the same dimensions as `whole`.
Returns a pair consisting of
1. contributions tensor for the `parts`
2. contributions tensor for the `one_off` vector
"""
assert one_off.shape == whole.shape
k = len(parts.shape) - len(whole.shape)
assert k >= 0
# Flatten the p_ dimensions, get contributions for the list, unflatten.
flat = parts.flatten(start_dim=0, end_dim=k - 1)
flat = torch.cat([flat, one_off.unsqueeze(0)])
contributions = get_contributions(flat, whole, distance_norm)
parts_contributions, one_off_contributions = torch.split(
contributions, flat.shape[0] - 1
)
return (
parts_contributions.unflatten(0, parts.shape[0:k]),
one_off_contributions[0],
)
@torch.no_grad()
@typechecked
def get_attention_contributions(
resid_pre: Float[torch.Tensor, "batch pos d_model"],
resid_mid: Float[torch.Tensor, "batch pos d_model"],
decomposed_attn: Float[torch.Tensor, "batch pos key_pos head d_model"],
distance_norm: int = 1,
) -> Tuple[
Float[torch.Tensor, "batch pos key_pos head"],
Float[torch.Tensor, "batch pos"],
]:
"""
Returns a pair of
- a tensor of contributions of each token via each head
- the contribution of the residual stream.
"""
# part dimensions | batch dimensions | vector dimension
# ----------------+------------------+-----------------
# key_pos, head | batch, pos | d_model
parts = einops.rearrange(
decomposed_attn,
"batch pos key_pos head d_model -> key_pos head batch pos d_model",
)
attn_contribution, residual_contribution = get_contributions_with_one_off_part(
parts, resid_pre, resid_mid, distance_norm
)
return (
einops.rearrange(
attn_contribution, "key_pos head batch pos -> batch pos key_pos head"
),
residual_contribution,
)
@torch.no_grad()
@typechecked
def get_mlp_contributions(
resid_mid: Float[torch.Tensor, "batch pos d_model"],
resid_post: Float[torch.Tensor, "batch pos d_model"],
mlp_out: Float[torch.Tensor, "batch pos d_model"],
distance_norm: int = 1,
) -> Tuple[Float[torch.Tensor, "batch pos"], Float[torch.Tensor, "batch pos"]]:
"""
Returns a pair of (mlp, residual) contributions for each sentence and token.
"""
contributions = get_contributions(
torch.stack((mlp_out, resid_mid)), resid_post, distance_norm
)
return contributions[0], contributions[1]
@torch.no_grad()
@typechecked
def get_decomposed_mlp_contributions(
resid_mid: Float[torch.Tensor, "d_model"],
resid_post: Float[torch.Tensor, "d_model"],
decomposed_mlp_out: Float[torch.Tensor, "hidden d_model"],
distance_norm: int = 1,
) -> Tuple[Float[torch.Tensor, "hidden"], float]:
"""
Similar to `get_mlp_contributions`, but it takes the MLP output for each neuron of
the hidden layer and thus computes a contribution per neuron.
Doesn't contain batch and token dimensions for sake of saving memory. But we may
consider adding them.
"""
neuron_contributions, residual_contribution = get_contributions_with_one_off_part(
decomposed_mlp_out, resid_mid, resid_post, distance_norm
)
return neuron_contributions, residual_contribution.item()
@torch.no_grad()
def apply_threshold_and_renormalize(
threshold: float,
c_blocks: torch.Tensor,
c_residual: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Thresholding mechanism used in the original graphs paper. After the threshold is
applied, the remaining contributions are renormalized on order to sum up to 1 for
each representation.
threshold: The threshold.
c_residual: Contribution of the residual stream for each representation. This tensor
should contain 1 element per representation, i.e., its dimensions are all batch
dimensions.
c_blocks: Contributions of the blocks. Could be 1 block per representation, like
ffn, or heads*tokens blocks in case of attention. The shape of `c_residual`
must be a prefix if the shape of this tensor. The remaining dimensions are for
listing the blocks.
"""
block_dims = len(c_blocks.shape)
resid_dims = len(c_residual.shape)
bound_dims = block_dims - resid_dims
assert bound_dims >= 0
assert c_blocks.shape[0:resid_dims] == c_residual.shape
c_blocks = c_blocks * (c_blocks > threshold)
c_residual = c_residual * (c_residual > threshold)
denom = c_residual + c_blocks.sum(dim=tuple(range(resid_dims, block_dims)))
return (
c_blocks / denom.reshape(denom.shape + (1,) * bound_dims),
c_residual / denom,
)
|