team222 / train_liquid.py
ylop's picture
Deploy 2M step LNN training with optimized GPU utilization
28dbd6d verified
raw
history blame
3.06 kB
import os
import sys
import argparse
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
import multiprocessing
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
from stable_baselines3.common.callbacks import CheckpointCallback
# Add project root to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from env_3d import Drone3DEnv
from models.liquid_policy import LiquidFeatureExtractor
def make_env(rank, seed=0):
def _init():
env = Drone3DEnv()
env.reset(seed=seed + rank)
return env
return _init
def main():
# Verify CUDA
import torch
print(f"Is CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
else:
print("WARNING: CUDA is NOT available. Training will be slow.")
parser = argparse.ArgumentParser()
parser.add_argument("--timesteps", type=int, default=8_000_000) # 8M steps as requested
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
# Create Vectorized Env
# 8x L40S has ~192 vCPUs. Using 64-96 is usually the sweet spot for PPO.
# Too many envs can cause overhead or instability.
max_cpu = 64
available_cpu = multiprocessing.cpu_count()
num_cpu = min(max_cpu, available_cpu)
print(f"Using {num_cpu} CPUs for parallel environments (Available: {available_cpu}).")
# SubprocVecEnv for true parallelism
env = SubprocVecEnv([make_env(i, args.seed) for i in range(num_cpu)])
env = VecMonitor(env) # Monitor wrapper for logging
# Liquid Policy Config
policy_kwargs = dict(
features_extractor_class=LiquidFeatureExtractor,
features_extractor_kwargs=dict(features_dim=128, hidden_size=128, dt=0.05), # Large Capacity
net_arch=dict(pi=[256, 256], vf=[256, 256]) # Large Capacity
)
# Initialize PPO
model = PPO(
"MlpPolicy",
env,
policy_kwargs=policy_kwargs,
verbose=1,
learning_rate=3e-4,
n_steps=8192, # Large horizon for stability
batch_size=2048, # Large batch for powerful GPU
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
tensorboard_log="./logs/ppo_liquid_3d/",
device="cuda"
)
# Checkpoint Callback
checkpoint_callback = CheckpointCallback(
save_freq=500_000 // num_cpu,
save_path="./models/checkpoints/",
name_prefix="liquid_ppo_3d"
)
print("Starting Training with Liquid Neural Network Policy...")
print(f"Target Timesteps: {args.timesteps}")
print(f"Configuration: {num_cpu} Envs, {model.n_steps} Steps, {model.batch_size} Batch Size")
model.learn(total_timesteps=args.timesteps, callback=checkpoint_callback)
model.save("models/liquid_ppo_3d_final")
print("Model saved to models/liquid_ppo_3d_final.zip")
if __name__ == "__main__":
main()