File size: 3,980 Bytes
8435bb6 00a2077 8435bb6 00a2077 1f634e4 00a2077 8435bb6 1f634e4 8435bb6 00a2077 8435bb6 00a2077 8435bb6 00a2077 8435bb6 00a2077 1f634e4 00a2077 8435bb6 071352b 00a2077 8435bb6 00a2077 8435bb6 00a2077 8435bb6 00a2077 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from abc import ABC
from typing import Any, Dict, List, Optional
from .dataclass import InternalField
from .formats import Format, ICLFormat
from .instructions import Instruction
from .operator import Operator, SequentialOperator, StreamInstanceOperator
from .random_utils import get_random
from .templates import Template
class Renderer(ABC):
pass
# @abstractmethod
# def get_postprocessors(self) -> List[str]:
# pass
class RenderTemplate(Renderer, StreamInstanceOperator):
template: Template
random_reference: bool = False
skip_rendered_instance: bool = True
def process(
self, instance: Dict[str, Any], stream_name: Optional[str] = None
) -> Dict[str, Any]:
if self.skip_rendered_instance:
if (
"inputs" not in instance
and "outputs" not in instance
and "source" in instance
and "target" in instance
and "references" in instance
):
return instance
inputs = instance["inputs"]
outputs = instance["outputs"]
source = self.template.process_inputs(inputs)
targets = self.template.process_outputs(outputs)
if self.template.is_multi_reference:
assert isinstance(targets, list), f"{targets} must be a list"
references = targets
if self.random_reference:
target = get_random().choice(references)
else:
if len(references) == 0:
raise ValueError("No references found")
target = references[0]
else:
references = [targets]
target = targets
instance.update(
{
"source": source,
"target": target,
"references": references,
}
)
return instance
class RenderDemonstrations(RenderTemplate):
demos_field: str
def process(
self, instance: Dict[str, Any], stream_name: Optional[str] = None
) -> Dict[str, Any]:
demos = instance.get(self.demos_field, [])
processed_demos = []
for demo_instance in demos:
demo_instance = super().process(demo_instance)
processed_demos.append(demo_instance)
instance[self.demos_field] = processed_demos
return instance
class RenderInstruction(Renderer, StreamInstanceOperator):
instruction: Instruction
def process(
self, instance: Dict[str, Any], stream_name: Optional[str] = None
) -> Dict[str, Any]:
if self.instruction is not None:
instance["instruction"] = self.instruction()
else:
instance["instruction"] = ""
return instance
class RenderFormat(Renderer, StreamInstanceOperator):
format: Format
demos_field: str = None
def process(
self, instance: Dict[str, Any], stream_name: Optional[str] = None
) -> Dict[str, Any]:
demos_instances = instance.pop(self.demos_field, None)
if demos_instances is not None:
instance["source"] = self.format.format(
instance, demos_instances=demos_instances
)
else:
instance["source"] = self.format.format(instance)
return instance
class StandardRenderer(Renderer, SequentialOperator):
template: Template
instruction: Instruction = None
demos_field: str = None
format: ICLFormat = None
steps: List[Operator] = InternalField(default_factory=list)
def prepare(self):
self.steps = [
RenderTemplate(template=self.template),
RenderDemonstrations(template=self.template, demos_field=self.demos_field),
RenderInstruction(instruction=self.instruction),
RenderFormat(format=self.format, demos_field=self.demos_field),
]
def get_postprocessors(self):
return self.template.get_postprocessors()
|