File size: 5,971 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union, cast
import numpy as np
import torch
from torch.nn.utils.rnn import PackedSequence
"""Useful functions to deal with tensor types with other python container types."""
def apply_to_type(
type_fn: Callable, fn: Callable, container: Union[torch.Tensor, np.ndarray, Dict, List, Tuple, Set, NamedTuple]
) -> Any:
"""Recursively apply to all objects in different kinds of container types that matches a type function."""
def _apply(x: Union[torch.Tensor, np.ndarray, Dict, List, Tuple, Set]) -> Any:
if type_fn(x):
return fn(x)
elif isinstance(x, OrderedDict):
od = x.__class__()
for key, value in x.items():
od[key] = _apply(value)
return od
elif isinstance(x, PackedSequence):
_apply(x.data)
return x
elif isinstance(x, dict):
return {key: _apply(value) for key, value in x.items()}
elif isinstance(x, list):
return [_apply(x) for x in x]
elif isinstance(x, tuple):
f = getattr(x, "_fields", None)
if f is None:
return tuple(_apply(x) for x in x)
else:
assert isinstance(f, tuple), "This needs to be a namedtuple"
# convert the namedtuple to a dict and _apply().
x = cast(NamedTuple, x)
_dict: Dict[str, Any] = x._asdict()
_dict = {key: _apply(value) for key, value in _dict.items()}
return type(x)(**_dict) # make a copy of the namedtuple
elif isinstance(x, set):
return {_apply(x) for x in x}
else:
return x
return _apply(container)
def apply_to_tensors(fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
"""Recursively apply to all tensor in different kinds of container types."""
return apply_to_type(torch.is_tensor, fn, container)
def to_np(tensor_or_container: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
"""Convert a tensor or a container to numpy."""
return apply_to_type(torch.is_tensor, lambda x: x.cpu().numpy(), tensor_or_container)
def from_np(ndarray_or_container: Union[np.ndarray, Dict, List, Tuple, Set]) -> Any:
"""Convert a ndarray or a container to tensor."""
return apply_to_type(lambda x: isinstance(x, np.ndarray), lambda x: torch.from_numpy(x), ndarray_or_container)
def pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[str, ...], Tuple[Any, ...]]:
"""
Turn argument list into separate key list and value list (unpack_kwargs does the opposite)
Usage::
kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4)
assert kwarg_keys == ("a", "b")
assert flat_args == (1, 2, 3, 4)
args, kwargs = unpack_kwargs(kwarg_keys, flat_args)
assert args == (1, 2)
assert kwargs == {"a": 3, "b": 4}
"""
kwarg_keys: List[str] = []
flat_args: List[Any] = list(args)
for k, v in kwargs.items():
kwarg_keys.append(k)
flat_args.append(v)
return tuple(kwarg_keys), tuple(flat_args)
def unpack_kwargs(kwarg_keys: Tuple[str, ...], flat_args: Tuple[Any, ...]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
"""See pack_kwargs."""
assert len(kwarg_keys) <= len(flat_args), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}"
if len(kwarg_keys) == 0:
return flat_args, {}
args = flat_args[: -len(kwarg_keys)]
kwargs = {k: v for k, v in zip(kwarg_keys, flat_args[-len(kwarg_keys) :])}
return args, kwargs
def split_non_tensors(
mixed: Union[torch.Tensor, Tuple[Any, ...]]
) -> Tuple[Tuple[torch.Tensor, ...], Optional[Dict[str, List[Any]]]]:
"""
Split a tuple into a list of tensors and the rest with information
for later reconstruction.
When called with a tensor X, will return: (x,), None
Usage::
x = torch.Tensor([1])
y = torch.Tensor([2])
tensors, packed_non_tensors = split_non_tensors((x, y, None, 3))
assert tensors == (x, y)
assert packed_non_tensors == {
"is_tensor": [True, True, False, False],
"objects": [None, 3],
}
recon = unpack_non_tensors(tensors, packed_non_tensors)
assert recon == (x, y, None, 3)
"""
if isinstance(mixed, torch.Tensor):
return (mixed,), None
tensors: List[torch.Tensor] = []
packed_non_tensors: Dict[str, List[Any]] = {"is_tensor": [], "objects": []}
for o in mixed:
if isinstance(o, torch.Tensor):
packed_non_tensors["is_tensor"].append(True)
tensors.append(o)
else:
packed_non_tensors["is_tensor"].append(False)
packed_non_tensors["objects"].append(o)
return tuple(tensors), packed_non_tensors
def unpack_non_tensors(
tensors: Tuple[torch.Tensor, ...], packed_non_tensors: Optional[Dict[str, List[Any]]]
) -> Tuple[Any, ...]:
"""See split_non_tensors."""
if packed_non_tensors is None:
return tensors
assert isinstance(packed_non_tensors, dict), type(packed_non_tensors)
mixed: List[Any] = []
is_tensor_list = packed_non_tensors["is_tensor"]
objects = packed_non_tensors["objects"]
assert len(tensors) + len(objects) == len(is_tensor_list), (
f"len(tensors) {len(tensors)} len(objects) {len(objects)} " f"len(is_tensor_list) {len(is_tensor_list)}"
)
obj_i = tnsr_i = 0
for is_tensor in is_tensor_list:
if is_tensor:
mixed.append(tensors[tnsr_i])
tnsr_i += 1
else:
mixed.append(objects[obj_i])
obj_i += 1
return tuple(mixed)
|