Elron commited on
Commit
c6e9c8c
1 Parent(s): 370e1f5

Upload standard.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. standard.py +83 -0
standard.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from .card import TaskCard
4
+ from .dataclass import InternalField
5
+ from .formats import ICLFormat
6
+ from .instructions import Instruction
7
+ from .operator import SourceSequntialOperator, StreamingOperator
8
+ from .recipe import Recipe
9
+ from .renderers import StandardRenderer
10
+ from .schema import ToUnitxtGroup
11
+ from .splitters import Sampler, SeparateSplit, SpreadSplit
12
+ from .templates import Template
13
+
14
+
15
+ class StandardRecipe(Recipe, SourceSequntialOperator):
16
+ card: TaskCard
17
+ template: Template
18
+ instruction: Instruction = None
19
+ format: ICLFormat = None
20
+
21
+ demos_pool_size: int = None
22
+ num_demos: int = None
23
+
24
+ demos_pool_name: str = "demos_pool"
25
+ demos_taken_from: str = "train"
26
+ demos_field: str = "demos"
27
+ sampler: Sampler = None
28
+
29
+ steps: List[StreamingOperator] = InternalField(default_factory=list)
30
+
31
+ def prepare(self):
32
+ self.steps = [
33
+ self.card.loader,
34
+ ]
35
+
36
+ if self.card.preprocess_steps is not None:
37
+ self.steps.extend(self.card.preprocess_steps)
38
+
39
+ self.steps.append(self.card.task)
40
+
41
+ if self.demos_pool_size is not None:
42
+ self.steps.append(
43
+ SeparateSplit(
44
+ from_split=self.demos_taken_from,
45
+ to_split_names=[self.demos_pool_name, self.demos_taken_from],
46
+ to_split_sizes=[int(self.demos_pool_size)],
47
+ )
48
+ )
49
+
50
+ if self.num_demos is not None:
51
+ sampler = self.card.sampler
52
+
53
+ if self.sampler is not None:
54
+ sampler = self.sampler
55
+
56
+ sampler.set_size(self.num_demos)
57
+
58
+ self.steps.append(
59
+ SpreadSplit(
60
+ source_stream=self.demos_pool_name,
61
+ target_field=self.demos_field,
62
+ sampler=sampler,
63
+ )
64
+ )
65
+
66
+ render = StandardRenderer(
67
+ instruction=self.instruction,
68
+ template=self.template,
69
+ format=self.format,
70
+ demos_field=self.demos_field,
71
+ )
72
+
73
+ self.steps.append(render)
74
+
75
+ postprocessors = render.get_postprocessors()
76
+
77
+ self.steps.append(
78
+ ToUnitxtGroup(
79
+ group="unitxt",
80
+ metrics=self.card.task.metrics,
81
+ postprocessors=postprocessors,
82
+ )
83
+ )