Eurus-RM-7b / modeling_eurus_rm.py
winglian's picture
Enable flash_attention_2 support since the underlying Mistral model supports it
93dbd18 verified
raw
history blame
1.71 kB
from transformers import PreTrainedModel, MistralConfig, MistralModel
import torch.nn as nn
import torch
from typing import Optional, List
class EurusRewardModel(PreTrainedModel):
config_class = MistralConfig
_supports_flash_attn_2 = True
def __init__(self, config):
super().__init__(config)
self.model = MistralModel(config)
self.regression_head = nn.Linear(self.config.hidden_size, 1, bias=False)
def forward( # args are the same as LlamaForCausalLM
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
)
hidden_states = transformer_outputs[0]
rewards = self.regression_head(hidden_states).squeeze(-1)
ends = attention_mask.cumsum(dim=1).argmax(dim=1).view(-1,1)
rewards = torch.gather(rewards, 1, ends)
return rewards