File size: 7,132 Bytes
7a484ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from configuration_vgs import VGSConfig
from transformers import Qwen2PreTrainedModel, Qwen2Model
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
from transformers.cache_utils import Cache
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
from dataclasses import dataclass


@dataclass
class CustomSequenceClassifierOutputWithPast(SequenceClassifierOutputWithPast):
    # Prob of reward being 1
    success_probs: Optional[torch.FloatTensor] = None


class VGSModel(Qwen2PreTrainedModel):
    config_class = VGSConfig
    def __init__(self, config):
        super().__init__(config)
        num_labels = config.num_labels
        self.model = Qwen2Model(config)
        self.score = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size, bias=config.use_bias),
            nn.ReLU(),
            nn.Linear(config.hidden_size, num_labels, bias=config.use_bias),
        )
        self.p_dropout = config.attention_dropout
        self.score_dropout = nn.Dropout(self.p_dropout)
        self.inference_impl = "naive"
        self.train_bt_model = False
        self.num_labels = num_labels

        # Initialize weights and apply final processing
        self.post_init()

    def forward(

        self,

        input_ids: Optional[torch.LongTensor] = None,

        attention_mask: Optional[torch.Tensor] = None,

        position_ids: Optional[torch.LongTensor] = None,

        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,

        inputs_embeds: Optional[torch.FloatTensor] = None,

        labels: Optional[torch.LongTensor] = None,

        use_cache: Optional[bool] = None,

        output_attentions: Optional[bool] = None,

        output_hidden_states: Optional[bool] = None,

        return_dict: Optional[bool] = None,

        loss_mask: Optional[torch.Tensor] = None,

        continuation_ids: Optional[torch.LongTensor] = None,

        continuation_attention_mask: Optional[torch.Tensor] = None,

    ) -> Union[Tuple, CustomSequenceClassifierOutputWithPast]:
        """

        During training:

        - labels should not be None and have shape: [bs, 1]

        - input_ids: [bs, seqlen]

        - loss_mask [bs, seqlen]



        During inference:

        labels, loss_mask should be None

        continuation_ids is [bs, N, c_len].

        If input_ids is [bs, seqlen], this is prefill stage.

        Otherwise, input_ids is also [bs, c_len] which contains the chosen continuation from last step. And we update the kv_cache.

        Here, attention_mask should be [bs, q_len] where q_len is seqlen + len of continuations so far.

        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        assert return_dict, "Only return_dict=True is supported."
        is_training = labels is not None
        is_single_eval = continuation_ids is None
        if not is_training: assert not self.training, "Model should not be in training mode during inference."

        if is_training:
            transformer_outputs = self.model(
                input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            hidden_states = transformer_outputs[0]  # [bs, seqlen, hidden_dim]
            logits = self.score(self.score_dropout(hidden_states)).float()  # [bs, seqlen, num_labels]
            bs, seqlen, _ = logits.shape
            if self.train_bt_model:
                assert self.num_labels == 1, f"BT model should have 1 label. Got {self.num_labels}."
                assert bs % 2 == 0, f"Batch size should be even for BT model. Got {bs}."
                logits = logits[:, -1, 0]  # [bs, seqlen, 1] -> [bs]
                # bt loss
                assert torch.all(labels[::2] == 1), f"Labels should be 1 for chosen logits. Got {labels[::2]}."
                assert torch.all(labels[1::2] == 0), f"Labels should be 0 for rejected logits. Got {labels[1::2]}."
                chosen_logits = logits[::2]  # [bs//2]
                reject_logits = logits[1::2]  # [bs//2]
                elemwise_loss = -F.logsigmoid(chosen_logits - reject_logits)  # [bs//2]
                loss = elemwise_loss.mean()
            else:
                if self.num_labels == 1:
                    # BCE Loss
                    labels_expanded = labels.unsqueeze(-1).expand_as(logits)
                    elemwise_loss = F.binary_cross_entropy_with_logits(logits, labels_expanded, reduction="none")  # [bs, seqlen]
                else:
                    # CrossEntropyLoss
                    labels_expanded = labels.long().unsqueeze(-1).expand((bs, seqlen))  # [bs, seqlen]
                    elemwise_loss = F.cross_entropy(
                        logits.transpose(1, 2),  # [bs, seqlen, num_labels] -> [bs, num_labels, seqlen]
                        labels_expanded,  # [bs, seqlen]
                        reduction="none",
                    )
                # avg over seqlen and bs. do so in a way that prevents nans from division by zero
                mask_sum = loss_mask.sum(1).float()
                safe_denom = torch.where(mask_sum > 0, mask_sum, torch.ones_like(mask_sum))
                loss = torch.where(mask_sum > 0, (elemwise_loss * loss_mask).sum(1) / safe_denom, mask_sum)  # [bs]
                loss = loss.mean()

            return CustomSequenceClassifierOutputWithPast(loss=loss, logits=logits)

        elif is_single_eval:
            # single eval is also useful for updating kv_cache
            assert continuation_ids is None
            transformer_outputs = self.model(
                input_ids,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            hidden_states = transformer_outputs[0]  # [bs, seqlen, hidden_dim]
            logits = self.score(hidden_states).float()  # [bs, seqlen, num_labels]
            if logits.shape[-1] > 1:
                # assume 1 is the index/label for success
                success_probs = F.softmax(logits, dim=-1)[:, :, 1]  # [bs, seqlen]
            else:
                assert logits.shape[-1] == 1, f"Expected logits to have 1 output, got {logits.shape}."
                success_probs = logits.squeeze(-1).sigmoid()  # [bs, seqlen]

            return CustomSequenceClassifierOutputWithPast(
                logits=logits, success_probs=success_probs, past_key_values=transformer_outputs.past_key_values)