| import torch
|
| from safetensors.torch import load_file
|
|
|
| def load_model(path='model.safetensors'):
|
| return load_file(path)
|
|
|
| def mux2(d0, d1, s, weights):
|
| """2:1 Multiplexer: returns d0 if s=0, d1 if s=1."""
|
| inp = torch.tensor([float(d0), float(d1), float(s)])
|
|
|
|
|
| sel0 = int((inp @ weights['sel0.weight'].T + weights['sel0.bias'] >= 0).item())
|
| sel1 = int((inp @ weights['sel1.weight'].T + weights['sel1.bias'] >= 0).item())
|
|
|
|
|
| l1 = torch.tensor([float(sel0), float(sel1)])
|
| return int((l1 @ weights['or.weight'].T + weights['or.bias'] >= 0).item())
|
|
|
| if __name__ == '__main__':
|
| w = load_model()
|
| print('MUX2 Truth Table:')
|
| print('s d0 d1 | out | expected')
|
| print('-' * 28)
|
| for s in [0, 1]:
|
| for d0 in [0, 1]:
|
| for d1 in [0, 1]:
|
| result = mux2(d0, d1, s, w)
|
| expected = d1 if s else d0
|
| status = 'OK' if result == expected else 'FAIL'
|
| print(f'{s} {d0} {d1} | {result} | {expected} {status}')
|
|
|