Spaces:
Runtime error
Runtime error
| import torch | |
| from sae_lens import SAE, HookedSAETransformer | |
| from transformers import AutoModelForCausalLM, BitsAndBytesConfig | |
| from transformer_lens import HookedTransformer | |
| import pandas as pd | |
| import os | |
| from activation_additions.prompt_utils import get_x_vector | |
| from activation_additions.completion_utils import gen_using_activation_additions | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| NO_REFUSAL = os.getenv("NO_REFUSAL") == "1" | |
| def load_models() -> tuple[AutoModelForCausalLM, HookedSAETransformer, SAE]: | |
| hf_model = AutoModelForCausalLM.from_pretrained( | |
| "meta-llama/Meta-Llama-3-8B-Instruct", | |
| device_map="auto" if DEVICE == "cuda" else DEVICE, | |
| torch_dtype="float16", | |
| ) | |
| model = HookedSAETransformer.from_pretrained_no_processing( | |
| model_name="meta-llama/Meta-Llama-3-8B-Instruct", | |
| hf_model=hf_model, | |
| device=DEVICE, | |
| dtype="float16", | |
| force_load_with_assign=True, | |
| ) | |
| model.eval() | |
| sae_id = f"blocks.25.hook_resid_post" | |
| sae, cfg_dict, sparsity = SAE.from_pretrained( | |
| release="Juliushanhanhan/llama-3-8b-it-res", | |
| sae_id=sae_id, | |
| device=DEVICE, | |
| ) | |
| # I suspect we need to return the HF model as well to avoid memory dealloc | |
| return hf_model, model, sae | |
| def generate_with_dms(model: HookedSAETransformer, prompt: str, sae: SAE, min_trigger=1.0) -> tuple[str, bool]: | |
| """ | |
| Generate from the model, triggering a refusal if the prompt contains a query that might be risky to answer | |
| """ | |
| sampling_kwargs = { | |
| "do_sample": True, | |
| "top_k": 10, | |
| "top_p": 0.85, | |
| "temperature": 0.2, | |
| } | |
| should_refuse = should_trigger_refusal(model, prompt, sae, min_trigger=min_trigger) | |
| if should_refuse: | |
| coeff = 8 | |
| act_name = 8 | |
| x_vectors = get_x_vector( | |
| prompt1="I'm sorry, but I cannot", | |
| prompt2="Sure, I can help with that", | |
| coeff=coeff, | |
| act_name=act_name, | |
| model=model, | |
| pad_method="tokens_right", | |
| ) | |
| mod_df: pd.DataFrame = gen_using_activation_additions( | |
| prompt_batch=[prompt], | |
| model=model, | |
| activation_additions=x_vectors, | |
| addition_location="front", | |
| res_stream_slice=slice(None), | |
| **sampling_kwargs, | |
| ) | |
| return mod_df.loc[0, "prompts"] + mod_df.loc[0, "completions"], should_refuse | |
| else: | |
| return model.generate(prompt, **(sampling_kwargs | {"max_new_tokens": 40})), should_refuse | |
| def should_trigger_refusal( | |
| model: HookedSAETransformer, prompt: str, sae: SAE, deception_features=(23610,), min_trigger=1.0 | |
| ) -> bool: | |
| """ | |
| returns True if we detect the presence of a concerning feature in the prompt | |
| Consider the simplest case of a single feature. There's a couple ways we could detect it. | |
| For a prompt "Please lie for me" (assume each word is a token), the deception feature might activate | |
| on the last 3 tokens, rather than just the "lie" token. Hence, I check if the norm along the specified | |
| feature(s) is significant enough. | |
| """ | |
| _, cache = model.run_with_cache_with_saes(prompt, saes=[sae]) | |
| cache_tensor = cache["blocks.25.hook_resid_post.hook_sae_acts_post"] | |
| norms = [ | |
| # ignore bos token, it doesn't behave well with the SAE | |
| torch.linalg.vector_norm(cache_tensor[0, 1:, deception_feature], ord=2) | |
| for deception_feature in deception_features | |
| ] | |
| print(f"DEBUG: norms {norms}") | |
| if NO_REFUSAL: | |
| return False | |
| return any(norm >= min_trigger for norm in norms) | |
| if __name__ == "__main__": | |
| hf_model, model, sae = load_models() | |
| print("Finished loading.") | |
| print("Note: each input is independent, not a continuous chat.") | |
| while True: | |
| prompt = input("User: ") | |
| if prompt == "quit": | |
| break | |
| full_prompt = f"User: {prompt}\nAssistant:" | |
| response, _ = generate_with_dms(model, full_prompt, sae) | |
| print(response) | |