Elron commited on
Commit
161e5a1
·
1 Parent(s): e4d5cd2

Upload templates.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. templates.py +150 -33
templates.py CHANGED
@@ -4,11 +4,13 @@ from dataclasses import field
4
  from typing import Any, Dict, List, Optional, Union
5
 
6
  from .artifact import Artifact
 
7
  from .dataclass import NonPositionalField
8
  from .instructions import Instruction, TextualInstruction
9
- from .operator import InstanceOperatorWithGlobalAccess, StreamInstanceOperator
10
- from .random_utils import random
11
  from .text_utils import split_words
 
12
 
13
 
14
  class Renderer(ABC):
@@ -39,10 +41,14 @@ class RenderFormatTemplate(Renderer, StreamInstanceOperator):
39
  random_reference: bool = False
40
 
41
  def verify(self):
42
- assert isinstance(self.template, Template), "Template must be an instance of Template"
 
 
43
  assert self.template is not None, "Template must be specified"
44
 
45
- def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
 
 
46
  return self.render(instance)
47
 
48
  def render(self, instance: Dict[str, Any]) -> Dict[str, Any]:
@@ -55,7 +61,7 @@ class RenderFormatTemplate(Renderer, StreamInstanceOperator):
55
  if self.template.is_multi_reference:
56
  references = targets
57
  if self.random_reference:
58
- target = random.choice(references)
59
  else:
60
  if len(references) == 0:
61
  raise ValueError("No references found")
@@ -87,7 +93,7 @@ class RenderAutoFormatTemplate(RenderFormatTemplate):
87
  except:
88
  pass
89
 
90
- inputs = {key: value for key, value in instance["inputs"].items()}
91
 
92
  return super().render({**instance, "inputs": inputs})
93
 
@@ -118,7 +124,12 @@ class RenderTemplatedICL(RenderAutoFormatTemplate):
118
 
119
  example = super().render(instance)
120
 
121
- input_str = self.input_prefix + example["source"] + self.input_output_separator + self.output_prefix
 
 
 
 
 
122
 
123
  if self.instruction is not None:
124
  source += self.instruction_prefix + self.instruction() + self.demo_separator
@@ -136,7 +147,9 @@ class RenderTemplatedICL(RenderAutoFormatTemplate):
136
  )
137
 
138
  if self.size_limiter is not None:
139
- if not self.size_limiter.check(source + demo_str + input_str + example["target"]):
 
 
140
  continue
141
 
142
  source += demo_str
@@ -155,7 +168,9 @@ class RenderTemplatedICL(RenderAutoFormatTemplate):
155
  class InputOutputTemplate(Template):
156
  input_format: str = None
157
  output_format: str = None
158
- postprocessors: List[str] = field(default_factory=lambda: ["processors.to_string_stripped"])
 
 
159
 
160
  def process_template(self, template: str, data: Dict[str, object]) -> str:
161
  data = {k: ", ".join(v) if isinstance(v, list) else v for k, v in data.items()}
@@ -166,16 +181,91 @@ class InputOutputTemplate(Template):
166
  return self.process_template(self.input_format, inputs)
167
  except KeyError as e:
168
  raise KeyError(
169
- f"Available inputs are {inputs.keys()} but input format requires a different one: {self.input_format}"
170
- )
171
 
172
  def process_outputs(self, outputs: Dict[str, object]) -> str:
173
  try:
174
  return self.process_template(self.output_format, outputs)
175
  except KeyError as e:
176
  raise KeyError(
177
- f"Available inputs are {outputs.keys()} but output format requires a different one: {self.output_format}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  def get_postprocessors(self) -> List[str]:
181
  return self.postprocessors
@@ -188,10 +278,17 @@ class KeyValTemplate(Template):
188
  outputs_key_val_seperator: str = ": "
189
  use_keys_for_outputs: bool = False
190
 
191
- postprocessors: List[str] = field(default_factory=lambda: ["processors.to_string_stripped"])
 
 
192
 
193
- def process_dict(self, dic: Dict[str, object], key_val_sep, pairs_sep, use_keys) -> str:
194
- dic = {k: ", ".join([str(vi) for vi in v]) if isinstance(v, list) else v for k, v in dic.items()}
 
 
 
 
 
195
  pairs = []
196
  for key, val in dic.items():
197
  key_val = [key, val] if use_keys else [val]
@@ -221,9 +318,10 @@ class KeyValTemplate(Template):
221
  class OutputQuantizingTemplate(InputOutputTemplate):
222
  quantum: float = 0.1
223
 
224
- def process_outputs(self, outputs: Dict[str, object]) -> Dict[str, object]:
225
  quantized_outputs = {
226
- key: round(input_float / self.quantum) * self.quantum for key, input_float in outputs.items()
 
227
  }
228
  return super().process_outputs(quantized_outputs)
229
 
@@ -235,12 +333,25 @@ class MultiLabelTemplate(InputOutputTemplate):
235
  output_format = "{labels}"
236
  empty_label = "None"
237
 
238
- def process_outputs(self, outputs: Dict[str, object]) -> Dict[str, object]:
239
  labels = outputs[self.labels_field]
240
  if len(labels) == 0:
241
  labels = [self.empty_label]
242
  labels_str = self.labels_seprator.join(labels)
243
- return super().process_outputs({"labels": labels_str})
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
 
246
  def escape_chars(s, chars_to_escape):
@@ -296,13 +407,16 @@ class SpanLabelingTemplate(SpanLabelingBaseTemplate):
296
 
297
 
298
  class SpanLabelingJsonTemplate(SpanLabelingBaseTemplate):
299
- postprocessors = ["processors.load_json", "processors.dict_of_lists_to_value_key_pairs"]
 
 
 
300
 
301
  def span_label_pairs_to_targets(self, span_label_pairs):
302
  groups = {}
303
  for span, label in span_label_pairs:
304
  if label not in groups:
305
- groups[label] = list()
306
  groups[label].append(span)
307
  if len(groups) > 0:
308
  targets = [json.dumps(groups)]
@@ -315,7 +429,9 @@ class AutoInputOutputTemplate(InputOutputTemplate):
315
  def infer_input_format(self, inputs):
316
  input_format = ""
317
  for key in inputs.keys():
318
- name = " ".join(word.lower().capitalize() for word in split_words(key) if word != " ")
 
 
319
  input_format += name + ": " + "{" + key + "}" + "\n"
320
  self.input_format = input_format
321
 
@@ -332,21 +448,20 @@ class AutoInputOutputTemplate(InputOutputTemplate):
332
  return self.input_format is not None and self.output_format is not None
333
 
334
 
335
- from .collections import ListCollection
336
-
337
-
338
  class TemplatesList(ListCollection):
339
  def verify(self):
340
  for template in self.items:
341
  assert isinstance(template, Template)
342
 
343
 
344
- def outputs_inputs2templates(inputs: Union[str, List], outputs: Union[str, List]) -> TemplatesList:
345
- """
346
- combines input and output formats into their dot product
 
 
347
  :param inputs: list of input formats (or one)
348
  :param outputs: list of output formats (or one)
349
- :return: TemplatesList of InputOutputTemplate
350
  """
351
  templates = []
352
  if isinstance(inputs, str):
@@ -367,8 +482,8 @@ def outputs_inputs2templates(inputs: Union[str, List], outputs: Union[str, List]
367
  def instructions2templates(
368
  instructions: List[TextualInstruction], templates: List[InputOutputTemplate]
369
  ) -> TemplatesList:
370
- """
371
- Insert instructions into per demonstration templates
372
  :param instructions:
373
  :param templates: strings containing {instuction} where the instruction should be placed
374
  :return:
@@ -378,7 +493,9 @@ def instructions2templates(
378
  for template in templates:
379
  res_templates.append(
380
  InputOutputTemplate(
381
- input_format=template.input_format.replace("{instruction}", instruction.text),
 
 
382
  output_format=template.output_format,
383
  )
384
  )
@@ -387,5 +504,5 @@ def instructions2templates(
387
 
388
  class TemplatesDict(Dict):
389
  def verify(self):
390
- for key, template in self.items():
391
  assert isinstance(template, Template)
 
4
  from typing import Any, Dict, List, Optional, Union
5
 
6
  from .artifact import Artifact
7
+ from .collections import ListCollection
8
  from .dataclass import NonPositionalField
9
  from .instructions import Instruction, TextualInstruction
10
+ from .operator import StreamInstanceOperator
11
+ from .random_utils import get_random
12
  from .text_utils import split_words
13
+ from .type_utils import isoftype
14
 
15
 
16
  class Renderer(ABC):
 
41
  random_reference: bool = False
42
 
43
  def verify(self):
44
+ assert isinstance(
45
+ self.template, Template
46
+ ), "Template must be an instance of Template"
47
  assert self.template is not None, "Template must be specified"
48
 
49
+ def process(
50
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
51
+ ) -> Dict[str, Any]:
52
  return self.render(instance)
53
 
54
  def render(self, instance: Dict[str, Any]) -> Dict[str, Any]:
 
61
  if self.template.is_multi_reference:
62
  references = targets
63
  if self.random_reference:
64
+ target = get_random().choice(references)
65
  else:
66
  if len(references) == 0:
67
  raise ValueError("No references found")
 
93
  except:
94
  pass
95
 
96
+ inputs = dict(instance["inputs"].items())
97
 
98
  return super().render({**instance, "inputs": inputs})
99
 
 
124
 
125
  example = super().render(instance)
126
 
127
+ input_str = (
128
+ self.input_prefix
129
+ + example["source"]
130
+ + self.input_output_separator
131
+ + self.output_prefix
132
+ )
133
 
134
  if self.instruction is not None:
135
  source += self.instruction_prefix + self.instruction() + self.demo_separator
 
147
  )
148
 
149
  if self.size_limiter is not None:
150
+ if not self.size_limiter.check(
151
+ source + demo_str + input_str + example["target"]
152
+ ):
153
  continue
154
 
155
  source += demo_str
 
168
  class InputOutputTemplate(Template):
169
  input_format: str = None
170
  output_format: str = None
171
+ postprocessors: List[str] = field(
172
+ default_factory=lambda: ["processors.to_string_stripped"]
173
+ )
174
 
175
  def process_template(self, template: str, data: Dict[str, object]) -> str:
176
  data = {k: ", ".join(v) if isinstance(v, list) else v for k, v in data.items()}
 
181
  return self.process_template(self.input_format, inputs)
182
  except KeyError as e:
183
  raise KeyError(
184
+ f"Available inputs are {list(inputs.keys())} but input format requires a different ones: '{self.input_format}'"
185
+ ) from e
186
 
187
  def process_outputs(self, outputs: Dict[str, object]) -> str:
188
  try:
189
  return self.process_template(self.output_format, outputs)
190
  except KeyError as e:
191
  raise KeyError(
192
+ f"Available outputs are {outputs.keys()} but output format requires a different one: {self.output_format}"
193
+ ) from e
194
+
195
+ def get_postprocessors(self) -> List[str]:
196
+ return self.postprocessors
197
+
198
+
199
+ class YesNoTemplate(Template):
200
+ """A template for generating binary Yes/No questions asking whether an input text is of a specific class.
201
+
202
+ input_format:
203
+ Defines the format of the question.
204
+ class_field:
205
+ Defines the field that contains the name of the class that this template
206
+ asks of.
207
+ label_field:
208
+ Defines the field which contains the true label of the input text. If a gold label is equal to the
209
+ value in class_name, then the correct output is self.yes_answer (by default, "Yes").
210
+ Otherwise the correct output is self.no_answer (by default, "No").
211
+ yes_answer:
212
+ The output value for when the gold label equals self.class_name.
213
+ Defaults to "Yes".
214
+ no_answer:
215
+ The output value for when the gold label differs from self.class_name.
216
+ Defaults to "No".
217
+ """
218
+
219
+ input_format: str = None
220
+ class_field: str = None
221
+ label_field: str = None
222
+ yes_answer: str = "Yes"
223
+ no_answer: str = "No"
224
+ postprocessors: List[str] = field(
225
+ default_factory=lambda: ["processors.to_string_stripped"]
226
+ )
227
+
228
+ def process_inputs(self, inputs: Dict[str, object]) -> str:
229
+ try:
230
+ data = {
231
+ k: ", ".join(v) if isinstance(v, list) else v for k, v in inputs.items()
232
+ }
233
+ return self.input_format.format(**data)
234
+ except KeyError as e:
235
+ raise RuntimeError(
236
+ f"Available inputs are {list(inputs.keys())} but input format requires a different one: {self.input_format}"
237
+ ) from e
238
+
239
+ def process_outputs(self, outputs: Dict[str, object]) -> str:
240
+ try:
241
+ gold_class_names = outputs[self.label_field]
242
+ except KeyError as e:
243
+ raise RuntimeError(
244
+ f"Available outputs are {list(outputs.keys())}, missing required label field: '{self.label_field}'."
245
+ ) from e
246
+ if not isinstance(gold_class_names, list) or not gold_class_names:
247
+ raise RuntimeError(
248
+ f"Unexpected value for gold_class_names: '{gold_class_names}'. Expected a non-empty list."
249
  )
250
+ try:
251
+ queried_class_names = outputs[self.class_field]
252
+ except KeyError as e:
253
+ raise RuntimeError(
254
+ f"Available outputs are {list(outputs.keys())}, missing required class field: '{self.class_field}'."
255
+ ) from e
256
+ if (
257
+ not queried_class_names
258
+ or not isinstance(queried_class_names, list)
259
+ or not len(queried_class_names) == 1
260
+ ):
261
+ raise RuntimeError(
262
+ f"Unexpected value for queried_class_names: '{queried_class_names}'. Expected a list with one item."
263
+ )
264
+ queried_class_name = queried_class_names[0]
265
+ if queried_class_name in gold_class_names:
266
+ return self.yes_answer
267
+
268
+ return self.no_answer
269
 
270
  def get_postprocessors(self) -> List[str]:
271
  return self.postprocessors
 
278
  outputs_key_val_seperator: str = ": "
279
  use_keys_for_outputs: bool = False
280
 
281
+ postprocessors: List[str] = field(
282
+ default_factory=lambda: ["processors.to_string_stripped"]
283
+ )
284
 
285
+ def process_dict(
286
+ self, dic: Dict[str, object], key_val_sep, pairs_sep, use_keys
287
+ ) -> str:
288
+ dic = {
289
+ k: ", ".join([str(vi) for vi in v]) if isinstance(v, list) else v
290
+ for k, v in dic.items()
291
+ }
292
  pairs = []
293
  for key, val in dic.items():
294
  key_val = [key, val] if use_keys else [val]
 
318
  class OutputQuantizingTemplate(InputOutputTemplate):
319
  quantum: float = 0.1
320
 
321
+ def process_outputs(self, outputs: Dict[str, object]) -> str:
322
  quantized_outputs = {
323
+ key: round(input_float / self.quantum) * self.quantum
324
+ for key, input_float in outputs.items()
325
  }
326
  return super().process_outputs(quantized_outputs)
327
 
 
333
  output_format = "{labels}"
334
  empty_label = "None"
335
 
336
+ def process_outputs(self, outputs: Dict[str, object]) -> str:
337
  labels = outputs[self.labels_field]
338
  if len(labels) == 0:
339
  labels = [self.empty_label]
340
  labels_str = self.labels_seprator.join(labels)
341
+ return super().process_outputs({self.labels_field: labels_str})
342
+
343
+
344
+ class MultiReferenceTemplate(InputOutputTemplate):
345
+ references_field: str = "references"
346
+ is_multi_reference = True
347
+
348
+ def process_outputs(self, outputs: Dict[str, object]) -> List[str]:
349
+ references = outputs[self.references_field]
350
+ if not isoftype(references, List[str]):
351
+ raise ValueError(
352
+ f"MultiReferenceTemplate requires that references field {self.references_field} is of type List[str]."
353
+ )
354
+ return references
355
 
356
 
357
  def escape_chars(s, chars_to_escape):
 
407
 
408
 
409
  class SpanLabelingJsonTemplate(SpanLabelingBaseTemplate):
410
+ postprocessors = [
411
+ "processors.load_json",
412
+ "processors.dict_of_lists_to_value_key_pairs",
413
+ ]
414
 
415
  def span_label_pairs_to_targets(self, span_label_pairs):
416
  groups = {}
417
  for span, label in span_label_pairs:
418
  if label not in groups:
419
+ groups[label] = []
420
  groups[label].append(span)
421
  if len(groups) > 0:
422
  targets = [json.dumps(groups)]
 
429
  def infer_input_format(self, inputs):
430
  input_format = ""
431
  for key in inputs.keys():
432
+ name = " ".join(
433
+ word.lower().capitalize() for word in split_words(key) if word != " "
434
+ )
435
  input_format += name + ": " + "{" + key + "}" + "\n"
436
  self.input_format = input_format
437
 
 
448
  return self.input_format is not None and self.output_format is not None
449
 
450
 
 
 
 
451
  class TemplatesList(ListCollection):
452
  def verify(self):
453
  for template in self.items:
454
  assert isinstance(template, Template)
455
 
456
 
457
+ def outputs_inputs2templates(
458
+ inputs: Union[str, List], outputs: Union[str, List]
459
+ ) -> TemplatesList:
460
+ """Combines input and output formats into their dot product.
461
+
462
  :param inputs: list of input formats (or one)
463
  :param outputs: list of output formats (or one)
464
+ :return: TemplatesList of InputOutputTemplate.
465
  """
466
  templates = []
467
  if isinstance(inputs, str):
 
482
  def instructions2templates(
483
  instructions: List[TextualInstruction], templates: List[InputOutputTemplate]
484
  ) -> TemplatesList:
485
+ """Insert instructions into per demonstration templates.
486
+
487
  :param instructions:
488
  :param templates: strings containing {instuction} where the instruction should be placed
489
  :return:
 
493
  for template in templates:
494
  res_templates.append(
495
  InputOutputTemplate(
496
+ input_format=template.input_format.replace(
497
+ "{instruction}", instruction.text
498
+ ),
499
  output_format=template.output_format,
500
  )
501
  )
 
504
 
505
  class TemplatesDict(Dict):
506
  def verify(self):
507
+ for _key, template in self.items():
508
  assert isinstance(template, Template)