Elron commited on
Commit
129744e
·
verified ·
1 Parent(s): 3c38cbc

Upload templates.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. templates.py +92 -75
templates.py CHANGED
@@ -9,6 +9,15 @@ from .random_utils import new_random_generator
9
  from .type_utils import isoftype
10
 
11
 
 
 
 
 
 
 
 
 
 
12
  class Template(StreamInstanceOperator):
13
  """The role of template is to take the fields of every instance and verbalize it.
14
 
@@ -26,8 +35,18 @@ class Template(StreamInstanceOperator):
26
  postprocessors: List[str] = NonPositionalField(
27
  default_factory=lambda: ["processors.to_string_stripped"]
28
  )
29
- instruction: str = NonPositionalField(default_factory=lambda: "")
30
- target_prefix: str = NonPositionalField(default_factory=lambda: "")
 
 
 
 
 
 
 
 
 
 
31
 
32
  def process(
33
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
@@ -43,7 +62,12 @@ class Template(StreamInstanceOperator):
43
  inputs = instance.get("inputs")
44
  outputs = instance.get("outputs")
45
 
46
- source, instruction = self.inputs_to_source(inputs)
 
 
 
 
 
47
  target, references = self.outputs_to_target_and_references(outputs)
48
 
49
  return {
@@ -52,13 +76,17 @@ class Template(StreamInstanceOperator):
52
  "target": target,
53
  "references": references,
54
  "instruction": instruction,
55
- "target_prefix": self.target_prefix.format(**inputs),
56
  }
57
 
58
  @abstractmethod
59
  def inputs_to_source(self, inputs: Dict[str, object]) -> Tuple[str, str]:
60
  pass
61
 
 
 
 
 
62
  @abstractmethod
63
  def outputs_to_target_and_references(
64
  self, outputs: Dict[str, object]
@@ -68,6 +96,24 @@ class Template(StreamInstanceOperator):
68
  def get_postprocessors(self) -> List[str]:
69
  return self.postprocessors
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  class InputOutputTemplate(Template):
73
  """Generate field 'source' from fields designated as input, and fields 'target' and 'references' from fields designated as output, of the processed instance.
@@ -78,30 +124,15 @@ class InputOutputTemplate(Template):
78
  input_format: str = None
79
  output_format: str = None
80
 
81
- def process_template(self, template: str, data: Dict[str, object]) -> str:
82
- data = {k: ", ".join(v) if isinstance(v, list) else v for k, v in data.items()}
83
- return template.format(**data)
84
-
85
  def inputs_to_source(self, inputs: Dict[str, object]) -> Tuple[str, str]:
86
- formatted = []
87
- for formatting in [self.input_format, self.instruction]:
88
- try:
89
- formatted.append(self.process_template(formatting, inputs))
90
- except KeyError as e:
91
- raise KeyError(
92
- f"Available inputs are {list(inputs.keys())} but input format requires a different ones: '{formatting}'"
93
- ) from e
94
-
95
- return tuple(formatted)
96
 
97
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
98
- try:
99
- target = self.process_template(self.output_format, outputs)
100
- except KeyError as e:
101
- raise KeyError(
102
- f"Available outputs are {outputs.keys()} but output format requires a different one: {self.output_format}"
103
- ) from e
104
-
105
  references = [target]
106
  return target, references
107
 
@@ -110,19 +141,13 @@ class InputOutputReferenceTemplate(InputOutputTemplate):
110
  reference: str
111
 
112
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
113
- output_fields = {}
114
- for name, val in [
115
- ("target", self.output_format),
116
- ("reference", self.reference),
117
- ]:
118
- try:
119
- result = self.process_template(val, outputs)
120
- output_fields[name] = result
121
- except KeyError as e:
122
- raise KeyError(
123
- f"Available outputs are {outputs.keys()} but {name} requires a different one: {val}"
124
- ) from e
125
- return output_fields["target"], [output_fields["reference"]]
126
 
127
 
128
  class MultipleChoiceTemplate(Template):
@@ -135,7 +160,6 @@ class MultipleChoiceTemplate(Template):
135
  choices_seperator: str = ", "
136
  source_choice_format: str = "{choice_numeral}. {choice_text}"
137
  target_choice_format: str = "{choice_numeral}"
138
- add_numerals_as_field: str = None
139
  enumerator: str = "capitals"
140
 
141
  def prepare(self):
@@ -170,7 +194,7 @@ class MultipleChoiceTemplate(Template):
170
  "XX",
171
  ]
172
 
173
- def get_choices(self, data: Dict[str, object], choice_format: str) -> str:
174
  choices = data[self.choices_field]
175
  enumrated_choices = []
176
  for i, choice in enumerate(choices):
@@ -182,22 +206,28 @@ class MultipleChoiceTemplate(Template):
182
  )
183
  return enumrated_choices
184
 
185
- def inputs_to_source(self, inputs: Dict[str, object]) -> Tuple[str, str]:
186
- choices = self.get_choices(inputs, self.source_choice_format)
187
- inputs = {
188
- "numerals": ",".join(self.get_choices(inputs, "{choice_numeral}")),
 
 
 
 
 
189
  **inputs,
190
  self.choices_field: self.choices_seperator.join(choices),
191
  }
192
- formatted = []
193
- for formatting in [self.input_format, self.instruction]:
194
- try:
195
- formatted.append(formatting.format(**inputs))
196
- except KeyError as e:
197
- raise KeyError(
198
- f"Available inputs are {inputs.keys()} but input format requires a different one: {formatting}"
199
- ) from e
200
- return tuple(formatted)
 
201
 
202
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
203
  target = outputs[self.target_field]
@@ -210,7 +240,7 @@ class MultipleChoiceTemplate(Template):
210
  f"MultipleChoiceTemplate could not locate textual target '{target}' in choices list: {outputs[self.choices_field]}"
211
  ) from e
212
 
213
- choices = self.get_choices(outputs, self.target_choice_format)
214
 
215
  try:
216
  target = choices[target]
@@ -226,7 +256,7 @@ class MultipleChoiceTemplate(Template):
226
  ) -> Dict[str, Any]:
227
  result = super().process(instance, stream_name)
228
  if "options" not in result["outputs"]:
229
- result["outputs"]["options"] = self.get_choices(
230
  instance["outputs"], self.target_choice_format
231
  )
232
  return result
@@ -259,18 +289,9 @@ class YesNoTemplate(Template):
259
  no_answer: str = "No"
260
 
261
  def inputs_to_source(self, inputs: Dict[str, object]) -> Tuple[str, str]:
262
- data = {
263
- k: ", ".join(v) if isinstance(v, list) else v for k, v in inputs.items()
264
- }
265
- formatted = []
266
- for formatting in [self.input_format, self.instruction]:
267
- try:
268
- formatted.append(formatting.format(**data))
269
- except KeyError as e:
270
- raise RuntimeError(
271
- f"Available inputs are {list(inputs.keys())} but input format requires a different one: {formatting}"
272
- ) from e
273
- return tuple(formatted)
274
 
275
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
276
  try:
@@ -316,26 +337,22 @@ class KeyValTemplate(Template):
316
  use_keys_for_outputs: bool = False
317
 
318
  def process_dict(
319
- self, dic: Dict[str, object], key_val_sep, pairs_sep, use_keys
320
  ) -> str:
321
- dic = {
322
- k: ", ".join([str(vi) for vi in v]) if isinstance(v, list) else v
323
- for k, v in dic.items()
324
- }
325
  pairs = []
326
- for key, val in dic.items():
327
  key_val = [key, str(val)] if use_keys else [str(val)]
328
  pairs.append(key_val_sep.join(key_val))
329
  return pairs_sep.join(pairs)
330
 
331
  def inputs_to_source(self, inputs: Dict[str, object]) -> Tuple[str, str]:
332
- ret = self.process_dict(
333
  inputs,
334
  key_val_sep=self.key_val_seperator,
335
  pairs_sep=self.pairs_seperator,
336
  use_keys=self.use_keys_for_inputs,
337
  )
338
- return (ret, ret)
339
 
340
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
341
  target = self.process_dict(
 
9
  from .type_utils import isoftype
10
 
11
 
12
+ class TemplateFormatKeyError(KeyError):
13
+ def __init__(self, template, data, data_type, format_str, format_name):
14
+ keys = ", ".join(data.keys())
15
+ super().__init__(
16
+ f"Available {data_type}s are [{keys}] "
17
+ f"but {template.__class__.__name__}.{format_name} format requires a different ones: '{format_str}'"
18
+ )
19
+
20
+
21
  class Template(StreamInstanceOperator):
22
  """The role of template is to take the fields of every instance and verbalize it.
23
 
 
35
  postprocessors: List[str] = NonPositionalField(
36
  default_factory=lambda: ["processors.to_string_stripped"]
37
  )
38
+ instruction: str = NonPositionalField(default="")
39
+ target_prefix: str = NonPositionalField(default="")
40
+ title_fields: List[str] = NonPositionalField(default_factory=list)
41
+
42
+ def inputs_to_instruction_and_target_prefix(self, inputs):
43
+ instruction = self.apply_formatting(
44
+ inputs, "input", self.instruction, "instruction", serialize=True
45
+ )
46
+ target_prefix = self.apply_formatting(
47
+ inputs, "input", self.target_prefix, "target_prefix", serialize=True
48
+ )
49
+ return instruction, target_prefix
50
 
51
  def process(
52
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
 
62
  inputs = instance.get("inputs")
63
  outputs = instance.get("outputs")
64
 
65
+ self.set_titles(inputs)
66
+
67
+ source = self.inputs_to_source(inputs)
68
+ instruction, target_prefix = self.inputs_to_instruction_and_target_prefix(
69
+ inputs
70
+ )
71
  target, references = self.outputs_to_target_and_references(outputs)
72
 
73
  return {
 
76
  "target": target,
77
  "references": references,
78
  "instruction": instruction,
79
+ "target_prefix": target_prefix,
80
  }
81
 
82
  @abstractmethod
83
  def inputs_to_source(self, inputs: Dict[str, object]) -> Tuple[str, str]:
84
  pass
85
 
86
+ def set_titles(self, data):
87
+ for field in self.title_fields:
88
+ data[field] = data[field].title()
89
+
90
  @abstractmethod
91
  def outputs_to_target_and_references(
92
  self, outputs: Dict[str, object]
 
96
  def get_postprocessors(self) -> List[str]:
97
  return self.postprocessors
98
 
99
+ def serialize_data(self, data):
100
+ return {
101
+ k: ", ".join(str(t) for t in v) if isinstance(v, list) else v
102
+ for k, v in data.items()
103
+ }
104
+
105
+ def apply_formatting(
106
+ self, data, data_type, format_str, format_name, serialize=False
107
+ ) -> str:
108
+ if serialize:
109
+ data = self.serialize_data(data)
110
+ try:
111
+ return format_str.format(**data)
112
+ except KeyError as e:
113
+ raise TemplateFormatKeyError(
114
+ self, data, data_type, format_str, format_name
115
+ ) from e
116
+
117
 
118
  class InputOutputTemplate(Template):
119
  """Generate field 'source' from fields designated as input, and fields 'target' and 'references' from fields designated as output, of the processed instance.
 
124
  input_format: str = None
125
  output_format: str = None
126
 
 
 
 
 
127
  def inputs_to_source(self, inputs: Dict[str, object]) -> Tuple[str, str]:
128
+ return self.apply_formatting(
129
+ inputs, "input", self.input_format, "input_format", serialize=True
130
+ )
 
 
 
 
 
 
 
131
 
132
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
133
+ target = self.apply_formatting(
134
+ outputs, "output", self.output_format, "output_format", serialize=True
135
+ )
 
 
 
 
136
  references = [target]
137
  return target, references
138
 
 
141
  reference: str
142
 
143
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
144
+ target = self.apply_formatting(
145
+ outputs, "output", self.output_format, "output_format", serialize=True
146
+ )
147
+ reference = self.apply_formatting(
148
+ outputs, "output", self.reference, "reference", serialize=True
149
+ )
150
+ return target, [reference]
 
 
 
 
 
 
151
 
152
 
153
  class MultipleChoiceTemplate(Template):
 
160
  choices_seperator: str = ", "
161
  source_choice_format: str = "{choice_numeral}. {choice_text}"
162
  target_choice_format: str = "{choice_numeral}"
 
163
  enumerator: str = "capitals"
164
 
165
  def prepare(self):
 
194
  "XX",
195
  ]
196
 
197
+ def inputs_to_choices(self, data: Dict[str, object], choice_format: str) -> str:
198
  choices = data[self.choices_field]
199
  enumrated_choices = []
200
  for i, choice in enumerate(choices):
 
206
  )
207
  return enumrated_choices
208
 
209
+ def inputs_to_numerals(self, inputs: Dict[str, object]) -> Tuple[str, str]:
210
+ return self.inputs_to_choices(inputs, "{choice_numeral}")
211
+
212
+ def prepare_multiple_choice_inputs(
213
+ self, inputs: Dict[str, object]
214
+ ) -> Dict[str, object]:
215
+ choices = self.inputs_to_choices(inputs, self.source_choice_format)
216
+ return {
217
+ "numerals": self.inputs_to_numerals(inputs),
218
  **inputs,
219
  self.choices_field: self.choices_seperator.join(choices),
220
  }
221
+
222
+ def inputs_to_source(self, inputs: Dict[str, object]) -> Tuple[str, str]:
223
+ inputs = self.prepare_multiple_choice_inputs(inputs)
224
+ return self.apply_formatting(
225
+ inputs, "input", self.input_format, "input_format", serialize=True
226
+ )
227
+
228
+ def inputs_to_instruction_and_target_prefix(self, inputs):
229
+ inputs = self.prepare_multiple_choice_inputs(inputs)
230
+ return super().inputs_to_instruction_and_target_prefix(inputs)
231
 
232
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
233
  target = outputs[self.target_field]
 
240
  f"MultipleChoiceTemplate could not locate textual target '{target}' in choices list: {outputs[self.choices_field]}"
241
  ) from e
242
 
243
+ choices = self.inputs_to_choices(outputs, self.target_choice_format)
244
 
245
  try:
246
  target = choices[target]
 
256
  ) -> Dict[str, Any]:
257
  result = super().process(instance, stream_name)
258
  if "options" not in result["outputs"]:
259
+ result["outputs"]["options"] = self.inputs_to_choices(
260
  instance["outputs"], self.target_choice_format
261
  )
262
  return result
 
289
  no_answer: str = "No"
290
 
291
  def inputs_to_source(self, inputs: Dict[str, object]) -> Tuple[str, str]:
292
+ return self.apply_formatting(
293
+ inputs, "input", self.input_format, "input_format", serialize=True
294
+ )
 
 
 
 
 
 
 
 
 
295
 
296
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
297
  try:
 
337
  use_keys_for_outputs: bool = False
338
 
339
  def process_dict(
340
+ self, data: Dict[str, object], key_val_sep, pairs_sep, use_keys
341
  ) -> str:
342
+ data = self.serialize_data(data)
 
 
 
343
  pairs = []
344
+ for key, val in data.items():
345
  key_val = [key, str(val)] if use_keys else [str(val)]
346
  pairs.append(key_val_sep.join(key_val))
347
  return pairs_sep.join(pairs)
348
 
349
  def inputs_to_source(self, inputs: Dict[str, object]) -> Tuple[str, str]:
350
+ return self.process_dict(
351
  inputs,
352
  key_val_sep=self.key_val_seperator,
353
  pairs_sep=self.pairs_seperator,
354
  use_keys=self.use_keys_for_inputs,
355
  )
 
356
 
357
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
358
  target = self.process_dict(