|
|
|
|
|
from dreamcoder.type import * |
|
from dreamcoder.utilities import * |
|
|
|
from time import time |
|
import math |
|
|
|
|
|
class InferenceFailure(Exception): |
|
pass |
|
|
|
|
|
class ShiftFailure(Exception): |
|
pass |
|
|
|
class RunFailure(Exception): |
|
pass |
|
|
|
|
|
class Program(object): |
|
def __repr__(self): return str(self) |
|
|
|
def __ne__(self, o): return not (self == o) |
|
|
|
def __str__(self): return self.show(False) |
|
|
|
def canHaveType(self, t): |
|
try: |
|
context, actualType = self.inferType(Context.EMPTY, [], {}) |
|
context, t = t.instantiate(context) |
|
context.unify(t, actualType) |
|
return True |
|
except UnificationFailure as e: |
|
return False |
|
|
|
def betaNormalForm(self): |
|
n = self |
|
while True: |
|
np = n.betaReduce() |
|
if np is None: return n |
|
n = np |
|
|
|
def infer(self): |
|
try: |
|
return self.inferType(Context.EMPTY, [], {})[1].canonical() |
|
except UnificationFailure as e: |
|
raise InferenceFailure(self, e) |
|
|
|
def uncurry(self): |
|
t = self.infer() |
|
a = len(t.functionArguments()) |
|
e = self |
|
existingAbstractions = 0 |
|
while e.isAbstraction: |
|
e = e.body |
|
existingAbstractions += 1 |
|
newAbstractions = a - existingAbstractions |
|
assert newAbstractions >= 0 |
|
|
|
|
|
|
|
|
|
|
|
e = e.shift(newAbstractions) |
|
|
|
for n in reversed(range(newAbstractions)): |
|
e = Application(e, Index(n)) |
|
for _ in range(a): |
|
e = Abstraction(e) |
|
|
|
assert self.infer() == e.infer(), \ |
|
"FATAL: uncurry has a bug. %s : %s, but uncurried to %s : %s" % (self, self.infer(), |
|
e, e.infer()) |
|
return e |
|
|
|
def wellTyped(self): |
|
try: |
|
self.infer() |
|
return True |
|
except InferenceFailure: |
|
return False |
|
|
|
def runWithArguments(self, xs): |
|
f = self.evaluate([]) |
|
for x in xs: |
|
f = f(x) |
|
return f |
|
|
|
def applicationParses(self): yield self, [] |
|
|
|
def applicationParse(self): return self, [] |
|
|
|
@property |
|
def closed(self): |
|
for surroundingAbstractions, child in self.walk(): |
|
if isinstance(child, FragmentVariable): |
|
return False |
|
if isinstance(child, Index) and child.free( |
|
surroundingAbstractions): |
|
return False |
|
return True |
|
|
|
@property |
|
def numberOfFreeVariables(expression): |
|
n = 0 |
|
for surroundingAbstractions, child in expression.walk(): |
|
|
|
if isinstance(child, Index) and child.free( |
|
surroundingAbstractions): |
|
n = max(n, child.i - surroundingAbstractions + 1) |
|
return n |
|
|
|
def freeVariables(self): |
|
for surroundingAbstractions, child in self.walk(): |
|
if child.isIndex and child.i >= surroundingAbstractions: |
|
yield child.i - surroundingAbstractions |
|
|
|
@property |
|
def isIndex(self): return False |
|
|
|
@property |
|
def isUnion(self): return False |
|
|
|
@property |
|
def isApplication(self): return False |
|
|
|
@property |
|
def isAbstraction(self): return False |
|
|
|
@property |
|
def isPrimitive(self): return False |
|
|
|
@property |
|
def isInvented(self): return False |
|
|
|
@property |
|
def isHole(self): return False |
|
|
|
@staticmethod |
|
def parse(s): |
|
s = parseSExpression(s) |
|
def p(e): |
|
if isinstance(e,list): |
|
if e[0] == '#': |
|
assert len(e) == 2 |
|
return Invented(p(e[1])) |
|
if e[0] == 'lambda': |
|
assert len(e) == 2 |
|
return Abstraction(p(e[1])) |
|
f = p(e[0]) |
|
for x in e[1:]: |
|
f = Application(f,p(x)) |
|
return f |
|
assert isinstance(e,str) |
|
if e[0] == '$': return Index(int(e[1:])) |
|
if e in Primitive.GLOBALS: return Primitive.GLOBALS[e] |
|
if e == '??' or e == '?': return FragmentVariable.single |
|
if e == '<HOLE>': return Hole.single |
|
raise ParseFailure((s,e)) |
|
return p(s) |
|
|
|
@staticmethod |
|
def _parse(s,n): |
|
while n < len(s) and s[n].isspace(): |
|
n += 1 |
|
for p in [ |
|
Application, |
|
Abstraction, |
|
Index, |
|
Invented, |
|
FragmentVariable, |
|
Hole, |
|
Primitive]: |
|
try: |
|
return p._parse(s,n) |
|
except ParseFailure: |
|
continue |
|
raise ParseFailure(s) |
|
|
|
|
|
@staticmethod |
|
def parseConstant(s,n,*constants): |
|
for constant in constants: |
|
try: |
|
for i,c in enumerate(constant): |
|
if i + n >= len(s) or s[i + n] != c: raise ParseFailure(s) |
|
return n + len(constant) |
|
except ParseFailure: continue |
|
raise ParseFailure(s) |
|
|
|
@staticmethod |
|
def parseHumanReadable(s): |
|
s = parseSExpression(s) |
|
def p(s, environment): |
|
if isinstance(s, list) and s[0] in ['lambda','\\']: |
|
assert isinstance(s[1], list) and len(s) == 3 |
|
newEnvironment = list(reversed(s[1])) + environment |
|
e = p(s[2], newEnvironment) |
|
for _ in s[1]: e = Abstraction(e) |
|
return e |
|
if isinstance(s, list): |
|
a = p(s[0], environment) |
|
for x in s[1:]: |
|
a = Application(a, p(x, environment)) |
|
return a |
|
for j,v in enumerate(environment): |
|
if s == v: return Index(j) |
|
if s in Primitive.GLOBALS: return Primitive.GLOBALS[s] |
|
assert False, f"could not parse {s}" |
|
return p(s, []) |
|
|
|
|
|
|
|
|
|
class Application(Program): |
|
'''Function application''' |
|
|
|
def __init__(self, f, x): |
|
self.f = f |
|
self.x = x |
|
self.hashCode = None |
|
self.isConditional = (not isinstance(f,int)) and \ |
|
f.isApplication and \ |
|
f.f.isApplication and \ |
|
f.f.f.isPrimitive and \ |
|
f.f.f.name == "if" |
|
if self.isConditional: |
|
self.falseBranch = x |
|
self.trueBranch = f.x |
|
self.branch = f.f.x |
|
else: |
|
self.falseBranch = None |
|
self.trueBranch = None |
|
self.branch = None |
|
|
|
def betaReduce(self): |
|
|
|
f = self.f.betaReduce() |
|
if f is not None: return Application(f,self.x) |
|
x = self.x.betaReduce() |
|
if x is not None: return Application(self.f,x) |
|
|
|
|
|
if not self.f.isAbstraction: return None |
|
|
|
|
|
b = self.f.body |
|
v = self.x |
|
return b.substitute(Index(0), v.shift(1)).shift(-1) |
|
|
|
def isBetaLong(self): |
|
return (not self.f.isAbstraction) and self.f.isBetaLong() and self.x.isBetaLong() |
|
|
|
def freeVariables(self): |
|
return self.f.freeVariables() | self.x.freeVariables() |
|
|
|
def clone(self): return Application(self.f.clone(), self.x.clone()) |
|
|
|
def annotateTypes(self, context, environment): |
|
self.f.annotateTypes(context, environment) |
|
self.x.annotateTypes(context, environment) |
|
r = context.makeVariable() |
|
context.unify(arrow(self.x.annotatedType, r), self.f.annotatedType) |
|
self.annotatedType = r.applyMutable(context) |
|
|
|
|
|
@property |
|
def isApplication(self): return True |
|
|
|
def __eq__( |
|
self, |
|
other): return isinstance( |
|
other, |
|
Application) and self.f == other.f and self.x == other.x |
|
|
|
def __hash__(self): |
|
if self.hashCode is None: |
|
self.hashCode = hash((hash(self.f), hash(self.x))) |
|
return self.hashCode |
|
|
|
"""Because Python3 randomizes the hash function, we need to never pickle the hash""" |
|
def __getstate__(self): |
|
return self.f, self.x, self.isConditional, self.falseBranch, self.trueBranch, self.branch |
|
def __setstate__(self, state): |
|
try: |
|
self.f, self.x, self.isConditional, self.falseBranch, self.trueBranch, self.branch = state |
|
except ValueError: |
|
|
|
assert 'x' in state |
|
assert 'f' in state |
|
f = state['f'] |
|
x = state['x'] |
|
self.f = f |
|
self.x = x |
|
self.isConditional = (not isinstance(f,int)) and \ |
|
f.isApplication and \ |
|
f.f.isApplication and \ |
|
f.f.f.isPrimitive and \ |
|
f.f.f.name == "if" |
|
if self.isConditional: |
|
self.falseBranch = x |
|
self.trueBranch = f.x |
|
self.branch = f.f.x |
|
else: |
|
self.falseBranch = None |
|
self.trueBranch = None |
|
self.branch = None |
|
|
|
self.hashCode = None |
|
|
|
def visit(self, |
|
visitor, |
|
*arguments, |
|
**keywords): return visitor.application(self, |
|
*arguments, |
|
**keywords) |
|
|
|
def show(self, isFunction): |
|
if isFunction: |
|
return "%s %s" % (self.f.show(True), self.x.show(False)) |
|
else: |
|
return "(%s %s)" % (self.f.show(True), self.x.show(False)) |
|
|
|
def evaluate(self, environment): |
|
if self.isConditional: |
|
if self.branch.evaluate(environment): |
|
return self.trueBranch.evaluate(environment) |
|
else: |
|
return self.falseBranch.evaluate(environment) |
|
else: |
|
return self.f.evaluate(environment)(self.x.evaluate(environment)) |
|
|
|
def inferType(self, context, environment, freeVariables): |
|
(context, ft) = self.f.inferType(context, environment, freeVariables) |
|
(context, xt) = self.x.inferType(context, environment, freeVariables) |
|
(context, returnType) = context.makeVariable() |
|
context = context.unify(ft, arrow(xt, returnType)) |
|
return (context, returnType.apply(context)) |
|
|
|
def applicationParses(self): |
|
yield self, [] |
|
for f, xs in self.f.applicationParses(): |
|
yield f, xs + [self.x] |
|
|
|
def applicationParse(self): |
|
f, xs = self.f.applicationParse() |
|
return f, xs + [self.x] |
|
|
|
def shift(self, offset, depth=0): |
|
return Application(self.f.shift(offset, depth), |
|
self.x.shift(offset, depth)) |
|
|
|
def substitute(self, old, new): |
|
if self == old: |
|
return new |
|
return Application( |
|
self.f.substitute( |
|
old, new), self.x.substitute( |
|
old, new)) |
|
|
|
def walkUncurried(self, d=0): |
|
yield d, self |
|
f, xs = self.applicationParse() |
|
yield from f.walkUncurried(d) |
|
for x in xs: |
|
yield from x.walkUncurried(d) |
|
|
|
def walk(self, surroundingAbstractions=0): |
|
yield surroundingAbstractions, self |
|
yield from self.f.walk(surroundingAbstractions) |
|
yield from self.x.walk(surroundingAbstractions) |
|
|
|
def size(self): return self.f.size() + self.x.size() |
|
|
|
@staticmethod |
|
def _parse(s,n): |
|
while n < len(s) and s[n].isspace(): n += 1 |
|
if n == len(s) or s[n] != '(': raise ParseFailure(s) |
|
n += 1 |
|
|
|
xs = [] |
|
while True: |
|
x, n = Program._parse(s, n) |
|
xs.append(x) |
|
while n < len(s) and s[n].isspace(): n += 1 |
|
if n == len(s): |
|
raise ParseFailure(s) |
|
if s[n] == ")": |
|
n += 1 |
|
break |
|
e = xs[0] |
|
for x in xs[1:]: |
|
e = Application(e, x) |
|
return e, n |
|
|
|
|
|
class Index(Program): |
|
''' |
|
deBruijn index: https://en.wikipedia.org/wiki/De_Bruijn_index |
|
These indices encode variables. |
|
''' |
|
|
|
def __init__(self, i): |
|
self.i = i |
|
|
|
def show(self, isFunction): return "$%d" % self.i |
|
|
|
def __eq__(self, o): return isinstance(o, Index) and o.i == self.i |
|
|
|
def __hash__(self): return self.i |
|
|
|
def visit(self, |
|
visitor, |
|
*arguments, |
|
**keywords): return visitor.index(self, |
|
*arguments, |
|
**keywords) |
|
|
|
def evaluate(self, environment): |
|
return environment[self.i] |
|
|
|
def inferType(self, context, environment, freeVariables): |
|
if self.bound(len(environment)): |
|
return (context, environment[self.i].apply(context)) |
|
else: |
|
i = self.i - len(environment) |
|
if i in freeVariables: |
|
return (context, freeVariables[i].apply(context)) |
|
context, variable = context.makeVariable() |
|
freeVariables[i] = variable |
|
return (context, variable) |
|
|
|
def clone(self): return Index(self.i) |
|
|
|
def annotateTypes(self, context, environment): |
|
self.annotatedType = environment[self.i].applyMutable(context) |
|
|
|
def shift(self, offset, depth=0): |
|
|
|
if self.bound(depth): |
|
return self |
|
else: |
|
i = self.i + offset |
|
if i < 0: |
|
raise ShiftFailure() |
|
return Index(i) |
|
|
|
def betaReduce(self): return None |
|
|
|
def isBetaLong(self): return True |
|
|
|
def freeVariables(self): return {self.i} |
|
|
|
def substitute(self, old, new): |
|
if old == self: |
|
return new |
|
else: |
|
return self |
|
|
|
def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self |
|
|
|
def walkUncurried(self, d=0): yield d, self |
|
|
|
def size(self): return 1 |
|
|
|
def free(self, surroundingAbstractions): |
|
'''Is this index a free variable, given that it has surroundingAbstractions lambda's around it?''' |
|
return self.i >= surroundingAbstractions |
|
|
|
def bound(self, surroundingAbstractions): |
|
'''Is this index a bound variable, given that it has surroundingAbstractions lambda's around it?''' |
|
return self.i < surroundingAbstractions |
|
|
|
@property |
|
def isIndex(self): return True |
|
|
|
@staticmethod |
|
def _parse(s,n): |
|
while n < len(s) and s[n].isspace(): n += 1 |
|
if n == len(s) or s[n] != '$': |
|
raise ParseFailure(s) |
|
n += 1 |
|
j = "" |
|
while n < len(s) and s[n].isdigit(): |
|
j += s[n] |
|
n += 1 |
|
if j == "": |
|
raise ParseFailure(s) |
|
return Index(int(j)), n |
|
|
|
|
|
class Abstraction(Program): |
|
'''Lambda abstraction. Creates a new function.''' |
|
|
|
def __init__(self, body): |
|
self.body = body |
|
self.hashCode = None |
|
|
|
@property |
|
def isAbstraction(self): return True |
|
|
|
def __eq__(self, o): return isinstance( |
|
o, Abstraction) and o.body == self.body |
|
|
|
def __hash__(self): |
|
if self.hashCode is None: |
|
self.hashCode = hash((hash(self.body),)) |
|
return self.hashCode |
|
|
|
"""Because Python3 randomizes the hash function, we need to never pickle the hash""" |
|
def __getstate__(self): |
|
return self.body |
|
def __setstate__(self, state): |
|
self.body = state |
|
self.hashCode = None |
|
|
|
def isBetaLong(self): return self.body.isBetaLong() |
|
|
|
def freeVariables(self): |
|
return {f - 1 for f in self.body.freeVariables() if f > 0} |
|
|
|
def visit(self, |
|
visitor, |
|
*arguments, |
|
**keywords): return visitor.abstraction(self, |
|
*arguments, |
|
**keywords) |
|
|
|
def clone(self): return Abstraction(self.body.clone()) |
|
|
|
def annotateTypes(self, context, environment): |
|
v = context.makeVariable() |
|
self.body.annotateTypes(context, [v] + environment) |
|
self.annotatedType = arrow(v.applyMutable(context), self.body.annotatedType) |
|
|
|
def show(self, isFunction): |
|
return "(lambda %s)" % (self.body.show(False)) |
|
|
|
def evaluate(self, environment): |
|
return lambda x: self.body.evaluate([x] + environment) |
|
|
|
def betaReduce(self): |
|
b = self.body.betaReduce() |
|
if b is None: return None |
|
return Abstraction(b) |
|
|
|
def inferType(self, context, environment, freeVariables): |
|
(context, argumentType) = context.makeVariable() |
|
(context, returnType) = self.body.inferType( |
|
context, [argumentType] + environment, freeVariables) |
|
return (context, arrow(argumentType, returnType).apply(context)) |
|
|
|
def shift(self, offset, depth=0): |
|
return Abstraction(self.body.shift(offset, depth + 1)) |
|
|
|
def substitute(self, old, new): |
|
if self == old: |
|
return new |
|
old = old.shift(1) |
|
new = new.shift(1) |
|
return Abstraction(self.body.substitute(old, new)) |
|
|
|
def walk(self, surroundingAbstractions=0): |
|
yield surroundingAbstractions, self |
|
yield from self.body.walk(surroundingAbstractions + 1) |
|
|
|
def walkUncurried(self, d=0): |
|
yield d, self |
|
yield from self.body.walkUncurried(d + 1) |
|
|
|
def size(self): return self.body.size() |
|
|
|
@staticmethod |
|
def _parse(s,n): |
|
n = Program.parseConstant(s,n, |
|
'(\\','(lambda','(\u03bb') |
|
|
|
while n < len(s) and s[n].isspace(): n += 1 |
|
|
|
b, n = Program._parse(s,n) |
|
while n < len(s) and s[n].isspace(): n += 1 |
|
n = Program.parseConstant(s,n,')') |
|
return Abstraction(b), n |
|
|
|
|
|
class Primitive(Program): |
|
GLOBALS = {} |
|
|
|
def __init__(self, name, ty, value): |
|
self.tp = ty |
|
self.name = name |
|
self.value = value |
|
if name not in Primitive.GLOBALS: |
|
Primitive.GLOBALS[name] = self |
|
|
|
@property |
|
def isPrimitive(self): return True |
|
|
|
def __eq__(self, o): return isinstance( |
|
o, Primitive) and o.name == self.name |
|
|
|
def __hash__(self): return hash(self.name) |
|
|
|
def visit(self, |
|
visitor, |
|
*arguments, |
|
**keywords): return visitor.primitive(self, |
|
*arguments, |
|
**keywords) |
|
|
|
def show(self, isFunction): return self.name |
|
|
|
def clone(self): return Primitive(self.name, self.tp, self.value) |
|
|
|
def annotateTypes(self, context, environment): |
|
self.annotatedType = self.tp.instantiateMutable(context) |
|
|
|
def evaluate(self, environment): return self.value |
|
|
|
def betaReduce(self): return None |
|
|
|
def isBetaLong(self): return True |
|
|
|
def freeVariables(self): return set() |
|
|
|
def inferType(self, context, environment, freeVariables): |
|
return self.tp.instantiate(context) |
|
|
|
def shift(self, offset, depth=0): return self |
|
|
|
def substitute(self, old, new): |
|
if self == old: |
|
return new |
|
else: |
|
return self |
|
|
|
def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self |
|
|
|
def walkUncurried(self, d=0): yield d, self |
|
|
|
def size(self): return 1 |
|
|
|
@staticmethod |
|
def _parse(s,n): |
|
while n < len(s) and s[n].isspace(): n += 1 |
|
name = [] |
|
while n < len(s) and not s[n].isspace() and s[n] not in '()': |
|
name.append(s[n]) |
|
n += 1 |
|
name = "".join(name) |
|
if name in Primitive.GLOBALS: |
|
return Primitive.GLOBALS[name], n |
|
raise ParseFailure(s) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Invented(Program): |
|
'''New invented primitives''' |
|
|
|
def __init__(self, body): |
|
self.body = body |
|
self.tp = self.body.infer() |
|
self.hashCode = None |
|
|
|
@property |
|
def isInvented(self): return True |
|
|
|
def show(self, isFunction): return "#%s" % (self.body.show(False)) |
|
|
|
def visit(self, |
|
visitor, |
|
*arguments, |
|
**keywords): return visitor.invented(self, |
|
*arguments, |
|
**keywords) |
|
|
|
def __eq__(self, o): return isinstance(o, Invented) and o.body == self.body |
|
|
|
def __hash__(self): |
|
if self.hashCode is None: |
|
self.hashCode = hash((0, hash(self.body))) |
|
return self.hashCode |
|
|
|
"""Because Python3 randomizes the hash function, we need to never pickle the hash""" |
|
def __getstate__(self): |
|
return self.body, self.tp |
|
def __setstate__(self, state): |
|
self.body, self.tp = state |
|
self.hashCode = None |
|
|
|
def clone(self): return Invented(self.body) |
|
|
|
def annotateTypes(self, context, environment): |
|
self.annotatedType = self.tp.instantiateMutable(context) |
|
|
|
def evaluate(self, e): return self.body.evaluate([]) |
|
|
|
def betaReduce(self): return self.body |
|
|
|
def isBetaLong(self): return True |
|
|
|
def freeVariables(self): return set() |
|
|
|
def inferType(self, context, environment, freeVariables): |
|
return self.tp.instantiate(context) |
|
|
|
def shift(self, offset, depth=0): return self |
|
|
|
def substitute(self, old, new): |
|
if self == old: |
|
return new |
|
else: |
|
return self |
|
|
|
def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self |
|
|
|
def walkUncurried(self, d=0): yield d, self |
|
|
|
def size(self): return 1 |
|
|
|
@staticmethod |
|
def _parse(s,n): |
|
while n < len(s) and s[n].isspace(): n += 1 |
|
if n < len(s) and s[n] == '#': |
|
n += 1 |
|
b,n = Program._parse(s,n) |
|
return Invented(b),n |
|
|
|
raise ParseFailure(s) |
|
|
|
|
|
class FragmentVariable(Program): |
|
def __init__(self): pass |
|
|
|
def show(self, isFunction): return "??" |
|
|
|
def __eq__(self, o): return isinstance(o, FragmentVariable) |
|
|
|
def __hash__(self): return 42 |
|
|
|
def visit(self, visitor, *arguments, **keywords): |
|
return visitor.fragmentVariable(self, *arguments, **keywords) |
|
|
|
def evaluate(self, e): |
|
raise Exception('Attempt to evaluate fragment variable') |
|
|
|
def betaReduce(self): |
|
raise Exception('Attempt to beta reduce fragment variable') |
|
|
|
def inferType(self, context, environment, freeVariables): |
|
return context.makeVariable() |
|
|
|
def shift(self, offset, depth=0): |
|
raise Exception('Attempt to shift fragment variable') |
|
|
|
def substitute(self, old, new): |
|
if self == old: |
|
return new |
|
else: |
|
return self |
|
|
|
def match( |
|
self, |
|
context, |
|
expression, |
|
holes, |
|
variableBindings, |
|
environment=[]): |
|
surroundingAbstractions = len(environment) |
|
try: |
|
context, variable = context.makeVariable() |
|
holes.append( |
|
(variable, expression.shift(-surroundingAbstractions))) |
|
return context, variable |
|
except ShiftFailure: |
|
raise MatchFailure() |
|
|
|
def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self |
|
|
|
def walkUncurried(self, d=0): yield d, self |
|
|
|
def size(self): return 1 |
|
|
|
@staticmethod |
|
def _parse(s,n): |
|
while n < len(s) and s[n].isspace(): n += 1 |
|
n = Program.parseConstant(s,n,'??','?') |
|
return FragmentVariable.single, n |
|
|
|
FragmentVariable.single = FragmentVariable() |
|
|
|
|
|
class Hole(Program): |
|
def __init__(self): pass |
|
|
|
def show(self, isFunction): return "<HOLE>" |
|
|
|
@property |
|
def isHole(self): return True |
|
|
|
def __eq__(self, o): return isinstance(o, Hole) |
|
|
|
def __hash__(self): return 42 |
|
|
|
def evaluate(self, e): |
|
raise Exception('Attempt to evaluate hole') |
|
|
|
def betaReduce(self): |
|
raise Exception('Attempt to beta reduce hole') |
|
|
|
def inferType(self, context, environment, freeVariables): |
|
return context.makeVariable() |
|
|
|
def shift(self, offset, depth=0): |
|
raise Exception('Attempt to shift fragment variable') |
|
|
|
def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self |
|
|
|
def walkUncurried(self, d=0): yield d, self |
|
|
|
def size(self): return 1 |
|
|
|
@staticmethod |
|
def _parse(s,n): |
|
while n < len(s) and s[n].isspace(): n += 1 |
|
n = Program.parseConstant(s,n, |
|
'<HOLE>') |
|
return Hole.single, n |
|
|
|
|
|
Hole.single = Hole() |
|
|
|
|
|
class ShareVisitor(object): |
|
def __init__(self): |
|
self.primitiveTable = {} |
|
self.inventedTable = {} |
|
self.indexTable = {} |
|
self.applicationTable = {} |
|
self.abstractionTable = {} |
|
|
|
def invented(self, e): |
|
body = e.body.visit(self) |
|
i = id(body) |
|
if i in self.inventedTable: |
|
return self.inventedTable[i] |
|
new = Invented(body) |
|
self.inventedTable[i] = new |
|
return new |
|
|
|
def primitive(self, e): |
|
if e.name in self.primitiveTable: |
|
return self.primitiveTable[e.name] |
|
self.primitiveTable[e.name] = e |
|
return e |
|
|
|
def index(self, e): |
|
if e.i in self.indexTable: |
|
return self.indexTable[e.i] |
|
self.indexTable[e.i] = e |
|
return e |
|
|
|
def application(self, e): |
|
f = e.f.visit(self) |
|
x = e.x.visit(self) |
|
fi = id(f) |
|
xi = id(x) |
|
i = (fi, xi) |
|
if i in self.applicationTable: |
|
return self.applicationTable[i] |
|
new = Application(f, x) |
|
self.applicationTable[i] = new |
|
return new |
|
|
|
def abstraction(self, e): |
|
body = e.body.visit(self) |
|
i = id(body) |
|
if i in self.abstractionTable: |
|
return self.abstractionTable[i] |
|
new = Abstraction(body) |
|
self.abstractionTable[i] = new |
|
return new |
|
|
|
def execute(self, e): |
|
return e.visit(self) |
|
|
|
|
|
class Mutator: |
|
"""Perform local mutations to an expr, yielding the expr and the |
|
description length distance from the original program""" |
|
|
|
def __init__(self, grammar, fn): |
|
"""Fn yields (expression, loglikelihood) from a type and loss. |
|
Therefore, loss+loglikelihood is the distance from the original program.""" |
|
self.fn = fn |
|
self.grammar = grammar |
|
self.history = [] |
|
|
|
def enclose(self, expr): |
|
for h in self.history[::-1]: |
|
expr = h(expr) |
|
return expr |
|
|
|
def invented(self, e, tp, env, is_lhs=False): |
|
deleted_ll = self.logLikelihood(tp, e, env) |
|
for expr, replaced_ll in self.fn(tp, deleted, is_left_application=is_lhs): |
|
yield self.enclose(expr), deleted_ll + replaced_ll |
|
|
|
def primitive(self, e, tp, env, is_lhs=False): |
|
deleted_ll = self.logLikelihood(tp, e, env) |
|
for expr, replaced_ll in self.fn(tp, deleted_ll, is_left_application=is_lhs): |
|
yield self.enclose(expr), deleted_ll + replaced_ll |
|
|
|
def index(self, e, tp, env, is_lhs=False): |
|
|
|
deleted_ll = self.logLikelihood(tp, e, env) |
|
for expr, replaced_ll in self.fn(tp, deleted_ll, is_left_application=is_lhs): |
|
yield self.enclose(expr), deleted_ll + replaced_ll |
|
|
|
def application(self, e, tp, env, is_lhs=False): |
|
self.history.append(lambda expr: Application(expr, e.x)) |
|
f_tp = arrow(e.x.infer(), tp) |
|
yield from e.f.visit(self, f_tp, env, is_lhs=True) |
|
self.history[-1] = lambda expr: Application(e.f, expr) |
|
x_tp = inferArg(tp, e.f.infer()) |
|
yield from e.x.visit(self, x_tp, env) |
|
self.history.pop() |
|
deleted_ll = self.logLikelihood(tp, e, env) |
|
for expr, replaced_ll in self.fn(tp, deleted_ll, is_left_application=is_lhs): |
|
yield self.enclose(expr), deleted_ll + replaced_ll |
|
|
|
def abstraction(self, e, tp, env, is_lhs=False): |
|
self.history.append(lambda expr: Abstraction(expr)) |
|
yield from e.body.visit(self, tp.arguments[1], [tp.arguments[0]]+env) |
|
self.history.pop() |
|
deleted_ll = self.logLikelihood(tp, e, env) |
|
for expr, replaced_ll in self.fn(tp, deleted_ll, is_left_application=is_lhs): |
|
yield self.enclose(expr), deleted_ll + replaced_ll |
|
|
|
def execute(self, e, tp): |
|
yield from e.visit(self, tp, []) |
|
|
|
def logLikelihood(self, tp, e, env): |
|
summary = None |
|
try: |
|
_, summary = self.grammar.likelihoodSummary(Context.EMPTY, env, |
|
tp, e, silent=True) |
|
except AssertionError as err: |
|
|
|
pass |
|
if summary is not None: |
|
return summary.logLikelihood(self.grammar) |
|
else: |
|
tmpE, depth = e, 0 |
|
while isinstance(tmpE, Abstraction): |
|
depth += 1 |
|
tmpE = tmpE.body |
|
to_introduce = len(tp.functionArguments()) - depth |
|
if to_introduce == 0: |
|
|
|
return NEGATIVEINFINITY |
|
for i in reversed(range(to_introduce)): |
|
e = Application(e, Index(i)) |
|
for _ in range(to_introduce): |
|
e = Abstraction(e) |
|
return self.logLikelihood(tp, e, env) |
|
|
|
|
|
class RegisterPrimitives(object): |
|
def invented(self, e): e.body.visit(self) |
|
|
|
def primitive(self, e): |
|
if e.name not in Primitive.GLOBALS: |
|
Primitive(e.name, e.tp, e.value) |
|
|
|
def index(self, e): pass |
|
|
|
def application(self, e): |
|
e.f.visit(self) |
|
e.x.visit(self) |
|
|
|
def abstraction(self, e): e.body.visit(self) |
|
|
|
@staticmethod |
|
def register(e): e.visit(RegisterPrimitives()) |
|
|
|
|
|
class PrettyVisitor(object): |
|
def __init__(self, Lisp=False): |
|
self.Lisp = Lisp |
|
self.numberOfVariables = 0 |
|
self.freeVariables = {} |
|
|
|
self.variableNames = ["x", "y", "z", "u", "v", "w"] |
|
self.variableNames += [chr(ord('a') + j) |
|
for j in range(20)] |
|
self.toplevel = True |
|
|
|
def makeVariable(self): |
|
v = self.variableNames[self.numberOfVariables] |
|
self.numberOfVariables += 1 |
|
return v |
|
|
|
def invented(self, e, environment, isFunction, isAbstraction): |
|
s = e.body.visit(self, [], isFunction, isAbstraction) |
|
return s |
|
|
|
def primitive(self, e, environment, isVariable, isAbstraction): return e.name |
|
|
|
def index(self, e, environment, isVariable, isAbstraction): |
|
if e.i < len(environment): |
|
return environment[e.i] |
|
else: |
|
i = e.i - len(environment) |
|
if i in self.freeVariables: |
|
return self.freeVariables[i] |
|
else: |
|
v = self.makeVariable() |
|
self.freeVariables[i] = v |
|
return v |
|
|
|
def application(self, e, environment, isFunction, isAbstraction): |
|
self.toplevel = False |
|
s = "%s %s" % (e.f.visit(self, environment, True, False), |
|
e.x.visit(self, environment, False, False)) |
|
if isFunction: |
|
return s |
|
else: |
|
return "(" + s + ")" |
|
|
|
def abstraction(self, e, environment, isFunction, isAbstraction): |
|
toplevel = self.toplevel |
|
self.toplevel = False |
|
if not self.Lisp: |
|
|
|
v = self.makeVariable() |
|
body = e.body.visit(self, |
|
[v] + environment, |
|
False, |
|
True) |
|
if not e.body.isAbstraction: |
|
body = "." + body |
|
body = v + body |
|
if not isAbstraction: |
|
body = "λ" + body |
|
if not toplevel: |
|
body = "(%s)" % body |
|
return body |
|
else: |
|
child = e |
|
newVariables = [] |
|
while child.isAbstraction: |
|
newVariables = [self.makeVariable()] + newVariables |
|
child = child.body |
|
body = child.visit(self, newVariables + environment, |
|
False, True) |
|
body = "(λ (%s) %s)"%(" ".join(reversed(newVariables)), body) |
|
return body |
|
|
|
|
|
|
|
def prettyProgram(e, Lisp=False): |
|
return e.visit(PrettyVisitor(Lisp=Lisp), [], False, False) |
|
|
|
class EtaExpandFailure(Exception): pass |
|
class EtaLongVisitor(object): |
|
"""Converts an expression into eta-longform""" |
|
def __init__(self, request=None): |
|
self.request = request |
|
self.context = None |
|
|
|
def makeLong(self, e, request): |
|
if request.isArrow(): |
|
|
|
return Abstraction(Application(e.shift(1), |
|
Index(0))) |
|
return None |
|
|
|
|
|
def abstraction(self, e, request, environment): |
|
if not request.isArrow(): raise EtaExpandFailure() |
|
|
|
return Abstraction(e.body.visit(self, |
|
request.arguments[1], |
|
[request.arguments[0]] + environment)) |
|
|
|
def _application(self, e, request, environment): |
|
l = self.makeLong(e, request) |
|
if l is not None: return l.visit(self, request, environment) |
|
|
|
f, xs = e.applicationParse() |
|
|
|
if f.isIndex: |
|
ft = environment[f.i].applyMutable(self.context) |
|
elif f.isInvented or f.isPrimitive: |
|
ft = f.tp.instantiateMutable(self.context) |
|
else: assert False, "Not in beta long form: %s"%e |
|
|
|
self.context.unify(request, ft.returns()) |
|
ft = ft.applyMutable(self.context) |
|
|
|
xt = ft.functionArguments() |
|
if len(xs) != len(xt): raise EtaExpandFailure() |
|
|
|
returnValue = f |
|
for x,t in zip(xs,xt): |
|
t = t.applyMutable(self.context) |
|
returnValue = Application(returnValue, |
|
x.visit(self, t, environment)) |
|
return returnValue |
|
|
|
|
|
|
|
|
|
def application(self, e, request, environment): return self._application(e, request, environment) |
|
|
|
def index(self, e, request, environment): return self._application(e, request, environment) |
|
|
|
def primitive(self, e, request, environment): return self._application(e, request, environment) |
|
|
|
def invented(self, e, request, environment): return self._application(e, request, environment) |
|
|
|
def execute(self, e): |
|
assert len(e.freeVariables()) == 0 |
|
|
|
if self.request is None: |
|
eprint("WARNING: request not specified for etaexpansion") |
|
self.request = e.infer() |
|
self.context = MutableContext() |
|
el = e.visit(self, self.request, []) |
|
self.context = None |
|
|
|
|
|
return el |
|
|
|
|
|
|
|
class StripPrimitiveVisitor(): |
|
"""Replaces all primitives .value's w/ None. Does not destructively modify anything""" |
|
def invented(self,e): |
|
return Invented(e.body.visit(self)) |
|
def primitive(self,e): |
|
return Primitive(e.name,e.tp,None) |
|
def application(self,e): |
|
return Application(e.f.visit(self), |
|
e.x.visit(self)) |
|
def abstraction(self,e): |
|
return Abstraction(e.body.visit(self)) |
|
def index(self,e): return e |
|
|
|
class ReplacePrimitiveValueVisitor(): |
|
"""Intended to be used after StripPrimitiveVisitor. |
|
Replaces all primitive.value's with their corresponding entry in Primitive.GLOBALS""" |
|
def invented(self,e): |
|
return Invented(e.body.visit(self)) |
|
def primitive(self,e): |
|
return Primitive(e.name,e.tp,Primitive.GLOBALS[e.name].value) |
|
def application(self,e): |
|
return Application(e.f.visit(self), |
|
e.x.visit(self)) |
|
def abstraction(self,e): |
|
return Abstraction(e.body.visit(self)) |
|
def index(self,e): return e |
|
|
|
def strip_primitive_values(e): |
|
return e.visit(StripPrimitiveVisitor()) |
|
def unstrip_primitive_values(e): |
|
return e.visit(ReplacePrimitiveValueVisitor()) |
|
|
|
|
|
|
|
class TokeniseVisitor(object): |
|
def invented(self, e): |
|
return [e.body] |
|
|
|
def primitive(self, e): return [e.name] |
|
|
|
def index(self, e): |
|
return ["$" + str(e.i)] |
|
|
|
def application(self, e): |
|
return ["("] + e.f.visit(self) + e.x.visit(self) + [")"] |
|
|
|
def abstraction(self, e): |
|
return ["(_lambda"] + e.body.visit(self) + [")_lambda"] |
|
|
|
|
|
def tokeniseProgram(e): |
|
return e.visit(TokeniseVisitor()) |
|
|
|
|
|
def untokeniseProgram(l): |
|
lookup = { |
|
"(_lambda": "(lambda", |
|
")_lambda": ")" |
|
} |
|
s = " ".join(lookup.get(x, x) for x in l) |
|
return Program.parse(s) |
|
|
|
if __name__ == "__main__": |
|
from dreamcoder.domains.arithmetic.arithmeticPrimitives import * |
|
e = Program.parse("(#(lambda (?? (+ 1 $0))) (lambda (?? (+ 1 $0))) (lambda (?? (+ 1 $0))) - * (+ +))") |
|
eprint(e) |
|
|