from abc import ABC from typing import Any, Dict, List, Optional from .dataclass import InternalField from .formats import Format, ICLFormat from .instructions import Instruction from .operator import Operator, SequentialOperator, StreamInstanceOperator from .random_utils import get_random from .templates import Template class Renderer(ABC): pass # @abstractmethod # def get_postprocessors(self) -> List[str]: # pass class RenderTemplate(Renderer, StreamInstanceOperator): template: Template random_reference: bool = False skip_rendered_instance: bool = True def process( self, instance: Dict[str, Any], stream_name: Optional[str] = None ) -> Dict[str, Any]: if self.skip_rendered_instance: if ( "inputs" not in instance and "outputs" not in instance and "source" in instance and "target" in instance and "references" in instance ): return instance inputs = instance["inputs"] outputs = instance["outputs"] source = self.template.process_inputs(inputs) targets = self.template.process_outputs(outputs) if self.template.is_multi_reference: assert isinstance(targets, list), f"{targets} must be a list" references = targets if self.random_reference: target = get_random().choice(references) else: if len(references) == 0: raise ValueError("No references found") target = references[0] else: references = [targets] target = targets instance.update( { "source": source, "target": target, "references": references, } ) return instance class RenderDemonstrations(RenderTemplate): demos_field: str def process( self, instance: Dict[str, Any], stream_name: Optional[str] = None ) -> Dict[str, Any]: demos = instance.get(self.demos_field, []) processed_demos = [] for demo_instance in demos: demo_instance = super().process(demo_instance) processed_demos.append(demo_instance) instance[self.demos_field] = processed_demos return instance class RenderInstruction(Renderer, StreamInstanceOperator): instruction: Instruction def process( self, instance: Dict[str, Any], stream_name: Optional[str] = None ) -> Dict[str, Any]: if self.instruction is not None: instance["instruction"] = self.instruction() else: instance["instruction"] = "" return instance class RenderFormat(Renderer, StreamInstanceOperator): format: Format demos_field: str = None def process( self, instance: Dict[str, Any], stream_name: Optional[str] = None ) -> Dict[str, Any]: demos_instances = instance.pop(self.demos_field, None) if demos_instances is not None: instance["source"] = self.format.format( instance, demos_instances=demos_instances ) else: instance["source"] = self.format.format(instance) return instance class StandardRenderer(Renderer, SequentialOperator): template: Template instruction: Instruction = None demos_field: str = None format: ICLFormat = None steps: List[Operator] = InternalField(default_factory=list) def prepare(self): self.steps = [ RenderTemplate(template=self.template), RenderDemonstrations(template=self.template, demos_field=self.demos_field), RenderInstruction(instruction=self.instruction), RenderFormat(format=self.format, demos_field=self.demos_field), ] def get_postprocessors(self): return self.template.get_postprocessors()