# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import uuid from typing import List, Optional, Tuple import networkx as nx import streamlit as st import torch import transformers import llm_transparency_tool.routes.graph from llm_transparency_tool.models.tlens_model import TransformerLensTransparentLlm from llm_transparency_tool.models.transparent_llm import TransparentLlm GPU = "gpu" CPU = "cpu" # This variable is for expressing the idea that batch_id = 0, but make it more # readable than just 0. B0 = 0 def possible_devices() -> List[str]: devices = [] if torch.cuda.is_available(): devices.append("gpu") devices.append("cpu") return devices def load_dataset(filename) -> List[str]: with open(filename) as f: dataset = [s.strip("\n") for s in f.readlines()] print(f"Loaded {len(dataset)} sentences from {filename}") return dataset @st.cache_resource( hash_funcs={ TransformerLensTransparentLlm: id } ) def load_model( model_name: str, _device: str, _model_path: Optional[str] = None, _dtype: torch.dtype = torch.float32, ) -> TransparentLlm: """ Returns the loaded model along with its key. The key is just a unique string which can be used later to identify if the model has changed. """ assert _device in possible_devices() causal_lm = None tokenizer = None tl_lm = TransformerLensTransparentLlm( model_name=model_name, hf_model=causal_lm, tokenizer=tokenizer, device=_device, dtype=_dtype, ) return tl_lm def run_model(model: TransparentLlm, sentence: str) -> None: print(f"Running inference for '{sentence}'") model.run([sentence]) def load_model_with_session_caching( **kwargs, ) -> Tuple[TransparentLlm, str]: return load_model(**kwargs) def run_model_with_session_caching( _model: TransparentLlm, model_key: str, sentence: str, ): LAST_RUN_MODEL_KEY = "last_run_model_key" LAST_RUN_SENTENCE = "last_run_sentence" state = st.session_state if ( state.get(LAST_RUN_MODEL_KEY, None) == model_key and state.get(LAST_RUN_SENTENCE, None) == sentence ): return run_model(_model, sentence) state[LAST_RUN_MODEL_KEY] = model_key state[LAST_RUN_SENTENCE] = sentence @st.cache_resource( hash_funcs={ TransformerLensTransparentLlm: id } ) def get_contribution_graph( model: TransparentLlm, # TODO bug here model_key: str, tokens: List[str], threshold: float, ) -> nx.Graph: """ The `model_key` and `tokens` are used only for caching. The model itself is not hashed, hence the `_` in the beginning. """ return llm_transparency_tool.routes.graph.build_full_graph( model, B0, threshold, ) def st_placeholder( text: str, container=st, border: bool = True, height: Optional[int] = 500, ): empty = container.empty() empty.container(border=border, height=height).write(f'{text}', unsafe_allow_html=True) return empty