File size: 3,038 Bytes
779c83e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from transformers import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from torch import nn
import torch
from dataclasses import dataclass
from typing import Optional 
from configuration_reward_model import RewardConfig
from transformers import AutoModelForCausalLM


# adapted from https://github.com/Dahoas/reward-modeling

@dataclass
class RewardOutputs(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    rewards: torch.FloatTensor = None


class RewardModel(PreTrainedModel):
    config_class = RewardConfig

    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        base_model = AutoModelForCausalLM.from_pretrained(config.base_model)
        self.config = config
        self.neox = "neox" in self.config.model_type
        # gpt-neo models have hidden_size instead of n_embd
        self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd
        self.transformer = base_model.transformer
        dtype = self.config.torch_dtype if hasattr(self.config, "torch_dtype") is not None else torch.float32
        dtype = torch.float16 if dtype == "float16" else torch.float32
        self.v_head = nn.Linear(self.config.n_embd, 1, bias=False, dtype=torch.float16)
        self.PAD_ID = config.pad_id
        self.base_model = base_model

    def gradient_checkpointing_enable(self):
        self.base_model.gradient_checkpointing_enable()


    def forward(
        self,
        chosen_input_ids=None,
        rejected_input_ids=None,
        past_key_values=None,
        chosen_attention_mask=None,
        rejected_attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
    ):
        # concat chosen + rejected where first half is chosen and second half is rejected
        input_ids = torch.cat([chosen_input_ids, rejected_input_ids], dim=0)
        attention_mask = torch.cat([chosen_attention_mask, rejected_attention_mask], dim=0)
        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )

        hidden_states = transformer_outputs[0]

        rewards = self.v_head(hidden_states).squeeze(-1)

        bs = input_ids.shape[0] // 2

        # argmax returns the first index of the maximum value !
        # so we find the first pad/eos token at each row
        ends = torch.argmax((input_ids == self.PAD_ID).type(torch.float32), dim=1).view(-1, 1)
        rewards = torch.gather(rewards, 1, ends)

        chosen_rewards = rewards[:bs]
        rejected_rewards = rewards[bs:]

        loss = -torch.log(torch.sigmoid(chosen_rewards - rejected_rewards)).mean() 

        
        return RewardOutputs(
            loss=loss,
            rewards=rewards,
        )