Text Classification
Transformers
Safetensors
mistral
feature-extraction
reward_model
custom_code
text-generation-inference
Eurus-RM-7b / modeling_eurus_rm.py
winglian's picture
Enable flash_attention_2 support since the underlying Mistral model supports it
93dbd18 verified
raw history blame
No virus
1.71 kB
from transformers import PreTrainedModel, MistralConfig, MistralModel
import torch.nn as nn
import torch
from typing import Optional, List
class EurusRewardModel(PreTrainedModel):
config_class = MistralConfig
_supports_flash_attn_2 = True
def __init__(self, config):
super().__init__(config)
self.model = MistralModel(config)
self.regression_head = nn.Linear(self.config.hidden_size, 1, bias=False)
def forward( # args are the same as LlamaForCausalLM
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
)
hidden_states = transformer_outputs[0]
rewards = self.regression_head(hidden_states).squeeze(-1)
ends = attention_mask.cumsum(dim=1).argmax(dim=1).view(-1,1)
rewards = torch.gather(rewards, 1, ends)
return rewards