gpt-j-reward-model / modeling_rewards.py
zpn's picture
Upload model
779c83e
raw
history blame
3.04 kB
from transformers import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from torch import nn
import torch
from dataclasses import dataclass
from typing import Optional
from configuration_reward_model import RewardConfig
from transformers import AutoModelForCausalLM
# adapted from https://github.com/Dahoas/reward-modeling
@dataclass
class RewardOutputs(ModelOutput):
loss: Optional[torch.FloatTensor] = None
rewards: torch.FloatTensor = None
class RewardModel(PreTrainedModel):
config_class = RewardConfig
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
base_model = AutoModelForCausalLM.from_pretrained(config.base_model)
self.config = config
self.neox = "neox" in self.config.model_type
# gpt-neo models have hidden_size instead of n_embd
self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd
self.transformer = base_model.transformer
dtype = self.config.torch_dtype if hasattr(self.config, "torch_dtype") is not None else torch.float32
dtype = torch.float16 if dtype == "float16" else torch.float32
self.v_head = nn.Linear(self.config.n_embd, 1, bias=False, dtype=torch.float16)
self.PAD_ID = config.pad_id
self.base_model = base_model
def gradient_checkpointing_enable(self):
self.base_model.gradient_checkpointing_enable()
def forward(
self,
chosen_input_ids=None,
rejected_input_ids=None,
past_key_values=None,
chosen_attention_mask=None,
rejected_attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
):
# concat chosen + rejected where first half is chosen and second half is rejected
input_ids = torch.cat([chosen_input_ids, rejected_input_ids], dim=0)
attention_mask = torch.cat([chosen_attention_mask, rejected_attention_mask], dim=0)
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
hidden_states = transformer_outputs[0]
rewards = self.v_head(hidden_states).squeeze(-1)
bs = input_ids.shape[0] // 2
# argmax returns the first index of the maximum value !
# so we find the first pad/eos token at each row
ends = torch.argmax((input_ids == self.PAD_ID).type(torch.float32), dim=1).view(-1, 1)
rewards = torch.gather(rewards, 1, ends)
chosen_rewards = rewards[:bs]
rejected_rewards = rewards[bs:]
loss = -torch.log(torch.sigmoid(chosen_rewards - rejected_rewards)).mean()
return RewardOutputs(
loss=loss,
rewards=rewards,
)