WCNegentropy commited on
Commit
7fe700d
Β·
verified Β·
1 Parent(s): 2866d3d

πŸš€ Refined BitTransformerLM: Organized codebase with best practices

Browse files
Files changed (1) hide show
  1. scripts/examples/better_sampling.py +138 -0
scripts/examples/better_sampling.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Better Sampling for BitTransformerLM
4
+ """
5
+
6
+ import sys
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ sys.path.append('/data')
11
+ sys.path.append('/data/BitTransformerLM')
12
+
13
+ from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text
14
+
15
+ def load_model():
16
+ model = BitTransformerLM(
17
+ d_model=512, nhead=16, num_layers=8, dim_feedforward=1024,
18
+ max_seq_len=512, reversible=True, use_checkpoint=False,
19
+ use_autocast=False, use_act=True, act_threshold=0.9,
20
+ lambda_K=0.05, lambda_C=0.05, lambda_S=0.05
21
+ )
22
+
23
+ checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu')
24
+ model.load_state_dict(checkpoint['model_state_dict'])
25
+ model.eval()
26
+
27
+ return model
28
+
29
+ def smart_generate(model, prompt, max_chars=5):
30
+ """Generate with better sampling strategies."""
31
+ print(f"\n🎯 Smart generating from: '{prompt}'")
32
+
33
+ input_bits = text_to_bits(prompt)
34
+ generated_bits = input_bits.copy()
35
+
36
+ with torch.no_grad():
37
+ for char_idx in range(max_chars):
38
+ # Generate 9 bits for one character (8 data + 1 parity)
39
+ char_bits = []
40
+
41
+ for bit_idx in range(9):
42
+ # Context (keep reasonable length)
43
+ context = generated_bits + char_bits
44
+ context = context[-300:] if len(context) > 300 else context
45
+ context_tensor = torch.tensor(context, dtype=torch.long).unsqueeze(0)
46
+
47
+ logits, telemetry = model(context_tensor)
48
+ next_bit_logits = logits[0, -1, :]
49
+
50
+ # Different strategies based on bit position
51
+ if bit_idx < 8: # Data bits
52
+ # Use higher temperature for more variety
53
+ temperature = 0.8
54
+ next_bit_logits = next_bit_logits / temperature
55
+
56
+ # Top-k sampling
57
+ k = 2 # Only 2 options anyway (0 or 1)
58
+ top_k_logits, top_k_indices = torch.topk(next_bit_logits, k)
59
+ probs = F.softmax(top_k_logits, dim=-1)
60
+ selected_idx = torch.multinomial(probs, 1).item()
61
+ next_bit = top_k_indices[selected_idx].item()
62
+ else: # Parity bit
63
+ # Calculate correct parity
64
+ data_bits = char_bits[:8]
65
+ expected_parity = sum(data_bits) % 2
66
+ next_bit = expected_parity
67
+
68
+ char_bits.append(next_bit)
69
+
70
+ # Add completed character
71
+ generated_bits.extend(char_bits)
72
+
73
+ # Try to decode the new character
74
+ try:
75
+ new_char_bits = char_bits
76
+ # Convert to bytes (remove parity)
77
+ data_bits = new_char_bits[:8]
78
+ byte_val = sum(bit * (2**(7-i)) for i, bit in enumerate(data_bits))
79
+
80
+ if 32 <= byte_val <= 126: # Printable ASCII
81
+ char = chr(byte_val)
82
+ print(f" Char {char_idx+1}: '{char}' (byte={byte_val})")
83
+
84
+ # Early stopping for sentence enders
85
+ if char in '.!?\n':
86
+ break
87
+ else:
88
+ print(f" Char {char_idx+1}: Non-printable (byte={byte_val})")
89
+
90
+ except Exception as e:
91
+ print(f" Char {char_idx+1}: Decode error: {e}")
92
+
93
+ # Final decode attempt
94
+ generated_only = generated_bits[len(input_bits):]
95
+ try:
96
+ final_text = bits_to_text(generated_only)
97
+ print(f"✨ Result: '{prompt}' + '{final_text}'")
98
+ return final_text
99
+ except Exception as e:
100
+ print(f"❌ Final decode failed: {e}")
101
+
102
+ # Manual decode of complete characters
103
+ manual_result = ""
104
+ for i in range(0, len(generated_only), 9):
105
+ if i + 8 < len(generated_only):
106
+ char_bits = generated_only[i:i+8] # Just data bits
107
+ byte_val = sum(bit * (2**(7-j)) for j, bit in enumerate(char_bits))
108
+ if 32 <= byte_val <= 126:
109
+ manual_result += chr(byte_val)
110
+ else:
111
+ manual_result += '?'
112
+
113
+ print(f"πŸ”§ Manual decode: '{prompt}' + '{manual_result}'")
114
+ return manual_result
115
+
116
+ def main():
117
+ print("πŸš€ SMART BITRANSFORMERLM GENERATION")
118
+ print("=" * 40)
119
+
120
+ model = load_model()
121
+ print("βœ… Model loaded!")
122
+
123
+ # Test different prompt styles
124
+ prompts = [
125
+ "Hello",
126
+ "Hi",
127
+ "A",
128
+ "The cat",
129
+ "I am",
130
+ "Yes",
131
+ "No"
132
+ ]
133
+
134
+ for prompt in prompts:
135
+ result = smart_generate(model, prompt, max_chars=4)
136
+
137
+ if __name__ == "__main__":
138
+ main()