rogermt commited on
Commit
376238e
·
verified ·
1 Parent(s): 3e32ac2

Add pool.finalize() call after building trajectory pool for O(1) sampling

Browse files
Files changed (1) hide show
  1. trainer.py +3 -0
trainer.py CHANGED
@@ -74,6 +74,9 @@ class NSGFTrainer:
74
  if (batch_idx + 1) % max(1, num_batches // 10) == 0:
75
  logger.info(f" Pool building: {batch_idx + 1}/{num_batches}, pool size: {len(self.pool)}")
76
  logger.info(f"Trajectory pool built. Total entries: {len(self.pool)}")
 
 
 
77
 
78
  def train(self) -> Dict[str, list]:
79
  self.model.train()
 
74
  if (batch_idx + 1) % max(1, num_batches // 10) == 0:
75
  logger.info(f" Pool building: {batch_idx + 1}/{num_batches}, pool size: {len(self.pool)}")
76
  logger.info(f"Trajectory pool built. Total entries: {len(self.pool)}")
77
+ # Pre-concatenate for O(1) sampling during training
78
+ self.pool.finalize()
79
+ logger.info("Trajectory pool finalized (pre-concatenated for fast sampling).")
80
 
81
  def train(self) -> Dict[str, list]:
82
  self.model.train()