Enxin's picture
Upload folder using huggingface_hub
96fe658 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
from collections import defaultdict
from typing import Any, Dict, Tuple, Union
import pandas as pd
import torch
import torch.nn as nn
from accelerate.utils import gather_object
from transformers import PreTrainedModel
from trl import RewardTrainer as HFRewardTrainer
from trl.trainer.utils import print_rich_table
from ..mixin import SwiftMixin
from .rlhf_mixin import RLHFTrainerMixin
del HFRewardTrainer.__init__
class RewardTrainer(RLHFTrainerMixin, SwiftMixin, HFRewardTrainer):
def compute_loss(self,
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
num_items_in_batch=None) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
inputs.pop('labels', None) # not use
margin = inputs.pop('margin', None)
attention_mask = inputs['attention_mask']
batch_size = attention_mask.shape[0] // 2
rewards = model(**inputs).logits
rewards_chosen, rewards_rejected = torch.split(rewards, batch_size, dim=0)
if margin is not None:
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - margin).mean()
else:
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
if self.args.center_rewards_coefficient is not None:
loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected)**2)
# compat transformers>=4.46.*
if num_items_in_batch is not None and self.model_accepts_loss_kwargs:
loss = loss / self.args.gradient_accumulation_steps
if return_outputs:
return loss, {
'rewards_chosen': rewards_chosen,
'rewards_rejected': rewards_rejected,
}
return loss
def visualize_samples(self, num_print_samples: int):
"""
Visualize the reward model logits prediction
Args:
num_print_samples (`int`, defaults to `4`):
The number of samples to print. Set to `-1` to print all samples.
"""
eval_dataloader = self.get_eval_dataloader()
table = defaultdict(list)
for _, inputs in enumerate(eval_dataloader):
_, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
sequence_lengths = ((torch.eq(attention_mask, 0).int().argmax(-1) - 1) % attention_mask.shape[1]).tolist()
text = [self.template.safe_decode(tokens[:sequence_lengths[i]]) for i, tokens in enumerate(input_ids)]
batch_size = input_ids.shape[0] // 2
chosen_text, rejected_text = text[:batch_size], text[batch_size:]
table['chosen_text'].extend(gather_object(chosen_text))
table['rejected_text'].extend(gather_object(rejected_text))
table['logits'].extend(
gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()]))
if 0 <= num_print_samples <= len(table['chosen_text']):
break
df = pd.DataFrame(table)
if self.accelerator.process_index == 0:
print_rich_table(df[:num_print_samples])
if 'wandb' in self.args.report_to:
import wandb
if wandb.run is not None:
wandb.log({'completions': wandb.Table(dataframe=df)})
if 'swanlab' in self.args.report_to:
import swanlab
if swanlab.get_run() is not None:
swanlab_table = swanlab.echarts.Table()
swanlab_table.add(headers=df.columns.tolist(), rows=df.values.tolist())
swanlab.log({'completions': swanlab_table})