from transformers import PreTrainedModel from transformers.modeling_outputs import ModelOutput from torch import nn import torch from dataclasses import dataclass from typing import Optional from configuration_reward_model import RewardConfig from transformers import AutoModelForCausalLM # adapted from https://github.com/Dahoas/reward-modeling @dataclass class RewardOutputs(ModelOutput): loss: Optional[torch.FloatTensor] = None rewards: torch.FloatTensor = None class RewardModel(PreTrainedModel): config_class = RewardConfig def __init__(self, config, **kwargs): super().__init__(config, **kwargs) base_model = AutoModelForCausalLM.from_pretrained(config.base_model) self.config = config self.neox = "neox" in self.config.model_type # gpt-neo models have hidden_size instead of n_embd self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd self.transformer = base_model.transformer dtype = self.config.torch_dtype if hasattr(self.config, "torch_dtype") is not None else torch.float32 dtype = torch.float16 if dtype == "float16" else torch.float32 self.v_head = nn.Linear(self.config.n_embd, 1, bias=False, dtype=torch.float16) self.PAD_ID = config.pad_id self.base_model = base_model def gradient_checkpointing_enable(self): self.base_model.gradient_checkpointing_enable() def forward( self, chosen_input_ids=None, rejected_input_ids=None, past_key_values=None, chosen_attention_mask=None, rejected_attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, ): # concat chosen + rejected where first half is chosen and second half is rejected input_ids = torch.cat([chosen_input_ids, rejected_input_ids], dim=0) attention_mask = torch.cat([chosen_attention_mask, rejected_attention_mask], dim=0) transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, ) hidden_states = transformer_outputs[0] rewards = self.v_head(hidden_states).squeeze(-1) bs = input_ids.shape[0] // 2 # argmax returns the first index of the maximum value ! # so we find the first pad/eos token at each row ends = torch.argmax((input_ids == self.PAD_ID).type(torch.float32), dim=1).view(-1, 1) rewards = torch.gather(rewards, 1, ends) chosen_rewards = rewards[:bs] rejected_rewards = rewards[bs:] loss = -torch.log(torch.sigmoid(chosen_rewards - rejected_rewards)).mean() return RewardOutputs( loss=loss, rewards=rewards, )