File size: 8,956 Bytes
da8868f
 
 
 
 
 
 
0537112
 
d754e91
 
40a8f4e
d754e91
da8868f
 
 
0537112
da8868f
 
 
 
d754e91
 
 
 
 
0537112
d754e91
0537112
 
 
 
 
 
40a8f4e
0537112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da8868f
 
 
 
 
 
 
 
d754e91
 
 
da8868f
 
d754e91
c620e0b
 
 
 
d754e91
 
0537112
 
 
 
d754e91
0537112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d754e91
da8868f
c620e0b
 
 
 
 
 
d754e91
 
 
 
 
 
 
 
 
 
 
da8868f
 
 
 
 
 
 
d754e91
 
00263ef
 
 
 
 
116804a
00263ef
116804a
d754e91
 
 
 
 
4a6324a
d754e91
 
 
0537112
 
4a6324a
 
0537112
 
 
 
4a6324a
 
 
0537112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d754e91
 
 
 
 
 
 
 
 
 
 
 
4a6324a
 
 
 
 
0537112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
"""
A dedicated helper to manage templates and prompt building.
From https://github.com/tloen/alpaca-lora/blob/main/utils/prompter.py
"""

import json
import os.path as osp
import importlib
import itertools
from typing import Union, List

from ..config import Config
from ..globals import Global


class Prompter(object):
    __slots__ = ("template_name", "template", "template_module", "_verbose")

    def __init__(self, template_name: str = "", verbose: bool = False):
        self._verbose = verbose
        if not template_name:
            template_name = "None"
        if template_name == "None":
            self.template_name = "None"
            return
        self.template_name = template_name
        self.template_module = None

        base_filename, ext = osp.splitext(template_name)
        if ext == "":
            filename = base_filename + ".json"
        else:
            filename = base_filename + ext

        file_path = osp.join(Config.data_dir, "templates", filename)

        if not osp.exists(file_path):
            raise ValueError(f"Can't read {file_path}")

        if ext == ".py":
            template_module_spec = importlib.util.spec_from_file_location(
                "template_module", file_path)
            template_module = importlib.util.module_from_spec(
                template_module_spec)
            template_module_spec.loader.exec_module(template_module)
            self.template_module = template_module

            if not hasattr(template_module, "variables"):
                raise ValueError(
                    "The template module does not have a \"variables\" attribute.")

            self.template = {
                'variables': template_module.variables
            }

            if hasattr(template_module, "response_split"):
                self.template["response_split"] = template_module.response_split

            return

        with open(file_path) as fp:
            self.template = json.load(fp)
        if self._verbose:
            print(
                f"Using prompt template {template_name}: {self.template['description']}"
            )

    def generate_prompt(
        self,
        variables: List[Union[None, str]] = [],
        # instruction: str,
        # input: Union[None, str] = None,
        label: Union[None, str] = None,
    ) -> str:
        if self.template_name == "None":
            if type(variables) == list:
                res = get_val(variables, 0, "")
            else:
                res = variables.get("prompt", "")
        elif "variables" in self.template:
            variable_names = self.template.get("variables")
            if self.template_module:
                if type(variables) == list:
                    variables = {k: v for k, v in zip(
                        variable_names, variables)}

                res = self.template_module.get_prompt(variables)
            else:
                if type(variables) == dict:
                    variables = [variables.get(name, None)
                                 for name in variable_names]

                if "default" not in self.template:
                    raise ValueError(
                        f"The template {self.template_name} has \"variables\" defined but does not has a default prompt defined. Please do it like: '\"default\": \"prompt_with_instruction\"' to handle cases when a matching prompt can't be found.")
                default_prompt_name = self.template.get("default")
                if default_prompt_name not in self.template:
                    raise ValueError(
                        f"The template {self.template_name} has \"default\" set to \"{default_prompt_name}\" but it's not defined. Please do it like: '\"{default_prompt_name}\": \"...\".")
                prompt_name = get_prompt_name(variables, variable_names)
                prompt_template = self.template.get(default_prompt_name)
                if prompt_name in self.template:
                    prompt_template = self.template.get(prompt_name)

                res = prompt_template.format(
                    **variables_to_dict(variables, variable_names))

        else:
            if type(variables) == dict:
                instruction = variables.get("instruction", "")
                input = variables.get("input")
            else:
                instruction = get_val(variables, 0, "")
                input = get_val(variables, 1)
            # returns the full prompt from instruction and optional input
            # if a label (=response, =output) is provided, it's also appended.
            if input:
                res = self.template["prompt_input"].format(
                    instruction=instruction, input=input
                )
            else:
                res = self.template["prompt_no_input"].format(
                    instruction=instruction
                )

        if label:
            res = f"{res}{label}"
        if self._verbose:
            print(res)
        return res

    def get_response(self, output: str) -> str:
        if self.template_name == "None":
            return output

        splitted_output = output.split(self.template["response_split"])
        # if len(splitted_output) <= 1:
        #     return output.strip()

        return self.template["response_split"].join(
            splitted_output[1:]
        ).strip()

    def get_variable_names(self) -> List[str]:
        if self.template_name == "None":
            return ["prompt"]
        elif "variables" in self.template:
            return self.template['variables']
        else:
            return ["instruction", "input"]

    def get_train_data_from_dataset(self, data, only_first_n_items=None):
        if self.template_module:
            if hasattr(self.template_module,
                       "get_train_data_list_from_dataset"):
                data = self.template_module.get_train_data_list_from_dataset(
                    data)
            if only_first_n_items:
                data = data[:only_first_n_items]
            return list(itertools.chain(*list(
                map(self.template_module.get_train_data, data)
            )))

        if only_first_n_items:
            data = data[:only_first_n_items]

        data = process_json_dataset(data)

        train_data = [
            {
                'prompt': self.generate_prompt(d['variables']),
                'completion': d['output'],
                **{"_var_" + k: v for k, v in d['variables'].items()}
            }
            for d in data]

        return train_data


def get_val(arr, index, default=None):
    return arr[index] if -len(arr) <= index < len(arr) else default


def get_prompt_name(variables, variable_names):
    result = [y for x, y in zip(
        variables, variable_names) if x not in (None, '')]
    return "prompt_with_" + '_'.join(result)


def variables_to_dict(variables, variable_names):
    return {
        key: (variables[i] if i < len(variables)
              and variables[i] is not None else '')
        for i, key in enumerate(variable_names)
    }


def process_json_dataset(data):
    if not isinstance(data, list):
        raise ValueError("The dataset is not an array of objects.")

    first_item = get_val_from_arr(data, 0, None)

    if first_item is None:
        raise ValueError("The dataset is empty.")
    if not isinstance(first_item, dict):
        raise ValueError("The dataset is not an array of objects.")

    # Convert OpenAI fine-tuning dataset to LLaMA LoRA style
    if "completion" in first_item and "output" not in first_item:
        data = [
            {"output" if k == "completion" else k: v for k, v in d.items()}
            for d in data]
        first_item = get_val_from_arr(data, 0, None)

    # Flatten Stanford Alpaca style instances
    if "instances" in first_item and isinstance(first_item["instances"], list):
        data = [
            {"output" if k == "completion" else k: v for k, v in d.items()}
            for d in data]
        flattened_data = []
        for item in data:
            for instance in item["instances"]:
                d = {k: v for k, v in item.items() if k != "instances"}
                d.update(instance)
                flattened_data.append(d)
        data = flattened_data
        first_item = get_val_from_arr(data, 0, None)

    if "output" not in first_item:
        raise ValueError(
            "The data does not contains an \"output\" or \"completion\".")

    # Put all variables under the "variables" key if it does not exists
    if "variables" not in first_item:
        data = [
            {
                "variables":
                    {k: v for k, v in d.items() if k != "output"},
                "output":
                    d["output"]
            }
            for d in data
        ]
    return data


def get_val_from_arr(arr, index, default=None):
    return arr[index] if -len(arr) <= index < len(arr) else default