|
|
""" |
|
|
3D Voxel Shape Classifier — Complete Geometric Primitive Vocabulary |
|
|
5×5×5 binary voxel grid → rigid cascade → curvature analysis → classify |
|
|
|
|
|
38 shape classes covering: |
|
|
- Rigid 0D-3D: points, lines, joints, triangles, quads, polyhedra, prisms |
|
|
- Curved 1D: arcs, helices |
|
|
- Curved 2D: circles, ellipses, discs |
|
|
- Curved 3D solid: sphere, hemisphere, cylinder, cone, capsule, torus |
|
|
- Curved 3D hollow: shell, tube |
|
|
- Curved 3D open: bowl (concave), saddle (hyperbolic) |
|
|
|
|
|
Curvature types: none, convex, concave, cylindrical, conical, toroidal, hyperbolic, helical |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Optional |
|
|
import math |
|
|
from itertools import combinations |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SwiGLU(nn.Module): |
|
|
""" |
|
|
SwiGLU activation: out = (x @ W1) * SiLU(x @ W2) |
|
|
|
|
|
SiLU(x) = x * sigmoid(x), aka Swish — the "Swi" in SwiGLU. |
|
|
Unlike plain sigmoid gating, SiLU preserves gradient magnitude |
|
|
through the gate branch while maintaining sharp gating behavior. |
|
|
|
|
|
Used at geometric decision points where crisp on/off transitions |
|
|
matter more than smooth interpolation. |
|
|
""" |
|
|
|
|
|
def __init__(self, in_dim, out_dim): |
|
|
super().__init__() |
|
|
self.w1 = nn.Linear(in_dim, out_dim) |
|
|
self.w2 = nn.Linear(in_dim, out_dim) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.w1(x) * F.silu(self.w2(x)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SHAPE_CATALOG = { |
|
|
|
|
|
"point": {"dim": 0, "curved": False, "curvature": "none"}, |
|
|
|
|
|
|
|
|
"line_x": {"dim": 1, "curved": False, "curvature": "none"}, |
|
|
"line_y": {"dim": 1, "curved": False, "curvature": "none"}, |
|
|
"line_z": {"dim": 1, "curved": False, "curvature": "none"}, |
|
|
"line_diag": {"dim": 1, "curved": False, "curvature": "none"}, |
|
|
|
|
|
|
|
|
"cross": {"dim": 1, "curved": False, "curvature": "none"}, |
|
|
"l_shape": {"dim": 1, "curved": False, "curvature": "none"}, |
|
|
"collinear": {"dim": 1, "curved": False, "curvature": "none"}, |
|
|
|
|
|
|
|
|
"triangle_xy": {"dim": 2, "curved": False, "curvature": "none"}, |
|
|
"triangle_xz": {"dim": 2, "curved": False, "curvature": "none"}, |
|
|
"triangle_3d": {"dim": 2, "curved": False, "curvature": "none"}, |
|
|
|
|
|
|
|
|
"square_xy": {"dim": 2, "curved": False, "curvature": "none"}, |
|
|
"square_xz": {"dim": 2, "curved": False, "curvature": "none"}, |
|
|
"rectangle": {"dim": 2, "curved": False, "curvature": "none"}, |
|
|
"coplanar": {"dim": 2, "curved": False, "curvature": "none"}, |
|
|
|
|
|
|
|
|
"plane": {"dim": 2, "curved": False, "curvature": "none"}, |
|
|
|
|
|
|
|
|
"tetrahedron": {"dim": 3, "curved": False, "curvature": "none"}, |
|
|
"pyramid": {"dim": 3, "curved": False, "curvature": "none"}, |
|
|
"pentachoron": {"dim": 3, "curved": False, "curvature": "none"}, |
|
|
|
|
|
|
|
|
"cube": {"dim": 3, "curved": False, "curvature": "none"}, |
|
|
"cuboid": {"dim": 3, "curved": False, "curvature": "none"}, |
|
|
"triangular_prism": {"dim": 3, "curved": False, "curvature": "none"}, |
|
|
"octahedron": {"dim": 3, "curved": False, "curvature": "none"}, |
|
|
|
|
|
|
|
|
"arc": {"dim": 1, "curved": True, "curvature": "convex"}, |
|
|
"helix": {"dim": 1, "curved": True, "curvature": "helical"}, |
|
|
|
|
|
|
|
|
"circle": {"dim": 2, "curved": True, "curvature": "convex"}, |
|
|
"ellipse": {"dim": 2, "curved": True, "curvature": "convex"}, |
|
|
|
|
|
|
|
|
"disc": {"dim": 2, "curved": True, "curvature": "convex"}, |
|
|
|
|
|
|
|
|
"sphere": {"dim": 3, "curved": True, "curvature": "convex"}, |
|
|
"hemisphere": {"dim": 3, "curved": True, "curvature": "convex"}, |
|
|
"cylinder": {"dim": 3, "curved": True, "curvature": "cylindrical"}, |
|
|
"cone": {"dim": 3, "curved": True, "curvature": "conical"}, |
|
|
"capsule": {"dim": 3, "curved": True, "curvature": "convex"}, |
|
|
"torus": {"dim": 3, "curved": True, "curvature": "toroidal"}, |
|
|
|
|
|
|
|
|
"shell": {"dim": 3, "curved": True, "curvature": "convex"}, |
|
|
"tube": {"dim": 3, "curved": True, "curvature": "cylindrical"}, |
|
|
|
|
|
|
|
|
"bowl": {"dim": 3, "curved": True, "curvature": "concave"}, |
|
|
"saddle": {"dim": 3, "curved": True, "curvature": "hyperbolic"}, |
|
|
} |
|
|
|
|
|
NUM_CLASSES = len(SHAPE_CATALOG) |
|
|
CLASS_NAMES = list(SHAPE_CATALOG.keys()) |
|
|
CLASS_TO_IDX = {name: i for i, name in enumerate(CLASS_NAMES)} |
|
|
|
|
|
CURVATURE_TYPES = ["none", "convex", "concave", "cylindrical", "conical", |
|
|
"toroidal", "hyperbolic", "helical"] |
|
|
CURV_TO_IDX = {c: i for i, c in enumerate(CURVATURE_TYPES)} |
|
|
NUM_CURVATURES = len(CURVATURE_TYPES) |
|
|
|
|
|
GS = 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cayley_menger_det(points: np.ndarray) -> float: |
|
|
n = len(points) |
|
|
D = np.zeros((n, n)) |
|
|
for i in range(n): |
|
|
for j in range(n): |
|
|
D[i, j] = np.sum((points[i] - points[j]) ** 2) |
|
|
CM = np.zeros((n + 1, n + 1)) |
|
|
CM[0, 1:] = 1 |
|
|
CM[1:, 0] = 1 |
|
|
CM[1:, 1:] = D |
|
|
return np.linalg.det(CM) |
|
|
|
|
|
|
|
|
def simplex_volume(points: np.ndarray) -> float: |
|
|
k = len(points) |
|
|
if k < 2: return 0.0 |
|
|
cm = cayley_menger_det(points) |
|
|
sign = (-1) ** k |
|
|
denom = (2 ** (k - 1)) * (math.factorial(k - 1) ** 2) |
|
|
v_sq = sign * cm / denom |
|
|
return np.sqrt(max(0, v_sq)) |
|
|
|
|
|
|
|
|
def effective_volume(points: np.ndarray) -> float: |
|
|
k = len(points) |
|
|
if k < 2: return 0.0 |
|
|
if k == 2: return np.linalg.norm(points[0] - points[1]) |
|
|
if k >= 3: |
|
|
max_a = 0 |
|
|
for idx in combinations(range(min(k, 8)), 3): |
|
|
max_a = max(max_a, simplex_volume(points[list(idx)])) |
|
|
if k < 4: return max_a |
|
|
if k >= 4: |
|
|
max_v = 0 |
|
|
for idx in combinations(range(min(k, 8)), 4): |
|
|
max_v = max(max_v, simplex_volume(points[list(idx)])) |
|
|
return max_v |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ShapeGenerator: |
|
|
def __init__(self, seed=42): |
|
|
self.rng = np.random.RandomState(seed) |
|
|
|
|
|
def generate(self, n_samples: int) -> list: |
|
|
samples = [] |
|
|
per_class = n_samples // NUM_CLASSES |
|
|
for name in CLASS_NAMES: |
|
|
count = 0 |
|
|
attempts = 0 |
|
|
while count < per_class and attempts < per_class * 5: |
|
|
s = self._make(name) |
|
|
attempts += 1 |
|
|
if s is not None: |
|
|
samples.append(s) |
|
|
count += 1 |
|
|
while len(samples) < n_samples: |
|
|
name = self.rng.choice(CLASS_NAMES) |
|
|
s = self._make(name) |
|
|
if s is not None: |
|
|
samples.append(s) |
|
|
self.rng.shuffle(samples) |
|
|
return samples[:n_samples] |
|
|
|
|
|
def _make(self, name: str) -> Optional[dict]: |
|
|
info = SHAPE_CATALOG[name] |
|
|
if info["curved"]: |
|
|
voxels = self._curved(name) |
|
|
else: |
|
|
voxels = self._rigid(name) |
|
|
if voxels is None: return None |
|
|
voxels = np.clip(voxels, 0, GS - 1).astype(int) |
|
|
voxels = np.unique(voxels, axis=0) |
|
|
if len(voxels) < 1: return None |
|
|
return self._build(name, info, voxels) |
|
|
|
|
|
|
|
|
|
|
|
def _rigid(self, name): |
|
|
rng = self.rng |
|
|
|
|
|
if name == "point": |
|
|
return rng.randint(0, GS, size=(1, 3)) |
|
|
|
|
|
elif name == "line_x": |
|
|
y, z = rng.randint(0, GS, size=2) |
|
|
x1, x2 = sorted(rng.choice(GS, 2, replace=False)) |
|
|
return np.array([[x1, y, z], [x2, y, z]]) |
|
|
|
|
|
elif name == "line_y": |
|
|
x, z = rng.randint(0, GS, size=2) |
|
|
y1, y2 = sorted(rng.choice(GS, 2, replace=False)) |
|
|
return np.array([[x, y1, z], [x, y2, z]]) |
|
|
|
|
|
elif name == "line_z": |
|
|
x, y = rng.randint(0, GS, size=2) |
|
|
z1, z2 = sorted(rng.choice(GS, 2, replace=False)) |
|
|
return np.array([[x, y, z1], [x, y, z2]]) |
|
|
|
|
|
elif name == "line_diag": |
|
|
p1 = rng.randint(0, 3, size=3) |
|
|
step = rng.randint(1, 3) |
|
|
direction = rng.choice([-1, 1], size=3) |
|
|
if np.sum(direction != 0) < 2: |
|
|
direction[rng.randint(3)] = rng.choice([-1, 1]) |
|
|
p2 = np.clip(p1 + step * direction, 0, GS - 1) |
|
|
if np.array_equal(p1, p2): |
|
|
p2 = np.clip(p1 + np.array([1, 1, 0]), 0, GS - 1) |
|
|
return np.array([p1, p2]) |
|
|
|
|
|
elif name == "cross": |
|
|
|
|
|
cx, cy, cz = rng.randint(1, GS - 1, size=3) |
|
|
length = rng.randint(1, 3) |
|
|
axis1, axis2 = rng.choice(3, 2, replace=False) |
|
|
pts = [[cx, cy, cz]] |
|
|
for sign in [-1, 1]: |
|
|
p = [cx, cy, cz] |
|
|
p[axis1] = np.clip(p[axis1] + sign * length, 0, GS - 1) |
|
|
pts.append(list(p)) |
|
|
for sign in [-1, 1]: |
|
|
p = [cx, cy, cz] |
|
|
p[axis2] = np.clip(p[axis2] + sign * length, 0, GS - 1) |
|
|
pts.append(list(p)) |
|
|
return np.array(pts) |
|
|
|
|
|
elif name == "l_shape": |
|
|
|
|
|
corner = rng.randint(1, GS - 1, size=3) |
|
|
axis1, axis2 = rng.choice(3, 2, replace=False) |
|
|
len1 = rng.randint(1, 3) |
|
|
len2 = rng.randint(1, 3) |
|
|
dir1 = rng.choice([-1, 1]) |
|
|
dir2 = rng.choice([-1, 1]) |
|
|
pts = [list(corner)] |
|
|
for i in range(1, len1 + 1): |
|
|
p = list(corner) |
|
|
p[axis1] = np.clip(p[axis1] + dir1 * i, 0, GS - 1) |
|
|
pts.append(p) |
|
|
for i in range(1, len2 + 1): |
|
|
p = list(corner) |
|
|
p[axis2] = np.clip(p[axis2] + dir2 * i, 0, GS - 1) |
|
|
pts.append(p) |
|
|
return np.array(pts) |
|
|
|
|
|
elif name == "collinear": |
|
|
axis = rng.randint(3) |
|
|
fixed = rng.randint(0, GS, size=2) |
|
|
vals = sorted(rng.choice(GS, 3, replace=False)) |
|
|
pts = np.zeros((3, 3), dtype=int) |
|
|
for i, v in enumerate(vals): |
|
|
pts[i, axis] = v |
|
|
pts[i, (axis + 1) % 3] = fixed[0] |
|
|
pts[i, (axis + 2) % 3] = fixed[1] |
|
|
return pts |
|
|
|
|
|
elif name == "triangle_xy": |
|
|
z = rng.randint(0, GS) |
|
|
pts = self._rand_pts_2d(3, min_dist=1) |
|
|
if pts is None: return None |
|
|
return np.column_stack([pts, np.full(3, z)]) |
|
|
|
|
|
elif name == "triangle_xz": |
|
|
y = rng.randint(0, GS) |
|
|
pts = self._rand_pts_2d(3, min_dist=1) |
|
|
if pts is None: return None |
|
|
return np.column_stack([pts[:, 0], np.full(3, y), pts[:, 1]]) |
|
|
|
|
|
elif name == "triangle_3d": |
|
|
return self._rand_pts_3d(3, min_dist=1) |
|
|
|
|
|
elif name == "square_xy": |
|
|
z = rng.randint(0, GS) |
|
|
x1, y1 = rng.randint(0, 3, size=2) |
|
|
s = rng.randint(1, 3) |
|
|
pts = np.array([[x1, y1, z], [x1 + s, y1, z], |
|
|
[x1, y1 + s, z], [x1 + s, y1 + s, z]]) |
|
|
return np.clip(pts, 0, GS - 1) |
|
|
|
|
|
elif name == "square_xz": |
|
|
y = rng.randint(0, GS) |
|
|
x1, z1 = rng.randint(0, 3, size=2) |
|
|
s = rng.randint(1, 3) |
|
|
pts = np.array([[x1, y, z1], [x1 + s, y, z1], |
|
|
[x1, y, z1 + s], [x1 + s, y, z1 + s]]) |
|
|
return np.clip(pts, 0, GS - 1) |
|
|
|
|
|
elif name == "rectangle": |
|
|
axis = rng.randint(3) |
|
|
val = rng.randint(0, GS) |
|
|
a1, a2 = rng.randint(0, 3), rng.randint(0, 3) |
|
|
w, h = rng.randint(1, 4), rng.randint(1, 3) |
|
|
if w == h: w = min(GS - 1, w + 1) |
|
|
c = np.array([[a1, a2], [a1 + w, a2], [a1, a2 + h], [a1 + w, a2 + h]]) |
|
|
c = np.clip(c, 0, GS - 1) |
|
|
if axis == 0: return np.column_stack([np.full(4, val), c]) |
|
|
elif axis == 1: return np.column_stack([c[:, 0], np.full(4, val), c[:, 1]]) |
|
|
else: return np.column_stack([c, np.full(4, val)]) |
|
|
|
|
|
elif name == "coplanar": |
|
|
pts = self._rand_pts_3d(4, min_dist=1) |
|
|
if pts is None: return None |
|
|
pts[:, rng.randint(3)] = pts[0, rng.randint(3)] |
|
|
return pts |
|
|
|
|
|
elif name == "plane": |
|
|
|
|
|
axis = rng.randint(3) |
|
|
val = rng.randint(0, GS) |
|
|
a_start = rng.randint(0, 2) |
|
|
b_start = rng.randint(0, 2) |
|
|
a_size = rng.randint(2, GS - a_start + 1) |
|
|
b_size = rng.randint(2, GS - b_start + 1) |
|
|
pts = [] |
|
|
for a in range(a_start, min(GS, a_start + a_size)): |
|
|
for b in range(b_start, min(GS, b_start + b_size)): |
|
|
p = [0, 0, 0] |
|
|
p[axis] = val |
|
|
p[(axis + 1) % 3] = a |
|
|
p[(axis + 2) % 3] = b |
|
|
pts.append(p) |
|
|
return np.array(pts) if len(pts) >= 4 else None |
|
|
|
|
|
elif name == "tetrahedron": |
|
|
pts = self._rand_pts_3d(4, min_dist=1) |
|
|
if pts is None: return None |
|
|
centered = pts - pts.mean(axis=0) |
|
|
_, s, _ = np.linalg.svd(centered.astype(float)) |
|
|
if s[-1] < 0.5: |
|
|
pts[rng.randint(4), rng.randint(3)] = (pts[0, 0] + 2) % GS |
|
|
return pts |
|
|
|
|
|
elif name == "pyramid": |
|
|
z_base = rng.randint(0, 3) |
|
|
x1, y1 = rng.randint(0, 3), rng.randint(0, 3) |
|
|
s = rng.randint(1, 3) |
|
|
base = np.array([[x1, y1, z_base], [x1 + s, y1, z_base], |
|
|
[x1, y1 + s, z_base], [x1 + s, y1 + s, z_base]]) |
|
|
apex = np.array([[x1 + s // 2, y1 + s // 2, z_base + rng.randint(1, 3)]]) |
|
|
return np.clip(np.vstack([base, apex]), 0, GS - 1) |
|
|
|
|
|
elif name == "pentachoron": |
|
|
return self._rand_pts_3d(5, min_dist=1) |
|
|
|
|
|
elif name == "cube": |
|
|
x1, y1, z1 = rng.randint(0, 3, size=3) |
|
|
s = rng.randint(1, 3) |
|
|
pts = [] |
|
|
for dx in [0, s]: |
|
|
for dy in [0, s]: |
|
|
for dz in [0, s]: |
|
|
pts.append([x1 + dx, y1 + dy, z1 + dz]) |
|
|
return np.clip(np.array(pts), 0, GS - 1) |
|
|
|
|
|
elif name == "cuboid": |
|
|
x1, y1, z1 = rng.randint(0, 2, size=3) |
|
|
sx, sy, sz = rng.randint(1, 4, size=3) |
|
|
|
|
|
if sx == sy == sz: |
|
|
sx = min(GS - 1, sx + 1) |
|
|
pts = [] |
|
|
for dx in [0, sx]: |
|
|
for dy in [0, sy]: |
|
|
for dz in [0, sz]: |
|
|
pts.append([x1 + dx, y1 + dy, z1 + dz]) |
|
|
return np.clip(np.array(pts), 0, GS - 1) |
|
|
|
|
|
elif name == "triangular_prism": |
|
|
|
|
|
axis = rng.randint(3) |
|
|
ext_start = rng.randint(0, 3) |
|
|
ext_len = rng.randint(1, 3) |
|
|
tri = self._rand_pts_2d(3, min_dist=1) |
|
|
if tri is None: return None |
|
|
pts = [] |
|
|
for e in range(ext_start, min(GS, ext_start + ext_len + 1)): |
|
|
for t in tri: |
|
|
p = [0, 0, 0] |
|
|
p[axis] = e |
|
|
p[(axis + 1) % 3] = t[0] |
|
|
p[(axis + 2) % 3] = t[1] |
|
|
pts.append(p) |
|
|
return np.clip(np.array(pts), 0, GS - 1) if len(pts) >= 6 else None |
|
|
|
|
|
elif name == "octahedron": |
|
|
|
|
|
cx, cy, cz = rng.randint(1, GS - 1, size=3) |
|
|
s = rng.randint(1, 3) |
|
|
pts = [[cx, cy, cz + s], [cx, cy, cz - s], |
|
|
[cx + s, cy, cz], [cx - s, cy, cz], |
|
|
[cx, cy + s, cz], [cx, cy - s, cz]] |
|
|
return np.clip(np.array(pts), 0, GS - 1) |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
def _curved(self, name): |
|
|
rng = self.rng |
|
|
cx, cy, cz = rng.uniform(1.0, 3.0, size=3) |
|
|
|
|
|
if name == "arc": |
|
|
r = rng.uniform(1.2, 2.2) |
|
|
plane = rng.choice(["xy", "xz", "yz"]) |
|
|
start = rng.uniform(0, 2 * np.pi) |
|
|
span = rng.uniform(np.pi * 0.4, np.pi * 1.2) |
|
|
n = rng.randint(6, 12) |
|
|
angles = np.linspace(start, start + span, n) |
|
|
pts = [] |
|
|
for a in angles: |
|
|
if plane == "xy": |
|
|
pts.append([cx + r * np.cos(a), cy + r * np.sin(a), cz]) |
|
|
elif plane == "xz": |
|
|
pts.append([cx + r * np.cos(a), cy, cz + r * np.sin(a)]) |
|
|
else: |
|
|
pts.append([cx, cy + r * np.cos(a), cz + r * np.sin(a)]) |
|
|
pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0) |
|
|
return pts if len(pts) >= 3 else None |
|
|
|
|
|
elif name == "helix": |
|
|
|
|
|
r = rng.uniform(0.8, 1.8) |
|
|
axis = rng.randint(3) |
|
|
pitch = rng.uniform(0.3, 0.8) |
|
|
n = rng.randint(15, 30) |
|
|
t = np.linspace(0, 2 * np.pi * rng.uniform(1.0, 2.5), n) |
|
|
pts = [] |
|
|
center = [cx, cy, cz] |
|
|
axes = [i for i in range(3) if i != axis] |
|
|
start_h = rng.uniform(0, 1.0) |
|
|
for ti in t: |
|
|
p = [0.0, 0.0, 0.0] |
|
|
p[axes[0]] = center[axes[0]] + r * np.cos(ti) |
|
|
p[axes[1]] = center[axes[1]] + r * np.sin(ti) |
|
|
p[axis] = start_h + pitch * ti |
|
|
pts.append(p) |
|
|
pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0) |
|
|
return pts if len(pts) >= 5 else None |
|
|
|
|
|
elif name == "circle": |
|
|
r = rng.uniform(1.0, 2.0) |
|
|
plane = rng.choice(["xy", "xz", "yz"]) |
|
|
n = rng.randint(12, 20) |
|
|
angles = np.linspace(0, 2 * np.pi, n, endpoint=False) |
|
|
pts = [] |
|
|
for a in angles: |
|
|
if plane == "xy": |
|
|
pts.append([cx + r * np.cos(a), cy + r * np.sin(a), cz]) |
|
|
elif plane == "xz": |
|
|
pts.append([cx + r * np.cos(a), cy, cz + r * np.sin(a)]) |
|
|
else: |
|
|
pts.append([cx, cy + r * np.cos(a), cz + r * np.sin(a)]) |
|
|
pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0) |
|
|
return pts if len(pts) >= 5 else None |
|
|
|
|
|
elif name == "ellipse": |
|
|
rx, ry = rng.uniform(0.8, 2.0), rng.uniform(0.8, 2.0) |
|
|
if abs(rx - ry) < 0.3: rx *= 1.4 |
|
|
plane = rng.choice(["xy", "xz", "yz"]) |
|
|
n = rng.randint(12, 20) |
|
|
angles = np.linspace(0, 2 * np.pi, n, endpoint=False) |
|
|
pts = [] |
|
|
for a in angles: |
|
|
if plane == "xy": |
|
|
pts.append([cx + rx * np.cos(a), cy + ry * np.sin(a), cz]) |
|
|
elif plane == "xz": |
|
|
pts.append([cx + rx * np.cos(a), cy, cz + ry * np.sin(a)]) |
|
|
else: |
|
|
pts.append([cx, cy + rx * np.cos(a), cz + ry * np.sin(a)]) |
|
|
pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0) |
|
|
return pts if len(pts) >= 5 else None |
|
|
|
|
|
elif name == "disc": |
|
|
|
|
|
r = rng.uniform(1.0, 2.2) |
|
|
axis = rng.randint(3) |
|
|
val = round(rng.uniform(0.5, 3.5)) |
|
|
center = [cx, cy, cz] |
|
|
axes = [i for i in range(3) if i != axis] |
|
|
pts = [] |
|
|
for x in range(GS): |
|
|
for y in range(GS): |
|
|
p = [0, 0, 0] |
|
|
p[axis] = val |
|
|
p[axes[0]] = x |
|
|
p[axes[1]] = y |
|
|
dist = np.sqrt((x - center[axes[0]])**2 + (y - center[axes[1]])**2) |
|
|
if dist <= r: |
|
|
pts.append(p) |
|
|
return np.array(pts) if len(pts) >= 4 else None |
|
|
|
|
|
elif name == "sphere": |
|
|
r = rng.uniform(1.0, 2.2) |
|
|
pts = [] |
|
|
for x in range(GS): |
|
|
for y in range(GS): |
|
|
for z in range(GS): |
|
|
if (x - cx)**2 + (y - cy)**2 + (z - cz)**2 <= r**2: |
|
|
pts.append([x, y, z]) |
|
|
return np.array(pts) if len(pts) >= 4 else None |
|
|
|
|
|
elif name == "hemisphere": |
|
|
r = rng.uniform(1.0, 2.2) |
|
|
cut_axis = rng.randint(3) |
|
|
center = [cx, cy, cz] |
|
|
pts = [] |
|
|
for x in range(GS): |
|
|
for y in range(GS): |
|
|
for z in range(GS): |
|
|
p = [x, y, z] |
|
|
if (x - cx)**2 + (y - cy)**2 + (z - cz)**2 <= r**2: |
|
|
if p[cut_axis] >= center[cut_axis]: |
|
|
pts.append(p) |
|
|
return np.array(pts) if len(pts) >= 3 else None |
|
|
|
|
|
elif name == "cylinder": |
|
|
r = rng.uniform(0.8, 1.8) |
|
|
axis = rng.randint(3) |
|
|
length = rng.randint(2, 5) |
|
|
start = rng.randint(0, GS - length + 1) |
|
|
center = [cx, cy, cz] |
|
|
axes = [i for i in range(3) if i != axis] |
|
|
pts = [] |
|
|
for x in range(GS): |
|
|
for y in range(GS): |
|
|
for z in range(GS): |
|
|
p = [x, y, z] |
|
|
if p[axis] < start or p[axis] >= start + length: continue |
|
|
dist_sq = sum((p[a] - center[a])**2 for a in axes) |
|
|
if dist_sq <= r**2: |
|
|
pts.append(p) |
|
|
return np.array(pts) if len(pts) >= 4 else None |
|
|
|
|
|
elif name == "cone": |
|
|
r_base = rng.uniform(1.0, 2.0) |
|
|
axis = rng.randint(3) |
|
|
height = rng.randint(2, 5) |
|
|
base_pos = rng.randint(0, GS - height + 1) |
|
|
center = [cx, cy, cz] |
|
|
axes = [i for i in range(3) if i != axis] |
|
|
pts = [] |
|
|
for x in range(GS): |
|
|
for y in range(GS): |
|
|
for z in range(GS): |
|
|
p = [x, y, z] |
|
|
along = p[axis] - base_pos |
|
|
if along < 0 or along >= height: continue |
|
|
t = along / (height - 1 + 1e-6) |
|
|
r_at = r_base * (1.0 - t) |
|
|
dist_sq = sum((p[a] - center[a])**2 for a in axes) |
|
|
if dist_sq <= r_at**2: |
|
|
pts.append(p) |
|
|
return np.array(pts) if len(pts) >= 4 else None |
|
|
|
|
|
elif name == "capsule": |
|
|
|
|
|
r = rng.uniform(0.8, 1.5) |
|
|
axis = rng.randint(3) |
|
|
body_len = rng.randint(1, 3) |
|
|
center = [cx, cy, cz] |
|
|
axes = [i for i in range(3) if i != axis] |
|
|
body_start = round(center[axis] - body_len / 2) |
|
|
body_end = body_start + body_len |
|
|
pts = [] |
|
|
for x in range(GS): |
|
|
for y in range(GS): |
|
|
for z in range(GS): |
|
|
p = [x, y, z] |
|
|
radial_sq = sum((p[a] - center[a])**2 for a in axes) |
|
|
along = p[axis] |
|
|
|
|
|
if body_start <= along <= body_end and radial_sq <= r**2: |
|
|
pts.append(p) |
|
|
|
|
|
elif along < body_start: |
|
|
cap_center = list(center) |
|
|
cap_center[axis] = body_start |
|
|
dist_sq = sum((p[i] - cap_center[i])**2 for i in range(3)) |
|
|
if dist_sq <= r**2: |
|
|
pts.append(p) |
|
|
|
|
|
elif along > body_end: |
|
|
cap_center = list(center) |
|
|
cap_center[axis] = body_end |
|
|
dist_sq = sum((p[i] - cap_center[i])**2 for i in range(3)) |
|
|
if dist_sq <= r**2: |
|
|
pts.append(p) |
|
|
return np.array(pts) if len(pts) >= 5 else None |
|
|
|
|
|
elif name == "torus": |
|
|
R = rng.uniform(1.2, 2.0) |
|
|
r = rng.uniform(0.5, 0.9) |
|
|
axis = rng.randint(3) |
|
|
center = [cx, cy, cz] |
|
|
ring_axes = [i for i in range(3) if i != axis] |
|
|
pts = [] |
|
|
for x in range(GS): |
|
|
for y in range(GS): |
|
|
for z in range(GS): |
|
|
p = [x, y, z] |
|
|
dist_in_plane = np.sqrt( |
|
|
sum((p[a] - center[a])**2 for a in ring_axes)) |
|
|
dist_from_ring = np.sqrt( |
|
|
(dist_in_plane - R)**2 + (p[axis] - center[axis])**2) |
|
|
if dist_from_ring <= r: |
|
|
pts.append(p) |
|
|
return np.array(pts) if len(pts) >= 4 else None |
|
|
|
|
|
elif name == "shell": |
|
|
|
|
|
r_out = rng.uniform(1.5, 2.3) |
|
|
r_in = r_out - rng.uniform(0.4, 0.8) |
|
|
if r_in < 0.3: r_in = 0.3 |
|
|
pts = [] |
|
|
for x in range(GS): |
|
|
for y in range(GS): |
|
|
for z in range(GS): |
|
|
d_sq = (x - cx)**2 + (y - cy)**2 + (z - cz)**2 |
|
|
if r_in**2 <= d_sq <= r_out**2: |
|
|
pts.append([x, y, z]) |
|
|
return np.array(pts) if len(pts) >= 4 else None |
|
|
|
|
|
elif name == "tube": |
|
|
|
|
|
r_out = rng.uniform(1.0, 2.0) |
|
|
r_in = r_out - rng.uniform(0.3, 0.7) |
|
|
if r_in < 0.2: r_in = 0.2 |
|
|
axis = rng.randint(3) |
|
|
length = rng.randint(2, 5) |
|
|
start = rng.randint(0, GS - length + 1) |
|
|
center = [cx, cy, cz] |
|
|
axes = [i for i in range(3) if i != axis] |
|
|
pts = [] |
|
|
for x in range(GS): |
|
|
for y in range(GS): |
|
|
for z in range(GS): |
|
|
p = [x, y, z] |
|
|
if p[axis] < start or p[axis] >= start + length: continue |
|
|
dist_sq = sum((p[a] - center[a])**2 for a in axes) |
|
|
if r_in**2 <= dist_sq <= r_out**2: |
|
|
pts.append(p) |
|
|
return np.array(pts) if len(pts) >= 4 else None |
|
|
|
|
|
elif name == "bowl": |
|
|
|
|
|
r = rng.uniform(1.2, 2.2) |
|
|
axis = rng.randint(3) |
|
|
center = [cx, cy, cz] |
|
|
axes = [i for i in range(3) if i != axis] |
|
|
thickness = 0.6 |
|
|
pts = [] |
|
|
for x in range(GS): |
|
|
for y in range(GS): |
|
|
for z in range(GS): |
|
|
p = [x, y, z] |
|
|
dist_planar = np.sqrt( |
|
|
sum((p[a] - center[a])**2 for a in axes)) |
|
|
if dist_planar > r: continue |
|
|
|
|
|
k = 1.0 / (r + 1e-6) |
|
|
expected_h = center[axis] + k * dist_planar**2 |
|
|
actual_h = p[axis] |
|
|
if abs(actual_h - expected_h) <= thickness: |
|
|
pts.append(p) |
|
|
return np.array(pts) if len(pts) >= 4 else None |
|
|
|
|
|
elif name == "saddle": |
|
|
|
|
|
axis = rng.randint(3) |
|
|
center = [cx, cy, cz] |
|
|
axes = [i for i in range(3) if i != axis] |
|
|
k = rng.uniform(0.3, 0.8) |
|
|
thickness = 0.7 |
|
|
pts = [] |
|
|
for x in range(GS): |
|
|
for y in range(GS): |
|
|
for z in range(GS): |
|
|
p = [x, y, z] |
|
|
da = p[axes[0]] - center[axes[0]] |
|
|
db = p[axes[1]] - center[axes[1]] |
|
|
expected_h = center[axis] + k * (da**2 - db**2) |
|
|
if abs(p[axis] - expected_h) <= thickness: |
|
|
|
|
|
dist_sq = da**2 + db**2 |
|
|
if dist_sq <= 4.0: |
|
|
pts.append(p) |
|
|
return np.array(pts) if len(pts) >= 4 else None |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
def _rand_pts_2d(self, n, min_dist=0): |
|
|
for _ in range(50): |
|
|
pts = set() |
|
|
while len(pts) < n: |
|
|
pts.add((self.rng.randint(0, GS), self.rng.randint(0, GS))) |
|
|
pts = np.array(list(pts)[:n]) |
|
|
if min_dist <= 0 or self._check_dist(pts, min_dist): |
|
|
return pts |
|
|
return None |
|
|
|
|
|
def _rand_pts_3d(self, n, min_dist=0): |
|
|
for _ in range(100): |
|
|
pts = set() |
|
|
while len(pts) < n: |
|
|
pts.add(tuple(self.rng.randint(0, GS, size=3))) |
|
|
pts = np.array(list(pts)[:n]) |
|
|
if min_dist <= 0 or self._check_dist(pts, min_dist): |
|
|
return pts |
|
|
return None |
|
|
|
|
|
def _check_dist(self, pts, min_dist): |
|
|
for i in range(len(pts)): |
|
|
for j in range(i + 1, len(pts)): |
|
|
if np.sum(np.abs(pts[i] - pts[j])) < min_dist: |
|
|
return False |
|
|
return True |
|
|
|
|
|
def _build(self, name, info, voxels): |
|
|
n = len(voxels) |
|
|
sub = voxels[:6].astype(float) if n > 6 else voxels.astype(float) |
|
|
cm_det = cayley_menger_det(sub) |
|
|
volume = effective_volume(sub) |
|
|
|
|
|
dim_conf = np.zeros(4, dtype=np.float32) |
|
|
dim_conf[0] = 1.0 |
|
|
if n >= 2: dim_conf[1] = 1.0 |
|
|
if info["dim"] >= 2: dim_conf[2] = 1.0 |
|
|
if info["dim"] >= 3: dim_conf[3] = 1.0 |
|
|
|
|
|
grid = np.zeros((GS, GS, GS), dtype=np.float32) |
|
|
for v in voxels: |
|
|
grid[v[0], v[1], v[2]] = 1.0 |
|
|
|
|
|
return { |
|
|
"grid": grid, "label": CLASS_TO_IDX[name], "class_name": name, |
|
|
"n_points": n, "n_occupied": int(grid.sum()), |
|
|
"cm_det": float(cm_det), "volume": float(volume), |
|
|
"peak_dim": info["dim"], "dim_confidence": dim_conf, |
|
|
"is_curved": info["curved"], "curvature": CURV_TO_IDX[info["curvature"]], |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_chunk(args): |
|
|
"""Worker function for parallel shape generation.""" |
|
|
class_assignments, seed, start_idx = args |
|
|
gen = ShapeGenerator(seed=seed) |
|
|
samples = [] |
|
|
for ci in class_assignments: |
|
|
name = CLASS_NAMES[ci] |
|
|
for attempt in range(10): |
|
|
s = gen._make(name) |
|
|
if s is not None: |
|
|
samples.append(s) |
|
|
break |
|
|
else: |
|
|
s = gen._make("cube") |
|
|
if s is not None: |
|
|
samples.append(s) |
|
|
return samples |
|
|
|
|
|
|
|
|
def generate_parallel(n_samples, seed=42, n_workers=8): |
|
|
"""Pre-generate all samples using multiprocessing.""" |
|
|
import multiprocessing as mp |
|
|
per_class = n_samples // NUM_CLASSES |
|
|
class_assignments = [] |
|
|
for ci in range(NUM_CLASSES): |
|
|
class_assignments.extend([ci] * per_class) |
|
|
rng = np.random.RandomState(seed) |
|
|
while len(class_assignments) < n_samples: |
|
|
class_assignments.append(rng.randint(0, NUM_CLASSES)) |
|
|
rng.shuffle(class_assignments) |
|
|
class_assignments = class_assignments[:n_samples] |
|
|
|
|
|
|
|
|
chunk_size = (n_samples + n_workers - 1) // n_workers |
|
|
chunks = [] |
|
|
for i in range(n_workers): |
|
|
start = i * chunk_size |
|
|
end = min(start + chunk_size, n_samples) |
|
|
if start >= n_samples: |
|
|
break |
|
|
chunks.append((class_assignments[start:end], seed + i * 1000000, start)) |
|
|
|
|
|
print(f"Generating {n_samples} shapes across {len(chunks)} workers...") |
|
|
import time; t0 = time.time() |
|
|
with mp.Pool(n_workers) as pool: |
|
|
results = pool.map(_generate_chunk, chunks) |
|
|
samples = [] |
|
|
for r in results: |
|
|
samples.extend(r) |
|
|
rng.shuffle(samples) |
|
|
dt = time.time() - t0 |
|
|
print(f"Generated {len(samples)} samples in {dt:.1f}s ({len(samples)/dt:.0f} samples/s)") |
|
|
return samples |
|
|
|
|
|
|
|
|
class ShapeDataset(torch.utils.data.Dataset): |
|
|
def __init__(self, samples): |
|
|
self.grids = torch.tensor(np.stack([s["grid"] for s in samples]), dtype=torch.float32) |
|
|
self.labels = torch.tensor([s["label"] for s in samples], dtype=torch.long) |
|
|
self.dim_conf = torch.tensor(np.stack([s["dim_confidence"] for s in samples]), dtype=torch.float32) |
|
|
self.peak_dim = torch.tensor([s["peak_dim"] for s in samples], dtype=torch.long) |
|
|
self.volume = torch.tensor([s["volume"] for s in samples], dtype=torch.float32) |
|
|
self.cm_det = torch.tensor([s["cm_det"] for s in samples], dtype=torch.float32) |
|
|
self.is_curved = torch.tensor([s["is_curved"] for s in samples], dtype=torch.float32) |
|
|
self.curvature = torch.tensor([s["curvature"] for s in samples], dtype=torch.long) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.labels) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
return (self.grids[idx], self.labels[idx], self.dim_conf[idx], |
|
|
self.peak_dim[idx], self.volume[idx], self.cm_det[idx], |
|
|
self.is_curved[idx], self.curvature[idx]) |
|
|
|
|
|
|
|
|
|
|
|
print(f'Loaded {NUM_CLASSES} shape classes, GS={GS}') |