Spaces:
Sleeping
Sleeping
File size: 2,128 Bytes
9bdaa77 c46567d 9bdaa77 c46567d 9bdaa77 c46567d 9bdaa77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converting a RaspExpr to a graph."""
import dataclasses
import queue
from typing import List
import networkx as nx
from tracr.compiler import nodes
from tracr.rasp import rasp
Node = nodes.Node
NodeID = nodes.NodeID
@dataclasses.dataclass
class ExtractRaspGraphOutput:
graph: nx.DiGraph
sink: Node # the program's output.
sources: List[Node] # the primitive S-Ops.
def extract_rasp_graph(tip: rasp.SOp) -> ExtractRaspGraphOutput:
"""Converts a RASP program into a graph representation."""
expr_queue = queue.Queue()
graph = nx.DiGraph()
sources: List[NodeID] = []
def ensure_node(expr: rasp.RASPExpr) -> NodeID:
"""Finds or creates a graph node corresponding to expr; returns its ID."""
node_id = expr.label
if node_id not in graph:
graph.add_node(node_id, **{nodes.ID: node_id, nodes.EXPR: expr})
return node_id
# Breadth-first search over the RASP expression graph.
def visit_raspexpr(expr: rasp.RASPExpr):
parent_id = ensure_node(expr)
for child_expr in expr.children:
expr_queue.put(child_expr)
child_id = ensure_node(child_expr)
graph.add_edge(child_id, parent_id)
if not expr.children:
sources.append(graph.nodes[parent_id])
expr_queue.put(tip)
sink = graph.nodes[ensure_node(tip)]
while not expr_queue.empty():
visit_raspexpr(expr_queue.get())
return ExtractRaspGraphOutput(graph=graph, sink=sink, sources=sources)
|