File size: 5,106 Bytes
4121bec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Copyright (c) Facebook, Inc. and its affiliates.

import os
import torch

from detectron2.utils.env import TORCH_VERSION
from detectron2.utils.file_io import PathManager

from .torchscript_patch import freeze_training_mode, patch_instances

__all__ = ["scripting_with_instances", "dump_torchscript_IR"]


def scripting_with_instances(model, fields):
    """
    Run :func:`torch.jit.script` on a model that uses the :class:`Instances` class. Since
    attributes of :class:`Instances` are "dynamically" added in eager mode,it is difficult
    for scripting to support it out of the box. This function is made to support scripting
    a model that uses :class:`Instances`. It does the following:

    1. Create a scriptable ``new_Instances`` class which behaves similarly to ``Instances``,
       but with all attributes been "static".
       The attributes need to be statically declared in the ``fields`` argument.
    2. Register ``new_Instances``, and force scripting compiler to
       use it when trying to compile ``Instances``.

    After this function, the process will be reverted. User should be able to script another model
    using different fields.

    Example:
        Assume that ``Instances`` in the model consist of two attributes named
        ``proposal_boxes`` and ``objectness_logits`` with type :class:`Boxes` and
        :class:`Tensor` respectively during inference. You can call this function like:
        ::
            fields = {"proposal_boxes": Boxes, "objectness_logits": torch.Tensor}
            torchscipt_model =  scripting_with_instances(model, fields)

    Note:
        It only support models in evaluation mode.

    Args:
        model (nn.Module): The input model to be exported by scripting.
        fields (Dict[str, type]): Attribute names and corresponding type that
            ``Instances`` will use in the model. Note that all attributes used in ``Instances``
            need to be added, regardless of whether they are inputs/outputs of the model.
            Data type not defined in detectron2 is not supported for now.

    Returns:
        torch.jit.ScriptModule: the model in torchscript format
    """
    assert TORCH_VERSION >= (1, 8), "This feature is not available in PyTorch < 1.8"
    assert (
        not model.training
    ), "Currently we only support exporting models in evaluation mode to torchscript"

    with freeze_training_mode(model), patch_instances(fields):
        scripted_model = torch.jit.script(model)
        return scripted_model


# alias for old name
export_torchscript_with_instances = scripting_with_instances


def dump_torchscript_IR(model, dir):
    """
    Dump IR of a TracedModule/ScriptModule/Function in various format (code, graph,
    inlined graph). Useful for debugging.

    Args:
        model (TracedModule/ScriptModule/ScriptFUnction): traced or scripted module
        dir (str): output directory to dump files.
    """
    PathManager.mkdirs(dir)

    def _get_script_mod(mod):
        if isinstance(mod, torch.jit.TracedModule):
            return mod._actual_script_module
        return mod

    # Dump pretty-printed code: https://pytorch.org/docs/stable/jit.html#inspecting-code
    with PathManager.open(os.path.join(dir, "model_ts_code.txt"), "w") as f:

        def get_code(mod):
            # Try a few ways to get code using private attributes.
            try:
                # This contains more information than just `mod.code`
                return _get_script_mod(mod)._c.code
            except AttributeError:
                pass
            try:
                return mod.code
            except AttributeError:
                return None

        def dump_code(prefix, mod):
            code = get_code(mod)
            name = prefix or "root model"
            if code is None:
                f.write(f"Could not found code for {name} (type={mod.original_name})\n")
                f.write("\n")
            else:
                f.write(f"\nCode for {name}, type={mod.original_name}:\n")
                f.write(code)
                f.write("\n")
                f.write("-" * 80)

            for name, m in mod.named_children():
                dump_code(prefix + "." + name, m)

        if isinstance(model, torch.jit.ScriptFunction):
            f.write(get_code(model))
        else:
            dump_code("", model)

    def _get_graph(model):
        try:
            # Recursively dump IR of all modules
            return _get_script_mod(model)._c.dump_to_str(True, False, False)
        except AttributeError:
            return model.graph.str()

    with PathManager.open(os.path.join(dir, "model_ts_IR.txt"), "w") as f:
        f.write(_get_graph(model))

    # Dump IR of the entire graph (all submodules inlined)
    with PathManager.open(os.path.join(dir, "model_ts_IR_inlined.txt"), "w") as f:
        f.write(str(model.inlined_graph))

    if not isinstance(model, torch.jit.ScriptFunction):
        # Dump the model structure in pytorch style
        with PathManager.open(os.path.join(dir, "model.txt"), "w") as f:
            f.write(str(model))