Elron commited on
Commit
c77cd1f
1 Parent(s): f60252a

Upload templates.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. templates.py +57 -15
templates.py CHANGED
@@ -1,6 +1,7 @@
 
1
  from abc import ABC, abstractmethod
2
  from dataclasses import field
3
- from typing import Any, Dict, List, Union
4
 
5
  from .artifact import Artifact
6
  from .dataclass import NonPositionalField
@@ -101,7 +102,8 @@ class CharacterSizeLimiter(Artifact):
101
  class RenderTemplatedICL(RenderAutoFormatTemplate):
102
  instruction: Instruction = None
103
  input_prefix: str = "Input: "
104
- output_prefix: str = "Output: "
 
105
  instruction_prefix: str = ""
106
  demos_field: str = None
107
  size_limiter: Artifact = None
@@ -127,6 +129,7 @@ class RenderTemplatedICL(RenderAutoFormatTemplate):
127
  + demo_example["source"]
128
  + self.input_output_separator
129
  + self.output_prefix
 
130
  + demo_example["target"]
131
  + self.demo_separator
132
  )
@@ -151,6 +154,7 @@ class InputOutputTemplate(Template):
151
  postprocessors: List[str] = field(default_factory=lambda: ["processors.to_string"])
152
 
153
  def process_template(self, template: str, data: Dict[str, object]) -> str:
 
154
  return template.format(**data)
155
 
156
  def process_inputs(self, inputs: Dict[str, object]) -> str:
@@ -198,14 +202,19 @@ class MultiLabelTemplate(InputOutputTemplate):
198
  return super().process_outputs({"labels": labels_str})
199
 
200
 
201
- class SpanLabelingTemplate(MultiLabelTemplate):
 
 
 
 
 
 
202
  spans_starts_field: str = "spans_starts"
203
  spans_ends_field: str = "spans_ends"
204
  text_field: str = "text"
205
- span_label_format: str = "{span}: {label}"
206
- postprocessors = ["processors.to_span_label_pairs"]
207
 
208
- def process_outputs(self, outputs: Dict[str, object]) -> Dict[str, object]:
209
  spans_starts = outputs[self.spans_starts_field]
210
  spans_ends = outputs[self.spans_ends_field]
211
  text = outputs[self.text_field]
@@ -213,19 +222,52 @@ class SpanLabelingTemplate(MultiLabelTemplate):
213
 
214
  spans = []
215
  for span_start, span_end, label in zip(spans_starts, spans_ends, labels):
216
- spans.append((span_start, span_end, label))
 
217
 
218
- spans.sort(key=lambda span: span[0])
 
 
219
 
220
- text_spans = []
221
- for span in spans:
222
- text_spans.append(text[span[0] : span[1]])
 
223
 
224
- targets = []
225
- for span, label in zip(text_spans, labels):
226
- targets.append(self.span_label_format.format(span=span, label=label))
227
 
228
- return super().process_outputs({"labels": targets})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
 
231
  class AutoInputOutputTemplate(InputOutputTemplate):
 
1
+ import json
2
  from abc import ABC, abstractmethod
3
  from dataclasses import field
4
+ from typing import Any, Dict, List, Optional, Union
5
 
6
  from .artifact import Artifact
7
  from .dataclass import NonPositionalField
 
102
  class RenderTemplatedICL(RenderAutoFormatTemplate):
103
  instruction: Instruction = None
104
  input_prefix: str = "Input: "
105
+ output_prefix: str = "Output:"
106
+ target_prefix: str = " "
107
  instruction_prefix: str = ""
108
  demos_field: str = None
109
  size_limiter: Artifact = None
 
129
  + demo_example["source"]
130
  + self.input_output_separator
131
  + self.output_prefix
132
+ + self.target_prefix
133
  + demo_example["target"]
134
  + self.demo_separator
135
  )
 
154
  postprocessors: List[str] = field(default_factory=lambda: ["processors.to_string"])
155
 
156
  def process_template(self, template: str, data: Dict[str, object]) -> str:
157
+ data = {k: ", ".join(v) if isinstance(v, list) else v for k, v in data.items()}
158
  return template.format(**data)
159
 
160
  def process_inputs(self, inputs: Dict[str, object]) -> str:
 
202
  return super().process_outputs({"labels": labels_str})
203
 
204
 
205
+ def escape_chars(s, chars_to_escape):
206
+ for char in chars_to_escape:
207
+ s = s.replace(char, f"\\{char}")
208
+ return s
209
+
210
+
211
+ class SpanLabelingBaseTemplate(MultiLabelTemplate):
212
  spans_starts_field: str = "spans_starts"
213
  spans_ends_field: str = "spans_ends"
214
  text_field: str = "text"
215
+ labels_support: list = None
 
216
 
217
+ def extract_span_label_pairs(self, outputs):
218
  spans_starts = outputs[self.spans_starts_field]
219
  spans_ends = outputs[self.spans_ends_field]
220
  text = outputs[self.text_field]
 
222
 
223
  spans = []
224
  for span_start, span_end, label in zip(spans_starts, spans_ends, labels):
225
+ if self.labels_support is None or label in self.labels_support:
226
+ spans.append((span_start, span_end, text[span_start:span_end], label))
227
 
228
+ for span in sorted(spans):
229
+ if self.labels_support is None or span[3] in self.labels_support:
230
+ yield span[2], span[3]
231
 
232
+ def process_outputs(self, outputs: Dict[str, object]) -> Dict[str, object]:
233
+ span_lables_pairs = self.extract_span_label_pairs(outputs)
234
+ targets = self.span_label_pairs_to_targets(span_lables_pairs)
235
+ return super().process_outputs({"labels": targets})
236
 
237
+ @abstractmethod
238
+ def span_label_pairs_to_targets(self, pairs):
239
+ pass
240
 
241
+
242
+ class SpanLabelingTemplate(SpanLabelingBaseTemplate):
243
+ span_label_format: str = "{span}: {label}"
244
+ escape_characters: List[str] = [":", ","]
245
+ postprocessors = ["processors.to_span_label_pairs"]
246
+
247
+ def span_label_pairs_to_targets(self, span_label_pairs):
248
+ targets = []
249
+ for span, label in span_label_pairs:
250
+ if self.escape_characters is not None:
251
+ span = escape_chars(span, self.escape_characters)
252
+ target = self.span_label_format.format(span=span, label=label)
253
+ targets.append(target)
254
+ return targets
255
+
256
+
257
+ class SpanLabelingJsonTemplate(SpanLabelingBaseTemplate):
258
+ postprocessors = ["processors.load_json", "processors.dict_of_lists_to_value_key_pairs"]
259
+
260
+ def span_label_pairs_to_targets(self, span_label_pairs):
261
+ groups = {}
262
+ for span, label in span_label_pairs:
263
+ if label not in groups:
264
+ groups[label] = list()
265
+ groups[label].append(span)
266
+ if len(groups) > 0:
267
+ targets = [json.dumps(groups)]
268
+ else:
269
+ targets = []
270
+ return targets
271
 
272
 
273
  class AutoInputOutputTemplate(InputOutputTemplate):