|
from configuration_vgs import VGSConfig
|
|
from transformers import Qwen2PreTrainedModel, Qwen2Model
|
|
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
|
|
from transformers.cache_utils import Cache
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from typing import List, Optional, Tuple, Union
|
|
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass
|
|
class CustomSequenceClassifierOutputWithPast(SequenceClassifierOutputWithPast):
|
|
|
|
success_probs: Optional[torch.FloatTensor] = None
|
|
|
|
|
|
class VGSModel(Qwen2PreTrainedModel):
|
|
config_class = VGSConfig
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
num_labels = config.num_labels
|
|
self.model = Qwen2Model(config)
|
|
self.score = nn.Sequential(
|
|
nn.Linear(config.hidden_size, config.hidden_size, bias=config.use_bias),
|
|
nn.ReLU(),
|
|
nn.Linear(config.hidden_size, num_labels, bias=config.use_bias),
|
|
)
|
|
self.p_dropout = config.attention_dropout
|
|
self.score_dropout = nn.Dropout(self.p_dropout)
|
|
self.inference_impl = "naive"
|
|
self.train_bt_model = False
|
|
self.num_labels = num_labels
|
|
|
|
|
|
self.post_init()
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
loss_mask: Optional[torch.Tensor] = None,
|
|
continuation_ids: Optional[torch.LongTensor] = None,
|
|
continuation_attention_mask: Optional[torch.Tensor] = None,
|
|
) -> Union[Tuple, CustomSequenceClassifierOutputWithPast]:
|
|
"""
|
|
During training:
|
|
- labels should not be None and have shape: [bs, 1]
|
|
- input_ids: [bs, seqlen]
|
|
- loss_mask [bs, seqlen]
|
|
|
|
During inference:
|
|
labels, loss_mask should be None
|
|
continuation_ids is [bs, N, c_len].
|
|
If input_ids is [bs, seqlen], this is prefill stage.
|
|
Otherwise, input_ids is also [bs, c_len] which contains the chosen continuation from last step. And we update the kv_cache.
|
|
Here, attention_mask should be [bs, q_len] where q_len is seqlen + len of continuations so far.
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
assert return_dict, "Only return_dict=True is supported."
|
|
is_training = labels is not None
|
|
is_single_eval = continuation_ids is None
|
|
if not is_training: assert not self.training, "Model should not be in training mode during inference."
|
|
|
|
if is_training:
|
|
transformer_outputs = self.model(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
hidden_states = transformer_outputs[0]
|
|
logits = self.score(self.score_dropout(hidden_states)).float()
|
|
bs, seqlen, _ = logits.shape
|
|
if self.train_bt_model:
|
|
assert self.num_labels == 1, f"BT model should have 1 label. Got {self.num_labels}."
|
|
assert bs % 2 == 0, f"Batch size should be even for BT model. Got {bs}."
|
|
logits = logits[:, -1, 0]
|
|
|
|
assert torch.all(labels[::2] == 1), f"Labels should be 1 for chosen logits. Got {labels[::2]}."
|
|
assert torch.all(labels[1::2] == 0), f"Labels should be 0 for rejected logits. Got {labels[1::2]}."
|
|
chosen_logits = logits[::2]
|
|
reject_logits = logits[1::2]
|
|
elemwise_loss = -F.logsigmoid(chosen_logits - reject_logits)
|
|
loss = elemwise_loss.mean()
|
|
else:
|
|
if self.num_labels == 1:
|
|
|
|
labels_expanded = labels.unsqueeze(-1).expand_as(logits)
|
|
elemwise_loss = F.binary_cross_entropy_with_logits(logits, labels_expanded, reduction="none")
|
|
else:
|
|
|
|
labels_expanded = labels.long().unsqueeze(-1).expand((bs, seqlen))
|
|
elemwise_loss = F.cross_entropy(
|
|
logits.transpose(1, 2),
|
|
labels_expanded,
|
|
reduction="none",
|
|
)
|
|
|
|
mask_sum = loss_mask.sum(1).float()
|
|
safe_denom = torch.where(mask_sum > 0, mask_sum, torch.ones_like(mask_sum))
|
|
loss = torch.where(mask_sum > 0, (elemwise_loss * loss_mask).sum(1) / safe_denom, mask_sum)
|
|
loss = loss.mean()
|
|
|
|
return CustomSequenceClassifierOutputWithPast(loss=loss, logits=logits)
|
|
|
|
elif is_single_eval:
|
|
|
|
assert continuation_ids is None
|
|
transformer_outputs = self.model(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
hidden_states = transformer_outputs[0]
|
|
logits = self.score(hidden_states).float()
|
|
if logits.shape[-1] > 1:
|
|
|
|
success_probs = F.softmax(logits, dim=-1)[:, :, 1]
|
|
else:
|
|
assert logits.shape[-1] == 1, f"Expected logits to have 1 output, got {logits.shape}."
|
|
success_probs = logits.squeeze(-1).sigmoid()
|
|
|
|
return CustomSequenceClassifierOutputWithPast(
|
|
logits=logits, success_probs=success_probs, past_key_values=transformer_outputs.past_key_values)
|
|
|