HuatuoGPT-reward-model-7B / modeling_llama_rm.py
jymcc's picture
Upload 8 files
61670e0
from transformers import PreTrainedModel,LlamaConfig,LlamaModel
import torch.nn as nn
import torch
from typing import Optional
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import SequenceClassifierOutput
class LlamaRewardModel(PreTrainedModel):
config_class =LlamaConfig
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.value_head = nn.Linear(config.hidden_size, 1)
def forward(self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None) -> SequenceClassifierOutput:
outputs = self.model(input_ids, attention_mask=attention_mask, output_hidden_states=True, output_attentions=self.config.output_attentions)
last_hidden_states = outputs.hidden_states[-1]
if attention_mask is None:
last_hidden_states = last_hidden_states[:, -1]
else:
last_index = attention_mask.cumsum(dim=1).argmax(dim=1)
last_hidden_states = last_hidden_states.gather(1, last_index.view(-1, 1, 1).expand(-1, -1, last_hidden_states.size(-1))).squeeze(1)
logits = self.value_head(last_hidden_states)
return (logits,)