|
import logging |
|
from typing import Any, Dict, List |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
|
|
from llm_studio.src.datasets.text_causal_language_modeling_ds import ( |
|
CustomDataset as CausalLMCustomDataset, |
|
) |
|
from llm_studio.src.datasets.text_utils import TEXT_SEPARATOR |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class CustomDataset(CausalLMCustomDataset): |
|
def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"): |
|
assert ( |
|
cfg.dataset.system_column == "None" |
|
), "RLHF is not compatible with system column." |
|
assert ( |
|
cfg.dataset.limit_chained_samples is False |
|
), "RLHF is not compatible with limit_chained_samples." |
|
assert ( |
|
cfg.dataset.mask_prompt_labels is True |
|
), "RLHF is not compatible with mask_prompt_labels." |
|
super().__init__(df, cfg, mode) |
|
|
|
def __getitem__(self, idx: int) -> Dict: |
|
"""Reads a single text observation.""" |
|
sample = super().__getitem__(idx) |
|
sample["reward_model_prompt_text"] = TEXT_SEPARATOR.join( |
|
self.get_chained_prompt_text_list(idx) |
|
) |
|
return sample |
|
|
|
def get_labels(self, prompt_encodings, answer_encodings): |
|
if self.mode == "train": |
|
return dict() |
|
else: |
|
return super().get_labels(prompt_encodings, answer_encodings) |
|
|
|
def get_encodings(self, input_text_dict): |
|
system_encoding, prompt_encodings, answer_encodings = super().get_encodings( |
|
input_text_dict |
|
) |
|
|
|
|
|
answer_encodings[-1] = torch.empty(0) |
|
return system_encoding, prompt_encodings, answer_encodings |
|
|
|
def postprocess_batch_predictions(self, output: Dict) -> Dict: |
|
if "predicted_answer_ids" in output.keys(): |
|
predicted_text = [ |
|
self.tokenizer.decode(ids, skip_special_tokens=True).strip() |
|
for ids in output["predicted_answer_ids"] |
|
] |
|
|
|
output["predicted_text"] = np.array(predicted_text) |
|
output["predicted_answer_ids"] = output["predicted_answer_ids"].detach() |
|
return output |
|
|
|
def augment_data(self, encodings): |
|
return encodings |
|
|
|
def get_chained_prompt_text_list(self, idx: int) -> List[str]: |
|
text_dict = self.conversation_chain_handler[idx] |
|
chat_history = "".join( |
|
[ |
|
prompt + TEXT_SEPARATOR + answer + TEXT_SEPARATOR |
|
for prompt, answer in zip( |
|
text_dict["prompts"][:-1], text_dict["answers"][:-1] |
|
) |
|
] |
|
) |
|
prompt_text = text_dict["systems"][0] + chat_history + text_dict["prompts"][-1] |
|
return prompt_text.split(TEXT_SEPARATOR) |
|
|