File size: 4,672 Bytes
c8fa89c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
from .llm_iface import get_or_load_model
from .utils import dbg

def run_diagnostic_suite(model_id: str, seed: int) -> str:
    """
    Führt eine Reihe von Selbsttests durch, um die mechanische Integrität des Experiments zu überprüfen.
    Löst bei einem kritischen Fehler eine Exception aus, um die Ausführung zu stoppen.
    """
    dbg("--- STARTING DIAGNOSTIC SUITE ---")
    results = []

    try:
        # --- Setup ---
        dbg("Loading model for diagnostics...")
        llm = get_or_load_model(model_id, seed)
        test_prompt = "Hello world"
        inputs = llm.tokenizer(test_prompt, return_tensors="pt").to(llm.model.device)

        # --- Test 1: Attention Output Verification ---
        dbg("Running Test 1: Attention Output Verification...")
        # This test ensures that 'eager' attention implementation is active, which is
        # necessary for reliable hook functionality in many transformers versions.
        outputs = llm.model(**inputs, output_attentions=True)
        assert outputs.attentions is not None, "FAIL: `outputs.attentions` is None. 'eager' implementation is likely not active."
        assert isinstance(outputs.attentions, tuple), "FAIL: `outputs.attentions` is not a tuple."
        assert len(outputs.attentions) == llm.config.num_hidden_layers, "FAIL: Number of attention tuples does not match number of layers."
        results.append("✅ Test 1: Attention Output PASSED")
        dbg("Test 1 PASSED.")

        # --- Test 2: Hook Causal Efficacy ---
        dbg("Running Test 2: Hook Causal Efficacy Verification...")
        # This is the most critical test. It verifies that our injection mechanism (via hooks)
        # has a real, causal effect on the model's computation.

        # Run 1: Get the baseline hidden state without any intervention
        outputs_no_hook = llm.model(**inputs, output_hidden_states=True)
        target_layer_idx = llm.config.num_hidden_layers // 2
        state_no_hook = outputs_no_hook.hidden_states[target_layer_idx + 1].clone()

        # Define a simple hook that adds a large, constant value
        injection_value = 42.0
        def test_hook_fn(module, layer_input):
            modified_input = layer_input[0] + injection_value
            return (modified_input,) + layer_input[1:]

        target_layer = llm.model.model.layers[target_layer_idx]
        handle = target_layer.register_forward_pre_hook(test_hook_fn)

        # Run 2: Get the hidden state with the hook active
        outputs_with_hook = llm.model(**inputs, output_hidden_states=True)
        state_with_hook = outputs_with_hook.hidden_states[target_layer_idx + 1].clone()

        handle.remove() # Clean up the hook immediately

        # The core assertion: the hook MUST change the subsequent hidden state.
        assert not torch.allclose(state_no_hook, state_with_hook), \
            "FAIL: Hook had no measurable effect on the subsequent layer's hidden state. Injections are not working."
        results.append("✅ Test 2: Hook Causal Efficacy PASSED")
        dbg("Test 2 PASSED.")

        # --- Test 3: KV-Cache Integrity ---
        dbg("Running Test 3: KV-Cache Integrity Verification...")
        # This test ensures that the `past_key_values` are being passed and updated correctly,
        # which is the core mechanic of the silent cogitation loop.

        # Step 1: Initial pass with `use_cache=True`
        outputs1 = llm.model(**inputs, use_cache=True)
        kv_cache1 = outputs1.past_key_values
        assert kv_cache1 is not None, "FAIL: KV-Cache was not generated in the first pass."

        # Step 2: Second pass using the cache from step 1
        next_token = torch.tensor([[123]], device=llm.model.device) # Arbitrary next token ID
        outputs2 = llm.model(input_ids=next_token, past_key_values=kv_cache1, use_cache=True)
        kv_cache2 = outputs2.past_key_values

        original_seq_len = inputs.input_ids.shape[-1]
        # The sequence length of the keys/values in the cache should have grown by 1
        assert kv_cache2[0][0].shape[-2] == original_seq_len + 1, \
            f"FAIL: KV-Cache sequence length did not update correctly. Expected {original_seq_len + 1}, got {kv_cache2[0][0].shape[-2]}."
        results.append("✅ Test 3: KV-Cache Integrity PASSED")
        dbg("Test 3 PASSED.")

        # Clean up memory
        del llm
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        return "\n".join(results)

    except Exception as e:
        dbg(f"--- DIAGNOSTIC SUITE FAILED --- \n{traceback.format_exc()}")
        # Re-raise the exception to be caught by the Gradio UI
        raise e