TerrainFreeSpaceNet / inference.py
Dinusharg's picture
Initial commit
5b1ad51
from __future__ import annotations
import argparse
import json
import pandas as pd
import torch
from .model import PointNetRegressor
from .preprocess import prepare_points
from .utils import get_device
def load_points_from_csv(
csv_path: str,
x_col: str = "x",
y_col: str = "y",
z_col: str = "z",
) -> torch.Tensor:
df = pd.read_csv(csv_path, usecols=[x_col, y_col, z_col])
for c in [x_col, y_col, z_col]:
df[c] = pd.to_numeric(df[c], errors="coerce")
df = df.dropna(subset=[x_col, y_col, z_col])
points = df[[x_col, y_col, z_col]].to_numpy(dtype="float32")
return points
def load_model(checkpoint_path: str, device: str | None = None):
device = device or get_device()
checkpoint = torch.load(checkpoint_path, map_location=device)
input_dim = int(checkpoint.get("input_dim", 3))
model = PointNetRegressor(input_dim=input_dim)
model.load_state_dict(checkpoint["model_state"])
model.to(device)
model.eval()
return model, checkpoint, device
@torch.no_grad()
def predict(
csv_path: str,
checkpoint_path: str,
num_points: int | None = None,
device: str | None = None,
) -> dict:
model, checkpoint, device = load_model(checkpoint_path, device=device)
if num_points is None:
num_points = int(checkpoint.get("num_points", 2048))
points = load_points_from_csv(csv_path)
points = prepare_points(points, num_points=num_points)
x = torch.from_numpy(points).transpose(0, 1).unsqueeze(0).to(device) # [1, 3, N]
pred = model(x).item()
return {
"free_space_score": float(pred),
"num_points_used": int(num_points),
"device": device,
}
def main():
parser = argparse.ArgumentParser(description="Inference for TerrainFreeSpaceNet")
parser.add_argument("--input_csv", type=str, required=True, help="Input CSV path")
parser.add_argument("--checkpoint", type=str, required=True, help="Checkpoint path")
parser.add_argument("--num_points", type=int, default=None, help="Override checkpoint num_points")
args = parser.parse_args()
result = predict(
csv_path=args.input_csv,
checkpoint_path=args.checkpoint,
num_points=args.num_points,
)
print(json.dumps(result, indent=2))
if __name__ == "__main__":
main()