bdck's picture
Upload scripts/train.py
859d9ba verified
"""
train.py
========
Train LrgNet on staged H5 data generated from labeled point clouds.
Example:
python train.py --data_dir staged_h5/ --epochs 50 --batch_size 16 \
--lr 1e-3 --save_dir checkpoints/ --device cuda
"""
import argparse
import glob
from pathlib import Path
from learn_region_grow.train import train_lrgnet
def main():
parser = argparse.ArgumentParser(description="Train LrgNet on staged H5 data")
parser.add_argument("--data_dir", required=True, help="Directory with *.h5 staged training files")
parser.add_argument("--val_split", type=float, default=0.1, help="Fraction of files for validation")
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--device", default="cuda")
parser.add_argument("--lite", type=int, default=0, choices=[0,1,2])
parser.add_argument("--save_dir", default="checkpoints")
parser.add_argument("--resume", default=None)
args = parser.parse_args()
h5_files = sorted(glob.glob(str(Path(args.data_dir) / "*.h5")))
if not h5_files:
raise FileNotFoundError(f"No H5 files found in {args.data_dir}")
split = int(len(h5_files) * (1 - args.val_split))
train_files = h5_files[:split]
val_files = h5_files[split:] if args.val_split > 0 else None
print(f"Train files: {len(train_files)}, Val files: {len(val_files) if val_files else 0}")
model = train_lrgnet(
train_files=train_files,
val_files=val_files,
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
device=args.device,
lite=args.lite,
save_dir=args.save_dir,
resume=args.resume,
)
if __name__ == "__main__":
main()