Upload standard.py with huggingface_hub
Browse files- standard.py +10 -14
standard.py
CHANGED
@@ -3,7 +3,6 @@ from typing import List
|
|
3 |
from .card import TaskCard
|
4 |
from .dataclass import Field, InternalField, OptionalField
|
5 |
from .formats import Format, SystemFormat
|
6 |
-
from .instructions import EmptyInstruction, Instruction
|
7 |
from .logging_utils import get_logger
|
8 |
from .operator import SourceSequentialOperator, StreamingOperator
|
9 |
from .operators import (
|
@@ -14,6 +13,7 @@ from .operators import (
|
|
14 |
from .recipe import Recipe
|
15 |
from .schema import ToUnitxtGroup
|
16 |
from .splitters import Sampler, SeparateSplit, SpreadSplit
|
|
|
17 |
from .templates import Template
|
18 |
|
19 |
logger = get_logger()
|
@@ -31,7 +31,7 @@ class AddDemosField(SpreadSplit):
|
|
31 |
class BaseRecipe(Recipe, SourceSequentialOperator):
|
32 |
card: TaskCard
|
33 |
template: Template = None
|
34 |
-
|
35 |
format: Format = Field(default_factory=SystemFormat)
|
36 |
|
37 |
loader_limit: int = None
|
@@ -46,6 +46,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
46 |
|
47 |
demos_pool_size: int = None
|
48 |
num_demos: int = 0
|
|
|
49 |
|
50 |
demos_pool_name: str = "demos_pool"
|
51 |
demos_taken_from: str = "train"
|
@@ -135,6 +136,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
135 |
from_split=self.demos_taken_from,
|
136 |
to_split_names=[self.demos_pool_name, self.demos_taken_from],
|
137 |
to_split_sizes=[int(self.demos_pool_size)],
|
|
|
138 |
)
|
139 |
)
|
140 |
|
@@ -160,7 +162,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
160 |
sampler=self.sampler,
|
161 |
)
|
162 |
)
|
163 |
-
self.steps.append(self.
|
164 |
self.steps.append(self.format)
|
165 |
if self.augmentor.augment_model_input:
|
166 |
self.steps.append(self.augmentor)
|
@@ -177,7 +179,6 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
177 |
|
178 |
|
179 |
class StandardRecipeWithIndexes(BaseRecipe):
|
180 |
-
instruction_card_index: int = None
|
181 |
template_card_index: int = None
|
182 |
|
183 |
def prepare(self):
|
@@ -192,17 +193,12 @@ class StandardRecipeWithIndexes(BaseRecipe):
|
|
192 |
self.template = self.card.templates[self.template_card_index]
|
193 |
except Exception as e:
|
194 |
if isinstance(self.card.templates, dict):
|
195 |
-
options = self.card.templates.keys()
|
196 |
else:
|
197 |
options = list(range(0, len(self.card.templates)))
|
198 |
raise ValueError(
|
199 |
-
f"card_template_index '{self.template_card_index}' is not in card.
|
200 |
) from e
|
201 |
-
assert (
|
202 |
-
self.instruction_card_index is None or self.instruction is None
|
203 |
-
), "Specify either instruction or instruction_card_index"
|
204 |
-
if self.instruction_card_index is not None:
|
205 |
-
self.instruction = self.card.instructions[int(self.instruction_card_index)]
|
206 |
|
207 |
super().prepare()
|
208 |
|
@@ -217,7 +213,7 @@ class StandardRecipe(StandardRecipeWithIndexes):
|
|
217 |
Attributes:
|
218 |
card (TaskCard): TaskCard object associated with the recipe.
|
219 |
template (Template, optional): Template object to be used for the recipe.
|
220 |
-
|
221 |
loader_limit (int, optional): Specifies the maximum number of instances per stream to be returned from the loader (used to reduce loading time in large datasets)
|
222 |
format (SystemFormat, optional): SystemFormat object to be used for the recipe.
|
223 |
train_refiner (StreamRefiner, optional): Train refiner to be used in the recipe.
|
@@ -231,6 +227,7 @@ class StandardRecipe(StandardRecipeWithIndexes):
|
|
231 |
demos_pool_name (str, optional): Name of the demos pool. Default is "demos_pool".
|
232 |
demos_taken_from (str, optional): Specifies from where the demos are taken. Default is "train".
|
233 |
demos_field (str, optional): Field name for demos. Default is "demos".
|
|
|
234 |
sampler (Sampler, optional): Sampler object to be used in the recipe.
|
235 |
steps (List[StreamingOperator], optional): List of StreamingOperator objects to be used in the recipe.
|
236 |
augmentor (Augmentor) : Augmentor to be used to pseudo randomly augment the source text
|
@@ -244,8 +241,7 @@ class StandardRecipe(StandardRecipeWithIndexes):
|
|
244 |
by arranging all the steps, refiners, and renderers in a sequential manner.
|
245 |
|
246 |
Raises:
|
247 |
-
AssertionError: If both template and template_card_index
|
248 |
-
are specified at the same time.
|
249 |
"""
|
250 |
|
251 |
pass
|
|
|
3 |
from .card import TaskCard
|
4 |
from .dataclass import Field, InternalField, OptionalField
|
5 |
from .formats import Format, SystemFormat
|
|
|
6 |
from .logging_utils import get_logger
|
7 |
from .operator import SourceSequentialOperator, StreamingOperator
|
8 |
from .operators import (
|
|
|
13 |
from .recipe import Recipe
|
14 |
from .schema import ToUnitxtGroup
|
15 |
from .splitters import Sampler, SeparateSplit, SpreadSplit
|
16 |
+
from .system_prompts import EmptySystemPrompt, SystemPrompt
|
17 |
from .templates import Template
|
18 |
|
19 |
logger = get_logger()
|
|
|
31 |
class BaseRecipe(Recipe, SourceSequentialOperator):
|
32 |
card: TaskCard
|
33 |
template: Template = None
|
34 |
+
system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
|
35 |
format: Format = Field(default_factory=SystemFormat)
|
36 |
|
37 |
loader_limit: int = None
|
|
|
46 |
|
47 |
demos_pool_size: int = None
|
48 |
num_demos: int = 0
|
49 |
+
demos_removed_from_data: bool = True
|
50 |
|
51 |
demos_pool_name: str = "demos_pool"
|
52 |
demos_taken_from: str = "train"
|
|
|
136 |
from_split=self.demos_taken_from,
|
137 |
to_split_names=[self.demos_pool_name, self.demos_taken_from],
|
138 |
to_split_sizes=[int(self.demos_pool_size)],
|
139 |
+
remove_targets_from_source_split=self.demos_removed_from_data,
|
140 |
)
|
141 |
)
|
142 |
|
|
|
162 |
sampler=self.sampler,
|
163 |
)
|
164 |
)
|
165 |
+
self.steps.append(self.system_prompt)
|
166 |
self.steps.append(self.format)
|
167 |
if self.augmentor.augment_model_input:
|
168 |
self.steps.append(self.augmentor)
|
|
|
179 |
|
180 |
|
181 |
class StandardRecipeWithIndexes(BaseRecipe):
|
|
|
182 |
template_card_index: int = None
|
183 |
|
184 |
def prepare(self):
|
|
|
193 |
self.template = self.card.templates[self.template_card_index]
|
194 |
except Exception as e:
|
195 |
if isinstance(self.card.templates, dict):
|
196 |
+
options = list(self.card.templates.keys())
|
197 |
else:
|
198 |
options = list(range(0, len(self.card.templates)))
|
199 |
raise ValueError(
|
200 |
+
f"card_template_index '{self.template_card_index}' is not defined in card. Possible card_template_index options: {options}"
|
201 |
) from e
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
super().prepare()
|
204 |
|
|
|
213 |
Attributes:
|
214 |
card (TaskCard): TaskCard object associated with the recipe.
|
215 |
template (Template, optional): Template object to be used for the recipe.
|
216 |
+
system_prompt (SystemPrompt, optional): SystemPrompt object to be used for the recipe.
|
217 |
loader_limit (int, optional): Specifies the maximum number of instances per stream to be returned from the loader (used to reduce loading time in large datasets)
|
218 |
format (SystemFormat, optional): SystemFormat object to be used for the recipe.
|
219 |
train_refiner (StreamRefiner, optional): Train refiner to be used in the recipe.
|
|
|
227 |
demos_pool_name (str, optional): Name of the demos pool. Default is "demos_pool".
|
228 |
demos_taken_from (str, optional): Specifies from where the demos are taken. Default is "train".
|
229 |
demos_field (str, optional): Field name for demos. Default is "demos".
|
230 |
+
demos_removed_from_data (bool, optional): whether to remove the demos from the source data, Default is True
|
231 |
sampler (Sampler, optional): Sampler object to be used in the recipe.
|
232 |
steps (List[StreamingOperator], optional): List of StreamingOperator objects to be used in the recipe.
|
233 |
augmentor (Augmentor) : Augmentor to be used to pseudo randomly augment the source text
|
|
|
241 |
by arranging all the steps, refiners, and renderers in a sequential manner.
|
242 |
|
243 |
Raises:
|
244 |
+
AssertionError: If both template and template_card_index are specified at the same time.
|
|
|
245 |
"""
|
246 |
|
247 |
pass
|