|
from dataclasses import dataclass |
|
from typing import Literal, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer |
|
from transformers.models.gpt_neox.modeling_gpt_neox import ( |
|
GPTNeoXConfig, |
|
GPTNeoXModel, |
|
GPTNeoXPreTrainedModel, |
|
) |
|
from transformers.utils import ModelOutput |
|
|
|
from llm_studio.src.datasets.text_utils import TEXT_SEPARATOR |
|
|
|
|
|
class GPTNeoXRewardModelConfig(GPTNeoXConfig): |
|
model_type = "gpt_neox_reward_model" |
|
|
|
pooling: Literal["mean", "last"] |
|
|
|
def __init__( |
|
self, |
|
pooling: Literal["mean", "last"] = "last", |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.pooling = pooling or "last" |
|
|
|
|
|
@dataclass |
|
class GPTNeoXRewardModelOutput(ModelOutput): |
|
""" |
|
Reward model output. |
|
|
|
Args: |
|
logits (`torch.FloatTensor` of shape `(batch_size, 1)`): |
|
Reward score |
|
""" |
|
|
|
logits: torch.FloatTensor = None |
|
|
|
|
|
class GPTNeoXRewardModel(GPTNeoXPreTrainedModel): |
|
config_class = GPTNeoXRewardModelConfig |
|
|
|
def __init__(self, config): |
|
if isinstance(config, GPTNeoXConfig): |
|
|
|
|
|
|
|
|
|
config = GPTNeoXRewardModelConfig.from_dict(config.to_dict()) |
|
super().__init__(config) |
|
|
|
self.gpt_neox = GPTNeoXModel(config) |
|
self.out_proj = nn.Linear(config.hidden_size, 1) |
|
self.pooling = config.pooling |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
return_dict: Optional[bool] = True, |
|
) -> GPTNeoXRewardModelOutput: |
|
outputs = self.gpt_neox( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
if self.pooling == "mean": |
|
if attention_mask is None: |
|
pooled = hidden_states.mean(dim=1) |
|
else: |
|
pooled = (hidden_states * attention_mask).sum( |
|
dim=1 |
|
) / attention_mask.sum(dim=1) |
|
elif self.pooling == "last": |
|
if attention_mask is None: |
|
pooled = hidden_states[:, -1] |
|
else: |
|
last_idx = attention_mask.cumsum(dim=1).argmax(dim=1) |
|
pooled = hidden_states.gather( |
|
1, last_idx.view(-1, 1, 1).expand(-1, 1, hidden_states.size(-1)) |
|
).squeeze(1) |
|
else: |
|
raise ValueError(f"Unknown pooling method: {self.pooling}") |
|
|
|
logits = self.out_proj(pooled) |
|
|
|
if not return_dict: |
|
return (logits,) + outputs[1:] |
|
|
|
return GPTNeoXRewardModelOutput(logits=logits) |
|
|
|
|
|
class RewardModel(nn.Module): |
|
def __init__(self, cfg): |
|
super(RewardModel, self).__init__() |
|
|
|
AutoConfig.register("gpt_neox_reward_model", GPTNeoXRewardModelConfig) |
|
AutoModelForSequenceClassification.register( |
|
GPTNeoXRewardModelConfig, GPTNeoXRewardModel |
|
) |
|
|
|
self.cfg = cfg |
|
self.model_name = cfg.reward_model |
|
self.device = cfg.environment._device |
|
self.model = AutoModelForSequenceClassification.from_pretrained( |
|
self.model_name, |
|
torch_dtype=( |
|
torch.float16 |
|
if (torch.cuda.is_available() and len(cfg.environment.gpus) > 0) |
|
else torch.float32 |
|
), |
|
).to(self.device) |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
self.model_name, max_model_input_sizes=2048 |
|
) |
|
|
|
def get_score( |
|
self, |
|
prompts=None, |
|
answers=None, |
|
): |
|
scores = [] |
|
for prompt, answer in zip(prompts, answers): |
|
if "deberta-v3" in self.model_name: |
|
inputs = self.tokenizer( |
|
" ".join(prompt.split(TEXT_SEPARATOR)), |
|
answer, |
|
return_tensors="pt", |
|
max_length=2048, |
|
).to(self.device) |
|
elif self.model_name in [ |
|
"OpenAssistant/oasst-rm-2.1-pythia-1.4b-epoch-2.5", |
|
"OpenAssistant/oasst-rm-2-pythia-6.9b-epoch-1", |
|
]: |
|
prompt = prompt.split(TEXT_SEPARATOR) |
|
|
|
input_text = "" |
|
|
|
for i, prompt_part in enumerate(prompt[::-1]): |
|
if i % 2 == 0: |
|
prefix = "<|prompter|>" |
|
else: |
|
prefix = "<|assistant|>" |
|
input_text = f"{prefix}{prompt_part}<|endoftext|>" + input_text |
|
|
|
input_text = input_text + f"<|assistant|>{answer}<|endoftext|>" |
|
|
|
inputs = self.tokenizer( |
|
input_text, return_tensors="pt", max_length=2048 |
|
).to(self.device) |
|
else: |
|
raise ValueError( |
|
f"Reward model {self.model_name} not supported for scoring." |
|
) |
|
|
|
scores.append(self.model(**inputs).logits[0].cpu().detach().item()) |
|
del inputs |
|
return scores |
|
|