KnutJaegersberg's picture
Upload 132 files
c1a41d7
raw
history blame contribute delete
No virus
8.69 kB
import math
import matplotlib.pyplot as plt
import torch
from scipy.optimize import minimize_scalar
torch.set_float32_matmul_precision('high')
torch.manual_seed(0)
def opt_err_cvx(fn):
res = minimize_scalar(fn, bounds=(0.1, 100))
scale = res.x.item()
err = res.fun
return err, scale
def round(X, grid, grid_norm):
Xqidx = (2 * X @ grid.T - grid_norm).argmax(-1)
return grid[Xqidx]
def get_hint_curve(bit_cap=4, cols=1):
def round_mvn(grid, dim=cols, nsamples=50000, sample_bs=500):
X = torch.distributions.multivariate_normal.MultivariateNormal(
torch.zeros(dim), torch.eye(dim)).rsample([nsamples]).to(grid.dtype).to(grid.device).abs()
grid_norm = grid.norm(dim=-1)**2
def test_s(s):
total_err = 0
for i in range(nsamples//sample_bs):
sample_b = X[i*sample_bs: (i+1)*sample_bs].cuda()
total_err += (round(sample_b*s, grid, grid_norm)/s - sample_b).float().norm()**2 / torch.numel(sample_b)
total_err = total_err/(nsamples//sample_bs)
return total_err.cpu()
return opt_err_cvx(test_s)
bits = 0
last_bits = 0
cr = 1
data = [[], [], []]
while bits < bit_cap:
base_grid = torch.arange(0, cr).to(torch.float16)
grid = torch.cartesian_prod(*[base_grid + 1/2] * cols)
if cols == 1:
grid = grid.unsqueeze(-1)
grid_norms = torch.sum(grid**2, dim=-1)
norms = torch.unique(grid_norms)
norms = norms[torch.where((norms >= (cr - 1)**2) * (norms < cr**2))[0]]
for norm in norms[::4]:
cb = grid[torch.where(grid_norms <= norm)[0]].cuda()
bits = math.log(len(cb))/math.log(2)/cols + 1
if bits - last_bits < 0.1:
continue
last_bits = bits
data[0].append(bits)
err, scale = round_mvn(cb.cuda())
data[1].append(err)
data[2].append(scale)
print(norm.item(), bits, err, scale)
if bits > bit_cap:
return data
cr += 1
return data
def get_D4_curve(bit_cap=4):
def round_mvn(grid, nsamples=50000, sample_bs=1000):
dim = grid.shape[-1]
X = torch.distributions.multivariate_normal.MultivariateNormal(
torch.zeros(dim), torch.eye(dim)).rsample([nsamples]).to(grid.dtype).to(grid.device)
grid_norm = grid.norm(dim=-1)**2
def test_s(s):
err = (round(X*s, grid, grid_norm)/s - X).float().norm()**2 / torch.numel(X)
return err.cpu()
return opt_err_cvx(test_s)
_D4_CODESZ = 4
bits = 0
last_bits = 0
cr = 1
data = [[], [], []]
while bits < bit_cap:
base_grid = torch.arange(-cr, cr).to(torch.float16)
grid = torch.cartesian_prod(*[base_grid + 1/2] * _D4_CODESZ)
grid = grid[torch.where(grid.sum(dim=-1) % 2 == 0)[0]]
grid_norms = torch.sum(grid**2, dim=-1)
norms = torch.unique(grid_norms)
norms = norms[torch.where((norms >= (cr - 1)**2) * (norms < cr**2))[0]]
for norm in norms[::4]:
cb = grid[torch.where(grid_norms <= norm)[0]].cuda()
bits = math.log(len(cb))/math.log(2)/_D4_CODESZ
if bits - last_bits < 0.1:
continue
last_bits = bits
data[0].append(bits)
err, scale = round_mvn(cb.cuda())
data[1].append(err)
data[2].append(scale)
print(norm.item(), bits, err, scale)
if bits > bit_cap:
return data
cr += 1
return data
def get_E8_curve(bit_cap=4):
def round_mvn(grid, nsamples=50000, sample_bs=250):
dim = grid.shape[-1]
X = torch.distributions.multivariate_normal.MultivariateNormal(
torch.zeros(dim), torch.eye(dim)).rsample([nsamples]).to(grid.dtype)
X_part = torch.abs(X)
X_odd = torch.where((X < 0).sum(dim=-1) % 2 != 0)[0]
X_part[X_odd, 0] = -X_part[X_odd, 0]
X = X_part
grid_norm = grid.norm(dim=-1)**2
def test_s(s):
total_err = 0
for i in range(nsamples//sample_bs):
sample_b = X[i*sample_bs: (i+1)*sample_bs].cuda()
total_err += (round(sample_b*s, grid, grid_norm)/s - sample_b).float().norm()**2 / torch.numel(sample_b)
total_err = total_err/(nsamples//sample_bs)
return total_err.cpu()
return opt_err_cvx(test_s)
def flip_cb(cb, flips, batch_size=5000000):
map = 1 - 2*flips
output = torch.zeros((len(cb), len(map), cb.shape[-1]), dtype=cb.dtype, device='cpu')
map = map.unsqueeze(0)
for i in range(math.ceil(len(cb)/batch_size)):
next = min(len(cb), (i+1)*batch_size)
output[i*batch_size: next] = (cb[i*batch_size:next].unsqueeze(1)*map).cpu()
return output.reshape(-1, cb.shape[-1])
def batched_unique(cpu_tensor, batch_size=10**9):
res = []
for i in range(math.ceil(len(cpu_tensor)/batch_size)):
next = min(len(cpu_tensor), (i+1)*batch_size)
res.append(torch.unique(cpu_tensor[i*batch_size:next].cuda(), dim=0).cpu())
return torch.concat(res, dim=0)
def combo(n, k):
return ((n + 1).lgamma() - (k + 1).lgamma() - ((n - k) + 1).lgamma()).exp()
_E8_CODESZ = 8
int_map = 2**torch.arange(8)
bitmap = torch.zeros(256, 8)
for i in range(256):
bitmap[i] = (i & int_map) != 0
bitmap = bitmap[torch.where(bitmap.sum(dim=-1)%2 == 0)[0]].cuda()
bits = 0
cr = 2
data = [[], [], []]
last_bits = 0
while bits < bit_cap:
base_grid = torch.arange(-1, cr).to(torch.float16)
int_grid = torch.cartesian_prod(*[base_grid] * _E8_CODESZ)
int_grid = int_grid[torch.where(int_grid.sum(dim=-1) % 2 == 0)[0]]
hint_grid = torch.cartesian_prod(*[base_grid + 1/2] * _E8_CODESZ)
hint_grid = hint_grid[torch.where(hint_grid.sum(dim=-1) % 2 == 0)[0]]
grid = torch.concat([int_grid, hint_grid], dim=0)
grid_norms = torch.sum(grid**2, dim=-1)
norms = torch.unique(grid_norms)
norms = norms[torch.where((norms >= (cr - 1)**2) * (norms < cr**2))[0]]
for norm in norms[::4]:
cb = grid[torch.where(grid_norms <= norm)[0]].cuda()
cb = batched_unique(flip_cb(cb, bitmap))
idxs = torch.where(
((cb[:, 1:] < 0).sum(dim=-1) <= 1) * \
(cb[:, 1:].min(dim=-1).values >= -0.5)
)[0]
cb_part = cb[idxs]
bits = math.log(len(cb))/math.log(2)/_E8_CODESZ
if bits - last_bits < 0.1:
continue
last_bits = bits
data[0].append(bits)
err, scale = round_mvn(cb_part.cuda())
data[1].append(err)
data[2].append(scale)
print(norm.item(), bits, err, scale)
if bits > bit_cap:
return data
cr += 1
return data
def parse_cached(s):
s = s.replace('\n', ' ')
s = s.strip().rstrip().split(' ')
bits = [float(_) for _ in s[1::3]]
err = [float(_) for _ in s[2::3]]
return bits, err
bit_cap = 3.5
hint_1c = get_hint_curve(bit_cap, 1)
hint_4c = get_hint_curve(bit_cap, 4)
hint_8c = get_hint_curve(bit_cap, 8)
D4 = get_D4_curve(bit_cap)
E8 = get_E8_curve(bit_cap)
import pickle as pkl
all_data = {
'half_int_1col': hint_1c,
'half_int_4col': hint_4c,
'half_int_8col': hint_8c,
'D4': D4,
'E8': E8,
}
print(all_data)
pkl.dump(all_data, open('plot_data.pkl', 'wb'))
exit()
plt.rcParams["figure.figsize"] = (6,5)
plt.cla()
box = plt.plot(hint_1c[0], hint_1c[1], 's', label='Half Integer 1 Column')[0]
plt.plot(hint_1c[0], hint_1c[1], '-', alpha=0.5, color=box._color)
box = plt.plot(hint_4c[0], hint_4c[1], 'o', label='Half Integer 4 Column')[0]
plt.plot(hint_4c[0], hint_4c[1], '-', alpha=0.5, color=box._color)
box = plt.plot(hint_8c[0], hint_8c[1], '+', label='Half Integer 8 Column')[0]
plt.plot(hint_8c[0], hint_8c[1], '-', alpha=0.5, color=box._color)
box = plt.plot(D4[0], D4[1], '*', label='D4')[0]
plt.plot(D4[0], D4[1], '-', alpha=0.5, color=box._color)
box = plt.plot(E8[0], E8[1], 'x', label='E8')[0]
plt.plot(E8[0], E8[1], '-', alpha=0.5, color=box._color)
plt.plot(2.0, 0.0915, 'yD', label='E8 Padded ($2^{16}$ entries)')
plt.legend()
plt.title('Lowest MSE Achievable for a Multivariate Gaussian')
plt.ylabel('MSE')
plt.yscale('log')
plt.xlabel('Bits')
plt.tight_layout()
plt.savefig('lattice_err.png', dpi=600)