CharlesCNorton commited on
Commit
df99f2e
·
1 Parent(s): e84332b

Update play.py and eval.py to current bit-cascade comparator/mod-5 layout

Browse files

play.py: alu_compare and mod5 still called legacy gate names from before
the comparators and modular detectors were bit-cascaded. CMP demo crashed
on a 16-vs-8 input shape mismatch; mod-5 demo crashed on a missing
modular.mod5.layer3.or.weight key. Both now walk the current gate family
(arithmetic.cmp8bit.* per-bit gt/lt/eq + cascade.eq_prefix + cascade.gt/lt
+ final OR/AND; modular.mod5.eq.k{val}.bit{i}.match + .all + top OR).

eval.py: _test_integration's add+compare and sub+conditional sub-tests
called pop['arithmetic.{greater,less}than8bit.weight'].view(pop_size, 16)
against the current 8-element OR-tree weight tensor, raising RuntimeError
that was caught and silently SKIPped. Added _pop_cmp8bit helper that
drives the bit-cascade comparator across the population and rewrote both
sub-tests to use it. Wired _test_integration into evaluate(); previously
defined but never called. Test count: 6,364 -> 6,379 on the ALU variants
(+15 integration cases), all passing at fit=1.0000.

Files changed (2) hide show
  1. eval.py +67 -10
  2. play.py +32 -14
eval.py CHANGED
@@ -4253,6 +4253,60 @@ class BatchedFitnessEvaluator:
4253
  # INTEGRATION TESTS (Multi-circuit chains)
4254
  # =========================================================================
4255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4256
  def _test_integration(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
4257
  """Test complex operations that chain multiple circuit families."""
4258
  pop_size = next(iter(pop.values())).shape[0]
@@ -4279,11 +4333,8 @@ class BatchedFitnessEvaluator:
4279
  c_bits = torch.tensor([((c >> (7 - i)) & 1) for i in range(8)],
4280
  device=self.device, dtype=torch.float32)
4281
 
4282
- # Use comparator
4283
- w = pop['arithmetic.greaterthan8bit.weight'].view(pop_size, 16)
4284
- bias = pop['arithmetic.greaterthan8bit.bias'].view(pop_size)
4285
- inp = torch.cat([sum_bits, c_bits])
4286
- out = heaviside((inp * w).sum(-1) + bias)
4287
  correct = (out == expected).float()
4288
  op_scores += correct
4289
  op_total += 1
@@ -4407,11 +4458,8 @@ class BatchedFitnessEvaluator:
4407
  b_bits = torch.tensor([((b >> (7 - i)) & 1) for i in range(8)],
4408
  device=self.device, dtype=torch.float32)
4409
 
4410
- # Check LT comparator
4411
- w = pop['arithmetic.lessthan8bit.weight'].view(pop_size, 16)
4412
- bias = pop['arithmetic.lessthan8bit.bias'].view(pop_size)
4413
- inp = torch.cat([a_bits, b_bits])
4414
- lt_out = heaviside((inp * w).sum(-1) + bias)
4415
 
4416
  correct = (lt_out[0].item() == float(is_negative))
4417
  op_scores += float(correct)
@@ -4734,6 +4782,15 @@ class BatchedFitnessEvaluator:
4734
  total_tests += t
4735
  self.category_scores['float32_cmp'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
4736
 
 
 
 
 
 
 
 
 
 
4737
  self.total_tests = total_tests
4738
 
4739
  if debug:
 
4253
  # INTEGRATION TESTS (Multi-circuit chains)
4254
  # =========================================================================
4255
 
4256
+ def _pop_cmp8bit(self, pop: Dict, pop_size: int,
4257
+ a_bits: torch.Tensor, b_bits: torch.Tensor,
4258
+ kind: str) -> torch.Tensor:
4259
+ """Drive the bit-cascade comparator (cmp8bit) over a population.
4260
+
4261
+ Returns a (pop_size,) tensor of heaviside outputs for the requested
4262
+ comparison kind ('gt' | 'lt' | 'eq'). Bit 0 is MSB.
4263
+ """
4264
+ def apply(name: str, inp: torch.Tensor, fan_in: int) -> torch.Tensor:
4265
+ w = pop[f'{name}.weight'].view(pop_size, fan_in)
4266
+ b = pop[f'{name}.bias'].view(pop_size)
4267
+ return heaviside((inp * w).sum(-1) + b)
4268
+
4269
+ # Per-bit primitives.
4270
+ bit_gt, bit_lt, bit_eq = [], [], []
4271
+ for i in range(8):
4272
+ ab = torch.stack([a_bits[i], b_bits[i]])
4273
+ bit_gt.append(apply(f'arithmetic.cmp8bit.bit{i}.gt', ab, 2))
4274
+ bit_lt.append(apply(f'arithmetic.cmp8bit.bit{i}.lt', ab, 2))
4275
+ eq_and = apply(f'arithmetic.cmp8bit.bit{i}.eq.layer1.and', ab, 2)
4276
+ eq_nor = apply(f'arithmetic.cmp8bit.bit{i}.eq.layer1.nor', ab, 2)
4277
+ eq_in = torch.stack([eq_and, eq_nor], dim=-1)
4278
+ w = pop[f'arithmetic.cmp8bit.bit{i}.eq.weight'].view(pop_size, 2)
4279
+ b = pop[f'arithmetic.cmp8bit.bit{i}.eq.bias'].view(pop_size)
4280
+ bit_eq.append(heaviside((eq_in * w).sum(-1) + b))
4281
+
4282
+ # Cascade.
4283
+ cas_gt = [bit_gt[0]]
4284
+ cas_lt = [bit_lt[0]]
4285
+ for i in range(1, 8):
4286
+ eq_pref_in = torch.stack(bit_eq[:i], dim=-1)
4287
+ w_pref = pop[f'arithmetic.cmp8bit.cascade.eq_prefix.bit{i}.weight'].view(pop_size, i)
4288
+ b_pref = pop[f'arithmetic.cmp8bit.cascade.eq_prefix.bit{i}.bias'].view(pop_size)
4289
+ eq_pref = heaviside((eq_pref_in * w_pref).sum(-1) + b_pref)
4290
+ cas_in = torch.stack([eq_pref, bit_gt[i]], dim=-1)
4291
+ w_g = pop[f'arithmetic.cmp8bit.cascade.gt.bit{i}.weight'].view(pop_size, 2)
4292
+ b_g = pop[f'arithmetic.cmp8bit.cascade.gt.bit{i}.bias'].view(pop_size)
4293
+ cas_gt.append(heaviside((cas_in * w_g).sum(-1) + b_g))
4294
+ cas_in_lt = torch.stack([eq_pref, bit_lt[i]], dim=-1)
4295
+ w_l = pop[f'arithmetic.cmp8bit.cascade.lt.bit{i}.weight'].view(pop_size, 2)
4296
+ b_l = pop[f'arithmetic.cmp8bit.cascade.lt.bit{i}.bias'].view(pop_size)
4297
+ cas_lt.append(heaviside((cas_in_lt * w_l).sum(-1) + b_l))
4298
+
4299
+ if kind == 'gt':
4300
+ inp = torch.stack(cas_gt, dim=-1)
4301
+ return apply('arithmetic.greaterthan8bit', inp, 8)
4302
+ if kind == 'lt':
4303
+ inp = torch.stack(cas_lt, dim=-1)
4304
+ return apply('arithmetic.lessthan8bit', inp, 8)
4305
+ if kind == 'eq':
4306
+ inp = torch.stack(bit_eq, dim=-1)
4307
+ return apply('arithmetic.equality8bit', inp, 8)
4308
+ raise ValueError(kind)
4309
+
4310
  def _test_integration(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
4311
  """Test complex operations that chain multiple circuit families."""
4312
  pop_size = next(iter(pop.values())).shape[0]
 
4333
  c_bits = torch.tensor([((c >> (7 - i)) & 1) for i in range(8)],
4334
  device=self.device, dtype=torch.float32)
4335
 
4336
+ # Drive sum_bits vs c_bits through the bit-cascade comparator.
4337
+ out = self._pop_cmp8bit(pop, pop_size, sum_bits, c_bits, 'gt')
 
 
 
4338
  correct = (out == expected).float()
4339
  op_scores += correct
4340
  op_total += 1
 
4458
  b_bits = torch.tensor([((b >> (7 - i)) & 1) for i in range(8)],
4459
  device=self.device, dtype=torch.float32)
4460
 
4461
+ # Drive a_bits vs b_bits through the bit-cascade LT comparator.
4462
+ lt_out = self._pop_cmp8bit(pop, pop_size, a_bits, b_bits, 'lt')
 
 
 
4463
 
4464
  correct = (lt_out[0].item() == float(is_negative))
4465
  op_scores += float(correct)
 
4782
  total_tests += t
4783
  self.category_scores['float32_cmp'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
4784
 
4785
+ # Cross-family integration tests (chain ripple-carry, comparator,
4786
+ # modular, shifts, subtractor). Each test is internally guarded with
4787
+ # try/except, so unsupported variants silently skip individual tests.
4788
+ if 'arithmetic.cmp8bit.bit0.gt.weight' in population:
4789
+ s, t = self._test_integration(population, debug)
4790
+ scores += s
4791
+ total_tests += t
4792
+ self.category_scores['integration'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
4793
+
4794
  self.total_tests = total_tests
4795
 
4796
  if debug:
play.py CHANGED
@@ -145,12 +145,30 @@ def main() -> int:
145
  return bits_msb_to_int(list(reversed(diff_lsb))), carry
146
 
147
  def alu_compare(a, b, kind):
148
- inp = int_to_bits_msb(a, 8) + int_to_bits_msb(b, 8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  if kind == "eq":
150
- h_geq = gate("arithmetic.equality8bit.layer1.geq", inp)
151
- h_leq = gate("arithmetic.equality8bit.layer1.leq", inp)
152
- return gate("arithmetic.equality8bit.layer2", [h_geq, h_leq])
153
- return gate(f"arithmetic.{kind}8bit", inp)
154
 
155
  def alu_mul(a, b):
156
  a_bits = int_to_bits_msb(a, 8)
@@ -197,16 +215,16 @@ def main() -> int:
197
  print("=" * 64)
198
 
199
  def mod5(v):
 
 
 
200
  bits = int_to_bits_msb(v, 8)
201
- n = 0
202
- while f"modular.mod5.layer1.geq{n}.weight" in T:
203
- n += 1
204
- eqs = []
205
- for i in range(n):
206
- h_geq = gate(f"modular.mod5.layer1.geq{i}", bits)
207
- h_leq = gate(f"modular.mod5.layer1.leq{i}", bits)
208
- eqs.append(gate(f"modular.mod5.layer2.eq{i}", [h_geq, h_leq]))
209
- return gate("modular.mod5.layer3.or", eqs)
210
 
211
  hits = [v for v in range(256) if mod5(v)]
212
  print(f" v in [0,255] with mod5(v)==1: {len(hits)} hits, first 12: {hits[:12]}")
 
145
  return bits_msb_to_int(list(reversed(diff_lsb))), carry
146
 
147
  def alu_compare(a, b, kind):
148
+ # Walks the bit-cascade comparator family: per-bit gt/lt/eq, cascaded
149
+ # eq_prefix, cascade.gt/lt, and the final OR/AND gates. Bit 0 is MSB.
150
+ a_msb = int_to_bits_msb(a, 8)
151
+ b_msb = int_to_bits_msb(b, 8)
152
+ bit_gt = [gate(f"arithmetic.cmp8bit.bit{i}.gt", [a_msb[i], b_msb[i]]) for i in range(8)]
153
+ bit_lt = [gate(f"arithmetic.cmp8bit.bit{i}.lt", [a_msb[i], b_msb[i]]) for i in range(8)]
154
+ bit_eq = []
155
+ for i in range(8):
156
+ eq_and = gate(f"arithmetic.cmp8bit.bit{i}.eq.layer1.and", [a_msb[i], b_msb[i]])
157
+ eq_nor = gate(f"arithmetic.cmp8bit.bit{i}.eq.layer1.nor", [a_msb[i], b_msb[i]])
158
+ bit_eq.append(gate(f"arithmetic.cmp8bit.bit{i}.eq", [eq_and, eq_nor]))
159
+ cas_gt = [bit_gt[0]]
160
+ cas_lt = [bit_lt[0]]
161
+ for i in range(1, 8):
162
+ eq_pref = gate(f"arithmetic.cmp8bit.cascade.eq_prefix.bit{i}", bit_eq[:i])
163
+ cas_gt.append(gate(f"arithmetic.cmp8bit.cascade.gt.bit{i}", [eq_pref, bit_gt[i]]))
164
+ cas_lt.append(gate(f"arithmetic.cmp8bit.cascade.lt.bit{i}", [eq_pref, bit_lt[i]]))
165
+ if kind == "greaterthan":
166
+ return gate("arithmetic.greaterthan8bit", cas_gt)
167
+ if kind == "lessthan":
168
+ return gate("arithmetic.lessthan8bit", cas_lt)
169
  if kind == "eq":
170
+ return gate("arithmetic.equality8bit", bit_eq)
171
+ raise ValueError(kind)
 
 
172
 
173
  def alu_mul(a, b):
174
  a_bits = int_to_bits_msb(a, 8)
 
215
  print("=" * 64)
216
 
217
  def mod5(v):
218
+ # Per-multiple-of-5 match (k0, k5, ..., k255): each k has 8 single-input
219
+ # "bit{i}.match" gates that fire when bit i of v matches bit i of k,
220
+ # ANDed by ".all". Final ".weight" ORs all 52 "all" outputs.
221
  bits = int_to_bits_msb(v, 8)
222
+ ks = [k for k in range(256) if k % 5 == 0]
223
+ alls = []
224
+ for k in ks:
225
+ matches = [gate(f"modular.mod5.eq.k{k}.bit{i}.match", [bits[i]]) for i in range(8)]
226
+ alls.append(gate(f"modular.mod5.eq.k{k}.all", matches))
227
+ return gate("modular.mod5", alls)
 
 
 
228
 
229
  hits = [v for v in range(256) if mod5(v)]
230
  print(f" v in [0,255] with mod5(v)==1: {len(hits)} hits, first 12: {hits[:12]}")