File size: 2,861 Bytes
07423df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import logging
from typing import Any, Dict, List
import numpy as np
import pandas as pd
import torch
from llm_studio.src.datasets.text_causal_language_modeling_ds import (
CustomDataset as CausalLMCustomDataset,
)
from llm_studio.src.datasets.text_utils import TEXT_SEPARATOR
logger = logging.getLogger(__name__)
class CustomDataset(CausalLMCustomDataset):
def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"):
assert (
cfg.dataset.system_column == "None"
), "RLHF is not compatible with system column."
assert (
cfg.dataset.limit_chained_samples is False
), "RLHF is not compatible with limit_chained_samples."
assert (
cfg.dataset.mask_prompt_labels is True
), "RLHF is not compatible with mask_prompt_labels."
super().__init__(df, cfg, mode)
def __getitem__(self, idx: int) -> Dict:
"""Reads a single text observation."""
sample = super().__getitem__(idx)
sample["reward_model_prompt_text"] = TEXT_SEPARATOR.join(
self.get_chained_prompt_text_list(idx)
)
return sample
def get_labels(self, prompt_encodings, answer_encodings):
if self.mode == "train": # no labels required for RLHF during training
return dict()
else:
return super().get_labels(prompt_encodings, answer_encodings)
def get_encodings(self, input_text_dict):
system_encoding, prompt_encodings, answer_encodings = super().get_encodings(
input_text_dict
)
# remove last ground truth answer,
# as RLHF will generate the answer from the prompt
answer_encodings[-1] = torch.empty(0)
return system_encoding, prompt_encodings, answer_encodings
def postprocess_batch_predictions(self, output: Dict) -> Dict:
if "predicted_answer_ids" in output.keys():
predicted_text = [
self.tokenizer.decode(ids, skip_special_tokens=True).strip()
for ids in output["predicted_answer_ids"]
]
output["predicted_text"] = np.array(predicted_text)
output["predicted_answer_ids"] = output["predicted_answer_ids"].detach()
return output
def augment_data(self, encodings):
return encodings
def get_chained_prompt_text_list(self, idx: int) -> List[str]:
text_dict = self.conversation_chain_handler[idx]
chat_history = "".join(
[
prompt + TEXT_SEPARATOR + answer + TEXT_SEPARATOR
for prompt, answer in zip(
text_dict["prompts"][:-1], text_dict["answers"][:-1]
)
]
)
prompt_text = text_dict["systems"][0] + chat_history + text_dict["prompts"][-1]
return prompt_text.split(TEXT_SEPARATOR)
|