Update modeling_custom.py
Browse files- modeling_custom.py +6 -6
modeling_custom.py
CHANGED
@@ -162,19 +162,19 @@ class LlamaForRewardModelWithGating(LlamaPreTrainedModel):
|
|
162 |
prompt_embedding = tokens_hidden_states[dummy_iterator, gating_token_positions, :]
|
163 |
gating_output = self.gating(prompt_embedding)
|
164 |
|
165 |
-
|
166 |
# [B, num_quantiles, num_objectives]
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
|
172 |
rewards_expectation = rewards.float().mean(dim=2)
|
173 |
|
174 |
score = torch.sum(gating_output.float() * rewards_expectation.float(), dim=1, keepdim=True)
|
175 |
|
176 |
return CustomOutput(
|
177 |
-
|
178 |
rewards=rewards_expectation,
|
179 |
hidden_state=hidden_states,
|
180 |
prompt_embedding=prompt_embedding,
|
|
|
162 |
prompt_embedding = tokens_hidden_states[dummy_iterator, gating_token_positions, :]
|
163 |
gating_output = self.gating(prompt_embedding)
|
164 |
|
165 |
+
with torch.autocast(device_type=rewards.device.type, dtype=torch.float32):
|
166 |
# [B, num_quantiles, num_objectives]
|
167 |
+
reward_quantiles = torch.mul(
|
168 |
+
gating_output.unsqueeze(-1).repeat(1, 1, self.num_objectives),
|
169 |
+
torch.transpose(rewards, 1, 2)
|
170 |
+
).sum(1)
|
171 |
|
172 |
rewards_expectation = rewards.float().mean(dim=2)
|
173 |
|
174 |
score = torch.sum(gating_output.float() * rewards_expectation.float(), dim=1, keepdim=True)
|
175 |
|
176 |
return CustomOutput(
|
177 |
+
reward_quantiles=reward_quantiles,
|
178 |
rewards=rewards_expectation,
|
179 |
hidden_state=hidden_states,
|
180 |
prompt_embedding=prompt_embedding,
|