to_delete / dreamcoder /program.py
Fraser-Greenlee
add dreamcoder codebase
e1c1753
# -*- coding: utf-8 -*-
from dreamcoder.type import *
from dreamcoder.utilities import *
from time import time
import math
class InferenceFailure(Exception):
pass
class ShiftFailure(Exception):
pass
class RunFailure(Exception):
pass
class Program(object):
def __repr__(self): return str(self)
def __ne__(self, o): return not (self == o)
def __str__(self): return self.show(False)
def canHaveType(self, t):
try:
context, actualType = self.inferType(Context.EMPTY, [], {})
context, t = t.instantiate(context)
context.unify(t, actualType)
return True
except UnificationFailure as e:
return False
def betaNormalForm(self):
n = self
while True:
np = n.betaReduce()
if np is None: return n
n = np
def infer(self):
try:
return self.inferType(Context.EMPTY, [], {})[1].canonical()
except UnificationFailure as e:
raise InferenceFailure(self, e)
def uncurry(self):
t = self.infer()
a = len(t.functionArguments())
e = self
existingAbstractions = 0
while e.isAbstraction:
e = e.body
existingAbstractions += 1
newAbstractions = a - existingAbstractions
assert newAbstractions >= 0
# e is the body stripped of abstractions. we are going to pile
# some more lambdas at the front, so free variables in e
# (which were bound to the stripped abstractions) need to be
# shifted by the number of abstractions that we will be adding
e = e.shift(newAbstractions)
for n in reversed(range(newAbstractions)):
e = Application(e, Index(n))
for _ in range(a):
e = Abstraction(e)
assert self.infer() == e.infer(), \
"FATAL: uncurry has a bug. %s : %s, but uncurried to %s : %s" % (self, self.infer(),
e, e.infer())
return e
def wellTyped(self):
try:
self.infer()
return True
except InferenceFailure:
return False
def runWithArguments(self, xs):
f = self.evaluate([])
for x in xs:
f = f(x)
return f
def applicationParses(self): yield self, []
def applicationParse(self): return self, []
@property
def closed(self):
for surroundingAbstractions, child in self.walk():
if isinstance(child, FragmentVariable):
return False
if isinstance(child, Index) and child.free(
surroundingAbstractions):
return False
return True
@property
def numberOfFreeVariables(expression):
n = 0
for surroundingAbstractions, child in expression.walk():
# Free variable
if isinstance(child, Index) and child.free(
surroundingAbstractions):
n = max(n, child.i - surroundingAbstractions + 1)
return n
def freeVariables(self):
for surroundingAbstractions, child in self.walk():
if child.isIndex and child.i >= surroundingAbstractions:
yield child.i - surroundingAbstractions
@property
def isIndex(self): return False
@property
def isUnion(self): return False
@property
def isApplication(self): return False
@property
def isAbstraction(self): return False
@property
def isPrimitive(self): return False
@property
def isInvented(self): return False
@property
def isHole(self): return False
@staticmethod
def parse(s):
s = parseSExpression(s)
def p(e):
if isinstance(e,list):
if e[0] == '#':
assert len(e) == 2
return Invented(p(e[1]))
if e[0] == 'lambda':
assert len(e) == 2
return Abstraction(p(e[1]))
f = p(e[0])
for x in e[1:]:
f = Application(f,p(x))
return f
assert isinstance(e,str)
if e[0] == '$': return Index(int(e[1:]))
if e in Primitive.GLOBALS: return Primitive.GLOBALS[e]
if e == '??' or e == '?': return FragmentVariable.single
if e == '<HOLE>': return Hole.single
raise ParseFailure((s,e))
return p(s)
@staticmethod
def _parse(s,n):
while n < len(s) and s[n].isspace():
n += 1
for p in [
Application,
Abstraction,
Index,
Invented,
FragmentVariable,
Hole,
Primitive]:
try:
return p._parse(s,n)
except ParseFailure:
continue
raise ParseFailure(s)
# parser helpers
@staticmethod
def parseConstant(s,n,*constants):
for constant in constants:
try:
for i,c in enumerate(constant):
if i + n >= len(s) or s[i + n] != c: raise ParseFailure(s)
return n + len(constant)
except ParseFailure: continue
raise ParseFailure(s)
@staticmethod
def parseHumanReadable(s):
s = parseSExpression(s)
def p(s, environment):
if isinstance(s, list) and s[0] in ['lambda','\\']:
assert isinstance(s[1], list) and len(s) == 3
newEnvironment = list(reversed(s[1])) + environment
e = p(s[2], newEnvironment)
for _ in s[1]: e = Abstraction(e)
return e
if isinstance(s, list):
a = p(s[0], environment)
for x in s[1:]:
a = Application(a, p(x, environment))
return a
for j,v in enumerate(environment):
if s == v: return Index(j)
if s in Primitive.GLOBALS: return Primitive.GLOBALS[s]
assert False, f"could not parse {s}"
return p(s, [])
class Application(Program):
'''Function application'''
def __init__(self, f, x):
self.f = f
self.x = x
self.hashCode = None
self.isConditional = (not isinstance(f,int)) and \
f.isApplication and \
f.f.isApplication and \
f.f.f.isPrimitive and \
f.f.f.name == "if"
if self.isConditional:
self.falseBranch = x
self.trueBranch = f.x
self.branch = f.f.x
else:
self.falseBranch = None
self.trueBranch = None
self.branch = None
def betaReduce(self):
# See if either the function or the argument can be reduced
f = self.f.betaReduce()
if f is not None: return Application(f,self.x)
x = self.x.betaReduce()
if x is not None: return Application(self.f,x)
# Neither of them could be reduced. Is this not a redex?
if not self.f.isAbstraction: return None
# Perform substitution
b = self.f.body
v = self.x
return b.substitute(Index(0), v.shift(1)).shift(-1)
def isBetaLong(self):
return (not self.f.isAbstraction) and self.f.isBetaLong() and self.x.isBetaLong()
def freeVariables(self):
return self.f.freeVariables() | self.x.freeVariables()
def clone(self): return Application(self.f.clone(), self.x.clone())
def annotateTypes(self, context, environment):
self.f.annotateTypes(context, environment)
self.x.annotateTypes(context, environment)
r = context.makeVariable()
context.unify(arrow(self.x.annotatedType, r), self.f.annotatedType)
self.annotatedType = r.applyMutable(context)
@property
def isApplication(self): return True
def __eq__(
self,
other): return isinstance(
other,
Application) and self.f == other.f and self.x == other.x
def __hash__(self):
if self.hashCode is None:
self.hashCode = hash((hash(self.f), hash(self.x)))
return self.hashCode
"""Because Python3 randomizes the hash function, we need to never pickle the hash"""
def __getstate__(self):
return self.f, self.x, self.isConditional, self.falseBranch, self.trueBranch, self.branch
def __setstate__(self, state):
try:
self.f, self.x, self.isConditional, self.falseBranch, self.trueBranch, self.branch = state
except ValueError:
# backward compatibility
assert 'x' in state
assert 'f' in state
f = state['f']
x = state['x']
self.f = f
self.x = x
self.isConditional = (not isinstance(f,int)) and \
f.isApplication and \
f.f.isApplication and \
f.f.f.isPrimitive and \
f.f.f.name == "if"
if self.isConditional:
self.falseBranch = x
self.trueBranch = f.x
self.branch = f.f.x
else:
self.falseBranch = None
self.trueBranch = None
self.branch = None
self.hashCode = None
def visit(self,
visitor,
*arguments,
**keywords): return visitor.application(self,
*arguments,
**keywords)
def show(self, isFunction):
if isFunction:
return "%s %s" % (self.f.show(True), self.x.show(False))
else:
return "(%s %s)" % (self.f.show(True), self.x.show(False))
def evaluate(self, environment):
if self.isConditional:
if self.branch.evaluate(environment):
return self.trueBranch.evaluate(environment)
else:
return self.falseBranch.evaluate(environment)
else:
return self.f.evaluate(environment)(self.x.evaluate(environment))
def inferType(self, context, environment, freeVariables):
(context, ft) = self.f.inferType(context, environment, freeVariables)
(context, xt) = self.x.inferType(context, environment, freeVariables)
(context, returnType) = context.makeVariable()
context = context.unify(ft, arrow(xt, returnType))
return (context, returnType.apply(context))
def applicationParses(self):
yield self, []
for f, xs in self.f.applicationParses():
yield f, xs + [self.x]
def applicationParse(self):
f, xs = self.f.applicationParse()
return f, xs + [self.x]
def shift(self, offset, depth=0):
return Application(self.f.shift(offset, depth),
self.x.shift(offset, depth))
def substitute(self, old, new):
if self == old:
return new
return Application(
self.f.substitute(
old, new), self.x.substitute(
old, new))
def walkUncurried(self, d=0):
yield d, self
f, xs = self.applicationParse()
yield from f.walkUncurried(d)
for x in xs:
yield from x.walkUncurried(d)
def walk(self, surroundingAbstractions=0):
yield surroundingAbstractions, self
yield from self.f.walk(surroundingAbstractions)
yield from self.x.walk(surroundingAbstractions)
def size(self): return self.f.size() + self.x.size()
@staticmethod
def _parse(s,n):
while n < len(s) and s[n].isspace(): n += 1
if n == len(s) or s[n] != '(': raise ParseFailure(s)
n += 1
xs = []
while True:
x, n = Program._parse(s, n)
xs.append(x)
while n < len(s) and s[n].isspace(): n += 1
if n == len(s):
raise ParseFailure(s)
if s[n] == ")":
n += 1
break
e = xs[0]
for x in xs[1:]:
e = Application(e, x)
return e, n
class Index(Program):
'''
deBruijn index: https://en.wikipedia.org/wiki/De_Bruijn_index
These indices encode variables.
'''
def __init__(self, i):
self.i = i
def show(self, isFunction): return "$%d" % self.i
def __eq__(self, o): return isinstance(o, Index) and o.i == self.i
def __hash__(self): return self.i
def visit(self,
visitor,
*arguments,
**keywords): return visitor.index(self,
*arguments,
**keywords)
def evaluate(self, environment):
return environment[self.i]
def inferType(self, context, environment, freeVariables):
if self.bound(len(environment)):
return (context, environment[self.i].apply(context))
else:
i = self.i - len(environment)
if i in freeVariables:
return (context, freeVariables[i].apply(context))
context, variable = context.makeVariable()
freeVariables[i] = variable
return (context, variable)
def clone(self): return Index(self.i)
def annotateTypes(self, context, environment):
self.annotatedType = environment[self.i].applyMutable(context)
def shift(self, offset, depth=0):
# bound variable
if self.bound(depth):
return self
else: # free variable
i = self.i + offset
if i < 0:
raise ShiftFailure()
return Index(i)
def betaReduce(self): return None
def isBetaLong(self): return True
def freeVariables(self): return {self.i}
def substitute(self, old, new):
if old == self:
return new
else:
return self
def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self
def walkUncurried(self, d=0): yield d, self
def size(self): return 1
def free(self, surroundingAbstractions):
'''Is this index a free variable, given that it has surroundingAbstractions lambda's around it?'''
return self.i >= surroundingAbstractions
def bound(self, surroundingAbstractions):
'''Is this index a bound variable, given that it has surroundingAbstractions lambda's around it?'''
return self.i < surroundingAbstractions
@property
def isIndex(self): return True
@staticmethod
def _parse(s,n):
while n < len(s) and s[n].isspace(): n += 1
if n == len(s) or s[n] != '$':
raise ParseFailure(s)
n += 1
j = ""
while n < len(s) and s[n].isdigit():
j += s[n]
n += 1
if j == "":
raise ParseFailure(s)
return Index(int(j)), n
class Abstraction(Program):
'''Lambda abstraction. Creates a new function.'''
def __init__(self, body):
self.body = body
self.hashCode = None
@property
def isAbstraction(self): return True
def __eq__(self, o): return isinstance(
o, Abstraction) and o.body == self.body
def __hash__(self):
if self.hashCode is None:
self.hashCode = hash((hash(self.body),))
return self.hashCode
"""Because Python3 randomizes the hash function, we need to never pickle the hash"""
def __getstate__(self):
return self.body
def __setstate__(self, state):
self.body = state
self.hashCode = None
def isBetaLong(self): return self.body.isBetaLong()
def freeVariables(self):
return {f - 1 for f in self.body.freeVariables() if f > 0}
def visit(self,
visitor,
*arguments,
**keywords): return visitor.abstraction(self,
*arguments,
**keywords)
def clone(self): return Abstraction(self.body.clone())
def annotateTypes(self, context, environment):
v = context.makeVariable()
self.body.annotateTypes(context, [v] + environment)
self.annotatedType = arrow(v.applyMutable(context), self.body.annotatedType)
def show(self, isFunction):
return "(lambda %s)" % (self.body.show(False))
def evaluate(self, environment):
return lambda x: self.body.evaluate([x] + environment)
def betaReduce(self):
b = self.body.betaReduce()
if b is None: return None
return Abstraction(b)
def inferType(self, context, environment, freeVariables):
(context, argumentType) = context.makeVariable()
(context, returnType) = self.body.inferType(
context, [argumentType] + environment, freeVariables)
return (context, arrow(argumentType, returnType).apply(context))
def shift(self, offset, depth=0):
return Abstraction(self.body.shift(offset, depth + 1))
def substitute(self, old, new):
if self == old:
return new
old = old.shift(1)
new = new.shift(1)
return Abstraction(self.body.substitute(old, new))
def walk(self, surroundingAbstractions=0):
yield surroundingAbstractions, self
yield from self.body.walk(surroundingAbstractions + 1)
def walkUncurried(self, d=0):
yield d, self
yield from self.body.walkUncurried(d + 1)
def size(self): return self.body.size()
@staticmethod
def _parse(s,n):
n = Program.parseConstant(s,n,
'(\\','(lambda','(\u03bb')
while n < len(s) and s[n].isspace(): n += 1
b, n = Program._parse(s,n)
while n < len(s) and s[n].isspace(): n += 1
n = Program.parseConstant(s,n,')')
return Abstraction(b), n
class Primitive(Program):
GLOBALS = {}
def __init__(self, name, ty, value):
self.tp = ty
self.name = name
self.value = value
if name not in Primitive.GLOBALS:
Primitive.GLOBALS[name] = self
@property
def isPrimitive(self): return True
def __eq__(self, o): return isinstance(
o, Primitive) and o.name == self.name
def __hash__(self): return hash(self.name)
def visit(self,
visitor,
*arguments,
**keywords): return visitor.primitive(self,
*arguments,
**keywords)
def show(self, isFunction): return self.name
def clone(self): return Primitive(self.name, self.tp, self.value)
def annotateTypes(self, context, environment):
self.annotatedType = self.tp.instantiateMutable(context)
def evaluate(self, environment): return self.value
def betaReduce(self): return None
def isBetaLong(self): return True
def freeVariables(self): return set()
def inferType(self, context, environment, freeVariables):
return self.tp.instantiate(context)
def shift(self, offset, depth=0): return self
def substitute(self, old, new):
if self == old:
return new
else:
return self
def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self
def walkUncurried(self, d=0): yield d, self
def size(self): return 1
@staticmethod
def _parse(s,n):
while n < len(s) and s[n].isspace(): n += 1
name = []
while n < len(s) and not s[n].isspace() and s[n] not in '()':
name.append(s[n])
n += 1
name = "".join(name)
if name in Primitive.GLOBALS:
return Primitive.GLOBALS[name], n
raise ParseFailure(s)
# TODO(@mtensor): needs to be fixed to handle both pickling lambda functions and unpickling in general.
# def __getstate__(self):
# return self.name
# def __setstate__(self, state):
# #for backwards compatibility:
# if type(state) == dict:
# self.__dict__ = state
# else:
# p = Primitive.GLOBALS[state]
# self.__init__(p.name, p.tp, p.value)
class Invented(Program):
'''New invented primitives'''
def __init__(self, body):
self.body = body
self.tp = self.body.infer()
self.hashCode = None
@property
def isInvented(self): return True
def show(self, isFunction): return "#%s" % (self.body.show(False))
def visit(self,
visitor,
*arguments,
**keywords): return visitor.invented(self,
*arguments,
**keywords)
def __eq__(self, o): return isinstance(o, Invented) and o.body == self.body
def __hash__(self):
if self.hashCode is None:
self.hashCode = hash((0, hash(self.body)))
return self.hashCode
"""Because Python3 randomizes the hash function, we need to never pickle the hash"""
def __getstate__(self):
return self.body, self.tp
def __setstate__(self, state):
self.body, self.tp = state
self.hashCode = None
def clone(self): return Invented(self.body)
def annotateTypes(self, context, environment):
self.annotatedType = self.tp.instantiateMutable(context)
def evaluate(self, e): return self.body.evaluate([])
def betaReduce(self): return self.body
def isBetaLong(self): return True
def freeVariables(self): return set()
def inferType(self, context, environment, freeVariables):
return self.tp.instantiate(context)
def shift(self, offset, depth=0): return self
def substitute(self, old, new):
if self == old:
return new
else:
return self
def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self
def walkUncurried(self, d=0): yield d, self
def size(self): return 1
@staticmethod
def _parse(s,n):
while n < len(s) and s[n].isspace(): n += 1
if n < len(s) and s[n] == '#':
n += 1
b,n = Program._parse(s,n)
return Invented(b),n
raise ParseFailure(s)
class FragmentVariable(Program):
def __init__(self): pass
def show(self, isFunction): return "??"
def __eq__(self, o): return isinstance(o, FragmentVariable)
def __hash__(self): return 42
def visit(self, visitor, *arguments, **keywords):
return visitor.fragmentVariable(self, *arguments, **keywords)
def evaluate(self, e):
raise Exception('Attempt to evaluate fragment variable')
def betaReduce(self):
raise Exception('Attempt to beta reduce fragment variable')
def inferType(self, context, environment, freeVariables):
return context.makeVariable()
def shift(self, offset, depth=0):
raise Exception('Attempt to shift fragment variable')
def substitute(self, old, new):
if self == old:
return new
else:
return self
def match(
self,
context,
expression,
holes,
variableBindings,
environment=[]):
surroundingAbstractions = len(environment)
try:
context, variable = context.makeVariable()
holes.append(
(variable, expression.shift(-surroundingAbstractions)))
return context, variable
except ShiftFailure:
raise MatchFailure()
def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self
def walkUncurried(self, d=0): yield d, self
def size(self): return 1
@staticmethod
def _parse(s,n):
while n < len(s) and s[n].isspace(): n += 1
n = Program.parseConstant(s,n,'??','?')
return FragmentVariable.single, n
FragmentVariable.single = FragmentVariable()
class Hole(Program):
def __init__(self): pass
def show(self, isFunction): return "<HOLE>"
@property
def isHole(self): return True
def __eq__(self, o): return isinstance(o, Hole)
def __hash__(self): return 42
def evaluate(self, e):
raise Exception('Attempt to evaluate hole')
def betaReduce(self):
raise Exception('Attempt to beta reduce hole')
def inferType(self, context, environment, freeVariables):
return context.makeVariable()
def shift(self, offset, depth=0):
raise Exception('Attempt to shift fragment variable')
def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self
def walkUncurried(self, d=0): yield d, self
def size(self): return 1
@staticmethod
def _parse(s,n):
while n < len(s) and s[n].isspace(): n += 1
n = Program.parseConstant(s,n,
'<HOLE>')
return Hole.single, n
Hole.single = Hole()
class ShareVisitor(object):
def __init__(self):
self.primitiveTable = {}
self.inventedTable = {}
self.indexTable = {}
self.applicationTable = {}
self.abstractionTable = {}
def invented(self, e):
body = e.body.visit(self)
i = id(body)
if i in self.inventedTable:
return self.inventedTable[i]
new = Invented(body)
self.inventedTable[i] = new
return new
def primitive(self, e):
if e.name in self.primitiveTable:
return self.primitiveTable[e.name]
self.primitiveTable[e.name] = e
return e
def index(self, e):
if e.i in self.indexTable:
return self.indexTable[e.i]
self.indexTable[e.i] = e
return e
def application(self, e):
f = e.f.visit(self)
x = e.x.visit(self)
fi = id(f)
xi = id(x)
i = (fi, xi)
if i in self.applicationTable:
return self.applicationTable[i]
new = Application(f, x)
self.applicationTable[i] = new
return new
def abstraction(self, e):
body = e.body.visit(self)
i = id(body)
if i in self.abstractionTable:
return self.abstractionTable[i]
new = Abstraction(body)
self.abstractionTable[i] = new
return new
def execute(self, e):
return e.visit(self)
class Mutator:
"""Perform local mutations to an expr, yielding the expr and the
description length distance from the original program"""
def __init__(self, grammar, fn):
"""Fn yields (expression, loglikelihood) from a type and loss.
Therefore, loss+loglikelihood is the distance from the original program."""
self.fn = fn
self.grammar = grammar
self.history = []
def enclose(self, expr):
for h in self.history[::-1]:
expr = h(expr)
return expr
def invented(self, e, tp, env, is_lhs=False):
deleted_ll = self.logLikelihood(tp, e, env)
for expr, replaced_ll in self.fn(tp, deleted, is_left_application=is_lhs):
yield self.enclose(expr), deleted_ll + replaced_ll
def primitive(self, e, tp, env, is_lhs=False):
deleted_ll = self.logLikelihood(tp, e, env)
for expr, replaced_ll in self.fn(tp, deleted_ll, is_left_application=is_lhs):
yield self.enclose(expr), deleted_ll + replaced_ll
def index(self, e, tp, env, is_lhs=False):
#yield from ()
deleted_ll = self.logLikelihood(tp, e, env) #self.grammar.logVariable
for expr, replaced_ll in self.fn(tp, deleted_ll, is_left_application=is_lhs):
yield self.enclose(expr), deleted_ll + replaced_ll
def application(self, e, tp, env, is_lhs=False):
self.history.append(lambda expr: Application(expr, e.x))
f_tp = arrow(e.x.infer(), tp)
yield from e.f.visit(self, f_tp, env, is_lhs=True)
self.history[-1] = lambda expr: Application(e.f, expr)
x_tp = inferArg(tp, e.f.infer())
yield from e.x.visit(self, x_tp, env)
self.history.pop()
deleted_ll = self.logLikelihood(tp, e, env)
for expr, replaced_ll in self.fn(tp, deleted_ll, is_left_application=is_lhs):
yield self.enclose(expr), deleted_ll + replaced_ll
def abstraction(self, e, tp, env, is_lhs=False):
self.history.append(lambda expr: Abstraction(expr))
yield from e.body.visit(self, tp.arguments[1], [tp.arguments[0]]+env)
self.history.pop()
deleted_ll = self.logLikelihood(tp, e, env)
for expr, replaced_ll in self.fn(tp, deleted_ll, is_left_application=is_lhs):
yield self.enclose(expr), deleted_ll + replaced_ll
def execute(self, e, tp):
yield from e.visit(self, tp, [])
def logLikelihood(self, tp, e, env):
summary = None
try:
_, summary = self.grammar.likelihoodSummary(Context.EMPTY, env,
tp, e, silent=True)
except AssertionError as err:
#print(f"closedLikelihoodSummary failed on tp={tp}, e={e}, error={err}")
pass
if summary is not None:
return summary.logLikelihood(self.grammar)
else:
tmpE, depth = e, 0
while isinstance(tmpE, Abstraction):
depth += 1
tmpE = tmpE.body
to_introduce = len(tp.functionArguments()) - depth
if to_introduce == 0:
#print(f"HIT NEGATIVEINFINITY, tp={tp}, e={e}")
return NEGATIVEINFINITY
for i in reversed(range(to_introduce)):
e = Application(e, Index(i))
for _ in range(to_introduce):
e = Abstraction(e)
return self.logLikelihood(tp, e, env)
class RegisterPrimitives(object):
def invented(self, e): e.body.visit(self)
def primitive(self, e):
if e.name not in Primitive.GLOBALS:
Primitive(e.name, e.tp, e.value)
def index(self, e): pass
def application(self, e):
e.f.visit(self)
e.x.visit(self)
def abstraction(self, e): e.body.visit(self)
@staticmethod
def register(e): e.visit(RegisterPrimitives())
class PrettyVisitor(object):
def __init__(self, Lisp=False):
self.Lisp = Lisp
self.numberOfVariables = 0
self.freeVariables = {}
self.variableNames = ["x", "y", "z", "u", "v", "w"]
self.variableNames += [chr(ord('a') + j)
for j in range(20)]
self.toplevel = True
def makeVariable(self):
v = self.variableNames[self.numberOfVariables]
self.numberOfVariables += 1
return v
def invented(self, e, environment, isFunction, isAbstraction):
s = e.body.visit(self, [], isFunction, isAbstraction)
return s
def primitive(self, e, environment, isVariable, isAbstraction): return e.name
def index(self, e, environment, isVariable, isAbstraction):
if e.i < len(environment):
return environment[e.i]
else:
i = e.i - len(environment)
if i in self.freeVariables:
return self.freeVariables[i]
else:
v = self.makeVariable()
self.freeVariables[i] = v
return v
def application(self, e, environment, isFunction, isAbstraction):
self.toplevel = False
s = "%s %s" % (e.f.visit(self, environment, True, False),
e.x.visit(self, environment, False, False))
if isFunction:
return s
else:
return "(" + s + ")"
def abstraction(self, e, environment, isFunction, isAbstraction):
toplevel = self.toplevel
self.toplevel = False
if not self.Lisp:
# Invent a new variable
v = self.makeVariable()
body = e.body.visit(self,
[v] + environment,
False,
True)
if not e.body.isAbstraction:
body = "." + body
body = v + body
if not isAbstraction:
body = "λ" + body
if not toplevel:
body = "(%s)" % body
return body
else:
child = e
newVariables = []
while child.isAbstraction:
newVariables = [self.makeVariable()] + newVariables
child = child.body
body = child.visit(self, newVariables + environment,
False, True)
body = "(λ (%s) %s)"%(" ".join(reversed(newVariables)), body)
return body
def prettyProgram(e, Lisp=False):
return e.visit(PrettyVisitor(Lisp=Lisp), [], False, False)
class EtaExpandFailure(Exception): pass
class EtaLongVisitor(object):
"""Converts an expression into eta-longform"""
def __init__(self, request=None):
self.request = request
self.context = None
def makeLong(self, e, request):
if request.isArrow():
# eta expansion
return Abstraction(Application(e.shift(1),
Index(0)))
return None
def abstraction(self, e, request, environment):
if not request.isArrow(): raise EtaExpandFailure()
return Abstraction(e.body.visit(self,
request.arguments[1],
[request.arguments[0]] + environment))
def _application(self, e, request, environment):
l = self.makeLong(e, request)
if l is not None: return l.visit(self, request, environment)
f, xs = e.applicationParse()
if f.isIndex:
ft = environment[f.i].applyMutable(self.context)
elif f.isInvented or f.isPrimitive:
ft = f.tp.instantiateMutable(self.context)
else: assert False, "Not in beta long form: %s"%e
self.context.unify(request, ft.returns())
ft = ft.applyMutable(self.context)
xt = ft.functionArguments()
if len(xs) != len(xt): raise EtaExpandFailure()
returnValue = f
for x,t in zip(xs,xt):
t = t.applyMutable(self.context)
returnValue = Application(returnValue,
x.visit(self, t, environment))
return returnValue
# This procedure works by recapitulating the generative process
# applications indices and primitives are all generated identically
def application(self, e, request, environment): return self._application(e, request, environment)
def index(self, e, request, environment): return self._application(e, request, environment)
def primitive(self, e, request, environment): return self._application(e, request, environment)
def invented(self, e, request, environment): return self._application(e, request, environment)
def execute(self, e):
assert len(e.freeVariables()) == 0
if self.request is None:
eprint("WARNING: request not specified for etaexpansion")
self.request = e.infer()
self.context = MutableContext()
el = e.visit(self, self.request, [])
self.context = None
# assert el.infer().canonical() == e.infer().canonical(), \
# f"Types are not preserved by ETA expansion: {e} : {e.infer().canonical()} vs {el} : {el.infer().canonical()}"
return el
class StripPrimitiveVisitor():
"""Replaces all primitives .value's w/ None. Does not destructively modify anything"""
def invented(self,e):
return Invented(e.body.visit(self))
def primitive(self,e):
return Primitive(e.name,e.tp,None)
def application(self,e):
return Application(e.f.visit(self),
e.x.visit(self))
def abstraction(self,e):
return Abstraction(e.body.visit(self))
def index(self,e): return e
class ReplacePrimitiveValueVisitor():
"""Intended to be used after StripPrimitiveVisitor.
Replaces all primitive.value's with their corresponding entry in Primitive.GLOBALS"""
def invented(self,e):
return Invented(e.body.visit(self))
def primitive(self,e):
return Primitive(e.name,e.tp,Primitive.GLOBALS[e.name].value)
def application(self,e):
return Application(e.f.visit(self),
e.x.visit(self))
def abstraction(self,e):
return Abstraction(e.body.visit(self))
def index(self,e): return e
def strip_primitive_values(e):
return e.visit(StripPrimitiveVisitor())
def unstrip_primitive_values(e):
return e.visit(ReplacePrimitiveValueVisitor())
# from luke
class TokeniseVisitor(object):
def invented(self, e):
return [e.body]
def primitive(self, e): return [e.name]
def index(self, e):
return ["$" + str(e.i)]
def application(self, e):
return ["("] + e.f.visit(self) + e.x.visit(self) + [")"]
def abstraction(self, e):
return ["(_lambda"] + e.body.visit(self) + [")_lambda"]
def tokeniseProgram(e):
return e.visit(TokeniseVisitor())
def untokeniseProgram(l):
lookup = {
"(_lambda": "(lambda",
")_lambda": ")"
}
s = " ".join(lookup.get(x, x) for x in l)
return Program.parse(s)
if __name__ == "__main__":
from dreamcoder.domains.arithmetic.arithmeticPrimitives import *
e = Program.parse("(#(lambda (?? (+ 1 $0))) (lambda (?? (+ 1 $0))) (lambda (?? (+ 1 $0))) - * (+ +))")
eprint(e)