from transformers import PreTrainedModel, MistralConfig, MistralModel import torch.nn as nn import torch from typing import Optional, List class EurusRewardModel(PreTrainedModel): config_class = MistralConfig _supports_flash_attn_2 = True def __init__(self, config): super().__init__(config) self.model = MistralModel(config) self.regression_head = nn.Linear(self.config.hidden_size, 1, bias=False) def forward( # args are the same as LlamaForCausalLM self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): transformer_outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, ) hidden_states = transformer_outputs[0] rewards = self.regression_head(hidden_states).squeeze(-1) ends = attention_mask.cumsum(dim=1).argmax(dim=1).view(-1,1) rewards = torch.gather(rewards, 1, ends) return rewards