Update play.py and eval.py to current bit-cascade comparator/mod-5 layout
Browse filesplay.py: alu_compare and mod5 still called legacy gate names from before
the comparators and modular detectors were bit-cascaded. CMP demo crashed
on a 16-vs-8 input shape mismatch; mod-5 demo crashed on a missing
modular.mod5.layer3.or.weight key. Both now walk the current gate family
(arithmetic.cmp8bit.* per-bit gt/lt/eq + cascade.eq_prefix + cascade.gt/lt
+ final OR/AND; modular.mod5.eq.k{val}.bit{i}.match + .all + top OR).
eval.py: _test_integration's add+compare and sub+conditional sub-tests
called pop['arithmetic.{greater,less}than8bit.weight'].view(pop_size, 16)
against the current 8-element OR-tree weight tensor, raising RuntimeError
that was caught and silently SKIPped. Added _pop_cmp8bit helper that
drives the bit-cascade comparator across the population and rewrote both
sub-tests to use it. Wired _test_integration into evaluate(); previously
defined but never called. Test count: 6,364 -> 6,379 on the ALU variants
(+15 integration cases), all passing at fit=1.0000.
|
@@ -4253,6 +4253,60 @@ class BatchedFitnessEvaluator:
|
|
| 4253 |
# INTEGRATION TESTS (Multi-circuit chains)
|
| 4254 |
# =========================================================================
|
| 4255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4256 |
def _test_integration(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
|
| 4257 |
"""Test complex operations that chain multiple circuit families."""
|
| 4258 |
pop_size = next(iter(pop.values())).shape[0]
|
|
@@ -4279,11 +4333,8 @@ class BatchedFitnessEvaluator:
|
|
| 4279 |
c_bits = torch.tensor([((c >> (7 - i)) & 1) for i in range(8)],
|
| 4280 |
device=self.device, dtype=torch.float32)
|
| 4281 |
|
| 4282 |
-
#
|
| 4283 |
-
|
| 4284 |
-
bias = pop['arithmetic.greaterthan8bit.bias'].view(pop_size)
|
| 4285 |
-
inp = torch.cat([sum_bits, c_bits])
|
| 4286 |
-
out = heaviside((inp * w).sum(-1) + bias)
|
| 4287 |
correct = (out == expected).float()
|
| 4288 |
op_scores += correct
|
| 4289 |
op_total += 1
|
|
@@ -4407,11 +4458,8 @@ class BatchedFitnessEvaluator:
|
|
| 4407 |
b_bits = torch.tensor([((b >> (7 - i)) & 1) for i in range(8)],
|
| 4408 |
device=self.device, dtype=torch.float32)
|
| 4409 |
|
| 4410 |
-
#
|
| 4411 |
-
|
| 4412 |
-
bias = pop['arithmetic.lessthan8bit.bias'].view(pop_size)
|
| 4413 |
-
inp = torch.cat([a_bits, b_bits])
|
| 4414 |
-
lt_out = heaviside((inp * w).sum(-1) + bias)
|
| 4415 |
|
| 4416 |
correct = (lt_out[0].item() == float(is_negative))
|
| 4417 |
op_scores += float(correct)
|
|
@@ -4734,6 +4782,15 @@ class BatchedFitnessEvaluator:
|
|
| 4734 |
total_tests += t
|
| 4735 |
self.category_scores['float32_cmp'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
|
| 4736 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4737 |
self.total_tests = total_tests
|
| 4738 |
|
| 4739 |
if debug:
|
|
|
|
| 4253 |
# INTEGRATION TESTS (Multi-circuit chains)
|
| 4254 |
# =========================================================================
|
| 4255 |
|
| 4256 |
+
def _pop_cmp8bit(self, pop: Dict, pop_size: int,
|
| 4257 |
+
a_bits: torch.Tensor, b_bits: torch.Tensor,
|
| 4258 |
+
kind: str) -> torch.Tensor:
|
| 4259 |
+
"""Drive the bit-cascade comparator (cmp8bit) over a population.
|
| 4260 |
+
|
| 4261 |
+
Returns a (pop_size,) tensor of heaviside outputs for the requested
|
| 4262 |
+
comparison kind ('gt' | 'lt' | 'eq'). Bit 0 is MSB.
|
| 4263 |
+
"""
|
| 4264 |
+
def apply(name: str, inp: torch.Tensor, fan_in: int) -> torch.Tensor:
|
| 4265 |
+
w = pop[f'{name}.weight'].view(pop_size, fan_in)
|
| 4266 |
+
b = pop[f'{name}.bias'].view(pop_size)
|
| 4267 |
+
return heaviside((inp * w).sum(-1) + b)
|
| 4268 |
+
|
| 4269 |
+
# Per-bit primitives.
|
| 4270 |
+
bit_gt, bit_lt, bit_eq = [], [], []
|
| 4271 |
+
for i in range(8):
|
| 4272 |
+
ab = torch.stack([a_bits[i], b_bits[i]])
|
| 4273 |
+
bit_gt.append(apply(f'arithmetic.cmp8bit.bit{i}.gt', ab, 2))
|
| 4274 |
+
bit_lt.append(apply(f'arithmetic.cmp8bit.bit{i}.lt', ab, 2))
|
| 4275 |
+
eq_and = apply(f'arithmetic.cmp8bit.bit{i}.eq.layer1.and', ab, 2)
|
| 4276 |
+
eq_nor = apply(f'arithmetic.cmp8bit.bit{i}.eq.layer1.nor', ab, 2)
|
| 4277 |
+
eq_in = torch.stack([eq_and, eq_nor], dim=-1)
|
| 4278 |
+
w = pop[f'arithmetic.cmp8bit.bit{i}.eq.weight'].view(pop_size, 2)
|
| 4279 |
+
b = pop[f'arithmetic.cmp8bit.bit{i}.eq.bias'].view(pop_size)
|
| 4280 |
+
bit_eq.append(heaviside((eq_in * w).sum(-1) + b))
|
| 4281 |
+
|
| 4282 |
+
# Cascade.
|
| 4283 |
+
cas_gt = [bit_gt[0]]
|
| 4284 |
+
cas_lt = [bit_lt[0]]
|
| 4285 |
+
for i in range(1, 8):
|
| 4286 |
+
eq_pref_in = torch.stack(bit_eq[:i], dim=-1)
|
| 4287 |
+
w_pref = pop[f'arithmetic.cmp8bit.cascade.eq_prefix.bit{i}.weight'].view(pop_size, i)
|
| 4288 |
+
b_pref = pop[f'arithmetic.cmp8bit.cascade.eq_prefix.bit{i}.bias'].view(pop_size)
|
| 4289 |
+
eq_pref = heaviside((eq_pref_in * w_pref).sum(-1) + b_pref)
|
| 4290 |
+
cas_in = torch.stack([eq_pref, bit_gt[i]], dim=-1)
|
| 4291 |
+
w_g = pop[f'arithmetic.cmp8bit.cascade.gt.bit{i}.weight'].view(pop_size, 2)
|
| 4292 |
+
b_g = pop[f'arithmetic.cmp8bit.cascade.gt.bit{i}.bias'].view(pop_size)
|
| 4293 |
+
cas_gt.append(heaviside((cas_in * w_g).sum(-1) + b_g))
|
| 4294 |
+
cas_in_lt = torch.stack([eq_pref, bit_lt[i]], dim=-1)
|
| 4295 |
+
w_l = pop[f'arithmetic.cmp8bit.cascade.lt.bit{i}.weight'].view(pop_size, 2)
|
| 4296 |
+
b_l = pop[f'arithmetic.cmp8bit.cascade.lt.bit{i}.bias'].view(pop_size)
|
| 4297 |
+
cas_lt.append(heaviside((cas_in_lt * w_l).sum(-1) + b_l))
|
| 4298 |
+
|
| 4299 |
+
if kind == 'gt':
|
| 4300 |
+
inp = torch.stack(cas_gt, dim=-1)
|
| 4301 |
+
return apply('arithmetic.greaterthan8bit', inp, 8)
|
| 4302 |
+
if kind == 'lt':
|
| 4303 |
+
inp = torch.stack(cas_lt, dim=-1)
|
| 4304 |
+
return apply('arithmetic.lessthan8bit', inp, 8)
|
| 4305 |
+
if kind == 'eq':
|
| 4306 |
+
inp = torch.stack(bit_eq, dim=-1)
|
| 4307 |
+
return apply('arithmetic.equality8bit', inp, 8)
|
| 4308 |
+
raise ValueError(kind)
|
| 4309 |
+
|
| 4310 |
def _test_integration(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
|
| 4311 |
"""Test complex operations that chain multiple circuit families."""
|
| 4312 |
pop_size = next(iter(pop.values())).shape[0]
|
|
|
|
| 4333 |
c_bits = torch.tensor([((c >> (7 - i)) & 1) for i in range(8)],
|
| 4334 |
device=self.device, dtype=torch.float32)
|
| 4335 |
|
| 4336 |
+
# Drive sum_bits vs c_bits through the bit-cascade comparator.
|
| 4337 |
+
out = self._pop_cmp8bit(pop, pop_size, sum_bits, c_bits, 'gt')
|
|
|
|
|
|
|
|
|
|
| 4338 |
correct = (out == expected).float()
|
| 4339 |
op_scores += correct
|
| 4340 |
op_total += 1
|
|
|
|
| 4458 |
b_bits = torch.tensor([((b >> (7 - i)) & 1) for i in range(8)],
|
| 4459 |
device=self.device, dtype=torch.float32)
|
| 4460 |
|
| 4461 |
+
# Drive a_bits vs b_bits through the bit-cascade LT comparator.
|
| 4462 |
+
lt_out = self._pop_cmp8bit(pop, pop_size, a_bits, b_bits, 'lt')
|
|
|
|
|
|
|
|
|
|
| 4463 |
|
| 4464 |
correct = (lt_out[0].item() == float(is_negative))
|
| 4465 |
op_scores += float(correct)
|
|
|
|
| 4782 |
total_tests += t
|
| 4783 |
self.category_scores['float32_cmp'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
|
| 4784 |
|
| 4785 |
+
# Cross-family integration tests (chain ripple-carry, comparator,
|
| 4786 |
+
# modular, shifts, subtractor). Each test is internally guarded with
|
| 4787 |
+
# try/except, so unsupported variants silently skip individual tests.
|
| 4788 |
+
if 'arithmetic.cmp8bit.bit0.gt.weight' in population:
|
| 4789 |
+
s, t = self._test_integration(population, debug)
|
| 4790 |
+
scores += s
|
| 4791 |
+
total_tests += t
|
| 4792 |
+
self.category_scores['integration'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
|
| 4793 |
+
|
| 4794 |
self.total_tests = total_tests
|
| 4795 |
|
| 4796 |
if debug:
|
|
@@ -145,12 +145,30 @@ def main() -> int:
|
|
| 145 |
return bits_msb_to_int(list(reversed(diff_lsb))), carry
|
| 146 |
|
| 147 |
def alu_compare(a, b, kind):
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
if kind == "eq":
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
return gate("arithmetic.equality8bit.layer2", [h_geq, h_leq])
|
| 153 |
-
return gate(f"arithmetic.{kind}8bit", inp)
|
| 154 |
|
| 155 |
def alu_mul(a, b):
|
| 156 |
a_bits = int_to_bits_msb(a, 8)
|
|
@@ -197,16 +215,16 @@ def main() -> int:
|
|
| 197 |
print("=" * 64)
|
| 198 |
|
| 199 |
def mod5(v):
|
|
|
|
|
|
|
|
|
|
| 200 |
bits = int_to_bits_msb(v, 8)
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
h_leq = gate(f"modular.mod5.layer1.leq{i}", bits)
|
| 208 |
-
eqs.append(gate(f"modular.mod5.layer2.eq{i}", [h_geq, h_leq]))
|
| 209 |
-
return gate("modular.mod5.layer3.or", eqs)
|
| 210 |
|
| 211 |
hits = [v for v in range(256) if mod5(v)]
|
| 212 |
print(f" v in [0,255] with mod5(v)==1: {len(hits)} hits, first 12: {hits[:12]}")
|
|
|
|
| 145 |
return bits_msb_to_int(list(reversed(diff_lsb))), carry
|
| 146 |
|
| 147 |
def alu_compare(a, b, kind):
|
| 148 |
+
# Walks the bit-cascade comparator family: per-bit gt/lt/eq, cascaded
|
| 149 |
+
# eq_prefix, cascade.gt/lt, and the final OR/AND gates. Bit 0 is MSB.
|
| 150 |
+
a_msb = int_to_bits_msb(a, 8)
|
| 151 |
+
b_msb = int_to_bits_msb(b, 8)
|
| 152 |
+
bit_gt = [gate(f"arithmetic.cmp8bit.bit{i}.gt", [a_msb[i], b_msb[i]]) for i in range(8)]
|
| 153 |
+
bit_lt = [gate(f"arithmetic.cmp8bit.bit{i}.lt", [a_msb[i], b_msb[i]]) for i in range(8)]
|
| 154 |
+
bit_eq = []
|
| 155 |
+
for i in range(8):
|
| 156 |
+
eq_and = gate(f"arithmetic.cmp8bit.bit{i}.eq.layer1.and", [a_msb[i], b_msb[i]])
|
| 157 |
+
eq_nor = gate(f"arithmetic.cmp8bit.bit{i}.eq.layer1.nor", [a_msb[i], b_msb[i]])
|
| 158 |
+
bit_eq.append(gate(f"arithmetic.cmp8bit.bit{i}.eq", [eq_and, eq_nor]))
|
| 159 |
+
cas_gt = [bit_gt[0]]
|
| 160 |
+
cas_lt = [bit_lt[0]]
|
| 161 |
+
for i in range(1, 8):
|
| 162 |
+
eq_pref = gate(f"arithmetic.cmp8bit.cascade.eq_prefix.bit{i}", bit_eq[:i])
|
| 163 |
+
cas_gt.append(gate(f"arithmetic.cmp8bit.cascade.gt.bit{i}", [eq_pref, bit_gt[i]]))
|
| 164 |
+
cas_lt.append(gate(f"arithmetic.cmp8bit.cascade.lt.bit{i}", [eq_pref, bit_lt[i]]))
|
| 165 |
+
if kind == "greaterthan":
|
| 166 |
+
return gate("arithmetic.greaterthan8bit", cas_gt)
|
| 167 |
+
if kind == "lessthan":
|
| 168 |
+
return gate("arithmetic.lessthan8bit", cas_lt)
|
| 169 |
if kind == "eq":
|
| 170 |
+
return gate("arithmetic.equality8bit", bit_eq)
|
| 171 |
+
raise ValueError(kind)
|
|
|
|
|
|
|
| 172 |
|
| 173 |
def alu_mul(a, b):
|
| 174 |
a_bits = int_to_bits_msb(a, 8)
|
|
|
|
| 215 |
print("=" * 64)
|
| 216 |
|
| 217 |
def mod5(v):
|
| 218 |
+
# Per-multiple-of-5 match (k0, k5, ..., k255): each k has 8 single-input
|
| 219 |
+
# "bit{i}.match" gates that fire when bit i of v matches bit i of k,
|
| 220 |
+
# ANDed by ".all". Final ".weight" ORs all 52 "all" outputs.
|
| 221 |
bits = int_to_bits_msb(v, 8)
|
| 222 |
+
ks = [k for k in range(256) if k % 5 == 0]
|
| 223 |
+
alls = []
|
| 224 |
+
for k in ks:
|
| 225 |
+
matches = [gate(f"modular.mod5.eq.k{k}.bit{i}.match", [bits[i]]) for i in range(8)]
|
| 226 |
+
alls.append(gate(f"modular.mod5.eq.k{k}.all", matches))
|
| 227 |
+
return gate("modular.mod5", alls)
|
|
|
|
|
|
|
|
|
|
| 228 |
|
| 229 |
hits = [v for v in range(256) if mod5(v)]
|
| 230 |
print(f" v in [0,255] with mod5(v)==1: {len(hits)} hits, first 12: {hits[:12]}")
|