renpas22 commited on
Commit
ccd696b
·
1 Parent(s): cd76323

Remove dead code with direct config access

Browse files
Files changed (1) hide show
  1. src/reasoning/step_level_cot.py +0 -81
src/reasoning/step_level_cot.py CHANGED
@@ -636,79 +636,6 @@ class StepLevelCoTTrainer:
636
 
637
  progress_bar.close()
638
  logger.info(f"PRM training completed - Avg loss: {total_loss / max(global_step, 1):.4f}")
639
-
640
- prm_trainer = PRMTrainer(
641
- model=self.prm,
642
- learning_rate=learning_rate, # Use the passed parameter
643
- weight_decay=self.config.weight_decay,
644
- warmup_steps=self.config.warmup_steps,
645
- )
646
-
647
- train_dataloader = torch.utils.data.DataLoader(
648
- train_dataset,
649
- batch_size=self.config.train_batch_size,
650
- shuffle=True,
651
- num_workers=self.config.dataloader_num_workers,
652
- )
653
-
654
- # Prepare with accelerator
655
- prm_trainer.model, prm_trainer.optimizer, train_dataloader = self.accelerator.prepare(
656
- prm_trainer.model, prm_trainer.optimizer, train_dataloader
657
- )
658
-
659
- best_val_loss = float('inf')
660
- global_step = 0
661
-
662
- # Training loop using steps instead of epochs
663
- prm_trainer.model.train()
664
-
665
- progress_bar = tqdm(
666
- total=max_steps,
667
- desc=f"PRM Training",
668
- disable=not self.accelerator.is_local_main_process,
669
- )
670
-
671
- while global_step < max_steps:
672
- for batch in train_dataloader:
673
- if global_step >= max_steps:
674
- break
675
-
676
- # Extract batch data
677
- vision_features = batch['visual_features']
678
- step_embeddings = batch['step_descriptions'] # Need to encode these
679
- target_rewards = batch['step_rewards']
680
-
681
- # Encode step descriptions
682
- # (In practice, you'd encode text properly)
683
-
684
- metrics = prm_trainer.train_step(
685
- vision_features,
686
- step_embeddings,
687
- target_rewards,
688
- )
689
-
690
- global_step += 1
691
- progress_bar.update(1)
692
- progress_bar.set_postfix({'loss': metrics['loss'], 'step': global_step})
693
-
694
- # Save checkpoint
695
- if global_step % save_steps == 0:
696
- logger.info(f"Saving checkpoint at step {global_step}")
697
- self._save_prm(global_step)
698
-
699
- # Validation
700
- if eval_steps > 0 and global_step % eval_steps == 0 and val_dataset is not None:
701
- val_metrics = self._evaluate_prm(prm_trainer, val_dataset)
702
- logger.info(f"Step {global_step} - Validation: {val_metrics}")
703
-
704
- # Save best model
705
- if val_metrics['mse'] < best_val_loss:
706
- best_val_loss = val_metrics['mse']
707
- self._save_prm(f"best_step_{global_step}")
708
-
709
- progress_bar.close()
710
-
711
- logger.info("PRM training completed")
712
 
713
  def train_rl(
714
  self,
@@ -811,14 +738,6 @@ class StepLevelCoTTrainer:
811
 
812
  progress_bar.close()
813
  logger.info(f"RL training completed - Avg reward: {total_reward / max(global_step, 1):.4f}")
814
-
815
- self.rl_trainer.train(
816
- train_dataset=train_dataset,
817
- num_iterations=num_iterations,
818
- log_interval=self.config.logging_steps,
819
- )
820
-
821
- logger.info("RL training completed")
822
 
823
  def evaluate_inference_scaling(
824
  self,
 
636
 
637
  progress_bar.close()
638
  logger.info(f"PRM training completed - Avg loss: {total_loss / max(global_step, 1):.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
 
640
  def train_rl(
641
  self,
 
738
 
739
  progress_bar.close()
740
  logger.info(f"RL training completed - Avg reward: {total_reward / max(global_step, 1):.4f}")
 
 
 
 
 
 
 
 
741
 
742
  def evaluate_inference_scaling(
743
  self,