geored commited on
Commit
8b52ed6
1 Parent(s): 5c1306e

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +9 -0
  2. gtm/bin/convert-caffe2-to-onnx +8 -0
  3. gtm/bin/convert-onnx-to-caffe2 +8 -0
  4. gtm/bin/isympy +8 -0
  5. gtm/bin/torchrun +8 -0
  6. gtm/lib/python3.12/site-packages/__pycache__/isympy.cpython-312.pyc +0 -0
  7. gtm/lib/python3.12/site-packages/functorch/_C.cpython-312-darwin.so +0 -0
  8. gtm/lib/python3.12/site-packages/functorch/__init__.py +38 -0
  9. gtm/lib/python3.12/site-packages/functorch/__pycache__/__init__.cpython-312.pyc +0 -0
  10. gtm/lib/python3.12/site-packages/functorch/_src/__init__.py +0 -0
  11. gtm/lib/python3.12/site-packages/functorch/_src/__pycache__/__init__.cpython-312.pyc +0 -0
  12. gtm/lib/python3.12/site-packages/functorch/_src/aot_autograd/__init__.py +8 -0
  13. gtm/lib/python3.12/site-packages/functorch/_src/aot_autograd/__pycache__/__init__.cpython-312.pyc +0 -0
  14. gtm/lib/python3.12/site-packages/functorch/_src/eager_transforms/__init__.py +7 -0
  15. gtm/lib/python3.12/site-packages/functorch/_src/eager_transforms/__pycache__/__init__.cpython-312.pyc +0 -0
  16. gtm/lib/python3.12/site-packages/functorch/_src/make_functional/__init__.py +4 -0
  17. gtm/lib/python3.12/site-packages/functorch/_src/make_functional/__pycache__/__init__.cpython-312.pyc +0 -0
  18. gtm/lib/python3.12/site-packages/functorch/_src/vmap/__init__.py +16 -0
  19. gtm/lib/python3.12/site-packages/functorch/_src/vmap/__pycache__/__init__.cpython-312.pyc +0 -0
  20. gtm/lib/python3.12/site-packages/functorch/compile/__init__.py +31 -0
  21. gtm/lib/python3.12/site-packages/functorch/compile/__pycache__/__init__.cpython-312.pyc +0 -0
  22. gtm/lib/python3.12/site-packages/functorch/dim/__init__.py +179 -0
  23. gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/__init__.cpython-312.pyc +0 -0
  24. gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/batch_tensor.cpython-312.pyc +0 -0
  25. gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/delayed_mul_tensor.cpython-312.pyc +0 -0
  26. gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/dim.cpython-312.pyc +0 -0
  27. gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/magic_trace.cpython-312.pyc +0 -0
  28. gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/op_properties.cpython-312.pyc +0 -0
  29. gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/reference.cpython-312.pyc +0 -0
  30. gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/tree_map.cpython-312.pyc +0 -0
  31. gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/wrap_type.cpython-312.pyc +0 -0
  32. gtm/lib/python3.12/site-packages/functorch/dim/batch_tensor.py +25 -0
  33. gtm/lib/python3.12/site-packages/functorch/dim/delayed_mul_tensor.py +77 -0
  34. gtm/lib/python3.12/site-packages/functorch/dim/dim.py +110 -0
  35. gtm/lib/python3.12/site-packages/functorch/dim/magic_trace.py +42 -0
  36. gtm/lib/python3.12/site-packages/functorch/dim/op_properties.py +311 -0
  37. gtm/lib/python3.12/site-packages/functorch/dim/reference.py +645 -0
  38. gtm/lib/python3.12/site-packages/functorch/dim/tree_map.py +14 -0
  39. gtm/lib/python3.12/site-packages/functorch/dim/wrap_type.py +71 -0
  40. gtm/lib/python3.12/site-packages/functorch/einops/__init__.py +3 -0
  41. gtm/lib/python3.12/site-packages/functorch/einops/__pycache__/__init__.cpython-312.pyc +0 -0
  42. gtm/lib/python3.12/site-packages/functorch/einops/__pycache__/_parsing.cpython-312.pyc +0 -0
  43. gtm/lib/python3.12/site-packages/functorch/einops/__pycache__/rearrange.cpython-312.pyc +0 -0
  44. gtm/lib/python3.12/site-packages/functorch/einops/_parsing.py +302 -0
  45. gtm/lib/python3.12/site-packages/functorch/einops/rearrange.py +207 -0
  46. gtm/lib/python3.12/site-packages/functorch/experimental/__init__.py +6 -0
  47. gtm/lib/python3.12/site-packages/functorch/experimental/__pycache__/__init__.cpython-312.pyc +0 -0
  48. gtm/lib/python3.12/site-packages/functorch/experimental/__pycache__/control_flow.cpython-312.pyc +0 -0
  49. gtm/lib/python3.12/site-packages/functorch/experimental/__pycache__/ops.cpython-312.pyc +0 -0
  50. gtm/lib/python3.12/site-packages/functorch/experimental/control_flow.py +8 -0
.gitattributes CHANGED
@@ -55,3 +55,12 @@ gtm/lib/python3.12/site-packages/pandas/_libs/join.cpython-312-darwin.so filter=
55
  gtm/lib/python3.12/site-packages/pandas/_libs/tslibs/offsets.cpython-312-darwin.so filter=lfs diff=lfs merge=lfs -text
56
  gtm/lib/python3.12/site-packages/pydantic_core/_pydantic_core.cpython-312-darwin.so filter=lfs diff=lfs merge=lfs -text
57
  gtm/lib/python3.12/site-packages/safetensors/_safetensors_rust.cpython-312-darwin.so filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
55
  gtm/lib/python3.12/site-packages/pandas/_libs/tslibs/offsets.cpython-312-darwin.so filter=lfs diff=lfs merge=lfs -text
56
  gtm/lib/python3.12/site-packages/pydantic_core/_pydantic_core.cpython-312-darwin.so filter=lfs diff=lfs merge=lfs -text
57
  gtm/lib/python3.12/site-packages/safetensors/_safetensors_rust.cpython-312-darwin.so filter=lfs diff=lfs merge=lfs -text
58
+ gtm/lib/python3.12/site-packages/sympy/polys/benchmarks/__pycache__/bench_solvers.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
59
+ gtm/lib/python3.12/site-packages/torch/.dylibs/libiomp5.dylib filter=lfs diff=lfs merge=lfs -text
60
+ gtm/lib/python3.12/site-packages/torch/bin/protoc filter=lfs diff=lfs merge=lfs -text
61
+ gtm/lib/python3.12/site-packages/torch/bin/protoc-3.13.0.0 filter=lfs diff=lfs merge=lfs -text
62
+ gtm/lib/python3.12/site-packages/torch/lib/libiomp5.dylib filter=lfs diff=lfs merge=lfs -text
63
+ gtm/lib/python3.12/site-packages/torch/lib/libtorch_cpu.dylib filter=lfs diff=lfs merge=lfs -text
64
+ gtm/lib/python3.12/site-packages/torch/lib/libtorch_python.dylib filter=lfs diff=lfs merge=lfs -text
65
+ gtm/lib/python3.12/site-packages/torchvision/.dylibs/libc++.1.0.dylib filter=lfs diff=lfs merge=lfs -text
66
+ gtm/lib/python3.12/site-packages/torchvision/_C.so filter=lfs diff=lfs merge=lfs -text
gtm/bin/convert-caffe2-to-onnx ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #!/Users/gorgigeorgievski/code/ai/gtmio/gtm/bin/python3.12
2
+ # -*- coding: utf-8 -*-
3
+ import re
4
+ import sys
5
+ from caffe2.python.onnx.bin.conversion import caffe2_to_onnx
6
+ if __name__ == '__main__':
7
+ sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
8
+ sys.exit(caffe2_to_onnx())
gtm/bin/convert-onnx-to-caffe2 ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #!/Users/gorgigeorgievski/code/ai/gtmio/gtm/bin/python3.12
2
+ # -*- coding: utf-8 -*-
3
+ import re
4
+ import sys
5
+ from caffe2.python.onnx.bin.conversion import onnx_to_caffe2
6
+ if __name__ == '__main__':
7
+ sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
8
+ sys.exit(onnx_to_caffe2())
gtm/bin/isympy ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #!/Users/gorgigeorgievski/code/ai/gtmio/gtm/bin/python3.12
2
+ # -*- coding: utf-8 -*-
3
+ import re
4
+ import sys
5
+ from isympy import main
6
+ if __name__ == '__main__':
7
+ sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
8
+ sys.exit(main())
gtm/bin/torchrun ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #!/Users/gorgigeorgievski/code/ai/gtmio/gtm/bin/python3.12
2
+ # -*- coding: utf-8 -*-
3
+ import re
4
+ import sys
5
+ from torch.distributed.run import main
6
+ if __name__ == '__main__':
7
+ sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
8
+ sys.exit(main())
gtm/lib/python3.12/site-packages/__pycache__/isympy.cpython-312.pyc ADDED
Binary file (11 kB). View file
 
gtm/lib/python3.12/site-packages/functorch/_C.cpython-312-darwin.so ADDED
Binary file (150 kB). View file
 
gtm/lib/python3.12/site-packages/functorch/__init__.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import torch
7
+
8
+ from torch._functorch.deprecated import (
9
+ combine_state_for_ensemble,
10
+ functionalize,
11
+ grad,
12
+ grad_and_value,
13
+ hessian,
14
+ jacfwd,
15
+ jacrev,
16
+ jvp,
17
+ make_functional,
18
+ make_functional_with_buffers,
19
+ vjp,
20
+ vmap,
21
+ )
22
+
23
+ # utilities. Maybe these should go in their own namespace in the future?
24
+ from torch._functorch.make_functional import (
25
+ FunctionalModule,
26
+ FunctionalModuleWithBuffers,
27
+ )
28
+
29
+ # Top-level APIs. Please think carefully before adding something to the
30
+ # top-level namespace:
31
+ # - private helper functions should go into torch._functorch
32
+ # - very experimental things should go into functorch.experimental
33
+ # - compilation related things should go into functorch.compile
34
+
35
+ # Was never documented
36
+ from torch._functorch.python_key import make_fx
37
+
38
+ __version__ = torch.__version__
gtm/lib/python3.12/site-packages/functorch/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (755 Bytes). View file
 
gtm/lib/python3.12/site-packages/functorch/_src/__init__.py ADDED
File without changes
gtm/lib/python3.12/site-packages/functorch/_src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (194 Bytes). View file
 
gtm/lib/python3.12/site-packages/functorch/_src/aot_autograd/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # This file has moved to under torch/_functorch. It is not public API.
2
+ # If you are not a PyTorch developer and you are relying on the following
3
+ # imports, please file an issue.
4
+ from torch._functorch.aot_autograd import (
5
+ aot_autograd_decompositions,
6
+ KNOWN_TYPES,
7
+ PytreeThunk,
8
+ )
gtm/lib/python3.12/site-packages/functorch/_src/aot_autograd/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (342 Bytes). View file
 
gtm/lib/python3.12/site-packages/functorch/_src/eager_transforms/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # This file has moved to under torch/_functorch. It is not public API.
2
+ # If you are not a PyTorch developer and you are relying on the following
3
+ # imports, please file an issue.
4
+ from torch._functorch.eager_transforms import (
5
+ _assert_wrapped_functional,
6
+ _unwrap_functional_tensor,
7
+ )
gtm/lib/python3.12/site-packages/functorch/_src/eager_transforms/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (341 Bytes). View file
 
gtm/lib/python3.12/site-packages/functorch/_src/make_functional/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # This file has moved to under torch/_functorch. It is not public API.
2
+ # If you are not a PyTorch developer and you are relying on the following
3
+ # imports, please file an issue.
4
+ from torch._functorch.make_functional import _swap_state
gtm/lib/python3.12/site-packages/functorch/_src/make_functional/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (283 Bytes). View file
 
gtm/lib/python3.12/site-packages/functorch/_src/vmap/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file has moved to under torch/_functorch. It is not public API.
2
+ # If you are not a PyTorch developer and you are relying on the following
3
+ # imports, please file an issue.
4
+ from torch._functorch.vmap import (
5
+ _add_batch_dim,
6
+ _broadcast_to_and_flatten,
7
+ _create_batched_inputs,
8
+ _get_name,
9
+ _process_batched_inputs,
10
+ _remove_batch_dim,
11
+ _unwrap_batched,
12
+ _validate_and_get_batch_size,
13
+ Tensor,
14
+ tree_flatten,
15
+ tree_unflatten,
16
+ )
gtm/lib/python3.12/site-packages/functorch/_src/vmap/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (560 Bytes). View file
 
gtm/lib/python3.12/site-packages/functorch/compile/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch._functorch import config
2
+ from torch._functorch.aot_autograd import (
3
+ aot_function,
4
+ aot_module,
5
+ aot_module_simplified,
6
+ compiled_function,
7
+ compiled_module,
8
+ get_aot_compilation_context,
9
+ get_aot_graph_name,
10
+ get_graph_being_compiled,
11
+ make_boxed_compiler,
12
+ make_boxed_func,
13
+ )
14
+ from torch._functorch.compilers import (
15
+ debug_compile,
16
+ default_decompositions,
17
+ draw_graph_compile,
18
+ memory_efficient_fusion,
19
+ nnc_jit,
20
+ nop,
21
+ print_compile,
22
+ ts_compile,
23
+ )
24
+ from torch._functorch.fx_minifier import minifier
25
+ from torch._functorch.partitioners import (
26
+ default_partition,
27
+ draw_graph,
28
+ draw_joint_graph,
29
+ min_cut_rematerialization_partition,
30
+ )
31
+ from torch._functorch.python_key import pythonkey_decompose
gtm/lib/python3.12/site-packages/functorch/compile/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.15 kB). View file
 
gtm/lib/python3.12/site-packages/functorch/dim/__init__.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dis
2
+ import inspect
3
+ from typing import Sequence, Union
4
+
5
+ import torch
6
+
7
+ import functorch._C
8
+ from functorch._C import dim as _C
9
+ from .tree_map import tree_flatten, tree_map
10
+ from .wrap_type import wrap_type
11
+
12
+ _C._patch_tensor_class()
13
+ dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists
14
+
15
+
16
+ class DimensionMismatchError(Exception):
17
+ pass
18
+
19
+
20
+ class DimensionBindError(Exception):
21
+ pass
22
+
23
+
24
+ from . import op_properties
25
+
26
+ # use dict to avoid writing C++ bindings for set
27
+ pointwise = {t: True for t in op_properties.pointwise}
28
+
29
+ use_c = True
30
+ if not use_c:
31
+ from . import reference
32
+
33
+
34
+ class _Tensor:
35
+ # fast path around slow wrapping/unwrapping logic for simply queries used
36
+ # by the implementation...
37
+
38
+ @property
39
+ def dims(self):
40
+ return tuple(d for d in self._levels if isinstance(d, Dim))
41
+
42
+ def dim(self):
43
+ return self.ndim
44
+
45
+ if use_c:
46
+ __torch_function__ = classmethod(_C.__torch_function__)
47
+ expand = _C._instancemethod(_C.expand)
48
+ else:
49
+ __torch_function__ = reference.__torch_function__
50
+ expand = reference.expand
51
+
52
+ index = _C._instancemethod(_C.index)
53
+
54
+ def __repr__(self):
55
+ tensor, levels, ndim = self._tensor, self._levels, self.ndim
56
+ return f"{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}"
57
+
58
+
59
+ TensorLike = (_Tensor, torch.Tensor)
60
+
61
+
62
+ class Dim(_C.Dim, _Tensor):
63
+ # note that _C.Dim comes before tensor because we want the Dim API for things like size to take precendence.
64
+ # Tensor defines format, but we want to print Dims with special formatting
65
+ __format__ = object.__format__
66
+
67
+
68
+ class Tensor(_Tensor, _C.Tensor):
69
+ if not use_c:
70
+ from_batched = staticmethod(_C.Tensor_from_batched)
71
+ from_positional = staticmethod(_C.Tensor_from_positional)
72
+ sum = _C._instancemethod(_C.Tensor_sum)
73
+
74
+
75
+ def cat(tensors, dim, new_dim):
76
+ n = dims()
77
+ return stack(tensors, n, dim).index([n, dim], new_dim)
78
+
79
+
80
+ if use_c:
81
+ _wrap = _C._wrap
82
+
83
+ def _def(name, *args, **kwargs):
84
+ orig = getattr(torch.Tensor, name)
85
+ setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs)))
86
+
87
+ t__getitem__ = _C._instancemethod(_C.__getitem__)
88
+ stack = _C.stack
89
+ split = _C._instancemethod(_C.split)
90
+ else:
91
+ _wrap, _def = reference._wrap, reference._def
92
+ t__getitem__ = reference.t__getitem__
93
+ stack = reference.stack
94
+ split = reference.split
95
+
96
+ # note: there is no python reference
97
+ t__setitem__ = _C._instancemethod(_C.__setitem__)
98
+ # this is patched in the C API because otherwise torch.Tensor will
99
+ # no longer be considered a sequence and things will break
100
+ # torch.Tensor.__getitem__ = t__getitem__
101
+
102
+ _Tensor.__getitem__ = t__getitem__
103
+ # torch.Tensor.__setitem__ = t__setitem__
104
+ _Tensor.__setitem__ = t__setitem__
105
+
106
+ torch.Tensor.split = split
107
+ _Tensor.split = split
108
+ torch.Tensor.expand = _C._instancemethod(_C.expand)
109
+ torch.Tensor.index = _C._instancemethod(_C.index)
110
+ wrap_type(use_c, _Tensor, torch.Tensor, _Tensor.__torch_function__)
111
+ del _Tensor.ndim
112
+
113
+ if use_c:
114
+ _Tensor.order = _C._instancemethod(_C.order)
115
+ else:
116
+ _Tensor.order = reference.positional
117
+
118
+ _def("mean")
119
+ _def("sum")
120
+ _def("all")
121
+ _def("amax")
122
+ _def("amin")
123
+ _def("aminmax")
124
+ _def("any")
125
+ _def("count_nonzero")
126
+ _def("logsumexp")
127
+ _def("nanmean")
128
+ _def("nansum")
129
+ _def("prod")
130
+ _def("std", keepdim_offset=2)
131
+ _def("var", keepdim_offset=2)
132
+ _def("max", single_dim=True)
133
+ _def("min", single_dim=True)
134
+ _def("argmax", single_dim=True)
135
+ _def("argmin", single_dim=True)
136
+ _def("kthvalue", single_dim=True)
137
+ _def("median", single_dim=True)
138
+ _def("nanmedian", single_dim=True)
139
+ _def("mode", single_dim=True)
140
+ _def("sort", reduce=False)
141
+ _def("argsort", reduce=False)
142
+ _def("unbind", single_dim=True)
143
+ _def("chunk", dim_offset=1, reduce=False)
144
+ _def("cummax", single_dim=True, reduce=False)
145
+ _def("cummin", single_dim=True, reduce=False)
146
+ _def("cumprod", single_dim=True, reduce=False)
147
+ _def("cumprod_", single_dim=True, reduce=False)
148
+ _def("cumsum", single_dim=True, reduce=False)
149
+ _def("cumsum_", single_dim=True, reduce=False)
150
+ _def("logcumsumexp", single_dim=True, reduce=False)
151
+ _def("renorm", dim_offset=1, single_dim=True, reduce=False)
152
+ _def("softmax", single_dim=True, reduce=False)
153
+ softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False)
154
+
155
+ # stuff to handle in the future, because they require special
156
+ # binding logic for dims
157
+ # cross
158
+ # diag_embed
159
+ # diagonal
160
+ # diagonal_scatter
161
+ # diff
162
+ # nanquantile
163
+ # quantile
164
+ # roll
165
+ # rot90
166
+ # topk (new dimes on output)
167
+ # should these all be subsumed by inplace indexing?
168
+ # index_add_
169
+ # index_add
170
+ # index_copy
171
+ # index_copy_
172
+ # index_fill
173
+ # index_fill_
174
+ # index_select
175
+ # scatter
176
+ # scatter_
177
+ # scatter_add
178
+ # scatter_add_
179
+ # scatter_reduce
gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (7.16 kB). View file
 
gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/batch_tensor.cpython-312.pyc ADDED
Binary file (1.13 kB). View file
 
gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/delayed_mul_tensor.cpython-312.pyc ADDED
Binary file (5.27 kB). View file
 
gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/dim.cpython-312.pyc ADDED
Binary file (6.18 kB). View file
 
gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/magic_trace.cpython-312.pyc ADDED
Binary file (2.26 kB). View file
 
gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/op_properties.cpython-312.pyc ADDED
Binary file (17 kB). View file
 
gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/reference.cpython-312.pyc ADDED
Binary file (27.8 kB). View file
 
gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/tree_map.cpython-312.pyc ADDED
Binary file (695 Bytes). View file
 
gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/wrap_type.cpython-312.pyc ADDED
Binary file (2.15 kB). View file
 
gtm/lib/python3.12/site-packages/functorch/dim/batch_tensor.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ from contextlib import contextmanager
7
+
8
+ from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers
9
+
10
+ _enabled = False
11
+
12
+
13
+ @contextmanager
14
+ def _enable_layers(dims):
15
+ global _enabled
16
+ assert not _enabled
17
+ input = sorted((d._level, d.size) for d in dims if not isinstance(d, int))
18
+ n = len(input)
19
+ try:
20
+ _vmap_add_layers(input)
21
+ _enabled = True
22
+ yield
23
+ finally:
24
+ _enabled = False
25
+ _vmap_remove_layers(n)
gtm/lib/python3.12/site-packages/functorch/dim/delayed_mul_tensor.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import torch
7
+
8
+ from . import _Tensor, Tensor
9
+ from .reference import _dims, _enable_layers, llist, ltuple
10
+
11
+
12
+ class DelayedMulTensor(_Tensor):
13
+ def __init__(self, lhs, rhs):
14
+ self._lhs, self._rhs = lhs, rhs
15
+ self._data = None
16
+ self._levels_data = None
17
+ self._has_device = lhs._has_device or rhs._has_device
18
+ self._batchtensor_data = None
19
+ self._tensor_data = None
20
+
21
+ @property
22
+ def _levels(self):
23
+ if self._levels_data is None:
24
+ levels = llist(self._lhs._levels)
25
+ for l in self._rhs._levels:
26
+ if l not in levels:
27
+ levels.append(l)
28
+ self._levels_data = ltuple(levels)
29
+ return self._levels_data
30
+
31
+ @property
32
+ def _batchtensor(self):
33
+ if self._batchtensor_data is None:
34
+ with _enable_layers(self._levels):
35
+ print("bt multiply fallback")
36
+ self._batchtensor_data = self._lhs._batchtensor * self._rhs._batchtensor
37
+ return self._batchtensor_data
38
+
39
+ @property
40
+ def _tensor(self):
41
+ if self._tensor_data is None:
42
+ self._tensor_data = Tensor.from_batched(
43
+ self._batchtensor, self._has_device
44
+ )._tensor
45
+ return self._tensor_data
46
+
47
+ @property
48
+ def ndim(self):
49
+ return self._batchtensor.ndim
50
+
51
+ @property
52
+ def dims(self):
53
+ return ltuple(super().dims)
54
+
55
+ def sum(self, dim):
56
+ dims = _dims(dim, 0, False, False)
57
+ n = ord("a")
58
+ all_levels = self._levels
59
+
60
+ def to_char(d):
61
+ return chr(n + all_levels.index(d))
62
+
63
+ plhs, levelslhs = self._lhs._tensor, self._lhs._levels
64
+ prhs, levelsrhs = self._rhs._tensor, self._rhs._levels
65
+ new_dims = tuple(d for d in self.dims if d not in dims)
66
+ new_levels = [l for l in self._levels if l not in dims]
67
+ fmt = "".join(
68
+ [
69
+ *(to_char(d) for d in levelslhs),
70
+ ",",
71
+ *(to_char(d) for d in levelsrhs),
72
+ "->",
73
+ *(to_char(d) for d in new_levels),
74
+ ]
75
+ )
76
+ result_data = torch.einsum(fmt, (plhs, prhs))
77
+ return Tensor.from_positional(result_data, new_levels, True)
gtm/lib/python3.12/site-packages/functorch/dim/dim.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ _vmap_levels = []
7
+
8
+
9
+ @dataclass
10
+ class LevelInfo:
11
+ level: int
12
+ alive: bool = True
13
+
14
+
15
+ class Dim:
16
+ def __init__(self, name: str, size: Union[None, int] = None):
17
+ self.name = name
18
+ self._size = None
19
+ self._vmap_level = None
20
+ if size is not None:
21
+ self.size = size
22
+
23
+ def __del__(self):
24
+ if self._vmap_level is not None:
25
+ _vmap_active_levels[self._vmap_stack].alive = False
26
+ while (
27
+ not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level
28
+ ):
29
+ _vmap_decrement_nesting()
30
+ _vmap_levels.pop()
31
+
32
+ @property
33
+ def size(self):
34
+ assert self.is_bound
35
+ return self._size
36
+
37
+ @size.setter
38
+ def size(self, size: int):
39
+ if self._size is None:
40
+ self._size = size
41
+ self._vmap_level = _vmap_increment_nesting(size, "same")
42
+ self._vmap_stack = len(_vmap_levels)
43
+ _vmap_levels.append(LevelInfo(self._vmap_level))
44
+
45
+ elif self._size != size:
46
+ raise DimensionBindError(
47
+ f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}"
48
+ )
49
+
50
+ @property
51
+ def is_bound(self):
52
+ return self._size is not None
53
+
54
+ def __repr__(self):
55
+ return self.name
56
+
57
+
58
+ def extract_name(inst):
59
+ assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME"
60
+ return inst.argval
61
+
62
+
63
+ _cache = {}
64
+
65
+
66
+ def dims(lists=0):
67
+ frame = inspect.currentframe()
68
+ assert frame is not None
69
+ calling_frame = frame.f_back
70
+ assert calling_frame is not None
71
+ code, lasti = calling_frame.f_code, calling_frame.f_lasti
72
+ key = (code, lasti)
73
+ if key not in _cache:
74
+ first = lasti // 2 + 1
75
+ instructions = list(dis.get_instructions(calling_frame.f_code))
76
+ unpack = instructions[first]
77
+
78
+ if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME":
79
+ # just a single dim, not a list
80
+ name = unpack.argval
81
+ ctor = Dim if lists == 0 else DimList
82
+ _cache[key] = lambda: ctor(name=name)
83
+ else:
84
+ assert unpack.opname == "UNPACK_SEQUENCE"
85
+ ndims = unpack.argval
86
+ names = tuple(
87
+ extract_name(instructions[first + 1 + i]) for i in range(ndims)
88
+ )
89
+ first_list = len(names) - lists
90
+ _cache[key] = lambda: tuple(
91
+ Dim(n) if i < first_list else DimList(name=n)
92
+ for i, n in enumerate(names)
93
+ )
94
+ return _cache[key]()
95
+
96
+
97
+ def _dim_set(positional, arg):
98
+ def convert(a):
99
+ if isinstance(a, Dim):
100
+ return a
101
+ else:
102
+ assert isinstance(a, int)
103
+ return positional[a]
104
+
105
+ if arg is None:
106
+ return positional
107
+ elif not isinstance(arg, (Dim, int)):
108
+ return tuple(convert(a) for a in arg)
109
+ else:
110
+ return (convert(arg),)
gtm/lib/python3.12/site-packages/functorch/dim/magic_trace.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import os
7
+ import signal
8
+ import subprocess
9
+ from contextlib import contextmanager
10
+
11
+
12
+ @contextmanager
13
+ def magic_trace(output="trace.fxt", magic_trace_cache="/tmp/magic-trace"):
14
+ pid = os.getpid()
15
+ if not os.path.exists(magic_trace_cache):
16
+ print(f"Downloading magic_trace to: {magic_trace_cache}")
17
+ subprocess.run(
18
+ [
19
+ "wget",
20
+ "-O",
21
+ magic_trace_cache,
22
+ "-q",
23
+ "https://github.com/janestreet/magic-trace/releases/download/v1.0.2/magic-trace",
24
+ ]
25
+ )
26
+ subprocess.run(["chmod", "+x", magic_trace_cache])
27
+ args = [magic_trace_cache, "attach", "-pid", str(pid), "-o", output]
28
+ p = subprocess.Popen(args, stderr=subprocess.PIPE, encoding="utf-8")
29
+ while True:
30
+ x = p.stderr.readline()
31
+ print(x)
32
+ if "Attached" in x:
33
+ break
34
+ try:
35
+ yield
36
+ finally:
37
+ p.send_signal(signal.SIGINT)
38
+ r = p.wait()
39
+ print(p.stderr.read())
40
+ p.stderr.close()
41
+ if r != 0:
42
+ raise ValueError(f"magic_trace exited abnormally: {r}")
gtm/lib/python3.12/site-packages/functorch/dim/op_properties.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import torch
7
+
8
+ # pointwise operators can go through a faster pathway
9
+
10
+ tensor_magic_methods = ["add", ""]
11
+ pointwise_magic_methods_with_reverse = (
12
+ "add",
13
+ "sub",
14
+ "mul",
15
+ "floordiv",
16
+ "div",
17
+ "truediv",
18
+ "mod",
19
+ "pow",
20
+ "lshift",
21
+ "rshift",
22
+ "and",
23
+ "or",
24
+ "xor",
25
+ )
26
+ pointwise_magic_methods = (
27
+ *(x for m in pointwise_magic_methods_with_reverse for x in (m, "r" + m)),
28
+ "eq",
29
+ "gt",
30
+ "le",
31
+ "lt",
32
+ "ge",
33
+ "gt",
34
+ "ne",
35
+ "neg",
36
+ "pos",
37
+ "abs",
38
+ "invert",
39
+ "iadd",
40
+ "isub",
41
+ "imul",
42
+ "ifloordiv",
43
+ "idiv",
44
+ "itruediv",
45
+ "imod",
46
+ "ipow",
47
+ "ilshift",
48
+ "irshift",
49
+ "iand",
50
+ "ior",
51
+ "ixor",
52
+ "int",
53
+ "long",
54
+ "float",
55
+ "complex",
56
+ )
57
+
58
+ pointwise_methods = (*(f"__{m}__" for m in pointwise_magic_methods),)
59
+
60
+ pointwise = (
61
+ *(getattr(torch.Tensor, m) for m in pointwise_methods),
62
+ torch.nn.functional.dropout,
63
+ torch.where,
64
+ torch.Tensor.abs,
65
+ torch.abs,
66
+ torch.Tensor.acos,
67
+ torch.acos,
68
+ torch.Tensor.acosh,
69
+ torch.acosh,
70
+ torch.Tensor.add,
71
+ torch.add,
72
+ torch.Tensor.addcdiv,
73
+ torch.addcdiv,
74
+ torch.Tensor.addcmul,
75
+ torch.addcmul,
76
+ torch.Tensor.addr,
77
+ torch.addr,
78
+ torch.Tensor.angle,
79
+ torch.angle,
80
+ torch.Tensor.asin,
81
+ torch.asin,
82
+ torch.Tensor.asinh,
83
+ torch.asinh,
84
+ torch.Tensor.atan,
85
+ torch.atan,
86
+ torch.Tensor.atan2,
87
+ torch.atan2,
88
+ torch.Tensor.atanh,
89
+ torch.atanh,
90
+ torch.Tensor.bitwise_and,
91
+ torch.bitwise_and,
92
+ torch.Tensor.bitwise_left_shift,
93
+ torch.bitwise_left_shift,
94
+ torch.Tensor.bitwise_not,
95
+ torch.bitwise_not,
96
+ torch.Tensor.bitwise_or,
97
+ torch.bitwise_or,
98
+ torch.Tensor.bitwise_right_shift,
99
+ torch.bitwise_right_shift,
100
+ torch.Tensor.bitwise_xor,
101
+ torch.bitwise_xor,
102
+ torch.Tensor.ceil,
103
+ torch.ceil,
104
+ torch.celu,
105
+ torch.nn.functional.celu,
106
+ torch.Tensor.clamp,
107
+ torch.clamp,
108
+ torch.Tensor.clamp_max,
109
+ torch.clamp_max,
110
+ torch.Tensor.clamp_min,
111
+ torch.clamp_min,
112
+ torch.Tensor.copysign,
113
+ torch.copysign,
114
+ torch.Tensor.cos,
115
+ torch.cos,
116
+ torch.Tensor.cosh,
117
+ torch.cosh,
118
+ torch.Tensor.deg2rad,
119
+ torch.deg2rad,
120
+ torch.Tensor.digamma,
121
+ torch.digamma,
122
+ torch.Tensor.div,
123
+ torch.div,
124
+ torch.dropout,
125
+ torch.nn.functional.dropout,
126
+ torch.nn.functional.elu,
127
+ torch.Tensor.eq,
128
+ torch.eq,
129
+ torch.Tensor.erf,
130
+ torch.erf,
131
+ torch.Tensor.erfc,
132
+ torch.erfc,
133
+ torch.Tensor.erfinv,
134
+ torch.erfinv,
135
+ torch.Tensor.exp,
136
+ torch.exp,
137
+ torch.Tensor.exp2,
138
+ torch.exp2,
139
+ torch.Tensor.expm1,
140
+ torch.expm1,
141
+ torch.feature_dropout,
142
+ torch.Tensor.float_power,
143
+ torch.float_power,
144
+ torch.Tensor.floor,
145
+ torch.floor,
146
+ torch.Tensor.floor_divide,
147
+ torch.floor_divide,
148
+ torch.Tensor.fmod,
149
+ torch.fmod,
150
+ torch.Tensor.frac,
151
+ torch.frac,
152
+ torch.Tensor.frexp,
153
+ torch.frexp,
154
+ torch.Tensor.gcd,
155
+ torch.gcd,
156
+ torch.Tensor.ge,
157
+ torch.ge,
158
+ torch.nn.functional.gelu,
159
+ torch.nn.functional.glu,
160
+ torch.Tensor.gt,
161
+ torch.gt,
162
+ torch.Tensor.hardshrink,
163
+ torch.hardshrink,
164
+ torch.nn.functional.hardshrink,
165
+ torch.nn.functional.hardsigmoid,
166
+ torch.nn.functional.hardswish,
167
+ torch.nn.functional.hardtanh,
168
+ torch.Tensor.heaviside,
169
+ torch.heaviside,
170
+ torch.Tensor.hypot,
171
+ torch.hypot,
172
+ torch.Tensor.i0,
173
+ torch.i0,
174
+ torch.Tensor.igamma,
175
+ torch.igamma,
176
+ torch.Tensor.igammac,
177
+ torch.igammac,
178
+ torch.Tensor.isclose,
179
+ torch.isclose,
180
+ torch.Tensor.isfinite,
181
+ torch.isfinite,
182
+ torch.Tensor.isinf,
183
+ torch.isinf,
184
+ torch.Tensor.isnan,
185
+ torch.isnan,
186
+ torch.Tensor.isneginf,
187
+ torch.isneginf,
188
+ torch.Tensor.isposinf,
189
+ torch.isposinf,
190
+ torch.Tensor.isreal,
191
+ torch.isreal,
192
+ torch.Tensor.kron,
193
+ torch.kron,
194
+ torch.Tensor.lcm,
195
+ torch.lcm,
196
+ torch.Tensor.ldexp,
197
+ torch.ldexp,
198
+ torch.Tensor.le,
199
+ torch.le,
200
+ torch.nn.functional.leaky_relu,
201
+ torch.Tensor.lerp,
202
+ torch.lerp,
203
+ torch.Tensor.lgamma,
204
+ torch.lgamma,
205
+ torch.Tensor.log,
206
+ torch.log,
207
+ torch.Tensor.log10,
208
+ torch.log10,
209
+ torch.Tensor.log1p,
210
+ torch.log1p,
211
+ torch.Tensor.log2,
212
+ torch.log2,
213
+ torch.nn.functional.logsigmoid,
214
+ torch.Tensor.logical_and,
215
+ torch.logical_and,
216
+ torch.Tensor.logical_not,
217
+ torch.logical_not,
218
+ torch.Tensor.logical_or,
219
+ torch.logical_or,
220
+ torch.Tensor.logical_xor,
221
+ torch.logical_xor,
222
+ torch.Tensor.logit,
223
+ torch.logit,
224
+ torch.Tensor.lt,
225
+ torch.lt,
226
+ torch.Tensor.maximum,
227
+ torch.maximum,
228
+ torch.Tensor.minimum,
229
+ torch.minimum,
230
+ torch.nn.functional.mish,
231
+ torch.Tensor.mvlgamma,
232
+ torch.mvlgamma,
233
+ torch.Tensor.nan_to_num,
234
+ torch.nan_to_num,
235
+ torch.Tensor.ne,
236
+ torch.ne,
237
+ torch.Tensor.neg,
238
+ torch.neg,
239
+ torch.Tensor.nextafter,
240
+ torch.nextafter,
241
+ torch.Tensor.outer,
242
+ torch.outer,
243
+ torch.polar,
244
+ torch.Tensor.polygamma,
245
+ torch.polygamma,
246
+ torch.Tensor.positive,
247
+ torch.positive,
248
+ torch.Tensor.pow,
249
+ torch.pow,
250
+ torch.Tensor.prelu,
251
+ torch.prelu,
252
+ torch.nn.functional.prelu,
253
+ torch.Tensor.rad2deg,
254
+ torch.rad2deg,
255
+ torch.Tensor.reciprocal,
256
+ torch.reciprocal,
257
+ torch.Tensor.relu,
258
+ torch.relu,
259
+ torch.nn.functional.relu,
260
+ torch.nn.functional.relu6,
261
+ torch.Tensor.remainder,
262
+ torch.remainder,
263
+ torch.Tensor.round,
264
+ torch.round,
265
+ torch.rrelu,
266
+ torch.nn.functional.rrelu,
267
+ torch.Tensor.rsqrt,
268
+ torch.rsqrt,
269
+ torch.rsub,
270
+ torch.selu,
271
+ torch.nn.functional.selu,
272
+ torch.Tensor.sgn,
273
+ torch.sgn,
274
+ torch.Tensor.sigmoid,
275
+ torch.sigmoid,
276
+ torch.nn.functional.sigmoid,
277
+ torch.Tensor.sign,
278
+ torch.sign,
279
+ torch.Tensor.signbit,
280
+ torch.signbit,
281
+ torch.nn.functional.silu,
282
+ torch.Tensor.sin,
283
+ torch.sin,
284
+ torch.Tensor.sinc,
285
+ torch.sinc,
286
+ torch.Tensor.sinh,
287
+ torch.sinh,
288
+ torch.nn.functional.softplus,
289
+ torch.nn.functional.softshrink,
290
+ torch.Tensor.sqrt,
291
+ torch.sqrt,
292
+ torch.Tensor.square,
293
+ torch.square,
294
+ torch.Tensor.sub,
295
+ torch.sub,
296
+ torch.Tensor.tan,
297
+ torch.tan,
298
+ torch.Tensor.tanh,
299
+ torch.tanh,
300
+ torch.nn.functional.tanh,
301
+ torch.threshold,
302
+ torch.nn.functional.threshold,
303
+ torch.trapz,
304
+ torch.Tensor.true_divide,
305
+ torch.true_divide,
306
+ torch.Tensor.trunc,
307
+ torch.trunc,
308
+ torch.Tensor.xlogy,
309
+ torch.xlogy,
310
+ torch.rand_like,
311
+ )
gtm/lib/python3.12/site-packages/functorch/dim/reference.py ADDED
@@ -0,0 +1,645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # reference python implementations for C ops
8
+ import torch
9
+
10
+ from functorch._C import dim as _C
11
+ from . import op_properties
12
+ from .batch_tensor import _enable_layers
13
+ from .tree_map import tree_flatten, tree_map
14
+
15
+ DimList = _C.DimList
16
+ import operator
17
+ from functools import reduce
18
+
19
+
20
+ # use dict to avoid writing C++ bindings for set
21
+ pointwise = set(op_properties.pointwise)
22
+
23
+
24
+ def prod(x):
25
+ return reduce(operator.mul, x, 1)
26
+
27
+
28
+ def _wrap_dim(d, N, keepdim):
29
+ from . import Dim
30
+
31
+ if isinstance(d, Dim):
32
+ assert not keepdim, "cannot preserve first-class dimensions with keepdim=True"
33
+ return d
34
+ elif d >= 0:
35
+ return d - N
36
+ else:
37
+ return d
38
+
39
+
40
+ def _dims(d, N, keepdim, single_dim):
41
+ from . import Dim
42
+
43
+ if isinstance(d, (Dim, int)):
44
+ return ltuple((_wrap_dim(d, N, keepdim),))
45
+ assert not single_dim, f"expected a single dimension or int but found: {d}"
46
+ return ltuple(_wrap_dim(x, N, keepdim) for x in d)
47
+
48
+
49
+ def _bind_dims_to_size(lhs_size, rhs, lhs_debug):
50
+ from . import DimensionMismatchError
51
+
52
+ not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound)
53
+ if len(not_bound) == 1:
54
+ idx, d = not_bound[0]
55
+ rhs_so_far = prod(r.size for r in rhs if r.is_bound)
56
+ if lhs_size % rhs_so_far != 0:
57
+ rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
58
+ raise DimensionMismatchError(
59
+ f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}"
60
+ )
61
+ new_size = lhs_size // rhs_so_far
62
+ d.size = new_size
63
+ elif len(not_bound) > 1:
64
+ rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
65
+ raise DimensionMismatchError(
66
+ f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}"
67
+ )
68
+ else:
69
+ rhs_size = prod(r.size for r in rhs)
70
+ if lhs_size != rhs_size:
71
+ raise DimensionMismatchError(
72
+ f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}"
73
+ )
74
+
75
+
76
+ def _tensor_levels(inp):
77
+ from . import _Tensor
78
+
79
+ if isinstance(inp, _Tensor):
80
+ return inp._tensor, llist(inp._levels), inp._has_device
81
+ else:
82
+ return inp, llist(range(-inp.ndim, 0)), True
83
+
84
+
85
+ def _match_levels(v, from_levels, to_levels):
86
+ view = []
87
+ permute = []
88
+ requires_view = False
89
+ size = v.size()
90
+ for t in to_levels:
91
+ try:
92
+ idx = from_levels.index(t)
93
+ permute.append(idx)
94
+ view.append(size[idx])
95
+ except ValueError:
96
+ view.append(1)
97
+ requires_view = True
98
+ if permute != list(range(len(permute))):
99
+ v = v.permute(*permute)
100
+ if requires_view:
101
+ v = v.view(*view)
102
+ return v
103
+
104
+
105
+ # make a single dimension positional but do not permute it,
106
+ # used to do multi-tensor operators where the dim being acted on
107
+ # should not physically move if possible
108
+ def _positional_no_permute(self, dim, expand_dim=False):
109
+ from . import Tensor
110
+
111
+ ptensor, levels = self._tensor, llist(self._levels)
112
+ try:
113
+ idx = levels.index(dim)
114
+ except ValueError:
115
+ if not expand_dim:
116
+ raise
117
+ idx = 0
118
+ ptensor = ptensor.expand(dim.size, *ptensor.size())
119
+ levels.insert(0, 0)
120
+ idx_batched = 0
121
+ for i in range(idx):
122
+ if isinstance(levels[i], int):
123
+ levels[i] -= 1
124
+ idx_batched += 1
125
+ levels[idx] = -idx_batched - 1
126
+ return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched
127
+
128
+
129
+ def seq(a, b):
130
+ from . import Dim
131
+
132
+ if isinstance(a, Dim) != isinstance(b, Dim):
133
+ return False
134
+ if isinstance(a, Dim):
135
+ return a is b
136
+ else:
137
+ return a == b
138
+
139
+
140
+ class isin:
141
+ def __contains__(self, item):
142
+ for x in self:
143
+ if seq(item, x):
144
+ return True
145
+ return False
146
+
147
+ def index(self, item):
148
+ for i, x in enumerate(self):
149
+ if seq(item, x):
150
+ return i
151
+ raise ValueError
152
+
153
+
154
+ class llist(isin, list):
155
+ pass
156
+
157
+
158
+ class ltuple(isin, tuple):
159
+ pass
160
+
161
+
162
+ empty_dict = {}
163
+
164
+
165
+ @classmethod
166
+ def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
167
+ from . import _Tensor, Tensor, TensorLike
168
+ from .delayed_mul_tensor import DelayedMulTensor
169
+
170
+ if orig is torch.Tensor.__mul__:
171
+ lhs, rhs = args
172
+ if (
173
+ isinstance(lhs, _Tensor)
174
+ and isinstance(rhs, _Tensor)
175
+ and lhs.ndim == 0
176
+ and rhs.ndim == 0
177
+ ):
178
+ return DelayedMulTensor(lhs, rhs)
179
+ all_dims = llist()
180
+ flat_args, unflatten = tree_flatten((args, kwargs))
181
+ device_holding_tensor = None
182
+ for f in flat_args:
183
+ if isinstance(f, _Tensor):
184
+ if f._has_device:
185
+ device_holding_tensor = f._batchtensor
186
+ for d in f.dims:
187
+ if d not in all_dims:
188
+ all_dims.append(d)
189
+
190
+ def unwrap(t):
191
+ if isinstance(t, _Tensor):
192
+ r = t._batchtensor
193
+ if device_holding_tensor is not None and not t._has_device:
194
+ r = r.to(device=device_holding_tensor.device)
195
+ return r
196
+ return t
197
+
198
+ if orig in pointwise:
199
+ result_levels = llist()
200
+ arg_levels = llist()
201
+ to_expand = []
202
+ for i, f in enumerate(flat_args):
203
+ if isinstance(f, TensorLike):
204
+ ptensor, levels, _ = _tensor_levels(f)
205
+ if (
206
+ isinstance(f, _Tensor)
207
+ and not f._has_device
208
+ and device_holding_tensor is not None
209
+ ):
210
+ ptensor = ptensor.to(device=device_holding_tensor.device)
211
+ flat_args[i] = ptensor
212
+ for l in levels:
213
+ if l not in result_levels:
214
+ result_levels.append(l)
215
+ to_expand.append((i, levels))
216
+
217
+ for i, levels in to_expand:
218
+ flat_args[i] = _match_levels(flat_args[i], levels, result_levels)
219
+ args, kwargs = unflatten(flat_args)
220
+ result = orig(*args, **kwargs)
221
+
222
+ def wrap(t):
223
+ if isinstance(t, TensorLike):
224
+ return Tensor.from_positional(
225
+ t, result_levels, device_holding_tensor is not None
226
+ )
227
+ return t
228
+
229
+ return tree_map(wrap, result)
230
+ else:
231
+
232
+ def wrap(t):
233
+ if isinstance(t, TensorLike):
234
+ return Tensor.from_batched(t, device_holding_tensor is not None)
235
+ return t
236
+
237
+ with _enable_layers(all_dims):
238
+ print(f"batch_tensor for {orig}")
239
+ args, kwargs = unflatten(unwrap(f) for f in flat_args)
240
+ result = orig(*args, **kwargs)
241
+ # print("END", orig)
242
+ return tree_map(wrap, result)
243
+
244
+
245
+ def positional(self, *dims):
246
+ from . import Dim, Tensor
247
+
248
+ ptensor, levels = self._tensor, llist(self._levels)
249
+ flat_dims = llist()
250
+ view = []
251
+ needs_view = False
252
+ ndim = self.ndim
253
+ for d in dims:
254
+ if isinstance(d, DimList):
255
+ flat_dims.extend(d)
256
+ view.extend(e.size for e in d)
257
+ elif isinstance(d, Dim):
258
+ flat_dims.append(d)
259
+ view.append(d.size)
260
+ elif isinstance(d, int):
261
+ d = _wrap_dim(d, ndim, False)
262
+ flat_dims.append(d)
263
+ view.append(ptensor.size(d))
264
+ else:
265
+ flat_dims.extend(d)
266
+ view.append(prod(e.size for e in d))
267
+ needs_view = True
268
+
269
+ permute = list(range(len(levels)))
270
+ nflat = len(flat_dims)
271
+ for i, d in enumerate(flat_dims):
272
+ try:
273
+ idx = levels.index(d)
274
+ except ValueError as e:
275
+ raise DimensionBindError(
276
+ f"tensor of dimensions {self.dims} does not contain dim {d}"
277
+ ) from e
278
+ p = permute[idx]
279
+ del levels[idx]
280
+ del permute[idx]
281
+ levels.insert(i, 0)
282
+ permute.insert(i, p)
283
+ ptensor = ptensor.permute(*permute)
284
+ seen = 0
285
+ for i in range(len(levels) - 1, -1, -1):
286
+ if isinstance(levels[i], int):
287
+ seen += 1
288
+ levels[i] = -seen
289
+ result = Tensor.from_positional(ptensor, levels, self._has_device)
290
+ if needs_view:
291
+ result = result.reshape(*view, *result.size()[len(flat_dims) :])
292
+ return result
293
+
294
+
295
+ def _contains_dim(input):
296
+ from . import Dim
297
+
298
+ for i in input:
299
+ if isinstance(i, Dim):
300
+ return True
301
+
302
+
303
+ def expand(self, *sizes):
304
+ if not _contains_dim(sizes):
305
+ return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes))
306
+ dims = sizes
307
+ sizes = [d.size for d in dims] + [-1] * self.ndim
308
+ self = self.expand(*sizes)
309
+ return self[dims]
310
+
311
+
312
+ _not_present = object()
313
+
314
+
315
+ def _getarg(name, offset, args, kwargs, default):
316
+ if len(args) > offset:
317
+ return args[offset]
318
+ return kwargs.get(name, default)
319
+
320
+
321
+ def _patcharg(name, offset, args, kwargs, value):
322
+ if len(args) > offset:
323
+ args[offset] = value
324
+ else:
325
+ kwargs[name] = value
326
+
327
+
328
+ def _wrap(
329
+ orig, dim_offset=0, keepdim_offset=1, dim_name="dim", single_dim=False, reduce=True
330
+ ):
331
+ from . import Dim, Tensor, TensorLike
332
+
333
+ def fn(self, *args, **kwargs):
334
+ dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present)
335
+ if dim is _not_present or (single_dim and not isinstance(dim, Dim)):
336
+ with _enable_layers(self.dims):
337
+ print(f"dim fallback batch_tensor for {orig}")
338
+ return Tensor.from_batched(
339
+ orig(self._batchtensor, *args, **kwargs), self._has_device
340
+ )
341
+ keepdim = (
342
+ _getarg("keepdim", keepdim_offset, args, kwargs, False) if reduce else False
343
+ )
344
+ t, levels = self._tensor, llist(self._levels)
345
+ dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim)
346
+ dim_indices = tuple(levels.index(d) for d in dims)
347
+ if reduce and not keepdim:
348
+ new_levels = [l for i, l in enumerate(levels) if i not in dim_indices]
349
+ else:
350
+ new_levels = levels
351
+
352
+ if len(dim_indices) == 1:
353
+ dim_indices = dim_indices[
354
+ 0
355
+ ] # so that dims that really only take a single argument work...
356
+ args = list(args)
357
+ _patcharg(dim_name, dim_offset, args, kwargs, dim_indices)
358
+
359
+ def wrap(t):
360
+ if isinstance(t, TensorLike):
361
+ return Tensor.from_positional(t, new_levels, self._has_device)
362
+ return t
363
+
364
+ with _enable_layers(new_levels):
365
+ print(f"dim used batch_tensor for {orig}")
366
+ r = orig(t, *args, **kwargs)
367
+ return tree_map(wrap, r)
368
+
369
+ return fn
370
+
371
+
372
+ def _def(name, *args, **kwargs):
373
+ from . import _Tensor
374
+
375
+ orig = getattr(torch.Tensor, name)
376
+ setattr(_Tensor, name, _wrap(orig, *args, **kwargs))
377
+
378
+
379
+ no_slice = slice(None)
380
+
381
+ _orig_getitem = torch.Tensor.__getitem__
382
+
383
+
384
+ class dim_tracker:
385
+ def __init__(self):
386
+ self.dims = llist()
387
+ self.count = []
388
+
389
+ def record(self, d):
390
+ if d not in self.dims:
391
+ self.dims.append(d)
392
+ self.count.append(1)
393
+
394
+ def __getitem__(self, d):
395
+ return self.count[self.dims.index(d)]
396
+
397
+
398
+ def t__getitem__(self, input):
399
+ from . import _Tensor, Dim, DimensionBindError, DimList, Tensor, TensorLike
400
+
401
+ # * bail to original example if we have a single non-Dim tensor, or a non-tensor
402
+ # * locate ... or an unbound tensor list, and determine its size, bind dim list
403
+ # (remember that None does not count to the total dim count)
404
+ # * bind simple dims and dim-packs to their sizes, count the number of uses of each dim,
405
+ # produce the re-view if needed
406
+ # * for each single-use dim index, replace with no_slice and mark that it will be added
407
+ # (keep track of whether we have to call super)
408
+ # * call super if needed
409
+ # * if we have dims to bind, bind them (it will help if we eliminated ... and None before)
410
+
411
+ # this handles bool indexing handling, as well as some other simple cases.
412
+
413
+ is_simple = (
414
+ not isinstance(input, Dim)
415
+ and not isinstance(input, (tuple, list))
416
+ and
417
+ # WAR for functorch bug where zero time tensors in getitem are not handled correctly.
418
+ not (isinstance(input, TensorLike) and input.ndim == 0)
419
+ )
420
+
421
+ if is_simple:
422
+ if isinstance(self, _Tensor):
423
+ return _Tensor.__torch_function__(_orig_getitem, None, (self, input))
424
+ else:
425
+ return _orig_getitem(self, input)
426
+
427
+ # can further optimize this case
428
+ if not isinstance(input, tuple):
429
+ input = [input]
430
+ else:
431
+ input = list(input)
432
+
433
+ dims_indexed = 0
434
+ expanding_object = None
435
+ dimlists = []
436
+ for i, s in enumerate(input):
437
+ if s is ... or isinstance(s, DimList) and not s.is_bound:
438
+ if expanding_object is not None:
439
+ msg = (
440
+ "at most one ... or unbound dimension list can exist in indexing list but"
441
+ f" found 2 at offsets {i} and {expanding_object}"
442
+ )
443
+ raise DimensionBindError(msg)
444
+ expanding_object = i
445
+
446
+ if isinstance(s, DimList):
447
+ dims_indexed += len(s) if s.is_bound else 0
448
+ dimlists.append(i)
449
+ elif s is not None and s is not ...:
450
+ dims_indexed += 1
451
+
452
+ ndim = self.ndim
453
+ if dims_indexed > ndim:
454
+ raise IndexError(
455
+ f"at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions."
456
+ )
457
+ if expanding_object is not None:
458
+ expanding_ndims = ndim - dims_indexed
459
+ obj = input[expanding_object]
460
+ if obj is ...:
461
+ input[expanding_object : expanding_object + 1] = [
462
+ no_slice
463
+ ] * expanding_ndims
464
+ else:
465
+ obj.bind_len(expanding_ndims)
466
+ # flatten the dimslists into the indexing
467
+ for i in reversed(dimlists):
468
+ input[i : i + 1] = input[i]
469
+ dims_indexed = 0
470
+ requires_view = False
471
+ size = self.size()
472
+ view_sizes = []
473
+ dims_seen = dim_tracker()
474
+
475
+ def add_dims(t):
476
+ if not isinstance(t, _Tensor):
477
+ return
478
+ for d in t.dims:
479
+ dims_seen.record(d)
480
+
481
+ add_dims(self)
482
+ dim_packs = []
483
+ for i, idx in enumerate(input):
484
+ if idx is None:
485
+ input[i] = no_slice
486
+ view_sizes.append(1)
487
+ requires_view = True
488
+ else:
489
+ sz = size[dims_indexed]
490
+ if isinstance(idx, Dim):
491
+ idx.size = sz
492
+ dims_seen.record(idx)
493
+ view_sizes.append(sz)
494
+ elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim):
495
+ for d in idx:
496
+ dims_seen.record(idx)
497
+ _bind_dims_to_size(sz, idx, f"offset {i}")
498
+ view_sizes.extend(d.size for d in idx)
499
+ requires_view = True
500
+ dim_packs.append(i)
501
+ else:
502
+ add_dims(idx)
503
+ view_sizes.append(sz)
504
+ dims_indexed += 1
505
+ if requires_view:
506
+ self = self.view(*view_sizes)
507
+ for i in reversed(dim_packs):
508
+ input[i : i + 1] = input[i]
509
+
510
+ # currenty:
511
+ # input is flat, containing either Dim, or Tensor, or something valid for standard indexing
512
+ # self may have first-class dims as well.
513
+
514
+ # to index:
515
+ # drop the first class dims from self, they just become direct indices of their positions
516
+
517
+ # figure out the dimensions of the indexing tensors: union of all the dims in the tensors in the index.
518
+ # these dimensions will appear and need to be bound at the first place tensor occures
519
+
520
+ if isinstance(self, _Tensor):
521
+ ptensor_self, levels = self._tensor, list(self._levels)
522
+ # indices to ptensor rather than self which has first-class dimensions
523
+ input_it = iter(input)
524
+ flat_inputs = [next(input_it) if isinstance(l, int) else l for l in levels]
525
+ has_device = self._has_device
526
+ to_pad = 0
527
+ else:
528
+ ptensor_self, flat_inputs = self, input
529
+ to_pad = ptensor_self.ndim - len(flat_inputs)
530
+ has_device = True
531
+
532
+ result_levels = []
533
+ index_levels = []
534
+ tensor_insert_point = None
535
+ to_expand = {}
536
+ requires_getindex = False
537
+ for i, inp in enumerate(flat_inputs):
538
+ if isinstance(inp, Dim) and dims_seen[inp] == 1:
539
+ flat_inputs[i] = no_slice
540
+ result_levels.append(inp)
541
+ elif isinstance(inp, TensorLike):
542
+ requires_getindex = True
543
+ if tensor_insert_point is None:
544
+ tensor_insert_point = len(result_levels)
545
+ ptensor, levels, _ = _tensor_levels(inp)
546
+ to_expand[i] = levels
547
+ flat_inputs[i] = ptensor
548
+ for l in levels:
549
+ if l not in index_levels:
550
+ index_levels.append(l)
551
+ else:
552
+ requires_getindex = True
553
+ result_levels.append(0)
554
+
555
+ if tensor_insert_point is not None:
556
+ result_levels[tensor_insert_point:tensor_insert_point] = index_levels
557
+
558
+ for i, levels in to_expand.items():
559
+ flat_inputs[i] = _match_levels(flat_inputs[i], levels, index_levels)
560
+
561
+ if requires_getindex:
562
+ result = _orig_getitem(ptensor_self, flat_inputs)
563
+ else:
564
+ result = ptensor_self
565
+
566
+ next_positional = -1
567
+ if to_pad > 0:
568
+ result_levels.extend([0] * to_pad)
569
+ for i, r in enumerate(reversed(result_levels)):
570
+ if isinstance(r, int):
571
+ result_levels[-1 - i] = next_positional
572
+ next_positional -= 1
573
+
574
+ return Tensor.from_positional(result, result_levels, has_device)
575
+
576
+
577
+ # XXX - dim is optional and can be the outer-most dimension...
578
+ def stack(tensors, new_dim, dim=0, out=None):
579
+ if isinstance(dim, int):
580
+ return torch.stack(tensors, dim, out).index(dim, new_dim)
581
+ index = None
582
+ if out is not None:
583
+ out, index = _positional_no_permute(out, dim, expand_dim=True)
584
+ ptensors = []
585
+ for t in tensors:
586
+ pt, pi = _positional_no_permute(t, dim, expand_dim=True)
587
+ if index is not None and pi != index:
588
+ pt = pt.move_dim(pi, index)
589
+ else:
590
+ index = pi
591
+ ptensors.append(pt)
592
+ pr = torch.stack(ptensors, index, out=out)
593
+ return pr.index((index, index + 1), (new_dim, dim))
594
+
595
+
596
+ _orig_split = torch.Tensor.split
597
+
598
+
599
+ def split(self, split_size_or_sections, dim=0):
600
+ from . import _Tensor, Dim
601
+
602
+ if isinstance(split_size_or_sections, int) or any(
603
+ isinstance(t, int) for t in split_size_or_sections
604
+ ):
605
+ if isinstance(dim, Dim):
606
+ raise ValueError(
607
+ "when dim is specified as a Dim object, split sizes must also be dimensions."
608
+ )
609
+ return _orig_split(self, split_size_or_sections, dim=dim)
610
+
611
+ if isinstance(dim, Dim):
612
+ assert isinstance(self, _Tensor), f"Tensor does not have dimension {dim}"
613
+ self, dim = _positional_no_permute(self, dim)
614
+
615
+ size = self.size(dim)
616
+ total_bound_size = 0
617
+ unbound = []
618
+ sizes = []
619
+ for i, d in enumerate(split_size_or_sections):
620
+ if d.is_bound:
621
+ sizes.append(d.size)
622
+ total_bound_size += d.size
623
+ else:
624
+ sizes.append(0)
625
+ unbound.append(i)
626
+
627
+ if unbound:
628
+ assert (
629
+ total_bound_size <= size
630
+ ), f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})"
631
+ remaining_size = size - total_bound_size
632
+ chunk_size = -(-remaining_size // len(unbound))
633
+ for u in unbound:
634
+ sz = min(chunk_size, remaining_size)
635
+ split_size_or_sections[u].size = sz
636
+ sizes[u] = sz
637
+ remaining_size -= sz
638
+ else:
639
+ assert (
640
+ total_bound_size == size
641
+ ), f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})"
642
+ return tuple(
643
+ t.index(dim, d)
644
+ for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim))
645
+ )
gtm/lib/python3.12/site-packages/functorch/dim/tree_map.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from functorch._C import dim
8
+
9
+ tree_flatten = dim.tree_flatten
10
+
11
+
12
+ def tree_map(fn, tree):
13
+ vs, unflatten = tree_flatten(tree)
14
+ return unflatten(fn(v) for v in vs)
gtm/lib/python3.12/site-packages/functorch/dim/wrap_type.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from types import (
8
+ BuiltinMethodType,
9
+ FunctionType,
10
+ GetSetDescriptorType,
11
+ MethodDescriptorType,
12
+ WrapperDescriptorType,
13
+ )
14
+
15
+ from functorch._C import dim as _C
16
+
17
+ _wrap_method = _C._wrap_method
18
+
19
+ FUNC_TYPES = (
20
+ FunctionType,
21
+ MethodDescriptorType,
22
+ BuiltinMethodType,
23
+ WrapperDescriptorType,
24
+ )
25
+ PROPERTY_TYPES = (GetSetDescriptorType, property)
26
+
27
+
28
+ def _py_wrap_method(orig, __torch_function__):
29
+ def impl(*args, **kwargs):
30
+ return __torch_function__(orig, None, args, kwargs)
31
+
32
+ return impl
33
+
34
+
35
+ def wrap_type(use_c, to_patch, pattern, __torch_function__):
36
+ if use_c:
37
+ wrap_method = _wrap_method
38
+ else:
39
+ wrap_method = _py_wrap_method
40
+
41
+ all = {}
42
+ for t in reversed(pattern.mro()[:-1]): # skip object
43
+ all.update(t.__dict__)
44
+
45
+ def wrap_attr(orig):
46
+ return property(wrap_method(orig.__get__, __torch_function__))
47
+
48
+ for name, obj in all.items():
49
+ if name in (
50
+ "__dict__",
51
+ "__new__",
52
+ "__init__",
53
+ "__repr__",
54
+ "__weakref__",
55
+ "__doc__",
56
+ "__module__",
57
+ "__dir__",
58
+ ):
59
+ continue
60
+
61
+ # skip things that have been overloaded
62
+ # things that come from object like `__eq__` still need to be patched, however.
63
+ if hasattr(to_patch, name) and getattr(to_patch, name) is not getattr(
64
+ object, name, None
65
+ ):
66
+ continue
67
+
68
+ if isinstance(obj, FUNC_TYPES):
69
+ setattr(to_patch, name, wrap_method(obj, __torch_function__))
70
+ elif isinstance(obj, PROPERTY_TYPES):
71
+ setattr(to_patch, name, wrap_attr(obj))
gtm/lib/python3.12/site-packages/functorch/einops/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .rearrange import rearrange
2
+
3
+ __all__ = ["rearrange"]
gtm/lib/python3.12/site-packages/functorch/einops/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (258 Bytes). View file
 
gtm/lib/python3.12/site-packages/functorch/einops/__pycache__/_parsing.cpython-312.pyc ADDED
Binary file (13 kB). View file
 
gtm/lib/python3.12/site-packages/functorch/einops/__pycache__/rearrange.cpython-312.pyc ADDED
Binary file (9.82 kB). View file
 
gtm/lib/python3.12/site-packages/functorch/einops/_parsing.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adapted from https://github.com/arogozhnikov/einops/blob/36c7bb16e57d6e57f8f3050f9e07abdf3f00469f/einops/parsing.py.
2
+
3
+ MIT License
4
+
5
+ Copyright (c) 2018 Alex Rogozhnikov
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all
15
+ copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
24
+ """
25
+ from __future__ import annotations
26
+
27
+ import keyword
28
+ import warnings
29
+ from typing import Collection, List, Mapping, Optional, Set, Tuple, Union
30
+
31
+ _ellipsis: str = "…" # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated
32
+
33
+
34
+ class AnonymousAxis:
35
+ """Used by `ParsedExpression` to represent an axis with a size (> 1), but no associated identifier.
36
+
37
+ Note: Different instances of this class are not equal to each other, even if they have the same value.
38
+ """
39
+
40
+ def __init__(self, value: str) -> None:
41
+ self.value = int(value)
42
+ if self.value < 1:
43
+ raise ValueError(
44
+ f"Anonymous axis should have positive length, not {self.value}"
45
+ )
46
+
47
+ def __repr__(self) -> str:
48
+ return f"{self.value}-axis"
49
+
50
+
51
+ class ParsedExpression:
52
+ """Structure containing information about one side of an `einops`-style pattern (e.g. 'b c (h w)')."""
53
+
54
+ def __init__(
55
+ self,
56
+ expression: str,
57
+ *,
58
+ allow_underscore: bool = False,
59
+ allow_duplicates: bool = False,
60
+ ) -> None:
61
+ """Parse the expression and store relevant metadata.
62
+
63
+ Args:
64
+ expression (str): the `einops`-pattern to parse
65
+ allow_underscore (bool): whether to allow axis identifier names to begin with an underscore
66
+ allow_duplicates (bool): whether to allow an identifier to appear more than once in the expression
67
+ """
68
+ self.has_ellipsis: bool = False
69
+ self.has_ellipsis_parenthesized: Optional[bool] = None
70
+ self.identifiers: Set[Union[str, AnonymousAxis]] = set()
71
+ # that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
72
+ self.has_non_unitary_anonymous_axes: bool = False
73
+ # composition keeps structure of composite axes, see how different corner cases are handled in tests
74
+ self.composition: List[Union[List[Union[str, AnonymousAxis]], str]] = []
75
+ if "." in expression:
76
+ if "..." not in expression:
77
+ raise ValueError(
78
+ "Expression may contain dots only inside ellipsis (...)"
79
+ )
80
+ if str.count(expression, "...") != 1 or str.count(expression, ".") != 3:
81
+ raise ValueError(
82
+ "Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor "
83
+ )
84
+ expression = expression.replace("...", _ellipsis)
85
+ self.has_ellipsis = True
86
+
87
+ bracket_group: Optional[List[Union[str, AnonymousAxis]]] = None
88
+
89
+ def add_axis_name(x: str) -> None:
90
+ if x in self.identifiers:
91
+ if not (allow_underscore and x == "_") and not allow_duplicates:
92
+ raise ValueError(
93
+ f"Indexing expression contains duplicate dimension '{x}'"
94
+ )
95
+ if x == _ellipsis:
96
+ self.identifiers.add(_ellipsis)
97
+ if bracket_group is None:
98
+ self.composition.append(_ellipsis)
99
+ self.has_ellipsis_parenthesized = False
100
+ else:
101
+ bracket_group.append(_ellipsis)
102
+ self.has_ellipsis_parenthesized = True
103
+ else:
104
+ is_number = str.isdecimal(x)
105
+ if is_number and int(x) == 1:
106
+ # handling the case of anonymous axis of length 1
107
+ if bracket_group is None:
108
+ self.composition.append([])
109
+ else:
110
+ pass # no need to think about 1s inside parenthesis
111
+ return
112
+ is_axis_name, reason = self.check_axis_name_return_reason(
113
+ x, allow_underscore=allow_underscore
114
+ )
115
+ if not (is_number or is_axis_name):
116
+ raise ValueError(f"Invalid axis identifier: {x}\n{reason}")
117
+ axis_name: Union[str, AnonymousAxis] = (
118
+ AnonymousAxis(x) if is_number else x
119
+ )
120
+ self.identifiers.add(axis_name)
121
+ if is_number:
122
+ self.has_non_unitary_anonymous_axes = True
123
+ if bracket_group is None:
124
+ self.composition.append([axis_name])
125
+ else:
126
+ bracket_group.append(axis_name)
127
+
128
+ current_identifier = None
129
+ for char in expression:
130
+ if char in "() ":
131
+ if current_identifier is not None:
132
+ add_axis_name(current_identifier)
133
+ current_identifier = None
134
+ if char == "(":
135
+ if bracket_group is not None:
136
+ raise ValueError(
137
+ "Axis composition is one-level (brackets inside brackets not allowed)"
138
+ )
139
+ bracket_group = []
140
+ elif char == ")":
141
+ if bracket_group is None:
142
+ raise ValueError("Brackets are not balanced")
143
+ self.composition.append(bracket_group)
144
+ bracket_group = None
145
+ elif str.isalnum(char) or char in ["_", _ellipsis]:
146
+ if current_identifier is None:
147
+ current_identifier = char
148
+ else:
149
+ current_identifier += char
150
+ else:
151
+ raise ValueError(f"Unknown character '{char}'")
152
+
153
+ if bracket_group is not None:
154
+ raise ValueError(f"Imbalanced parentheses in expression: '{expression}'")
155
+ if current_identifier is not None:
156
+ add_axis_name(current_identifier)
157
+
158
+ @staticmethod
159
+ def check_axis_name_return_reason(
160
+ name: str, allow_underscore: bool = False
161
+ ) -> Tuple[bool, str]:
162
+ """Check if the given axis name is valid, and a message explaining why if not.
163
+
164
+ Valid axes names are python identifiers except keywords, and should not start or end with an underscore.
165
+
166
+ Args:
167
+ name (str): the axis name to check
168
+ allow_underscore (bool): whether axis names are allowed to start with an underscore
169
+
170
+ Returns:
171
+ Tuple[bool, str]: whether the axis name is valid, a message explaining why if not
172
+ """
173
+ if not str.isidentifier(name):
174
+ return False, "not a valid python identifier"
175
+ elif name[0] == "_" or name[-1] == "_":
176
+ if name == "_" and allow_underscore:
177
+ return True, ""
178
+ return False, "axis name should should not start or end with underscore"
179
+ else:
180
+ if keyword.iskeyword(name):
181
+ warnings.warn(
182
+ f"It is discouraged to use axes names that are keywords: {name}",
183
+ RuntimeWarning,
184
+ )
185
+ if name in ["axis"]:
186
+ warnings.warn(
187
+ "It is discouraged to use 'axis' as an axis name and will raise an error in future",
188
+ FutureWarning,
189
+ )
190
+ return True, ""
191
+
192
+ @staticmethod
193
+ def check_axis_name(name: str) -> bool:
194
+ """Check if the name is a valid axis name.
195
+
196
+ Args:
197
+ name (str): the axis name to check
198
+
199
+ Returns:
200
+ bool: whether the axis name is valid
201
+ """
202
+ is_valid, _ = ParsedExpression.check_axis_name_return_reason(name)
203
+ return is_valid
204
+
205
+
206
+ def parse_pattern(
207
+ pattern: str, axes_lengths: Mapping[str, int]
208
+ ) -> Tuple[ParsedExpression, ParsedExpression]:
209
+ """Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object.
210
+
211
+ Args:
212
+ pattern (str): the `einops`-style rearrangement pattern
213
+ axes_lengths (Mapping[str, int]): any additional length specifications for dimensions
214
+
215
+ Returns:
216
+ Tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions
217
+ """
218
+ # adapted from einops.einops._prepare_transformation_recipe
219
+ # https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/einops/einops.py
220
+ try:
221
+ left_str, right_str = pattern.split("->")
222
+ except ValueError:
223
+ raise ValueError("Pattern must contain a single '->' separator") from None
224
+
225
+ if _ellipsis in axes_lengths:
226
+ raise ValueError(f"'{_ellipsis}' is not an allowed axis identifier")
227
+
228
+ left = ParsedExpression(left_str)
229
+ right = ParsedExpression(right_str)
230
+
231
+ if not left.has_ellipsis and right.has_ellipsis:
232
+ raise ValueError(
233
+ f"Ellipsis found in right side, but not left side of a pattern {pattern}"
234
+ )
235
+ if left.has_ellipsis and left.has_ellipsis_parenthesized:
236
+ raise ValueError(
237
+ f"Ellipsis is parenthesis in the left side is not allowed: {pattern}"
238
+ )
239
+
240
+ return left, right
241
+
242
+
243
+ def validate_rearrange_expressions(
244
+ left: ParsedExpression, right: ParsedExpression, axes_lengths: Mapping[str, int]
245
+ ) -> None:
246
+ """Perform expression validations that are specific to the `rearrange` operation.
247
+
248
+ Args:
249
+ left (ParsedExpression): left-hand side expression
250
+ right (ParsedExpression): right-hand side expression
251
+ axes_lengths (Mapping[str, int]): any additional length specifications for dimensions
252
+ """
253
+ for length in axes_lengths.values():
254
+ if (length_type := type(length)) is not int:
255
+ raise TypeError(
256
+ f"rearrange axis lengths must be integers, got: {length_type}"
257
+ )
258
+
259
+ if left.has_non_unitary_anonymous_axes or right.has_non_unitary_anonymous_axes:
260
+ raise ValueError("rearrange only supports unnamed axes of size 1")
261
+
262
+ difference = set.symmetric_difference(left.identifiers, right.identifiers)
263
+ if len(difference) > 0:
264
+ raise ValueError(
265
+ f"Identifiers only on one side of rearrange expression (should be on both): {difference}"
266
+ )
267
+
268
+ unmatched_axes = axes_lengths.keys() - left.identifiers
269
+ if len(unmatched_axes) > 0:
270
+ raise ValueError(
271
+ f"Identifiers not found in rearrange expression: {unmatched_axes}"
272
+ )
273
+
274
+
275
+ def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str:
276
+ """Convert a collection of strings representing first class dims into a comma-separated string.
277
+
278
+ Args:
279
+ collection (Collection[Union[str, Collection[str]]]): the collection of strings to convert
280
+
281
+ Returns:
282
+ str: the comma-separated string
283
+
284
+ Examples:
285
+ >>> comma_separate(('d0',))
286
+ 'd0'
287
+
288
+ >>> comma_separate(('d0', 'd1', 'd2', 'd3'))
289
+ 'd0, d1, d2, d3'
290
+
291
+ >>> comma_separate([('d1', 'd4')])
292
+ '(d1, d4)'
293
+
294
+ >>> comma_separate([('d0',), (), ('d1',), ('d2',), ('d3', 'd4')])
295
+ '(d0,), (), (d1,), (d2,), (d3, d4)'
296
+ """
297
+ return ", ".join(
298
+ item
299
+ if isinstance(item, str)
300
+ else f"({comma_separate(item)}{',' if len(item) == 1 else ''})"
301
+ for item in collection
302
+ )
gtm/lib/python3.12/site-packages/functorch/einops/rearrange.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ from typing import Callable, Dict, List, Sequence, Tuple, Union
5
+
6
+ import torch
7
+
8
+ from functorch._C import dim as _C
9
+ from ._parsing import (
10
+ _ellipsis,
11
+ AnonymousAxis,
12
+ comma_separate,
13
+ parse_pattern,
14
+ validate_rearrange_expressions,
15
+ )
16
+
17
+ __all__ = ["rearrange"]
18
+
19
+ dims = _C.dims
20
+
21
+
22
+ @functools.lru_cache(256)
23
+ def _create_rearrange_callable(
24
+ tensor_ndim: int, pattern: str, **axes_lengths: int
25
+ ) -> Callable[[torch.Tensor], torch.Tensor]:
26
+ r"""Translate an `einops`-style pattern into a callable that performs the rearrange using first-class dimensions.
27
+
28
+ Since the an equivalent result is computed for tensors with the same number of dimensions, with the same pattern and
29
+ specified axes lengths, this function can be memoized.
30
+
31
+ Args:
32
+ tensor_ndim (int): the number of dimensions in the tensor to rearrange
33
+ pattern (str): the `einops`-style rearrangement pattern
34
+ axes_lengths (int): any additional length specifications for dimensions
35
+
36
+ Returns:
37
+ Callable[[torch.Tensor], torch.Tensor]: a callable that performs the rearrangement
38
+ """
39
+ left, right = parse_pattern(pattern, axes_lengths)
40
+ validate_rearrange_expressions(left, right, axes_lengths)
41
+
42
+ n_anon_dims = sum(not dim for dim in left.composition)
43
+ if left.has_ellipsis:
44
+ n_ellipsis_dims = tensor_ndim - (len(left.composition) - 1)
45
+ n_named_dims = len(left.identifiers) - 1
46
+
47
+ if (pattern_ndim := n_anon_dims + n_named_dims) > tensor_ndim:
48
+ raise ValueError(
49
+ f"Number of dimensions in pattern ({pattern_ndim}) must be less than or equal to the number of "
50
+ f"dimensions in the tensor ({tensor_ndim})"
51
+ )
52
+ else:
53
+ n_ellipsis_dims = 0
54
+ n_named_dims = len(left.identifiers)
55
+
56
+ if (pattern_ndim := len(left.composition)) != tensor_ndim:
57
+ raise ValueError(
58
+ f"Number of dimensions in pattern ({pattern_ndim}) must be equal to the number of dimensions in "
59
+ f"the tensor ({tensor_ndim})"
60
+ )
61
+ n_dims = n_named_dims + n_ellipsis_dims + n_anon_dims
62
+
63
+ if n_dims == 0:
64
+ # an identity rearrangement on a 0-dimension tensor
65
+ return lambda tensor: tensor
66
+
67
+ first_class_dims: Tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))
68
+ identifier_dim_map: Dict[Union[str, AnonymousAxis], Tuple[str, ...]] = {}
69
+ anon_axes: List[AnonymousAxis] = []
70
+
71
+ # map the left-hand side identifiers to strings representing first class dims
72
+ dims_i = 0
73
+ for dimension in left.composition:
74
+ if isinstance(dimension, list):
75
+ for identifier in dimension:
76
+ # non-unitary anon axes are not allowed in rearrange & unitary anon axes are represented as empty lists
77
+ assert isinstance(identifier, str)
78
+ identifier_dim_map[identifier] = (first_class_dims[dims_i],)
79
+ dims_i += 1
80
+ if not dimension:
81
+ # unitary anonymous axis
82
+ anon_axis = AnonymousAxis("1")
83
+ identifier_dim_map[anon_axis] = (first_class_dims[dims_i],)
84
+ anon_axes.append(anon_axis)
85
+ dimension.append(anon_axis)
86
+ dims_i += 1
87
+ elif dimension == _ellipsis:
88
+ identifier = _ellipsis
89
+ identifier_dim_map[identifier] = tuple(
90
+ first_class_dims[dims_i + j] for j in range(n_ellipsis_dims)
91
+ )
92
+ dims_i += n_ellipsis_dims
93
+ else:
94
+ raise ValueError(f"Unexpected dimension: {dimension}")
95
+
96
+ def composition_to_dims(
97
+ composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]]
98
+ ) -> List[Union[str, Tuple[str, ...]]]:
99
+ """Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first
100
+ class dims."""
101
+ dim_composition: List[Union[str, Tuple[str, ...]]] = []
102
+ for dimension in composition:
103
+ if isinstance(dimension, list):
104
+ dim_composition.append(
105
+ tuple(
106
+ dim
107
+ for identifier in dimension
108
+ for dim in identifier_dim_map[identifier]
109
+ )
110
+ )
111
+ elif dimension == _ellipsis:
112
+ dim_composition.extend(identifier_dim_map[_ellipsis])
113
+ else:
114
+ raise ValueError(f"Unexpected dimension: {dimension}")
115
+ return dim_composition
116
+
117
+ left_dims = composition_to_dims(left.composition)
118
+ right_dims = composition_to_dims(right.composition)
119
+ anon_dims = tuple(identifier_dim_map[axis][0] for axis in anon_axes)
120
+ specified_lengths = tuple(
121
+ (identifier_dim_map[axis][0], length) for axis, length in axes_lengths.items()
122
+ )
123
+
124
+ custom_rearrange_callable_name = "do_rearrange"
125
+ custom_rearrange_callable_code = (
126
+ (
127
+ f"def {custom_rearrange_callable_name}(tensor):\n"
128
+ f" {comma_separate(first_class_dims)} = dims({n_dims})\n"
129
+ )
130
+ + (
131
+ "".join(
132
+ f" {dim}.size = {length}\n" for (dim, length) in specified_lengths
133
+ )
134
+ if specified_lengths
135
+ else ""
136
+ )
137
+ + f" tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n"
138
+ + (
139
+ f" return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n"
140
+ if anon_dims
141
+ else " return tensor\n"
142
+ )
143
+ )
144
+
145
+ exec(custom_rearrange_callable_code)
146
+ return locals()[custom_rearrange_callable_name]
147
+
148
+
149
+ def rearrange(
150
+ tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
151
+ pattern: str,
152
+ **axes_lengths: int,
153
+ ) -> torch.Tensor:
154
+ r"""A native implementation of `einops.rearrange`, a reader-friendly smart element reordering for multidimensional
155
+ tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
156
+ stack, concatenate and other operations.
157
+
158
+ See: https://einops.rocks/api/rearrange/
159
+
160
+ Args:
161
+ tensor (Tensor or sequence of Tensor): the tensor(s) to rearrange
162
+ pattern (str): the rearrangement pattern
163
+ axes_lengths (int): any additional length specifications for dimensions
164
+
165
+ Returns:
166
+ Tensor: the rearranged tensor
167
+
168
+ Examples:
169
+ >>> # suppose we have a set of 32 images in "h w c" format (height-width-channel)
170
+ >>> images = torch.randn((32, 30, 40, 3))
171
+
172
+ >>> # stack along first (batch) axis, output is a single array
173
+ >>> rearrange(images, 'b h w c -> b h w c').shape
174
+ torch.Size([32, 30, 40, 3])
175
+
176
+ >>> # concatenate images along height (vertical axis), 960 = 32 * 30
177
+ >>> rearrange(images, 'b h w c -> (b h) w c').shape
178
+ torch.Size([960, 40, 3])
179
+
180
+ >>> # concatenated images along horizontal axis, 1280 = 32 * 40
181
+ >>> rearrange(images, 'b h w c -> h (b w) c').shape
182
+ torch.Size([30, 1280, 3])
183
+
184
+ >>> # reordered axes to "b c h w" format for deep learning
185
+ >>> rearrange(images, 'b h w c -> b c h w').shape
186
+ torch.Size([32, 3, 30, 40])
187
+
188
+ >>> # flattened each image into a vector, 3600 = 30 * 40 * 3
189
+ >>> rearrange(images, 'b h w c -> b (c h w)').shape
190
+ torch.Size([32, 3600])
191
+
192
+ >>> # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
193
+ >>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
194
+ torch.Size([128, 15, 20, 3])
195
+
196
+ >>> # space-to-depth operation
197
+ >>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
198
+ torch.Size([32, 15, 20, 12])
199
+ """
200
+ if not isinstance(tensor, torch.Tensor):
201
+ tensor = torch.stack(tensor)
202
+
203
+ rearrange_callable = _create_rearrange_callable(
204
+ tensor.ndim, pattern, **axes_lengths
205
+ )
206
+
207
+ return rearrange_callable(tensor)
gtm/lib/python3.12/site-packages/functorch/experimental/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # PyTorch forward-mode is not mature yet
2
+ from torch._functorch.apis import chunk_vmap
3
+ from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
4
+ from torch._functorch.eager_transforms import hessian, jacfwd, jvp
5
+
6
+ from functorch import functionalize
gtm/lib/python3.12/site-packages/functorch/experimental/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (505 Bytes). View file
 
gtm/lib/python3.12/site-packages/functorch/experimental/__pycache__/control_flow.cpython-312.pyc ADDED
Binary file (443 Bytes). View file
 
gtm/lib/python3.12/site-packages/functorch/experimental/__pycache__/ops.cpython-312.pyc ADDED
Binary file (254 Bytes). View file
 
gtm/lib/python3.12/site-packages/functorch/experimental/control_flow.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from torch import cond # noqa: F401
2
+ from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401
3
+
4
+ from torch._higher_order_ops.map import ( # noqa: F401
5
+ _stack_pytree,
6
+ _unstack_pytree,
7
+ map,
8
+ )