| | |
| |
|
| | import os |
| | import torch |
| |
|
| | from detectron2.utils.file_io import PathManager |
| |
|
| | from .torchscript_patch import freeze_training_mode, patch_instances |
| |
|
| | __all__ = ["scripting_with_instances", "dump_torchscript_IR"] |
| |
|
| |
|
| | def scripting_with_instances(model, fields): |
| | """ |
| | Run :func:`torch.jit.script` on a model that uses the :class:`Instances` class. Since |
| | attributes of :class:`Instances` are "dynamically" added in eager mode,it is difficult |
| | for scripting to support it out of the box. This function is made to support scripting |
| | a model that uses :class:`Instances`. It does the following: |
| | |
| | 1. Create a scriptable ``new_Instances`` class which behaves similarly to ``Instances``, |
| | but with all attributes been "static". |
| | The attributes need to be statically declared in the ``fields`` argument. |
| | 2. Register ``new_Instances``, and force scripting compiler to |
| | use it when trying to compile ``Instances``. |
| | |
| | After this function, the process will be reverted. User should be able to script another model |
| | using different fields. |
| | |
| | Example: |
| | Assume that ``Instances`` in the model consist of two attributes named |
| | ``proposal_boxes`` and ``objectness_logits`` with type :class:`Boxes` and |
| | :class:`Tensor` respectively during inference. You can call this function like: |
| | :: |
| | fields = {"proposal_boxes": Boxes, "objectness_logits": torch.Tensor} |
| | torchscipt_model = scripting_with_instances(model, fields) |
| | |
| | Note: |
| | It only support models in evaluation mode. |
| | |
| | Args: |
| | model (nn.Module): The input model to be exported by scripting. |
| | fields (Dict[str, type]): Attribute names and corresponding type that |
| | ``Instances`` will use in the model. Note that all attributes used in ``Instances`` |
| | need to be added, regardless of whether they are inputs/outputs of the model. |
| | Data type not defined in detectron2 is not supported for now. |
| | |
| | Returns: |
| | torch.jit.ScriptModule: the model in torchscript format |
| | """ |
| | assert ( |
| | not model.training |
| | ), "Currently we only support exporting models in evaluation mode to torchscript" |
| |
|
| | with freeze_training_mode(model), patch_instances(fields): |
| | scripted_model = torch.jit.script(model) |
| | return scripted_model |
| |
|
| |
|
| | |
| | export_torchscript_with_instances = scripting_with_instances |
| |
|
| |
|
| | def dump_torchscript_IR(model, dir): |
| | """ |
| | Dump IR of a TracedModule/ScriptModule/Function in various format (code, graph, |
| | inlined graph). Useful for debugging. |
| | |
| | Args: |
| | model (TracedModule/ScriptModule/ScriptFUnction): traced or scripted module |
| | dir (str): output directory to dump files. |
| | """ |
| | dir = os.path.expanduser(dir) |
| | PathManager.mkdirs(dir) |
| |
|
| | def _get_script_mod(mod): |
| | if isinstance(mod, torch.jit.TracedModule): |
| | return mod._actual_script_module |
| | return mod |
| |
|
| | |
| | with PathManager.open(os.path.join(dir, "model_ts_code.txt"), "w") as f: |
| |
|
| | def get_code(mod): |
| | |
| | try: |
| | |
| | return _get_script_mod(mod)._c.code |
| | except AttributeError: |
| | pass |
| | try: |
| | return mod.code |
| | except AttributeError: |
| | return None |
| |
|
| | def dump_code(prefix, mod): |
| | code = get_code(mod) |
| | name = prefix or "root model" |
| | if code is None: |
| | f.write(f"Could not found code for {name} (type={mod.original_name})\n") |
| | f.write("\n") |
| | else: |
| | f.write(f"\nCode for {name}, type={mod.original_name}:\n") |
| | f.write(code) |
| | f.write("\n") |
| | f.write("-" * 80) |
| |
|
| | for name, m in mod.named_children(): |
| | dump_code(prefix + "." + name, m) |
| |
|
| | if isinstance(model, torch.jit.ScriptFunction): |
| | f.write(get_code(model)) |
| | else: |
| | dump_code("", model) |
| |
|
| | def _get_graph(model): |
| | try: |
| | |
| | return _get_script_mod(model)._c.dump_to_str(True, False, False) |
| | except AttributeError: |
| | return model.graph.str() |
| |
|
| | with PathManager.open(os.path.join(dir, "model_ts_IR.txt"), "w") as f: |
| | f.write(_get_graph(model)) |
| |
|
| | |
| | with PathManager.open(os.path.join(dir, "model_ts_IR_inlined.txt"), "w") as f: |
| | f.write(str(model.inlined_graph)) |
| |
|
| | if not isinstance(model, torch.jit.ScriptFunction): |
| | |
| | with PathManager.open(os.path.join(dir, "model.txt"), "w") as f: |
| | f.write(str(model)) |
| |
|