| |
|
|
| import os |
| import torch |
|
|
| from detectron2.utils.file_io import PathManager |
|
|
| from .torchscript_patch import freeze_training_mode, patch_instances |
|
|
| __all__ = ["scripting_with_instances", "dump_torchscript_IR"] |
|
|
|
|
| def scripting_with_instances(model, fields): |
| """ |
| Run :func:`torch.jit.script` on a model that uses the :class:`Instances` class. Since |
| attributes of :class:`Instances` are "dynamically" added in eager mode,it is difficult |
| for scripting to support it out of the box. This function is made to support scripting |
| a model that uses :class:`Instances`. It does the following: |
| |
| 1. Create a scriptable ``new_Instances`` class which behaves similarly to ``Instances``, |
| but with all attributes been "static". |
| The attributes need to be statically declared in the ``fields`` argument. |
| 2. Register ``new_Instances``, and force scripting compiler to |
| use it when trying to compile ``Instances``. |
| |
| After this function, the process will be reverted. User should be able to script another model |
| using different fields. |
| |
| Example: |
| Assume that ``Instances`` in the model consist of two attributes named |
| ``proposal_boxes`` and ``objectness_logits`` with type :class:`Boxes` and |
| :class:`Tensor` respectively during inference. You can call this function like: |
| :: |
| fields = {"proposal_boxes": Boxes, "objectness_logits": torch.Tensor} |
| torchscipt_model = scripting_with_instances(model, fields) |
| |
| Note: |
| It only support models in evaluation mode. |
| |
| Args: |
| model (nn.Module): The input model to be exported by scripting. |
| fields (Dict[str, type]): Attribute names and corresponding type that |
| ``Instances`` will use in the model. Note that all attributes used in ``Instances`` |
| need to be added, regardless of whether they are inputs/outputs of the model. |
| Data type not defined in detectron2 is not supported for now. |
| |
| Returns: |
| torch.jit.ScriptModule: the model in torchscript format |
| """ |
| assert ( |
| not model.training |
| ), "Currently we only support exporting models in evaluation mode to torchscript" |
|
|
| with freeze_training_mode(model), patch_instances(fields): |
| scripted_model = torch.jit.script(model) |
| return scripted_model |
|
|
|
|
| |
| export_torchscript_with_instances = scripting_with_instances |
|
|
|
|
| def dump_torchscript_IR(model, dir): |
| """ |
| Dump IR of a TracedModule/ScriptModule/Function in various format (code, graph, |
| inlined graph). Useful for debugging. |
| |
| Args: |
| model (TracedModule/ScriptModule/ScriptFUnction): traced or scripted module |
| dir (str): output directory to dump files. |
| """ |
| dir = os.path.expanduser(dir) |
| PathManager.mkdirs(dir) |
|
|
| def _get_script_mod(mod): |
| if isinstance(mod, torch.jit.TracedModule): |
| return mod._actual_script_module |
| return mod |
|
|
| |
| with PathManager.open(os.path.join(dir, "model_ts_code.txt"), "w") as f: |
|
|
| def get_code(mod): |
| |
| try: |
| |
| return _get_script_mod(mod)._c.code |
| except AttributeError: |
| pass |
| try: |
| return mod.code |
| except AttributeError: |
| return None |
|
|
| def dump_code(prefix, mod): |
| code = get_code(mod) |
| name = prefix or "root model" |
| if code is None: |
| f.write(f"Could not found code for {name} (type={mod.original_name})\n") |
| f.write("\n") |
| else: |
| f.write(f"\nCode for {name}, type={mod.original_name}:\n") |
| f.write(code) |
| f.write("\n") |
| f.write("-" * 80) |
|
|
| for name, m in mod.named_children(): |
| dump_code(prefix + "." + name, m) |
|
|
| if isinstance(model, torch.jit.ScriptFunction): |
| f.write(get_code(model)) |
| else: |
| dump_code("", model) |
|
|
| def _get_graph(model): |
| try: |
| |
| return _get_script_mod(model)._c.dump_to_str(True, False, False) |
| except AttributeError: |
| return model.graph.str() |
|
|
| with PathManager.open(os.path.join(dir, "model_ts_IR.txt"), "w") as f: |
| f.write(_get_graph(model)) |
|
|
| |
| with PathManager.open(os.path.join(dir, "model_ts_IR_inlined.txt"), "w") as f: |
| f.write(str(model.inlined_graph)) |
|
|
| if not isinstance(model, torch.jit.ScriptFunction): |
| |
| with PathManager.open(os.path.join(dir, "model.txt"), "w") as f: |
| f.write(str(model)) |
|
|