| """ |
| segment.py |
| ========== |
| |
| Command-line script to run learned region growing on a PLY / PCD point cloud. |
| |
| Example: |
| python segment.py --input scene.ply --ckpt checkpoints/best_model.pt \ |
| --output segmented_scene.ply --device cuda |
| """ |
|
|
| import argparse |
| import torch |
| import numpy as np |
| from pathlib import Path |
|
|
| from learn_region_grow.io import load_point_cloud, save_ply |
| from learn_region_grow.preprocess import voxel_equalize, compute_normals_and_curvature, build_feature_vector |
| from learn_region_grow.lrg_net import LrgNet |
| from learn_region_grow.growing import RegionGrower |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="LRGNet: Learned Region Growing on a Point Cloud") |
| parser.add_argument("--input", required=True, help="Input .ply or .pcd file") |
| parser.add_argument("--ckpt", required=True, help="PyTorch checkpoint (.pt)") |
| parser.add_argument("--output", default="output.ply", help="Output segmented PLY") |
| parser.add_argument("--device", default="cuda", help="cuda or cpu") |
| parser.add_argument("--resolution", type=float, default=0.1, help="Voxel resolution in meters") |
| parser.add_argument("--lite", type=int, default=0, choices=[0,1,2], help="Lite model variant") |
| parser.add_argument("--stochastic", action="store_true", default=True, help="Use stochastic growing (default)") |
| parser.add_argument("--deterministic", dest="stochastic", action="store_false", help="Use deterministic thresholding") |
| parser.add_argument("--add_threshold", type=float, default=0.5) |
| parser.add_argument("--remove_threshold", type=float, default=0.5) |
| parser.add_argument("--cluster_threshold", type=int, default=10) |
| args = parser.parse_args() |
|
|
| device = torch.device(args.device if torch.cuda.is_available() else 'cpu') |
| print(f"Loading point cloud from {args.input} ...") |
| xyz, rgb, normals_input = load_point_cloud(args.input) |
| print(f" {len(xyz)} points loaded") |
|
|
| |
| print(f"Voxel equalization (resolution={args.resolution}m) ...") |
| eq_xyz, eq_idx, voxel_map = voxel_equalize(xyz, args.resolution) |
| eq_rgb = rgb[eq_idx] if rgb is not None else None |
| print(f" {len(eq_xyz)} points after equalization") |
|
|
| |
| if normals_input is not None: |
| normals = np.abs(normals_input[eq_idx]) |
| curvature = np.zeros(len(eq_xyz), dtype=np.float32) |
| |
| _, curvature = compute_normals_and_curvature(eq_xyz, args.resolution) |
| else: |
| print("Computing normals & curvature via PCA ...") |
| normals, curvature = compute_normals_and_curvature(eq_xyz, args.resolution) |
|
|
| |
| features = build_feature_vector(eq_xyz, eq_rgb, normals, curvature) |
| print(f"Feature vector shape: {features.shape}") |
|
|
| |
| print(f"Loading checkpoint {args.ckpt} ...") |
| model = LrgNet(in_channels=13, lite=args.lite) |
| model.load_state_dict(torch.load(args.ckpt, map_location=device)) |
| model.to(device) |
| print("Model loaded.") |
|
|
| |
| grower = RegionGrower( |
| model, device, |
| add_threshold=args.add_threshold, |
| remove_threshold=args.remove_threshold, |
| cluster_threshold=args.cluster_threshold, |
| stochastic=args.stochastic, |
| ) |
| print("Running learned region growing ...") |
| labels = grower.grow(eq_xyz, features, voxel_map, args.resolution) |
| n_instances = len(np.unique(labels[labels >= 0])) |
| print(f"Segmented into {n_instances} instances") |
|
|
| |
| from scipy.spatial import cKDTree |
| tree = cKDTree(eq_xyz) |
| _, nn = tree.query(xyz) |
| full_labels = labels[nn] |
|
|
| |
| print(f"Saving output to {args.output} ...") |
| save_ply(args.output, xyz, labels=full_labels) |
| print("Done.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|