File size: 10,863 Bytes
c6e9c8c
 
 
1e05e68
 
 
eee0bf8
a024d9a
c6e9c8c
 
 
cd9d84b
c6e9c8c
 
1e05e68
 
c6e9c8c
eee0bf8
 
 
 
 
 
 
 
 
 
c6e9c8c
e6be0c8
cd9d84b
1e05e68
c6e9c8c
eee0bf8
 
5bbb99c
 
 
 
eee0bf8
 
 
5bbb99c
c6e9c8c
5bbb99c
cd9d84b
c6e9c8c
 
 
 
 
 
eee0bf8
 
c6e9c8c
 
1e05e68
 
 
 
 
5bbb99c
 
 
 
 
a024d9a
5bbb99c
 
 
1e05e68
eee0bf8
 
 
1e05e68
eee0bf8
 
 
 
 
1e05e68
eee0bf8
 
 
 
 
 
1e05e68
eee0bf8
 
 
 
 
 
1e05e68
5bbb99c
 
67f4e71
 
 
0762906
67f4e71
 
 
0762906
67f4e71
 
 
0762906
67f4e71
c6e9c8c
 
 
a024d9a
 
 
 
 
 
 
 
 
 
c6e9c8c
 
eee0bf8
 
1e05e68
eee0bf8
 
c6e9c8c
 
 
 
 
eee0bf8
 
 
 
c6e9c8c
 
eee0bf8
c6e9c8c
 
 
cd9d84b
c6e9c8c
 
 
5bbb99c
1e05e68
 
 
 
 
 
 
c6e9c8c
1e05e68
c6e9c8c
67f4e71
5bbb99c
1e05e68
 
 
 
 
 
 
 
 
cd9d84b
1e05e68
eee0bf8
 
 
1e05e68
c6e9c8c
 
 
 
 
 
 
 
e6be0c8
 
5bbb99c
e6be0c8
 
 
 
 
eee0bf8
 
 
 
e6be0c8
eee0bf8
 
 
 
cd9d84b
eee0bf8
 
 
cd9d84b
eee0bf8
e6be0c8
 
5bbb99c
 
 
1e05e68
eee0bf8
 
5bbb99c
 
 
 
 
 
cd9d84b
eee0bf8
1e05e68
5bbb99c
 
 
 
 
 
 
 
 
 
 
cd9d84b
5bbb99c
 
eee0bf8
5bbb99c
 
 
 
 
 
 
 
 
 
cd9d84b
5bbb99c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
from typing import List

from .card import TaskCard
from .dataclass import Field, InternalField, OptionalField
from .formats import Format, SystemFormat
from .logging_utils import get_logger
from .operator import SourceSequentialOperator, StreamingOperator
from .operators import AddFields, Augmentor, NullAugmentor, StreamRefiner
from .recipe import Recipe
from .schema import ToUnitxtGroup
from .splitters import Sampler, SeparateSplit, SpreadSplit
from .system_prompts import EmptySystemPrompt, SystemPrompt
from .templates import Template

logger = get_logger()


# Used to give meaningful name to recipe steps
class CreateDemosPool(SeparateSplit):
    pass


class AddDemosField(SpreadSplit):
    pass


class BaseRecipe(Recipe, SourceSequentialOperator):
    card: TaskCard
    template: Template = None
    system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
    format: Format = Field(default_factory=SystemFormat)

    loader_limit: int = None

    max_train_instances: int = None
    max_validation_instances: int = None
    max_test_instances: int = None

    train_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner)
    validation_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner)
    test_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner)

    demos_pool_size: int = None
    num_demos: int = 0
    demos_removed_from_data: bool = True

    demos_pool_name: str = "demos_pool"
    demos_taken_from: str = "train"
    demos_field: str = "demos"
    sampler: Sampler = None

    augmentor: Augmentor = OptionalField(default_factory=NullAugmentor)

    steps: List[StreamingOperator] = InternalField(default_factory=list)

    def before_process_multi_stream(self):
        super().before_process_multi_stream()
        if self.sampler:  # e.g. when num_demos is 0, the sampler may not be initialized
            self.sampler.init_new_random_generator()

    def verify(self):
        super().verify()
        if self.num_demos > 0:
            if self.demos_pool_size is None or self.demos_pool_size < 1:
                raise ValueError(
                    "When using demonstrations both num_demos and demos_pool_size should be assigned with positive integers."
                )
            if self.demos_pool_size < self.num_demos:
                raise ValueError(
                    f"num_demos (got: {self.num_demos}) should not exceed demos_pool_size (got: {self.demos_pool_size})"
                )
            if self.loader_limit and self.demos_pool_size > self.loader_limit:
                raise ValueError(
                    f"demos_pool_size should not exceed loader_limit ({self.loader_limit}), Got demos_pool_size={self.demos_pool_size}"
                )

        if self.loader_limit:
            if self.max_test_instances and self.max_test_instances > self.loader_limit:
                raise ValueError(
                    f"max_test_instances should not exceed loader_limit ({self.loader_limit}), Got max_test_instances={self.max_test_instances}"
                )
            if (
                self.max_validation_instances
                and self.max_validation_instances > self.loader_limit
            ):
                raise ValueError(
                    f"max_validation_instances should not exceed loader_limit ({self.loader_limit}), Got max_validation_instances={self.max_validation_instances}"
                )
            if (
                self.max_train_instances
                and self.max_train_instances > self.loader_limit
            ):
                raise ValueError(
                    f"max_train_instances should not exceed loader_limit ({self.loader_limit}), Got max_train_instances={self.max_train_instances}"
                )

    def prepare_refiners(self):
        self.train_refiner.max_instances = self.max_train_instances
        self.train_refiner.apply_to_streams = ["train"]
        self.steps.append(self.train_refiner)

        self.validation_refiner.max_instances = self.max_validation_instances
        self.validation_refiner.apply_to_streams = ["validation"]
        self.steps.append(self.validation_refiner)

        self.test_refiner.max_instances = self.max_test_instances
        self.test_refiner.apply_to_streams = ["test"]
        self.steps.append(self.test_refiner)

    def prepare(self):
        self.steps = [
            self.card.loader,
            AddFields(
                fields={
                    "recipe_metadata": {
                        "card": self.card,
                        "template": self.template,
                        "system_prompt": self.system_prompt,
                        "format": self.format,
                    }
                }
            ),
        ]

        if self.loader_limit:
            self.card.loader.loader_limit = self.loader_limit
            logger.info(f"Loader line limit was set to  {self.loader_limit}")
            self.steps.append(StreamRefiner(max_instances=self.loader_limit))

        if self.card.preprocess_steps is not None:
            self.steps.extend(self.card.preprocess_steps)

        self.steps.append(self.card.task)

        if self.augmentor.augment_task_input:
            self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
            self.steps.append(self.augmentor)

        if self.demos_pool_size is not None:
            self.steps.append(
                CreateDemosPool(
                    from_split=self.demos_taken_from,
                    to_split_names=[self.demos_pool_name, self.demos_taken_from],
                    to_split_sizes=[int(self.demos_pool_size)],
                    remove_targets_from_source_split=self.demos_removed_from_data,
                )
            )

        if self.num_demos > 0:
            if self.sampler is None:
                if self.card.sampler is None:
                    raise ValueError(
                        "Unexpected None value for card.sampler. "
                        "To use num_demos > 0, please set a sampler on the TaskCard."
                    )
                self.sampler = self.card.sampler

            self.sampler.set_size(self.num_demos)

        self.prepare_refiners()

        self.steps.append(self.template)
        if self.num_demos > 0:
            self.steps.append(
                AddDemosField(
                    source_stream=self.demos_pool_name,
                    target_field=self.demos_field,
                    sampler=self.sampler,
                )
            )
        self.steps.append(self.system_prompt)
        self.steps.append(self.format)
        if self.augmentor.augment_model_input:
            self.steps.append(self.augmentor)

        postprocessors = self.template.get_postprocessors()

        self.steps.append(
            ToUnitxtGroup(
                group="unitxt",
                metrics=self.card.task.metrics,
                postprocessors=postprocessors,
            )
        )


class StandardRecipeWithIndexes(BaseRecipe):
    template_card_index: int = None

    def prepare(self):
        assert (
            self.template_card_index is None or self.template is None
        ), f"Specify either template ({self.template}) or template_card_index ({self.template_card_index}) but not both"
        assert not (
            self.template_card_index is None and self.template is None
        ), "Specify either template or template_card_index in card"
        if self.template_card_index is not None:
            try:
                self.template = self.card.templates[self.template_card_index]
            except Exception as e:
                if isinstance(self.card.templates, dict):
                    options = list(self.card.templates.keys())
                else:
                    options = list(range(0, len(self.card.templates)))
                raise ValueError(
                    f"card_template_index '{self.template_card_index}' is not defined in card. Possible card_template_index options: {options}"
                ) from e

        super().prepare()


class StandardRecipe(StandardRecipeWithIndexes):
    """This class represents a standard recipe for data processing and preparation.

    This class can be used to prepare a recipe.
    with all necessary steps, refiners and renderers included. It allows to set various
    parameters and steps in a sequential manner for preparing the recipe.

    Attributes:
        card (TaskCard): TaskCard object associated with the recipe.
        template (Template, optional): Template object to be used for the recipe.
        system_prompt (SystemPrompt, optional): SystemPrompt object to be used for the recipe.
        loader_limit (int, optional): Specifies the maximum number of instances per stream to be returned from the loader (used to reduce loading time in large datasets)
        format (SystemFormat, optional): SystemFormat object to be used for the recipe.
        train_refiner (StreamRefiner, optional): Train refiner to be used in the recipe.
        max_train_instances (int, optional): Maximum training instances for the refiner.
        validation_refiner (StreamRefiner, optional): Validation refiner to be used in the recipe.
        max_validation_instances (int, optional): Maximum validation instances for the refiner.
        test_refiner (StreamRefiner, optional): Test refiner to be used in the recipe.
        max_test_instances (int, optional): Maximum test instances for the refiner.
        demos_pool_size (int, optional): Size of the demos pool.
        num_demos (int, optional): Number of demos to be used.
        demos_pool_name (str, optional): Name of the demos pool. Default is "demos_pool".
        demos_taken_from (str, optional): Specifies from where the demos are taken. Default is "train".
        demos_field (str, optional): Field name for demos. Default is "demos".
        demos_removed_from_data (bool, optional): whether to remove the demos from the source data, Default is True
        sampler (Sampler, optional): Sampler object to be used in the recipe.
        steps (List[StreamingOperator], optional): List of StreamingOperator objects to be used in the recipe.
        augmentor (Augmentor) : Augmentor to be used to pseudo randomly augment the source text
        instruction_card_index (int, optional): Index of instruction card to be used
            for preparing the recipe.
        template_card_index (int, optional): Index of template card to be used for
            preparing the recipe.

    Methods:
        prepare(): This overridden method is used for preparing the recipe
            by arranging all the steps, refiners, and renderers in a sequential manner.

    Raises:
        AssertionError: If both template and template_card_index are specified at the same time.
    """

    pass