Elron commited on
Commit
b5acc40
1 Parent(s): 6754ccb

Upload common.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. common.py +99 -0
common.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .stream import MultiStream
2
+ from .operator import SourceOperator
3
+ from .card import TaskCard
4
+ from .splitters import SliceSplit, SpreadSplit, RandomSampler
5
+ from .recipe import SequentialRecipe, Recipe
6
+ from .collections import ItemPicker, RandomPicker
7
+ from .templates import RenderTemplatedICL
8
+ from .schema import ToUnitxtGroup
9
+
10
+ from typing import Union
11
+
12
+
13
+ class CommonRecipe(Recipe, SourceOperator):
14
+ card: TaskCard
15
+ demos_pool_name: str = "demos_pool"
16
+ demos_pool_size: int = None
17
+ demos_field: str = "demos"
18
+ num_demos: int = None
19
+ sampler_type: str = "random"
20
+ instruction_item: Union[str, int] = None
21
+ template_item: Union[str, int] = None
22
+
23
+ def verify(self):
24
+ self.sampler_type in ["random"]
25
+
26
+ def prepare(self):
27
+ steps = [
28
+ self.card.loader,
29
+ ]
30
+
31
+ if self.card.preprocess_steps is not None:
32
+ steps.extend(self.card.preprocess_steps)
33
+
34
+ steps.append(self.card.task)
35
+
36
+ if self.demos_pool_size is not None:
37
+ steps.append(
38
+ SliceSplit(
39
+ slices={
40
+ self.demos_pool_name: f"train[:{self.demos_pool_size}]",
41
+ "train": f"train[{self.demos_pool_size}:]",
42
+ "validation": "validation",
43
+ "test": "test",
44
+ }
45
+ )
46
+ )
47
+
48
+ if self.num_demos is not None:
49
+ if self.sampler_type == "random":
50
+ sampler = RandomSampler(sample_size=self.num_demos)
51
+
52
+ steps.append(
53
+ SpreadSplit(
54
+ source_stream=self.demos_pool_name,
55
+ target_field=self.demos_field,
56
+ sampler=sampler,
57
+ )
58
+ )
59
+
60
+ if self.card.instructions is not None:
61
+ if self.instruction_item is None:
62
+ picker = ItemPicker(self.instruction_item)
63
+ else:
64
+ picker = RandomPicker()
65
+ instruction = picker(self.card.instructions)
66
+ else:
67
+ instruction = None
68
+
69
+ if self.card.templates is not None:
70
+ if self.template_item is None:
71
+ picker = ItemPicker(self.template_item)
72
+ else:
73
+ picker = RandomPicker()
74
+ template = picker(self.card.templates)
75
+ else:
76
+ template = None
77
+
78
+ render = RenderTemplatedICL(
79
+ instruction=instruction,
80
+ template=template,
81
+ demos_field=self.demos_field,
82
+ )
83
+
84
+ steps.append(render)
85
+
86
+ postprocessors = render.get_postprocessors()
87
+
88
+ steps.append(
89
+ ToUnitxtGroup(
90
+ group="default",
91
+ metrics=self.card.task.metrics,
92
+ postprocessors=postprocessors,
93
+ )
94
+ )
95
+
96
+ self.recipe = SequentialRecipe(steps)
97
+
98
+ def process(self) -> MultiStream:
99
+ return self.recipe()