File size: 3,059 Bytes
a61eb0e
 
b5acc40
 
b0144ee
a61eb0e
 
b5acc40
b0144ee
a61eb0e
 
b5acc40
 
 
 
 
b74242e
b5acc40
 
 
b0144ee
b5acc40
 
40a8c03
b5acc40
 
b0144ee
b5acc40
 
 
 
 
 
 
 
 
 
 
 
 
b74242e
 
 
 
b5acc40
 
 
 
b0144ee
 
 
 
 
 
b5acc40
 
 
 
 
 
 
 
 
 
b74242e
 
b5acc40
 
 
 
 
 
 
 
 
b74242e
 
b5acc40
 
 
 
 
 
 
 
40a8c03
b5acc40
 
 
 
 
 
 
 
368a37d
b5acc40
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from typing import Union

from .card import TaskCard
from .collections import ItemPicker, RandomPicker
from .dataclass import OptionalField
from .operator import SourceOperator
from .recipe import Recipe, SequentialRecipe
from .schema import ToUnitxtGroup
from .splitters import RandomSampler, Sampler, SeparateSplit, SliceSplit, SpreadSplit
from .stream import MultiStream
from .templates import RenderTemplatedICL


class CommonRecipe(Recipe, SourceOperator):
    card: TaskCard
    demos_pool_name: str = "demos_pool"
    demos_taken_from: str = "train"
    demos_pool_size: int = None
    demos_field: str = "demos"
    num_demos: int = None
    sampler: Sampler = None
    instruction_item: Union[str, int] = None
    template_item: Union[str, int] = None
    system_prompt: str = None

    def verify(self):
        super().verify()

    def prepare(self):
        steps = [
            self.card.loader,
        ]

        if self.card.preprocess_steps is not None:
            steps.extend(self.card.preprocess_steps)

        steps.append(self.card.task)

        if self.demos_pool_size is not None:
            steps.append(
                SeparateSplit(
                    from_split=self.demos_taken_from,
                    to_split_names=[self.demos_pool_name, self.demos_taken_from],
                    to_split_sizes=[int(self.demos_pool_size)],
                )
            )

        if self.num_demos is not None:
            sampler = self.card.sampler

            if self.sampler is not None:
                sampler = self.sampler

            sampler.set_size(self.num_demos)

            steps.append(
                SpreadSplit(
                    source_stream=self.demos_pool_name,
                    target_field=self.demos_field,
                    sampler=sampler,
                )
            )

        if self.card.instructions is not None:
            if not self.instruction_item is None:
                picker = ItemPicker(int(self.instruction_item))
            else:
                picker = RandomPicker()
            instruction = picker(self.card.instructions)
        else:
            instruction = None

        if self.card.templates is not None:
            if self.template_item is None:
                picker = RandomPicker()
            else:
                picker = ItemPicker(self.template_item)
            template = picker(self.card.templates)
        else:
            template = None

        render = RenderTemplatedICL(
            instruction=instruction,
            template=template,
            demos_field=self.demos_field,
            system_prompt=self.system_prompt,
        )

        steps.append(render)

        postprocessors = render.get_postprocessors()

        steps.append(
            ToUnitxtGroup(
                group="unitxt",
                metrics=self.card.task.metrics,
                postprocessors=postprocessors,
            )
        )

        self.recipe = SequentialRecipe(steps)

    def process(self) -> MultiStream:
        return self.recipe()