| | """Utilities related to model visualization.""" |
| |
|
| | import os |
| | import sys |
| |
|
| | from keras.src import tree |
| | from keras.src.api_export import keras_export |
| | from keras.src.utils import io_utils |
| |
|
| | try: |
| | import pydot |
| | except ImportError: |
| | |
| | |
| | try: |
| | import pydot_ng as pydot |
| | except ImportError: |
| | try: |
| | import pydotplus as pydot |
| | except ImportError: |
| | pydot = None |
| |
|
| |
|
| | def check_pydot(): |
| | """Returns True if PyDot is available.""" |
| | return pydot is not None |
| |
|
| |
|
| | def check_graphviz(): |
| | """Returns True if both PyDot and Graphviz are available.""" |
| | if not check_pydot(): |
| | return False |
| | try: |
| | |
| | |
| | pydot.Dot.create(pydot.Dot()) |
| | return True |
| | except (OSError, pydot.PydotException): |
| | return False |
| |
|
| |
|
| | def add_edge(dot, src, dst): |
| | src_id = str(id(src)) |
| | dst_id = str(id(dst)) |
| | if not dot.get_edge(src_id, dst_id): |
| | edge = pydot.Edge(src_id, dst_id) |
| | edge.set("penwidth", "2") |
| | dot.add_edge(edge) |
| |
|
| |
|
| | def get_layer_activation_name(layer): |
| | if hasattr(layer.activation, "name"): |
| | activation_name = layer.activation.name |
| | elif hasattr(layer.activation, "__name__"): |
| | activation_name = layer.activation.__name__ |
| | else: |
| | activation_name = str(layer.activation) |
| | return activation_name |
| |
|
| |
|
| | def make_layer_label(layer, **kwargs): |
| | class_name = layer.__class__.__name__ |
| |
|
| | show_layer_names = kwargs.pop("show_layer_names") |
| | show_layer_activations = kwargs.pop("show_layer_activations") |
| | show_dtype = kwargs.pop("show_dtype") |
| | show_shapes = kwargs.pop("show_shapes") |
| | show_trainable = kwargs.pop("show_trainable") |
| | if kwargs: |
| | raise ValueError(f"Invalid kwargs: {kwargs}") |
| |
|
| | table = ( |
| | '<<table border="0" cellborder="1" bgcolor="black" cellpadding="10">' |
| | ) |
| |
|
| | colspan_max = sum(int(x) for x in (show_dtype, show_trainable)) |
| | if show_shapes: |
| | colspan_max += 2 |
| | colspan = max(1, colspan_max) |
| |
|
| | if show_layer_names: |
| | table += ( |
| | f'<tr><td colspan="{colspan}" bgcolor="black">' |
| | '<font point-size="16" color="white">' |
| | f"<b>{layer.name}</b> ({class_name})" |
| | "</font></td></tr>" |
| | ) |
| | else: |
| | table += ( |
| | f'<tr><td colspan="{colspan}" bgcolor="black">' |
| | '<font point-size="16" color="white">' |
| | f"<b>{class_name}</b>" |
| | "</font></td></tr>" |
| | ) |
| | if ( |
| | show_layer_activations |
| | and hasattr(layer, "activation") |
| | and layer.activation is not None |
| | ): |
| | table += ( |
| | f'<tr><td bgcolor="white" colspan="{colspan}">' |
| | '<font point-size="14">' |
| | f"Activation: <b>{get_layer_activation_name(layer)}</b>" |
| | "</font></td></tr>" |
| | ) |
| |
|
| | cols = [] |
| | if show_shapes: |
| | input_shape = None |
| | output_shape = None |
| | try: |
| | input_shape = tree.map_structure(lambda x: x.shape, layer.input) |
| | output_shape = tree.map_structure(lambda x: x.shape, layer.output) |
| | except (ValueError, AttributeError): |
| | pass |
| |
|
| | def format_shape(shape): |
| | if shape is not None: |
| | if isinstance(shape, dict): |
| | shape_str = ", ".join( |
| | [f"{k}: {v}" for k, v in shape.items()] |
| | ) |
| | else: |
| | shape_str = f"{shape}" |
| | shape_str = shape_str.replace("}", "").replace("{", "") |
| | else: |
| | shape_str = "?" |
| | return shape_str |
| |
|
| | if class_name != "InputLayer": |
| | cols.append( |
| | ( |
| | '<td bgcolor="white"><font point-size="14">' |
| | f"Input shape: <b>{format_shape(input_shape)}</b>" |
| | "</font></td>" |
| | ) |
| | ) |
| | cols.append( |
| | ( |
| | '<td bgcolor="white"><font point-size="14">' |
| | f"Output shape: <b>{format_shape(output_shape)}</b>" |
| | "</font></td>" |
| | ) |
| | ) |
| | if show_dtype: |
| | dtype = None |
| | try: |
| | dtype = tree.map_structure(lambda x: x.dtype, layer.output) |
| | except (ValueError, AttributeError): |
| | pass |
| | cols.append( |
| | ( |
| | '<td bgcolor="white"><font point-size="14">' |
| | f"Output dtype: <b>{dtype or '?'}</b>" |
| | "</font></td>" |
| | ) |
| | ) |
| | if show_trainable and hasattr(layer, "trainable") and layer.weights: |
| | if layer.trainable: |
| | cols.append( |
| | ( |
| | '<td bgcolor="forestgreen">' |
| | '<font point-size="14" color="white">' |
| | "<b>Trainable</b></font></td>" |
| | ) |
| | ) |
| | else: |
| | cols.append( |
| | ( |
| | '<td bgcolor="firebrick">' |
| | '<font point-size="14" color="white">' |
| | "<b>Non-trainable</b></font></td>" |
| | ) |
| | ) |
| | if cols: |
| | colspan = len(cols) |
| | else: |
| | colspan = 1 |
| |
|
| | if cols: |
| | table += "<tr>" + "".join(cols) + "</tr>" |
| | table += "</table>>" |
| | return table |
| |
|
| |
|
| | def make_node(layer, **kwargs): |
| | node = pydot.Node(str(id(layer)), label=make_layer_label(layer, **kwargs)) |
| | node.set("fontname", "Helvetica") |
| | node.set("border", "0") |
| | node.set("margin", "0") |
| | return node |
| |
|
| |
|
| | @keras_export("keras.utils.model_to_dot") |
| | def model_to_dot( |
| | model, |
| | show_shapes=False, |
| | show_dtype=False, |
| | show_layer_names=True, |
| | rankdir="TB", |
| | expand_nested=False, |
| | dpi=200, |
| | subgraph=False, |
| | show_layer_activations=False, |
| | show_trainable=False, |
| | **kwargs, |
| | ): |
| | """Convert a Keras model to dot format. |
| | |
| | Args: |
| | model: A Keras model instance. |
| | show_shapes: whether to display shape information. |
| | show_dtype: whether to display layer dtypes. |
| | show_layer_names: whether to display layer names. |
| | rankdir: `rankdir` argument passed to PyDot, |
| | a string specifying the format of the plot: `"TB"` |
| | creates a vertical plot; `"LR"` creates a horizontal plot. |
| | expand_nested: whether to expand nested Functional models |
| | into clusters. |
| | dpi: Image resolution in dots per inch. |
| | subgraph: whether to return a `pydot.Cluster` instance. |
| | show_layer_activations: Display layer activations (only for layers that |
| | have an `activation` property). |
| | show_trainable: whether to display if a layer is trainable. |
| | |
| | Returns: |
| | A `pydot.Dot` instance representing the Keras model or |
| | a `pydot.Cluster` instance representing nested model if |
| | `subgraph=True`. |
| | """ |
| | from keras.src.ops.function import make_node_key |
| |
|
| | if not model.built: |
| | raise ValueError( |
| | "This model has not yet been built. " |
| | "Build the model first by calling `build()` or by calling " |
| | "the model on a batch of data." |
| | ) |
| |
|
| | from keras.src.models import functional |
| | from keras.src.models import sequential |
| |
|
| | |
| |
|
| | if not check_pydot(): |
| | raise ImportError( |
| | "You must install pydot (`pip install pydot`) for " |
| | "model_to_dot to work." |
| | ) |
| |
|
| | if subgraph: |
| | dot = pydot.Cluster(style="dashed", graph_name=model.name) |
| | dot.set("label", model.name) |
| | dot.set("labeljust", "l") |
| | else: |
| | dot = pydot.Dot() |
| | dot.set("rankdir", rankdir) |
| | dot.set("concentrate", True) |
| | dot.set("dpi", dpi) |
| | dot.set("splines", "ortho") |
| | dot.set_node_defaults(shape="record") |
| |
|
| | if kwargs.pop("layer_range", None) is not None: |
| | raise ValueError("Argument `layer_range` is no longer supported.") |
| | if kwargs: |
| | raise ValueError(f"Unrecognized keyword arguments: {kwargs}") |
| |
|
| | kwargs = { |
| | "show_layer_names": show_layer_names, |
| | "show_layer_activations": show_layer_activations, |
| | "show_dtype": show_dtype, |
| | "show_shapes": show_shapes, |
| | "show_trainable": show_trainable, |
| | } |
| |
|
| | if isinstance(model, sequential.Sequential): |
| | layers = model.layers |
| | elif not isinstance(model, functional.Functional): |
| | |
| | node = make_node(model, **kwargs) |
| | dot.add_node(node) |
| | return dot |
| | else: |
| | layers = model._operations |
| |
|
| | |
| | for i, layer in enumerate(layers): |
| | |
| | if expand_nested and isinstance( |
| | layer, (functional.Functional, sequential.Sequential) |
| | ): |
| | submodel = model_to_dot( |
| | layer, |
| | show_shapes, |
| | show_dtype, |
| | show_layer_names, |
| | rankdir, |
| | expand_nested, |
| | subgraph=True, |
| | show_layer_activations=show_layer_activations, |
| | show_trainable=show_trainable, |
| | ) |
| | dot.add_subgraph(submodel) |
| |
|
| | else: |
| | node = make_node(layer, **kwargs) |
| | dot.add_node(node) |
| |
|
| | |
| | if isinstance(model, sequential.Sequential): |
| | if not expand_nested: |
| | |
| | for i in range(len(layers) - 1): |
| | add_edge(dot, layers[i], layers[i + 1]) |
| | return dot |
| | else: |
| | |
| | |
| | |
| | |
| | layers = model.layers[1:] |
| |
|
| | |
| | for layer in layers: |
| | |
| | for inbound_index, inbound_node in enumerate(layer._inbound_nodes): |
| | |
| | if ( |
| | isinstance(model, functional.Functional) |
| | and make_node_key(layer, inbound_index) not in model._nodes |
| | ): |
| | continue |
| |
|
| | |
| | for input_index, input_tensor in enumerate( |
| | inbound_node.input_tensors |
| | ): |
| | |
| | |
| | input_history = input_tensor._keras_history |
| | if input_history.operation is None: |
| | |
| | continue |
| |
|
| | |
| | |
| | input_node = input_history.operation._inbound_nodes[ |
| | input_history.node_index |
| | ] |
| | output_index = input_history.tensor_index |
| |
|
| | |
| | source = input_node.operation |
| | destination = layer |
| |
|
| | if not expand_nested: |
| | |
| | add_edge(dot, source, layer) |
| | continue |
| |
|
| | |
| |
|
| | |
| | while isinstance( |
| | source, |
| | (functional.Functional, sequential.Sequential), |
| | ): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | source, _, output_index = source.outputs[ |
| | output_index |
| | ]._keras_history |
| |
|
| | |
| | while isinstance( |
| | destination, |
| | (functional.Functional, sequential.Sequential), |
| | ): |
| | if isinstance(destination, functional.Functional): |
| | |
| | |
| | destination = destination.inputs[ |
| | input_index |
| | ]._keras_history.operation |
| | else: |
| | |
| | |
| | |
| | |
| | |
| | |
| | destination = destination.layers[0] |
| |
|
| | add_edge(dot, source, destination) |
| | return dot |
| |
|
| |
|
| | @keras_export("keras.utils.plot_model") |
| | def plot_model( |
| | model, |
| | to_file="model.png", |
| | show_shapes=False, |
| | show_dtype=False, |
| | show_layer_names=False, |
| | rankdir="TB", |
| | expand_nested=False, |
| | dpi=200, |
| | show_layer_activations=False, |
| | show_trainable=False, |
| | **kwargs, |
| | ): |
| | """Converts a Keras model to dot format and save to a file. |
| | |
| | Example: |
| | |
| | ```python |
| | inputs = ... |
| | outputs = ... |
| | model = keras.Model(inputs=inputs, outputs=outputs) |
| | |
| | dot_img_file = '/tmp/model_1.png' |
| | keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True) |
| | ``` |
| | |
| | Args: |
| | model: A Keras model instance |
| | to_file: File name of the plot image. |
| | show_shapes: whether to display shape information. |
| | show_dtype: whether to display layer dtypes. |
| | show_layer_names: whether to display layer names. |
| | rankdir: `rankdir` argument passed to PyDot, |
| | a string specifying the format of the plot: `"TB"` |
| | creates a vertical plot; `"LR"` creates a horizontal plot. |
| | expand_nested: whether to expand nested Functional models |
| | into clusters. |
| | dpi: Image resolution in dots per inch. |
| | show_layer_activations: Display layer activations (only for layers that |
| | have an `activation` property). |
| | show_trainable: whether to display if a layer is trainable. |
| | |
| | Returns: |
| | A Jupyter notebook Image object if Jupyter is installed. |
| | This enables in-line display of the model plots in notebooks. |
| | """ |
| |
|
| | if not model.built: |
| | raise ValueError( |
| | "This model has not yet been built. " |
| | "Build the model first by calling `build()` or by calling " |
| | "the model on a batch of data." |
| | ) |
| | if not check_pydot(): |
| | message = ( |
| | "You must install pydot (`pip install pydot`) " |
| | "for `plot_model` to work." |
| | ) |
| | if "IPython.core.magics.namespace" in sys.modules: |
| | |
| | |
| | io_utils.print_msg(message) |
| | return |
| | else: |
| | raise ImportError(message) |
| | if not check_graphviz(): |
| | message = ( |
| | "You must install graphviz " |
| | "(see instructions at https://graphviz.gitlab.io/download/) " |
| | "for `plot_model` to work." |
| | ) |
| | if "IPython.core.magics.namespace" in sys.modules: |
| | |
| | |
| | io_utils.print_msg(message) |
| | return |
| | else: |
| | raise ImportError(message) |
| |
|
| | if kwargs.pop("layer_range", None) is not None: |
| | raise ValueError("Argument `layer_range` is no longer supported.") |
| | if kwargs: |
| | raise ValueError(f"Unrecognized keyword arguments: {kwargs}") |
| |
|
| | dot = model_to_dot( |
| | model, |
| | show_shapes=show_shapes, |
| | show_dtype=show_dtype, |
| | show_layer_names=show_layer_names, |
| | rankdir=rankdir, |
| | expand_nested=expand_nested, |
| | dpi=dpi, |
| | show_layer_activations=show_layer_activations, |
| | show_trainable=show_trainable, |
| | ) |
| | to_file = str(to_file) |
| | if dot is None: |
| | return |
| | _, extension = os.path.splitext(to_file) |
| | if not extension: |
| | extension = "png" |
| | else: |
| | extension = extension[1:] |
| | |
| | dot.write(to_file, format=extension) |
| | |
| | |
| | |
| | if extension != "pdf": |
| | try: |
| | from IPython import display |
| |
|
| | return display.Image(filename=to_file) |
| | except ImportError: |
| | pass |
| |
|