# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import argparse from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple import networkx as nx import pandas as pd import plotly.express import plotly.graph_objects as go import streamlit as st import streamlit_extras.row as st_row import torch from jaxtyping import Float from torch.amp import autocast from transformers import HfArgumentParser import llm_transparency_tool.components from llm_transparency_tool.models.tlens_model import TransformerLensTransparentLlm import llm_transparency_tool.routes.contributions as contributions import llm_transparency_tool.routes.graph from llm_transparency_tool.models.transparent_llm import TransparentLlm from llm_transparency_tool.routes.graph_node import NodeType from llm_transparency_tool.server.graph_selection import ( GraphSelection, UiGraphEdge, UiGraphNode, ) from llm_transparency_tool.server.styles import ( RenderSettings, logits_color_map, margins_css, string_to_display, ) from llm_transparency_tool.server.utils import ( B0, get_contribution_graph, load_dataset, load_model, possible_devices, run_model_with_session_caching, st_placeholder, ) from llm_transparency_tool.server.monitor import SystemMonitor from networkx.classes.digraph import DiGraph @st.cache_resource( hash_funcs={ nx.Graph: id, DiGraph: id } ) def cached_build_paths_to_predictions( graph: nx.Graph, n_layers: int, n_tokens: int, starting_tokens: List[int], threshold: float, ): return llm_transparency_tool.routes.graph.build_paths_to_predictions( graph, n_layers, n_tokens, starting_tokens, threshold ) @st.cache_resource( hash_funcs={ TransformerLensTransparentLlm: id } ) def cached_run_inference_and_populate_state( stateless_model, sentences, ): stateful_model = stateless_model.copy() stateful_model.run(sentences) return stateful_model @dataclass class LlmViewerConfig: debug: bool = field( default=False, metadata={"help": "Show debugging information, like the time profile."}, ) preloaded_dataset_filename: Optional[str] = field( default=None, metadata={"help": "The name of the text file to load the lines from."}, ) demo_mode: bool = field( default=False, metadata={"help": "Whether the app should be in the demo mode."}, ) allow_loading_dataset_files: bool = field( default=True, metadata={"help": "Whether the app should be able to load the dataset files " "on the server side."}, ) max_user_string_length: Optional[int] = field( default=None, metadata={ "help": "Limit for the length of user-provided sentences (in characters), " "or None if there is no limit." }, ) models: Dict[str, str] = field( default_factory=dict, metadata={ "help": "Locations of models which are stored locally. Dictionary: official " "HuggingFace name -> path to dir. If None is specified, the model will be" "downloaded from HuggingFace." }, ) default_model: str = field( default="", metadata={"help": "The model to load once the UI is started."}, ) class App: _stateful_model: TransparentLlm = None render_settings = RenderSettings() _graph: Optional[nx.Graph] = None _contribution_threshold: float = 0.0 _renormalize_after_threshold: bool = False _normalize_before_unembedding: bool = True @property def stateful_model(self) -> TransparentLlm: return self._stateful_model def __init__(self, config: LlmViewerConfig): self._config = config st.set_page_config(layout="wide") st.markdown(margins_css, unsafe_allow_html=True) def _get_representation(self, node: Optional[UiGraphNode]) -> Optional[Float[torch.Tensor, "d_model"]]: if node is None: return None fn = { NodeType.AFTER_ATTN: self.stateful_model.residual_after_attn, NodeType.AFTER_FFN: self.stateful_model.residual_out, NodeType.FFN: None, NodeType.ORIGINAL: self.stateful_model.residual_in, } return fn[node.type](node.layer)[B0][node.token] def draw_model_info(self): info = self.stateful_model.model_info().__dict__ df = pd.DataFrame( data=[str(x) for x in info.values()], index=info.keys(), columns=["Model parameter"], ) st.dataframe(df, use_container_width=False) def draw_dataset_selection(self) -> int: def update_dataset(filename: Optional[str]): dataset = load_dataset(filename) if filename is not None else [] st.session_state["dataset"] = dataset st.session_state["dataset_file"] = filename if "dataset" not in st.session_state: update_dataset(self._config.preloaded_dataset_filename) if not self._config.demo_mode: if self._config.allow_loading_dataset_files: row_f = st_row.row([2, 1], vertical_align="bottom") filename = row_f.text_input("Dataset", value=st.session_state.dataset_file or "") if row_f.button("Load"): update_dataset(filename) row_s = st_row.row([2, 1], vertical_align="bottom") new_sentence = row_s.text_input("New sentence") new_sentence_added = False if row_s.button("Add"): max_len = self._config.max_user_string_length n = len(new_sentence) if max_len is None or n <= max_len: st.session_state.dataset.append(new_sentence) new_sentence_added = True st.session_state.sentence_selector = new_sentence else: st.warning(f"Sentence length {n} is larger than " f"the configured limit of {max_len}") sentences = st.session_state.dataset selection = st.selectbox( "Sentence", sentences, index=len(sentences) - 1, key="sentence_selector", ) return selection def _unembed( self, representation: torch.Tensor, ) -> torch.Tensor: return self.stateful_model.unembed(representation, normalize=self._normalize_before_unembedding) def draw_graph(self, contribution_threshold: float) -> Optional[GraphSelection]: tokens = self.stateful_model.tokens()[B0] n_tokens = tokens.shape[0] model_info = self.stateful_model.model_info() graphs = cached_build_paths_to_predictions( self._graph, model_info.n_layers, n_tokens, range(n_tokens), contribution_threshold, ) return llm_transparency_tool.components.contribution_graph( model_info, self.stateful_model.tokens_to_strings(tokens), graphs, key=f"graph_{hash(self.sentence)}", ) def draw_token_matrix( self, values: Float[torch.Tensor, "t t"], tokens: List[str], value_name: str, title: str, ): assert values.shape[0] == len(tokens) labels = { "x": "src", "y": "tgt", "color": value_name, } captions = [f"({i}){t}" for i, t in enumerate(tokens)] fig = plotly.express.imshow( values.cpu(), title=f'{title}', labels=labels, x=captions, y=captions, color_continuous_scale=self.render_settings.attention_color_map, aspect="equal", ) fig.update_layout( autosize=True, margin=go.layout.Margin( l=50, # left margin r=0, # right margin b=100, # bottom margin t=100, # top margin # pad=10 # padding ) ) fig.update_xaxes(tickmode="linear") fig.update_yaxes(tickmode="linear") fig.update_coloraxes(showscale=False) st.plotly_chart(fig, use_container_width=True, theme=None) def draw_attn_info(self, edge: UiGraphEdge, container_attention_map) -> Optional[int]: """ Returns: the index of the selected head. """ n_heads = self.stateful_model.model_info().n_heads layer = edge.target.layer head_contrib, _ = contributions.get_attention_contributions( resid_pre=self.stateful_model.residual_in(layer)[B0].unsqueeze(0), resid_mid=self.stateful_model.residual_after_attn(layer)[B0].unsqueeze(0), decomposed_attn=self.stateful_model.decomposed_attn(B0, layer).unsqueeze(0), ) # [batch pos key_pos head] -> [head] flat_contrib = head_contrib[0, edge.target.token, edge.source.token, :] assert flat_contrib.shape[0] == n_heads, f"{flat_contrib.shape} vs {n_heads}" selected_head = llm_transparency_tool.components.selector( items=[f"H{h}" if h >= 0 else "All" for h in range(-1, n_heads)], indices=range(-1, n_heads), temperatures=[sum(flat_contrib).item()] + flat_contrib.tolist(), preselected_index=flat_contrib.argmax().item(), key=f"head_selector_layer_{layer}" #_from_tok_{edge.source.token}_to_tok_{edge.target.token}", ) print(f"head_selector_layer_{layer}_from_tok_{edge.source.token}_to_tok_{edge.target.token}") if selected_head == -1 or selected_head is None: # selected_head = None selected_head = flat_contrib.argmax().item() print('****\n' * 3 + f"selected_head: {selected_head}" + '\n****\n' * 3) # Draw attention matrix and contributions for the selected head. if selected_head is not None: tokens = [ string_to_display(s) for s in self.stateful_model.tokens_to_strings(self.stateful_model.tokens()[B0]) ] with container_attention_map: attn_container, contrib_container = st.columns([1, 1]) with attn_container: attn = self.stateful_model.attention_matrix(B0, layer, selected_head) self.draw_token_matrix( attn, tokens, "attention", f"Attention map L{layer} H{selected_head}", ) with contrib_container: contrib = head_contrib[B0, :, :, selected_head] self.draw_token_matrix( contrib, tokens, "contribution", f"Contribution map L{layer} H{selected_head}", ) return selected_head def draw_ffn_info(self, node: UiGraphNode) -> Optional[int]: """ Returns: the index of the selected neuron. """ resid_mid = self.stateful_model.residual_after_attn(node.layer)[B0][node.token] resid_post = self.stateful_model.residual_out(node.layer)[B0][node.token] decomposed_ffn = self.stateful_model.decomposed_ffn_out(B0, node.layer, node.token) c_ffn, _ = contributions.get_decomposed_mlp_contributions(resid_mid, resid_post, decomposed_ffn) top_values, top_i = c_ffn.sort(descending=True) n = min(self.render_settings.n_top_neurons, c_ffn.shape[0]) top_neurons = top_i[0:n].tolist() selected_neuron = llm_transparency_tool.components.selector( items=[f"{top_neurons[i]}" if i >= 0 else "All" for i in range(-1, n)], indices=range(-1, n), temperatures=[0.0] + top_values[0:n].tolist(), preselected_index=-1, key="neuron_selector", ) if selected_neuron is None: selected_neuron = -1 selected_neuron = None if selected_neuron == -1 else top_neurons[selected_neuron] return selected_neuron def _draw_token_table( self, n_top: int, n_bottom: int, representation: torch.Tensor, predecessor: Optional[torch.Tensor] = None, ): n_total = n_top + n_bottom logits = self._unembed(representation) n_vocab = logits.shape[0] scores, indices = torch.topk(logits, n_top, largest=True) positions = list(range(n_top)) if n_bottom > 0: low_scores, low_indices = torch.topk(logits, n_bottom, largest=False) indices = torch.cat((indices, low_indices.flip(0))) scores = torch.cat((scores, low_scores.flip(0))) positions += range(n_vocab - n_bottom, n_vocab) tokens = [string_to_display(w) for w in self.stateful_model.tokens_to_strings(indices)] if predecessor is not None: pre_logits = self._unembed(predecessor) _, sorted_pre_indices = pre_logits.sort(descending=True) pre_indices_dict = {index: pos for pos, index in enumerate(sorted_pre_indices.tolist())} old_positions = [pre_indices_dict[i] for i in indices.tolist()] def pos_gain_string(pos, old_pos): if pos == old_pos: return "" sign = "↓" if pos > old_pos else "↑" return f"({sign}{abs(pos - old_pos)})" position_strings = [f"{i} {pos_gain_string(i, old_i)}" for (i, old_i) in zip(positions, old_positions)] else: position_strings = [str(pos) for pos in positions] def pos_gain_color(s): color = "black" if isinstance(s, str): if "↓" in s: color = "red" if "↑" in s: color = "green" return f"color: {color}" top_df = pd.DataFrame( data=zip(position_strings, tokens, scores.tolist()), columns=["Pos", "Token", "Score"], ) st.dataframe( top_df.style.map(pos_gain_color) .background_gradient( axis=0, cmap=logits_color_map(positive_and_negative=n_bottom > 0), ) .format(precision=3), hide_index=True, height=self.render_settings.table_cell_height * (n_total + 1), use_container_width=True, ) def draw_token_dynamics(self, representation: torch.Tensor, block_name: str) -> None: st.caption(block_name) self._draw_token_table( self.render_settings.n_promoted_tokens, self.render_settings.n_suppressed_tokens, representation, None, ) def draw_top_tokens( self, node: UiGraphNode, container_top_tokens, container_token_dynamics, ) -> None: pre_node = node.get_residual_predecessor() if pre_node is None: return representation = self._get_representation(node) predecessor = self._get_representation(pre_node) with container_top_tokens: st.caption(node.get_name()) self._draw_token_table( self.render_settings.n_top_tokens, 0, representation, predecessor, ) if container_token_dynamics is not None: with container_token_dynamics: self.draw_token_dynamics(representation - predecessor, node.get_predecessor_block_name()) def draw_attention_dynamics(self, node: UiGraphNode, head: Optional[int]): block_name = node.get_head_name(head) block_output = ( self.stateful_model.attention_output_per_head(B0, node.layer, node.token, head) if head is not None else self.stateful_model.attention_output(B0, node.layer, node.token) ) self.draw_token_dynamics(block_output, block_name) def draw_ffn_dynamics(self, node: UiGraphNode, neuron: Optional[int]): block_name = node.get_neuron_name(neuron) block_output = ( self.stateful_model.neuron_output(node.layer, neuron) if neuron is not None else self.stateful_model.ffn_out(node.layer)[B0][node.token] ) self.draw_token_dynamics(block_output, block_name) def draw_precision_controls(self, device: str) -> Tuple[torch.dtype, bool]: """ Draw fp16/fp32 switch and AMP control. return: The selected precision and whether AMP should be enabled. """ if device == "cpu": dtype = torch.float32 else: dtype = st.selectbox( "Precision", [torch.float16, torch.bfloat16, torch.float32], index=0, ) amp_enabled = dtype != torch.float32 return dtype, amp_enabled def draw_controls(self): # model_container, data_container = st.columns([1, 1]) with st.sidebar.expander("Model", expanded=True): list_of_devices = possible_devices() if len(list_of_devices) > 1: self.device = st.selectbox( "Device", possible_devices(), index=0, ) else: self.device = list_of_devices[0] self.dtype, self.amp_enabled = self.draw_precision_controls(self.device) model_list = list(self._config.models) default_choice = model_list.index(self._config.default_model) self.model_name = st.selectbox( "Model", model_list, index=default_choice, ) if self.model_name: self._stateful_model = load_model( model_name=self.model_name, _model_path=self._config.models[self.model_name], _device=self.device, _dtype=self.dtype, ) self.model_key = self.model_name # TODO maybe something else? self.draw_model_info() self.sentence = self.draw_dataset_selection() with st.sidebar.expander("Graph", expanded=True): self._contribution_threshold = st.slider( min_value=0.01, max_value=0.1, step=0.01, value=0.04, format=r"%.3f", label="Contribution threshold", ) self._renormalize_after_threshold = st.checkbox("Renormalize after threshold", value=True) self._normalize_before_unembedding = st.checkbox("Normalize before unembedding", value=True) def run_inference(self): with autocast(enabled=self.amp_enabled, device_type="cuda", dtype=self.dtype): self._stateful_model = cached_run_inference_and_populate_state(self.stateful_model, [self.sentence]) with autocast(enabled=self.amp_enabled, device_type="cuda", dtype=self.dtype): self._graph = get_contribution_graph( self.stateful_model, self.model_key, self.stateful_model.tokens()[B0].tolist(), (self._contribution_threshold if self._renormalize_after_threshold else 0.0), ) def draw_graph_and_selection( self, ) -> None: ( container_graph, container_tokens, ) = st.columns(self.render_settings.column_proportions) container_graph_left, container_graph_right = container_graph.columns([5, 1]) container_graph_left.write('##### Graph') heads_placeholder = container_graph_right.empty() heads_placeholder.write('##### Blocks') container_graph_right_used = False container_top_tokens, container_token_dynamics = container_tokens.columns([1, 1]) container_top_tokens.write('##### Top Tokens') container_top_tokens_used = False container_token_dynamics.write('##### Promoted Tokens') container_token_dynamics_used = False try: if self.sentence is None: return with container_graph_left: selection = self.draw_graph(self._contribution_threshold if not self._renormalize_after_threshold else 0.0) if selection is None: return node = selection.node edge = selection.edge if edge is not None and edge.target.type == NodeType.AFTER_ATTN: with container_graph_right: container_graph_right_used = True heads_placeholder.write('##### Heads') head = self.draw_attn_info(edge, container_graph) with container_token_dynamics: self.draw_attention_dynamics(edge.target, head) container_token_dynamics_used = True elif node is not None and node.type == NodeType.FFN: with container_graph_right: container_graph_right_used = True heads_placeholder.write('##### Neurons') neuron = self.draw_ffn_info(node) with container_token_dynamics: self.draw_ffn_dynamics(node, neuron) container_token_dynamics_used = True if node is not None and node.is_in_residual_stream(): self.draw_top_tokens( node, container_top_tokens, container_token_dynamics if not container_token_dynamics_used else None, ) container_top_tokens_used = True container_token_dynamics_used = True finally: if not container_graph_right_used: st_placeholder('Click on an edge to see head contributions. \n\n' 'Or click on FFN to see individual neuron contributions.', container_graph_right, height=1100) if not container_top_tokens_used: st_placeholder('Select a node from residual stream to see its top tokens.', container_top_tokens, height=1100) if not container_token_dynamics_used: st_placeholder('Select a node to see its promoted tokens.', container_token_dynamics, height=1100) def run(self): with st.sidebar.expander("About", expanded=True): if self._config.demo_mode: st.caption(""" The app is deployed in Demo Mode, thus only predefined models and inputs are available.\n You can still install the app locally and use your own models and inputs.\n See https://github.com/facebookresearch/llm-transparency-tool for more information. """) self.draw_controls() if not self.model_name: st.warning("No model selected") st.stop() if self.sentence is None: st.warning("No sentence selected") else: with torch.inference_mode(): self.run_inference() self.draw_graph_and_selection() if __name__ == "__main__": top_parser = argparse.ArgumentParser() top_parser.add_argument("config_file") args = top_parser.parse_args() parser = HfArgumentParser([LlmViewerConfig]) config = parser.parse_json_file(args.config_file)[0] with SystemMonitor(config.debug) as prof: app = App(config) app.run()