from dreamcoder.type import * |
from dreamcoder.utilities import * |
from time import time |
import math |
class InferenceFailure(Exception): |
pass |
class ShiftFailure(Exception): |
pass |
class RunFailure(Exception): |
pass |
class Program(object): |
def __repr__(self): return str(self) |
def __ne__(self, o): return not (self == o) |
def __str__(self): return self.show(False) |
def canHaveType(self, t): |
try: |
context, actualType = self.inferType(Context.EMPTY, [], {}) |
context, t = t.instantiate(context) |
context.unify(t, actualType) |
return True |
except UnificationFailure as e: |
return False |
def betaNormalForm(self): |
n = self |
while True: |
np = n.betaReduce() |
if np is None: return n |
n = np |
def infer(self): |
try: |
return self.inferType(Context.EMPTY, [], {})[1].canonical() |
except UnificationFailure as e: |
raise InferenceFailure(self, e) |
def uncurry(self): |
t = self.infer() |
a = len(t.functionArguments()) |
e = self |
existingAbstractions = 0 |
while e.isAbstraction: |
e = e.body |
existingAbstractions += 1 |
newAbstractions = a - existingAbstractions |
assert newAbstractions >= 0 |
e = e.shift(newAbstractions) |
for n in reversed(range(newAbstractions)): |
e = Application(e, Index(n)) |
for _ in range(a): |
e = Abstraction(e) |
assert self.infer() == e.infer(), \ |
"FATAL: uncurry has a bug. %s : %s, but uncurried to %s : %s" % (self, self.infer(), |
e, e.infer()) |
return e |
def wellTyped(self): |
try: |
self.infer() |
return True |
except InferenceFailure: |
return False |
def runWithArguments(self, xs): |
f = self.evaluate([]) |
for x in xs: |
f = f(x) |
return f |
def applicationParses(self): yield self, [] |
def applicationParse(self): return self, [] |
@property |
def closed(self): |
for surroundingAbstractions, child in self.walk(): |
if isinstance(child, FragmentVariable): |
return False |
if isinstance(child, Index) and child.free( |
surroundingAbstractions): |
return False |
return True |
@property |
def numberOfFreeVariables(expression): |
n = 0 |
for surroundingAbstractions, child in expression.walk(): |
if isinstance(child, Index) and child.free( |
surroundingAbstractions): |
n = max(n, child.i - surroundingAbstractions + 1) |
return n |
def freeVariables(self): |
for surroundingAbstractions, child in self.walk(): |
if child.isIndex and child.i >= surroundingAbstractions: |
yield child.i - surroundingAbstractions |
@property |
def isIndex(self): return False |
@property |
def isUnion(self): return False |
@property |
def isApplication(self): return False |
@property |
def isAbstraction(self): return False |
@property |
def isPrimitive(self): return False |
@property |
def isInvented(self): return False |
@property |
def isHole(self): return False |
@staticmethod |
def parse(s): |
s = parseSExpression(s) |
def p(e): |
if isinstance(e,list): |
if e[0] == '#': |
assert len(e) == 2 |
return Invented(p(e[1])) |
if e[0] == 'lambda': |
assert len(e) == 2 |
return Abstraction(p(e[1])) |
f = p(e[0]) |
for x in e[1:]: |
f = Application(f,p(x)) |
return f |
assert isinstance(e,str) |
if e[0] == '$': return Index(int(e[1:])) |
if e in Primitive.GLOBALS: return Primitive.GLOBALS[e] |
if e == '??' or e == '?': return FragmentVariable.single |
if e == '<HOLE>': return Hole.single |
raise ParseFailure((s,e)) |
return p(s) |
@staticmethod |
def _parse(s,n): |
while n < len(s) and s[n].isspace(): |
n += 1 |
for p in [ |
Application, |
Abstraction, |
Index, |
Invented, |
FragmentVariable, |
Hole, |
Primitive]: |
try: |
return p._parse(s,n) |
except ParseFailure: |
continue |
raise ParseFailure(s) |
@staticmethod |
def parseConstant(s,n,*constants): |
for constant in constants: |
try: |
for i,c in enumerate(constant): |
if i + n >= len(s) or s[i + n] != c: raise ParseFailure(s) |
return n + len(constant) |
except ParseFailure: continue |
raise ParseFailure(s) |
@staticmethod |
def parseHumanReadable(s): |
s = parseSExpression(s) |
def p(s, environment): |
if isinstance(s, list) and s[0] in ['lambda','\\']: |
assert isinstance(s[1], list) and len(s) == 3 |
newEnvironment = list(reversed(s[1])) + environment |
e = p(s[2], newEnvironment) |
for _ in s[1]: e = Abstraction(e) |
return e |
if isinstance(s, list): |
a = p(s[0], environment) |
for x in s[1:]: |
a = Application(a, p(x, environment)) |
return a |
for j,v in enumerate(environment): |
if s == v: return Index(j) |
if s in Primitive.GLOBALS: return Primitive.GLOBALS[s] |
assert False, f"could not parse {s}" |
return p(s, []) |
class Application(Program): |
'''Function application''' |
def __init__(self, f, x): |
self.f = f |
self.x = x |
self.hashCode = None |
self.isConditional = (not isinstance(f,int)) and \ |
f.isApplication and \ |
f.f.isApplication and \ |
f.f.f.isPrimitive and \ |
f.f.f.name == "if" |
if self.isConditional: |
self.falseBranch = x |
self.trueBranch = f.x |
self.branch = f.f.x |
else: |
self.falseBranch = None |
self.trueBranch = None |
self.branch = None |
def betaReduce(self): |
f = self.f.betaReduce() |
if f is not None: return Application(f,self.x) |
x = self.x.betaReduce() |
if x is not None: return Application(self.f,x) |
if not self.f.isAbstraction: return None |
b = self.f.body |
v = self.x |
return b.substitute(Index(0), v.shift(1)).shift(-1) |
def isBetaLong(self): |
return (not self.f.isAbstraction) and self.f.isBetaLong() and self.x.isBetaLong() |
def freeVariables(self): |
return self.f.freeVariables() | self.x.freeVariables() |
def clone(self): return Application(self.f.clone(), self.x.clone()) |
def annotateTypes(self, context, environment): |
self.f.annotateTypes(context, environment) |
self.x.annotateTypes(context, environment) |
r = context.makeVariable() |
context.unify(arrow(self.x.annotatedType, r), self.f.annotatedType) |
self.annotatedType = r.applyMutable(context) |
@property |
def isApplication(self): return True |
def __eq__( |
self, |
other): return isinstance( |
other, |
Application) and self.f == other.f and self.x == other.x |
def __hash__(self): |
if self.hashCode is None: |
self.hashCode = hash((hash(self.f), hash(self.x))) |
return self.hashCode |
"""Because Python3 randomizes the hash function, we need to never pickle the hash""" |
def __getstate__(self): |
return self.f, self.x, self.isConditional, self.falseBranch, self.trueBranch, self.branch |
def __setstate__(self, state): |
try: |
self.f, self.x, self.isConditional, self.falseBranch, self.trueBranch, self.branch = state |
except ValueError: |
assert 'x' in state |
assert 'f' in state |
f = state['f'] |
x = state['x'] |
self.f = f |
self.x = x |
self.isConditional = (not isinstance(f,int)) and \ |
f.isApplication and \ |
f.f.isApplication and \ |
f.f.f.isPrimitive and \ |
f.f.f.name == "if" |
if self.isConditional: |
self.falseBranch = x |
self.trueBranch = f.x |
self.branch = f.f.x |
else: |
self.falseBranch = None |
self.trueBranch = None |
self.branch = None |
self.hashCode = None |
def visit(self, |
visitor, |
*arguments, |
**keywords): return visitor.application(self, |
*arguments, |
**keywords) |
def show(self, isFunction): |
if isFunction: |
return "%s %s" % (self.f.show(True), self.x.show(False)) |
else: |
return "(%s %s)" % (self.f.show(True), self.x.show(False)) |
def evaluate(self, environment): |
if self.isConditional: |
if self.branch.evaluate(environment): |
return self.trueBranch.evaluate(environment) |
else: |
return self.falseBranch.evaluate(environment) |
else: |
return self.f.evaluate(environment)(self.x.evaluate(environment)) |
def inferType(self, context, environment, freeVariables): |
(context, ft) = self.f.inferType(context, environment, freeVariables) |
(context, xt) = self.x.inferType(context, environment, freeVariables) |
(context, returnType) = context.makeVariable() |
context = context.unify(ft, arrow(xt, returnType)) |
return (context, returnType.apply(context)) |
def applicationParses(self): |
yield self, [] |
for f, xs in self.f.applicationParses(): |
yield f, xs + [self.x] |
def applicationParse(self): |
f, xs = self.f.applicationParse() |
return f, xs + [self.x] |
def shift(self, offset, depth=0): |
return Application(self.f.shift(offset, depth), |
self.x.shift(offset, depth)) |
def substitute(self, old, new): |
if self == old: |
return new |
return Application( |
self.f.substitute( |
old, new), self.x.substitute( |
old, new)) |
def walkUncurried(self, d=0): |
yield d, self |
f, xs = self.applicationParse() |
yield from f.walkUncurried(d) |
for x in xs: |
yield from x.walkUncurried(d) |
def walk(self, surroundingAbstractions=0): |
yield surroundingAbstractions, self |
yield from self.f.walk(surroundingAbstractions) |
yield from self.x.walk(surroundingAbstractions) |
def size(self): return self.f.size() + self.x.size() |
@staticmethod |
def _parse(s,n): |
while n < len(s) and s[n].isspace(): n += 1 |
if n == len(s) or s[n] != '(': raise ParseFailure(s) |
n += 1 |
xs = [] |
while True: |
x, n = Program._parse(s, n) |
xs.append(x) |
while n < len(s) and s[n].isspace(): n += 1 |
if n == len(s): |
raise ParseFailure(s) |
if s[n] == ")": |
n += 1 |
break |
e = xs[0] |
for x in xs[1:]: |
e = Application(e, x) |
return e, n |
class Index(Program): |
''' |
deBruijn index: https://en.wikipedia.org/wiki/De_Bruijn_index |
These indices encode variables. |
''' |
def __init__(self, i): |
self.i = i |
def show(self, isFunction): return "$%d" % self.i |
def __eq__(self, o): return isinstance(o, Index) and o.i == self.i |
def __hash__(self): return self.i |
def visit(self, |
visitor, |
*arguments, |
**keywords): return visitor.index(self, |
*arguments, |
**keywords) |
def evaluate(self, environment): |
return environment[self.i] |
def inferType(self, context, environment, freeVariables): |
if self.bound(len(environment)): |
return (context, environment[self.i].apply(context)) |
else: |
i = self.i - len(environment) |
if i in freeVariables: |
return (context, freeVariables[i].apply(context)) |
context, variable = context.makeVariable() |
freeVariables[i] = variable |
return (context, variable) |
def clone(self): return Index(self.i) |
def annotateTypes(self, context, environment): |
self.annotatedType = environment[self.i].applyMutable(context) |
def shift(self, offset, depth=0): |
if self.bound(depth): |
return self |
else: |
i = self.i + offset |
if i < 0: |
raise ShiftFailure() |
return Index(i) |
def betaReduce(self): return None |
def isBetaLong(self): return True |
def freeVariables(self): return {self.i} |
def substitute(self, old, new): |
if old == self: |
return new |
else: |
return self |
def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self |
def walkUncurried(self, d=0): yield d, self |
def size(self): return 1 |
def free(self, surroundingAbstractions): |
'''Is this index a free variable, given that it has surroundingAbstractions lambda's around it?''' |
return self.i >= surroundingAbstractions |
def bound(self, surroundingAbstractions): |
'''Is this index a bound variable, given that it has surroundingAbstractions lambda's around it?''' |
return self.i < surroundingAbstractions |
@property |
def isIndex(self): return True |
@staticmethod |
def _parse(s,n): |
while n < len(s) and s[n].isspace(): n += 1 |
if n == len(s) or s[n] != '$': |
raise ParseFailure(s) |
n += 1 |
j = "" |
while n < len(s) and s[n].isdigit(): |
j += s[n] |
n += 1 |
if j == "": |
raise ParseFailure(s) |
return Index(int(j)), n |
class Abstraction(Program): |
'''Lambda abstraction. Creates a new function.''' |
def __init__(self, body): |
self.body = body |
self.hashCode = None |
@property |
def isAbstraction(self): return True |
def __eq__(self, o): return isinstance( |
o, Abstraction) and o.body == self.body |
def __hash__(self): |
if self.hashCode is None: |
self.hashCode = hash((hash(self.body),)) |
return self.hashCode |
"""Because Python3 randomizes the hash function, we need to never pickle the hash""" |
def __getstate__(self): |
return self.body |
def __setstate__(self, state): |
self.body = state |
self.hashCode = None |
def isBetaLong(self): return self.body.isBetaLong() |
def freeVariables(self): |
return {f - 1 for f in self.body.freeVariables() if f > 0} |
def visit(self, |
visitor, |
*arguments, |
**keywords): return visitor.abstraction(self, |
*arguments, |
**keywords) |
def clone(self): return Abstraction(self.body.clone()) |
def annotateTypes(self, context, environment): |
v = context.makeVariable() |
self.body.annotateTypes(context, [v] + environment) |
self.annotatedType = arrow(v.applyMutable(context), self.body.annotatedType) |
def show(self, isFunction): |
return "(lambda %s)" % (self.body.show(False)) |
def evaluate(self, environment): |
return lambda x: self.body.evaluate([x] + environment) |
def betaReduce(self): |
b = self.body.betaReduce() |
if b is None: return None |
return Abstraction(b) |
def inferType(self, context, environment, freeVariables): |
(context, argumentType) = context.makeVariable() |
(context, returnType) = self.body.inferType( |
context, [argumentType] + environment, freeVariables) |
return (context, arrow(argumentType, returnType).apply(context)) |
def shift(self, offset, depth=0): |
return Abstraction(self.body.shift(offset, depth + 1)) |
def substitute(self, old, new): |
if self == old: |
return new |
old = old.shift(1) |
new = new.shift(1) |
return Abstraction(self.body.substitute(old, new)) |
def walk(self, surroundingAbstractions=0): |
yield surroundingAbstractions, self |
yield from self.body.walk(surroundingAbstractions + 1) |
def walkUncurried(self, d=0): |
yield d, self |
yield from self.body.walkUncurried(d + 1) |
def size(self): return self.body.size() |
@staticmethod |
def _parse(s,n): |
n = Program.parseConstant(s,n, |
'(\\','(lambda','(\u03bb') |
while n < len(s) and s[n].isspace(): n += 1 |
b, n = Program._parse(s,n) |
while n < len(s) and s[n].isspace(): n += 1 |
n = Program.parseConstant(s,n,')') |
return Abstraction(b), n |
class Primitive(Program): |
GLOBALS = {} |
def __init__(self, name, ty, value): |
self.tp = ty |
self.name = name |
self.value = value |
if name not in Primitive.GLOBALS: |
Primitive.GLOBALS[name] = self |
@property |
def isPrimitive(self): return True |
def __eq__(self, o): return isinstance( |
o, Primitive) and o.name == self.name |
def __hash__(self): return hash(self.name) |
def visit(self, |
visitor, |
*arguments, |
**keywords): return visitor.primitive(self, |
*arguments, |
**keywords) |
def show(self, isFunction): return self.name |
def clone(self): return Primitive(self.name, self.tp, self.value) |
def annotateTypes(self, context, environment): |
self.annotatedType = self.tp.instantiateMutable(context) |
def evaluate(self, environment): return self.value |
def betaReduce(self): return None |
def isBetaLong(self): return True |
def freeVariables(self): return set() |
def inferType(self, context, environment, freeVariables): |
return self.tp.instantiate(context) |
def shift(self, offset, depth=0): return self |
def substitute(self, old, new): |
if self == old: |
return new |
else: |
return self |
def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self |
def walkUncurried(self, d=0): yield d, self |
def size(self): return 1 |
@staticmethod |
def _parse(s,n): |
while n < len(s) and s[n].isspace(): n += 1 |
name = [] |
while n < len(s) and not s[n].isspace() and s[n] not in '()': |
name.append(s[n]) |
n += 1 |
name = "".join(name) |
if name in Primitive.GLOBALS: |
return Primitive.GLOBALS[name], n |
raise ParseFailure(s) |
class Invented(Program): |
'''New invented primitives''' |
def __init__(self, body): |
self.body = body |
self.tp = self.body.infer() |
self.hashCode = None |
@property |
def isInvented(self): return True |
def show(self, isFunction): return "#%s" % (self.body.show(False)) |
def visit(self, |
visitor, |
*arguments, |
**keywords): return visitor.invented(self, |
*arguments, |
**keywords) |
def __eq__(self, o): return isinstance(o, Invented) and o.body == self.body |
def __hash__(self): |
if self.hashCode is None: |
self.hashCode = hash((0, hash(self.body))) |
return self.hashCode |
"""Because Python3 randomizes the hash function, we need to never pickle the hash""" |
def __getstate__(self): |
return self.body, self.tp |
def __setstate__(self, state): |
self.body, self.tp = state |
self.hashCode = None |
def clone(self): return Invented(self.body) |
def annotateTypes(self, context, environment): |
self.annotatedType = self.tp.instantiateMutable(context) |
def evaluate(self, e): return self.body.evaluate([]) |
def betaReduce(self): return self.body |
def isBetaLong(self): return True |
def freeVariables(self): return set() |
def inferType(self, context, environment, freeVariables): |
return self.tp.instantiate(context) |
def shift(self, offset, depth=0): return self |
def substitute(self, old, new): |
if self == old: |
return new |
else: |
return self |
def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self |
def walkUncurried(self, d=0): yield d, self |
def size(self): return 1 |
@staticmethod |
def _parse(s,n): |
while n < len(s) and s[n].isspace(): n += 1 |
if n < len(s) and s[n] == '#': |
n += 1 |
b,n = Program._parse(s,n) |
return Invented(b),n |
raise ParseFailure(s) |
class FragmentVariable(Program): |
def __init__(self): pass |
def show(self, isFunction): return "??" |
def __eq__(self, o): return isinstance(o, FragmentVariable) |
def __hash__(self): return 42 |
def visit(self, visitor, *arguments, **keywords): |
return visitor.fragmentVariable(self, *arguments, **keywords) |
def evaluate(self, e): |
raise Exception('Attempt to evaluate fragment variable') |
def betaReduce(self): |
raise Exception('Attempt to beta reduce fragment variable') |
def inferType(self, context, environment, freeVariables): |
return context.makeVariable() |
def shift(self, offset, depth=0): |
raise Exception('Attempt to shift fragment variable') |
def substitute(self, old, new): |
if self == old: |
return new |
else: |
return self |
def match( |
self, |
context, |
expression, |
holes, |
variableBindings, |
environment=[]): |
surroundingAbstractions = len(environment) |
try: |
context, variable = context.makeVariable() |
holes.append( |
(variable, expression.shift(-surroundingAbstractions))) |
return context, variable |
except ShiftFailure: |
raise MatchFailure() |
def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self |
def walkUncurried(self, d=0): yield d, self |
def size(self): return 1 |
@staticmethod |
def _parse(s,n): |
while n < len(s) and s[n].isspace(): n += 1 |
n = Program.parseConstant(s,n,'??','?') |
return FragmentVariable.single, n |
FragmentVariable.single = FragmentVariable() |
class Hole(Program): |
def __init__(self): pass |
def show(self, isFunction): return "<HOLE>" |
@property |
def isHole(self): return True |
def __eq__(self, o): return isinstance(o, Hole) |
def __hash__(self): return 42 |
def evaluate(self, e): |
raise Exception('Attempt to evaluate hole') |
def betaReduce(self): |
raise Exception('Attempt to beta reduce hole') |
def inferType(self, context, environment, freeVariables): |
return context.makeVariable() |
def shift(self, offset, depth=0): |
raise Exception('Attempt to shift fragment variable') |
def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self |
def walkUncurried(self, d=0): yield d, self |
def size(self): return 1 |
@staticmethod |
def _parse(s,n): |
while n < len(s) and s[n].isspace(): n += 1 |
n = Program.parseConstant(s,n, |
'<HOLE>') |
return Hole.single, n |
Hole.single = Hole() |
class ShareVisitor(object): |
def __init__(self): |
self.primitiveTable = {} |
self.inventedTable = {} |
self.indexTable = {} |
self.applicationTable = {} |
self.abstractionTable = {} |
def invented(self, e): |
body = e.body.visit(self) |
i = id(body) |
if i in self.inventedTable: |
return self.inventedTable[i] |
new = Invented(body) |
self.inventedTable[i] = new |
return new |
def primitive(self, e): |
if e.name in self.primitiveTable: |
return self.primitiveTable[e.name] |
self.primitiveTable[e.name] = e |
return e |
def index(self, e): |
if e.i in self.indexTable: |
return self.indexTable[e.i] |
self.indexTable[e.i] = e |
return e |
def application(self, e): |
f = e.f.visit(self) |
x = e.x.visit(self) |
fi = id(f) |
xi = id(x) |
i = (fi, xi) |
if i in self.applicationTable: |
return self.applicationTable[i] |
new = Application(f, x) |
self.applicationTable[i] = new |
return new |
def abstraction(self, e): |
body = e.body.visit(self) |
i = id(body) |
if i in self.abstractionTable: |
return self.abstractionTable[i] |
new = Abstraction(body) |
self.abstractionTable[i] = new |
return new |
def execute(self, e): |
return e.visit(self) |
class Mutator: |
"""Perform local mutations to an expr, yielding the expr and the |
description length distance from the original program""" |
def __init__(self, grammar, fn): |
"""Fn yields (expression, loglikelihood) from a type and loss. |
Therefore, loss+loglikelihood is the distance from the original program.""" |
self.fn = fn |
self.grammar = grammar |
self.history = [] |
def enclose(self, expr): |
for h in self.history[::-1]: |
expr = h(expr) |
return expr |
def invented(self, e, tp, env, is_lhs=False): |
deleted_ll = self.logLikelihood(tp, e, env) |
for expr, replaced_ll in self.fn(tp, deleted, is_left_application=is_lhs): |
yield self.enclose(expr), deleted_ll + replaced_ll |
def primitive(self, e, tp, env, is_lhs=False): |
deleted_ll = self.logLikelihood(tp, e, env) |
for expr, replaced_ll in self.fn(tp, deleted_ll, is_left_application=is_lhs): |
yield self.enclose(expr), deleted_ll + replaced_ll |
def index(self, e, tp, env, is_lhs=False): |
deleted_ll = self.logLikelihood(tp, e, env) |
for expr, replaced_ll in self.fn(tp, deleted_ll, is_left_application=is_lhs): |
yield self.enclose(expr), deleted_ll + replaced_ll |
def application(self, e, tp, env, is_lhs=False): |
self.history.append(lambda expr: Application(expr, e.x)) |
f_tp = arrow(e.x.infer(), tp) |
yield from e.f.visit(self, f_tp, env, is_lhs=True) |
self.history[-1] = lambda expr: Application(e.f, expr) |
x_tp = inferArg(tp, e.f.infer()) |
yield from e.x.visit(self, x_tp, env) |
self.history.pop() |
deleted_ll = self.logLikelihood(tp, e, env) |
for expr, replaced_ll in self.fn(tp, deleted_ll, is_left_application=is_lhs): |
yield self.enclose(expr), deleted_ll + replaced_ll |
def abstraction(self, e, tp, env, is_lhs=False): |
self.history.append(lambda expr: Abstraction(expr)) |
yield from e.body.visit(self, tp.arguments[1], [tp.arguments[0]]+env) |
self.history.pop() |
deleted_ll = self.logLikelihood(tp, e, env) |
for expr, replaced_ll in self.fn(tp, deleted_ll, is_left_application=is_lhs): |
yield self.enclose(expr), deleted_ll + replaced_ll |
def execute(self, e, tp): |
yield from e.visit(self, tp, []) |
def logLikelihood(self, tp, e, env): |
summary = None |
try: |
_, summary = self.grammar.likelihoodSummary(Context.EMPTY, env, |
tp, e, silent=True) |
except AssertionError as err: |
pass |
if summary is not None: |
return summary.logLikelihood(self.grammar) |
else: |
tmpE, depth = e, 0 |
while isinstance(tmpE, Abstraction): |
depth += 1 |
tmpE = tmpE.body |
to_introduce = len(tp.functionArguments()) - depth |
if to_introduce == 0: |
for i in reversed(range(to_introduce)): |
e = Application(e, Index(i)) |
for _ in range(to_introduce): |
e = Abstraction(e) |
return self.logLikelihood(tp, e, env) |
class RegisterPrimitives(object): |
def invented(self, e): e.body.visit(self) |
def primitive(self, e): |
if e.name not in Primitive.GLOBALS: |
Primitive(e.name, e.tp, e.value) |
def index(self, e): pass |
def application(self, e): |
e.f.visit(self) |
e.x.visit(self) |
def abstraction(self, e): e.body.visit(self) |
@staticmethod |
def register(e): e.visit(RegisterPrimitives()) |
class PrettyVisitor(object): |
def __init__(self, Lisp=False): |
self.Lisp = Lisp |
self.numberOfVariables = 0 |
self.freeVariables = {} |
self.variableNames = ["x", "y", "z", "u", "v", "w"] |
self.variableNames += [chr(ord('a') + j) |
for j in range(20)] |
self.toplevel = True |
def makeVariable(self): |
v = self.variableNames[self.numberOfVariables] |
self.numberOfVariables += 1 |
return v |
def invented(self, e, environment, isFunction, isAbstraction): |
s = e.body.visit(self, [], isFunction, isAbstraction) |
return s |
def primitive(self, e, environment, isVariable, isAbstraction): return e.name |
def index(self, e, environment, isVariable, isAbstraction): |
if e.i < len(environment): |
return environment[e.i] |
else: |
i = e.i - len(environment) |
if i in self.freeVariables: |
return self.freeVariables[i] |
else: |
v = self.makeVariable() |
self.freeVariables[i] = v |
return v |
def application(self, e, environment, isFunction, isAbstraction): |
self.toplevel = False |
s = "%s %s" % (e.f.visit(self, environment, True, False), |
e.x.visit(self, environment, False, False)) |
if isFunction: |
return s |
else: |
return "(" + s + ")" |
def abstraction(self, e, environment, isFunction, isAbstraction): |
toplevel = self.toplevel |
self.toplevel = False |
if not self.Lisp: |
v = self.makeVariable() |
body = e.body.visit(self, |
[v] + environment, |
False, |
True) |
if not e.body.isAbstraction: |
body = "." + body |
body = v + body |
if not isAbstraction: |
body = "λ" + body |
if not toplevel: |
body = "(%s)" % body |
return body |
else: |
child = e |
newVariables = [] |
while child.isAbstraction: |
newVariables = [self.makeVariable()] + newVariables |
child = child.body |
body = child.visit(self, newVariables + environment, |
False, True) |
body = "(λ (%s) %s)"%(" ".join(reversed(newVariables)), body) |
return body |
def prettyProgram(e, Lisp=False): |
return e.visit(PrettyVisitor(Lisp=Lisp), [], False, False) |
class EtaExpandFailure(Exception): pass |
class EtaLongVisitor(object): |
"""Converts an expression into eta-longform""" |
def __init__(self, request=None): |
self.request = request |
self.context = None |
def makeLong(self, e, request): |
if request.isArrow(): |
return Abstraction(Application(e.shift(1), |
Index(0))) |
return None |
def abstraction(self, e, request, environment): |
if not request.isArrow(): raise EtaExpandFailure() |
return Abstraction(e.body.visit(self, |
request.arguments[1], |
[request.arguments[0]] + environment)) |
def _application(self, e, request, environment): |
l = self.makeLong(e, request) |
if l is not None: return l.visit(self, request, environment) |
f, xs = e.applicationParse() |
if f.isIndex: |
ft = environment[f.i].applyMutable(self.context) |
elif f.isInvented or f.isPrimitive: |
ft = f.tp.instantiateMutable(self.context) |
else: assert False, "Not in beta long form: %s"%e |
self.context.unify(request, ft.returns()) |
ft = ft.applyMutable(self.context) |
xt = ft.functionArguments() |
if len(xs) != len(xt): raise EtaExpandFailure() |
returnValue = f |
for x,t in zip(xs,xt): |
t = t.applyMutable(self.context) |
returnValue = Application(returnValue, |
x.visit(self, t, environment)) |
return returnValue |
def application(self, e, request, environment): return self._application(e, request, environment) |
def index(self, e, request, environment): return self._application(e, request, environment) |
def primitive(self, e, request, environment): return self._application(e, request, environment) |
def invented(self, e, request, environment): return self._application(e, request, environment) |
def execute(self, e): |
assert len(e.freeVariables()) == 0 |
if self.request is None: |
eprint("WARNING: request not specified for etaexpansion") |
self.request = e.infer() |
self.context = MutableContext() |
el = e.visit(self, self.request, []) |
self.context = None |
return el |
class StripPrimitiveVisitor(): |
"""Replaces all primitives .value's w/ None. Does not destructively modify anything""" |
def invented(self,e): |
return Invented(e.body.visit(self)) |
def primitive(self,e): |
return Primitive(e.name,e.tp,None) |
def application(self,e): |
return Application(e.f.visit(self), |
e.x.visit(self)) |
def abstraction(self,e): |
return Abstraction(e.body.visit(self)) |
def index(self,e): return e |
class ReplacePrimitiveValueVisitor(): |
"""Intended to be used after StripPrimitiveVisitor. |
Replaces all primitive.value's with their corresponding entry in Primitive.GLOBALS""" |
def invented(self,e): |
return Invented(e.body.visit(self)) |
def primitive(self,e): |
return Primitive(e.name,e.tp,Primitive.GLOBALS[e.name].value) |
def application(self,e): |
return Application(e.f.visit(self), |
e.x.visit(self)) |
def abstraction(self,e): |
return Abstraction(e.body.visit(self)) |
def index(self,e): return e |
def strip_primitive_values(e): |
return e.visit(StripPrimitiveVisitor()) |
def unstrip_primitive_values(e): |
return e.visit(ReplacePrimitiveValueVisitor()) |
class TokeniseVisitor(object): |
def invented(self, e): |
return [e.body] |
def primitive(self, e): return [e.name] |
def index(self, e): |
return ["$" + str(e.i)] |
def application(self, e): |
return ["("] + e.f.visit(self) + e.x.visit(self) + [")"] |
def abstraction(self, e): |
return ["(_lambda"] + e.body.visit(self) + [")_lambda"] |
def tokeniseProgram(e): |
return e.visit(TokeniseVisitor()) |
def untokeniseProgram(l): |
lookup = { |
"(_lambda": "(lambda", |
")_lambda": ")" |
} |
s = " ".join(lookup.get(x, x) for x in l) |
return Program.parse(s) |
if __name__ == "__main__": |
from dreamcoder.domains.arithmetic.arithmeticPrimitives import * |
e = Program.parse("(#(lambda (?? (+ 1 $0))) (lambda (?? (+ 1 $0))) (lambda (?? (+ 1 $0))) - * (+ +))") |
eprint(e) |