from transformers import PreTrainedModel,LlamaConfig,LlamaModel import torch.nn as nn import torch from typing import Optional from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import SequenceClassifierOutput class LlamaRewardModel(PreTrainedModel): config_class =LlamaConfig def __init__(self, config): super().__init__(config) self.model = LlamaModel(config) self.value_head = nn.Linear(config.hidden_size, 1) def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None) -> SequenceClassifierOutput: outputs = self.model(input_ids, attention_mask=attention_mask, output_hidden_states=True, output_attentions=self.config.output_attentions) last_hidden_states = outputs.hidden_states[-1] if attention_mask is None: last_hidden_states = last_hidden_states[:, -1] else: last_index = attention_mask.cumsum(dim=1).argmax(dim=1) last_hidden_states = last_hidden_states.gather(1, last_index.view(-1, 1, 1).expand(-1, -1, last_hidden_states.size(-1))).squeeze(1) logits = self.value_head(last_hidden_states) return (logits,)