File size: 2,861 Bytes
07423df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import logging
from typing import Any, Dict, List

import numpy as np
import pandas as pd
import torch

from llm_studio.src.datasets.text_causal_language_modeling_ds import (
    CustomDataset as CausalLMCustomDataset,
)
from llm_studio.src.datasets.text_utils import TEXT_SEPARATOR

logger = logging.getLogger(__name__)


class CustomDataset(CausalLMCustomDataset):
    def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"):
        assert (
            cfg.dataset.system_column == "None"
        ), "RLHF is not compatible with system column."
        assert (
            cfg.dataset.limit_chained_samples is False
        ), "RLHF is not compatible with limit_chained_samples."
        assert (
            cfg.dataset.mask_prompt_labels is True
        ), "RLHF is not compatible with mask_prompt_labels."
        super().__init__(df, cfg, mode)

    def __getitem__(self, idx: int) -> Dict:
        """Reads a single text observation."""
        sample = super().__getitem__(idx)
        sample["reward_model_prompt_text"] = TEXT_SEPARATOR.join(
            self.get_chained_prompt_text_list(idx)
        )
        return sample

    def get_labels(self, prompt_encodings, answer_encodings):
        if self.mode == "train":  # no labels required for RLHF during training
            return dict()
        else:
            return super().get_labels(prompt_encodings, answer_encodings)

    def get_encodings(self, input_text_dict):
        system_encoding, prompt_encodings, answer_encodings = super().get_encodings(
            input_text_dict
        )
        # remove last ground truth answer,
        # as RLHF will generate the answer from the prompt
        answer_encodings[-1] = torch.empty(0)
        return system_encoding, prompt_encodings, answer_encodings

    def postprocess_batch_predictions(self, output: Dict) -> Dict:
        if "predicted_answer_ids" in output.keys():
            predicted_text = [
                self.tokenizer.decode(ids, skip_special_tokens=True).strip()
                for ids in output["predicted_answer_ids"]
            ]

            output["predicted_text"] = np.array(predicted_text)
            output["predicted_answer_ids"] = output["predicted_answer_ids"].detach()
        return output

    def augment_data(self, encodings):
        return encodings

    def get_chained_prompt_text_list(self, idx: int) -> List[str]:
        text_dict = self.conversation_chain_handler[idx]
        chat_history = "".join(
            [
                prompt + TEXT_SEPARATOR + answer + TEXT_SEPARATOR
                for prompt, answer in zip(
                    text_dict["prompts"][:-1], text_dict["answers"][:-1]
                )
            ]
        )
        prompt_text = text_dict["systems"][0] + chat_history + text_dict["prompts"][-1]
        return prompt_text.split(TEXT_SEPARATOR)