cdpearlman commited on
Commit
5528a77
·
1 Parent(s): ac0335e

Bug fix: type issue with models leading to gibberish outputs

Browse files
app.py CHANGED
@@ -16,7 +16,8 @@ import dash
16
  from dash import html, dcc, Input, Output, State, callback, no_update, ALL, MATCH
17
  import json
18
  import torch
19
- from utils import (load_model_and_get_patterns, execute_forward_pass, extract_layer_data,
 
20
  perform_beam_search, execute_forward_pass_with_multi_layer_head_ablation)
21
  from utils.head_detection import get_active_head_summary
22
  from utils.model_config import get_auto_selections
@@ -374,10 +375,9 @@ def run_generation(n_clicks, model_name, prompt, max_new_tokens, beam_width, pat
374
  return no_update, no_update, no_update, no_update, no_update, no_update, no_update, no_update
375
 
376
  try:
377
- from transformers import AutoModelForCausalLM, AutoTokenizer
378
- model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
379
  tokenizer = AutoTokenizer.from_pretrained(model_name)
380
- model.eval()
381
 
382
  # Always run beam search (even with max_new_tokens=1)
383
  results = perform_beam_search(model, tokenizer, prompt, beam_width, max_new_tokens)
@@ -555,16 +555,15 @@ def store_selected_beam(n_clicks_list, results_data, existing_activation_data, o
555
  new_activation_data = no_update
556
  if existing_activation_data:
557
  try:
558
- from transformers import AutoModelForCausalLM, AutoTokenizer
559
  model_name = existing_activation_data['model']
560
  config = {
561
  'attention_modules': existing_activation_data['attention_modules'],
562
  'block_modules': existing_activation_data['block_modules'],
563
  'norm_parameters': existing_activation_data.get('norm_parameters', [])
564
  }
565
- model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
566
  tokenizer = AutoTokenizer.from_pretrained(model_name)
567
- model.eval()
568
  # Pass original_prompt so per-position top-5 data is computed for scrubber
569
  orig_prompt = original_prompt_data.get('prompt', '') if original_prompt_data else ''
570
  new_activation_data = execute_forward_pass(
@@ -613,10 +612,10 @@ def update_pipeline_content(activation_data, model_name):
613
  return tuple(empty_outputs)
614
 
615
  try:
616
- from transformers import AutoModelForCausalLM, AutoTokenizer
617
- model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
618
  tokenizer = AutoTokenizer.from_pretrained(model_name)
619
-
620
  # Use pre-decoded tokens if available, otherwise decode from input_ids
621
  input_ids = activation_data.get('input_ids', [[]])[0]
622
  tokens = activation_data.get('tokens') or [tokenizer.decode([tid]) for tid in input_ids]
@@ -922,11 +921,10 @@ def run_ablation_experiment(n_clicks, selected_heads, activation_data, model_nam
922
  return no_update, no_update, no_update
923
 
924
  try:
925
- from transformers import AutoModelForCausalLM, AutoTokenizer
926
- model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
927
  tokenizer = AutoTokenizer.from_pretrained(model_name)
928
- model.eval()
929
-
930
  sequence_text = prompt
931
 
932
  config = {
@@ -1087,7 +1085,9 @@ def update_attribution_target_options(activation_data):
1087
  options = []
1088
  for t in global_top5:
1089
  if isinstance(t, dict):
1090
- options.append({'label': f"{t['token']} ({t['probability']:.1%})", 'value': t['token']})
 
 
1091
  else:
1092
  options.append({'label': t[0], 'value': t[0]})
1093
  return options
@@ -1108,11 +1108,10 @@ def run_attribution_experiment(n_clicks, method, target_token, activation_data,
1108
  return no_update
1109
 
1110
  try:
1111
- from transformers import AutoModelForCausalLM, AutoTokenizer
1112
- model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
1113
  tokenizer = AutoTokenizer.from_pretrained(model_name)
1114
- model.eval()
1115
-
1116
  sequence_text = activation_data.get('prompt', prompt)
1117
 
1118
  # Get target token ID if specified
 
16
  from dash import html, dcc, Input, Output, State, callback, no_update, ALL, MATCH
17
  import json
18
  import torch
19
+ from utils import (load_model_for_inference, load_model_and_get_patterns,
20
+ execute_forward_pass, extract_layer_data,
21
  perform_beam_search, execute_forward_pass_with_multi_layer_head_ablation)
22
  from utils.head_detection import get_active_head_summary
23
  from utils.model_config import get_auto_selections
 
375
  return no_update, no_update, no_update, no_update, no_update, no_update, no_update, no_update
376
 
377
  try:
378
+ from transformers import AutoTokenizer
379
+ model = load_model_for_inference(model_name)
380
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
381
 
382
  # Always run beam search (even with max_new_tokens=1)
383
  results = perform_beam_search(model, tokenizer, prompt, beam_width, max_new_tokens)
 
555
  new_activation_data = no_update
556
  if existing_activation_data:
557
  try:
558
+ from transformers import AutoTokenizer
559
  model_name = existing_activation_data['model']
560
  config = {
561
  'attention_modules': existing_activation_data['attention_modules'],
562
  'block_modules': existing_activation_data['block_modules'],
563
  'norm_parameters': existing_activation_data.get('norm_parameters', [])
564
  }
565
+ model = load_model_for_inference(model_name)
566
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
567
  # Pass original_prompt so per-position top-5 data is computed for scrubber
568
  orig_prompt = original_prompt_data.get('prompt', '') if original_prompt_data else ''
569
  new_activation_data = execute_forward_pass(
 
612
  return tuple(empty_outputs)
613
 
614
  try:
615
+ from transformers import AutoTokenizer
616
+ model = load_model_for_inference(model_name)
617
  tokenizer = AutoTokenizer.from_pretrained(model_name)
618
+
619
  # Use pre-decoded tokens if available, otherwise decode from input_ids
620
  input_ids = activation_data.get('input_ids', [[]])[0]
621
  tokens = activation_data.get('tokens') or [tokenizer.decode([tid]) for tid in input_ids]
 
921
  return no_update, no_update, no_update
922
 
923
  try:
924
+ from transformers import AutoTokenizer
925
+ model = load_model_for_inference(model_name)
926
  tokenizer = AutoTokenizer.from_pretrained(model_name)
927
+
 
928
  sequence_text = prompt
929
 
930
  config = {
 
1085
  options = []
1086
  for t in global_top5:
1087
  if isinstance(t, dict):
1088
+ prob = t.get('probability')
1089
+ prob_str = f" ({prob:.1%})" if prob is not None else ""
1090
+ options.append({'label': f"{t['token']}{prob_str}", 'value': t['token']})
1091
  else:
1092
  options.append({'label': t[0], 'value': t[0]})
1093
  return options
 
1108
  return no_update
1109
 
1110
  try:
1111
+ from transformers import AutoTokenizer
1112
+ model = load_model_for_inference(model_name)
1113
  tokenizer = AutoTokenizer.from_pretrained(model_name)
1114
+
 
1115
  sequence_text = activation_data.get('prompt', prompt)
1116
 
1117
  # Get target token ID if specified
components/pipeline.py CHANGED
@@ -1127,7 +1127,7 @@ def create_output_content(top_tokens=None, predicted_token=None, predicted_prob=
1127
  fig = go.Figure(go.Bar(
1128
  x=probs, y=tokens, orientation='h',
1129
  marker_color=['#00f2fe' if i == 0 else '#4facfe' for i in range(len(tokens))],
1130
- text=[f"{p:.1%}" for p in probs], textposition='outside',
1131
  hovertemplate='%{y} (%{x:.1%})<extra></extra>'
1132
  ))
1133
  fig.update_layout(
 
1127
  fig = go.Figure(go.Bar(
1128
  x=probs, y=tokens, orientation='h',
1129
  marker_color=['#00f2fe' if i == 0 else '#4facfe' for i in range(len(tokens))],
1130
+ text=[f"{p:.1%}" if p is not None else "?" for p in probs], textposition='outside',
1131
  hovertemplate='%{y} (%{x:.1%})<extra></extra>'
1132
  ))
1133
  fig.update_layout(
debug_logs.md ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -------------------------------------------------------
2
+ Qwen2ForCausalLM LOAD REPORT from: Qwen/Qwen2.5-0.5B
3
+ Key | Status |
4
+ ---------------+---------+-
5
+ lm_head.weight | MISSING |
6
+
7
+ Notes:
8
+ - MISSING :those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.
9
+ -------------------------------------------------------
10
+ Traceback (most recent call last):
11
+ File "/app/app.py", line 383, in run_generation
12
+ results = perform_beam_search(model, tokenizer, prompt, beam_width, max_new_tokens)
13
+ File "/app/utils/beam_search.py", line 142, in perform_beam_search
14
+ outputs = model(seq)
15
+ File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
16
+ return self._call_impl(*args, **kwargs)
17
+ File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
18
+ return forward_call(*args, **kwargs)
19
+ File "/usr/local/lib/python3.10/site-packages/transformers/utils/generic.py", line 843, in wrapper
20
+ output = func(self, *args, **kwargs)
21
+ File "/usr/local/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 476, in forward
22
+ outputs: BaseModelOutputWithPast = self.model(
23
+ File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
24
+ return self._call_impl(*args, **kwargs)
25
+ File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
26
+ return forward_call(*args, **kwargs)
27
+ File "/usr/local/lib/python3.10/site-packages/transformers/utils/generic.py", line 917, in wrapper
28
+ output = func(self, *args, **kwargs)
29
+ File "/usr/local/lib/python3.10/site-packages/transformers/utils/output_capturing.py", line 253, in wrapper
30
+ outputs = func(self, *args, **kwargs)
31
+ File "/usr/local/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 411, in forward
32
+ hidden_states = decoder_layer(
33
+ File "/usr/local/lib/python3.10/site-packages/transformers/modeling_layers.py", line 93, in __call__
34
+ return super().__call__(*args, **kwargs)
35
+ File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
36
+ return self._call_impl(*args, **kwargs)
37
+ File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
38
+ return forward_call(*args, **kwargs)
39
+ File "/usr/local/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 298, in forward
40
+ hidden_states, _ = self.self_attn(
41
+ File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
42
+ return self._call_impl(*args, **kwargs)
43
+ File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
44
+ return forward_call(*args, **kwargs)
45
+ File "/usr/local/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 218, in forward
46
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
47
+ File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
48
+ return self._call_impl(*args, **kwargs)
49
+ File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
50
+ return forward_call(*args, **kwargs)
51
+ File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 134, in forward
52
+ return F.linear(input, self.weight, self.bias)
53
+ RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16
54
+ -------------------------------------------------------
55
+ Qwen2ForCausalLM LOAD REPORT from: Qwen/Qwen2.5-0.5B
56
+ Key | Status |
57
+ ---------------+---------+-
58
+ lm_head.weight | MISSING |
59
+
60
+ Notes:
61
+ - MISSING :those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.
62
+ 10.16.43.195 - - [19/Mar/2026 19:47:47] "POST /_dash-update-component HTTP/1.1" 200 -
63
+ 10.16.43.195 - - [19/Mar/2026 19:47:55] "POST /_dash-update-component HTTP/1.1" 200 -
64
+ 10.16.43.195 - - [19/Mar/2026 19:47:55] "POST /_dash-update-component HTTP/1.1" 200 -
65
+ Executing forward pass with prompt: 'Draw ascii art for a cat'
66
+ Captured 48 module outputs using PyVene
67
+ Loading weights: 0%| | 0/290 [00:00<?, ?it/s]10.16.31.44 - - [19/Mar/2026 19:47:55] "POST /_dash-update-component HTTP/1.1" 200 -
68
+ [2026-03-19 19:47:55,972] ERROR in app: Exception on /_dash-update-component [POST]
69
+ Traceback (most recent call last):
70
+ File "/usr/local/lib/python3.10/site-packages/flask/app.py", line 1511, in wsgi_app
71
+ response = self.full_dispatch_request()
72
+ File "/usr/local/lib/python3.10/site-packages/flask/app.py", line 919, in full_dispatch_request
73
+ rv = self.handle_user_exception(e)
74
+ File "/usr/local/lib/python3.10/site-packages/flask/app.py", line 917, in full_dispatch_request
75
+ rv = self.dispatch_request()
76
+ File "/usr/local/lib/python3.10/site-packages/flask/app.py", line 902, in dispatch_request
77
+ return self.ensure_sync(self.view_functions[rule.endpoint])(**view_args) # type: ignore[no-any-return]
78
+ File "/usr/local/lib/python3.10/site-packages/dash/_get_app.py", line 17, in wrap
79
+ return ctx.run(func, self, *args, **kwargs)
80
+ File "/usr/local/lib/python3.10/site-packages/dash/dash.py", line 1600, in dispatch
81
+ response_data = ctx.run(partial_func)
82
+ File "/usr/local/lib/python3.10/site-packages/dash/_callback.py", line 720, in add_context
83
+ raise err
84
+ File "/usr/local/lib/python3.10/site-packages/dash/_callback.py", line 711, in add_context
85
+ output_value = _invoke_callback(func, *func_args, **func_kwargs) # type: ignore[reportArgumentType]
86
+ File "/usr/local/lib/python3.10/site-packages/dash/_callback.py", line 58, in _invoke_callback
87
+ return func(*args, **kwargs) # %% callback invoked %%
88
+ File "/app/app.py", line 1090, in update_attribution_target_options
89
+ options.append({'label': f"{t['token']} ({t['probability']:.1%})", 'value': t['token']})
90
+ TypeError: unsupported format string passed to NoneType.__format__
91
+ -------------------------------------------------------
92
+ DEBUG extract_layer_data: Found 24 attention modules
93
+ Loading model: gpt2-medium
94
+ Loading weights: 0%| | 0/292 [00:00<?, ?it/s]
95
+ Loading weights: 51%|█████ | 149/292 [00:00<00:00, 1299.04it/s]
96
+ Loading weights: 100%|██████████| 292/292 [00:00<00:00, 1450.02it/s]
97
+ GPT2LMHeadModel LOAD REPORT from: gpt2-medium
98
+ Key | Status |
99
+ ---------------+---------+-
100
+ lm_head.weight | MISSING |
101
+
102
+ Notes:
103
+ - MISSING :those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.
104
+ -------------------------------------------------------
105
+ 10.16.43.195 - - [19/Mar/2026 20:11:03] "POST /_dash-update-component HTTP/1.1" 200 -
106
+ [2026-03-19 20:11:03,238] ERROR in app: Exception on /_dash-update-component [POST]
107
+ Traceback (most recent call last):
108
+ File "/usr/local/lib/python3.10/site-packages/flask/app.py", line 1511, in wsgi_app
109
+ response = self.full_dispatch_request()
110
+ File "/usr/local/lib/python3.10/site-packages/flask/app.py", line 919, in full_dispatch_request
111
+ rv = self.handle_user_exception(e)
112
+ File "/usr/local/lib/python3.10/site-packages/flask/app.py", line 917, in full_dispatch_request
113
+ rv = self.dispatch_request()
114
+ File "/usr/local/lib/python3.10/site-packages/flask/app.py", line 902, in dispatch_request
115
+ return self.ensure_sync(self.view_functions[rule.endpoint])(**view_args) # type: ignore[no-any-return]
116
+ File "/usr/local/lib/python3.10/site-packages/dash/_get_app.py", line 17, in wrap
117
+ return ctx.run(func, self, *args, **kwargs)
118
+ File "/usr/local/lib/python3.10/site-packages/dash/dash.py", line 1600, in dispatch
119
+ response_data = ctx.run(partial_func)
120
+ File "/usr/local/lib/python3.10/site-packages/dash/_callback.py", line 720, in add_context
121
+ raise err
122
+ File "/usr/local/lib/python3.10/site-packages/dash/_callback.py", line 711, in add_context
123
+ output_value = _invoke_callback(func, *func_args, **func_kwargs) # type: ignore[reportArgumentType]
124
+ File "/usr/local/lib/python3.10/site-packages/dash/_callback.py", line 58, in _invoke_callback
125
+ return func(*args, **kwargs) # %% callback invoked %%
126
+ File "/app/app.py", line 1090, in update_attribution_target_options
127
+ options.append({'label': f"{t['token']} ({t['probability']:.1%})", 'value': t['token']})
128
+ TypeError: unsupported format string passed to NoneType.__format__
129
+ 10.16.43.195 - - [19/Mar/2026 20:11:03] "POST /_dash-update-component HTTP/1.1" 500 -
130
+ 10.16.43.195 - - [19/Mar/2026 20:11:03] "POST /_dash-update-component HTTP/1.1" 200 -
131
+ Traceback (most recent call last):
132
+ File "/app/utils/model_patterns.py", line 1337, in generate_bertviz_html
133
+ attention_weights = torch.tensor(attention_output[1]) # [batch, heads, seq, seq]
134
+ RuntimeError: Could not infer dtype of NoneType
135
+ Traceback (most recent call last):
136
+ File "/app/app.py", line 691, in update_pipeline_content
137
+ outputs.append(create_output_content(
138
+ File "/app/components/pipeline.py", line 1130, in create_output_content
139
+ text=[f"{p:.1%}" for p in probs], textposition='outside',
140
+ File "/app/components/pipeline.py", line 1130, in <listcomp>
141
+ text=[f"{p:.1%}" for p in probs], textposition='outside',
142
+ TypeError: unsupported format string passed to NoneType.__format__
143
+ -------------------------------------------------------
144
+ 10.16.43.195 - - [19/Mar/2026 20:26:26] "POST /_dash-update-component HTTP/1.1" 200 -
145
+ DEBUG extract_layer_data: Found 24 attention modules
146
+ Warning: Could not compute logit lens for gpt_neox.layers.0: mixed dtype (CPU): expect parameter to have scalar type of Float
147
+ Warning: Could not compute token probabilities for gpt_neox.layers.0: mixed dtype (CPU): expect parameter to have scalar type of Float
148
+ Warning: Could not compute logit lens for gpt_neox.layers.1: mixed dtype (CPU): expect parameter to have scalar type of Float
149
+ Warning: Could not compute token probabilities for gpt_neox.layers.1: mixed dtype (CPU): expect parameter to have scalar type of Float
150
+ Warning: Could not compute logit lens for gpt_neox.layers.2: mixed dtype (CPU): expect parameter to have scalar type of Float
151
+ Warning: Could not compute token probabilities for gpt_neox.layers.2: mixed dtype (CPU): expect parameter to have scalar type of Float
152
+ -------------------------------------------------------
153
+ Warning: Could not compute logit lens for gpt_neox.layers.13: mixed dtype (CPU): expect parameter to have scalar type of Float
154
+ Warning: Could not compute token probabilities for gpt_neox.layers.13: mixed dtype (CPU): expect parameter to have scalar type of Float
155
+ Warning: Could not compute logit lens for gpt_neox.layers.14: Could not infer dtype of NoneType
156
+ Warning: Could not compute token probabilities for gpt_neox.layers.14: Could not infer dtype of NoneType
157
+ Warning: Could not compute logit lens for gpt_neox.layers.15: Could not infer dtype of NoneType
158
+ Warning: Could not compute token probabilities for gpt_neox.layers.15: Could not infer dtype of NoneType
159
+ Warning: Could not compute logit lens for gpt_neox.layers.16: Could not infer dtype of NoneType
160
+ Warning: Could not compute token probabilities for gpt_neox.layers.16: Could not infer dtype of NoneType
161
+ Warning: Could not compute logit lens for gpt_neox.layers.17: Could not infer dtype of NoneType
162
+ Warning: Could not compute token probabilities for gpt_neox.layers.17: Could not infer dtype of NoneType
163
+ -------------------------------------------------------
tests/test_model_patterns.py CHANGED
@@ -10,8 +10,8 @@ Tests pure logic functions that don't require model loading:
10
  import pytest
11
  import torch
12
  import numpy as np
13
- from utils.model_patterns import merge_token_probabilities, safe_to_serializable
14
- from utils import execute_forward_pass_with_multi_layer_head_ablation
15
 
16
 
17
  class TestMergeTokenProbabilities:
@@ -478,3 +478,43 @@ class TestFullSequenceAttentionData:
478
  attn = data['attention_outputs'][module]['output'][1]
479
  assert len(attn[0][0]) == 8
480
  assert len(attn[0][0][0]) == 8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  import pytest
11
  import torch
12
  import numpy as np
13
+ from utils.model_patterns import merge_token_probabilities, safe_to_serializable, _prepare_hidden_state
14
+ from utils import execute_forward_pass_with_multi_layer_head_ablation, load_model_for_inference
15
 
16
 
17
  class TestMergeTokenProbabilities:
 
478
  attn = data['attention_outputs'][module]['output'][1]
479
  assert len(attn[0][0]) == 8
480
  assert len(attn[0][0][0]) == 8
481
+
482
+
483
+ class TestPrepareHiddenState:
484
+ """Tests for _prepare_hidden_state helper."""
485
+
486
+ def test_raises_on_none(self):
487
+ """_prepare_hidden_state(None) should raise ValueError."""
488
+ with pytest.raises(ValueError, match="Layer output is None"):
489
+ _prepare_hidden_state(None)
490
+
491
+ def test_unwraps_tuple_with_none_second(self):
492
+ """Tuple where second element is None should unwrap first element."""
493
+ result = _prepare_hidden_state(([[1.0, 2.0]], None))
494
+ assert isinstance(result, torch.Tensor)
495
+ assert result.shape[-1] == 2
496
+
497
+ def test_converts_list(self):
498
+ """Plain list should be converted to torch.Tensor."""
499
+ result = _prepare_hidden_state([[[1.0, 2.0]]])
500
+ assert isinstance(result, torch.Tensor)
501
+
502
+
503
+ class TestSafeToSerializableTupleWithNone:
504
+ """Test that safe_to_serializable handles tuples containing None."""
505
+
506
+ def test_tuple_with_tensor_and_none(self):
507
+ """Tuple of (tensor, None) should become [list, None]."""
508
+ tensor = torch.tensor([1.0, 2.0])
509
+ result = safe_to_serializable((tensor, None))
510
+ assert isinstance(result, list)
511
+ assert result[0] == [1.0, 2.0]
512
+ assert result[1] is None
513
+
514
+
515
+ class TestLoadModelForInference:
516
+ """Tests for load_model_for_inference helper."""
517
+
518
+ def test_function_is_importable(self):
519
+ """load_model_for_inference should be importable from utils."""
520
+ assert callable(load_model_for_inference)
utils/__init__.py CHANGED
@@ -1,11 +1,12 @@
1
- from .model_patterns import (load_model_and_get_patterns, execute_forward_pass,
2
- logit_lens_transformation, extract_layer_data,
 
3
  generate_bertviz_html,
4
  execute_forward_pass_with_head_ablation,
5
  execute_forward_pass_with_multi_layer_head_ablation,
6
- merge_token_probabilities,
7
  compute_global_top5_tokens, compute_per_position_top5,
8
- detect_significant_probability_increases,
9
  evaluate_sequence_ablation, generate_bertviz_model_view_html)
10
  from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
11
  from .head_detection import load_head_categories, verify_head_activation, get_active_head_summary
@@ -16,7 +17,8 @@ from .token_attribution import compute_integrated_gradients, compute_simple_grad
16
 
17
  __all__ = [
18
  # Model patterns
19
- 'load_model_and_get_patterns',
 
20
  'execute_forward_pass',
21
  'execute_forward_pass_with_head_ablation',
22
  'execute_forward_pass_with_multi_layer_head_ablation',
 
1
+ from .model_patterns import (load_model_for_inference, load_model_and_get_patterns,
2
+ execute_forward_pass,
3
+ logit_lens_transformation, extract_layer_data,
4
  generate_bertviz_html,
5
  execute_forward_pass_with_head_ablation,
6
  execute_forward_pass_with_multi_layer_head_ablation,
7
+ merge_token_probabilities,
8
  compute_global_top5_tokens, compute_per_position_top5,
9
+ detect_significant_probability_increases,
10
  evaluate_sequence_ablation, generate_bertviz_model_view_html)
11
  from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
12
  from .head_detection import load_head_categories, verify_head_activation, get_active_head_summary
 
17
 
18
  __all__ = [
19
  # Model patterns
20
+ 'load_model_for_inference',
21
+ 'load_model_and_get_patterns',
22
  'execute_forward_pass',
23
  'execute_forward_pass_with_head_ablation',
24
  'execute_forward_pass_with_multi_layer_head_ablation',
utils/model_patterns.py CHANGED
@@ -7,6 +7,26 @@ from typing import Dict, List, Tuple, Any, Optional
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def extract_patterns(model, use_modules=True) -> Dict[str, List[str]]:
11
  """Extract patterns from model modules or parameters."""
12
  items = model.named_modules() if use_modules else model.named_parameters()
@@ -36,9 +56,8 @@ def load_model_and_get_patterns(model_name: str) -> Tuple[Dict[str, List[str]],
36
  print(f"Loading model: {model_name}")
37
 
38
  # Load model and tokenizer
39
- model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
40
  tokenizer = AutoTokenizer.from_pretrained(model_name)
41
- model.eval()
42
 
43
  # Extract patterns
44
  module_patterns = extract_patterns(model, use_modules=True)
@@ -919,10 +938,13 @@ def evaluate_sequence_ablation(model, tokenizer, sequence_text: str, config: Dic
919
 
920
  def _prepare_hidden_state(layer_output: Any) -> torch.Tensor:
921
  """Helper to convert layer output to tensor, handling tuple outputs."""
 
 
 
922
  # Handle PyVene captured tuple outputs where 2nd element is None (e.g. use_cache=False)
923
  if isinstance(layer_output, (list, tuple)) and len(layer_output) > 1 and layer_output[1] is None:
924
  layer_output = layer_output[0]
925
-
926
  hidden = torch.tensor(layer_output) if not isinstance(layer_output, torch.Tensor) else layer_output
927
  if hidden.dim() == 4:
928
  hidden = hidden.squeeze(0)
@@ -954,7 +976,9 @@ def logit_lens_transformation(layer_output: Any, norm_data: List[Any], model, to
954
  with torch.no_grad():
955
  # Convert to tensor and ensure proper shape [batch, seq_len, hidden_dim]
956
  hidden = _prepare_hidden_state(layer_output)
957
-
 
 
958
  # Step 1: Apply final layer normalization (critical for intermediate layers)
959
  final_norm = get_norm_layer_from_parameter(model, norm_parameter)
960
  if final_norm is not None:
@@ -1271,7 +1295,10 @@ def generate_bertviz_model_view_html(activation_data: Dict[str, Any]) -> str:
1271
  attention_output = attention_outputs[module_name]['output']
1272
  if isinstance(attention_output, list) and len(attention_output) >= 2:
1273
  # Get attention weights (element 1 of the output tuple)
1274
- attention_weights = torch.tensor(attention_output[1]) # [batch, heads, seq, seq]
 
 
 
1275
  layer_attention_pairs.append((layer_num, attention_weights))
1276
 
1277
  if not layer_attention_pairs:
@@ -1334,7 +1361,10 @@ def generate_bertviz_html(activation_data: Dict[str, Any], layer_index: int, vie
1334
  attention_output = attention_outputs[module_name]['output']
1335
  if isinstance(attention_output, list) and len(attention_output) >= 2:
1336
  # Get attention weights (element 1 of the output tuple)
1337
- attention_weights = torch.tensor(attention_output[1]) # [batch, heads, seq, seq]
 
 
 
1338
  layer_attention_pairs.append((layer_num, attention_weights))
1339
 
1340
  if not layer_attention_pairs:
 
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
9
 
10
+ def load_model_for_inference(model_name: str):
11
+ """Load model with float32 dtype for CPU stability and verify weight tying."""
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ model_name,
14
+ attn_implementation='eager',
15
+ torch_dtype=torch.float32
16
+ )
17
+ model.eval()
18
+
19
+ # Verify lm_head is properly tied to embeddings (not randomly initialized)
20
+ embed = model.get_input_embeddings()
21
+ lm_head = model.get_output_embeddings()
22
+ if embed is not None and lm_head is not None:
23
+ if embed.weight.data_ptr() != lm_head.weight.data_ptr():
24
+ print(f"Warning: {model_name} lm_head not tied to embeddings, re-tying...")
25
+ model.tie_weights()
26
+
27
+ return model
28
+
29
+
30
  def extract_patterns(model, use_modules=True) -> Dict[str, List[str]]:
31
  """Extract patterns from model modules or parameters."""
32
  items = model.named_modules() if use_modules else model.named_parameters()
 
56
  print(f"Loading model: {model_name}")
57
 
58
  # Load model and tokenizer
59
+ model = load_model_for_inference(model_name)
60
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
61
 
62
  # Extract patterns
63
  module_patterns = extract_patterns(model, use_modules=True)
 
938
 
939
  def _prepare_hidden_state(layer_output: Any) -> torch.Tensor:
940
  """Helper to convert layer output to tensor, handling tuple outputs."""
941
+ if layer_output is None:
942
+ raise ValueError("Layer output is None")
943
+
944
  # Handle PyVene captured tuple outputs where 2nd element is None (e.g. use_cache=False)
945
  if isinstance(layer_output, (list, tuple)) and len(layer_output) > 1 and layer_output[1] is None:
946
  layer_output = layer_output[0]
947
+
948
  hidden = torch.tensor(layer_output) if not isinstance(layer_output, torch.Tensor) else layer_output
949
  if hidden.dim() == 4:
950
  hidden = hidden.squeeze(0)
 
976
  with torch.no_grad():
977
  # Convert to tensor and ensure proper shape [batch, seq_len, hidden_dim]
978
  hidden = _prepare_hidden_state(layer_output)
979
+ # Serialized intermediates may be float64; cast to model dtype
980
+ hidden = hidden.to(dtype=next(model.parameters()).dtype)
981
+
982
  # Step 1: Apply final layer normalization (critical for intermediate layers)
983
  final_norm = get_norm_layer_from_parameter(model, norm_parameter)
984
  if final_norm is not None:
 
1295
  attention_output = attention_outputs[module_name]['output']
1296
  if isinstance(attention_output, list) and len(attention_output) >= 2:
1297
  # Get attention weights (element 1 of the output tuple)
1298
+ raw_weights = attention_output[1]
1299
+ if raw_weights is None:
1300
+ continue # Skip layers with missing attention data
1301
+ attention_weights = torch.tensor(raw_weights) # [batch, heads, seq, seq]
1302
  layer_attention_pairs.append((layer_num, attention_weights))
1303
 
1304
  if not layer_attention_pairs:
 
1361
  attention_output = attention_outputs[module_name]['output']
1362
  if isinstance(attention_output, list) and len(attention_output) >= 2:
1363
  # Get attention weights (element 1 of the output tuple)
1364
+ raw_weights = attention_output[1]
1365
+ if raw_weights is None:
1366
+ continue # Skip layers with missing attention data
1367
+ attention_weights = torch.tensor(raw_weights) # [batch, heads, seq, seq]
1368
  layer_attention_pairs.append((layer_num, attention_weights))
1369
 
1370
  if not layer_attention_pairs: