from collections import defaultdict from dreamcoder.frontier import * from dreamcoder.program import * from dreamcoder.type import * from dreamcoder.utilities import * import time class GrammarFailure(Exception): pass class SketchEnumerationFailure(Exception): pass class NoCandidates(Exception): pass class Grammar(object): def __init__(self, logVariable, productions, continuationType=None): self.logVariable = logVariable self.productions = productions self.continuationType = continuationType self.expression2likelihood = dict((p, l) for l, _, p in productions) self.expression2likelihood[Index(0)] = self.logVariable def randomWeights(self, r): """returns a new grammar with random weights drawn from r. calls `r` w/ old weight""" return Grammar(logVariable=r(self.logVariable), productions=[(r(l),t,p) for l,t,p in self.productions ], continuationType=self.continuationType) def strip_primitive_values(self): return Grammar(logVariable=self.logVariable, productions=[(l,t,strip_primitive_values(p)) for l,t,p in self.productions ], continuationType=self.continuationType) def unstrip_primitive_values(self): return Grammar(logVariable=self.logVariable, productions=[(l,t,unstrip_primitive_values(p)) for l,t,p in self.productions ], continuationType=self.continuationType) def __setstate__(self, state): """ Legacy support for loading grammar objects without the imperative type filled in """ assert 'logVariable' in state assert 'productions' in state if 'continuationType' in state: continuationType = state['continuationType'] else: if any( 'turtle' in str(t) for l,t,p in state['productions'] ): continuationType = baseType("turtle") elif any( 'tower' in str(t) for l,t,p in state['productions'] ): continuationType = baseType("tower") else: continuationType = None self.__init__(state['logVariable'], state['productions'], continuationType=continuationType) @staticmethod def fromProductions(productions, logVariable=0.0, continuationType=None): """Make a grammar from primitives and their relative logpriors.""" return Grammar(logVariable, [(l, p.infer(), p) for l, p in productions], continuationType=continuationType) @staticmethod def uniform(primitives, continuationType=None): return Grammar(0.0, [(0.0, p.infer(), p) for p in primitives], continuationType=continuationType) def __len__(self): return len(self.productions) def __str__(self): def productionKey(xxx_todo_changeme): (l, t, p) = xxx_todo_changeme return not isinstance(p, Primitive), l is not None and -l if self.continuationType is not None: lines = ["continuation : %s"%self.continuationType] else: lines = [] lines += ["%f\tt0\t$_" % self.logVariable] for l, t, p in sorted(self.productions, key=productionKey): if l is not None: l = "%f\t%s\t%s" % (l, t, p) else: l = "-Inf\t%s\t%s" % (t, p) if not t.isArrow() and isinstance(p, Invented): try: l += "\teval = %s" % (p.evaluate([])) except BaseException: pass lines.append(l) return "\n".join(lines) def json(self): j = {"logVariable": self.logVariable, "productions": [{"expression": str(p), "logProbability": l} for l, _, p in self.productions]} if self.continuationType is not None: j["continuationType"] = self.continuationType.json() return j def _immutable_code(self): return self.logVariable, tuple(self.productions) def __eq__(self, o): return self._immutable_code() == o._immutable_code() def __ne__(self, o): return not (self == o) def __hash__(self): return hash(self._immutable_code()) @property def primitives(self): return [p for _, _, p in self.productions] def removeProductions(self, ps): return Grammar( self.logVariable, [ (l, t, p) for ( l, t, p) in self.productions if p not in ps], continuationType=self.continuationType) def buildCandidates(self, request, context, environment, # Should the log probabilities be normalized? normalize=True, # Should be returned a table mapping primitives to # their candidate entry? returnTable=False, # Should we return probabilities vs log probabilities? returnProbabilities=False, # Must be a leaf (have no arguments)? mustBeLeaf=False): """Primitives that are candidates for being used given a requested type If returnTable is false (default): returns [((log)likelihood, tp, primitive, context)] if returntable is true: returns {primitive: ((log)likelihood, tp, context)}""" if returnProbabilities: assert normalize candidates = [] variableCandidates = [] for l, t, p in self.productions: try: newContext, t = t.instantiate(context) newContext = newContext.unify(t.returns(), request) t = t.apply(newContext) if mustBeLeaf and t.isArrow(): continue candidates.append((l, t, p, newContext)) except UnificationFailure: continue for j, t in enumerate(environment): try: newContext = context.unify(t.returns(), request) t = t.apply(newContext) if mustBeLeaf and t.isArrow(): continue variableCandidates.append((t, Index(j), newContext)) except UnificationFailure: continue if self.continuationType == request: terminalIndices = [v.i for t,v,k in variableCandidates if not t.isArrow()] if terminalIndices: smallestIndex = Index(min(terminalIndices)) variableCandidates = [(t,v,k) for t,v,k in variableCandidates if t.isArrow() or v == smallestIndex] candidates += [(self.logVariable - log(len(variableCandidates)), t, p, k) for t, p, k in variableCandidates] if candidates == []: raise NoCandidates() #eprint("candidates inside buildCandidates before norm:") #eprint(candidates) if normalize: z = lse([l for l, t, p, k in candidates]) if returnProbabilities: candidates = [(exp(l - z), t, p, k) for l, t, p, k in candidates] else: candidates = [(l - z, t, p, k) for l, t, p, k in candidates] #eprint("candidates inside buildCandidates after norm:") #eprint(candidates) if returnTable: return {p: (l, t, k) for l, t, p, k in candidates} else: return candidates def sample(self, request, maximumDepth=6, maxAttempts=None): attempts = 0 while True: try: _, e = self._sample( request, Context.EMPTY, [], maximumDepth=maximumDepth) return e except NoCandidates: if maxAttempts is not None: attempts += 1 if attempts > maxAttempts: return None continue def _sample(self, request, context, environment, maximumDepth): if request.isArrow(): context, expression = self._sample( request.arguments[1], context, [ request.arguments[0]] + environment, maximumDepth) return context, Abstraction(expression) candidates = self.buildCandidates(request, context, environment, normalize=True, returnProbabilities=True, # Force it to terminate in a # leaf; a primitive with no # function arguments mustBeLeaf=maximumDepth <= 1) #eprint("candidates:") #eprint(candidates) newType, chosenPrimitive, context = sampleDistribution(candidates) # Sample the arguments xs = newType.functionArguments() returnValue = chosenPrimitive for x in xs: x = x.apply(context) context, x = self._sample(x, context, environment, maximumDepth - 1) returnValue = Application(returnValue, x) return context, returnValue def likelihoodSummary(self, context, environment, request, expression, silent=False): if request.isArrow(): if not isinstance(expression, Abstraction): if not silent: eprint("Request is an arrow but I got", expression) return context, None return self.likelihoodSummary(context, [request.arguments[0]] + environment, request.arguments[1], expression.body, silent=silent) # Build the candidates candidates = self.buildCandidates(request, context, environment, normalize=False, returnTable=True) # A list of everything that would have been possible to use here possibles = [p for p in candidates.keys() if not p.isIndex] numberOfVariables = sum(p.isIndex for p in candidates.keys()) if numberOfVariables > 0: possibles += [Index(0)] f, xs = expression.applicationParse() if f not in candidates: if self.continuationType is not None and f.isIndex: ls = LikelihoodSummary() ls.constant = NEGATIVEINFINITY return ls if not silent: eprint(f, "Not in candidates") eprint("Candidates is", candidates) #eprint("grammar:", grammar.productions) eprint("request is", request) eprint("xs", xs) eprint("environment", environment) assert False return context, None thisSummary = LikelihoodSummary() thisSummary.record(f, possibles, constant= -math.log(numberOfVariables) if f.isIndex else 0) _, tp, context = candidates[f] argumentTypes = tp.functionArguments() if len(xs) != len(argumentTypes): eprint("PANIC: not enough arguments for the type") eprint("request", request) eprint("tp", tp) eprint("expression", expression) eprint("xs", xs) eprint("argumentTypes", argumentTypes) # This should absolutely never occur raise GrammarFailure((context, environment, request, expression)) for argumentType, argument in zip(argumentTypes, xs): argumentType = argumentType.apply(context) context, newSummary = self.likelihoodSummary( context, environment, argumentType, argument, silent=silent) if newSummary is None: return context, None thisSummary.join(newSummary) return context, thisSummary def bestFirstEnumeration(self, request): from heapq import heappush, heappop pq = [] def choices(parentCost, xs): for c, x in xs: heappush(pq, (parentCost + c, x)) def g(parentCost, request, _=None, context=None, environment=[], k=None): """ k is a continuation. k: Expects to be called with MDL, context, expression. """ assert k is not None if context is None: context = Context.EMPTY if request.isArrow(): g(parentCost, request.arguments[1], context=context, environment=[request.arguments[0]] + environment, k=lambda MDL, newContext, p: k(MDL, newContext, Abstraction(p))) else: candidates = self.buildCandidates(request, context, environment, normalize=True, returnProbabilities=False, returnTable=True) choices(parentCost, [(-f_ll_tp_newContext[1][0], lambda: ga(parentCost - f_ll_tp_newContext[1][0], f_ll_tp_newContext[0], f_ll_tp_newContext[1][1].functionArguments(), context=f_ll_tp_newContext[1][2], environment=environment, k=k)) for f_ll_tp_newContext in iter(candidates.items())]) def ga(costSoFar, f, argumentTypes, _=None, context=None, environment=None, k=None): if argumentTypes == []: k(costSoFar, context, f) else: t1 = argumentTypes[0].apply(context) g(costSoFar, t1, context=context, environment=environment, k=lambda newCost, newContext, argument: ga(newCost, Application(f, argument), argumentTypes[1:], context=newContext, environment=environment, k=k)) def receiveResult(MDL, _, expression): heappush(pq, (MDL, expression)) g(0., request, context=Context.EMPTY, environment=[], k=receiveResult) frontier = [] while len(frontier) < 10**3: MDL, action = heappop(pq) if isinstance(action, Program): expression = action frontier.append(expression) #eprint("Enumerated program",expression,-MDL,self.closedLogLikelihood(request, expression)) else: action() def closedLikelihoodSummary(self, request, expression, silent=False): try: context, summary = self.likelihoodSummary(Context.EMPTY, [], request, expression, silent=silent) except GrammarFailure as e: failureExport = 'failures/grammarFailure%s.pickle' % ( time.time() + getPID()) eprint("PANIC: Grammar failure, exporting to ", failureExport) with open(failureExport, 'wb') as handle: pickle.dump((e, self, request, expression), handle) assert False return summary def logLikelihood(self, request, expression): summary = self.closedLikelihoodSummary(request, expression) if summary is None: eprint( "FATAL: program [ %s ] does not have a likelihood summary." % expression, "r = ", request, "\n", self) assert False return summary.logLikelihood(self) def rescoreFrontier(self, frontier): return Frontier([FrontierEntry(e.program, logPrior=self.logLikelihood(frontier.task.request, e.program), logLikelihood=e.logLikelihood) for e in frontier], frontier.task) def productionUses(self, frontiers): """Returns the expected number of times that each production was used. {production: expectedUses}""" frontiers = [self.rescoreFrontier(f).normalize() for f in frontiers if not f.empty] uses = {p: 0. for p in self.primitives} for f in frontiers: for e in f: summary = self.closedLikelihoodSummary(f.task.request, e.program) for p, u in summary.uses: uses[p] += u * math.exp(e.logPosterior) return uses def insideOutside(self, frontiers, pseudoCounts, iterations=1): # Replace programs with (likelihood summary, uses) frontiers = [ Frontier([ FrontierEntry((summary, summary.toUses()), logPrior=summary.logLikelihood(self), logLikelihood=e.logLikelihood) for e in f for summary in [self.closedLikelihoodSummary(f.task.request, e.program)] ], task=f.task) for f in frontiers ] g = self for i in range(iterations): u = Uses() for f in frontiers: f = f.normalize() for e in f: _, eu = e.program u += math.exp(e.logPosterior) * eu lv = math.log(u.actualVariables + pseudoCounts) - \ math.log(u.possibleVariables + pseudoCounts) g = Grammar(lv, [ (math.log(u.actualUses.get(p,0.) + pseudoCounts) - \ math.log(u.possibleUses.get(p,0.) + pseudoCounts), t,p) for _,t,p in g.productions ], continuationType=self.continuationType) if i < iterations - 1: frontiers = [Frontier([ FrontierEntry((summary, uses), logPrior=summary.logLikelihood(g), logLikelihood=e.logLikelihood) for e in f for (summary, uses) in [e.program] ], task=f.task) for f in frontiers ] return g def frontierMDL(self, frontier): return max( e.logLikelihood + self.logLikelihood(frontier.task.request, e.program) for e in frontier ) def enumeration(self,context,environment,request,upperBound, maximumDepth=20, lowerBound=0.): '''Enumerates all programs whose MDL satisfies: lowerBound <= MDL < upperBound''' if upperBound < 0 or maximumDepth == 1: return if request.isArrow(): v = request.arguments[0] for l, newContext, b in self.enumeration(context, [v] + environment, request.arguments[1], upperBound=upperBound, lowerBound=lowerBound, maximumDepth=maximumDepth): yield l, newContext, Abstraction(b) else: candidates = self.buildCandidates(request, context, environment, normalize=True) for l, t, p, newContext in candidates: mdl = -l if not (mdl < upperBound): continue xs = t.functionArguments() for aL, aK, application in\ self.enumerateApplication(newContext, environment, p, xs, upperBound=upperBound + l, lowerBound=lowerBound + l, maximumDepth=maximumDepth - 1): yield aL + l, aK, application def enumerateApplication(self, context, environment, function, argumentRequests, # Upper bound on the description length of all of # the arguments upperBound, # Lower bound on the description length of all of # the arguments lowerBound=0., maximumDepth=20, originalFunction=None, argumentIndex=0): if upperBound < 0. or maximumDepth == 1: return if originalFunction is None: originalFunction = function if argumentRequests == []: if lowerBound <= 0. and 0. < upperBound: yield 0., context, function else: return else: argRequest = argumentRequests[0].apply(context) laterRequests = argumentRequests[1:] for argL, newContext, arg in self.enumeration(context, environment, argRequest, upperBound=upperBound, lowerBound=0., maximumDepth=maximumDepth): if violatesSymmetry(originalFunction, arg, argumentIndex): continue newFunction = Application(function, arg) for resultL, resultK, result in self.enumerateApplication(newContext, environment, newFunction, laterRequests, upperBound=upperBound + argL, lowerBound=lowerBound + argL, maximumDepth=maximumDepth, originalFunction=originalFunction, argumentIndex=argumentIndex + 1): yield resultL + argL, resultK, result def sketchEnumeration(self,context,environment,request,sk,upperBound, maximumDepth=20, lowerBound=0.): '''Enumerates all sketch instantiations whose MDL satisfies: lowerBound <= MDL < upperBound''' if upperBound < 0. or maximumDepth == 1: return if sk.isHole: yield from self.enumeration(context, environment, request, upperBound, maximumDepth=maximumDepth, lowerBound=lowerBound) elif request.isArrow(): assert sk.isAbstraction v = request.arguments[0] for l, newContext, b in self.sketchEnumeration(context, [v] + environment, request.arguments[1], sk.body, upperBound=upperBound, lowerBound=lowerBound, maximumDepth=maximumDepth): yield l, newContext, Abstraction(b) else: f, xs = sk.applicationParse() if f.isIndex: ft = environment[f.i].apply(context) elif f.isInvented or f.isPrimitive: context, ft = f.tp.instantiate(context) elif f.isAbstraction: assert False, "sketch is not in beta longform" elif f.isHole: assert False, "hole as function not yet supported" elif f.isApplication: assert False, "should never happen - bug in applicationParse" else: assert False try: context = context.unify(ft.returns(), request) except UnificationFailure: print("Exception: sketch is ill-typed") return #so that we can continue evaluating # raise SketchEnumerationFailure() #"sketch is ill-typed" ft = ft.apply(context) argumentRequests = ft.functionArguments() assert len(argumentRequests) == len(xs) yield from self.sketchApplication(context, environment, f, xs, argumentRequests, upperBound=upperBound, lowerBound=lowerBound, maximumDepth=maximumDepth - 1) def sketchApplication(self, context, environment, function, arguments, argumentRequests, # Upper bound on the description length of all of # the arguments upperBound, # Lower bound on the description length of all of # the arguments lowerBound=0., maximumDepth=20): if upperBound < 0. or maximumDepth == 1: return if argumentRequests == []: if lowerBound <= 0. and 0. < upperBound: yield 0., context, function else: return else: argRequest = argumentRequests[0].apply(context) laterRequests = argumentRequests[1:] firstSketch = arguments[0] laterSketches = arguments[1:] for argL, newContext, arg in self.sketchEnumeration(context, environment, argRequest, firstSketch, upperBound=upperBound, lowerBound=0., maximumDepth=maximumDepth): newFunction = Application(function, arg) for resultL, resultK, result in self.sketchApplication(newContext, environment, newFunction, laterSketches, laterRequests, upperBound=upperBound + argL, lowerBound=lowerBound + argL, maximumDepth=maximumDepth): yield resultL + argL, resultK, result def sketchLogLikelihood(self, request, full, sk, context=Context.EMPTY, environment=[]): """ calculates mdl of full program 'full' from sketch 'sk' """ if sk.isHole: _, summary = self.likelihoodSummary(context, environment, request, full) if summary is None: eprint( "FATAL: program [ %s ] does not have a likelihood summary." % full, "r = ", request, "\n", self) assert False return summary.logLikelihood(self), context elif request.isArrow(): assert sk.isAbstraction and full.isAbstraction #assert sk.f == full.f #is this right? or do i need to recurse? v = request.arguments[0] return self.sketchLogLikelihood(request.arguments[1], full.body, sk.body, context=context, environment=[v] + environment) else: sk_f, sk_xs = sk.applicationParse() full_f, full_xs = full.applicationParse() if sk_f.isIndex: assert sk_f == full_f, "sketch and full program don't match on an index" ft = environment[sk_f.i].apply(context) elif sk_f.isInvented or sk_f.isPrimitive: assert sk_f == full_f, "sketch and full program don't match on a primitive" context, ft = sk_f.tp.instantiate(context) elif sk_f.isAbstraction: assert False, "sketch is not in beta longform" elif sk_f.isHole: assert False, "hole as function not yet supported" elif sk_f.isApplication: assert False, "should never happen - bug in applicationParse" else: assert False try: context = context.unify(ft.returns(), request) except UnificationFailure: assert False, "sketch is ill-typed" ft = ft.apply(context) argumentRequests = ft.functionArguments() assert len(argumentRequests) == len(sk_xs) == len(full_xs) #this might not be true if holes?? return self.sketchllApplication(context, environment, sk_f, sk_xs, full_f, full_xs, argumentRequests) def sketchllApplication(self, context, environment, sk_function, sk_arguments, full_function, full_arguments, argumentRequests): if argumentRequests == []: return torch.tensor([0.]).cuda(), context #does this make sense? else: argRequest = argumentRequests[0].apply(context) laterRequests = argumentRequests[1:] sk_firstSketch = sk_arguments[0] full_firstSketch = full_arguments[0] sk_laterSketches = sk_arguments[1:] full_laterSketches = full_arguments[1:] argL, newContext = self.sketchLogLikelihood(argRequest, full_firstSketch, sk_firstSketch, context=context, environment=environment) #finish this... sk_newFunction = Application(sk_function, sk_firstSketch) # is this redundant? maybe full_newFunction = Application(full_function, full_firstSketch) resultL, context = self.sketchllApplication(newContext, environment, sk_newFunction, sk_laterSketches, full_newFunction, full_laterSketches, laterRequests) return resultL + argL, context def enumerateNearby(self, request, expr, distance=3.0): """Enumerate programs with local mutations in subtrees with small description length""" if distance <= 0: yield expr else: def mutations(tp, loss): for l, _, expr in self.enumeration( Context.EMPTY, [], tp, distance - loss): yield expr, l yield from Mutator(self, mutations).execute(expr, request) def enumerateHoles(self, request, expr, k=3, return_obj=Hole): """Enumerate programs with a single hole within mdl distance""" #TODO: make it possible to enumerate sketches with multiple holes def mutations(tp, loss, is_left_application=False): """ to allow applications lhs to become a hole, remove the condition below and ignore all the is_left_application kwds """ if not is_left_application: yield return_obj(), 0 top_k = [] for expr, l in Mutator(self, mutations).execute(expr, request): if len(top_k) > 0: i, v = min(enumerate(top_k), key=lambda x:x[1][1]) if l > v[1]: if len(top_k) >= k: top_k[i] = (expr, l) else: top_k.append((expr, l)) elif len(top_k) < k: top_k.append((expr, l)) else: top_k.append((expr, l)) return sorted(top_k, key=lambda x:-x[1]) def untorch(self): return Grammar(self.logVariable.data.tolist()[0], [ (l.data.tolist()[0], t, p) for l, t, p in self.productions], continuationType=self.continuationType) class LikelihoodSummary(object): '''Summarizes the terms that will be used in a likelihood calculation''' def __init__(self): self.uses = {} self.normalizers = {} self.constant = 0. def __str__(self): return """LikelihoodSummary(constant = %f, uses = {%s}, normalizers = {%s})""" % (self.constant, ", ".join( "%s: %d" % (k, v) for k, v in self.uses.items()), ", ".join( "%s: %d" % (k, v) for k, v in self.normalizers.items())) def record(self, actual, possibles, constant=0.): # Variables are all normalized to be $0 if isinstance(actual, Index): actual = Index(0) # Make it something that we can hash possibles = frozenset(sorted(possibles, key=hash)) self.constant += constant self.uses[actual] = self.uses.get(actual, 0) + 1 self.normalizers[possibles] = self.normalizers.get(possibles, 0) + 1 def join(self, other): self.constant += other.constant for k, v in other.uses.items(): self.uses[k] = self.uses.get(k, 0) + v for k, v in other.normalizers.items(): self.normalizers[k] = self.normalizers.get(k, 0) + v def logLikelihood(self, grammar): return self.constant + \ sum(count * grammar.expression2likelihood[p] for p, count in self.uses.items()) - \ sum(count * lse([grammar.expression2likelihood[p] for p in ps]) for ps, count in self.normalizers.items()) def logLikelihood_overlyGeneral(self, grammar): """Calculates log likelihood of this summary, given that the summary might refer to productions that don't occur in the grammar""" return self.constant + \ sum(count * grammar.expression2likelihood[p] for p, count in self.uses.items()) - \ sum(count * lse([grammar.expression2likelihood.get(p,NEGATIVEINFINITY) for p in ps]) for ps, count in self.normalizers.items()) def numerator(self, grammar): return self.constant + \ sum(count * grammar.expression2likelihood[p] for p, count in self.uses.items()) def denominator(self, grammar): return \ sum(count * lse([grammar.expression2likelihood[p] for p in ps]) for ps, count in self.normalizers.items()) def toUses(self): from collections import Counter possibleVariables = sum( count if Index(0) in ps else 0 for ps, count in self.normalizers.items() ) actualVariables = self.uses.get(Index(0), 0.) actualUses = {k: v for k, v in self.uses.items() if not k.isIndex } possibleUses = dict(Counter(p for ps, count in self.normalizers.items() for p_ in ps if not p_.isIndex for p in [p_]*count )) return Uses(possibleVariables, actualVariables, possibleUses, actualUses) class Uses(object): '''Tracks uses of different grammar productions''' def __init__(self, possibleVariables=0., actualVariables=0., possibleUses={}, actualUses={}): self.actualVariables = actualVariables self.possibleVariables = possibleVariables self.possibleUses = possibleUses self.actualUses = actualUses def __str__(self): return "Uses(actualVariables = %f, possibleVariables = %f, actualUses = %s, possibleUses = %s)" %\ (self.actualVariables, self.possibleVariables, self.actualUses, self.possibleUses) def __repr__(self): return str(self) def __mul__(self, a): return Uses(a * self.possibleVariables, a * self.actualVariables, {p: a * u for p, u in self.possibleUses.items()}, {p: a * u for p, u in self.actualUses.items()}) def __imul__(self, a): self.possibleVariables *= a self.actualVariables *= a for p in self.possibleUses: self.possibleUses[p] *= a for p in self.actualUses: self.actualUses[p] *= a return self def __rmul__(self, a): return self * a def __radd__(self, o): if o == 0: return self return self + o def __add__(self, o): if o == 0: return self def merge(x, y): z = x.copy() for k, v in y.items(): z[k] = v + x.get(k, 0.) return z return Uses(self.possibleVariables + o.possibleVariables, self.actualVariables + o.actualVariables, merge(self.possibleUses, o.possibleUses), merge(self.actualUses, o.actualUses)) def __iadd__(self, o): self.possibleVariables += o.possibleVariables self.actualVariables += o.actualVariables for k, v in o.possibleUses.items(): self.possibleUses[k] = self.possibleUses.get(k, 0.) + v for k, v in o.actualUses.items(): self.actualUses[k] = self.actualUses.get(k, 0.) + v return self @staticmethod def join(z, *weightedUses): """Consumes weightedUses""" if not weightedUses: Uses.empty if len(weightedUses) == 1: return weightedUses[0][1] for w, u in weightedUses: u *= exp(w - z) total = Uses() total.possibleVariables = sum( u.possibleVariables for _, u in weightedUses) total.actualVariables = sum(u.actualVariables for _, u in weightedUses) total.possibleUses = defaultdict(float) total.actualUses = defaultdict(float) for _, u in weightedUses: for k, v in u.possibleUses.items(): total.possibleUses[k] += v for k, v in u.actualUses.items(): total.actualUses[k] += v return total Uses.empty = Uses() class ContextualGrammar: def __init__(self, noParent, variableParent, library): self.noParent, self.variableParent, self.library = noParent, variableParent, library self.productions = [(None,t,p) for _,t,p in self.noParent.productions ] self.primitives = [p for _,_2,p in self.productions ] self.continuationType = noParent.continuationType assert variableParent.continuationType == self.continuationType assert set(noParent.primitives) == set(variableParent.primitives) assert set(variableParent.primitives) == set(library.keys()) for e,gs in library.items(): assert len(gs) == len(e.infer().functionArguments()) for g in gs: assert set(g.primitives) == set(library.keys()) assert g.continuationType == self.continuationType def untorch(self): return ContextualGrammar(self.noParent.untorch(), self.variableParent.untorch(), {e: [g.untorch() for g in gs ] for e,gs in self.library.items() }) def randomWeights(self, r): """returns a new grammar with random weights drawn from r. calls `r` w/ old weight""" return ContextualGrammar(self.noParent.randomWeights(r), self.variableParent.randomWeights(r), {e: [g.randomWeights(r) for g in gs] for e,gs in self.library.items() }) def __str__(self): lines = ["No parent:",str(self.noParent),"", "Variable parent:",str(self.variableParent),"", ""] for e,gs in self.library.items(): for j,g in enumerate(gs): lines.extend(["Parent %s, argument index %s"%(e,j), str(g), ""]) return "\n".join(lines) def json(self): return {"noParent": self.noParent.json(), "variableParent": self.variableParent.json(), "productions": [{"program": str(e), "arguments": [gp.json() for gp in gs ]} for e,gs in self.library.items() ]} @staticmethod def fromGrammar(g): return ContextualGrammar(g, g, {e: [g]*len(e.infer().functionArguments()) for e in g.primitives }) class LS: # likelihood summary def __init__(self, owner): self.noParent = LikelihoodSummary() self.variableParent = LikelihoodSummary() self.library = {e: [LikelihoodSummary() for _ in gs] for e,gs in owner.library.items() } def record(self, parent, parentIndex, actual, possibles, constant): if parent is None: ls = self.noParent elif parent.isIndex: ls = self.variableParent else: ls = self.library[parent][parentIndex] ls.record(actual, possibles, constant=constant) def join(self, other): self.noParent.join(other.noParent) self.variableParent.join(other.variableParent) for e,gs in self.library.items(): for g1,g2 in zip(gs, other.library[e]): g1.join(g2) def logLikelihood(self, owner): return self.noParent.logLikelihood(owner.noParent) + \ self.variableParent.logLikelihood(owner.variableParent) + \ sum(r.logLikelihood(g) for e, rs in self.library.items() for r,g in zip(rs, owner.library[e]) ) def numerator(self, owner): return self.noParent.numerator(owner.noParent) + \ self.variableParent.numerator(owner.variableParent) + \ sum(r.numerator(g) for e, rs in self.library.items() for r,g in zip(rs, owner.library[e]) ) def denominator(self, owner): return self.noParent.denominator(owner.noParent) + \ self.variableParent.denominator(owner.variableParent) + \ sum(r.denominator(g) for e, rs in self.library.items() for r,g in zip(rs, owner.library[e]) ) def likelihoodSummary(self, parent, parentIndex, context, environment, request, expression): if request.isArrow(): assert expression.isAbstraction return self.likelihoodSummary(parent, parentIndex, context, [request.arguments[0]] + environment, request.arguments[1], expression.body) if parent is None: g = self.noParent elif parent.isIndex: g = self.variableParent else: g = self.library[parent][parentIndex] candidates = g.buildCandidates(request, context, environment, normalize=False, returnTable=True) # A list of everything that would have been possible to use here possibles = [p for p in candidates.keys() if not p.isIndex] numberOfVariables = sum(p.isIndex for p in candidates.keys()) if numberOfVariables > 0: possibles += [Index(0)] f, xs = expression.applicationParse() assert f in candidates thisSummary = self.LS(self) thisSummary.record(parent, parentIndex, f, possibles, constant= -math.log(numberOfVariables) if f.isIndex else 0) _, tp, context = candidates[f] argumentTypes = tp.functionArguments() assert len(xs) == len(argumentTypes) for i, (argumentType, argument) in enumerate(zip(argumentTypes, xs)): argumentType = argumentType.apply(context) context, newSummary = self.likelihoodSummary(f, i, context, environment, argumentType, argument) thisSummary.join(newSummary) return context, thisSummary def closedLikelihoodSummary(self, request, expression): return self.likelihoodSummary(None,None, Context.EMPTY,[], request, expression)[1] def logLikelihood(self, request, expression): return self.closedLikelihoodSummary(request, expression).logLikelihood(self) def sample(self, request, maximumDepth=8, maxAttempts=None): attempts = 0 while True: try: _, e = self._sample(None, None, Context.EMPTY, [], request, maximumDepth) return e except NoCandidates: if maxAttempts is not None: attempts += 1 if attempts > maxAttempts: return None continue def _sample(self, parent, parentIndex, context, environment, request, maximumDepth): if request.isArrow(): context, body = self._sample(parent, parentIndex, context, [request.arguments[0]] + environment, request.arguments[1], maximumDepth) return context, Abstraction(body) if parent is None: g = self.noParent elif parent.isIndex: g = self.variableParent else: g = self.library[parent][parentIndex] candidates = g.buildCandidates(request, context, environment, normalize=True, returnProbabilities=True, mustBeLeaf=(maximumDepth <= 1)) newType, chosenPrimitive, context = sampleDistribution(candidates) xs = newType.functionArguments() returnValue = chosenPrimitive for j,x in enumerate(xs): x = x.apply(context) context, x = self._sample(chosenPrimitive, j, context, environment, x, maximumDepth - 1) returnValue = Application(returnValue, x) return context, returnValue def expectedUsesMonteCarlo(self, request, debug=None): import numpy as np n = 0 u = [0.]*len(self.primitives) primitives = list(sorted(self.primitives, key=str)) noInventions = all( not p.isInvented for p in primitives ) primitive2index = {primitive: i for i, primitive in enumerate(primitives) if primitive.isInvented or noInventions } eprint(primitive2index) ns = 10000 with timing(f"calculated expected uses using Monte Carlo simulation w/ {ns} samples"): for _ in range(ns): p = self.sample(request, maxAttempts=0) if p is None: continue n += 1 if debug and n < 10: eprint(debug, p) for _, child in p.walk(): if child not in primitive2index: continue u[primitive2index[child]] += 1.0 u = np.array(u)/n if debug: eprint(f"Got {n} samples. Feature vector:\n{u}") eprint(f"Likely used primitives: {[p for p,i in primitive2index.items() if u[i] > 0.5]}") eprint(f"Likely used primitive indices: {[i for p,i in primitive2index.items() if u[i] > 0.5]}") return u def featureVector(self, _=None, requests=None, onlyInventions=True, normalize=True): """ Returns the probabilities licensed by the type system. This is like the grammar productions, but with irrelevant junk removed. Its intended use case is for clustering; it should be strictly better than the raw transition matrix. """ if requests is None: if self.continuationType: requests = {self.continuationType} elif any( 'REAL' == str(p) for p in self.primitives ): requests = set() elif any( 'STRING' == str(p) for p in self.primitives ): requests = {tlist(tcharacter)} else: requests = set() requests = {r.returns() for r in requests} features = [] logWeights = [] for l,t,p in sorted(self.noParent.productions, key=lambda z: str(z[2])): if onlyInventions and not p.isInvented: continue if any( canUnify(r, t.returns()) for r in requests ) or len(requests) == 0: logWeights.append(l) features.append(logWeights) for parent in sorted(self.primitives, key=str): if onlyInventions and not parent.isInvented: continue if parent not in self.library: continue argumentTypes = parent.infer().functionArguments() for j,g in enumerate(self.library[parent]): argumentType = argumentTypes[j] logWeights = [] for l,t,p in sorted(g.productions, key=lambda z: str(z[2])): if onlyInventions and not p.isInvented: continue if canUnify(argumentType.returns(), t.returns()): logWeights.append(l) features.append(logWeights) if normalize: features = [ [math.exp(w - z) for w in lw ] for lw in features if lw for z in [lse(lw)] ] import numpy as np return np.array([f for lw in features for f in lw]) def enumeration(self,context,environment,request,upperBound, parent=None, parentIndex=None, maximumDepth=20, lowerBound=0.): '''Enumerates all programs whose MDL satisfies: lowerBound <= MDL < upperBound''' if upperBound < 0 or maximumDepth == 1: return if request.isArrow(): v = request.arguments[0] for l, newContext, b in self.enumeration(context, [v] + environment, request.arguments[1], parent=parent, parentIndex=parentIndex, upperBound=upperBound, lowerBound=lowerBound, maximumDepth=maximumDepth): yield l, newContext, Abstraction(b) else: if parent is None: g = self.noParent elif parent.isIndex: g = self.variableParent else: g = self.library[parent][parentIndex] candidates = g.buildCandidates(request, context, environment, normalize=True) for l, t, p, newContext in candidates: mdl = -l if not (mdl < upperBound): continue xs = t.functionArguments() for aL, aK, application in\ self.enumerateApplication(newContext, environment, p, xs, parent=p, upperBound=upperBound + l, lowerBound=lowerBound + l, maximumDepth=maximumDepth - 1): yield aL + l, aK, application def enumerateApplication(self, context, environment, function, argumentRequests, # Upper bound on the description length of all of # the arguments upperBound, # Lower bound on the description length of all of # the arguments lowerBound=0., maximumDepth=20, parent=None, originalFunction=None, argumentIndex=0): assert parent is not None if upperBound < 0. or maximumDepth == 1: return if originalFunction is None: originalFunction = function if argumentRequests == []: if lowerBound <= 0. and 0. < upperBound: yield 0., context, function else: return else: argRequest = argumentRequests[0].apply(context) laterRequests = argumentRequests[1:] for argL, newContext, arg in self.enumeration(context, environment, argRequest, parent=parent, parentIndex=argumentIndex, upperBound=upperBound, lowerBound=0., maximumDepth=maximumDepth): if violatesSymmetry(originalFunction, arg, argumentIndex): continue newFunction = Application(function, arg) for resultL, resultK, result in self.enumerateApplication(newContext, environment, newFunction, laterRequests, parent=parent, upperBound=upperBound + argL, lowerBound=lowerBound + argL, maximumDepth=maximumDepth, originalFunction=originalFunction, argumentIndex=argumentIndex + 1): yield resultL + argL, resultK, result def violatesSymmetry(f, x, argumentIndex): if not f.isPrimitive: return False while x.isApplication: x = x.f if not x.isPrimitive: return False f = f.name x = x.name if f == "car": return x == "cons" or x == "empty" if f == "cdr": return x == "cons" or x == "empty" if f == "+": return x == "0" or (argumentIndex == 1 and x == "+") if f == "-": return argumentIndex == 1 and x == "0" if f == "empty?": return x == "cons" or x == "empty" if f == "zero?": return x == "0" or x == "1" if f == "index" or f == "map" or f == "zip": return x == "empty" if f == "range": return x == "0" if f == "fold": return argumentIndex == 1 and x == "empty" return False def batchLikelihood(jobs): """Takes as input a set of (program, request, grammar) and returns a dictionary mapping each of these to its likelihood under the grammar""" superGrammar = Grammar.uniform(list({p for _1,_2,g in jobs for p in g.primitives}), continuationType=list(jobs)[0][-1].continuationType) programsAndRequests = {(program, request) for program, request, grammar in jobs} with timing(f"Calculated {len(programsAndRequests)} likelihood summaries"): summary = {(program, request): superGrammar.closedLikelihoodSummary(request, program) for program, request in programsAndRequests} with timing(f"Calculated log likelihoods from summaries"): response = {} for program, request, grammar in jobs: fast = summary[(program, request)].logLikelihood_overlyGeneral(grammar) if False: # debugging slow = grammar.logLikelihood(request, program) print(program) eprint(grammar.closedLikelihoodSummary(request, program)) eprint(superGrammar.closedLikelihoodSummary(request, program)) print() assert abs(fast - slow) < 0.0001 response[(program, request, grammar)] = fast return response if __name__ == "__main__": from dreamcoder.domains.arithmetic.arithmeticPrimitives import * g = ContextualGrammar.fromGrammar(Grammar.uniform([k0,k1,addition, subtraction])) g = g.randomWeights(lambda *a: random.random()) #p = Program.parse("(lambda (+ 1 $0))") request = arrow(tint,tint) for ll,_,p in g.enumeration(Context.EMPTY,[],request, 12.): ll_ = g.logLikelihood(request,p) print(ll,p,ll_) d = abs(ll - ll_) assert d < 0.0001