test_rm_8b / hf_utils.py
jdchang's picture
Update hf_utils.py
37325f3 verified
raw
history blame
11.4 kB
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0
import os
from copy import deepcopy
from dataclasses import dataclass
from typing import (
Any,
Optional,
Union,
)
import numpy as np
import torch
import torch.nn as nn
from transformers import (
AutoConfig,
AutoModelForCausalLM,
PretrainedConfig,
PreTrainedModel,
)
from transformers.modeling_outputs import ModelOutput
@dataclass
class SequenceClassifierOutput(ModelOutput):
"""Sequence Classification Output.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
scores (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
scores: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
class ValueHead(nn.Module):
"""Value head for the transformer which outputs n_labels values."""
def __init__(self, n_labels: int, hidden_size: int, p_dropout: float = 0.0):
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.dropout = nn.Dropout(p_dropout)
self.score = nn.Linear(hidden_size, n_labels)
torch.nn.init.normal_(
self.score.weight,
std=1 / np.sqrt(hidden_size + 1),
)
torch.nn.init.constant_(self.score.bias, val=0.0)
def forward(
self,
hidden_states: torch.Tensor,
**kwargs: Any,
) -> torch.Tensor:
hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states)
hidden_states = torch.tanh(hidden_states)
hidden_states = self.dropout(hidden_states)
output = self.score(hidden_states)
return output
class RewardModelConfig(PretrainedConfig):
model_type = 'pairwise_rm'
def __init__(
self,
base_model: Optional[Union[str, os.PathLike]
] = 'meta-llama/Meta-Llama-3-70B-Instruct',
base_config: Optional[PretrainedConfig] = None,
p_dropout: float = 0.0,
n_labels: int = 1,
bias: float = 0.0,
return_logits: bool = False,
pretrain_cfg: Optional[dict[str, Any]] = None,
pretrained: bool = False,
**kwargs: Any,
):
super().__init__(**kwargs)
self.base_model = base_model
self.base_config = base_config if base_config is not None else AutoConfig.from_pretrained(
base_model,
)
temp_config = deepcopy(self.base_config)
if not isinstance(temp_config, dict):
temp_config = temp_config.__dict__
for key, value in temp_config.items():
if key not in ['_name_or_path', 'architectures']:
setattr(self, key, value)
self.p_dropout = p_dropout
self.n_labels = n_labels
self.bias = bias
self.return_logits = return_logits
self.pretrain_cfg = pretrain_cfg if pretrain_cfg is not None else {}
self.pretrained = pretrained
class AutoModelForCausalLMWithRM(PreTrainedModel):
config_class = RewardModelConfig
def __init__(self, config: RewardModelConfig):
super().__init__(config)
self.config = config
pretrain_cfg = config.pretrain_cfg
pretrained = config.pretrained
if pretrained:
self.lm_backbone = AutoModelForCausalLM.from_pretrained(
config.base_model,
config=config.base_config,
**pretrain_cfg,
)
else:
#hack for now
if isinstance(config.base_config, dict):
config.base_config = AutoConfig.from_pretrained(
config.base_model,
**config.base_config,
)
self.lm_backbone = AutoModelForCausalLM.from_config(
config.base_config,
trust_remote_code=True,
)
self.value_head = ValueHead(
n_labels=self.config.n_labels,
hidden_size=self.config.hidden_size,
p_dropout=self.config.p_dropout,
)
def generate(self, *args: Any, **kwargs: Any):
return self.lm_backbone.generate(**kwargs)
def resize_token_embeddings(
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
) -> nn.Embedding:
# Note need to update vocab size in base config as well so lm_head modification happens
self.config.base_config.vocab_size = new_num_tokens
model_embeds = super().resize_token_embeddings(
new_num_tokens=new_num_tokens,
pad_to_multiple_of=pad_to_multiple_of,
)
return model_embeds
def set_input_embeddings(self, new_embeddings: Any):
return self.lm_backbone.set_input_embeddings(new_embeddings)
def get_input_embeddings(self):
return self.lm_backbone.get_input_embeddings()
def set_output_embeddings(self, new_embeddings: Any):
return self.lm_backbone.set_output_embeddings(new_embeddings)
def get_output_embeddings(self):
return self.lm_backbone.get_output_embeddings()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Any] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Any,
):
output = self.lm_backbone(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=True,
cache_position=cache_position,
)
scores = self.value_head(
output.hidden_states[-1],
).squeeze(-1) - self.config.bias
logits = None
if self.config.return_logits:
logits = output.logits
return SequenceClassifierOutput(
loss=output.loss,
scores=scores,
logits=logits,
past_key_values=output.past_key_values,
hidden_states=output.hidden_states,
attentions=output.attentions,
)
@classmethod
def from_config(
cls,
config: PretrainedConfig,
**kwargs: Any,
) -> PreTrainedModel:
return cls._from_config(config, **kwargs)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args: Any,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = 'main',
use_safetensors: Optional[bool] = None,
**kwargs: Any,
) -> PreTrainedModel:
trust_remote_code = kwargs.pop('trust_remote_code', True)
use_flash_attention_2 = kwargs.pop('use_flash_attention_2', False)
return_lm_logits = kwargs.pop('return_lm_logits', False)
load_in_8bit = kwargs.pop('load_in_8bit', False)
requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager'
pretrained_model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
token=True,
attn_implementation=requested_attention_implementation,
use_cache=False,
)
if isinstance(pretrained_model_config, cls.config_class):
return super().from_pretrained(
pretrained_model_name_or_path,
*model_args,
config,
cache_dir,
ignore_mismatched_sizes,
force_download,
local_files_only,
token,
revision,
use_safetensors,
**kwargs,
)
pretrain_cfg = {
'trust_remote_code': trust_remote_code,
'token': True,
'load_in_8bit': load_in_8bit,
}
reward_model_config = RewardModelConfig(
base_model=pretrained_model_name_or_path,
base_config=pretrained_model_config,
hidden_size=pretrained_model_config.hidden_size,
torch_dtype=pretrained_model_config.torch_dtype,
return_logits=return_lm_logits,
vocab_size=pretrained_model_config.vocab_size,
pretrained=True,
pretrain_cfg=pretrain_cfg,
)
model = cls(reward_model_config)
return model