renpas22 commited on
Commit ·
fa9e543
1
Parent(s): 5af9eca
Skip .to(device) for quantized models with device_map
Browse files
src/reasoning/rl_trainer.py
CHANGED
|
@@ -76,9 +76,13 @@ class RLReasoningTrainer:
|
|
| 76 |
self.config = config
|
| 77 |
self.device = device
|
| 78 |
|
| 79 |
-
# Move models to device
|
| 80 |
-
self.policy
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
# Freeze PRM (only train policy)
|
| 84 |
for param in self.prm.parameters():
|
|
|
|
| 76 |
self.config = config
|
| 77 |
self.device = device
|
| 78 |
|
| 79 |
+
# Move models to device (skip if already quantized with device_map)
|
| 80 |
+
if not (hasattr(self.policy, 'hf_device_map') or
|
| 81 |
+
getattr(self.policy, 'is_quantized', False)):
|
| 82 |
+
self.policy.to(device)
|
| 83 |
+
if not (hasattr(self.prm, 'hf_device_map') or
|
| 84 |
+
getattr(self.prm, 'is_quantized', False)):
|
| 85 |
+
self.prm.to(device)
|
| 86 |
|
| 87 |
# Freeze PRM (only train policy)
|
| 88 |
for param in self.prm.parameters():
|