Elron commited on
Commit
6de46af
1 Parent(s): 2ec6f71

Upload templates.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. templates.py +34 -1
templates.py CHANGED
@@ -1,4 +1,5 @@
1
  from abc import ABC, abstractmethod
 
2
  from typing import Any, Dict, List, Union
3
 
4
  from .artifact import Artifact
@@ -149,6 +150,7 @@ class RenderTemplatedICL(RenderAutoFormatTemplate):
149
  class InputOutputTemplate(Template):
150
  input_format: str = None
151
  output_format: str = None
 
152
 
153
  def process_template(self, template: str, data: Dict[str, object]) -> str:
154
  return template.format(**data)
@@ -170,7 +172,7 @@ class InputOutputTemplate(Template):
170
  )
171
 
172
  def get_postprocessors(self) -> List[str]:
173
- return ["to_string"]
174
 
175
 
176
  class OutputQuantizingTemplate(InputOutputTemplate):
@@ -183,6 +185,37 @@ class OutputQuantizingTemplate(InputOutputTemplate):
183
  return super().process_outputs(quantized_outputs)
184
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  class AutoInputOutputTemplate(InputOutputTemplate):
187
  def infer_input_format(self, inputs):
188
  input_format = ""
 
1
  from abc import ABC, abstractmethod
2
+ from dataclasses import field
3
  from typing import Any, Dict, List, Union
4
 
5
  from .artifact import Artifact
 
150
  class InputOutputTemplate(Template):
151
  input_format: str = None
152
  output_format: str = None
153
+ postprocessors: List[str] = field(default_factory=lambda: ["processors.to_string"])
154
 
155
  def process_template(self, template: str, data: Dict[str, object]) -> str:
156
  return template.format(**data)
 
172
  )
173
 
174
  def get_postprocessors(self) -> List[str]:
175
+ return self.postprocessors
176
 
177
 
178
  class OutputQuantizingTemplate(InputOutputTemplate):
 
185
  return super().process_outputs(quantized_outputs)
186
 
187
 
188
+ class SpanLabelingTemplate(InputOutputTemplate):
189
+ spans_starts_field: str = "spans_starts"
190
+ spans_ends_field: str = "spans_ends"
191
+ text_field: str = "text"
192
+ labels_field: str = "labels"
193
+ span_label_format: str = "{span}: {label}"
194
+ postprocessors = ["processors.to_span_label_pairs"]
195
+
196
+ def process_outputs(self, outputs: Dict[str, object]) -> Dict[str, object]:
197
+ spans_starts = outputs[self.spans_starts_field]
198
+ spans_ends = outputs[self.spans_ends_field]
199
+ text = outputs[self.text_field]
200
+ labels = outputs[self.labels_field]
201
+
202
+ spans = []
203
+ for span_start, span_end, label in zip(spans_starts, spans_ends, labels):
204
+ spans.append((span_start, span_end, label))
205
+
206
+ spans.sort(key=lambda span: span[0])
207
+
208
+ text_spans = []
209
+ for span in spans:
210
+ text_spans.append(text[span[0] : span[1]])
211
+
212
+ targets = []
213
+ for span, label in zip(text_spans, labels):
214
+ targets.append(self.span_label_format.format(span=span, label=label))
215
+
216
+ return super().process_outputs({"spans_and_labels": targets})
217
+
218
+
219
  class AutoInputOutputTemplate(InputOutputTemplate):
220
  def infer_input_format(self, inputs):
221
  input_format = ""