to_delete / program_synthesis.py
Fraser-Greenlee
basics
a3a8e6c
"""Program Synthesis dataset from dreamcoder. https://github.com/ellisk42/ec"""
from random import choice, shuffle
import datasets
from dreamcoder.domains.text.makeTextTasks import makeTasks as textMakeTasks
from dreamcoder.domains.list.main import main as listMakeTasks
_DESCRIPTION = """\
Generated program synthesis datasets used to train dreamcoder.
"""
_FEATURES = datasets.Features(
{
"description": datasets.Value("string"),
"input": datasets.Value("string"),
"output": datasets.Value("string"),
"types": datasets.Value("string")
}
)
_HOMEPAGE = "https://github.com/ellisk42/ec"
_LICENSE = "MIT License"
_MAX_STEPS = 10
class infIterator:
def __init__(self, make_mthd):
self.make_mthd = make_mthd
self.i = None
def reset(self):
tasks = self.make_mthd()
rows = []
for task in tasks:
base = {
'types': str(task.request),
"description": task.name,
}
for (inp, outp) in task.examples:
rows.append(dict(input=str(inp), output=str(outp), **base))
shuffle(rows)
self.rows = rows
self.i = 0
def step(self):
if self.i is None:
self.reset()
row = self.rows[self.i]
self.i += 1
if self.i >= len(self.rows):
self.reset()
return row
class ProgramSynthesis(datasets.GeneratorBasedBuilder):
"""Program Synthesis dataset from dreamcoder."""
VERSION = datasets.Version("1.1.0")
BUILDER_CONFIGS = [
datasets.BuilderConfig(name="text", version=VERSION, description="Text tasks."),
datasets.BuilderConfig(name="list", version=VERSION, description="List tasks."),
datasets.BuilderConfig(name="all", version=VERSION, description="All tasks at once."),
]
DEFAULT_CONFIG_NAME = "all"
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=_FEATURES,
supervised_keys=("input", "output"),
homepage=_HOMEPAGE,
license=_LICENSE,
)
def _split_generators(self, dl_manager):
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
),
]
def _generate_examples(self):
task_samples = {
'text': infIterator(textMakeTasks),
'list': infIterator(listMakeTasks)
}
for key in range(_MAX_STEPS):
if self.config.name == 'all':
dataset_type = choice(task_samples.keys())
else:
dataset_type = self.config.name
yield key, task_samples[dataset_type].step()