renpas22 commited on
Commit ·
ccd696b
1
Parent(s): cd76323
Remove dead code with direct config access
Browse files
src/reasoning/step_level_cot.py
CHANGED
|
@@ -636,79 +636,6 @@ class StepLevelCoTTrainer:
|
|
| 636 |
|
| 637 |
progress_bar.close()
|
| 638 |
logger.info(f"PRM training completed - Avg loss: {total_loss / max(global_step, 1):.4f}")
|
| 639 |
-
|
| 640 |
-
prm_trainer = PRMTrainer(
|
| 641 |
-
model=self.prm,
|
| 642 |
-
learning_rate=learning_rate, # Use the passed parameter
|
| 643 |
-
weight_decay=self.config.weight_decay,
|
| 644 |
-
warmup_steps=self.config.warmup_steps,
|
| 645 |
-
)
|
| 646 |
-
|
| 647 |
-
train_dataloader = torch.utils.data.DataLoader(
|
| 648 |
-
train_dataset,
|
| 649 |
-
batch_size=self.config.train_batch_size,
|
| 650 |
-
shuffle=True,
|
| 651 |
-
num_workers=self.config.dataloader_num_workers,
|
| 652 |
-
)
|
| 653 |
-
|
| 654 |
-
# Prepare with accelerator
|
| 655 |
-
prm_trainer.model, prm_trainer.optimizer, train_dataloader = self.accelerator.prepare(
|
| 656 |
-
prm_trainer.model, prm_trainer.optimizer, train_dataloader
|
| 657 |
-
)
|
| 658 |
-
|
| 659 |
-
best_val_loss = float('inf')
|
| 660 |
-
global_step = 0
|
| 661 |
-
|
| 662 |
-
# Training loop using steps instead of epochs
|
| 663 |
-
prm_trainer.model.train()
|
| 664 |
-
|
| 665 |
-
progress_bar = tqdm(
|
| 666 |
-
total=max_steps,
|
| 667 |
-
desc=f"PRM Training",
|
| 668 |
-
disable=not self.accelerator.is_local_main_process,
|
| 669 |
-
)
|
| 670 |
-
|
| 671 |
-
while global_step < max_steps:
|
| 672 |
-
for batch in train_dataloader:
|
| 673 |
-
if global_step >= max_steps:
|
| 674 |
-
break
|
| 675 |
-
|
| 676 |
-
# Extract batch data
|
| 677 |
-
vision_features = batch['visual_features']
|
| 678 |
-
step_embeddings = batch['step_descriptions'] # Need to encode these
|
| 679 |
-
target_rewards = batch['step_rewards']
|
| 680 |
-
|
| 681 |
-
# Encode step descriptions
|
| 682 |
-
# (In practice, you'd encode text properly)
|
| 683 |
-
|
| 684 |
-
metrics = prm_trainer.train_step(
|
| 685 |
-
vision_features,
|
| 686 |
-
step_embeddings,
|
| 687 |
-
target_rewards,
|
| 688 |
-
)
|
| 689 |
-
|
| 690 |
-
global_step += 1
|
| 691 |
-
progress_bar.update(1)
|
| 692 |
-
progress_bar.set_postfix({'loss': metrics['loss'], 'step': global_step})
|
| 693 |
-
|
| 694 |
-
# Save checkpoint
|
| 695 |
-
if global_step % save_steps == 0:
|
| 696 |
-
logger.info(f"Saving checkpoint at step {global_step}")
|
| 697 |
-
self._save_prm(global_step)
|
| 698 |
-
|
| 699 |
-
# Validation
|
| 700 |
-
if eval_steps > 0 and global_step % eval_steps == 0 and val_dataset is not None:
|
| 701 |
-
val_metrics = self._evaluate_prm(prm_trainer, val_dataset)
|
| 702 |
-
logger.info(f"Step {global_step} - Validation: {val_metrics}")
|
| 703 |
-
|
| 704 |
-
# Save best model
|
| 705 |
-
if val_metrics['mse'] < best_val_loss:
|
| 706 |
-
best_val_loss = val_metrics['mse']
|
| 707 |
-
self._save_prm(f"best_step_{global_step}")
|
| 708 |
-
|
| 709 |
-
progress_bar.close()
|
| 710 |
-
|
| 711 |
-
logger.info("PRM training completed")
|
| 712 |
|
| 713 |
def train_rl(
|
| 714 |
self,
|
|
@@ -811,14 +738,6 @@ class StepLevelCoTTrainer:
|
|
| 811 |
|
| 812 |
progress_bar.close()
|
| 813 |
logger.info(f"RL training completed - Avg reward: {total_reward / max(global_step, 1):.4f}")
|
| 814 |
-
|
| 815 |
-
self.rl_trainer.train(
|
| 816 |
-
train_dataset=train_dataset,
|
| 817 |
-
num_iterations=num_iterations,
|
| 818 |
-
log_interval=self.config.logging_steps,
|
| 819 |
-
)
|
| 820 |
-
|
| 821 |
-
logger.info("RL training completed")
|
| 822 |
|
| 823 |
def evaluate_inference_scaling(
|
| 824 |
self,
|
|
|
|
| 636 |
|
| 637 |
progress_bar.close()
|
| 638 |
logger.info(f"PRM training completed - Avg loss: {total_loss / max(global_step, 1):.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 639 |
|
| 640 |
def train_rl(
|
| 641 |
self,
|
|
|
|
| 738 |
|
| 739 |
progress_bar.close()
|
| 740 |
logger.info(f"RL training completed - Avg reward: {total_reward / max(global_step, 1):.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 741 |
|
| 742 |
def evaluate_inference_scaling(
|
| 743 |
self,
|