File size: 6,719 Bytes
ce00289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import einops
import torch
from jaxtyping import Float
from typeguard import typechecked


@torch.no_grad()
@typechecked
def get_contributions(
    parts: torch.Tensor,
    whole: torch.Tensor,
    distance_norm: int = 1,
) -> torch.Tensor:
    """
    Compute contributions of the `parts` vectors into the `whole` vector.

    Shapes of the tensors are as follows:
    parts:  p_1 ... p_k, v_1 ... v_n, d
    whole:               v_1 ... v_n, d
    result: p_1 ... p_k, v_1 ... v_n

    Here
    * `p_1 ... p_k`: dimensions for enumerating the parts
    * `v_1 ... v_n`: dimensions listing the independent cases (batching),
    * `d` is the dimension to compute the distances on.

    The resulting contributions will be normalized so that
    for each v_: sum(over p_ of result(p_, v_)) = 1.
    """
    EPS = 1e-5

    k = len(parts.shape) - len(whole.shape)
    assert k >= 0
    assert parts.shape[k:] == whole.shape
    bc_whole = whole.expand(parts.shape)  # new dims p_1 ... p_k are added to the front

    distance = torch.nn.functional.pairwise_distance(parts, bc_whole, p=distance_norm)

    whole_norm = torch.norm(whole, p=distance_norm, dim=-1)
    distance = (whole_norm - distance).clip(min=EPS)

    sum = distance.sum(dim=tuple(range(k)), keepdim=True)

    return distance / sum


@torch.no_grad()
@typechecked
def get_contributions_with_one_off_part(
    parts: torch.Tensor,
    one_off: torch.Tensor,
    whole: torch.Tensor,
    distance_norm: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Same as computing the contributions, but there is one additional part. That's useful
    because we always have the residual stream as one of the parts.

    See `get_contributions` documentation about `parts` and `whole` dimensions. The
    `one_off` should have the same dimensions as `whole`.

    Returns a pair consisting of
    1. contributions tensor for the `parts`
    2. contributions tensor for the `one_off` vector
    """
    assert one_off.shape == whole.shape

    k = len(parts.shape) - len(whole.shape)
    assert k >= 0

    # Flatten the p_ dimensions, get contributions for the list, unflatten.
    flat = parts.flatten(start_dim=0, end_dim=k - 1)
    flat = torch.cat([flat, one_off.unsqueeze(0)])
    contributions = get_contributions(flat, whole, distance_norm)
    parts_contributions, one_off_contributions = torch.split(
        contributions, flat.shape[0] - 1
    )
    return (
        parts_contributions.unflatten(0, parts.shape[0:k]),
        one_off_contributions[0],
    )


@torch.no_grad()
@typechecked
def get_attention_contributions(
    resid_pre: Float[torch.Tensor, "batch pos d_model"],
    resid_mid: Float[torch.Tensor, "batch pos d_model"],
    decomposed_attn: Float[torch.Tensor, "batch pos key_pos head d_model"],
    distance_norm: int = 1,
) -> Tuple[
    Float[torch.Tensor, "batch pos key_pos head"],
    Float[torch.Tensor, "batch pos"],
]:
    """
    Returns a pair of
    - a tensor of contributions of each token via each head
    - the contribution of the residual stream.
    """

    # part dimensions | batch dimensions | vector dimension
    # ----------------+------------------+-----------------
    # key_pos, head   | batch, pos       | d_model
    parts = einops.rearrange(
        decomposed_attn,
        "batch pos key_pos head d_model -> key_pos head batch pos d_model",
    )
    attn_contribution, residual_contribution = get_contributions_with_one_off_part(
        parts, resid_pre, resid_mid, distance_norm
    )
    return (
        einops.rearrange(
            attn_contribution, "key_pos head batch pos -> batch pos key_pos head"
        ),
        residual_contribution,
    )


@torch.no_grad()
@typechecked
def get_mlp_contributions(
    resid_mid: Float[torch.Tensor, "batch pos d_model"],
    resid_post: Float[torch.Tensor, "batch pos d_model"],
    mlp_out: Float[torch.Tensor, "batch pos d_model"],
    distance_norm: int = 1,
) -> Tuple[Float[torch.Tensor, "batch pos"], Float[torch.Tensor, "batch pos"]]:
    """
    Returns a pair of (mlp, residual) contributions for each sentence and token.
    """

    contributions = get_contributions(
        torch.stack((mlp_out, resid_mid)), resid_post, distance_norm
    )
    return contributions[0], contributions[1]


@torch.no_grad()
@typechecked
def get_decomposed_mlp_contributions(
    resid_mid: Float[torch.Tensor, "d_model"],
    resid_post: Float[torch.Tensor, "d_model"],
    decomposed_mlp_out: Float[torch.Tensor, "hidden d_model"],
    distance_norm: int = 1,
) -> Tuple[Float[torch.Tensor, "hidden"], float]:
    """
    Similar to `get_mlp_contributions`, but it takes the MLP output for each neuron of
    the hidden layer and thus computes a contribution per neuron.

    Doesn't contain batch and token dimensions for sake of saving memory. But we may
    consider adding them.
    """

    neuron_contributions, residual_contribution = get_contributions_with_one_off_part(
        decomposed_mlp_out, resid_mid, resid_post, distance_norm
    )
    return neuron_contributions, residual_contribution.item()


@torch.no_grad()
def apply_threshold_and_renormalize(
    threshold: float,
    c_blocks: torch.Tensor,
    c_residual: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Thresholding mechanism used in the original graphs paper. After the threshold is
    applied, the remaining contributions are renormalized on order to sum up to 1 for
    each representation.

    threshold: The threshold.
    c_residual: Contribution of the residual stream for each representation. This tensor
        should contain 1 element per representation, i.e., its dimensions are all batch
        dimensions.
    c_blocks: Contributions of the blocks. Could be 1 block per representation, like
        ffn, or heads*tokens blocks in case of attention. The shape of `c_residual`
        must be a prefix if the shape of this tensor. The remaining dimensions are for
        listing the blocks.
    """

    block_dims = len(c_blocks.shape)
    resid_dims = len(c_residual.shape)
    bound_dims = block_dims - resid_dims
    assert bound_dims >= 0
    assert c_blocks.shape[0:resid_dims] == c_residual.shape

    c_blocks = c_blocks * (c_blocks > threshold)
    c_residual = c_residual * (c_residual > threshold)

    denom = c_residual + c_blocks.sum(dim=tuple(range(resid_dims, block_dims)))
    return (
        c_blocks / denom.reshape(denom.shape + (1,) * bound_dims),
        c_residual / denom,
    )