from collections import defaultdict from typing import Dict, Optional import torch from transformers import AutoModelForCausalLM, AutoTokenizer from . import nethook class LogitLens: """ Applies the LM head at the output of each hidden layer, then analyzes the resultant token probability distribution. Only works when hooking outputs of *one* individual generation. Inspiration: https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens Warning: when running multiple times (e.g. generation), will return outputs _only_ for the last processing step. """ def __init__( self, model: AutoModelForCausalLM, tok: AutoTokenizer, layer_module_tmp: str, ln_f_module: str, lm_head_module: str, disabled: bool = False, ): self.disabled = disabled self.model, self.tok = model, tok self.n_layers = self.model.config.n_layer self.lm_head, self.ln_f = ( nethook.get_module(model, lm_head_module), nethook.get_module(model, ln_f_module), ) self.output: Optional[Dict] = None self.td: Optional[nethook.TraceDict] = None self.trace_layers = [ layer_module_tmp.format(layer) for layer in range(self.n_layers) ] def __enter__(self): if not self.disabled: self.td = nethook.TraceDict( self.model, self.trace_layers, retain_input=False, retain_output=True, ) self.td.__enter__() def __exit__(self, *args): if self.disabled: return self.td.__exit__(*args) self.output = {layer: [] for layer in range(self.n_layers)} with torch.no_grad(): for layer, (_, t) in enumerate(self.td.items()): cur_out = t.output[0] assert ( cur_out.size(0) == 1 ), "Make sure you're only running LogitLens on single generations only." self.output[layer] = torch.softmax( self.lm_head(self.ln_f(cur_out[:, -1, :])), dim=1 ) return self.output def pprint(self, k=5): to_print = defaultdict(list) for layer, pred in self.output.items(): rets = torch.topk(pred[0], k) for i in range(k): to_print[layer].append( ( self.tok.decode(rets[1][i]), round(rets[0][i].item() * 1e2) / 1e2, ) ) print( "\n".join( [ f"{layer}: {[(el[0], round(el[1] * 1e2)) for el in to_print[layer]]}" for layer in range(self.n_layers) ] ) )