mahnerak's picture
Initial Commit πŸš€
ce00289
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional
import networkx as nx
import torch
import llm_transparency_tool.routes.contributions as contributions
from llm_transparency_tool.models.transparent_llm import TransparentLlm
class GraphBuilder:
"""
Constructs the contributions graph with edges given one by one. The resulting graph
is a networkx graph that can be accessed via the `graph` field. It contains the
following types of nodes:
- X0_<token>: the original token.
- A<layer>_<token>: the residual stream after attention at the given layer for the
given token.
- M<layer>_<token>: the ffn block.
- I<layer>_<token>: the residual stream after the ffn block.
"""
def __init__(self, n_layers: int, n_tokens: int):
self._n_layers = n_layers
self._n_tokens = n_tokens
self.graph = nx.DiGraph()
for layer in range(n_layers):
for token in range(n_tokens):
self.graph.add_node(f"A{layer}_{token}")
self.graph.add_node(f"I{layer}_{token}")
self.graph.add_node(f"M{layer}_{token}")
for token in range(n_tokens):
self.graph.add_node(f"X0_{token}")
def get_output_node(self, token: int):
return f"I{self._n_layers - 1}_{token}"
def _add_edge(self, u: str, v: str, weight: float):
# TODO(igortufanov): Here we sum up weights for multi-edges. It happens with
# attention from the current token and the residual edge. Ideally these need to
# be 2 separate edges, but then we need to do a MultiGraph. Multigraph is fine,
# but when we try to traverse it, we face some NetworkX issue with EDGE_OK
# receiving 3 arguments instead of 2.
if self.graph.has_edge(u, v):
self.graph[u][v]["weight"] += weight
else:
self.graph.add_edge(u, v, weight=weight)
def add_attention_edge(self, layer: int, token_from: int, token_to: int, w: float):
self._add_edge(
f"I{layer-1}_{token_from}" if layer > 0 else f"X0_{token_from}",
f"A{layer}_{token_to}",
w,
)
def add_residual_to_attn(self, layer: int, token: int, w: float):
self._add_edge(
f"I{layer-1}_{token}" if layer > 0 else f"X0_{token}",
f"A{layer}_{token}",
w,
)
def add_ffn_edge(self, layer: int, token: int, w: float):
self._add_edge(f"A{layer}_{token}", f"M{layer}_{token}", w)
self._add_edge(f"M{layer}_{token}", f"I{layer}_{token}", w)
def add_residual_to_ffn(self, layer: int, token: int, w: float):
self._add_edge(f"A{layer}_{token}", f"I{layer}_{token}", w)
@torch.no_grad()
def build_full_graph(
model: TransparentLlm,
batch_i: int = 0,
renormalizing_threshold: Optional[float] = None,
) -> nx.Graph:
"""
Build the contribution graph for all blocks of the model and all tokens.
model: The transparent llm which already did the inference.
batch_i: Which sentence to use from the batch that was given to the model.
renormalizing_threshold: If specified, will apply renormalizing thresholding to the
contributions. All contributions below the threshold will be erazed and the rest
will be renormalized.
"""
n_layers = model.model_info().n_layers
n_tokens = model.tokens()[batch_i].shape[0]
builder = GraphBuilder(n_layers, n_tokens)
for layer in range(n_layers):
c_attn, c_resid_attn = contributions.get_attention_contributions(
resid_pre=model.residual_in(layer)[batch_i].unsqueeze(0),
resid_mid=model.residual_after_attn(layer)[batch_i].unsqueeze(0),
decomposed_attn=model.decomposed_attn(batch_i, layer).unsqueeze(0),
)
if renormalizing_threshold is not None:
c_attn, c_resid_attn = contributions.apply_threshold_and_renormalize(
renormalizing_threshold, c_attn, c_resid_attn
)
for token_from in range(n_tokens):
for token_to in range(n_tokens):
# Sum attention contributions over heads.
c = c_attn[batch_i, token_to, token_from].sum().item()
builder.add_attention_edge(layer, token_from, token_to, c)
for token in range(n_tokens):
builder.add_residual_to_attn(
layer, token, c_resid_attn[batch_i, token].item()
)
c_ffn, c_resid_ffn = contributions.get_mlp_contributions(
resid_mid=model.residual_after_attn(layer)[batch_i].unsqueeze(0),
resid_post=model.residual_out(layer)[batch_i].unsqueeze(0),
mlp_out=model.ffn_out(layer)[batch_i].unsqueeze(0),
)
if renormalizing_threshold is not None:
c_ffn, c_resid_ffn = contributions.apply_threshold_and_renormalize(
renormalizing_threshold, c_ffn, c_resid_ffn
)
for token in range(n_tokens):
builder.add_ffn_edge(layer, token, c_ffn[batch_i, token].item())
builder.add_residual_to_ffn(
layer, token, c_resid_ffn[batch_i, token].item()
)
return builder.graph
def build_paths_to_predictions(
graph: nx.Graph,
n_layers: int,
n_tokens: int,
starting_tokens: List[int],
threshold: float,
) -> List[nx.Graph]:
"""
Given the full graph, this function returns only the trees leading to the specified
tokens. Edges with weight below `threshold` will be ignored.
"""
builder = GraphBuilder(n_layers, n_tokens)
rgraph = graph.reverse()
search_graph = nx.subgraph_view(
rgraph, filter_edge=lambda u, v: rgraph[u][v]["weight"] > threshold
)
result = []
for start in starting_tokens:
assert start < n_tokens
assert start >= 0
edges = nx.edge_dfs(search_graph, source=builder.get_output_node(start))
tree = search_graph.edge_subgraph(edges)
# Reverse the edges because the dfs was going from upper layer downwards.
result.append(tree.reverse())
return result