File size: 14,467 Bytes
b2659ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
import copy
import logging
import os
import pickle
import random
from contextlib import contextmanager
from functools import partial
from typing import Callable, Union
import sympy

import torch
from torch import SymInt
import torch.fx as fx
import torch.nn as nn
from torch._decomp import get_decompositions
from torch.fx.experimental.symbolic_shapes import bind_symbols

from .aot_autograd import aot_function, aot_module, make_boxed_compiler
from .compile_utils import strip_overloads
from .partitioners import (
    default_partition,
    draw_graph,
    min_cut_rematerialization_partition,
)
import torch.utils._pytree as pytree


log = logging.getLogger(__name__)


# These canonicalizations are needed here (and not decompositions), as the ops
# we're trying to canonicalize to CompositeImplicitAutograd.
def _canonicalize(fx_g):
    for node in fx_g.graph.nodes:
        if node.target == torch.ops.aten._to_copy:
            node.target = torch.ops.aten.to
    fx_g.recompile()
    return fx_g


@contextmanager
def _disable_jit_autocast():
    old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
    try:
        yield
    finally:
        torch._C._jit_set_autocast_mode(old_jit_autocast_flag)


@make_boxed_compiler
def ts_compile(fx_g: fx.GraphModule, inps) -> Callable:
    """

    Compiles the :attr:`fx_g` with Torchscript compiler.



    .. warning::

        This API is experimental and likely to change.



    Args:

        fx_g(fx.GraphModule): The input Fx graph module to be compiled.



    Returns:

        Torch scripted model.

    """

    with _disable_jit_autocast():
        strip_overloads(fx_g)

        for node in fx_g.graph.nodes:
            if (
                node.target == torch.ops.aten._to_copy
                and len(node.args) == 1
                and len(node.kwargs) == 1
                and "dtype" in node.kwargs
            ):
                node.target = torch.ops.aten.to

        for node in fx_g.graph.nodes:
            new_kwargs = {}
            for k, v in node.kwargs.items():
                if isinstance(v, torch.device):
                    v = v.type
                new_kwargs[k] = v
            node.kwargs = new_kwargs

        fx_g.graph.lint()

        fx_g.recompile()

        f = torch.jit.script(fx_g)

        torch._C._jit_pass_remove_mutation(f.graph)

        f = torch.jit.freeze(f.eval())
        f = torch.jit.optimize_for_inference(f)
        if not any(isinstance(t, torch._subclasses.FakeTensor) for t in inps):
            f(*inps)
    return f


def _draw_graph_compile(fx_g, _, name, clear_meta=True):
    print(fx_g.code)
    draw_graph(fx_g, name, clear_meta=clear_meta)
    return fx_g


def draw_graph_compile(name):
    return make_boxed_compiler(
        partial(_draw_graph_compile, name=name)
    )


@make_boxed_compiler
def nop(fx_g: fx.GraphModule, _) -> Callable:
    """

    Returns the :attr:`fx_g` Fx graph module as it is. This is a no-op compiler

    and can be used to check accuracy.



    .. warning::

        This API is experimental and likely to change.



    """
    return fx_g

class DebugInterpreter(fx.Interpreter):
    def run(self, *args):
        self.symbol_mapping = bind_symbols(self.module, *args)
        super().run(*args)

    def run_node(self, n):

        def subst_symint(ni):
            if not isinstance(ni, SymInt):
                return ni
            r = sympy.expand(ni.node.expr.xreplace(self.symbol_mapping))
            assert r.is_number, r
            return int(r)

        def subst_symint_tuple(nis):
            return tuple(subst_symint(ni) for ni in nis)

        def check_significant_strides(a, b):
            if subst_symint(a.numel()) > 0:
                for idx in range(a.ndim):
                    if subst_symint(a.stride(idx)) != b.stride(idx) and subst_symint(a.size(idx)) > 1:
                        return False
            return True

        def check(nv, rv, desc):
            assert callable(desc)
            assert nv.dtype == rv.dtype, f"{desc()}: {nv.dtype} != {rv.dtype}"
            assert subst_symint_tuple(nv.size()) == rv.size(), \
                f"{desc()}: {nv.size()} aka {subst_symint_tuple(nv.size())} != {rv.size()}"
            same_strides = check_significant_strides(nv, rv)
            assert same_strides, f"{desc()}: {nv.stride()} aka {subst_symint_tuple(nv.stride())} != {rv.stride()}"

        r = super().run_node(n)
        if 'val' in n.meta:
            n_vals, n_spec = pytree.tree_flatten(n.meta['val'])
            r_vals, r_spec = pytree.tree_flatten(r)
            # TODO: There is some sort of problem where we record that an
            # operator returned a tuple/list, and then later it turns out the
            # real version of the operator returned a list/tuple. Need to
            # figure out what's actually going on here, the error itself is
            # harmless enough as we only getitem out the outputs.
            # assert n_spec == r_spec, f"{n_spec} != {r_spec}"
            assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}"
            for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
                if not isinstance(rv, torch.Tensor):
                    continue
                check(nv, rv, lambda: f"output {i} where {self.symbol_mapping}")
        return r


@make_boxed_compiler
def debug_nop(fx_g: fx.GraphModule, _) -> Callable:
    """

    Returns a (slow) interpreter over the FX graph module that also checks

    various debugging properties (e.g., that tracing strides matched real

    strides.)

    """
    return DebugInterpreter(fx_g).run

@make_boxed_compiler
def simple_ts_compile(fx_g, _):
    strip_overloads(fx_g)
    f = torch.jit.script(fx_g)
    f = torch.jit.freeze(f.eval())
    return f


def nnc_jit(f):
    return aot_function(f, simple_ts_compile)


aten = torch.ops.aten
default_decompositions = {
    aten.detach,
    aten.gelu_backward,
    aten.leaky_relu_backward,
    aten.sigmoid_backward,
    aten.threshold_backward,
    aten.hardtanh_backward,
    aten.hardsigmoid_backward,
    aten.hardswish_backward,
    aten.tanh_backward,
    aten.silu_backward,
    aten.elu_backward,
    aten.cudnn_batch_norm,
    aten.cudnn_batch_norm_backward,
    aten.masked_fill.Scalar,
    aten.masked_fill.Tensor,
    aten.elu,
    aten.leaky_relu,
    aten.hardtanh,
    aten.hardswish,
    aten.hardsigmoid,
    aten.conj_physical,
    aten.is_same_size,
}

default_decompositions = get_decompositions(default_decompositions)


@make_boxed_compiler
def print_compile(fx_g, _):
    print(fx_g.code)
    return fx_g


def memory_efficient_fusion(

    fn: Union[Callable, nn.Module],

    **kwargs,

):
    """

    Wrapper function over :func:`aot_function` and :func:`aot_module` to perform

    memory efficient fusion. It uses the

    :func:`min_cut_rematerialization_partition` partitioner to perform efficient

    recomputation. It uses NVFuser to compile the generated forward and backward

    graphs.



    .. warning::

        This API is experimental and likely to change.



    Args:

        fn (Union[Callable, nn.Module]): A Python function or a ``nn.Module``

            that takes one ore more arguments. Must return one or more Tensors.

        **kwargs: Any other overrides you want to make to the settings



    Returns:

        Returns a ``Callable``  or ``nn.Module`` that retains the eager behavior

        of the original :attr:`fn`, but whose forward and backward graphs have

        gone through recomputation optimizations, and the graphs have been

        compiled with nvfuser.



    """
    config = {
        "fw_compiler": ts_compile,
        "bw_compiler": ts_compile,
        "partition_fn": min_cut_rematerialization_partition,
        "decompositions": default_decompositions,
    }
    config.update(kwargs)
    if isinstance(fn, torch.nn.Module):
        return aot_module(fn, **config)
    else:
        return aot_function(fn, **config)


def debug_compile(fx_g, inps):
    fx_g.to_folder("foo")
    print(
        f"""

##############################################################

# To minimize FX graph, copy and paste the below and run it  #

##############################################################



import torch

import torch.fx as fx

from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess



inps = {[(i.shape, i.dtype) for i in inps]}

inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]

from foo import FxModule

mod = FxModule().cuda()



with torch.jit.fuser("fuser2"):

  # check_nvfuser_subprocess can be replaced with check_nvfuser_correctness_subprocess

  minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess)

"""
    )
    from foo import FxModule

    FxModule().cuda()(*inps)

    return ts_compile(fx_g, inps)


graph_index = 0


def get_inputs(input_data_path):
    """

    Return a random input for the given inputs meta generated from _save_fx_default.

    """
    inputs = []
    with (open(input_data_path, "rb")) as f:
        inputs_meta = pickle.load(f)
        inputs = []
        for meta in inputs_meta:
            if len(meta) == 1:
                type = meta
                input = type(random.rand())
            else:
                type, shape, stride, dtype, device = meta
                if dtype in {
                    torch.int,
                    torch.int32,
                    torch.int64,
                    torch.bool,
                    torch.int,
                    torch.uint8,
                    int,
                    float,
                }:
                    input = torch.randint(0, 1, shape, dtype=dtype, device=device)
                else:
                    input = torch.rand(shape, dtype=dtype, device=device)
            inputs.append(input)
    return inputs


def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_inputs):
    """

    The forward, backward, and joint computation graph will be stored in

    {folder_name}/{current_name}/{current_name}_forward_{graph_index},

    {folder_name}/{current_name}/{current_name}_backward_{graph_index}, and

    {folder_name}/{current_name}/{current_name}_joint_{graph_index} respectively.

    The input shape of the graphs will be stored in the .input files.

    These files can be loaded with pickle,

    and is a list of format (type, shape, stride, dtype, device).

    In the case of type = int or float, it is just (type,).

    For joint graph input, it is a nested list [[],[]]

    where the two inner lists have the same format.

    If dump_example_input is True, example_inputs will be stored in .pt file.

    Since each function might produce multiple graphs,

    the graph_index is used to distinguish difference graphs

    """
    from functorch.compile import aot_module_simplified

    def get_input_meta(args):
        input_meta = []
        if len(args) > 0 and isinstance(args[0], tuple):  # joint input
            input_meta += get_input_meta(args[0])
            input_meta += get_input_meta(args[1])
            return input_meta
        for arg in args:
            if type(arg) == int or type(arg) == float:
                input_meta.append((type(arg),))
            else:
                input_meta.append(
                    (type(arg), arg.shape, arg.stride(), arg.dtype, arg.device)
                )
        return input_meta

    def graph_saver_helper(gm_to_save, args, type_name):
        global graph_index
        if len(gm_to_save.graph.nodes) == 0:
            log.log(
                logging.WARNING,
                "No nodes in graph {%s}_{%s}_{%s}.",
                current_name,
                type_name,
                graph_index,
            )
            return

        gm = copy.deepcopy(gm_to_save)
        gm.graph.set_codegen(torch.fx.graph.CodeGen())  # remove codegen
        gm.recompile()

        input_meta = get_input_meta(args)

        isExist = os.path.exists(f"{folder_name}/{current_name}")
        if not isExist:
            os.makedirs(f"{folder_name}/{current_name}")
        gm.to_folder(
            f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}"
        )
        pickle.dump(
            input_meta,
            open(
                f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input",  # noqa: B950
                "wb",
            ),
        )  # noqa: E501
        if dump_example_input:
            torch.save(
                args,
                f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt",  # noqa: B950
            )  # noqa: E501

    def graph_saver_forward(gm, fw_args):
        graph_saver_helper(gm, fw_args, "forward")
        return gm

    def graph_saver_backward(gm, bw_args):
        graph_saver_helper(gm, bw_args, "backward")
        global graph_index
        graph_index += 1
        return gm

    def graph_saver_joint(gm, joint_args):
        graph_saver_helper(gm, joint_args, "joint")
        return default_partition(gm, joint_args)

    return aot_module_simplified(
        gm,
        example_inputs,
        fw_compiler=graph_saver_forward,
        bw_compiler=graph_saver_backward,
        partition_fn=graph_saver_joint,
        decompositions=default_decompositions,
    )


# WARNING: This isn't tested anywhere!!
def graph_dumper_aot(current_name, folder_name, dump_example_input=False):
    """

    Dump the forward, backward, and joint computation graph.

    Example Usage:

    save_fx_func = graph_dumper_aot(current_name, folder_name, dump_example_input = False)

    optimize_ctx = torchdynamo.optimize(

        save_fx_func

    )

    with torch.enable_grad():

        with optimize_ctx:

            result = forward_and_backward_pass(model, example_inputs)

    """
    global graph_index
    graph_index = 0
    return partial(_save_fx_default, current_name, folder_name, dump_example_input)