|
|
|
|
|
|
|
import graphviz |
|
|
|
|
|
def convert_to_visualize(graph_ir, vgraph): |
|
for name, graph in graph_ir.items(): |
|
if name == '_evaluator': |
|
continue |
|
with vgraph.subgraph(name='cluster'+name) as subgraph: |
|
subgraph.attr(color='blue') |
|
cell_node = {} |
|
ioput = {'_inputs': '{}-{}'.format(name, '_'.join(graph['inputs'])), |
|
'_outputs': '{}-{}'.format(name, '_'.join(graph['outputs']))} |
|
subgraph.node(ioput['_inputs']) |
|
subgraph.node(ioput['_outputs']) |
|
for node_name, node_value in graph['nodes'].items(): |
|
value = node_value['operation'] |
|
if value['type'] == '_cell': |
|
cell_input_name = '{}-{}'.format(value['cell_name'], '_'.join(graph_ir[value['cell_name']]['inputs'])) |
|
cell_output_name = '{}-{}'.format(value['cell_name'], '_'.join(graph_ir[value['cell_name']]['outputs'])) |
|
cell_node[node_name] = (cell_input_name, cell_output_name) |
|
print('cell: ', node_name, cell_input_name, cell_output_name) |
|
else: |
|
subgraph.node(node_name) |
|
for edge in graph['edges']: |
|
src = edge['head'][0] |
|
if src == '_inputs': |
|
src = ioput['_inputs'] |
|
elif src in cell_node: |
|
src = cell_node[src][1] |
|
dst = edge['tail'][0] |
|
if dst == '_outputs': |
|
dst = ioput['_outputs'] |
|
elif dst in cell_node: |
|
dst = cell_node[dst][0] |
|
subgraph.edge(src, dst) |
|
|
|
|
|
def visualize_model(graph_ir): |
|
vgraph = graphviz.Digraph('G', filename='vgraph', format='jpg') |
|
convert_to_visualize(graph_ir, vgraph) |
|
vgraph.render() |
|
|