|
import math |
|
import matplotlib.pyplot as plt |
|
import torch |
|
from scipy.optimize import minimize_scalar |
|
|
|
torch.set_float32_matmul_precision('high') |
|
torch.manual_seed(0) |
|
|
|
def opt_err_cvx(fn): |
|
res = minimize_scalar(fn, bounds=(0.1, 100)) |
|
scale = res.x.item() |
|
err = res.fun |
|
return err, scale |
|
|
|
|
|
def round(X, grid, grid_norm): |
|
Xqidx = (2 * X @ grid.T - grid_norm).argmax(-1) |
|
return grid[Xqidx] |
|
|
|
def get_hint_curve(bit_cap=4, cols=1): |
|
def round_mvn(grid, dim=cols, nsamples=50000, sample_bs=500): |
|
X = torch.distributions.multivariate_normal.MultivariateNormal( |
|
torch.zeros(dim), torch.eye(dim)).rsample([nsamples]).to(grid.dtype).to(grid.device).abs() |
|
|
|
grid_norm = grid.norm(dim=-1)**2 |
|
def test_s(s): |
|
total_err = 0 |
|
for i in range(nsamples//sample_bs): |
|
sample_b = X[i*sample_bs: (i+1)*sample_bs].cuda() |
|
total_err += (round(sample_b*s, grid, grid_norm)/s - sample_b).float().norm()**2 / torch.numel(sample_b) |
|
total_err = total_err/(nsamples//sample_bs) |
|
return total_err.cpu() |
|
|
|
return opt_err_cvx(test_s) |
|
|
|
bits = 0 |
|
last_bits = 0 |
|
cr = 1 |
|
data = [[], [], []] |
|
while bits < bit_cap: |
|
base_grid = torch.arange(0, cr).to(torch.float16) |
|
grid = torch.cartesian_prod(*[base_grid + 1/2] * cols) |
|
if cols == 1: |
|
grid = grid.unsqueeze(-1) |
|
grid_norms = torch.sum(grid**2, dim=-1) |
|
norms = torch.unique(grid_norms) |
|
norms = norms[torch.where((norms >= (cr - 1)**2) * (norms < cr**2))[0]] |
|
for norm in norms[::4]: |
|
cb = grid[torch.where(grid_norms <= norm)[0]].cuda() |
|
bits = math.log(len(cb))/math.log(2)/cols + 1 |
|
if bits - last_bits < 0.1: |
|
continue |
|
last_bits = bits |
|
data[0].append(bits) |
|
err, scale = round_mvn(cb.cuda()) |
|
data[1].append(err) |
|
data[2].append(scale) |
|
print(norm.item(), bits, err, scale) |
|
if bits > bit_cap: |
|
return data |
|
cr += 1 |
|
return data |
|
|
|
|
|
def get_D4_curve(bit_cap=4): |
|
def round_mvn(grid, nsamples=50000, sample_bs=1000): |
|
dim = grid.shape[-1] |
|
|
|
X = torch.distributions.multivariate_normal.MultivariateNormal( |
|
torch.zeros(dim), torch.eye(dim)).rsample([nsamples]).to(grid.dtype).to(grid.device) |
|
|
|
grid_norm = grid.norm(dim=-1)**2 |
|
def test_s(s): |
|
err = (round(X*s, grid, grid_norm)/s - X).float().norm()**2 / torch.numel(X) |
|
return err.cpu() |
|
return opt_err_cvx(test_s) |
|
|
|
_D4_CODESZ = 4 |
|
bits = 0 |
|
last_bits = 0 |
|
cr = 1 |
|
data = [[], [], []] |
|
while bits < bit_cap: |
|
base_grid = torch.arange(-cr, cr).to(torch.float16) |
|
grid = torch.cartesian_prod(*[base_grid + 1/2] * _D4_CODESZ) |
|
grid = grid[torch.where(grid.sum(dim=-1) % 2 == 0)[0]] |
|
grid_norms = torch.sum(grid**2, dim=-1) |
|
norms = torch.unique(grid_norms) |
|
norms = norms[torch.where((norms >= (cr - 1)**2) * (norms < cr**2))[0]] |
|
for norm in norms[::4]: |
|
cb = grid[torch.where(grid_norms <= norm)[0]].cuda() |
|
bits = math.log(len(cb))/math.log(2)/_D4_CODESZ |
|
if bits - last_bits < 0.1: |
|
continue |
|
last_bits = bits |
|
data[0].append(bits) |
|
err, scale = round_mvn(cb.cuda()) |
|
data[1].append(err) |
|
data[2].append(scale) |
|
print(norm.item(), bits, err, scale) |
|
if bits > bit_cap: |
|
return data |
|
cr += 1 |
|
return data |
|
|
|
|
|
def get_E8_curve(bit_cap=4): |
|
def round_mvn(grid, nsamples=50000, sample_bs=250): |
|
dim = grid.shape[-1] |
|
|
|
X = torch.distributions.multivariate_normal.MultivariateNormal( |
|
torch.zeros(dim), torch.eye(dim)).rsample([nsamples]).to(grid.dtype) |
|
X_part = torch.abs(X) |
|
X_odd = torch.where((X < 0).sum(dim=-1) % 2 != 0)[0] |
|
X_part[X_odd, 0] = -X_part[X_odd, 0] |
|
X = X_part |
|
|
|
grid_norm = grid.norm(dim=-1)**2 |
|
def test_s(s): |
|
total_err = 0 |
|
for i in range(nsamples//sample_bs): |
|
sample_b = X[i*sample_bs: (i+1)*sample_bs].cuda() |
|
total_err += (round(sample_b*s, grid, grid_norm)/s - sample_b).float().norm()**2 / torch.numel(sample_b) |
|
total_err = total_err/(nsamples//sample_bs) |
|
return total_err.cpu() |
|
return opt_err_cvx(test_s) |
|
|
|
def flip_cb(cb, flips, batch_size=5000000): |
|
map = 1 - 2*flips |
|
output = torch.zeros((len(cb), len(map), cb.shape[-1]), dtype=cb.dtype, device='cpu') |
|
map = map.unsqueeze(0) |
|
for i in range(math.ceil(len(cb)/batch_size)): |
|
next = min(len(cb), (i+1)*batch_size) |
|
output[i*batch_size: next] = (cb[i*batch_size:next].unsqueeze(1)*map).cpu() |
|
return output.reshape(-1, cb.shape[-1]) |
|
|
|
def batched_unique(cpu_tensor, batch_size=10**9): |
|
res = [] |
|
for i in range(math.ceil(len(cpu_tensor)/batch_size)): |
|
next = min(len(cpu_tensor), (i+1)*batch_size) |
|
res.append(torch.unique(cpu_tensor[i*batch_size:next].cuda(), dim=0).cpu()) |
|
return torch.concat(res, dim=0) |
|
|
|
def combo(n, k): |
|
return ((n + 1).lgamma() - (k + 1).lgamma() - ((n - k) + 1).lgamma()).exp() |
|
|
|
_E8_CODESZ = 8 |
|
|
|
int_map = 2**torch.arange(8) |
|
bitmap = torch.zeros(256, 8) |
|
for i in range(256): |
|
bitmap[i] = (i & int_map) != 0 |
|
bitmap = bitmap[torch.where(bitmap.sum(dim=-1)%2 == 0)[0]].cuda() |
|
|
|
bits = 0 |
|
cr = 2 |
|
data = [[], [], []] |
|
last_bits = 0 |
|
while bits < bit_cap: |
|
base_grid = torch.arange(-1, cr).to(torch.float16) |
|
int_grid = torch.cartesian_prod(*[base_grid] * _E8_CODESZ) |
|
int_grid = int_grid[torch.where(int_grid.sum(dim=-1) % 2 == 0)[0]] |
|
hint_grid = torch.cartesian_prod(*[base_grid + 1/2] * _E8_CODESZ) |
|
hint_grid = hint_grid[torch.where(hint_grid.sum(dim=-1) % 2 == 0)[0]] |
|
grid = torch.concat([int_grid, hint_grid], dim=0) |
|
|
|
grid_norms = torch.sum(grid**2, dim=-1) |
|
norms = torch.unique(grid_norms) |
|
norms = norms[torch.where((norms >= (cr - 1)**2) * (norms < cr**2))[0]] |
|
for norm in norms[::4]: |
|
cb = grid[torch.where(grid_norms <= norm)[0]].cuda() |
|
cb = batched_unique(flip_cb(cb, bitmap)) |
|
idxs = torch.where( |
|
((cb[:, 1:] < 0).sum(dim=-1) <= 1) * \ |
|
(cb[:, 1:].min(dim=-1).values >= -0.5) |
|
)[0] |
|
cb_part = cb[idxs] |
|
|
|
bits = math.log(len(cb))/math.log(2)/_E8_CODESZ |
|
if bits - last_bits < 0.1: |
|
continue |
|
last_bits = bits |
|
data[0].append(bits) |
|
err, scale = round_mvn(cb_part.cuda()) |
|
data[1].append(err) |
|
data[2].append(scale) |
|
print(norm.item(), bits, err, scale) |
|
if bits > bit_cap: |
|
return data |
|
cr += 1 |
|
return data |
|
|
|
def parse_cached(s): |
|
s = s.replace('\n', ' ') |
|
s = s.strip().rstrip().split(' ') |
|
bits = [float(_) for _ in s[1::3]] |
|
err = [float(_) for _ in s[2::3]] |
|
return bits, err |
|
|
|
bit_cap = 3.5 |
|
hint_1c = get_hint_curve(bit_cap, 1) |
|
hint_4c = get_hint_curve(bit_cap, 4) |
|
hint_8c = get_hint_curve(bit_cap, 8) |
|
D4 = get_D4_curve(bit_cap) |
|
E8 = get_E8_curve(bit_cap) |
|
|
|
import pickle as pkl |
|
all_data = { |
|
'half_int_1col': hint_1c, |
|
'half_int_4col': hint_4c, |
|
'half_int_8col': hint_8c, |
|
'D4': D4, |
|
'E8': E8, |
|
} |
|
|
|
print(all_data) |
|
pkl.dump(all_data, open('plot_data.pkl', 'wb')) |
|
exit() |
|
|
|
|
|
plt.rcParams["figure.figsize"] = (6,5) |
|
plt.cla() |
|
box = plt.plot(hint_1c[0], hint_1c[1], 's', label='Half Integer 1 Column')[0] |
|
plt.plot(hint_1c[0], hint_1c[1], '-', alpha=0.5, color=box._color) |
|
box = plt.plot(hint_4c[0], hint_4c[1], 'o', label='Half Integer 4 Column')[0] |
|
plt.plot(hint_4c[0], hint_4c[1], '-', alpha=0.5, color=box._color) |
|
box = plt.plot(hint_8c[0], hint_8c[1], '+', label='Half Integer 8 Column')[0] |
|
plt.plot(hint_8c[0], hint_8c[1], '-', alpha=0.5, color=box._color) |
|
box = plt.plot(D4[0], D4[1], '*', label='D4')[0] |
|
plt.plot(D4[0], D4[1], '-', alpha=0.5, color=box._color) |
|
box = plt.plot(E8[0], E8[1], 'x', label='E8')[0] |
|
plt.plot(E8[0], E8[1], '-', alpha=0.5, color=box._color) |
|
plt.plot(2.0, 0.0915, 'yD', label='E8 Padded ($2^{16}$ entries)') |
|
plt.legend() |
|
plt.title('Lowest MSE Achievable for a Multivariate Gaussian') |
|
plt.ylabel('MSE') |
|
plt.yscale('log') |
|
plt.xlabel('Bits') |
|
plt.tight_layout() |
|
plt.savefig('lattice_err.png', dpi=600) |
|
|
|
|
|
|