initial_test / modeling_vgs.py
jpzhou01's picture
Upload model
7a484ef verified
from configuration_vgs import VGSConfig
from transformers import Qwen2PreTrainedModel, Qwen2Model
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
from transformers.cache_utils import Cache
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
from dataclasses import dataclass
@dataclass
class CustomSequenceClassifierOutputWithPast(SequenceClassifierOutputWithPast):
# Prob of reward being 1
success_probs: Optional[torch.FloatTensor] = None
class VGSModel(Qwen2PreTrainedModel):
config_class = VGSConfig
def __init__(self, config):
super().__init__(config)
num_labels = config.num_labels
self.model = Qwen2Model(config)
self.score = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size, bias=config.use_bias),
nn.ReLU(),
nn.Linear(config.hidden_size, num_labels, bias=config.use_bias),
)
self.p_dropout = config.attention_dropout
self.score_dropout = nn.Dropout(self.p_dropout)
self.inference_impl = "naive"
self.train_bt_model = False
self.num_labels = num_labels
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
loss_mask: Optional[torch.Tensor] = None,
continuation_ids: Optional[torch.LongTensor] = None,
continuation_attention_mask: Optional[torch.Tensor] = None,
) -> Union[Tuple, CustomSequenceClassifierOutputWithPast]:
"""
During training:
- labels should not be None and have shape: [bs, 1]
- input_ids: [bs, seqlen]
- loss_mask [bs, seqlen]
During inference:
labels, loss_mask should be None
continuation_ids is [bs, N, c_len].
If input_ids is [bs, seqlen], this is prefill stage.
Otherwise, input_ids is also [bs, c_len] which contains the chosen continuation from last step. And we update the kv_cache.
Here, attention_mask should be [bs, q_len] where q_len is seqlen + len of continuations so far.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
assert return_dict, "Only return_dict=True is supported."
is_training = labels is not None
is_single_eval = continuation_ids is None
if not is_training: assert not self.training, "Model should not be in training mode during inference."
if is_training:
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0] # [bs, seqlen, hidden_dim]
logits = self.score(self.score_dropout(hidden_states)).float() # [bs, seqlen, num_labels]
bs, seqlen, _ = logits.shape
if self.train_bt_model:
assert self.num_labels == 1, f"BT model should have 1 label. Got {self.num_labels}."
assert bs % 2 == 0, f"Batch size should be even for BT model. Got {bs}."
logits = logits[:, -1, 0] # [bs, seqlen, 1] -> [bs]
# bt loss
assert torch.all(labels[::2] == 1), f"Labels should be 1 for chosen logits. Got {labels[::2]}."
assert torch.all(labels[1::2] == 0), f"Labels should be 0 for rejected logits. Got {labels[1::2]}."
chosen_logits = logits[::2] # [bs//2]
reject_logits = logits[1::2] # [bs//2]
elemwise_loss = -F.logsigmoid(chosen_logits - reject_logits) # [bs//2]
loss = elemwise_loss.mean()
else:
if self.num_labels == 1:
# BCE Loss
labels_expanded = labels.unsqueeze(-1).expand_as(logits)
elemwise_loss = F.binary_cross_entropy_with_logits(logits, labels_expanded, reduction="none") # [bs, seqlen]
else:
# CrossEntropyLoss
labels_expanded = labels.long().unsqueeze(-1).expand((bs, seqlen)) # [bs, seqlen]
elemwise_loss = F.cross_entropy(
logits.transpose(1, 2), # [bs, seqlen, num_labels] -> [bs, num_labels, seqlen]
labels_expanded, # [bs, seqlen]
reduction="none",
)
# avg over seqlen and bs. do so in a way that prevents nans from division by zero
mask_sum = loss_mask.sum(1).float()
safe_denom = torch.where(mask_sum > 0, mask_sum, torch.ones_like(mask_sum))
loss = torch.where(mask_sum > 0, (elemwise_loss * loss_mask).sum(1) / safe_denom, mask_sum) # [bs]
loss = loss.mean()
return CustomSequenceClassifierOutputWithPast(loss=loss, logits=logits)
elif is_single_eval:
# single eval is also useful for updating kv_cache
assert continuation_ids is None
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0] # [bs, seqlen, hidden_dim]
logits = self.score(hidden_states).float() # [bs, seqlen, num_labels]
if logits.shape[-1] > 1:
# assume 1 is the index/label for success
success_probs = F.softmax(logits, dim=-1)[:, :, 1] # [bs, seqlen]
else:
assert logits.shape[-1] == 1, f"Expected logits to have 1 output, got {logits.shape}."
success_probs = logits.squeeze(-1).sigmoid() # [bs, seqlen]
return CustomSequenceClassifierOutputWithPast(
logits=logits, success_probs=success_probs, past_key_values=transformer_outputs.past_key_values)