WCNegentropy commited on
Commit
f64dfb1
·
verified ·
1 Parent(s): 3928bd8

🚀 Refined BitTransformerLM: Organized codebase with best practices

Browse files
Files changed (1) hide show
  1. scripts/examples/raw_generation.py +121 -0
scripts/examples/raw_generation.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Raw BitTransformerLM Generation - Bypass Parity
4
+ """
5
+
6
+ import sys
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ sys.path.append('/data')
11
+ sys.path.append('/data/BitTransformerLM')
12
+
13
+ from bit_transformer import BitTransformerLM, text_to_bits
14
+
15
+ def load_model():
16
+ model = BitTransformerLM(
17
+ d_model=512, nhead=16, num_layers=8, dim_feedforward=1024,
18
+ max_seq_len=512, reversible=True, use_checkpoint=False,
19
+ use_autocast=False, use_act=True, act_threshold=0.9,
20
+ lambda_K=0.05, lambda_C=0.05, lambda_S=0.05
21
+ )
22
+
23
+ checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu')
24
+ model.load_state_dict(checkpoint['model_state_dict'])
25
+ model.eval()
26
+
27
+ return model, checkpoint['loss']
28
+
29
+ def bits_to_ascii_raw(bits):
30
+ """Convert bits to ASCII without parity checking."""
31
+ if len(bits) % 8 != 0:
32
+ # Pad to multiple of 8
33
+ bits = bits + [0] * (8 - len(bits) % 8)
34
+
35
+ chars = []
36
+ for i in range(0, len(bits), 8):
37
+ byte_bits = bits[i:i+8]
38
+ byte_value = sum(bit * (2 ** (7-j)) for j, bit in enumerate(byte_bits))
39
+
40
+ # Only accept printable ASCII
41
+ if 32 <= byte_value <= 126:
42
+ chars.append(chr(byte_value))
43
+ elif byte_value == 10: # newline
44
+ chars.append('\n')
45
+ elif byte_value == 13: # carriage return
46
+ chars.append('\r')
47
+ else:
48
+ chars.append('�') # replacement for non-printable
49
+
50
+ return ''.join(chars)
51
+
52
+ def generate_raw(model, prompt, num_bits=72): # 9 bytes worth
53
+ """Generate bits and decode as raw ASCII."""
54
+ print(f"\n🎯 Generating {num_bits} bits from: '{prompt}'")
55
+
56
+ input_bits = text_to_bits(prompt)
57
+ print(f"Input: {len(input_bits)} bits")
58
+
59
+ generated_bits = input_bits.copy()
60
+
61
+ with torch.no_grad():
62
+ for i in range(num_bits):
63
+ # Context window
64
+ context_bits = generated_bits[-400:] if len(generated_bits) > 400 else generated_bits
65
+ context_tensor = torch.tensor(context_bits, dtype=torch.long).unsqueeze(0)
66
+
67
+ logits, telemetry = model(context_tensor)
68
+ next_bit_logits = logits[0, -1, :]
69
+
70
+ # Lower temperature for more coherent output
71
+ temperature = 0.6
72
+ next_bit_logits = next_bit_logits / temperature
73
+ probs = F.softmax(next_bit_logits, dim=-1)
74
+ next_bit = torch.multinomial(probs, 1).item()
75
+
76
+ generated_bits.append(next_bit)
77
+
78
+ # Progress update
79
+ if (i + 1) % 16 == 0: # Every 2 bytes
80
+ generated_only = generated_bits[len(input_bits):]
81
+ partial_text = bits_to_ascii_raw(generated_only)
82
+ print(f" {i+1:2d} bits: '{partial_text}'")
83
+
84
+ # Final decode
85
+ generated_only = generated_bits[len(input_bits):]
86
+ final_text = bits_to_ascii_raw(generated_only)
87
+
88
+ print(f"✨ Final: '{prompt}' + '{final_text}'")
89
+
90
+ if telemetry:
91
+ k = telemetry.get('negentropy_logits', 0)
92
+ c = telemetry.get('lz_complexity_logits', 0)
93
+ s = telemetry.get('symbiosis_score', 0)
94
+ if torch.is_tensor(k): k = k.mean().item()
95
+ if torch.is_tensor(c): c = c.mean().item()
96
+ if torch.is_tensor(s): s = s.mean().item()
97
+ print(f"📊 Telemetry: K={k:.3f}, C={c:.3f}, S={s:.3f}")
98
+
99
+ return final_text
100
+
101
+ def main():
102
+ print("🚀 RAW BITRANSFORMERLM GENERATION")
103
+ print("=" * 40)
104
+
105
+ model, loss = load_model()
106
+ print(f"✅ Model loaded! Loss: {loss:.6f}")
107
+
108
+ prompts = [
109
+ "Hello",
110
+ "Hi there",
111
+ "What",
112
+ "The weather",
113
+ "AI:",
114
+ "Q: What is your name?\nA:"
115
+ ]
116
+
117
+ for prompt in prompts:
118
+ generate_raw(model, prompt, num_bits=64) # 8 characters worth
119
+
120
+ if __name__ == "__main__":
121
+ main()