File size: 6,994 Bytes
c6e9c8c
 
 
5bbb99c
c6e9c8c
 
 
5bbb99c
c6e9c8c
 
 
 
 
 
 
5bbb99c
c6e9c8c
e6be0c8
c6e9c8c
e6be0c8
c6e9c8c
5bbb99c
 
 
 
 
 
 
 
 
 
c6e9c8c
5bbb99c
c6e9c8c
 
 
 
 
 
 
 
5bbb99c
 
 
 
 
 
 
 
 
 
 
 
c6e9c8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bbb99c
c6e9c8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bbb99c
 
 
 
 
 
 
 
 
c6e9c8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6be0c8
 
5bbb99c
e6be0c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bbb99c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
from typing import List

from .card import TaskCard
from .dataclass import InternalField, OptionalField
from .formats import ICLFormat
from .instructions import Instruction
from .operator import SourceSequntialOperator, StreamingOperator
from .operators import StreamRefiner
from .recipe import Recipe
from .renderers import StandardRenderer
from .schema import ToUnitxtGroup
from .splitters import Sampler, SeparateSplit, SpreadSplit
from .templates import Template


class BaseRecipe(Recipe, SourceSequntialOperator):
    card: TaskCard
    template: Template = None
    instruction: Instruction = None
    format: ICLFormat = ICLFormat()

    max_train_instances: int = None
    max_validation_instances: int = None
    max_test_instances: int = None

    train_refiner: StreamRefiner = OptionalField(default_factory=lambda: StreamRefiner(apply_to_streams=["train"]))
    validation_refiner: StreamRefiner = OptionalField(
        default_factory=lambda: StreamRefiner(apply_to_streams=["validation"])
    )
    test_refiner: StreamRefiner = OptionalField(default_factory=lambda: StreamRefiner(apply_to_streams=["test"]))

    demos_pool_size: int = None
    num_demos: int = 0

    demos_pool_name: str = "demos_pool"
    demos_taken_from: str = "train"
    demos_field: str = "demos"
    sampler: Sampler = None

    steps: List[StreamingOperator] = InternalField(default_factory=list)

    def verify(self):
        super().verify()
        if self.num_demos > 0:
            if self.demos_pool_size is None or self.demos_pool_size < 1:
                raise ValueError(
                    "When using demonstrations both num_demos and demos_pool_size should be assigned with postive integers."
                )
            if self.demos_pool_size < self.num_demos:
                raise ValueError(
                    f"demos_pool_size must be bigger than num_demos={self.num_demos}, Got demos_pool_size={self.demos_pool_size}"
                )

    def prepare(self):
        self.steps = [
            self.card.loader,
        ]

        if self.card.preprocess_steps is not None:
            self.steps.extend(self.card.preprocess_steps)

        self.steps.append(self.card.task)

        if self.demos_pool_size is not None:
            self.steps.append(
                SeparateSplit(
                    from_split=self.demos_taken_from,
                    to_split_names=[self.demos_pool_name, self.demos_taken_from],
                    to_split_sizes=[int(self.demos_pool_size)],
                )
            )

        if self.num_demos > 0:
            sampler = self.card.sampler

            if self.sampler is not None:
                sampler = self.sampler

            sampler.set_size(self.num_demos)

            self.steps.append(
                SpreadSplit(
                    source_stream=self.demos_pool_name,
                    target_field=self.demos_field,
                    sampler=sampler,
                )
            )

        self.train_refiner.max_instances = self.max_train_instances
        self.steps.append(self.train_refiner)

        self.validation_refiner.max_instances = self.max_validation_instances
        self.steps.append(self.validation_refiner)

        self.test_refiner.max_instances = self.max_test_instances
        self.steps.append(self.test_refiner)

        render = StandardRenderer(
            instruction=self.instruction,
            template=self.template,
            format=self.format,
            demos_field=self.demos_field,
        )

        self.steps.append(render)

        postprocessors = render.get_postprocessors()

        self.steps.append(
            ToUnitxtGroup(
                group="unitxt",
                metrics=self.card.task.metrics,
                postprocessors=postprocessors,
            )
        )


class StandardRecipeWithIndexes(BaseRecipe):
    instruction_card_index: int = None
    template_card_index: int = None

    def prepare(self):
        assert (
            self.template_card_index is None or self.template is None
        ), "Specify either template or template_card_index"
        if self.template_card_index is not None:
            self.template = self.card.templates[int(self.template_card_index)]

        assert (
            self.instruction_card_index is None or self.instruction is None
        ), "Specify either instruction or instruction_card_index"
        if self.instruction_card_index is not None:
            self.instruction = self.card.instructions[int(self.instruction_card_index)]

        super().prepare()


class StandardRecipe(StandardRecipeWithIndexes):
    """
    This class represents a standard recipe for data processing and preperation.
    This class can be used to prepare a recipe
    with all necessary steps, refiners and renderers included. It allows to set various
    parameters and steps in a sequential manner for preparing the recipe.

    Attributes:
        card (TaskCard): TaskCard object associated with the recipe.
        template (Template, optional): Template object to be used for the recipe.
        instruction (Instruction, optional): Instruction object to be used for the recipe.
        format (ICLFormat, optional): ICLFormat object to be used for the recipe.
        train_refiner (StreamRefiner, optional): Train refiner to be used in the recipe.
        max_train_instances (int, optional): Maximum training instances for the refiner.
        validation_refiner (StreamRefiner, optional): Validation refiner to be used in the recipe.
        max_validation_instances (int, optional): Maximum validation instances for the refiner.
        test_refiner (StreamRefiner, optional): Test refiner to be used in the recipe.
        max_test_instances (int, optional): Maximum test instances for the refiner.
        demos_pool_size (int, optional): Size of the demos pool.
        num_demos (int, optional): Number of demos to be used.
        demos_pool_name (str, optional): Name of the demos pool. Default is "demos_pool".
        demos_taken_from (str, optional): Specifies from where the demos are taken. Default is "train".
        demos_field (str, optional): Field name for demos. Default is "demos".
        sampler (Sampler, optional): Sampler object to be used in the recipe.
        steps (List[StreamingOperator], optional): List of StreamingOperator objects to be used in the recipe.
        instruction_card_index (int, optional): Index of instruction card to be used
            for preparing the recipe.
        template_card_index (int, optional): Index of template card to be used for
            preparing the recipe.

    Methods:
        prepare(): This overridden method is used for preparing the recipe
            by arranging all the steps, refiners, and renderers in a sequential manner.

    Raises:
        AssertionError: If both template and template_card_index, or instruction and instruction_card_index
            are specified at the same time.
    """

    pass