Update modeling_custom.py
#14
by
gabrielmbmb
HF staff
- opened
- modeling_custom.py +6 -1
modeling_custom.py
CHANGED
@@ -96,6 +96,9 @@ class LlamaForRewardModelWithGating(LlamaPreTrainedModel):
|
|
96 |
temperature=config_dict.get("gating_temperature", 10),
|
97 |
hidden_dim=config_dict.get("gating_hidden_dim", 1024),
|
98 |
n_hidden=config_dict.get("gating_n_hidden", 3))
|
|
|
|
|
|
|
99 |
|
100 |
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
101 |
def forward(
|
@@ -153,6 +156,8 @@ class LlamaForRewardModelWithGating(LlamaPreTrainedModel):
|
|
153 |
prompt_embedding = tokens_hidden_states[dummy_iterator, gating_token_positions, :]
|
154 |
gating_output = self.gating(prompt_embedding)
|
155 |
|
|
|
|
|
156 |
rewards_adjusted = rewards @ self.reward_transform_matrix
|
157 |
score = torch.sum(gating_output * rewards_adjusted, dim=1)
|
158 |
|
@@ -163,4 +168,4 @@ class LlamaForRewardModelWithGating(LlamaPreTrainedModel):
|
|
163 |
gating_output=gating_output,
|
164 |
score=score,
|
165 |
logits=score,
|
166 |
-
)
|
|
|
96 |
temperature=config_dict.get("gating_temperature", 10),
|
97 |
hidden_dim=config_dict.get("gating_hidden_dim", 1024),
|
98 |
n_hidden=config_dict.get("gating_n_hidden", 3))
|
99 |
+
def align_tensor_devices(self, *tensors):
|
100 |
+
target_device = tensors[0].device
|
101 |
+
return [tensor.to(target_device) if tensor.device != target_device else tensor for tensor in tensors]
|
102 |
|
103 |
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
104 |
def forward(
|
|
|
156 |
prompt_embedding = tokens_hidden_states[dummy_iterator, gating_token_positions, :]
|
157 |
gating_output = self.gating(prompt_embedding)
|
158 |
|
159 |
+
rewards, self.reward_transform_matrix = self.align_tensor_devices(rewards, self.reward_transform_matrix)
|
160 |
+
|
161 |
rewards_adjusted = rewards @ self.reward_transform_matrix
|
162 |
score = torch.sum(gating_output * rewards_adjusted, dim=1)
|
163 |
|
|
|
168 |
gating_output=gating_output,
|
169 |
score=score,
|
170 |
logits=score,
|
171 |
+
)
|