renpas22 commited on
Commit ·
c74a578
1
Parent(s): 37e8f2f
Add type conversion and debug logging for config values
Browse files- src/reasoning/step_level_cot.py +25 -13
src/reasoning/step_level_cot.py
CHANGED
|
@@ -97,19 +97,31 @@ class StepLevelCoTTrainer:
|
|
| 97 |
)
|
| 98 |
|
| 99 |
# Initialize inference scaler
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
logger.info("StepLevelCoTTrainer initialized")
|
| 115 |
|
|
|
|
| 97 |
)
|
| 98 |
|
| 99 |
# Initialize inference scaler
|
| 100 |
+
# Ensure numeric types are properly converted
|
| 101 |
+
try:
|
| 102 |
+
num_samples = int(self.config.num_inference_samples)
|
| 103 |
+
temperature = float(self.config.inference_temperature)
|
| 104 |
+
aggregation = str(self.config.aggregation_method)
|
| 105 |
+
|
| 106 |
+
inference_config = InferenceConfig(
|
| 107 |
+
num_samples=num_samples,
|
| 108 |
+
temperature=temperature,
|
| 109 |
+
aggregation=aggregation,
|
| 110 |
+
use_prm_scores=True,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
self.inference_scaler = InferenceTimeScaling(
|
| 114 |
+
model=self.model,
|
| 115 |
+
prm=self.prm,
|
| 116 |
+
config=inference_config,
|
| 117 |
+
device=device,
|
| 118 |
+
)
|
| 119 |
+
except Exception as e:
|
| 120 |
+
logger.error(f"Failed to initialize inference scaler: {e}")
|
| 121 |
+
logger.error(f"Config values - num_samples: {self.config.num_inference_samples} (type: {type(self.config.num_inference_samples)})")
|
| 122 |
+
logger.error(f"Config values - temperature: {self.config.inference_temperature} (type: {type(self.config.inference_temperature)})")
|
| 123 |
+
logger.error(f"Config values - aggregation: {self.config.aggregation_method} (type: {type(self.config.aggregation_method)})")
|
| 124 |
+
raise
|
| 125 |
|
| 126 |
logger.info("StepLevelCoTTrainer initialized")
|
| 127 |
|