nicolinho commited on
Commit
1d2da62
·
verified ·
1 Parent(s): ce64863

Update modeling_custom.py

Browse files
Files changed (1) hide show
  1. modeling_custom.py +6 -6
modeling_custom.py CHANGED
@@ -162,19 +162,19 @@ class LlamaForRewardModelWithGating(LlamaPreTrainedModel):
162
  prompt_embedding = tokens_hidden_states[dummy_iterator, gating_token_positions, :]
163
  gating_output = self.gating(prompt_embedding)
164
 
165
- #with torch.autocast(device_type=rewards.device.type, dtype=torch.float32):
166
  # [B, num_quantiles, num_objectives]
167
- #reward_quantiles = torch.mul(
168
- # gating_output.unsqueeze(-1).repeat(1, 1, self.num_objectives),
169
- # torch.transpose(rewards, 1, 2)
170
- #).sum(1)
171
 
172
  rewards_expectation = rewards.float().mean(dim=2)
173
 
174
  score = torch.sum(gating_output.float() * rewards_expectation.float(), dim=1, keepdim=True)
175
 
176
  return CustomOutput(
177
- # reward_quantiles=reward_quantiles,
178
  rewards=rewards_expectation,
179
  hidden_state=hidden_states,
180
  prompt_embedding=prompt_embedding,
 
162
  prompt_embedding = tokens_hidden_states[dummy_iterator, gating_token_positions, :]
163
  gating_output = self.gating(prompt_embedding)
164
 
165
+ with torch.autocast(device_type=rewards.device.type, dtype=torch.float32):
166
  # [B, num_quantiles, num_objectives]
167
+ reward_quantiles = torch.mul(
168
+ gating_output.unsqueeze(-1).repeat(1, 1, self.num_objectives),
169
+ torch.transpose(rewards, 1, 2)
170
+ ).sum(1)
171
 
172
  rewards_expectation = rewards.float().mean(dim=2)
173
 
174
  score = torch.sum(gating_output.float() * rewards_expectation.float(), dim=1, keepdim=True)
175
 
176
  return CustomOutput(
177
+ reward_quantiles=reward_quantiles,
178
  rewards=rewards_expectation,
179
  hidden_state=hidden_states,
180
  prompt_embedding=prompt_embedding,