|
import math |
|
|
|
import torch |
|
from torch.nn import functional as F |
|
|
|
|
|
def projection_linf(points_to_project, w_hyperplane, b_hyperplane): |
|
device = points_to_project.device |
|
t, w, b = points_to_project, w_hyperplane.clone(), b_hyperplane.clone() |
|
|
|
sign = 2 * ((w * t).sum(1) - b >= 0) - 1 |
|
w.mul_(sign.unsqueeze(1)) |
|
b.mul_(sign) |
|
|
|
a = (w < 0).float() |
|
d = (a - t) * (w != 0).float() |
|
|
|
p = a - t * (2 * a - 1) |
|
indp = torch.argsort(p, dim=1) |
|
|
|
b = b - (w * t).sum(1) |
|
b0 = (w * d).sum(1) |
|
|
|
indp2 = indp.flip((1,)) |
|
ws = w.gather(1, indp2) |
|
bs2 = - ws * d.gather(1, indp2) |
|
|
|
s = torch.cumsum(ws.abs(), dim=1) |
|
sb = torch.cumsum(bs2, dim=1) + b0.unsqueeze(1) |
|
|
|
b2 = sb[:, -1] - s[:, -1] * p.gather(1, indp[:, 0:1]).squeeze(1) |
|
c_l = b - b2 > 0 |
|
c2 = (b - b0 > 0) & (~c_l) |
|
lb = torch.zeros(c2.sum(), device=device) |
|
ub = torch.full_like(lb, w.shape[1] - 1) |
|
nitermax = math.ceil(math.log2(w.shape[1])) |
|
|
|
indp_, sb_, s_, p_, b_ = indp[c2], sb[c2], s[c2], p[c2], b[c2] |
|
for counter in range(nitermax): |
|
counter4 = torch.floor((lb + ub) / 2) |
|
|
|
counter2 = counter4.long().unsqueeze(1) |
|
indcurr = indp_.gather(1, indp_.size(1) - 1 - counter2) |
|
b2 = (sb_.gather(1, counter2) - s_.gather(1, counter2) * p_.gather(1, indcurr)).squeeze(1) |
|
c = b_ - b2 > 0 |
|
|
|
lb = torch.where(c, counter4, lb) |
|
ub = torch.where(c, ub, counter4) |
|
|
|
lb = lb.long() |
|
|
|
if c_l.any(): |
|
lmbd_opt = torch.clamp_min((b[c_l] - sb[c_l, -1]) / (-s[c_l, -1]), min=0).unsqueeze(-1) |
|
d[c_l] = (2 * a[c_l] - 1) * lmbd_opt |
|
|
|
lmbd_opt = torch.clamp_min((b[c2] - sb[c2, lb]) / (-s[c2, lb]), min=0).unsqueeze(-1) |
|
d[c2] = torch.min(lmbd_opt, d[c2]) * a[c2] + torch.max(-lmbd_opt, d[c2]) * (1 - a[c2]) |
|
|
|
return d * (w != 0).float() |
|
|
|
|
|
def projection_l2(points_to_project, w_hyperplane, b_hyperplane): |
|
device = points_to_project.device |
|
t, w, b = points_to_project, w_hyperplane.clone(), b_hyperplane |
|
|
|
c = (w * t).sum(1) - b |
|
ind2 = 2 * (c >= 0) - 1 |
|
w.mul_(ind2.unsqueeze(1)) |
|
c.mul_(ind2) |
|
|
|
r = torch.max(t / w, (t - 1) / w).clamp(min=-1e12, max=1e12) |
|
r.masked_fill_(w.abs() < 1e-8, 1e12) |
|
r[r == -1e12] *= -1 |
|
rs, indr = torch.sort(r, dim=1) |
|
rs2 = F.pad(rs[:, 1:], (0, 1)) |
|
rs.masked_fill_(rs == 1e12, 0) |
|
rs2.masked_fill_(rs2 == 1e12, 0) |
|
|
|
w3s = (w ** 2).gather(1, indr) |
|
w5 = w3s.sum(dim=1, keepdim=True) |
|
ws = w5 - torch.cumsum(w3s, dim=1) |
|
d = -(r * w) |
|
d.mul_((w.abs() > 1e-8).float()) |
|
s = torch.cat((-w5 * rs[:, 0:1], torch.cumsum((-rs2 + rs) * ws, dim=1) - w5 * rs[:, 0:1]), 1) |
|
|
|
c4 = s[:, 0] + c < 0 |
|
c3 = (d * w).sum(dim=1) + c > 0 |
|
c2 = ~(c4 | c3) |
|
|
|
lb = torch.zeros(c2.sum(), device=device) |
|
ub = torch.full_like(lb, w.shape[1] - 1) |
|
nitermax = math.ceil(math.log2(w.shape[1])) |
|
|
|
s_, c_ = s[c2], c[c2] |
|
for counter in range(nitermax): |
|
counter4 = torch.floor((lb + ub) / 2) |
|
counter2 = counter4.long().unsqueeze(1) |
|
c3 = s_.gather(1, counter2).squeeze(1) + c_ > 0 |
|
lb = torch.where(c3, counter4, lb) |
|
ub = torch.where(c3, ub, counter4) |
|
|
|
lb = lb.long() |
|
|
|
if c4.any(): |
|
alpha = c[c4] / w5[c4].squeeze(-1) |
|
d[c4] = -alpha.unsqueeze(-1) * w[c4] |
|
|
|
if c2.any(): |
|
alpha = (s[c2, lb] + c[c2]) / ws[c2, lb] + rs[c2, lb] |
|
alpha[ws[c2, lb] == 0] = 0 |
|
c5 = (alpha.unsqueeze(-1) > r[c2]).float() |
|
d[c2] = d[c2] * c5 - alpha.unsqueeze(-1) * w[c2] * (1 - c5) |
|
|
|
return d * (w.abs() > 1e-8).float() |
|
|
|
|
|
def projection_l1(points_to_project, w_hyperplane, b_hyperplane): |
|
device = points_to_project.device |
|
t, w, b = points_to_project, w_hyperplane.clone(), b_hyperplane |
|
|
|
c = (w * t).sum(1) - b |
|
ind2 = 2 * (c >= 0) - 1 |
|
w.mul_(ind2.unsqueeze(1)) |
|
c.mul_(ind2) |
|
|
|
r = (1 / w).abs().clamp_max(1e12) |
|
indr = torch.argsort(r, dim=1) |
|
indr_rev = torch.argsort(indr) |
|
|
|
c6 = (w < 0).float() |
|
d = (-t + c6) * (w != 0).float() |
|
ds = torch.min(-w * t, w * (1 - t)).gather(1, indr) |
|
ds2 = torch.cat((c.unsqueeze(-1), ds), 1) |
|
s = torch.cumsum(ds2, dim=1) |
|
|
|
c2 = s[:, -1] < 0 |
|
|
|
lb = torch.zeros(c2.sum(), device=device) |
|
ub = torch.full_like(lb, s.shape[1]) |
|
nitermax = math.ceil(math.log2(w.shape[1])) |
|
|
|
s_ = s[c2] |
|
for counter in range(nitermax): |
|
counter4 = torch.floor((lb + ub) / 2) |
|
counter2 = counter4.long().unsqueeze(1) |
|
c3 = s_.gather(1, counter2).squeeze(1) > 0 |
|
lb = torch.where(c3, counter4, lb) |
|
ub = torch.where(c3, ub, counter4) |
|
|
|
lb2 = lb.long() |
|
|
|
if c2.any(): |
|
indr = indr[c2].gather(1, lb2.unsqueeze(1)).squeeze(1) |
|
u = torch.arange(0, w.shape[0], device=device).unsqueeze(1) |
|
u2 = torch.arange(0, w.shape[1], device=device, dtype=torch.float).unsqueeze(0) |
|
alpha = -s[c2, lb2] / w[c2, indr] |
|
c5 = u2 < lb.unsqueeze(-1) |
|
u3 = c5[u[:c5.shape[0]], indr_rev[c2]] |
|
d[c2] = d[c2] * u3.float() |
|
d[c2, indr] = alpha |
|
|
|
return d * (w.abs() > 1e-8).float() |
|
|