|
from transformers import PreTrainedModel |
|
from transformers.modeling_outputs import ModelOutput |
|
from torch import nn |
|
import torch |
|
from dataclasses import dataclass |
|
from typing import Optional |
|
from configuration_reward_model import RewardConfig |
|
from transformers import AutoModelForCausalLM |
|
|
|
|
|
|
|
|
|
@dataclass |
|
class RewardOutputs(ModelOutput): |
|
loss: Optional[torch.FloatTensor] = None |
|
rewards: torch.FloatTensor = None |
|
|
|
|
|
class RewardModel(PreTrainedModel): |
|
config_class = RewardConfig |
|
|
|
def __init__(self, config, **kwargs): |
|
super().__init__(config, **kwargs) |
|
base_model = AutoModelForCausalLM.from_pretrained(config.base_model) |
|
self.config = config |
|
self.neox = "neox" in self.config.model_type |
|
|
|
self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd |
|
self.transformer = base_model.transformer |
|
dtype = self.config.torch_dtype if hasattr(self.config, "torch_dtype") is not None else torch.float32 |
|
dtype = torch.float16 if dtype == "float16" else torch.float32 |
|
self.v_head = nn.Linear(self.config.n_embd, 1, bias=False, dtype=torch.float16) |
|
self.PAD_ID = config.pad_id |
|
self.base_model = base_model |
|
|
|
def gradient_checkpointing_enable(self): |
|
self.base_model.gradient_checkpointing_enable() |
|
|
|
|
|
def forward( |
|
self, |
|
chosen_input_ids=None, |
|
rejected_input_ids=None, |
|
past_key_values=None, |
|
chosen_attention_mask=None, |
|
rejected_attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
): |
|
|
|
input_ids = torch.cat([chosen_input_ids, rejected_input_ids], dim=0) |
|
attention_mask = torch.cat([chosen_attention_mask, rejected_attention_mask], dim=0) |
|
transformer_outputs = self.transformer( |
|
input_ids, |
|
past_key_values=past_key_values, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
) |
|
|
|
hidden_states = transformer_outputs[0] |
|
|
|
rewards = self.v_head(hidden_states).squeeze(-1) |
|
|
|
bs = input_ids.shape[0] // 2 |
|
|
|
|
|
|
|
ends = torch.argmax((input_ids == self.PAD_ID).type(torch.float32), dim=1).view(-1, 1) |
|
rewards = torch.gather(rewards, 1, ends) |
|
|
|
chosen_rewards = rewards[:bs] |
|
rejected_rewards = rewards[bs:] |
|
|
|
loss = -torch.log(torch.sigmoid(chosen_rewards - rejected_rewards)).mean() |
|
|
|
|
|
return RewardOutputs( |
|
loss=loss, |
|
rewards=rewards, |
|
) |