renpas22 commited on
Commit
fa9e543
·
1 Parent(s): 5af9eca

Skip .to(device) for quantized models with device_map

Browse files
Files changed (1) hide show
  1. src/reasoning/rl_trainer.py +7 -3
src/reasoning/rl_trainer.py CHANGED
@@ -76,9 +76,13 @@ class RLReasoningTrainer:
76
  self.config = config
77
  self.device = device
78
 
79
- # Move models to device
80
- self.policy.to(device)
81
- self.prm.to(device)
 
 
 
 
82
 
83
  # Freeze PRM (only train policy)
84
  for param in self.prm.parameters():
 
76
  self.config = config
77
  self.device = device
78
 
79
+ # Move models to device (skip if already quantized with device_map)
80
+ if not (hasattr(self.policy, 'hf_device_map') or
81
+ getattr(self.policy, 'is_quantized', False)):
82
+ self.policy.to(device)
83
+ if not (hasattr(self.prm, 'hf_device_map') or
84
+ getattr(self.prm, 'is_quantized', False)):
85
+ self.prm.to(device)
86
 
87
  # Freeze PRM (only train policy)
88
  for param in self.prm.parameters():