| """ |
| Test: output_attentions が正しく Attention Output を返すか検証する。 |
| |
| Gemma4TextDecoderLayer は output_attentions=True のとき、 |
| (hidden_states, attn_output) を返す。attn_output は self_attn の出力 |
| (post_attention_layernorm 適用前の hidden states)。 |
| |
| capture_outputs フックは Gemma4TextAttention の output[1] (attn_weights) を |
| キャプチャするが、sdpa 実装では attn_weights=None のため空になる。 |
| そこで DecoderLayer レベルで attn_output が正しく取得できるかを検証する。 |
| """ |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| MODEL_PATH = "/workspace/llm/gemma-4-31B-Text" |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) |
| inputs = tokenizer("hello", return_tensors="pt") |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_PATH, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
| inputs = inputs.to(model.device) |
|
|
| num_layers = model.config.num_hidden_layers |
| hidden_size = model.config.hidden_size |
| seq_len = inputs["input_ids"].shape[1] |
| batch_size = inputs["input_ids"].shape[0] |
|
|
| print(f"Model: num_layers={num_layers}, hidden_size={hidden_size}") |
| print(f"Input: batch={batch_size}, seq_len={seq_len}") |
|
|
| |
| |
| |
| print("\n=== Test 1: Gemma4TextModel.forward(output_attentions=True) ===") |
| with torch.no_grad(): |
| text_outputs = model.model( |
| **inputs, |
| output_attentions=True, |
| use_cache=False, |
| ) |
|
|
| attentions = text_outputs.attentions |
| print(f"attentions is None: {attentions is None}") |
|
|
| if attentions is not None: |
| print(f"Number of attention entries: {len(attentions)}") |
| if len(attentions) > 0: |
| for i, attn in enumerate(attentions): |
| if attn is None: |
| print(f" Layer {i}: None") |
| else: |
| print(f" Layer {i}: shape={attn.shape}, dtype={attn.dtype}") |
| if i == 0: |
| |
| expected_shape = (batch_size, seq_len, hidden_size) |
| if attn.shape == expected_shape: |
| print(f" PASS: shape matches expected {expected_shape}") |
| else: |
| print(f" FAIL: expected {expected_shape}, got {attn.shape}") |
| else: |
| print(" (empty tuple - capture_outputs hook did not collect anything)") |
|
|
| |
| |
| |
| print("\n=== Test 2: DecoderLayer direct call with output_attentions=True ===") |
| with torch.no_grad(): |
| |
| input_ids = inputs["input_ids"].to(model.device) |
| inputs_embeds = model.model.embed_tokens(input_ids) |
| position_ids = torch.arange(seq_len, device=model.device).unsqueeze(0) |
|
|
| |
| layer_type = model.config.layer_types[0] |
| position_embeddings = model.model.rotary_emb(inputs_embeds, position_ids, layer_type) |
|
|
| |
| first_layer = model.model.layers[0] |
|
|
| layer_outputs = first_layer( |
| inputs_embeds, |
| per_layer_input=None, |
| position_embeddings=position_embeddings, |
| attention_mask=None, |
| position_ids=position_ids, |
| past_key_values=None, |
| output_attentions=True, |
| ) |
|
|
| print(f"DecoderLayer returned {len(layer_outputs)} outputs") |
| if len(layer_outputs) >= 2: |
| hidden_out = layer_outputs[0] |
| attn_out = layer_outputs[1] |
| print(f" hidden_states: shape={hidden_out.shape}, dtype={hidden_out.dtype}") |
| print(f" attn_output: shape={attn_out.shape}, dtype={attn_out.dtype}") |
|
|
| expected_shape = (batch_size, seq_len, hidden_size) |
| if attn_out.shape == expected_shape: |
| print(f" PASS: attn_output shape is correct {expected_shape}") |
| else: |
| print(f" FAIL: expected {expected_shape}, got {attn_out.shape}") |
|
|
| |
| if attn_out.abs().sum() > 0: |
| print(f" PASS: attn_output is non-zero (norm={attn_out.float().norm().item():.4f})") |
| else: |
| print(f" FAIL: attn_output is all zeros") |
|
|
| |
| |
| if not torch.equal(hidden_out, attn_out): |
| print(f" PASS: attn_output differs from hidden_states (as expected)") |
| else: |
| print(f" FAIL: attn_output is identical to hidden_states") |
| else: |
| print(f" FAIL: expected 2 outputs, got {len(layer_outputs)}") |
|
|
| |
| |
| |
| print("\n=== Test 3: DecoderLayer with output_attentions=False ===") |
| with torch.no_grad(): |
| layer_outputs_no_attn = first_layer( |
| inputs_embeds, |
| per_layer_input=None, |
| position_embeddings=position_embeddings, |
| attention_mask=None, |
| position_ids=position_ids, |
| past_key_values=None, |
| output_attentions=False, |
| ) |
| print(f"DecoderLayer returned {len(layer_outputs_no_attn)} outputs") |
| if len(layer_outputs_no_attn) == 1: |
| print(" PASS: only hidden_states returned (no attn_output)") |
| else: |
| print(f" FAIL: expected 1 output, got {len(layer_outputs_no_attn)}") |
|
|
| |
| |
| |
| print("\n=== Test 4: Gemma4ForCausalLM output_attentions propagation ===") |
| with torch.no_grad(): |
| causal_outputs = model(**inputs, output_attentions=True, use_cache=False) |
|
|
| attentions_causal = causal_outputs.attentions |
| print(f"CausalLM attentions is None: {attentions_causal is None}") |
| if attentions_causal is not None: |
| print(f"CausalLM attentions length: {len(attentions_causal)}") |
| if len(attentions_causal) == num_layers: |
| print(f" PASS: got {num_layers} layers of attention output") |
| elif len(attentions_causal) == 0: |
| print(f" FAIL: empty tuple (capture_outputs hook could not collect attn_weights from sdpa)") |
| print(f" NOTE: This is a known issue - sdpa does not return attention weights.") |
| print(f" Use attn_implementation='eager' to get attention weights via this path.") |
| else: |
| print(f" Got {len(attentions_causal)} (expected {num_layers})") |
|
|
| |
| |
| |
| print("\n" + "=" * 60) |
| print("SUMMARY") |
| print("=" * 60) |
| print("- DecoderLayer correctly returns attn_output when output_attentions=True") |
| print("- DecoderLayer correctly omits attn_output when output_attentions=False") |
| print("- capture_outputs hook on CausalLM/TextModel collects Gemma4TextAttention output[1]") |
| print(" which is attn_weights (None with sdpa), so CausalLM.attentions is empty.") |
| print("- To get attention outputs at model level, either:") |
| print(" (a) use attn_implementation='eager', or") |
| print(" (b) access DecoderLayer outputs directly.") |
|
|