|
from transformers import PreTrainedModel, MistralConfig, MistralModel |
|
import torch.nn as nn |
|
import torch |
|
from typing import Optional, List |
|
|
|
class EurusRewardModel(PreTrainedModel): |
|
config_class = MistralConfig |
|
_supports_flash_attn_2 = True |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = MistralModel(config) |
|
self.regression_head = nn.Linear(self.config.hidden_size, 1, bias=False) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
): |
|
|
|
transformer_outputs = self.model( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
) |
|
|
|
hidden_states = transformer_outputs[0] |
|
rewards = self.regression_head(hidden_states).squeeze(-1) |
|
|
|
ends = attention_mask.cumsum(dim=1).argmax(dim=1).view(-1,1) |
|
rewards = torch.gather(rewards, 1, ends) |
|
|
|
return rewards |