Elron commited on
Commit
dae2dfd
1 Parent(s): 48f5034

Upload templates.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. templates.py +203 -0
templates.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .text_utils import split_words
2
+ from .artifact import Artifact
3
+ from .operator import StreamInstanceOperator, InstanceOperatorWithGlobalAccess
4
+ from .instructions import Instruction
5
+
6
+ import random
7
+ from typing import Dict, Any, List
8
+ from abc import ABC, abstractmethod
9
+
10
+
11
+ class Renderer(ABC):
12
+ @abstractmethod
13
+ def get_postprocessors(self) -> List[str]:
14
+ pass
15
+
16
+
17
+ class Template(Artifact):
18
+ @abstractmethod
19
+ def process_inputs(self, inputs: Dict[str, object]) -> Dict[str, object]:
20
+ pass
21
+
22
+ @abstractmethod
23
+ def process_outputs(self, outputs: Dict[str, object]) -> Dict[str, object]:
24
+ pass
25
+
26
+ @abstractmethod
27
+ def get_postprocessors(self) -> List[str]:
28
+ pass
29
+
30
+
31
+ class RenderFormatTemplate(Renderer, StreamInstanceOperator):
32
+ template: Template = None
33
+ random_reference: bool = False
34
+
35
+ def verify(self):
36
+ assert isinstance(self.template, Template), "Template must be an instance of Template"
37
+ assert self.template is not None, "Template must be specified"
38
+
39
+ def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
40
+ return self.render(instance)
41
+
42
+ def render(self, instance: Dict[str, Any]) -> Dict[str, Any]:
43
+ inputs = instance.pop("inputs")
44
+ outputs = instance.pop("outputs")
45
+
46
+ source = self.template.process_inputs(inputs)
47
+
48
+ key, targets = next(iter(outputs.items()))
49
+ if not isinstance(targets, list):
50
+ targets = [targets]
51
+
52
+ references = [self.template.process_outputs({key: target}) for target in targets]
53
+
54
+ if self.random_reference:
55
+ target = random.choice(references)
56
+ else:
57
+ if len(references) == 0:
58
+ raise ValueError("No references found")
59
+ target = references[0] # what
60
+
61
+ return {
62
+ **instance,
63
+ "source": source,
64
+ "target": target,
65
+ "references": references,
66
+ }
67
+
68
+ def get_postprocessors(self) -> List[str]:
69
+ return self.template.get_postprocessors()
70
+
71
+
72
+ class RenderAutoFormatTemplate(RenderFormatTemplate):
73
+ def prepare(self):
74
+ if self.template is None:
75
+ self.template = AutoInputOutputTemplate()
76
+ elif isinstance(self.template, InputOutputTemplate):
77
+ self.template = AutoInputOutputTemplate(
78
+ input_format=self.template.input_format,
79
+ output_format=self.template.output_format,
80
+ )
81
+ else:
82
+ raise ValueError(
83
+ f"Template must be an instance of InputOutputTemplate or AutoInputOutputTemplate, got {type(self.template)}"
84
+ )
85
+
86
+ def render(self, instance: Dict[str, object]) -> Dict[str, object]:
87
+ if not self.template.is_complete():
88
+ self.template.infer_missing(instance["inputs"], instance["outputs"])
89
+
90
+ inputs = {key: value for key, value in instance["inputs"].items()}
91
+
92
+ return super().render({**instance, "inputs": inputs})
93
+
94
+
95
+ class CharacterSizeLimiter(Artifact):
96
+ limit: int = 1000
97
+
98
+ def check(self, text: str) -> bool:
99
+ return len(text) <= self.limit
100
+
101
+
102
+ class RenderTemplatedICL(RenderAutoFormatTemplate):
103
+ instruction: Instruction = None
104
+ input_prefix: str = "Input: "
105
+ output_prefix: str = "Output: "
106
+ instruction_prefix: str = ""
107
+ demos_field: str = None
108
+ size_limiter: Artifact = None
109
+ input_output_separator: str = "\n"
110
+ demo_separator: str = "\n\n"
111
+ demos_cache = None
112
+
113
+ def verify(self):
114
+ assert self.demos_cache is None
115
+
116
+ def render(self, instance: Dict[str, object]) -> Dict[str, object]:
117
+ if self.demos_cache is None:
118
+ self.demos_cache = instance.pop(self.demos_field, [])
119
+ else:
120
+ instance.pop(self.demos_field, None)
121
+
122
+ source = ""
123
+
124
+ example = super().render(instance)
125
+
126
+ input_str = self.input_prefix + example["source"] + self.input_output_separator + self.output_prefix
127
+
128
+ if self.instruction is not None:
129
+ source += self.instruction_prefix + self.instruction() + self.demo_separator
130
+
131
+ for demo_instance in self.demos_cache:
132
+ demo_example = super().render(demo_instance)
133
+ demo_str = (
134
+ self.input_prefix
135
+ + demo_example["source"]
136
+ + self.input_output_separator
137
+ + self.output_prefix
138
+ + demo_example["target"]
139
+ + self.demo_separator
140
+ )
141
+
142
+ if self.size_limiter is not None:
143
+ if not self.size_limiter.check(source + demo_str + input_str + example["target"]):
144
+ continue
145
+
146
+ source += demo_str
147
+
148
+ source += input_str
149
+
150
+ return {
151
+ **example,
152
+ "source": source,
153
+ }
154
+
155
+
156
+ class InputOutputTemplate(Template):
157
+ input_format: str = None
158
+ output_format: str = None
159
+
160
+ def process_inputs(self, inputs: Dict[str, object]) -> Dict[str, object]:
161
+ return self.input_format.format(**inputs)
162
+
163
+ def process_outputs(self, outputs: Dict[str, object]) -> Dict[str, object]:
164
+ return self.output_format.format(**outputs)
165
+
166
+ def get_postprocessors(self) -> List[str]:
167
+ return ["to_string"]
168
+
169
+
170
+ class AutoInputOutputTemplate(InputOutputTemplate):
171
+ def infer_input_format(self, inputs):
172
+ input_format = ""
173
+ for key in inputs.keys():
174
+ name = " ".join(word.lower().capitalize() for word in split_words(key) if word != " ")
175
+ input_format += name + ": " + "{" + key + "}" + "\n"
176
+ self.input_format = input_format
177
+
178
+ def infer_output_format(self, outputs):
179
+ self.output_format = "{" + next(iter(outputs.keys())) + "}"
180
+
181
+ def infer_missing(self, inputs, outputs):
182
+ if self.input_format is None:
183
+ self.infer_input_format(inputs)
184
+ if self.output_format is None:
185
+ self.infer_output_format(outputs)
186
+
187
+ def is_complete(self):
188
+ return self.input_format is not None and self.output_format is not None
189
+
190
+
191
+ from .collections import ListCollection
192
+
193
+
194
+ class TemplatesList(ListCollection):
195
+ def verify(self):
196
+ for template in self.items:
197
+ assert isinstance(template, Template)
198
+
199
+
200
+ class TemplatesDict(Dict):
201
+ def verify(self):
202
+ for key, template in self.items():
203
+ assert isinstance(template, Template)