rans-env / examples /train_against_server.py
dpang's picture
Update examples/train_against_server.py
6de1b43 verified
#!/usr/bin/env python3
# Copyright (c) Space Robotics Lab, SnT, University of Luxembourg, SpaceR
# RANS: arXiv:2310.07393 β€” OpenEnv training examples
"""
PPO Training Against a Running RANS Server
==========================================
Trains a spacecraft navigation policy via the OpenEnv HTTP/WebSocket client,
connecting to a RANS server running locally (uvicorn) or in Docker.
This is the canonical OpenEnv training pattern:
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” HTTP/WS β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ ppo_train (this script) β”‚ ──────────► β”‚ uvicorn / Docker β”‚
β”‚ RemoteRANSGymnasiumEnv β”‚ β”‚ RANSEnvironment β”‚
β”‚ ActorCritic + PPO β”‚ ◄────────── β”‚ spacecraft physics β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
Start the server first:
uvicorn rans_env.server.app:app --host 0.0.0.0 --port 8000
Then run this script:
python examples/train_against_server.py --task GoToPosition
python examples/train_against_server.py --task GoToPose --url http://localhost:8000
python examples/train_against_server.py --eval --checkpoint rans_ppo_remote_GoToPosition.pt
Requirements:
pip install torch gymnasium openenv-core
"""
from __future__ import annotations
import argparse
import os
import sys
import time
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
try:
import gymnasium as gym
from gymnasium import spaces
except ImportError:
print("gymnasium is required: pip install gymnasium")
sys.exit(1)
try:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
except ImportError:
print("torch is required: pip install torch")
sys.exit(1)
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
# ---------------------------------------------------------------------------
# Remote gymnasium wrapper (uses OpenEnv HTTP client)
# ---------------------------------------------------------------------------
class RemoteRANSGymnasiumEnv(gym.Env):
"""
Gymnasium-compatible environment that connects to a running RANS server
via the OpenEnv ``RANSEnv`` WebSocket/HTTP client.
Identical observation and action spaces to ``RANSGymnasiumEnv``, but all
physics runs inside the server process (or Docker container).
"""
metadata = {"render_modes": []}
def __init__(self, base_url: str = "http://localhost:8000") -> None:
super().__init__()
from rans_env import RANSEnv, SpacecraftAction
self._SpacecraftAction = SpacecraftAction
# EnvClient is already synchronous (WebSocket-based); just connect.
self._client = RANSEnv(base_url=base_url)
self._client.connect()
# Probe the environment to determine spaces
result = self._client.reset()
obs = result.observation
flat = self._flatten(obs)
n = len(obs.thruster_masks)
self.action_space = spaces.Box(low=0.0, high=1.0, shape=(n,), dtype=np.float32)
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(flat.shape[0],), dtype=np.float32
)
self._last_flat = flat
self._task = obs.task
def reset(self, *, seed=None, options=None) -> Tuple[np.ndarray, Dict]:
super().reset(seed=seed)
result = self._client.reset()
self._last_flat = self._flatten(result.observation)
return self._last_flat, {"task": result.observation.task}
def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict]:
result = self._client.step(
self._SpacecraftAction(thrusters=action.tolist())
)
flat = self._flatten(result.observation)
reward = float(result.reward or 0.0)
done = bool(result.done)
self._last_flat = flat
return flat, reward, done, False, result.observation.info or {}
def close(self) -> None:
self._client.disconnect()
@staticmethod
def _flatten(obs) -> np.ndarray:
return np.concatenate([
np.array(obs.state_obs, dtype=np.float32),
np.array(obs.thruster_transforms, dtype=np.float32).flatten(),
np.array(obs.thruster_masks, dtype=np.float32),
np.array([obs.mass, obs.inertia], dtype=np.float32),
])
# ---------------------------------------------------------------------------
# Re-use ActorCritic and PPO from ppo_train.py
# ---------------------------------------------------------------------------
def _mlp(in_dim: int, hidden: List[int], out_dim: int) -> nn.Sequential:
layers: List[nn.Module] = []
prev = in_dim
for h in hidden:
layers += [nn.Linear(prev, h), nn.Tanh()]
prev = h
layers.append(nn.Linear(prev, out_dim))
return nn.Sequential(*layers)
class ActorCritic(nn.Module):
def __init__(self, obs_dim: int, act_dim: int, hidden: List[int] = None):
super().__init__()
hidden = hidden or [64, 64]
self.actor_mean = _mlp(obs_dim, hidden, act_dim)
self.log_std = nn.Parameter(torch.zeros(act_dim))
self.critic = _mlp(obs_dim, hidden, 1)
def forward(self, obs):
mean = torch.sigmoid(self.actor_mean(obs))
std = self.log_std.exp().expand_as(mean)
return Normal(mean, std), self.critic(obs).squeeze(-1)
@torch.no_grad()
def act(self, obs):
dist, value = self(obs)
action = dist.sample().clamp(0.0, 1.0)
return action, dist.log_prob(action).sum(-1), value
@torch.no_grad()
def act_deterministic(self, obs):
return torch.sigmoid(self.actor_mean(obs)).clamp(0.0, 1.0)
class RolloutBuffer:
def __init__(self, n: int, obs_dim: int, act_dim: int, device: str):
self.n, self.device = n, device
self.obs = torch.zeros(n, obs_dim, device=device)
self.actions = torch.zeros(n, act_dim, device=device)
self.log_probs = torch.zeros(n, device=device)
self.rewards = torch.zeros(n, device=device)
self.values = torch.zeros(n, device=device)
self.dones = torch.zeros(n, device=device)
self.ptr = 0
def add(self, obs, action, log_prob, reward, value, done):
i = self.ptr
self.obs[i], self.actions[i] = obs, action
self.log_probs[i], self.rewards[i] = log_prob, reward
self.values[i], self.dones[i] = value, done
self.ptr += 1
def reset(self): self.ptr = 0
def compute_gae(self, last_value, gamma=0.99, lam=0.95):
adv = torch.zeros_like(self.rewards)
last_gae = 0.0
for t in reversed(range(self.n)):
nv = last_value if t == self.n - 1 else self.values[t + 1]
nd = 0.0 if t == self.n - 1 else self.dones[t + 1]
delta = self.rewards[t] + gamma * nv * (1 - nd) - self.values[t]
last_gae = delta + gamma * lam * (1 - self.dones[t]) * last_gae
adv[t] = last_gae
return adv, adv + self.values
def ppo_update(policy, optimizer, buf, adv, returns,
clip=0.2, ent=0.01, vf=0.5, epochs=10, bs=64):
n = buf.n
stats = {"pi": 0.0, "vf": 0.0, "ent": 0.0}
n_updates = 0
for _ in range(epochs):
for s in range(0, n, bs):
mb = torch.randperm(n, device=buf.device)[s:s+bs]
a_b = (adv[mb] - adv[mb].mean()) / (adv[mb].std() + 1e-8)
dist, val = policy(buf.obs[mb])
lp = dist.log_prob(buf.actions[mb]).sum(-1)
r = (lp - buf.log_probs[mb]).exp()
pi_loss = -torch.min(r * a_b, r.clamp(1-clip, 1+clip) * a_b).mean()
vf_loss = (val - returns[mb]).pow(2).mean()
loss = pi_loss + vf * vf_loss - ent * dist.entropy().sum(-1).mean()
optimizer.zero_grad(); loss.backward()
nn.utils.clip_grad_norm_(policy.parameters(), 0.5)
optimizer.step()
stats["pi"] += pi_loss.item()
stats["vf"] += vf_loss.item()
n_updates += 1
return {key: val / max(n_updates, 1) for key, val in stats.items()}
# ---------------------------------------------------------------------------
# Training loop
# ---------------------------------------------------------------------------
def train(args: argparse.Namespace) -> None:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\nRANS PPO β€” Remote Training via OpenEnv Client")
print(f" server : {args.url}")
print(f" task : connecting… (task set by RANS_TASK on server)")
print(f" device : {device}")
print(f" steps : {args.timesteps:,}")
print("=" * 60)
env = RemoteRANSGymnasiumEnv(base_url=args.url)
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
print(f" task : {env._task}")
print(f" obs_dim : {obs_dim}")
print(f" act_dim : {act_dim} (thrusters)")
print()
policy = ActorCritic(obs_dim, act_dim).to(device)
optimizer = optim.Adam(policy.parameters(), lr=args.lr)
if args.checkpoint and os.path.exists(args.checkpoint):
ck = torch.load(args.checkpoint, map_location=device)
policy.load_state_dict(ck["policy"])
optimizer.load_state_dict(ck["optimizer"])
print(f" Loaded checkpoint: {args.checkpoint}")
buf = RolloutBuffer(args.n_steps, obs_dim, act_dim, device)
ep_rewards: List[float] = []
ep_lengths: List[int] = []
ep_reward = ep_len = 0.0
best_mean = -float("inf")
obs_np, _ = env.reset()
obs = torch.from_numpy(obs_np).float().to(device)
total_steps = update_num = 0
t0 = time.perf_counter()
while total_steps < args.timesteps:
buf.reset()
for _ in range(args.n_steps):
action, log_prob, value = policy.act(obs)
next_obs_np, reward, terminated, truncated, info = env.step(
action.cpu().numpy()
)
done = terminated or truncated
buf.add(obs, action, log_prob,
torch.tensor(reward, device=device),
value,
torch.tensor(float(done), device=device))
ep_reward += reward
ep_len += 1
total_steps += 1
if done:
ep_rewards.append(ep_reward)
ep_lengths.append(ep_len)
ep_reward = ep_len = 0.0
next_obs_np, _ = env.reset()
obs = torch.from_numpy(next_obs_np).float().to(device)
with torch.no_grad():
_, last_val = policy(obs)
adv, returns = buf.compute_gae(last_val, args.gamma, args.lam)
stats = ppo_update(policy, optimizer, buf, adv, returns,
clip=args.clip_eps, ent=args.entropy_coef,
epochs=args.n_epochs, bs=args.batch_size)
update_num += 1
if update_num % args.log_interval == 0:
mean_rew = np.mean(ep_rewards[-100:]) if ep_rewards else float("nan")
fps = total_steps / (time.perf_counter() - t0)
print(f" update {update_num:4d} | steps {total_steps:7,} | "
f"mean_rew {mean_rew:6.3f} | fps {fps:4.0f} | "
f"pi {stats['pi']:+.4f} vf {stats['vf']:.4f}")
if ep_rewards:
mean_rew = np.mean(ep_rewards[-100:])
if mean_rew > best_mean:
best_mean = mean_rew
ck_path = args.checkpoint or f"rans_ppo_remote_{env._task}.pt"
torch.save({"policy": policy.state_dict(),
"optimizer": optimizer.state_dict(),
"best_mean_reward": best_mean,
"task": env._task}, ck_path)
env.close()
print(f"\nTraining complete. Best mean reward: {best_mean:.3f}")
print(f"Checkpoint: {args.checkpoint or f'rans_ppo_remote_{env._task}.pt'}")
# ---------------------------------------------------------------------------
# Evaluation
# ---------------------------------------------------------------------------
def evaluate(args: argparse.Namespace) -> None:
env = RemoteRANSGymnasiumEnv(base_url=args.url)
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
ck = torch.load(args.checkpoint, map_location="cpu")
policy = ActorCritic(obs_dim, act_dim)
policy.load_state_dict(ck["policy"])
policy.eval()
print(f"\nEvaluating {args.checkpoint} against {args.url}")
print(f" task: {env._task} | best training reward: {ck.get('best_mean_reward', '?'):.3f}")
print("=" * 60)
for ep in range(args.eval_episodes):
obs_np, _ = env.reset()
total_r, steps = 0.0, 0
while True:
action = policy.act_deterministic(
torch.from_numpy(obs_np).float()
).numpy()
obs_np, r, term, trunc, info = env.step(action)
total_r += r; steps += 1
if term or trunc: break
print(f" ep {ep+1:2d} | steps {steps:4d} | reward {total_r:.3f} | "
f"goal {info.get('goal_reached', '?')}")
env.close()
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def main():
p = argparse.ArgumentParser(description="RANS PPO training via OpenEnv client")
p.add_argument("--url", default="http://localhost:8000")
p.add_argument("--timesteps", type=int, default=200_000)
p.add_argument("--n-steps", type=int, default=1024)
p.add_argument("--n-epochs", type=int, default=10)
p.add_argument("--batch-size", type=int, default=64)
p.add_argument("--lr", type=float, default=3e-4)
p.add_argument("--gamma", type=float, default=0.99)
p.add_argument("--lam", type=float, default=0.95)
p.add_argument("--clip-eps", type=float, default=0.2)
p.add_argument("--entropy-coef", type=float, default=0.01)
p.add_argument("--log-interval", type=int, default=5)
p.add_argument("--checkpoint", default=None)
p.add_argument("--eval", action="store_true")
p.add_argument("--eval-episodes", type=int, default=10)
args = p.parse_args()
if args.eval:
if not args.checkpoint:
print("--eval requires --checkpoint"); sys.exit(1)
evaluate(args)
else:
train(args)
if __name__ == "__main__":
main()