|
|
import torch
|
|
|
from safetensors.torch import load_file
|
|
|
|
|
|
def load_model(path='model.safetensors'):
|
|
|
return load_file(path)
|
|
|
|
|
|
def xor2_from_weights(a, b, w, or_w, or_b, nand_w, nand_b, and_w, and_b):
|
|
|
"""Compute XOR(a,b) using threshold gates"""
|
|
|
inp = torch.tensor([float(a), float(b)])
|
|
|
or_out = float((inp * or_w).sum() + or_b >= 0)
|
|
|
nand_out = float((inp * nand_w).sum() + nand_b >= 0)
|
|
|
l1 = torch.tensor([or_out, nand_out])
|
|
|
return int((l1 * and_w).sum() + and_b >= 0)
|
|
|
|
|
|
def hamming74_encode(d1, d2, d3, d4, w):
|
|
|
"""Hamming(7,4) encoder: 4 data bits -> 7 coded bits"""
|
|
|
inp = torch.tensor([float(d1), float(d2), float(d3), float(d4)])
|
|
|
|
|
|
|
|
|
or_out = float((inp * w['p1.xor12.layer1.or.weight']).sum() + w['p1.xor12.layer1.or.bias'] >= 0)
|
|
|
nand_out = float((inp * w['p1.xor12.layer1.nand.weight']).sum() + w['p1.xor12.layer1.nand.bias'] >= 0)
|
|
|
xor12 = int((torch.tensor([or_out, nand_out]) * w['p1.xor12.layer2.weight']).sum() + w['p1.xor12.layer2.bias'] >= 0)
|
|
|
|
|
|
inp2 = torch.tensor([float(xor12), float(d4)])
|
|
|
or_out = float((inp2 * w['p1.xor_final.layer1.or.weight']).sum() + w['p1.xor_final.layer1.or.bias'] >= 0)
|
|
|
nand_out = float((inp2 * w['p1.xor_final.layer1.nand.weight']).sum() + w['p1.xor_final.layer1.nand.bias'] >= 0)
|
|
|
p1 = int((torch.tensor([or_out, nand_out]) * w['p1.xor_final.layer2.weight']).sum() + w['p1.xor_final.layer2.bias'] >= 0)
|
|
|
|
|
|
|
|
|
or_out = float((inp * w['p2.xor13.layer1.or.weight']).sum() + w['p2.xor13.layer1.or.bias'] >= 0)
|
|
|
nand_out = float((inp * w['p2.xor13.layer1.nand.weight']).sum() + w['p2.xor13.layer1.nand.bias'] >= 0)
|
|
|
xor13 = int((torch.tensor([or_out, nand_out]) * w['p2.xor13.layer2.weight']).sum() + w['p2.xor13.layer2.bias'] >= 0)
|
|
|
|
|
|
inp2 = torch.tensor([float(xor13), float(d4)])
|
|
|
or_out = float((inp2 * w['p2.xor_final.layer1.or.weight']).sum() + w['p2.xor_final.layer1.or.bias'] >= 0)
|
|
|
nand_out = float((inp2 * w['p2.xor_final.layer1.nand.weight']).sum() + w['p2.xor_final.layer1.nand.bias'] >= 0)
|
|
|
p2 = int((torch.tensor([or_out, nand_out]) * w['p2.xor_final.layer2.weight']).sum() + w['p2.xor_final.layer2.bias'] >= 0)
|
|
|
|
|
|
|
|
|
or_out = float((inp * w['p3.xor23.layer1.or.weight']).sum() + w['p3.xor23.layer1.or.bias'] >= 0)
|
|
|
nand_out = float((inp * w['p3.xor23.layer1.nand.weight']).sum() + w['p3.xor23.layer1.nand.bias'] >= 0)
|
|
|
xor23 = int((torch.tensor([or_out, nand_out]) * w['p3.xor23.layer2.weight']).sum() + w['p3.xor23.layer2.bias'] >= 0)
|
|
|
|
|
|
inp2 = torch.tensor([float(xor23), float(d4)])
|
|
|
or_out = float((inp2 * w['p3.xor_final.layer1.or.weight']).sum() + w['p3.xor_final.layer1.or.bias'] >= 0)
|
|
|
nand_out = float((inp2 * w['p3.xor_final.layer1.nand.weight']).sum() + w['p3.xor_final.layer1.nand.bias'] >= 0)
|
|
|
p3 = int((torch.tensor([or_out, nand_out]) * w['p3.xor_final.layer2.weight']).sum() + w['p3.xor_final.layer2.bias'] >= 0)
|
|
|
|
|
|
|
|
|
c3 = int((inp * w['d1.weight']).sum() + w['d1.bias'] >= 0)
|
|
|
c5 = int((inp * w['d2.weight']).sum() + w['d2.bias'] >= 0)
|
|
|
c6 = int((inp * w['d3.weight']).sum() + w['d3.bias'] >= 0)
|
|
|
c7 = int((inp * w['d4.weight']).sum() + w['d4.bias'] >= 0)
|
|
|
|
|
|
|
|
|
return [p1, p2, c3, p3, c5, c6, c7]
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
w = load_model()
|
|
|
print('Hamming(7,4) Encoder')
|
|
|
print('Input (d1d2d3d4) -> Output (c1c2c3c4c5c6c7)')
|
|
|
|
|
|
def ref_encode(d1, d2, d3, d4):
|
|
|
p1 = d1 ^ d2 ^ d4
|
|
|
p2 = d1 ^ d3 ^ d4
|
|
|
p3 = d2 ^ d3 ^ d4
|
|
|
return [p1, p2, d1, p3, d2, d3, d4]
|
|
|
|
|
|
errors = 0
|
|
|
for d in range(16):
|
|
|
d1, d2, d3, d4 = (d>>0)&1, (d>>1)&1, (d>>2)&1, (d>>3)&1
|
|
|
result = hamming74_encode(d1, d2, d3, d4, w)
|
|
|
expected = ref_encode(d1, d2, d3, d4)
|
|
|
status = 'OK' if result == expected else 'FAIL'
|
|
|
if result != expected:
|
|
|
errors += 1
|
|
|
r_str = ''.join(map(str, result))
|
|
|
print(f'{d1}{d2}{d3}{d4} -> {r_str} {status}')
|
|
|
|
|
|
print(f'\n{16-errors}/16 correct')
|
|
|
|