grid-geometric-classifier-proto / data_generator.py
AbstractPhil's picture
Create data_generator.py
4706821 verified
"""
3D Voxel Shape Classifier — Complete Geometric Primitive Vocabulary
5×5×5 binary voxel grid → rigid cascade → curvature analysis → classify
38 shape classes covering:
- Rigid 0D-3D: points, lines, joints, triangles, quads, polyhedra, prisms
- Curved 1D: arcs, helices
- Curved 2D: circles, ellipses, discs
- Curved 3D solid: sphere, hemisphere, cylinder, cone, capsule, torus
- Curved 3D hollow: shell, tube
- Curved 3D open: bowl (concave), saddle (hyperbolic)
Curvature types: none, convex, concave, cylindrical, conical, toroidal, hyperbolic, helical
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
import math
from itertools import combinations
# === SwiGLU Activation =======================================================
class SwiGLU(nn.Module):
"""
SwiGLU activation: out = (x @ W1) * SiLU(x @ W2)
SiLU(x) = x * sigmoid(x), aka Swish — the "Swi" in SwiGLU.
Unlike plain sigmoid gating, SiLU preserves gradient magnitude
through the gate branch while maintaining sharp gating behavior.
Used at geometric decision points where crisp on/off transitions
matter more than smooth interpolation.
"""
def __init__(self, in_dim, out_dim):
super().__init__()
self.w1 = nn.Linear(in_dim, out_dim)
self.w2 = nn.Linear(in_dim, out_dim)
def forward(self, x):
return self.w1(x) * F.silu(self.w2(x))
# === Shape Catalog ===========================================================
SHAPE_CATALOG = {
# ---- Rigid 0D ----
"point": {"dim": 0, "curved": False, "curvature": "none"},
# ---- Rigid 1D: lines ----
"line_x": {"dim": 1, "curved": False, "curvature": "none"},
"line_y": {"dim": 1, "curved": False, "curvature": "none"},
"line_z": {"dim": 1, "curved": False, "curvature": "none"},
"line_diag": {"dim": 1, "curved": False, "curvature": "none"},
# ---- Rigid 1D: compounds ----
"cross": {"dim": 1, "curved": False, "curvature": "none"},
"l_shape": {"dim": 1, "curved": False, "curvature": "none"},
"collinear": {"dim": 1, "curved": False, "curvature": "none"},
# ---- Rigid 2D: triangles ----
"triangle_xy": {"dim": 2, "curved": False, "curvature": "none"},
"triangle_xz": {"dim": 2, "curved": False, "curvature": "none"},
"triangle_3d": {"dim": 2, "curved": False, "curvature": "none"},
# ---- Rigid 2D: quads ----
"square_xy": {"dim": 2, "curved": False, "curvature": "none"},
"square_xz": {"dim": 2, "curved": False, "curvature": "none"},
"rectangle": {"dim": 2, "curved": False, "curvature": "none"},
"coplanar": {"dim": 2, "curved": False, "curvature": "none"},
# ---- Rigid 2D: filled ----
"plane": {"dim": 2, "curved": False, "curvature": "none"},
# ---- Rigid 3D: simplices ----
"tetrahedron": {"dim": 3, "curved": False, "curvature": "none"},
"pyramid": {"dim": 3, "curved": False, "curvature": "none"},
"pentachoron": {"dim": 3, "curved": False, "curvature": "none"},
# ---- Rigid 3D: prisms/polyhedra ----
"cube": {"dim": 3, "curved": False, "curvature": "none"},
"cuboid": {"dim": 3, "curved": False, "curvature": "none"},
"triangular_prism": {"dim": 3, "curved": False, "curvature": "none"},
"octahedron": {"dim": 3, "curved": False, "curvature": "none"},
# ---- Curved 1D ----
"arc": {"dim": 1, "curved": True, "curvature": "convex"},
"helix": {"dim": 1, "curved": True, "curvature": "helical"},
# ---- Curved 2D: outlines ----
"circle": {"dim": 2, "curved": True, "curvature": "convex"},
"ellipse": {"dim": 2, "curved": True, "curvature": "convex"},
# ---- Curved 2D: filled ----
"disc": {"dim": 2, "curved": True, "curvature": "convex"},
# ---- Curved 3D: solid ----
"sphere": {"dim": 3, "curved": True, "curvature": "convex"},
"hemisphere": {"dim": 3, "curved": True, "curvature": "convex"},
"cylinder": {"dim": 3, "curved": True, "curvature": "cylindrical"},
"cone": {"dim": 3, "curved": True, "curvature": "conical"},
"capsule": {"dim": 3, "curved": True, "curvature": "convex"},
"torus": {"dim": 3, "curved": True, "curvature": "toroidal"},
# ---- Curved 3D: hollow ----
"shell": {"dim": 3, "curved": True, "curvature": "convex"},
"tube": {"dim": 3, "curved": True, "curvature": "cylindrical"},
# ---- Curved 3D: open surfaces ----
"bowl": {"dim": 3, "curved": True, "curvature": "concave"},
"saddle": {"dim": 3, "curved": True, "curvature": "hyperbolic"},
}
NUM_CLASSES = len(SHAPE_CATALOG)
CLASS_NAMES = list(SHAPE_CATALOG.keys())
CLASS_TO_IDX = {name: i for i, name in enumerate(CLASS_NAMES)}
CURVATURE_TYPES = ["none", "convex", "concave", "cylindrical", "conical",
"toroidal", "hyperbolic", "helical"]
CURV_TO_IDX = {c: i for i, c in enumerate(CURVATURE_TYPES)}
NUM_CURVATURES = len(CURVATURE_TYPES)
GS = 5 # grid size
# === Cayley-Menger Utilities =================================================
def cayley_menger_det(points: np.ndarray) -> float:
n = len(points)
D = np.zeros((n, n))
for i in range(n):
for j in range(n):
D[i, j] = np.sum((points[i] - points[j]) ** 2)
CM = np.zeros((n + 1, n + 1))
CM[0, 1:] = 1
CM[1:, 0] = 1
CM[1:, 1:] = D
return np.linalg.det(CM)
def simplex_volume(points: np.ndarray) -> float:
k = len(points)
if k < 2: return 0.0
cm = cayley_menger_det(points)
sign = (-1) ** k
denom = (2 ** (k - 1)) * (math.factorial(k - 1) ** 2)
v_sq = sign * cm / denom
return np.sqrt(max(0, v_sq))
def effective_volume(points: np.ndarray) -> float:
k = len(points)
if k < 2: return 0.0
if k == 2: return np.linalg.norm(points[0] - points[1])
if k >= 3:
max_a = 0
for idx in combinations(range(min(k, 8)), 3):
max_a = max(max_a, simplex_volume(points[list(idx)]))
if k < 4: return max_a
if k >= 4:
max_v = 0
for idx in combinations(range(min(k, 8)), 4):
max_v = max(max_v, simplex_volume(points[list(idx)]))
return max_v
return 0.0
# === Shape Generator =========================================================
class ShapeGenerator:
def __init__(self, seed=42):
self.rng = np.random.RandomState(seed)
def generate(self, n_samples: int) -> list:
samples = []
per_class = n_samples // NUM_CLASSES
for name in CLASS_NAMES:
count = 0
attempts = 0
while count < per_class and attempts < per_class * 5:
s = self._make(name)
attempts += 1
if s is not None:
samples.append(s)
count += 1
while len(samples) < n_samples:
name = self.rng.choice(CLASS_NAMES)
s = self._make(name)
if s is not None:
samples.append(s)
self.rng.shuffle(samples)
return samples[:n_samples]
def _make(self, name: str) -> Optional[dict]:
info = SHAPE_CATALOG[name]
if info["curved"]:
voxels = self._curved(name)
else:
voxels = self._rigid(name)
if voxels is None: return None
voxels = np.clip(voxels, 0, GS - 1).astype(int)
voxels = np.unique(voxels, axis=0)
if len(voxels) < 1: return None
return self._build(name, info, voxels)
# === Rigid Generators ===
def _rigid(self, name):
rng = self.rng
if name == "point":
return rng.randint(0, GS, size=(1, 3))
elif name == "line_x":
y, z = rng.randint(0, GS, size=2)
x1, x2 = sorted(rng.choice(GS, 2, replace=False))
return np.array([[x1, y, z], [x2, y, z]])
elif name == "line_y":
x, z = rng.randint(0, GS, size=2)
y1, y2 = sorted(rng.choice(GS, 2, replace=False))
return np.array([[x, y1, z], [x, y2, z]])
elif name == "line_z":
x, y = rng.randint(0, GS, size=2)
z1, z2 = sorted(rng.choice(GS, 2, replace=False))
return np.array([[x, y, z1], [x, y, z2]])
elif name == "line_diag":
p1 = rng.randint(0, 3, size=3)
step = rng.randint(1, 3)
direction = rng.choice([-1, 1], size=3)
if np.sum(direction != 0) < 2:
direction[rng.randint(3)] = rng.choice([-1, 1])
p2 = np.clip(p1 + step * direction, 0, GS - 1)
if np.array_equal(p1, p2):
p2 = np.clip(p1 + np.array([1, 1, 0]), 0, GS - 1)
return np.array([p1, p2])
elif name == "cross":
# Two perpendicular lines intersecting at a point
cx, cy, cz = rng.randint(1, GS - 1, size=3)
length = rng.randint(1, 3)
axis1, axis2 = rng.choice(3, 2, replace=False)
pts = [[cx, cy, cz]] # center
for sign in [-1, 1]:
p = [cx, cy, cz]
p[axis1] = np.clip(p[axis1] + sign * length, 0, GS - 1)
pts.append(list(p))
for sign in [-1, 1]:
p = [cx, cy, cz]
p[axis2] = np.clip(p[axis2] + sign * length, 0, GS - 1)
pts.append(list(p))
return np.array(pts)
elif name == "l_shape":
# Two lines meeting at a vertex (right angle)
corner = rng.randint(1, GS - 1, size=3)
axis1, axis2 = rng.choice(3, 2, replace=False)
len1 = rng.randint(1, 3)
len2 = rng.randint(1, 3)
dir1 = rng.choice([-1, 1])
dir2 = rng.choice([-1, 1])
pts = [list(corner)]
for i in range(1, len1 + 1):
p = list(corner)
p[axis1] = np.clip(p[axis1] + dir1 * i, 0, GS - 1)
pts.append(p)
for i in range(1, len2 + 1):
p = list(corner)
p[axis2] = np.clip(p[axis2] + dir2 * i, 0, GS - 1)
pts.append(p)
return np.array(pts)
elif name == "collinear":
axis = rng.randint(3)
fixed = rng.randint(0, GS, size=2)
vals = sorted(rng.choice(GS, 3, replace=False))
pts = np.zeros((3, 3), dtype=int)
for i, v in enumerate(vals):
pts[i, axis] = v
pts[i, (axis + 1) % 3] = fixed[0]
pts[i, (axis + 2) % 3] = fixed[1]
return pts
elif name == "triangle_xy":
z = rng.randint(0, GS)
pts = self._rand_pts_2d(3, min_dist=1)
if pts is None: return None
return np.column_stack([pts, np.full(3, z)])
elif name == "triangle_xz":
y = rng.randint(0, GS)
pts = self._rand_pts_2d(3, min_dist=1)
if pts is None: return None
return np.column_stack([pts[:, 0], np.full(3, y), pts[:, 1]])
elif name == "triangle_3d":
return self._rand_pts_3d(3, min_dist=1)
elif name == "square_xy":
z = rng.randint(0, GS)
x1, y1 = rng.randint(0, 3, size=2)
s = rng.randint(1, 3)
pts = np.array([[x1, y1, z], [x1 + s, y1, z],
[x1, y1 + s, z], [x1 + s, y1 + s, z]])
return np.clip(pts, 0, GS - 1)
elif name == "square_xz":
y = rng.randint(0, GS)
x1, z1 = rng.randint(0, 3, size=2)
s = rng.randint(1, 3)
pts = np.array([[x1, y, z1], [x1 + s, y, z1],
[x1, y, z1 + s], [x1 + s, y, z1 + s]])
return np.clip(pts, 0, GS - 1)
elif name == "rectangle":
axis = rng.randint(3)
val = rng.randint(0, GS)
a1, a2 = rng.randint(0, 3), rng.randint(0, 3)
w, h = rng.randint(1, 4), rng.randint(1, 3)
if w == h: w = min(GS - 1, w + 1)
c = np.array([[a1, a2], [a1 + w, a2], [a1, a2 + h], [a1 + w, a2 + h]])
c = np.clip(c, 0, GS - 1)
if axis == 0: return np.column_stack([np.full(4, val), c])
elif axis == 1: return np.column_stack([c[:, 0], np.full(4, val), c[:, 1]])
else: return np.column_stack([c, np.full(4, val)])
elif name == "coplanar":
pts = self._rand_pts_3d(4, min_dist=1)
if pts is None: return None
pts[:, rng.randint(3)] = pts[0, rng.randint(3)]
return pts
elif name == "plane":
# Filled rectangular slab, 1 voxel thick
axis = rng.randint(3)
val = rng.randint(0, GS)
a_start = rng.randint(0, 2)
b_start = rng.randint(0, 2)
a_size = rng.randint(2, GS - a_start + 1)
b_size = rng.randint(2, GS - b_start + 1)
pts = []
for a in range(a_start, min(GS, a_start + a_size)):
for b in range(b_start, min(GS, b_start + b_size)):
p = [0, 0, 0]
p[axis] = val
p[(axis + 1) % 3] = a
p[(axis + 2) % 3] = b
pts.append(p)
return np.array(pts) if len(pts) >= 4 else None
elif name == "tetrahedron":
pts = self._rand_pts_3d(4, min_dist=1)
if pts is None: return None
centered = pts - pts.mean(axis=0)
_, s, _ = np.linalg.svd(centered.astype(float))
if s[-1] < 0.5:
pts[rng.randint(4), rng.randint(3)] = (pts[0, 0] + 2) % GS
return pts
elif name == "pyramid":
z_base = rng.randint(0, 3)
x1, y1 = rng.randint(0, 3), rng.randint(0, 3)
s = rng.randint(1, 3)
base = np.array([[x1, y1, z_base], [x1 + s, y1, z_base],
[x1, y1 + s, z_base], [x1 + s, y1 + s, z_base]])
apex = np.array([[x1 + s // 2, y1 + s // 2, z_base + rng.randint(1, 3)]])
return np.clip(np.vstack([base, apex]), 0, GS - 1)
elif name == "pentachoron":
return self._rand_pts_3d(5, min_dist=1)
elif name == "cube":
x1, y1, z1 = rng.randint(0, 3, size=3)
s = rng.randint(1, 3)
pts = []
for dx in [0, s]:
for dy in [0, s]:
for dz in [0, s]:
pts.append([x1 + dx, y1 + dy, z1 + dz])
return np.clip(np.array(pts), 0, GS - 1)
elif name == "cuboid":
x1, y1, z1 = rng.randint(0, 2, size=3)
sx, sy, sz = rng.randint(1, 4, size=3)
# Ensure not a cube: at least 2 different edge lengths
if sx == sy == sz:
sx = min(GS - 1, sx + 1)
pts = []
for dx in [0, sx]:
for dy in [0, sy]:
for dz in [0, sz]:
pts.append([x1 + dx, y1 + dy, z1 + dz])
return np.clip(np.array(pts), 0, GS - 1)
elif name == "triangular_prism":
# Triangle in one plane, extruded along the other axis
axis = rng.randint(3) # extrusion axis
ext_start = rng.randint(0, 3)
ext_len = rng.randint(1, 3)
tri = self._rand_pts_2d(3, min_dist=1)
if tri is None: return None
pts = []
for e in range(ext_start, min(GS, ext_start + ext_len + 1)):
for t in tri:
p = [0, 0, 0]
p[axis] = e
p[(axis + 1) % 3] = t[0]
p[(axis + 2) % 3] = t[1]
pts.append(p)
return np.clip(np.array(pts), 0, GS - 1) if len(pts) >= 6 else None
elif name == "octahedron":
# 6 vertices: ±1 along each axis from center
cx, cy, cz = rng.randint(1, GS - 1, size=3)
s = rng.randint(1, 3)
pts = [[cx, cy, cz + s], [cx, cy, cz - s],
[cx + s, cy, cz], [cx - s, cy, cz],
[cx, cy + s, cz], [cx, cy - s, cz]]
return np.clip(np.array(pts), 0, GS - 1)
return None
# === Curved Generators ===
def _curved(self, name):
rng = self.rng
cx, cy, cz = rng.uniform(1.0, 3.0, size=3)
if name == "arc":
r = rng.uniform(1.2, 2.2)
plane = rng.choice(["xy", "xz", "yz"])
start = rng.uniform(0, 2 * np.pi)
span = rng.uniform(np.pi * 0.4, np.pi * 1.2)
n = rng.randint(6, 12)
angles = np.linspace(start, start + span, n)
pts = []
for a in angles:
if plane == "xy":
pts.append([cx + r * np.cos(a), cy + r * np.sin(a), cz])
elif plane == "xz":
pts.append([cx + r * np.cos(a), cy, cz + r * np.sin(a)])
else:
pts.append([cx, cy + r * np.cos(a), cz + r * np.sin(a)])
pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0)
return pts if len(pts) >= 3 else None
elif name == "helix":
# Spiral through 3D: parametric curve
r = rng.uniform(0.8, 1.8)
axis = rng.randint(3)
pitch = rng.uniform(0.3, 0.8) # rise per radian
n = rng.randint(15, 30)
t = np.linspace(0, 2 * np.pi * rng.uniform(1.0, 2.5), n)
pts = []
center = [cx, cy, cz]
axes = [i for i in range(3) if i != axis]
start_h = rng.uniform(0, 1.0)
for ti in t:
p = [0.0, 0.0, 0.0]
p[axes[0]] = center[axes[0]] + r * np.cos(ti)
p[axes[1]] = center[axes[1]] + r * np.sin(ti)
p[axis] = start_h + pitch * ti
pts.append(p)
pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0)
return pts if len(pts) >= 5 else None
elif name == "circle":
r = rng.uniform(1.0, 2.0)
plane = rng.choice(["xy", "xz", "yz"])
n = rng.randint(12, 20)
angles = np.linspace(0, 2 * np.pi, n, endpoint=False)
pts = []
for a in angles:
if plane == "xy":
pts.append([cx + r * np.cos(a), cy + r * np.sin(a), cz])
elif plane == "xz":
pts.append([cx + r * np.cos(a), cy, cz + r * np.sin(a)])
else:
pts.append([cx, cy + r * np.cos(a), cz + r * np.sin(a)])
pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0)
return pts if len(pts) >= 5 else None
elif name == "ellipse":
rx, ry = rng.uniform(0.8, 2.0), rng.uniform(0.8, 2.0)
if abs(rx - ry) < 0.3: rx *= 1.4
plane = rng.choice(["xy", "xz", "yz"])
n = rng.randint(12, 20)
angles = np.linspace(0, 2 * np.pi, n, endpoint=False)
pts = []
for a in angles:
if plane == "xy":
pts.append([cx + rx * np.cos(a), cy + ry * np.sin(a), cz])
elif plane == "xz":
pts.append([cx + rx * np.cos(a), cy, cz + ry * np.sin(a)])
else:
pts.append([cx, cy + rx * np.cos(a), cz + ry * np.sin(a)])
pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0)
return pts if len(pts) >= 5 else None
elif name == "disc":
# Filled circle in a plane (not just outline)
r = rng.uniform(1.0, 2.2)
axis = rng.randint(3)
val = round(rng.uniform(0.5, 3.5))
center = [cx, cy, cz]
axes = [i for i in range(3) if i != axis]
pts = []
for x in range(GS):
for y in range(GS):
p = [0, 0, 0]
p[axis] = val
p[axes[0]] = x
p[axes[1]] = y
dist = np.sqrt((x - center[axes[0]])**2 + (y - center[axes[1]])**2)
if dist <= r:
pts.append(p)
return np.array(pts) if len(pts) >= 4 else None
elif name == "sphere":
r = rng.uniform(1.0, 2.2)
pts = []
for x in range(GS):
for y in range(GS):
for z in range(GS):
if (x - cx)**2 + (y - cy)**2 + (z - cz)**2 <= r**2:
pts.append([x, y, z])
return np.array(pts) if len(pts) >= 4 else None
elif name == "hemisphere":
r = rng.uniform(1.0, 2.2)
cut_axis = rng.randint(3)
center = [cx, cy, cz]
pts = []
for x in range(GS):
for y in range(GS):
for z in range(GS):
p = [x, y, z]
if (x - cx)**2 + (y - cy)**2 + (z - cz)**2 <= r**2:
if p[cut_axis] >= center[cut_axis]:
pts.append(p)
return np.array(pts) if len(pts) >= 3 else None
elif name == "cylinder":
r = rng.uniform(0.8, 1.8)
axis = rng.randint(3)
length = rng.randint(2, 5)
start = rng.randint(0, GS - length + 1)
center = [cx, cy, cz]
axes = [i for i in range(3) if i != axis]
pts = []
for x in range(GS):
for y in range(GS):
for z in range(GS):
p = [x, y, z]
if p[axis] < start or p[axis] >= start + length: continue
dist_sq = sum((p[a] - center[a])**2 for a in axes)
if dist_sq <= r**2:
pts.append(p)
return np.array(pts) if len(pts) >= 4 else None
elif name == "cone":
r_base = rng.uniform(1.0, 2.0)
axis = rng.randint(3)
height = rng.randint(2, 5)
base_pos = rng.randint(0, GS - height + 1)
center = [cx, cy, cz]
axes = [i for i in range(3) if i != axis]
pts = []
for x in range(GS):
for y in range(GS):
for z in range(GS):
p = [x, y, z]
along = p[axis] - base_pos
if along < 0 or along >= height: continue
t = along / (height - 1 + 1e-6)
r_at = r_base * (1.0 - t)
dist_sq = sum((p[a] - center[a])**2 for a in axes)
if dist_sq <= r_at**2:
pts.append(p)
return np.array(pts) if len(pts) >= 4 else None
elif name == "capsule":
# Cylinder with hemispherical caps
r = rng.uniform(0.8, 1.5)
axis = rng.randint(3)
body_len = rng.randint(1, 3)
center = [cx, cy, cz]
axes = [i for i in range(3) if i != axis]
body_start = round(center[axis] - body_len / 2)
body_end = body_start + body_len
pts = []
for x in range(GS):
for y in range(GS):
for z in range(GS):
p = [x, y, z]
radial_sq = sum((p[a] - center[a])**2 for a in axes)
along = p[axis]
# Body
if body_start <= along <= body_end and radial_sq <= r**2:
pts.append(p)
# Bottom cap
elif along < body_start:
cap_center = list(center)
cap_center[axis] = body_start
dist_sq = sum((p[i] - cap_center[i])**2 for i in range(3))
if dist_sq <= r**2:
pts.append(p)
# Top cap
elif along > body_end:
cap_center = list(center)
cap_center[axis] = body_end
dist_sq = sum((p[i] - cap_center[i])**2 for i in range(3))
if dist_sq <= r**2:
pts.append(p)
return np.array(pts) if len(pts) >= 5 else None
elif name == "torus":
R = rng.uniform(1.2, 2.0)
r = rng.uniform(0.5, 0.9)
axis = rng.randint(3)
center = [cx, cy, cz]
ring_axes = [i for i in range(3) if i != axis]
pts = []
for x in range(GS):
for y in range(GS):
for z in range(GS):
p = [x, y, z]
dist_in_plane = np.sqrt(
sum((p[a] - center[a])**2 for a in ring_axes))
dist_from_ring = np.sqrt(
(dist_in_plane - R)**2 + (p[axis] - center[axis])**2)
if dist_from_ring <= r:
pts.append(p)
return np.array(pts) if len(pts) >= 4 else None
elif name == "shell":
# Hollow sphere: outer radius - inner radius
r_out = rng.uniform(1.5, 2.3)
r_in = r_out - rng.uniform(0.4, 0.8)
if r_in < 0.3: r_in = 0.3
pts = []
for x in range(GS):
for y in range(GS):
for z in range(GS):
d_sq = (x - cx)**2 + (y - cy)**2 + (z - cz)**2
if r_in**2 <= d_sq <= r_out**2:
pts.append([x, y, z])
return np.array(pts) if len(pts) >= 4 else None
elif name == "tube":
# Hollow cylinder
r_out = rng.uniform(1.0, 2.0)
r_in = r_out - rng.uniform(0.3, 0.7)
if r_in < 0.2: r_in = 0.2
axis = rng.randint(3)
length = rng.randint(2, 5)
start = rng.randint(0, GS - length + 1)
center = [cx, cy, cz]
axes = [i for i in range(3) if i != axis]
pts = []
for x in range(GS):
for y in range(GS):
for z in range(GS):
p = [x, y, z]
if p[axis] < start or p[axis] >= start + length: continue
dist_sq = sum((p[a] - center[a])**2 for a in axes)
if r_in**2 <= dist_sq <= r_out**2:
pts.append(p)
return np.array(pts) if len(pts) >= 4 else None
elif name == "bowl":
# Paraboloid: concave surface, open on top
r = rng.uniform(1.2, 2.2)
axis = rng.randint(3)
center = [cx, cy, cz]
axes = [i for i in range(3) if i != axis]
thickness = 0.6
pts = []
for x in range(GS):
for y in range(GS):
for z in range(GS):
p = [x, y, z]
dist_planar = np.sqrt(
sum((p[a] - center[a])**2 for a in axes))
if dist_planar > r: continue
# Paraboloid surface: h = k * dist^2
k = 1.0 / (r + 1e-6)
expected_h = center[axis] + k * dist_planar**2
actual_h = p[axis]
if abs(actual_h - expected_h) <= thickness:
pts.append(p)
return np.array(pts) if len(pts) >= 4 else None
elif name == "saddle":
# Hyperbolic paraboloid: z = k*(x^2 - y^2)
axis = rng.randint(3)
center = [cx, cy, cz]
axes = [i for i in range(3) if i != axis]
k = rng.uniform(0.3, 0.8)
thickness = 0.7
pts = []
for x in range(GS):
for y in range(GS):
for z in range(GS):
p = [x, y, z]
da = p[axes[0]] - center[axes[0]]
db = p[axes[1]] - center[axes[1]]
expected_h = center[axis] + k * (da**2 - db**2)
if abs(p[axis] - expected_h) <= thickness:
# Limit radius so it doesn't fill everything
dist_sq = da**2 + db**2
if dist_sq <= 4.0:
pts.append(p)
return np.array(pts) if len(pts) >= 4 else None
return None
# === Helpers ===
def _rand_pts_2d(self, n, min_dist=0):
for _ in range(50):
pts = set()
while len(pts) < n:
pts.add((self.rng.randint(0, GS), self.rng.randint(0, GS)))
pts = np.array(list(pts)[:n])
if min_dist <= 0 or self._check_dist(pts, min_dist):
return pts
return None
def _rand_pts_3d(self, n, min_dist=0):
for _ in range(100):
pts = set()
while len(pts) < n:
pts.add(tuple(self.rng.randint(0, GS, size=3)))
pts = np.array(list(pts)[:n])
if min_dist <= 0 or self._check_dist(pts, min_dist):
return pts
return None
def _check_dist(self, pts, min_dist):
for i in range(len(pts)):
for j in range(i + 1, len(pts)):
if np.sum(np.abs(pts[i] - pts[j])) < min_dist:
return False
return True
def _build(self, name, info, voxels):
n = len(voxels)
sub = voxels[:6].astype(float) if n > 6 else voxels.astype(float)
cm_det = cayley_menger_det(sub)
volume = effective_volume(sub)
dim_conf = np.zeros(4, dtype=np.float32)
dim_conf[0] = 1.0
if n >= 2: dim_conf[1] = 1.0
if info["dim"] >= 2: dim_conf[2] = 1.0
if info["dim"] >= 3: dim_conf[3] = 1.0
grid = np.zeros((GS, GS, GS), dtype=np.float32)
for v in voxels:
grid[v[0], v[1], v[2]] = 1.0
return {
"grid": grid, "label": CLASS_TO_IDX[name], "class_name": name,
"n_points": n, "n_occupied": int(grid.sum()),
"cm_det": float(cm_det), "volume": float(volume),
"peak_dim": info["dim"], "dim_confidence": dim_conf,
"is_curved": info["curved"], "curvature": CURV_TO_IDX[info["curvature"]],
}
# === Dataset =================================================================
def _generate_chunk(args):
"""Worker function for parallel shape generation."""
class_assignments, seed, start_idx = args
gen = ShapeGenerator(seed=seed)
samples = []
for ci in class_assignments:
name = CLASS_NAMES[ci]
for attempt in range(10):
s = gen._make(name)
if s is not None:
samples.append(s)
break
else:
s = gen._make("cube")
if s is not None:
samples.append(s)
return samples
def generate_parallel(n_samples, seed=42, n_workers=8):
"""Pre-generate all samples using multiprocessing."""
import multiprocessing as mp
per_class = n_samples // NUM_CLASSES
class_assignments = []
for ci in range(NUM_CLASSES):
class_assignments.extend([ci] * per_class)
rng = np.random.RandomState(seed)
while len(class_assignments) < n_samples:
class_assignments.append(rng.randint(0, NUM_CLASSES))
rng.shuffle(class_assignments)
class_assignments = class_assignments[:n_samples]
# Split into chunks per worker
chunk_size = (n_samples + n_workers - 1) // n_workers
chunks = []
for i in range(n_workers):
start = i * chunk_size
end = min(start + chunk_size, n_samples)
if start >= n_samples:
break
chunks.append((class_assignments[start:end], seed + i * 1000000, start))
print(f"Generating {n_samples} shapes across {len(chunks)} workers...")
import time; t0 = time.time()
with mp.Pool(n_workers) as pool:
results = pool.map(_generate_chunk, chunks)
samples = []
for r in results:
samples.extend(r)
rng.shuffle(samples)
dt = time.time() - t0
print(f"Generated {len(samples)} samples in {dt:.1f}s ({len(samples)/dt:.0f} samples/s)")
return samples
class ShapeDataset(torch.utils.data.Dataset):
def __init__(self, samples):
self.grids = torch.tensor(np.stack([s["grid"] for s in samples]), dtype=torch.float32)
self.labels = torch.tensor([s["label"] for s in samples], dtype=torch.long)
self.dim_conf = torch.tensor(np.stack([s["dim_confidence"] for s in samples]), dtype=torch.float32)
self.peak_dim = torch.tensor([s["peak_dim"] for s in samples], dtype=torch.long)
self.volume = torch.tensor([s["volume"] for s in samples], dtype=torch.float32)
self.cm_det = torch.tensor([s["cm_det"] for s in samples], dtype=torch.float32)
self.is_curved = torch.tensor([s["is_curved"] for s in samples], dtype=torch.float32)
self.curvature = torch.tensor([s["curvature"] for s in samples], dtype=torch.long)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return (self.grids[idx], self.labels[idx], self.dim_conf[idx],
self.peak_dim[idx], self.volume[idx], self.cm_det[idx],
self.is_curved[idx], self.curvature[idx])
print(f'Loaded {NUM_CLASSES} shape classes, GS={GS}')