renpas22 commited on
Commit
c74a578
·
1 Parent(s): 37e8f2f

Add type conversion and debug logging for config values

Browse files
Files changed (1) hide show
  1. src/reasoning/step_level_cot.py +25 -13
src/reasoning/step_level_cot.py CHANGED
@@ -97,19 +97,31 @@ class StepLevelCoTTrainer:
97
  )
98
 
99
  # Initialize inference scaler
100
- inference_config = InferenceConfig(
101
- num_samples=self.config.num_inference_samples,
102
- temperature=self.config.inference_temperature,
103
- aggregation=self.config.aggregation_method,
104
- use_prm_scores=True,
105
- )
106
-
107
- self.inference_scaler = InferenceTimeScaling(
108
- model=self.model,
109
- prm=self.prm,
110
- config=inference_config,
111
- device=device,
112
- )
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  logger.info("StepLevelCoTTrainer initialized")
115
 
 
97
  )
98
 
99
  # Initialize inference scaler
100
+ # Ensure numeric types are properly converted
101
+ try:
102
+ num_samples = int(self.config.num_inference_samples)
103
+ temperature = float(self.config.inference_temperature)
104
+ aggregation = str(self.config.aggregation_method)
105
+
106
+ inference_config = InferenceConfig(
107
+ num_samples=num_samples,
108
+ temperature=temperature,
109
+ aggregation=aggregation,
110
+ use_prm_scores=True,
111
+ )
112
+
113
+ self.inference_scaler = InferenceTimeScaling(
114
+ model=self.model,
115
+ prm=self.prm,
116
+ config=inference_config,
117
+ device=device,
118
+ )
119
+ except Exception as e:
120
+ logger.error(f"Failed to initialize inference scaler: {e}")
121
+ logger.error(f"Config values - num_samples: {self.config.num_inference_samples} (type: {type(self.config.num_inference_samples)})")
122
+ logger.error(f"Config values - temperature: {self.config.inference_temperature} (type: {type(self.config.inference_temperature)})")
123
+ logger.error(f"Config values - aggregation: {self.config.aggregation_method} (type: {type(self.config.aggregation_method)})")
124
+ raise
125
 
126
  logger.info("StepLevelCoTTrainer initialized")
127