Spaces:
Runtime error
Runtime error
import time | |
from collections import defaultdict | |
from functools import partial | |
from typing import DefaultDict | |
import torch | |
# Unfortunately it doesn't seem as if there was any way to get TensorBoard to do | |
# anything without having TF installed, and so this file has a hard dependency on it | |
# as well. It really is a debugging tool, so it doesn't matter. | |
try: | |
from tensorflow.core.util import event_pb2 | |
from tensorflow.core.framework import graph_pb2 | |
from tensorflow.python.summary.writer.writer import FileWriter | |
except ImportError: | |
raise ImportError("TensorBoard visualization of GraphExecutors requires having " | |
"TensorFlow installed") from None | |
def dump_tensorboard_summary(graph_executor, logdir): | |
with FileWriter(logdir) as w: | |
pb_graph = visualize(graph_executor) | |
evt = event_pb2.Event(wall_time=time.time(), graph_def=pb_graph.SerializeToString()) | |
w.add_event(evt) | |
def visualize(graph, name_prefix='', pb_graph=None, executors_it=None): | |
"""Visualizes an independent graph, or a graph executor.""" | |
value_map = {} | |
pb_graph = pb_graph or graph_pb2.GraphDef() | |
if isinstance(graph, torch._C.GraphExecutorState): | |
visualize_graph_executor(graph, name_prefix, pb_graph, | |
partial(visualize, pb_graph=pb_graph)) | |
return pb_graph | |
# Set up an input node | |
input_node = pb_graph.node.add(op='input', name=name_prefix + 'input') | |
for i, value in enumerate(graph.param_node().outputs()): | |
value_map[value.unique()] = name_prefix + 'input:' + str(i) | |
visualize_rec(graph, value_map, name_prefix, pb_graph, executors_it) | |
# Gather all outputs | |
return_node = pb_graph.node.add(op='output', name=name_prefix + 'output') | |
for value in graph.return_node().inputs(): | |
return_node.input.append(value_map[value.unique()]) | |
return pb_graph | |
def visualize_graph_executor(state, name_prefix, pb_graph, inline_graph): | |
"""Append the state of a given GraphExecutor to the graph protobuf. | |
Args: | |
state (GraphExecutor or GraphExecutorState): GraphExecutor to display. | |
name_prefix (str): Name prefix of the containing subgraph. | |
pb_graph (GraphDef): graph to append to. | |
inline_graph (Callable): a function that handles setting up a value_map, | |
so that some graphs in here can be inlined. This is necessary, because | |
this will simply be `visualize` for the top-level GraphExecutor, | |
or `inline_graph` for all nested ones. | |
The signature should look like (Graph, name_prefix) -> (). | |
It will be called exactly once. | |
The strategy is to embed all different configurations as independent subgraphs, | |
while inlining the original graph as the one that actually produces the values. | |
""" | |
if state.autograd_fallback_graph is not None: | |
visualize(graph=state.autograd_fallback_graph, | |
name_prefix=name_prefix + 'autograd_fallback/', | |
pb_graph=pb_graph, | |
executors_it=iter(state.autograd_fallback.executors())) | |
for i, (arg_spec, plan) in enumerate(state.execution_plans.items()): | |
subgraph_name = name_prefix + f'plan{i}/' | |
# Create a disconnected node that will keep information regarding the input | |
# types of this trace. This is unfortunately a bit too verbose to be included | |
# in the subgraph name. | |
input_kinds = pb_graph.node.add(op='INPUT_KIND', name=subgraph_name) | |
input_kinds.attr['inputs'].s = repr(arg_spec).encode('ascii') | |
visualize(plan.graph, subgraph_name, pb_graph, iter(plan.code.executors())) | |
# Show gradient as an independent subgraph of this plan | |
if plan.grad_executor is not None: | |
grad_subgraph_name = subgraph_name + 'grad/' | |
visualize(plan.grad_executor, grad_subgraph_name, pb_graph) | |
return inline_graph(state.graph, name_prefix + 'original/') | |
def visualize_rec(graph, value_map, name_prefix, pb_graph, executors_it=None): | |
"""Recursive part of visualize (basically skips setting up the input and output nodes).""" | |
def inline_graph(subgraph, name, node): | |
rec_value_map = {inp.unique(): value_map[val.unique()] | |
for inp, val in zip(subgraph.inputs(), node.inputs())} | |
visualize_rec(graph=subgraph, | |
value_map=rec_value_map, | |
name_prefix=name, | |
pb_graph=pb_graph) | |
for out, val in zip(subgraph.outputs(), node.outputs()): | |
value_map[val.unique()] = rec_value_map[out.unique()] | |
op_id_counter: DefaultDict[str, int] = defaultdict(int) | |
def name_for(node): | |
kind = node.kind()[node.kind().index('::') + 2:] | |
op_id_counter[kind] += 1 | |
return kind, name_prefix + kind + '_' + str(op_id_counter[kind]) | |
def add_fusion_group(node): | |
op, name = name_for(node) | |
inline_graph(node.g('Subgraph'), name + '/', node) | |
def add_graph_executor(node): | |
op, name = name_for(node) | |
if executors_it is None: | |
add_node(node) | |
else: | |
ge = next(executors_it) | |
visualize_graph_executor(ge, name + '/', pb_graph, | |
partial(inline_graph, node=node)) | |
def add_node(node): | |
if node.kind() == 'prim::FusionGroup': | |
return add_fusion_group(node) | |
elif node.kind() == 'prim::GraphExecutor': | |
return add_graph_executor(node) | |
op, name = name_for(node) | |
pb_node = pb_graph.node.add(op=op, name=name) | |
for value in node.inputs(): | |
pb_node.input.append(value_map[value.unique()]) | |
# TODO: handle attrs | |
for i, value in enumerate(node.outputs()): | |
value_map[value.unique()] = name + ':' + str(i) | |
for node in graph.nodes(): | |
add_node(node) | |