Spaces:
Runtime error
Runtime error
File size: 14,757 Bytes
5f5d58c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 |
import torch
import torch.utils._pytree as pytree
from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
from torch._dispatch.python import suspend_functionalization
from torch._functorch.aot_autograd import AOTConfig, create_joint
from torch._functorch.eager_transforms import (
_unwrap_all_tensors_from_functional,
_wrap_all_tensors_to_functional,
functionalize,
)
from torch._higher_order_ops.cond import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
UnsupportedAliasMutationException,
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
make_fx,
ProxyTorchDispatchMode,
track_tensor_tree,
)
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
_pop_mode_temporarily,
)
# TODO: We add this to prevent dymamo from tracing into map_wrapper,
# remove the wrapper call when it's ready.
class MapWrapper(HigherOrderOperator):
def __call__(self, xs, *args):
return map_wrapper(xs, *args)
map = MapWrapper("map", _deprecated_global_ns=True)
map_impl = HigherOrderOperator("map_impl", _deprecated_global_ns=True)
dummy_aot_config = AOTConfig(
fw_compiler=None,
bw_compiler=None,
partition_fn=None,
decompositions={},
num_params_buffers=0,
aot_id=0,
keep_inference_input_mutations=False,
)
def create_fw_bw_graph(f, num_mapped_args, *args):
mapped_xs = args[:num_mapped_args]
pos_args = args[num_mapped_args:]
# Note: We create "clean" environments for make_fx by suspending all dispatch keys
# between Autograd and Python key. Currently, we only suspend functionalization but more can be
# added when required. Will encounter two problems if we don't suspend functionalization:
#
# 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
# but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
# However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
# fetch the proxy for the inputs and fail to capture any operations on them.
#
# 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
# wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
# only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
# when creating the output node, it fails to associate the wrapped tensor with its proxy.
# Instead, it will create _tensor_constant as output.
with suspend_functionalization():
with disable_proxy_modes_tracing():
def from_fun(t):
if isinstance(t, torch.Tensor):
if t.dtype != torch.bool:
return torch.empty_strided(
t.size(),
t.stride(),
dtype=t.dtype,
requires_grad=t.requires_grad,
)
else:
return t.clone()
return t
example_xs = [from_fun(xs) for xs in _unstack_pytree(mapped_xs)[0]]
example_pos_args = [
from_fun(arg) if isinstance(arg, torch.Tensor) else arg
for arg in pos_args
]
example_flat_out = pytree.tree_map(
from_fun, f(*example_xs, *example_pos_args)
)
if any(
not isinstance(out, torch.Tensor)
for out in example_flat_out
if out is not None
):
raise RuntimeError(
"Expect outputs of map only contains tensors or None. "
f"Got types {[type(out) for out in example_flat_out]}."
)
example_grad = [from_fun(out) for out in example_flat_out]
fw_graph = make_fx(f)(*example_xs, *example_pos_args)
def joint_f(*example_args):
joint_mapped_args = example_args[:joint_num_mapped]
args = example_args[joint_num_mapped:]
mapped_input = joint_mapped_args[:num_mapped_args]
mapped_grads = joint_mapped_args[num_mapped_args:]
def fw_with_masks(*args):
fw_out = f(*args)
return fw_out, [
True
if isinstance(ret, torch.Tensor) and ret.requires_grad
else False
for ret in fw_out
]
joint = create_joint(fw_with_masks, aot_config=dummy_aot_config)
_, grads = joint(
list(mapped_input) + list(args),
[
grad
for grad in mapped_grads
if grad is not None and grad.requires_grad
],
)
# In order to keep map functional for backward graph,
# we clone outputs that are aliasing inputs
input_storage = {
StorageWeakRef(arg._typed_storage())
for arg in example_args
if isinstance(arg, torch.Tensor)
}
def maybe_clone(t):
if (
isinstance(t, torch.Tensor)
and StorageWeakRef(t._typed_storage()) in input_storage
):
return t.clone()
return t
return pytree.tree_map(maybe_clone, grads)
joint_num_mapped = len(example_grad) + len(example_xs)
joint_graph = make_fx(joint_f)(*example_xs, *example_grad, *example_pos_args)
return fw_graph, joint_graph
def map_wrapper(f, xs, *args):
flat_xs, xs_spec = pytree.tree_flatten(xs)
if not all(isinstance(t, torch.Tensor) for t in flat_xs):
raise RuntimeError(f"Mapped xs can only consist of tensors. Got xs {flat_xs}.")
num_mapped_args = len(flat_xs)
shapes = [xs.shape for xs in flat_xs]
leading_dim_size = shapes[0][0]
if leading_dim_size == 0:
raise RuntimeError("Leading dimensions of mapped xs cannot be 0.")
if any(cur_shape[0] != leading_dim_size for cur_shape in shapes):
raise RuntimeError(
f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}."
)
out_spec = None
def flat_fn(*flat_args):
xs = pytree.tree_unflatten(flat_args[:num_mapped_args], xs_spec)
unflattened_out = f(xs, *flat_args[num_mapped_args:])
flat_out, tmp_out_spec = pytree.tree_flatten(unflattened_out)
nonlocal out_spec
out_spec = tmp_out_spec
return flat_out
return pytree.tree_unflatten(
map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec
)
class MapAutogradOp(torch.autograd.Function):
@staticmethod
def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args):
ctx.save_for_backward(*flat_args)
ctx._joint_graph = joint_graph
ctx._num_mapped_args = num_mapped_args
with torch._C._AutoDispatchBelowAutograd():
return (*map_impl(fw_graph, num_mapped_args, *flat_args),)
@staticmethod
def backward(ctx, *flat_grads):
fw_args = ctx.saved_tensors
fw_mapped_args = fw_args[: ctx._num_mapped_args]
pos_args = fw_args[ctx._num_mapped_args :]
grads = map_impl(
ctx._joint_graph,
ctx._num_mapped_args + len(flat_grads),
*fw_mapped_args,
*flat_grads,
*pos_args,
)
return None, None, None, *grads
def trace_map(proxy_mode, func_overload, f, num_mapped, *args):
xs = list(args[:num_mapped])
pos_args = list(args[num_mapped:])
leading_dim_size = xs[0].shape[0]
example_input = _unstack_pytree(xs)[0]
body_graph = f
if not isinstance(body_graph, torch.fx.GraphModule):
body_graph = make_fx(body_graph)(*example_input, *pos_args)
with disable_proxy_modes_tracing():
example_outs = body_graph(*example_input, *pos_args)
def expand_tensor(t):
if isinstance(t, torch.Tensor):
return t.expand(leading_dim_size, *t.shape)
return t
expanded_outs = pytree.tree_map(expand_tensor, example_outs)
next_name = None
i = 0
while not next_name:
candidate = f"body_graph_{i}"
if hasattr(proxy_mode.tracer.root, candidate):
i += 1
else:
next_name = candidate
proxy_mode.tracer.root.register_module(next_name, body_graph)
node_args = (body_graph, num_mapped, *args)
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", func_overload, proxy_args, {}, name="map_impl"
)
return track_tensor_tree(
expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer
)
def _unstack_pytree(xs):
flat_xs, inspec = pytree.tree_flatten(xs)
if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
raise RuntimeError(
f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
)
a = zip(*flat_xs)
pytrees = []
for tuple in a:
pytrees.append(pytree.tree_unflatten(tuple, inspec))
return pytrees
def _stack_pytree(pytrees):
flat_out = []
out_spec = None
for pt in pytrees:
flat_pt, out_spec = pytree.tree_flatten(pt)
flat_out.append(flat_pt)
b = zip(*flat_out)
stacked_out = []
for leaves in b:
if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
stacked_out.append(torch.stack(leaves))
elif all(leaf is None for leaf in leaves):
# Backward graph can return None output when forward inputs doesn't require grad.
# When we eagerly execute backward graph, we need to call _stack_pytree on its output,
# therefore we need to deal with None output.
stacked_out.append(None)
else:
raise RuntimeError(f"Cannot stack {leaves}.")
return pytree.tree_unflatten(stacked_out, out_spec)
@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
def map_dense(f, num_mapped_args, *args):
xs = args[:num_mapped_args]
pos_args = args[num_mapped_args:]
pytrees = []
for inp in _unstack_pytree(xs):
pytrees.append(f(*inp, *pos_args))
return _stack_pytree(pytrees)
@map_impl.py_impl(DispatchKey.Autograd)
def map_autograd(f, num_mapped_args, *args):
fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *args)
flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *args)
return flat_out
@map_impl.py_impl(ProxyTorchDispatchMode)
def map_proxy_torch_dispatch_mode(f, num_mapped, *args):
mode = _get_current_dispatch_mode()
assert mode is not None, "Mode should always be enabled for python fallback key"
with _pop_mode_temporarily() as mode:
if mode.enable_tracing:
return trace_map(mode, map_impl, f, num_mapped, *args)
else:
return map_impl(f, num_mapped, *args)
@map_impl.py_impl(FakeTensorMode)
def map_fake_tensor_mode(f, num_mapped, *args):
return map_dense(f, num_mapped, *args)
@map_impl.py_impl(DispatchKey.Functionalize)
def map_func(f, num_mapped, *args):
reapply_views = torch._C._functionalization_reapply_views_tls()
xs = args[:num_mapped]
pos_args = args[num_mapped:]
unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views)
unwrapped_args = _unwrap_all_tensors_from_functional(
pos_args, reapply_views=reapply_views
)
mode = "mutations_and_views" if reapply_views else "mutations"
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
functional_map_fn = functionalize(f, remove=mode)
with disable_proxy_modes_tracing():
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
if _has_potential_branch_input_mutation(f, example_inputs):
raise UnsupportedAliasMutationException("torch.map is mutating the input!")
if _has_potential_branch_input_alias(f, example_inputs):
raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
map_return = map_impl(
functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args
)
return _wrap_all_tensors_to_functional(map_return, level=0)
@map_impl.py_impl(torch._C._functorch.TransformType.Functionalize)
def map_functionalize(interpreter, f, num_mapped, *args):
"""
Functionalization implementation for torch.map. Currently:
1. We don't allow any input mutation inside the map function
2. Our check for above condition is not exhaustive
"""
xs = args[:num_mapped]
pos_args = args[num_mapped:]
reapply_views = interpreter.functionalize_add_back_views()
mode = "mutations_and_views" if reapply_views else "mutations"
# At this point, we will see functionalized tensors, so need to unwrap them first
unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views)
unwrapped_args = _unwrap_all_tensors_from_functional(
pos_args, reapply_views=reapply_views
)
functional_map_fn = functionalize(f, remove=mode)
with interpreter.lower():
with disable_proxy_modes_tracing():
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
if _has_potential_branch_input_mutation(f, example_inputs):
raise UnsupportedAliasMutationException("torch.map is mutating the input!")
if _has_potential_branch_input_alias(f, example_inputs):
raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
map_return = map_impl(
functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args
)
return _wrap_all_tensors_to_functional(map_return, level=interpreter.level())
# TODO(voz) Make this automatic for keys, this is very ugly atm
map_impl.fallthrough(DispatchKey.PythonDispatcher)
map_impl.fallthrough(DispatchKey.PythonTLSSnapshot)
map_impl.fallthrough(DispatchKey.ADInplaceOrView)
map_impl.fallthrough(DispatchKey.BackendSelect)
map_impl.fallthrough(DispatchKey.AutocastCPU)
|