broadfield-dev commited on
Commit
315b928
·
verified ·
1 Parent(s): 8aa7b0f

Update debug_ee.py

Browse files
Files changed (1) hide show
  1. debug_ee.py +63 -58
debug_ee.py CHANGED
@@ -1,7 +1,5 @@
1
  """
2
-
3
- EE Sanity Check Script
4
- Run this locally (not on HF Spaces) to verify the transform is correct.
5
  Usage:
6
  python debug_ee.py --original Qwen/Qwen3-0.6B --ee your/model-dp-ee --seed 424242
7
  """
@@ -12,9 +10,7 @@ import argparse
12
 
13
  def get_sigma(hidden_size, seed):
14
  rng = np.random.default_rng(seed)
15
- sigma = rng.permutation(hidden_size)
16
- sigma_inv = np.argsort(sigma)
17
- return sigma, sigma_inv
18
 
19
  def run_check(original_name, ee_name, seed, prompt="Hello, how are you?"):
20
  print(f"\n{'='*60}")
@@ -28,66 +24,64 @@ def run_check(original_name, ee_name, seed, prompt="Hello, how are you?"):
28
  inputs = tokenizer(prompt, return_tensors="pt")
29
  input_ids = inputs.input_ids
30
 
31
- print("\n[1] Loading original model...")
32
- orig = AutoModelForCausalLM.from_pretrained(
33
- original_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True
34
- )
35
- orig.eval()
36
-
37
- print("[2] Loading EE model...")
38
- ee = AutoModelForCausalLM.from_pretrained(
39
- ee_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True
40
- )
41
- ee.eval()
42
 
43
  hidden_size = orig.config.hidden_size
44
- sigma, sigma_inv = get_sigma(hidden_size, seed)
 
45
 
46
- # --- Check 1: Does the EE embed layer match original? ---
47
- orig_embed = orig.model.embed_tokens.weight.data
48
- ee_embed = ee.model.embed_tokens.weight.data
49
- embed_match = torch.allclose(orig_embed, ee_embed, atol=1e-3)
 
50
  print(f"\n[CHECK 1] Embed layers identical: {embed_match}")
51
  if not embed_match:
52
- diff = (orig_embed - ee_embed).abs().max().item()
53
- print(f" ⚠️ Max diff: {diff:.6f} — EE embed was permuted, this BREAKS client-side encryption")
54
- print(f" → Re-run transform with the embed layer skipped (see transform_fix.py)")
55
 
56
- # --- Check 2: Run plain forward on original ---
57
  print("\n[CHECK 2] Running plain forward on original...")
58
  with torch.no_grad():
59
- plain_embeds = orig.model.embed_tokens(input_ids)
60
- orig_out = orig(inputs_embeds=plain_embeds, output_hidden_states=False)
61
- orig_logits = orig_out.logits # (1, seq, vocab)
62
 
63
- # --- Check 3: Run encrypted forward on EE model ---
64
  print("[CHECK 3] Running encrypted forward on EE model...")
65
  with torch.no_grad():
66
- encrypted_embeds = plain_embeds[..., sigma]
67
- ee_out = ee(inputs_embeds=encrypted_embeds, output_hidden_states=False)
68
- ee_logits = ee_out.logits
69
 
70
- # --- Check 4: Do logits match? ---
71
- logit_match = torch.allclose(orig_logits, ee_logits, atol=1e-1)
72
  max_diff = (orig_logits - ee_logits).abs().max().item()
73
- print(f"\n[CHECK 4] Logits match (atol=0.1): {logit_match}")
 
74
  print(f" Max logit diff: {max_diff:.4f}")
75
- if not logit_match:
76
- print(" ⚠️ Logits differ — equivariance is BROKEN")
77
- # Find where it breaks — check RoPE suspicion
78
- print("\n Diagnosing: checking if RoPE is the culprit...")
79
- print(" RoPE applies rotation in head_dim space (64), not hidden space (1024)")
80
- print(" If q_proj/k_proj output is permuted (because output==hidden_size),")
81
- print(" the head_dim slices fed to RoPE will be scrambled broken attention")
82
-
83
- # --- Check 5: Greedy decode comparison ---
84
  print("\n[CHECK 5] Greedy decode comparison (10 tokens)...")
85
  with torch.no_grad():
86
- orig_ids = orig.generate(input_ids, max_new_tokens=10, do_sample=False)
87
- ee_ids = ee.generate(inputs_embeds=encrypted_embeds,
88
- attention_mask=inputs.attention_mask,
89
- max_new_tokens=10, do_sample=False,
90
- pad_token_id=tokenizer.eos_token_id)
 
 
 
 
 
 
 
 
 
91
 
92
  orig_text = tokenizer.decode(orig_ids[0], skip_special_tokens=True)
93
  ee_text = tokenizer.decode(ee_ids[0], skip_special_tokens=True)
@@ -95,16 +89,27 @@ def run_check(original_name, ee_name, seed, prompt="Hello, how are you?"):
95
  print(f" EE model output : {repr(ee_text)}")
96
  print(f" Match: {orig_text == ee_text}")
97
 
98
- if orig_text != ee_text:
99
- print("\n ⚠️ OUTPUTS DIFFER. Most likely causes in order:")
100
- print(" 1. Embed layer was permuted in EE model (Check 1 above)")
101
- print(" 2. RoPE disruption q_proj/k_proj output rows were permuted")
102
- print(" FIX: do NOT permute output rows of q_proj and k_proj")
103
- print(" because their outputs are split into heads for RoPE rotation")
104
- print(" 3. Model on Hub is stale — re-run transform and re-push")
 
 
 
105
 
106
  print(f"\n{'='*60}\n")
107
- return embed_match and logit_match
 
 
 
 
 
 
 
 
108
 
109
  if __name__ == "__main__":
110
 
 
1
  """
2
+ EE Sanity Check
 
 
3
  Usage:
4
  python debug_ee.py --original Qwen/Qwen3-0.6B --ee your/model-dp-ee --seed 424242
5
  """
 
10
 
11
  def get_sigma(hidden_size, seed):
12
  rng = np.random.default_rng(seed)
13
+ return rng.permutation(hidden_size)
 
 
14
 
15
  def run_check(original_name, ee_name, seed, prompt="Hello, how are you?"):
16
  print(f"\n{'='*60}")
 
24
  inputs = tokenizer(prompt, return_tensors="pt")
25
  input_ids = inputs.input_ids
26
 
27
+ print("\n[1] Loading models...")
28
+ orig = AutoModelForCausalLM.from_pretrained(original_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True)
29
+ ee = AutoModelForCausalLM.from_pretrained(ee_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True)
30
+ orig.eval(); ee.eval()
 
 
 
 
 
 
 
31
 
32
  hidden_size = orig.config.hidden_size
33
+ sigma = get_sigma(hidden_size, seed)
34
+ sigma_t = torch.tensor(sigma, dtype=torch.long)
35
 
36
+ # --- CHECK 1: Embed layers must be identical ---
37
+ embed_match = torch.allclose(
38
+ orig.model.embed_tokens.weight.data,
39
+ ee.model.embed_tokens.weight.data, atol=1e-3
40
+ )
41
  print(f"\n[CHECK 1] Embed layers identical: {embed_match}")
42
  if not embed_match:
43
+ print(" ⚠️ Embed was permuted — client-side encryption will be double-permuted")
 
 
44
 
45
+ # --- CHECK 2 & 3: Forward pass with encrypted embeds ---
46
  print("\n[CHECK 2] Running plain forward on original...")
47
  with torch.no_grad():
48
+ plain_embeds = orig.model.embed_tokens(input_ids) # use ORIGINAL embed
49
+ orig_logits = orig(inputs_embeds=plain_embeds).logits
 
50
 
 
51
  print("[CHECK 3] Running encrypted forward on EE model...")
52
  with torch.no_grad():
53
+ # Client encrypts: take plain embeds, apply sigma
54
+ encrypted_embeds = plain_embeds[..., sigma_t]
55
+ ee_logits = ee(inputs_embeds=encrypted_embeds).logits
56
 
57
+ # --- CHECK 4: Logits ---
 
58
  max_diff = (orig_logits - ee_logits).abs().max().item()
59
+ match = max_diff < 0.5
60
+ print(f"\n[CHECK 4] Logits match (atol=0.1): {match}")
61
  print(f" Max logit diff: {max_diff:.4f}")
62
+ if not match:
63
+ print(" ⚠️ Equivariance BROKEN")
64
+
65
+ # --- CHECK 5: Greedy decode ---
66
+ # Both models must use inputs_embeds (not input_ids).
67
+ # Original uses plain embeds, EE uses sigma-encrypted embeds.
68
+ # Their outputs should be identical token sequences.
 
 
69
  print("\n[CHECK 5] Greedy decode comparison (10 tokens)...")
70
  with torch.no_grad():
71
+ orig_ids = orig.generate(
72
+ inputs_embeds=plain_embeds,
73
+ attention_mask=inputs.attention_mask,
74
+ max_new_tokens=10,
75
+ do_sample=False,
76
+ pad_token_id=tokenizer.eos_token_id
77
+ )
78
+ ee_ids = ee.generate(
79
+ inputs_embeds=encrypted_embeds,
80
+ attention_mask=inputs.attention_mask,
81
+ max_new_tokens=10,
82
+ do_sample=False,
83
+ pad_token_id=tokenizer.eos_token_id
84
+ )
85
 
86
  orig_text = tokenizer.decode(orig_ids[0], skip_special_tokens=True)
87
  ee_text = tokenizer.decode(ee_ids[0], skip_special_tokens=True)
 
89
  print(f" EE model output : {repr(ee_text)}")
90
  print(f" Match: {orig_text == ee_text}")
91
 
92
+ if orig_text == ee_text:
93
+ print("\n All checks passed EE transform is correct")
94
+ else:
95
+ print("\n⚠️ Text differs despite logits matching.")
96
+ print(" This usually means floating point drift in autoregressive generation.")
97
+ print(" Check if token IDs match even if decoded text differs slightly:")
98
+ print(f" orig_ids: {orig_ids[0].tolist()}")
99
+ print(f" ee_ids: {ee_ids[0].tolist()}")
100
+ ids_match = orig_ids[0].tolist() == ee_ids[0].tolist()
101
+ print(f" Token IDs match: {ids_match}")
102
 
103
  print(f"\n{'='*60}\n")
104
+
105
+ '''if __name__ == "__main__":
106
+ parser = argparse.ArgumentParser()
107
+ parser.add_argument("--original", required=True)
108
+ parser.add_argument("--ee", required=True)
109
+ parser.add_argument("--seed", type=int, required=True)
110
+ parser.add_argument("--prompt", default="Hello, how are you?")
111
+ args = parser.parse_args()
112
+ run_check(args.original, args.ee, args.seed, args.prompt)'''
113
 
114
  if __name__ == "__main__":
115