File size: 6,298 Bytes
ce00289 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional
import networkx as nx
import torch
import llm_transparency_tool.routes.contributions as contributions
from llm_transparency_tool.models.transparent_llm import TransparentLlm
class GraphBuilder:
"""
Constructs the contributions graph with edges given one by one. The resulting graph
is a networkx graph that can be accessed via the `graph` field. It contains the
following types of nodes:
- X0_<token>: the original token.
- A<layer>_<token>: the residual stream after attention at the given layer for the
given token.
- M<layer>_<token>: the ffn block.
- I<layer>_<token>: the residual stream after the ffn block.
"""
def __init__(self, n_layers: int, n_tokens: int):
self._n_layers = n_layers
self._n_tokens = n_tokens
self.graph = nx.DiGraph()
for layer in range(n_layers):
for token in range(n_tokens):
self.graph.add_node(f"A{layer}_{token}")
self.graph.add_node(f"I{layer}_{token}")
self.graph.add_node(f"M{layer}_{token}")
for token in range(n_tokens):
self.graph.add_node(f"X0_{token}")
def get_output_node(self, token: int):
return f"I{self._n_layers - 1}_{token}"
def _add_edge(self, u: str, v: str, weight: float):
# TODO(igortufanov): Here we sum up weights for multi-edges. It happens with
# attention from the current token and the residual edge. Ideally these need to
# be 2 separate edges, but then we need to do a MultiGraph. Multigraph is fine,
# but when we try to traverse it, we face some NetworkX issue with EDGE_OK
# receiving 3 arguments instead of 2.
if self.graph.has_edge(u, v):
self.graph[u][v]["weight"] += weight
else:
self.graph.add_edge(u, v, weight=weight)
def add_attention_edge(self, layer: int, token_from: int, token_to: int, w: float):
self._add_edge(
f"I{layer-1}_{token_from}" if layer > 0 else f"X0_{token_from}",
f"A{layer}_{token_to}",
w,
)
def add_residual_to_attn(self, layer: int, token: int, w: float):
self._add_edge(
f"I{layer-1}_{token}" if layer > 0 else f"X0_{token}",
f"A{layer}_{token}",
w,
)
def add_ffn_edge(self, layer: int, token: int, w: float):
self._add_edge(f"A{layer}_{token}", f"M{layer}_{token}", w)
self._add_edge(f"M{layer}_{token}", f"I{layer}_{token}", w)
def add_residual_to_ffn(self, layer: int, token: int, w: float):
self._add_edge(f"A{layer}_{token}", f"I{layer}_{token}", w)
@torch.no_grad()
def build_full_graph(
model: TransparentLlm,
batch_i: int = 0,
renormalizing_threshold: Optional[float] = None,
) -> nx.Graph:
"""
Build the contribution graph for all blocks of the model and all tokens.
model: The transparent llm which already did the inference.
batch_i: Which sentence to use from the batch that was given to the model.
renormalizing_threshold: If specified, will apply renormalizing thresholding to the
contributions. All contributions below the threshold will be erazed and the rest
will be renormalized.
"""
n_layers = model.model_info().n_layers
n_tokens = model.tokens()[batch_i].shape[0]
builder = GraphBuilder(n_layers, n_tokens)
for layer in range(n_layers):
c_attn, c_resid_attn = contributions.get_attention_contributions(
resid_pre=model.residual_in(layer)[batch_i].unsqueeze(0),
resid_mid=model.residual_after_attn(layer)[batch_i].unsqueeze(0),
decomposed_attn=model.decomposed_attn(batch_i, layer).unsqueeze(0),
)
if renormalizing_threshold is not None:
c_attn, c_resid_attn = contributions.apply_threshold_and_renormalize(
renormalizing_threshold, c_attn, c_resid_attn
)
for token_from in range(n_tokens):
for token_to in range(n_tokens):
# Sum attention contributions over heads.
c = c_attn[batch_i, token_to, token_from].sum().item()
builder.add_attention_edge(layer, token_from, token_to, c)
for token in range(n_tokens):
builder.add_residual_to_attn(
layer, token, c_resid_attn[batch_i, token].item()
)
c_ffn, c_resid_ffn = contributions.get_mlp_contributions(
resid_mid=model.residual_after_attn(layer)[batch_i].unsqueeze(0),
resid_post=model.residual_out(layer)[batch_i].unsqueeze(0),
mlp_out=model.ffn_out(layer)[batch_i].unsqueeze(0),
)
if renormalizing_threshold is not None:
c_ffn, c_resid_ffn = contributions.apply_threshold_and_renormalize(
renormalizing_threshold, c_ffn, c_resid_ffn
)
for token in range(n_tokens):
builder.add_ffn_edge(layer, token, c_ffn[batch_i, token].item())
builder.add_residual_to_ffn(
layer, token, c_resid_ffn[batch_i, token].item()
)
return builder.graph
def build_paths_to_predictions(
graph: nx.Graph,
n_layers: int,
n_tokens: int,
starting_tokens: List[int],
threshold: float,
) -> List[nx.Graph]:
"""
Given the full graph, this function returns only the trees leading to the specified
tokens. Edges with weight below `threshold` will be ignored.
"""
builder = GraphBuilder(n_layers, n_tokens)
rgraph = graph.reverse()
search_graph = nx.subgraph_view(
rgraph, filter_edge=lambda u, v: rgraph[u][v]["weight"] > threshold
)
result = []
for start in starting_tokens:
assert start < n_tokens
assert start >= 0
edges = nx.edge_dfs(search_graph, source=builder.get_output_node(start))
tree = search_graph.edge_subgraph(edges)
# Reverse the edges because the dfs was going from upper layer downwards.
result.append(tree.reverse())
return result
|