WCNegentropy commited on
Commit
42dd387
Β·
verified Β·
1 Parent(s): f64dfb1

πŸš€ Refined BitTransformerLM: Organized codebase with best practices

Browse files
Files changed (1) hide show
  1. scripts/examples/simple_test.py +117 -0
scripts/examples/simple_test.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Simple BitTransformerLM Test - No Interactive Input
4
+ """
5
+
6
+ import sys
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ # Add paths for imports
11
+ sys.path.append('/data')
12
+ sys.path.append('/data/BitTransformerLM')
13
+
14
+ from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text
15
+
16
+ def test_breakthrough_model():
17
+ """Simple test of the breakthrough model."""
18
+ print("πŸš€ Loading breakthrough BitTransformerLM...")
19
+
20
+ # Create model with exact config
21
+ model = BitTransformerLM(
22
+ d_model=512,
23
+ nhead=16,
24
+ num_layers=8,
25
+ dim_feedforward=1024,
26
+ max_seq_len=512,
27
+ reversible=True,
28
+ use_checkpoint=False, # Disable for inference
29
+ use_autocast=False, # Disable for inference
30
+ use_act=True,
31
+ act_threshold=0.9,
32
+ lambda_K=0.05,
33
+ lambda_C=0.05,
34
+ lambda_S=0.05
35
+ )
36
+
37
+ # Load checkpoint
38
+ checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu')
39
+ model.load_state_dict(checkpoint['model_state_dict'])
40
+ model.eval()
41
+
42
+ print(f"βœ… Model loaded! Loss: {checkpoint['loss']:.6f}")
43
+
44
+ # Simple test prompts
45
+ prompts = [
46
+ "Hello",
47
+ "Hi there",
48
+ "What is your name?",
49
+ "The weather is"
50
+ ]
51
+
52
+ for prompt in prompts:
53
+ print(f"\nπŸ€– Testing: '{prompt}'")
54
+
55
+ # Convert to bits
56
+ input_bits = text_to_bits(prompt)
57
+ input_tensor = torch.tensor(input_bits, dtype=torch.long).unsqueeze(0)
58
+
59
+ print(f"πŸ“ Input: {len(input_bits)} bits")
60
+
61
+ with torch.no_grad():
62
+ try:
63
+ # Forward pass
64
+ logits, telemetry = model(input_tensor)
65
+
66
+ # Get next bit probabilities
67
+ next_probs = F.softmax(logits[0, -1, :], dim=-1)
68
+
69
+ print(f"🎯 Next bit probs: [0]={next_probs[0]:.3f}, [1]={next_probs[1]:.3f}")
70
+
71
+ if telemetry:
72
+ k_val = telemetry.get('negentropy_logits', 0)
73
+ c_val = telemetry.get('lz_complexity_logits', 0)
74
+ s_val = telemetry.get('symbiosis_score', 0)
75
+
76
+ # Convert to scalar if tensor
77
+ if torch.is_tensor(k_val):
78
+ k_val = k_val.mean().item()
79
+ if torch.is_tensor(c_val):
80
+ c_val = c_val.mean().item()
81
+ if torch.is_tensor(s_val):
82
+ s_val = s_val.mean().item()
83
+
84
+ print(f"πŸ“Š Telemetry: K={k_val:.3f}, C={c_val:.3f}, S={s_val:.3f}")
85
+
86
+ # Try simple generation (just 18 bits = 2 characters)
87
+ generated_bits = input_bits.copy()
88
+
89
+ for i in range(18): # 2 characters worth
90
+ current_tensor = torch.tensor(generated_bits, dtype=torch.long).unsqueeze(0)
91
+ if current_tensor.size(1) > 500: # Truncate if too long
92
+ current_tensor = current_tensor[:, -500:]
93
+
94
+ logits, _ = model(current_tensor)
95
+ next_bit_logits = logits[0, -1, :]
96
+
97
+ # Sample with temperature
98
+ next_bit_logits = next_bit_logits / 0.8
99
+ probs = F.softmax(next_bit_logits, dim=-1)
100
+ next_bit = torch.multinomial(probs, 1).item()
101
+
102
+ generated_bits.append(next_bit)
103
+
104
+ # Try to decode
105
+ generated_only = generated_bits[len(input_bits):]
106
+ try:
107
+ generated_text = bits_to_text(generated_only)
108
+ print(f"✨ Generated: '{generated_text}'")
109
+ except Exception as e:
110
+ print(f"πŸ”§ Decode failed: {e}")
111
+ print(f"Raw bits: {generated_only}")
112
+
113
+ except Exception as e:
114
+ print(f"❌ Model error: {e}")
115
+
116
+ if __name__ == "__main__":
117
+ test_breakthrough_model()