"""Program Synthesis dataset from dreamcoder. https://github.com/ellisk42/ec""" from random import choice, shuffle import datasets from dreamcoder.domains.text.makeTextTasks import makeTasks as textMakeTasks from dreamcoder.domains.list.main import main as listMakeTasks _DESCRIPTION = """\ Generated program synthesis datasets used to train dreamcoder. """ _FEATURES = datasets.Features( { "description": datasets.Value("string"), "input": datasets.Value("string"), "output": datasets.Value("string"), "types": datasets.Value("string") } ) _HOMEPAGE = "https://github.com/ellisk42/ec" _LICENSE = "MIT License" _MAX_STEPS = 10 class infIterator: def __init__(self, make_mthd): self.make_mthd = make_mthd self.i = None def reset(self): tasks = self.make_mthd() rows = [] for task in tasks: base = { 'types': str(task.request), "description": task.name, } for (inp, outp) in task.examples: rows.append(dict(input=str(inp), output=str(outp), **base)) shuffle(rows) self.rows = rows self.i = 0 def step(self): if self.i is None: self.reset() row = self.rows[self.i] self.i += 1 if self.i >= len(self.rows): self.reset() return row class ProgramSynthesis(datasets.GeneratorBasedBuilder): """Program Synthesis dataset from dreamcoder.""" VERSION = datasets.Version("1.1.0") BUILDER_CONFIGS = [ datasets.BuilderConfig(name="text", version=VERSION, description="Text tasks."), datasets.BuilderConfig(name="list", version=VERSION, description="List tasks."), datasets.BuilderConfig(name="all", version=VERSION, description="All tasks at once."), ] DEFAULT_CONFIG_NAME = "all" def _info(self): return datasets.DatasetInfo( description=_DESCRIPTION, features=_FEATURES, supervised_keys=("input", "output"), homepage=_HOMEPAGE, license=_LICENSE, ) def _split_generators(self, dl_manager): return [ datasets.SplitGenerator( name=datasets.Split.TRAIN, ), ] def _generate_examples(self): task_samples = { 'text': infIterator(textMakeTasks), 'list': infIterator(listMakeTasks) } for key in range(_MAX_STEPS): if self.config.name == 'all': dataset_type = choice(task_samples.keys()) else: dataset_type = self.config.name yield key, task_samples[dataset_type].step()